diff --git a/br/pkg/restore/snap_client/pipeline_items.go b/br/pkg/restore/snap_client/pipeline_items.go index 232847cbb31f4..0b6ded550e866 100644 --- a/br/pkg/restore/snap_client/pipeline_items.go +++ b/br/pkg/restore/snap_client/pipeline_items.go @@ -162,7 +162,7 @@ func (rc *SnapClient) replaceTables( return 0, errors.Trace(err) } - if err := notifyUpdateAllUsersPrivilege(renamedTables, rc.dom.NotifyUpdatePrivilege); err != nil { + if err := notifyUpdateAllUsersPrivilege(renamedTables, rc.dom.NotifyUpdateAllUsersPrivilege); err != nil { return 0, errors.Trace(err) } diff --git a/br/pkg/restore/snap_client/systable_restore.go b/br/pkg/restore/snap_client/systable_restore.go index 55c40354d7307..de43f20784025 100644 --- a/br/pkg/restore/snap_client/systable_restore.go +++ b/br/pkg/restore/snap_client/systable_restore.go @@ -388,7 +388,7 @@ func (rc *SnapClient) afterSystemTablesReplaced(ctx context.Context, db string, var err error for _, table := range tables { if table == "user" { - if serr := rc.dom.NotifyUpdatePrivilege(); serr != nil { + if serr := rc.dom.NotifyUpdateAllUsersPrivilege(); serr != nil { log.Warn("failed to flush privileges, please manually execute `FLUSH PRIVILEGES`") err = multierr.Append(err, berrors.ErrUnknown.Wrap(serr).GenWithStack("failed to flush privileges")) } else { diff --git a/br/pkg/restore/snap_client/systable_restore_test.go b/br/pkg/restore/snap_client/systable_restore_test.go index 5729ee1dd4e7c..45d1063389b98 100644 --- a/br/pkg/restore/snap_client/systable_restore_test.go +++ b/br/pkg/restore/snap_client/systable_restore_test.go @@ -393,7 +393,7 @@ func TestCheckPrivilegeTableRowsCollateCompatibility(t *testing.T) { // // The above variables are in the file br/pkg/restore/systable_restore.go func TestMonitorTheSystemTableIncremental(t *testing.T) { - require.Equal(t, int64(223), session.CurrentBootstrapVersion) + require.Equal(t, int64(224), session.CurrentBootstrapVersion) } func TestIsStatsTemporaryTable(t *testing.T) { diff --git a/pkg/ddl/executor.go b/pkg/ddl/executor.go index a891f7fbeb7eb..6a715f744f2ec 100644 --- a/pkg/ddl/executor.go +++ b/pkg/ddl/executor.go @@ -6267,7 +6267,7 @@ func (e *executor) DropResourceGroup(ctx sessionctx.Context, stmt *ast.DropResou if checker == nil { return errors.New("miss privilege checker") } - user, matched := checker.MatchUserResourceGroupName(groupName.L) + user, matched := checker.MatchUserResourceGroupName(ctx.GetRestrictedSQLExecutor(), groupName.L) if matched { err = errors.Errorf("user [%s] depends on the resource group to drop", user) return err diff --git a/pkg/domain/BUILD.bazel b/pkg/domain/BUILD.bazel index 340bd723c1ad9..5c485aa4f452d 100644 --- a/pkg/domain/BUILD.bazel +++ b/pkg/domain/BUILD.bazel @@ -95,6 +95,7 @@ go_library( "//pkg/util/printer", "//pkg/util/replayer", "//pkg/util/servermemorylimit", + "//pkg/util/size", "//pkg/util/sqlexec", "//pkg/util/sqlkiller", "//pkg/util/syncutil", diff --git a/pkg/domain/domain.go b/pkg/domain/domain.go index 2aa4e69648d0c..80b50a801279d 100644 --- a/pkg/domain/domain.go +++ b/pkg/domain/domain.go @@ -16,6 +16,7 @@ package domain import ( "context" + "encoding/json" "fmt" "math" "math/rand" @@ -98,6 +99,7 @@ import ( "github.com/pingcap/tidb/pkg/util/memoryusagealarm" "github.com/pingcap/tidb/pkg/util/replayer" "github.com/pingcap/tidb/pkg/util/servermemorylimit" + "github.com/pingcap/tidb/pkg/util/size" "github.com/pingcap/tidb/pkg/util/sqlkiller" "github.com/pingcap/tidb/pkg/util/syncutil" "github.com/tikv/client-go/v2/tikv" @@ -1885,6 +1887,36 @@ func (do *Domain) GetPDHTTPClient() pdhttp.Client { return nil } +func decodePrivilegeEvent(resp clientv3.WatchResponse) PrivilegeEvent { + var msg PrivilegeEvent + for _, event := range resp.Events { + if event.Kv != nil { + val := event.Kv.Value + if len(val) > 0 { + var tmp PrivilegeEvent + err := json.Unmarshal(val, &tmp) + if err != nil { + logutil.BgLogger().Warn("decodePrivilegeEvent unmarshal fail", zap.Error(err)) + break + } + if tmp.All { + msg.All = true + break + } + // duplicated users in list is ok. + msg.UserList = append(msg.UserList, tmp.UserList...) + } + } + } + + // In case old version triggers the event, the event value is empty, + // Then we fall back to the old way: reload all the users. + if len(msg.UserList) == 0 { + msg.All = true + } + return msg +} + // LoadPrivilegeLoop create a goroutine loads privilege tables in a loop, it // should be called only once in BootstrapSession. func (do *Domain) LoadPrivilegeLoop(sctx sessionctx.Context) error { @@ -1894,15 +1926,12 @@ func (do *Domain) LoadPrivilegeLoop(sctx sessionctx.Context) error { if err != nil { return err } - do.privHandle = privileges.NewHandle(sctx.GetRestrictedSQLExecutor()) - if err := do.privHandle.Update(); err != nil { - return errors.Trace(err) - } + do.privHandle = privileges.NewHandle(do.SysSessionPool(), sctx.GetSessionVars().GlobalVarsAccessor) var watchCh clientv3.WatchChan duration := 5 * time.Minute if do.etcdClient != nil { - watchCh = do.etcdClient.Watch(context.Background(), privilegeKey) + watchCh = do.etcdClient.Watch(do.ctx, privilegeKey) duration = 10 * time.Minute } @@ -1914,25 +1943,30 @@ func (do *Domain) LoadPrivilegeLoop(sctx sessionctx.Context) error { var count int for { - ok := true + var event PrivilegeEvent select { case <-do.exit: return - case _, ok = <-watchCh: - case <-time.After(duration): - } - if !ok { - logutil.BgLogger().Warn("load privilege loop watch channel closed") - watchCh = do.etcdClient.Watch(context.Background(), privilegeKey) - count++ - if count > 10 { - time.Sleep(time.Duration(count) * time.Second) + case resp, ok := <-watchCh: + if ok { + count = 0 + event = decodePrivilegeEvent(resp) + } else { + if do.ctx.Err() == nil { + logutil.BgLogger().Error("load privilege loop watch channel closed") + watchCh = do.etcdClient.Watch(do.ctx, privilegeKey) + count++ + if count > 10 { + time.Sleep(time.Duration(count) * time.Second) + } + continue + } } - continue + case <-time.After(duration): + event.All = true } - count = 0 - err := do.privHandle.Update() + err := privReloadEvent(do.privHandle, &event) metrics.LoadPrivilegeCounter.WithLabelValues(metrics.RetLabel(err)).Inc() if err != nil { logutil.BgLogger().Warn("load privilege failed", zap.Error(err)) @@ -1942,6 +1976,18 @@ func (do *Domain) LoadPrivilegeLoop(sctx sessionctx.Context) error { return nil } +func privReloadEvent(h *privileges.Handle, event *PrivilegeEvent) (err error) { + switch { + case !variable.AccelerateUserCreationUpdate.Load(): + err = h.UpdateAll() + case event.All: + err = h.UpdateAllActive() + default: + err = h.Update(event.UserList) + } + return +} + // LoadSysVarCacheLoop create a goroutine loads sysvar cache in a loop, // it should be called only once in BootstrapSession. func (do *Domain) LoadSysVarCacheLoop(ctx sessionctx.Context) error { @@ -2923,15 +2969,39 @@ const ( tiflashComputeNodeKey = "/tiflash/new_tiflash_compute_nodes" ) +// PrivilegeEvent is the message definition for NotifyUpdatePrivilege(), encoded in json. +// TiDB old version do not use no such message. +type PrivilegeEvent struct { + All bool + UserList []string +} + +// NotifyUpdateAllUsersPrivilege updates privilege key in etcd, TiDB client that watches +// the key will get notification. +func (do *Domain) NotifyUpdateAllUsersPrivilege() error { + return do.notifyUpdatePrivilege(PrivilegeEvent{All: true}) +} + // NotifyUpdatePrivilege updates privilege key in etcd, TiDB client that watches // the key will get notification. -func (do *Domain) NotifyUpdatePrivilege() error { +func (do *Domain) NotifyUpdatePrivilege(userList []string) error { + return do.notifyUpdatePrivilege(PrivilegeEvent{UserList: userList}) +} + +func (do *Domain) notifyUpdatePrivilege(event PrivilegeEvent) error { // No matter skip-grant-table is configured or not, sending an etcd message is required. // Because we need to tell other TiDB instances to update privilege data, say, we're changing the // password using a special TiDB instance and want the new password to take effect. if do.etcdClient != nil { + data, err := json.Marshal(event) + if err != nil { + return errors.Trace(err) + } + if uint64(len(data)) > size.MB { + logutil.BgLogger().Warn("notify update privilege message too large", zap.ByteString("value", data)) + } row := do.etcdClient.KV - _, err := row.Put(context.Background(), privilegeKey, "") + _, err = row.Put(do.ctx, privilegeKey, string(data)) if err != nil { logutil.BgLogger().Warn("notify update privilege failed", zap.Error(err)) } @@ -2944,7 +3014,7 @@ func (do *Domain) NotifyUpdatePrivilege() error { return nil } - return do.PrivilegeHandle().Update() + return privReloadEvent(do.PrivilegeHandle(), &event) } // NotifyUpdateSysVarCache updates the sysvar cache key in etcd, which other TiDB diff --git a/pkg/executor/grant.go b/pkg/executor/grant.go index efe517b836685..9309be3224eb2 100644 --- a/pkg/executor/grant.go +++ b/pkg/executor/grant.go @@ -253,7 +253,8 @@ func (e *GrantExec) Next(ctx context.Context, _ *chunk.Chunk) error { return err } isCommit = true - return domain.GetDomain(e.Ctx()).NotifyUpdatePrivilege() + users := userSpecToUserList(e.Users) + return domain.GetDomain(e.Ctx()).NotifyUpdatePrivilege(users) } func containsNonDynamicPriv(privList []*ast.PrivElem) bool { diff --git a/pkg/executor/infoschema_reader_test.go b/pkg/executor/infoschema_reader_test.go index 0afb2b077458f..6ab3991c1371b 100644 --- a/pkg/executor/infoschema_reader_test.go +++ b/pkg/executor/infoschema_reader_test.go @@ -711,8 +711,6 @@ func TestIndexUsageTable(t *testing.T) { where TABLE_SCHEMA = 'test' and TABLE_NAME = 'idt2' and INDEX_NAME = 'idx_4';`).Check( testkit.RowsWithSep("|", "test|idt2|idx_4")) - tk.MustQuery(`select count(*) from information_schema.tidb_index_usage;`).Check( - testkit.RowsWithSep("|", "81")) tk.MustQuery(`select TABLE_SCHEMA, TABLE_NAME, INDEX_NAME from information_schema.tidb_index_usage where TABLE_SCHEMA = 'test1';`).Check(testkit.Rows()) diff --git a/pkg/executor/revoke.go b/pkg/executor/revoke.go index 79391c43bb7a8..ca631d08f6383 100644 --- a/pkg/executor/revoke.go +++ b/pkg/executor/revoke.go @@ -125,7 +125,16 @@ func (e *RevokeExec) Next(ctx context.Context, _ *chunk.Chunk) error { return err } isCommit = true - return domain.GetDomain(e.Ctx()).NotifyUpdatePrivilege() + users := userSpecToUserList(e.Users) + return domain.GetDomain(e.Ctx()).NotifyUpdatePrivilege(users) +} + +func userSpecToUserList(specs []*ast.UserSpec) []string { + users := make([]string, 0, len(specs)) + for _, user := range specs { + users = append(users, user.User.Username) + } + return users } // Checks that dynamic privileges are only of global scope. diff --git a/pkg/executor/show.go b/pkg/executor/show.go index 0cb952c5ae713..f0f56ea113675 100644 --- a/pkg/executor/show.go +++ b/pkg/executor/show.go @@ -207,7 +207,7 @@ func (e *ShowExec) fetchAll(ctx context.Context) error { case ast.ShowEngines: return e.fetchShowEngines(ctx) case ast.ShowGrants: - return e.fetchShowGrants() + return e.fetchShowGrants(ctx) case ast.ShowIndex: return e.fetchShowIndex() case ast.ShowProcedureStatus: @@ -1869,7 +1869,7 @@ func (e *ShowExec) fetchShowCreateUser(ctx context.Context) error { return nil } -func (e *ShowExec) fetchShowGrants() error { +func (e *ShowExec) fetchShowGrants(ctx context.Context) error { vars := e.Ctx().GetSessionVars() checker := privilege.GetPrivilegeManager(e.Ctx()) if checker == nil { @@ -1898,11 +1898,11 @@ func (e *ShowExec) fetchShowGrants() error { if r.Hostname == "" { r.Hostname = "%" } - if !checker.FindEdge(e.Ctx(), r, e.User) { + if !checker.FindEdge(ctx, e.Ctx(), r, e.User) { return exeerrors.ErrRoleNotGranted.GenWithStackByArgs(r.String(), e.User.String()) } } - gs, err := checker.ShowGrants(e.Ctx(), e.User, e.Roles) + gs, err := checker.ShowGrants(ctx, e.Ctx(), e.User, e.Roles) if err != nil { return errors.Trace(err) } diff --git a/pkg/executor/simple.go b/pkg/executor/simple.go index 9fa8f6450722e..5188e10f826d7 100644 --- a/pkg/executor/simple.go +++ b/pkg/executor/simple.go @@ -176,7 +176,7 @@ func (e *SimpleExec) Next(ctx context.Context, _ *chunk.Chunk) (err error) { case *ast.DropStatsStmt: err = e.executeDropStats(ctx, x) case *ast.SetRoleStmt: - err = e.executeSetRole(x) + err = e.executeSetRole(ctx, x) case *ast.RevokeRoleStmt: err = e.executeRevokeRole(ctx, x) case *ast.SetDefaultRoleStmt: @@ -274,7 +274,7 @@ func (e *SimpleExec) setDefaultRoleRegular(ctx context.Context, s *ast.SetDefaul } for _, role := range s.RoleList { checker := privilege.GetPrivilegeManager(e.Ctx()) - ok := checker.FindEdge(e.Ctx(), role, user) + ok := checker.FindEdge(ctx, e.Ctx(), role, user) if !ok { if _, rollbackErr := sqlExecutor.ExecuteInternal(internalCtx, "rollback"); rollbackErr != nil { return rollbackErr @@ -388,7 +388,7 @@ func (e *SimpleExec) setDefaultRoleForCurrentUser(s *ast.SetDefaultRoleStmt) (er if i > 0 { sqlescape.MustFormatSQL(sql, ",") } - ok := checker.FindEdge(e.Ctx(), role, user) + ok := checker.FindEdge(ctx, e.Ctx(), role, user) if !ok { return exeerrors.ErrRoleNotGranted.GenWithStackByArgs(role.String(), user.String()) } @@ -409,6 +409,14 @@ func (e *SimpleExec) setDefaultRoleForCurrentUser(s *ast.SetDefaultRoleStmt) (er return nil } +func userIdentityToUserList(specs []*auth.UserIdentity) []string { + users := make([]string, 0, len(specs)) + for _, user := range specs { + users = append(users, user.Username) + } + return users +} + func (e *SimpleExec) executeSetDefaultRole(ctx context.Context, s *ast.SetDefaultRoleStmt) (err error) { sessionVars := e.Ctx().GetSessionVars() checker := privilege.GetPrivilegeManager(e.Ctx()) @@ -423,7 +431,8 @@ func (e *SimpleExec) executeSetDefaultRole(ctx context.Context, s *ast.SetDefaul if err != nil { return err } - return domain.GetDomain(e.Ctx()).NotifyUpdatePrivilege() + users := userIdentityToUserList(s.UserList) + return domain.GetDomain(e.Ctx()).NotifyUpdatePrivilege(users) } } @@ -445,10 +454,11 @@ func (e *SimpleExec) executeSetDefaultRole(ctx context.Context, s *ast.SetDefaul if err != nil { return } - return domain.GetDomain(e.Ctx()).NotifyUpdatePrivilege() + users := userIdentityToUserList(s.UserList) + return domain.GetDomain(e.Ctx()).NotifyUpdatePrivilege(users) } -func (e *SimpleExec) setRoleRegular(s *ast.SetRoleStmt) error { +func (e *SimpleExec) setRoleRegular(ctx context.Context, s *ast.SetRoleStmt) error { // Deal with SQL like `SET ROLE role1, role2;` checkDup := make(map[string]*auth.RoleIdentity, len(s.RoleList)) // Check whether RoleNameList contain duplicate role name. @@ -462,7 +472,7 @@ func (e *SimpleExec) setRoleRegular(s *ast.SetRoleStmt) error { } checker := privilege.GetPrivilegeManager(e.Ctx()) - ok, roleName := checker.ActiveRoles(e.Ctx(), roleList) + ok, roleName := checker.ActiveRoles(ctx, e.Ctx(), roleList) if !ok { u := e.Ctx().GetSessionVars().User return exeerrors.ErrRoleNotGranted.GenWithStackByArgs(roleName, u.String()) @@ -470,12 +480,12 @@ func (e *SimpleExec) setRoleRegular(s *ast.SetRoleStmt) error { return nil } -func (e *SimpleExec) setRoleAll() error { +func (e *SimpleExec) setRoleAll(ctx context.Context) error { // Deal with SQL like `SET ROLE ALL;` checker := privilege.GetPrivilegeManager(e.Ctx()) user, host := e.Ctx().GetSessionVars().User.AuthUsername, e.Ctx().GetSessionVars().User.AuthHostname roles := checker.GetAllRoles(user, host) - ok, roleName := checker.ActiveRoles(e.Ctx(), roles) + ok, roleName := checker.ActiveRoles(ctx, e.Ctx(), roles) if !ok { u := e.Ctx().GetSessionVars().User return exeerrors.ErrRoleNotGranted.GenWithStackByArgs(roleName, u.String()) @@ -483,7 +493,7 @@ func (e *SimpleExec) setRoleAll() error { return nil } -func (e *SimpleExec) setRoleAllExcept(s *ast.SetRoleStmt) error { +func (e *SimpleExec) setRoleAllExcept(ctx context.Context, s *ast.SetRoleStmt) error { // Deal with SQL like `SET ROLE ALL EXCEPT role1, role2;` for _, r := range s.RoleList { if r.Hostname == "" { @@ -514,7 +524,7 @@ func (e *SimpleExec) setRoleAllExcept(s *ast.SetRoleStmt) error { } afterExcept := filter(roles, banned) - ok, roleName := checker.ActiveRoles(e.Ctx(), afterExcept) + ok, roleName := checker.ActiveRoles(ctx, e.Ctx(), afterExcept) if !ok { u := e.Ctx().GetSessionVars().User return exeerrors.ErrRoleNotGranted.GenWithStackByArgs(roleName, u.String()) @@ -522,12 +532,12 @@ func (e *SimpleExec) setRoleAllExcept(s *ast.SetRoleStmt) error { return nil } -func (e *SimpleExec) setRoleDefault() error { +func (e *SimpleExec) setRoleDefault(ctx context.Context) error { // Deal with SQL like `SET ROLE DEFAULT;` checker := privilege.GetPrivilegeManager(e.Ctx()) user, host := e.Ctx().GetSessionVars().User.AuthUsername, e.Ctx().GetSessionVars().User.AuthHostname roles := checker.GetDefaultRoles(user, host) - ok, roleName := checker.ActiveRoles(e.Ctx(), roles) + ok, roleName := checker.ActiveRoles(ctx, e.Ctx(), roles) if !ok { u := e.Ctx().GetSessionVars().User return exeerrors.ErrRoleNotGranted.GenWithStackByArgs(roleName, u.String()) @@ -535,11 +545,11 @@ func (e *SimpleExec) setRoleDefault() error { return nil } -func (e *SimpleExec) setRoleNone() error { +func (e *SimpleExec) setRoleNone(ctx context.Context) error { // Deal with SQL like `SET ROLE NONE;` checker := privilege.GetPrivilegeManager(e.Ctx()) roles := make([]*auth.RoleIdentity, 0) - ok, roleName := checker.ActiveRoles(e.Ctx(), roles) + ok, roleName := checker.ActiveRoles(ctx, e.Ctx(), roles) if !ok { u := e.Ctx().GetSessionVars().User return exeerrors.ErrRoleNotGranted.GenWithStackByArgs(roleName, u.String()) @@ -547,18 +557,18 @@ func (e *SimpleExec) setRoleNone() error { return nil } -func (e *SimpleExec) executeSetRole(s *ast.SetRoleStmt) error { +func (e *SimpleExec) executeSetRole(ctx context.Context, s *ast.SetRoleStmt) error { switch s.SetRoleOpt { case ast.SetRoleRegular: - return e.setRoleRegular(s) + return e.setRoleRegular(ctx, s) case ast.SetRoleAll: - return e.setRoleAll() + return e.setRoleAll(ctx) case ast.SetRoleAllExcept: - return e.setRoleAllExcept(s) + return e.setRoleAllExcept(ctx, s) case ast.SetRoleNone: - return e.setRoleNone() + return e.setRoleNone(ctx) case ast.SetRoleDefault: - return e.setRoleDefault() + return e.setRoleDefault(ctx) } return nil } @@ -754,11 +764,12 @@ func (e *SimpleExec) executeRevokeRole(ctx context.Context, s *ast.RevokeRoleStm if checker == nil { return errors.New("miss privilege checker") } - if ok, roleName := checker.ActiveRoles(e.Ctx(), activeRoles); !ok { + if ok, roleName := checker.ActiveRoles(ctx, e.Ctx(), activeRoles); !ok { u := e.Ctx().GetSessionVars().User return exeerrors.ErrRoleNotGranted.GenWithStackByArgs(roleName, u.String()) } - return domain.GetDomain(e.Ctx()).NotifyUpdatePrivilege() + userList := userIdentityToUserList(s.Users) + return domain.GetDomain(e.Ctx()).NotifyUpdatePrivilege(userList) } func (e *SimpleExec) executeCommit() { @@ -1265,7 +1276,8 @@ func (e *SimpleExec) executeCreateUser(ctx context.Context, s *ast.CreateUserStm if _, err := sqlExecutor.ExecuteInternal(internalCtx, "commit"); err != nil { return errors.Trace(err) } - return domain.GetDomain(e.Ctx()).NotifyUpdatePrivilege() + userList := userIdentityToUserList(users) + return domain.GetDomain(e.Ctx()).NotifyUpdatePrivilege(userList) } func isRole(ctx context.Context, sqlExecutor sqlexec.SQLExecutor, name, host string) (bool, error) { @@ -1762,15 +1774,17 @@ func (e *SimpleExec) executeAlterUser(ctx context.Context, s *ast.AlterUserStmt) if !(hasCreateUserPriv || hasSystemSchemaPriv) { return plannererrors.ErrSpecificAccessDenied.GenWithStackByArgs("CREATE USER") } - if checker.RequestDynamicVerificationWithUser("SYSTEM_USER", false, spec.User) && !(hasSystemUserPriv || hasRestrictedUserPriv) { + if !(hasSystemUserPriv || hasRestrictedUserPriv) && checker.RequestDynamicVerificationWithUser("SYSTEM_USER", false, spec.User) { return plannererrors.ErrSpecificAccessDenied.GenWithStackByArgs("SYSTEM_USER or SUPER") } - if sem.IsEnabled() && checker.RequestDynamicVerificationWithUser("RESTRICTED_USER_ADMIN", false, spec.User) && !hasRestrictedUserPriv { - return plannererrors.ErrSpecificAccessDenied.GenWithStackByArgs("RESTRICTED_USER_ADMIN") + if sem.IsEnabled() { + if !hasRestrictedUserPriv && checker.RequestDynamicVerificationWithUser("RESTRICTED_USER_ADMIN", false, spec.User) { + return plannererrors.ErrSpecificAccessDenied.GenWithStackByArgs("RESTRICTED_USER_ADMIN") + } } } - exists, err := userExistsInternal(ctx, sqlExecutor, spec.User.Username, spec.User.Hostname) + exists, currentAuthPlugin, err := userExistsInternal(ctx, sqlExecutor, spec.User.Username, spec.User.Hostname) if err != nil { return err } @@ -1791,10 +1805,6 @@ func (e *SimpleExec) executeAlterUser(ctx context.Context, s *ast.AlterUserStmt) RequireAuthTokenOptions ) authTokenOptionHandler := noNeedAuthTokenOptions - currentAuthPlugin, err := privilege.GetPrivilegeManager(e.Ctx()).GetAuthPlugin(spec.User.Username, spec.User.Hostname) - if err != nil { - return err - } if currentAuthPlugin == mysql.AuthTiDBAuthToken { authTokenOptionHandler = OptionalAuthTokenOptions } @@ -2025,7 +2035,8 @@ func (e *SimpleExec) executeAlterUser(ctx context.Context, s *ast.AlterUserStmt) if _, err := sqlExecutor.ExecuteInternal(ctx, "commit"); err != nil { return err } - if err = domain.GetDomain(e.Ctx()).NotifyUpdatePrivilege(); err != nil { + users := userSpecToUserList(s.Specs) + if err = domain.GetDomain(e.Ctx()).NotifyUpdatePrivilege(users); err != nil { return err } if disableSandBoxMode { @@ -2101,7 +2112,8 @@ func (e *SimpleExec) executeGrantRole(ctx context.Context, s *ast.GrantRoleStmt) if _, err := sqlExecutor.ExecuteInternal(internalCtx, "commit"); err != nil { return err } - return domain.GetDomain(e.Ctx()).NotifyUpdatePrivilege() + userList := userIdentityToUserList(s.Users) + return domain.GetDomain(e.Ctx()).NotifyUpdatePrivilege(userList) } // Should cover same internal mysql.* tables as DROP USER, so this function is very similar @@ -2126,7 +2138,7 @@ func (e *SimpleExec) executeRenameUser(s *ast.RenameUserStmt) error { if len(newUser.Hostname) > auth.HostNameMaxLength { return exeerrors.ErrWrongStringLength.GenWithStackByArgs(newUser.Hostname, "host name", auth.HostNameMaxLength) } - exists, err := userExistsInternal(ctx, sqlExecutor, oldUser.Username, oldUser.Hostname) + exists, _, err := userExistsInternal(ctx, sqlExecutor, oldUser.Username, oldUser.Hostname) if err != nil { return err } @@ -2135,7 +2147,7 @@ func (e *SimpleExec) executeRenameUser(s *ast.RenameUserStmt) error { break } - exists, err = userExistsInternal(ctx, sqlExecutor, newUser.Username, newUser.Hostname) + exists, _, err = userExistsInternal(ctx, sqlExecutor, newUser.Username, newUser.Hostname) if err != nil { return err } @@ -2216,7 +2228,13 @@ func (e *SimpleExec) executeRenameUser(s *ast.RenameUserStmt) error { if _, err := sqlExecutor.ExecuteInternal(ctx, "commit"); err != nil { return err } - return domain.GetDomain(e.Ctx()).NotifyUpdatePrivilege() + + userList := make([]string, 0, len(s.UserToUsers)*2) + for _, users := range s.UserToUsers { + userList = append(userList, users.OldUser.Username) + userList = append(userList, users.NewUser.Username) + } + return domain.GetDomain(e.Ctx()).NotifyUpdatePrivilege(userList) } func renameUserHostInSystemTable(sqlExecutor sqlexec.SQLExecutor, tableName, usernameColumn, hostColumn string, users *ast.UserToUser) error { @@ -2287,7 +2305,7 @@ func (e *SimpleExec) executeDropUser(ctx context.Context, s *ast.DropUserStmt) e // Because in TiDB SUPER can be used as a substitute for any dynamic privilege, this effectively means that // any user with SUPER requires a user with SUPER to be able to DROP the user. // We also allow RESTRICTED_USER_ADMIN to count for simplicity. - if checker.RequestDynamicVerificationWithUser("SYSTEM_USER", false, user) && !(hasSystemUserPriv || hasRestrictedUserPriv) { + if !(hasSystemUserPriv || hasRestrictedUserPriv) && checker.RequestDynamicVerificationWithUser("SYSTEM_USER", false, user) { if _, err := sqlExecutor.ExecuteInternal(internalCtx, "rollback"); err != nil { return err } @@ -2408,12 +2426,13 @@ func (e *SimpleExec) executeDropUser(ctx context.Context, s *ast.DropUserStmt) e } if s.IsDropRole { // apply new activeRoles - if ok, roleName := checker.ActiveRoles(e.Ctx(), activeRoles); !ok { + if ok, roleName := checker.ActiveRoles(ctx, e.Ctx(), activeRoles); !ok { u := e.Ctx().GetSessionVars().User return exeerrors.ErrRoleNotGranted.GenWithStackByArgs(roleName, u.String()) } } - return domain.GetDomain(e.Ctx()).NotifyUpdatePrivilege() + userList := userIdentityToUserList(s.UserList) + return domain.GetDomain(e.Ctx()).NotifyUpdatePrivilege(userList) } func userExists(ctx context.Context, sctx sessionctx.Context, name string, host string) (bool, error) { @@ -2427,12 +2446,12 @@ func userExists(ctx context.Context, sctx sessionctx.Context, name string, host } // use the same internal executor to read within the same transaction, otherwise same as userExists -func userExistsInternal(ctx context.Context, sqlExecutor sqlexec.SQLExecutor, name string, host string) (bool, error) { +func userExistsInternal(ctx context.Context, sqlExecutor sqlexec.SQLExecutor, name string, host string) (bool, string, error) { sql := new(strings.Builder) sqlescape.MustFormatSQL(sql, `SELECT * FROM %n.%n WHERE User=%? AND Host=%? FOR UPDATE;`, mysql.SystemDB, mysql.UserTable, name, strings.ToLower(host)) recordSet, err := sqlExecutor.ExecuteInternal(ctx, sql.String()) if err != nil { - return false, err + return false, "", err } req := recordSet.NewChunk(nil) err = recordSet.Next(ctx, req) @@ -2440,11 +2459,27 @@ func userExistsInternal(ctx context.Context, sqlExecutor sqlexec.SQLExecutor, na if err == nil { rows = req.NumRows() } + + var authPlugin string + colIdx := -1 + for i, f := range recordSet.Fields() { + if f.ColumnAsName.L == "plugin" { + colIdx = i + } + } + if rows == 1 { + // rows can only be 0 or 1 + // When user + host does not exist, the rows is 0 + // When user + host exists, the rows is 1 because user + host is primary key of the table. + row := req.GetRow(0) + authPlugin = row.GetString(colIdx) + } + errClose := recordSet.Close() if errClose != nil { - return false, errClose + return false, "", errClose } - return rows > 0, err + return rows > 0, authPlugin, err } func (e *SimpleExec) executeSetPwd(ctx context.Context, s *ast.SetPwdStmt) error { @@ -2487,7 +2522,7 @@ func (e *SimpleExec) executeSetPwd(ctx context.Context, s *ast.SetPwdStmt) error return exeerrors.ErrDBaccessDenied.GenWithStackByArgs(currUser.Username, currUser.Hostname, "mysql") } } - exists, err := userExistsInternal(ctx, sqlExecutor, u, h) + exists, authplugin, err := userExistsInternal(ctx, sqlExecutor, u, h) if err != nil { return err } @@ -2501,11 +2536,6 @@ func (e *SimpleExec) executeSetPwd(ctx context.Context, s *ast.SetPwdStmt) error } disableSandboxMode = true } - - authplugin, err := privilege.GetPrivilegeManager(e.Ctx()).GetAuthPlugin(u, h) - if err != nil { - return err - } if e.isValidatePasswordEnabled() { if err := pwdValidator.ValidatePassword(e.Ctx().GetSessionVars(), s.Password); err != nil { return err @@ -2566,7 +2596,7 @@ func (e *SimpleExec) executeSetPwd(ctx context.Context, s *ast.SetPwdStmt) error if _, err := sqlExecutor.ExecuteInternal(ctx, "commit"); err != nil { return err } - err = domain.GetDomain(e.Ctx()).NotifyUpdatePrivilege() + err = domain.GetDomain(e.Ctx()).NotifyUpdatePrivilege([]string{u}) if err != nil { return err } @@ -2698,7 +2728,7 @@ func (e *SimpleExec) executeFlush(s *ast.FlushStmt) error { } case ast.FlushPrivileges: dom := domain.GetDomain(e.Ctx()) - return dom.NotifyUpdatePrivilege() + return dom.NotifyUpdateAllUsersPrivilege() case ast.FlushTiDBPlugin: dom := domain.GetDomain(e.Ctx()) for _, pluginName := range s.Plugins { diff --git a/pkg/executor/test/passwordtest/password_management_test.go b/pkg/executor/test/passwordtest/password_management_test.go index 1a2e04dde5b45..1d66178004d54 100644 --- a/pkg/executor/test/passwordtest/password_management_test.go +++ b/pkg/executor/test/passwordtest/password_management_test.go @@ -218,7 +218,7 @@ func TestPasswordManagement(t *testing.T) { rootTK.MustExec(`set global validate_password.enable = OFF`) rootTK.MustExec(`update mysql.user set Password_last_changed = date_sub(Password_last_changed,interval '3 0:0:1' DAY_SECOND) where user = 'u2' and host = '%'`) - err = domain.GetDomain(rootTK.Session()).NotifyUpdatePrivilege() + err = domain.GetDomain(rootTK.Session()).NotifyUpdateAllUsersPrivilege() require.NoError(t, err) // Password expires and takes effect. err = tk.Session().Auth(&auth.UserIdentity{Username: "u2", Hostname: "%"}, sha1Password("Uu3@22222"), nil, nil) @@ -723,7 +723,7 @@ func TestFailedLoginTrackingAlterUser(t *testing.T) { "JSON_EXTRACT(user_attributes, '$.Password_locking.failed_login_count')," + "JSON_EXTRACT(user_attributes, '$.Password_locking.password_lock_time_days')," + "JSON_EXTRACT(user_attributes, '$.metadata')from mysql.user where user= %? and host = %?" - err := domain.GetDomain(rootTK.Session()).NotifyUpdatePrivilege() + err := domain.GetDomain(rootTK.Session()).NotifyUpdateAllUsersPrivilege() require.NoError(t, err) rootTK.MustExec(`CREATE USER test1 IDENTIFIED BY '1234' FAILED_LOGIN_ATTEMPTS 3 PASSWORD_LOCK_TIME 3 COMMENT 'test'`) err = tk.Session().Auth(&auth.UserIdentity{Username: "test1", Hostname: "%"}, sha1Password("1234"), nil, nil) @@ -924,7 +924,7 @@ func changeAutoLockedLastChanged(tk *testkit.TestKit, ds, user string) { changeTime := time.Now().Add(d).Format(time.UnixDate) SQL = fmt.Sprintf(SQL, changeTime, user) tk.MustExec(SQL) - domain.GetDomain(tk.Session()).NotifyUpdatePrivilege() + domain.GetDomain(tk.Session()).NotifyUpdateAllUsersPrivilege() } func checkUserUserAttributes(tk *testkit.TestKit, user, host, row string) { diff --git a/pkg/extension/auth_test.go b/pkg/extension/auth_test.go index 1aa2917ba6545..4098eae81923d 100644 --- a/pkg/extension/auth_test.go +++ b/pkg/extension/auth_test.go @@ -235,7 +235,7 @@ func TestAuthPlugin(t *testing.T) { // Should authenticate using plugin impl. p.AssertNumberOfCalls(t, "AuthenticateUser", 2) p.AssertCalled(t, "ValidateAuthString", "encodedpassword") - p.AssertNumberOfCalls(t, "ValidateAuthString", 4) + p.AssertNumberOfCalls(t, "ValidateAuthString", 3) // Change password should work using ALTER USER statement. tk.MustExec("alter user 'u2'@'localhost' identified with 'authentication_test_plugin' by 'anotherrawpassword'") diff --git a/pkg/metrics/grafana/tidb.json b/pkg/metrics/grafana/tidb.json index 551dcfab90fd3..e9929096373dc 100644 --- a/pkg/metrics/grafana/tidb.json +++ b/pkg/metrics/grafana/tidb.json @@ -4314,6 +4314,116 @@ "points": false, "stack": false, "steppedLine": false + }, + { + "aliasColors": {}, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": "${DS_TEST-CLUSTER}", + "description": "The total count of active users.", + "editable": true, + "error": false, + "fieldConfig": { + "defaults": {}, + "overrides": [] + }, + "fill": 0, + "fillGradient": 0, + "grid": {}, + "gridPos": { + "h": 7, + "w": 12, + "x": 0, + "y": 73 + }, + "hiddenSeries": false, + "id": 23763574014, + "legend": { + "alignAsTable": true, + "avg": false, + "current": true, + "hideEmpty": true, + "hideZero": false, + "max": true, + "min": false, + "rightSide": true, + "show": true, + "sideWidth": null, + "total": false, + "values": true + }, + "lines": true, + "linewidth": 1, + "links": [], + "nullPointMode": "null as zero", + "options": { + "alertThreshold": true + }, + "percentage": false, + "pluginVersion": "7.5.17", + "pointradius": 5, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "exemplar": true, + "expr": "tidb_server_active_users{k8s_cluster=\"$k8s_cluster\", tidb_cluster=\"$tidb_cluster\", job=\"tidb\"}", + "format": "time_series", + "hide": false, + "interval": "", + "intervalFactor": 1, + "legendFormat": "{{instance}}", + "refId": "H" + } + ], + "thresholds": [], + "timeFrom": null, + "timeRegions": [], + "timeShift": null, + "title": "Active Users", + "tooltip": { + "msResolution": true, + "shared": true, + "sort": 0, + "value_type": "cumulative" + }, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "$$hashKey": "object:264", + "format": "short", + "label": "", + "logBase": 1, + "max": null, + "min": "0", + "show": true + }, + { + "$$hashKey": "object:265", + "format": "short", + "label": "", + "logBase": 1, + "max": null, + "min": null, + "show": false + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } } ], "repeat": null, @@ -25563,4 +25673,4 @@ "title": "Test-Cluster-TiDB", "uid": "000000011", "version": 1 -} \ No newline at end of file +} diff --git a/pkg/metrics/metrics.go b/pkg/metrics/metrics.go index c17859e256af8..2db78870d8da1 100644 --- a/pkg/metrics/metrics.go +++ b/pkg/metrics/metrics.go @@ -299,6 +299,7 @@ func RegisterMetrics() { prometheus.MustRegister(BindingCacheMemLimit) prometheus.MustRegister(BindingCacheNumBindings) prometheus.MustRegister(InternalSessions) + prometheus.MustRegister(ActiveUser) tikvmetrics.InitMetrics(TiDB, TiKVClient) tikvmetrics.RegisterMetrics() diff --git a/pkg/metrics/server.go b/pkg/metrics/server.go index 346f21ea54a16..be0795a97c83e 100644 --- a/pkg/metrics/server.go +++ b/pkg/metrics/server.go @@ -73,6 +73,7 @@ var ( RCCheckTSWriteConfilictCounter *prometheus.CounterVec MemoryLimit prometheus.Gauge InternalSessions prometheus.Gauge + ActiveUser prometheus.Gauge ) // InitServerMetrics initializes server metrics. @@ -402,6 +403,14 @@ func InitServerMetrics() { Name: "internal_sessions", Help: "The total count of internal sessions.", }) + + ActiveUser = prometheus.NewGauge( + prometheus.GaugeOpts{ + Namespace: "tidb", + Subsystem: "server", + Name: "active_users", + Help: "The total count of active user.", + }) } // ExecuteErrorToLabel converts an execute error to label. diff --git a/pkg/privilege/BUILD.bazel b/pkg/privilege/BUILD.bazel index 2a8e6e17cba93..3e07babb15194 100644 --- a/pkg/privilege/BUILD.bazel +++ b/pkg/privilege/BUILD.bazel @@ -12,5 +12,6 @@ go_library( "//pkg/sessionctx", "//pkg/sessionctx/variable", "//pkg/types", + "//pkg/util/sqlexec", ], ) diff --git a/pkg/privilege/privilege.go b/pkg/privilege/privilege.go index cba7c122a7419..5452f3ee5a5a1 100644 --- a/pkg/privilege/privilege.go +++ b/pkg/privilege/privilege.go @@ -15,6 +15,7 @@ package privilege import ( + "context" "fmt" "github.com/pingcap/tidb/pkg/parser/auth" @@ -23,6 +24,7 @@ import ( "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/sessionctx/variable" "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/sqlexec" ) type keyType int @@ -44,7 +46,7 @@ type VerificationInfo struct { // Manager is the interface for providing privilege related operations. type Manager interface { // ShowGrants shows granted privileges for user. - ShowGrants(ctx sessionctx.Context, user *auth.UserIdentity, roles []*auth.RoleIdentity) ([]string, error) + ShowGrants(ctx context.Context, sctx sessionctx.Context, user *auth.UserIdentity, roles []*auth.RoleIdentity) ([]string, error) // GetEncodedPassword shows the encoded password for user. GetEncodedPassword(user, host string) string @@ -91,7 +93,7 @@ type Manager interface { MatchIdentity(user, host string, skipNameResolve bool) (string, string, bool) // MatchUserResourceGroupName matches a user with specified resource group name - MatchUserResourceGroupName(resourceGroupName string) (string, bool) + MatchUserResourceGroupName(exec sqlexec.RestrictedSQLExecutor, resourceGroupName string) (string, bool) // DBIsVisible returns true is the database is visible to current user. DBIsVisible(activeRole []*auth.RoleIdentity, db string) bool @@ -101,10 +103,10 @@ type Manager interface { // ActiveRoles active roles for current session. // The first illegal role will be returned. - ActiveRoles(ctx sessionctx.Context, roleList []*auth.RoleIdentity) (bool, string) + ActiveRoles(ctx context.Context, sctx sessionctx.Context, roleList []*auth.RoleIdentity) (bool, string) // FindEdge find if there is an edge between role and user. - FindEdge(ctx sessionctx.Context, role *auth.RoleIdentity, user *auth.UserIdentity) bool + FindEdge(ctx context.Context, sctx sessionctx.Context, role *auth.RoleIdentity, user *auth.UserIdentity) bool // GetDefaultRoles returns all default roles for certain user. GetDefaultRoles(user, host string) []*auth.RoleIdentity diff --git a/pkg/privilege/privileges/BUILD.bazel b/pkg/privilege/privileges/BUILD.bazel index b08fc60a8adac..dd6aae40fc875 100644 --- a/pkg/privilege/privileges/BUILD.bazel +++ b/pkg/privilege/privileges/BUILD.bazel @@ -15,6 +15,7 @@ go_library( "//pkg/extension", "//pkg/infoschema", "//pkg/kv", + "//pkg/metrics", "//pkg/parser/auth", "//pkg/parser/mysql", "//pkg/parser/terror", @@ -35,6 +36,7 @@ go_library( "//pkg/util/sqlescape", "//pkg/util/sqlexec", "//pkg/util/stringutil", + "@com_github_google_btree//:btree", "@com_github_lestrrat_go_jwx_v2//jwk", "@com_github_lestrrat_go_jwx_v2//jws", "@com_github_lestrrat_go_jwx_v2//jwt", @@ -58,6 +60,7 @@ go_test( shard_count = 50, deps = [ "//pkg/config", + "//pkg/domain", "//pkg/errno", "//pkg/kv", "//pkg/parser/auth", diff --git a/pkg/privilege/privileges/cache.go b/pkg/privilege/privileges/cache.go index b7f4492534c5f..fb52476cb6496 100644 --- a/pkg/privilege/privileges/cache.go +++ b/pkg/privilege/privileges/cache.go @@ -27,8 +27,10 @@ import ( "sync/atomic" "time" + "github.com/google/btree" "github.com/pingcap/errors" "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/metrics" "github.com/pingcap/tidb/pkg/parser/auth" "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/parser/terror" @@ -247,7 +249,7 @@ type defaultRoleRecord struct { // roleGraphEdgesTable is used to cache relationship between and role. type roleGraphEdgesTable struct { - roleList map[string]*auth.RoleIdentity + roleList map[auth.RoleIdentity]*auth.RoleIdentity } // Find method is used to find role from table @@ -255,7 +257,7 @@ func (g roleGraphEdgesTable) Find(user, host string) bool { if host == "" { host = "%" } - key := user + "@" + host + key := auth.RoleIdentity{Username: user, Hostname: host} if g.roleList == nil { return false } @@ -275,16 +277,85 @@ type immutable struct { roleGraph map[string]roleGraphEdgesTable } -type extended struct { - UserMap map[string][]UserRecord // Accelerate User searching - Global map[string][]globalPrivRecord - Dynamic map[string][]dynamicPrivRecord - DBMap map[string][]dbRecord // Accelerate DB searching - TablesPrivMap map[string][]tablesPrivRecord // Accelerate TablesPriv searching +type itemUser struct { + username string + data []UserRecord +} + +func compareItemUser(a, b itemUser) bool { + return a.username < b.username +} + +type itemDB struct { + username string + data []dbRecord +} + +func compareItemDB(a, b itemDB) bool { + return a.username < b.username +} + +type itemTablesPriv struct { + username string + data []tablesPrivRecord +} + +func compareItemTablesPriv(a, b itemTablesPriv) bool { + return a.username < b.username +} + +type itemColumnsPriv struct { + username string + data []columnsPrivRecord +} + +func compareItemColumnsPriv(a, b itemColumnsPriv) bool { + return a.username < b.username +} + +type itemDefaultRole struct { + username string + data []defaultRoleRecord +} + +func compareItemDefaultRole(a, b itemDefaultRole) bool { + return a.username < b.username +} + +type itemGlobalPriv struct { + username string + data []globalPrivRecord +} + +func compareItemGlobalPriv(a, b itemGlobalPriv) bool { + return a.username < b.username +} + +type itemDynamicPriv struct { + username string + data []dynamicPrivRecord +} + +func compareItemDynamicPriv(a, b itemDynamicPriv) bool { + return a.username < b.username +} + +type bTree[T any] struct { + *btree.BTreeG[T] + sync.Mutex +} + +// Clone provides the concurrent-safe operation by wraping the original Clone. +func (bt *bTree[T]) Clone() *btree.BTreeG[T] { + bt.Lock() + defer bt.Unlock() + return bt.BTreeG.Clone() } // MySQLPrivilege is the in-memory cache of mysql privilege tables. type MySQLPrivilege struct { + globalVars variable.GlobalVarAccessor + // In MySQL, a user identity consists of a user + host. // Either portion of user or host can contain wildcards, // requiring the privileges system to use a list-like @@ -294,9 +365,28 @@ type MySQLPrivilege struct { // which is that usernames can not contain wildcards. // This means that DB-records are organized in both a // slice (p.DB) and a Map (p.DBMap). - immutable - extended + user bTree[itemUser] + db bTree[itemDB] + tablesPriv bTree[itemTablesPriv] + columnsPriv bTree[itemColumnsPriv] + defaultRoles bTree[itemDefaultRole] + + globalPriv bTree[itemGlobalPriv] + dynamicPriv bTree[itemDynamicPriv] + roleGraph map[auth.RoleIdentity]roleGraphEdgesTable +} + +func newMySQLPrivilege() *MySQLPrivilege { + var p MySQLPrivilege + p.user = bTree[itemUser]{BTreeG: btree.NewG(8, compareItemUser)} + p.db = bTree[itemDB]{BTreeG: btree.NewG(8, compareItemDB)} + p.tablesPriv = bTree[itemTablesPriv]{BTreeG: btree.NewG(8, compareItemTablesPriv)} + p.columnsPriv = bTree[itemColumnsPriv]{BTreeG: btree.NewG(8, compareItemColumnsPriv)} + p.defaultRoles = bTree[itemDefaultRole]{BTreeG: btree.NewG(8, compareItemDefaultRole)} + p.globalPriv = bTree[itemGlobalPriv]{BTreeG: btree.NewG(8, compareItemGlobalPriv)} + p.dynamicPriv = bTree[itemDynamicPriv]{BTreeG: btree.NewG(8, compareItemDynamicPriv)} + return &p } // FindAllUserEffectiveRoles is used to find all effective roles grant to this user. @@ -322,7 +412,7 @@ func (p *MySQLPrivilege) FindAllRole(activeRoles []*auth.RoleIdentity) []*auth.R if _, ok := visited[role.String()]; !ok { visited[role.String()] = true ret = append(ret, role) - key := role.Username + "@" + role.Hostname + key := auth.RoleIdentity{Username: role.Username, Hostname: role.Hostname} if edgeTable, ok := p.roleGraph[key]; ok { for _, v := range edgeTable.roleList { if _, ok := visited[v.String()]; !ok { @@ -341,19 +431,29 @@ func (p *MySQLPrivilege) FindRole(user string, host string, role *auth.RoleIdent rec := p.matchUser(user, host) r := p.matchUser(role.Username, role.Hostname) if rec != nil && r != nil { - key := rec.User + "@" + rec.Host + key := auth.RoleIdentity{Username: rec.User, Hostname: rec.Host} return p.roleGraph[key].Find(role.Username, role.Hostname) } return false } +func findRole(ctx context.Context, h *Handle, user string, host string, role *auth.RoleIdentity) bool { + terror.Log(h.ensureActiveUser(ctx, user)) + terror.Log(h.ensureActiveUser(ctx, role.Username)) + mysqlPrivilege := h.Get() + return mysqlPrivilege.FindRole(user, host, role) +} + // LoadAll loads the tables from database to memory. -func (p *MySQLPrivilege) LoadAll(ctx sqlexec.RestrictedSQLExecutor) error { +func (p *MySQLPrivilege) LoadAll(ctx sqlexec.SQLExecutor) error { err := p.LoadUserTable(ctx) if err != nil { logutil.BgLogger().Warn("load mysql.user fail", zap.Error(err)) return errLoadPrivilege.FastGen("mysql.user") } + if l := p.user.Len(); l > 1024 { + logutil.BgLogger().Warn("load all called and user list is long, suggest enabling @@global.tidb_accelerate_user_creation_update", zap.Int("len", l)) + } err = p.LoadGlobalPrivTable(ctx) if err != nil { @@ -412,44 +512,67 @@ func (p *MySQLPrivilege) LoadAll(ctx sqlexec.RestrictedSQLExecutor) error { return nil } -func (p *immutable) loadSomeUsers(ctx sqlexec.RestrictedSQLExecutor, userList ...string) error { - err := p.loadTable(ctx, sqlLoadUserTable, p.decodeUserTableRow, userList...) - if err != nil { - return errors.Trace(err) +func findUserAndAllRoles(userList []string, roleGraph map[auth.RoleIdentity]roleGraphEdgesTable) map[string]struct{} { + // Including the user list and also their roles + all := make(map[string]struct{}, len(userList)) + queue := make([]string, 0, len(userList)) + + // Initialize the queue with the initial user list + for _, user := range userList { + all[user] = struct{}{} + queue = append(queue, user) + } + + // Process the queue using BFS + for len(queue) > 0 { + user := queue[0] + queue = queue[1:] + for userHost, value := range roleGraph { + if userHost.Username == user { + for _, role := range value.roleList { + if _, ok := all[role.Username]; !ok { + all[role.Username] = struct{}{} + queue = append(queue, role.Username) + } + } + } + } } + return all +} - err = p.loadTable(ctx, sqlLoadGlobalPrivTable, p.decodeGlobalPrivTableRow, userList...) +func (p *MySQLPrivilege) loadSomeUsers(ctx sqlexec.SQLExecutor, userList map[string]struct{}) error { + err := loadTable(ctx, addUserFilterCondition(sqlLoadUserTable, userList), p.decodeUserTableRow(userList)) if err != nil { return errors.Trace(err) } - err = p.loadTable(ctx, sqlLoadGlobalGrantsTable, p.decodeGlobalGrantsTableRow, userList...) + err = loadTable(ctx, addUserFilterCondition(sqlLoadGlobalPrivTable, userList), p.decodeGlobalPrivTableRow(userList)) if err != nil { return errors.Trace(err) } - err = p.loadTable(ctx, sqlLoadDBTable, p.decodeDBTableRow, userList...) + err = loadTable(ctx, addUserFilterCondition(sqlLoadGlobalGrantsTable, userList), p.decodeGlobalGrantsTableRow(userList)) if err != nil { return errors.Trace(err) } - err = p.loadTable(ctx, sqlLoadTablePrivTable, p.decodeTablesPrivTableRow, userList...) + err = loadTable(ctx, addUserFilterCondition(sqlLoadDBTable, userList), p.decodeDBTableRow(userList)) if err != nil { return errors.Trace(err) } - err = p.loadTable(ctx, sqlLoadDefaultRoles, p.decodeDefaultRoleTableRow, userList...) + err = loadTable(ctx, addUserFilterCondition(sqlLoadTablePrivTable, userList), p.decodeTablesPrivTableRow(userList)) if err != nil { return errors.Trace(err) } - err = p.loadTable(ctx, sqlLoadColumnsPrivTable, p.decodeColumnsPrivTableRow, userList...) + err = loadTable(ctx, addUserFilterCondition(sqlLoadDefaultRoles, userList), p.decodeDefaultRoleTableRow(userList)) if err != nil { return errors.Trace(err) } - p.roleGraph = make(map[string]roleGraphEdgesTable) - err = p.loadTable(ctx, sqlLoadRoleGraph, p.decodeRoleEdgesTable) + err = loadTable(ctx, addUserFilterCondition(sqlLoadColumnsPrivTable, userList), p.decodeColumnsPrivTableRow(userList)) if err != nil { return errors.Trace(err) } @@ -457,74 +580,95 @@ func (p *immutable) loadSomeUsers(ctx sqlexec.RestrictedSQLExecutor, userList .. return nil } -func dedupSortedKeepLast[S ~[]E, E any](s S, eq func(a, b E) bool) S { - skip := 0 - for i := 1; i < len(s); i++ { - if eq(s[i], s[i-1]) { - skip++ +// merge construct a new MySQLPrivilege by merging the data of the two objects. +func (p *MySQLPrivilege) merge(diff *MySQLPrivilege, userList map[string]struct{}) *MySQLPrivilege { + ret := newMySQLPrivilege() + user := p.user.Clone() + for u := range userList { + itm, ok := diff.user.Get(itemUser{username: u}) + if !ok { + user.Delete(itemUser{username: u}) + } else { + slices.SortFunc(itm.data, compareUserRecord) + user.ReplaceOrInsert(itm) } - s[i-skip] = s[i] } - s = s[:len(s)-skip] - return s -} - -// merge construct a new MySQLPrivilege by merging the data of the two objects;. -func (p *MySQLPrivilege) merge(diff *immutable) *MySQLPrivilege { - var ret MySQLPrivilege - ret.user = make([]UserRecord, 0, len(p.user)+len(diff.user)) - ret.user = append(ret.user, p.user...) - ret.user = append(ret.user, diff.user...) + ret.user.BTreeG = user - // sort and dedup - slices.SortStableFunc(ret.user, compareUserRecord) - ret.user = dedupSortedKeepLast(ret.user, func(x, y UserRecord) bool { return x.User == y.User && x.Host == y.Host }) - ret.buildUserMap() + db := p.db.Clone() + for u := range userList { + itm, ok := diff.db.Get(itemDB{username: u}) + if !ok { + db.Delete(itemDB{username: u}) + } else { + slices.SortFunc(itm.data, compareDBRecord) + db.ReplaceOrInsert(itm) + } + } + ret.db.BTreeG = db - ret.db = make([]dbRecord, 0, len(p.db)+len(diff.db)) - ret.db = append(ret.db, p.db...) - ret.db = append(ret.db, diff.db...) - ret.buildDBMap() + tablesPriv := p.tablesPriv.Clone() + for u := range userList { + itm, ok := diff.tablesPriv.Get(itemTablesPriv{username: u}) + if !ok { + tablesPriv.Delete(itemTablesPriv{username: u}) + } else { + slices.SortFunc(itm.data, compareTablesPrivRecord) + tablesPriv.ReplaceOrInsert(itm) + } + } + ret.tablesPriv.BTreeG = tablesPriv - ret.tablesPriv = make([]tablesPrivRecord, 0, len(p.tablesPriv)+len(diff.tablesPriv)) - ret.tablesPriv = append(ret.tablesPriv, p.tablesPriv...) - ret.tablesPriv = append(ret.tablesPriv, diff.tablesPriv...) - ret.buildTablesPrivMap() + columnsPriv := p.columnsPriv.Clone() + for u := range userList { + itm, ok := diff.columnsPriv.Get(itemColumnsPriv{username: u}) + if !ok { + columnsPriv.Delete(itemColumnsPriv{username: u}) + } else { + slices.SortFunc(itm.data, compareColumnsPrivRecord) + columnsPriv.ReplaceOrInsert(itm) + } + } + ret.columnsPriv.BTreeG = columnsPriv - ret.columnsPriv = make([]columnsPrivRecord, 0, len(p.columnsPriv)+len(diff.columnsPriv)) - ret.columnsPriv = append(ret.columnsPriv, p.columnsPriv...) - ret.columnsPriv = append(ret.columnsPriv, diff.columnsPriv...) - slices.SortStableFunc(ret.columnsPriv, compareColumnsPrivRecord) - ret.columnsPriv = dedupSortedKeepLast(ret.columnsPriv, func(x, y columnsPrivRecord) bool { - return x.Host == y.Host && x.User == y.User && - x.DB == y.DB && x.TableName == y.TableName && x.ColumnName == y.ColumnName - }) + defaultRoles := p.defaultRoles.Clone() + for u := range userList { + itm, ok := diff.defaultRoles.Get(itemDefaultRole{username: u}) + if !ok { + defaultRoles.Delete(itemDefaultRole{username: u}) + } else { + slices.SortFunc(itm.data, compareDefaultRoleRecord) + defaultRoles.ReplaceOrInsert(itm) + } + } + ret.defaultRoles.BTreeG = defaultRoles - ret.defaultRoles = make([]defaultRoleRecord, 0, len(p.defaultRoles)+len(diff.defaultRoles)) - ret.defaultRoles = append(ret.defaultRoles, p.defaultRoles...) - ret.defaultRoles = append(ret.defaultRoles, diff.defaultRoles...) - slices.SortStableFunc(ret.defaultRoles, compareDefaultRoleRecord) - ret.defaultRoles = dedupSortedKeepLast(ret.defaultRoles, func(x, y defaultRoleRecord) bool { - return x.Host == y.Host && x.User == y.User - }) + dynamicPriv := p.dynamicPriv.Clone() + for u := range userList { + itm, ok := diff.dynamicPriv.Get(itemDynamicPriv{username: u}) + if !ok { + dynamicPriv.Delete(itemDynamicPriv{username: u}) + } else { + slices.SortFunc(itm.data, compareDynamicPrivRecord) + dynamicPriv.ReplaceOrInsert(itm) + } + } + ret.dynamicPriv.BTreeG = dynamicPriv - ret.dynamicPriv = make([]dynamicPrivRecord, 0, len(p.dynamicPriv)+len(diff.dynamicPriv)) - ret.dynamicPriv = append(ret.dynamicPriv, p.dynamicPriv...) - ret.dynamicPriv = append(ret.dynamicPriv, diff.dynamicPriv...) - ret.buildDynamicMap() - - ret.globalPriv = make([]globalPrivRecord, 0, len(p.globalPriv)+len(diff.globalPriv)) - ret.globalPriv = append(ret.globalPriv, p.globalPriv...) - ret.globalPriv = append(ret.globalPriv, diff.globalPriv...) - slices.SortStableFunc(ret.globalPriv, compareGlobalPrivRecord) - ret.globalPriv = dedupSortedKeepLast(ret.globalPriv, func(x, y globalPrivRecord) bool { - return x.Host == y.Host && x.User == y.User - }) - ret.buildGlobalMap() + globalPriv := p.globalPriv.Clone() + for u := range userList { + itm, ok := diff.globalPriv.Get(itemGlobalPriv{username: u}) + if !ok { + globalPriv.Delete(itemGlobalPriv{username: u}) + } else { + slices.SortFunc(itm.data, compareGlobalPrivRecord) + globalPriv.ReplaceOrInsert(itm) + } + } + ret.globalPriv.BTreeG = globalPriv ret.roleGraph = diff.roleGraph - - return &ret + return ret } func noSuchTable(err error) bool { @@ -538,9 +682,9 @@ func noSuchTable(err error) bool { } // LoadRoleGraph loads the mysql.role_edges table from database. -func (p *MySQLPrivilege) LoadRoleGraph(ctx sqlexec.RestrictedSQLExecutor) error { - p.roleGraph = make(map[string]roleGraphEdgesTable) - err := p.loadTable(ctx, sqlLoadRoleGraph, p.decodeRoleEdgesTable) +func (p *MySQLPrivilege) LoadRoleGraph(exec sqlexec.SQLExecutor) error { + p.roleGraph = make(map[auth.RoleIdentity]roleGraphEdgesTable) + err := loadTable(exec, sqlLoadRoleGraph, p.decodeRoleEdgesTable) if err != nil { return errors.Trace(err) } @@ -548,8 +692,8 @@ func (p *MySQLPrivilege) LoadRoleGraph(ctx sqlexec.RestrictedSQLExecutor) error } // LoadUserTable loads the mysql.user table from database. -func (p *MySQLPrivilege) LoadUserTable(ctx sqlexec.RestrictedSQLExecutor) error { - err := p.loadTable(ctx, sqlLoadUserTable, p.decodeUserTableRow) +func (p *MySQLPrivilege) LoadUserTable(exec sqlexec.SQLExecutor) error { + err := loadTable(exec, sqlLoadUserTable, p.decodeUserTableRow(nil)) if err != nil { return errors.Trace(err) } @@ -560,18 +704,9 @@ func (p *MySQLPrivilege) LoadUserTable(ctx sqlexec.RestrictedSQLExecutor) error // 3. The server uses the first row that matches the client host name and user name. // The server uses sorting rules that order rows with the most-specific Host values first. p.SortUserTable() - p.buildUserMap() return nil } -func (p *MySQLPrivilege) buildUserMap() { - userMap := make(map[string][]UserRecord, len(p.user)) - for _, record := range p.user { - userMap[record.User] = append(userMap[record.User], record) - } - p.UserMap = userMap -} - func compareBaseRecord(x, y *baseRecord) int { // Compare two item by user's host first. c1 := compareHost(x.Host, y.Host) @@ -594,6 +729,10 @@ func compareGlobalPrivRecord(x, y globalPrivRecord) int { return compareBaseRecord(&x.baseRecord, &y.baseRecord) } +func compareDynamicPrivRecord(x, y dynamicPrivRecord) int { + return compareBaseRecord(&x.baseRecord, &y.baseRecord) +} + func compareColumnsPrivRecord(x, y columnsPrivRecord) int { cmp := compareBaseRecord(&x.baseRecord, &y.baseRecord) if cmp != 0 { @@ -670,133 +809,136 @@ func compareHost(x, y string) int { } // SortUserTable sorts p.User in the MySQLPrivilege struct. -func (p MySQLPrivilege) SortUserTable() { - slices.SortFunc(p.user, compareUserRecord) -} - -func (p *MySQLPrivilege) buildGlobalMap() { - global := make(map[string][]globalPrivRecord) - for _, value := range p.globalPriv { - global[value.User] = append(global[value.User], value) - } - p.Global = global +func (p *MySQLPrivilege) SortUserTable() { + p.user.Ascend(func(itm itemUser) bool { + slices.SortFunc(itm.data, compareUserRecord) + return true + }) } // LoadGlobalPrivTable loads the mysql.global_priv table from database. -func (p *MySQLPrivilege) LoadGlobalPrivTable(ctx sqlexec.RestrictedSQLExecutor) error { - if err := p.loadTable(ctx, sqlLoadGlobalPrivTable, p.decodeGlobalPrivTableRow); err != nil { +func (p *MySQLPrivilege) LoadGlobalPrivTable(exec sqlexec.SQLExecutor) error { + if err := loadTable(exec, sqlLoadGlobalPrivTable, p.decodeGlobalPrivTableRow(nil)); err != nil { return errors.Trace(err) } - p.buildGlobalMap() return nil } // LoadGlobalGrantsTable loads the mysql.global_priv table from database. -func (p *MySQLPrivilege) LoadGlobalGrantsTable(ctx sqlexec.RestrictedSQLExecutor) error { - if err := p.loadTable(ctx, sqlLoadGlobalGrantsTable, p.decodeGlobalGrantsTableRow); err != nil { +func (p *MySQLPrivilege) LoadGlobalGrantsTable(exec sqlexec.SQLExecutor) error { + if err := loadTable(exec, sqlLoadGlobalGrantsTable, p.decodeGlobalGrantsTableRow(nil)); err != nil { return errors.Trace(err) } - p.buildDynamicMap() return nil } // LoadDBTable loads the mysql.db table from database. -func (p *MySQLPrivilege) LoadDBTable(ctx sqlexec.RestrictedSQLExecutor) error { - err := p.loadTable(ctx, sqlLoadDBTable, p.decodeDBTableRow) +func (p *MySQLPrivilege) LoadDBTable(exec sqlexec.SQLExecutor) error { + err := loadTable(exec, sqlLoadDBTable, p.decodeDBTableRow(nil)) if err != nil { return err } - p.buildDBMap() + p.db.Ascend(func(itm itemDB) bool { + slices.SortFunc(itm.data, compareDBRecord) + return true + }) return nil } func compareDBRecord(x, y dbRecord) int { - return compareBaseRecord(&x.baseRecord, &y.baseRecord) + ret := compareBaseRecord(&x.baseRecord, &y.baseRecord) + if ret != 0 { + return ret + } + + return strings.Compare(x.DB, y.DB) } -func (p *MySQLPrivilege) buildDBMap() { - dbMap := make(map[string][]dbRecord, len(p.db)) - for _, record := range p.db { - dbMap[record.User] = append(dbMap[record.User], record) +func compareTablesPrivRecord(x, y tablesPrivRecord) int { + ret := compareBaseRecord(&x.baseRecord, &y.baseRecord) + if ret != 0 { + return ret } - // Sort the records to make the matching rule work. - for _, records := range dbMap { - slices.SortFunc(records, compareDBRecord) + ret = strings.Compare(x.DB, y.DB) + if ret != 0 { + return ret } - p.DBMap = dbMap -} -func (p *MySQLPrivilege) buildDynamicMap() { - dynamic := make(map[string][]dynamicPrivRecord) - for _, value := range p.dynamicPriv { - dynamic[value.User] = append(dynamic[value.User], value) - } - p.Dynamic = dynamic + return strings.Compare(x.TableName, y.TableName) } // LoadTablesPrivTable loads the mysql.tables_priv table from database. -func (p *MySQLPrivilege) LoadTablesPrivTable(ctx sqlexec.RestrictedSQLExecutor) error { - err := p.loadTable(ctx, sqlLoadTablePrivTable, p.decodeTablesPrivTableRow) +func (p *MySQLPrivilege) LoadTablesPrivTable(exec sqlexec.SQLExecutor) error { + err := loadTable(exec, sqlLoadTablePrivTable, p.decodeTablesPrivTableRow(nil)) if err != nil { return err } - p.buildTablesPrivMap() return nil } -func (p *MySQLPrivilege) buildTablesPrivMap() { - tablesPrivMap := make(map[string][]tablesPrivRecord, len(p.tablesPriv)) - for _, record := range p.tablesPriv { - tablesPrivMap[record.User] = append(tablesPrivMap[record.User], record) - } - p.TablesPrivMap = tablesPrivMap -} - // LoadColumnsPrivTable loads the mysql.columns_priv table from database. -func (p *MySQLPrivilege) LoadColumnsPrivTable(ctx sqlexec.RestrictedSQLExecutor) error { - return p.loadTable(ctx, sqlLoadColumnsPrivTable, p.decodeColumnsPrivTableRow) +func (p *MySQLPrivilege) LoadColumnsPrivTable(exec sqlexec.SQLExecutor) error { + return loadTable(exec, sqlLoadColumnsPrivTable, p.decodeColumnsPrivTableRow(nil)) } // LoadDefaultRoles loads the mysql.columns_priv table from database. -func (p *MySQLPrivilege) LoadDefaultRoles(ctx sqlexec.RestrictedSQLExecutor) error { - return p.loadTable(ctx, sqlLoadDefaultRoles, p.decodeDefaultRoleTableRow) +func (p *MySQLPrivilege) LoadDefaultRoles(exec sqlexec.SQLExecutor) error { + return loadTable(exec, sqlLoadDefaultRoles, p.decodeDefaultRoleTableRow(nil)) } -func addUserFilterCondition(sql string, userList []string) string { - if len(userList) == 0 { +func addUserFilterCondition(sql string, userList map[string]struct{}) string { + if len(userList) == 0 || len(userList) > 1024 { return sql } var b strings.Builder b.WriteString(sql) b.WriteString(" WHERE ") - for i, user := range userList { - if i > 0 { + first := true + for user := range userList { + if !first { b.WriteString(" OR ") + } else { + first = false } fmt.Fprintf(&b, "USER = '%s'", sqlescape.EscapeString(user)) } return b.String() } -func (p *immutable) loadTable(sctx sqlexec.RestrictedSQLExecutor, sql string, - decodeTableRow func(chunk.Row, []*resolve.ResultField) error, userList ...string) error { +func loadTable(exec sqlexec.SQLExecutor, sql string, + decodeTableRow func(chunk.Row, []*resolve.ResultField) error) error { ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnPrivilege) - sql = addUserFilterCondition(sql, userList) - rows, fs, err := sctx.ExecRestrictedSQL(ctx, nil, sql) + // Do not use sctx.ExecRestrictedSQL() here deliberately. + // The result set can be extremely large, so this streaming API is important to + // reduce memory cost. + rs, err := exec.ExecuteInternal(ctx, sql) if err != nil { return errors.Trace(err) } - for _, row := range rows { - // NOTE: decodeTableRow decodes data from a chunk Row, that is a shallow copy. - // The result will reference memory in the chunk, so the chunk must not be reused - // here, otherwise some werid bug will happen! - err = decodeTableRow(row, fs) + defer terror.Call(rs.Close) + fs := rs.Fields() + req := rs.NewChunk(nil) + for { + err = rs.Next(ctx, req) if err != nil { return errors.Trace(err) } + if req.NumRows() == 0 { + return nil + } + it := chunk.NewIterator4Chunk(req) + for row := it.Begin(); row != it.End(); row = it.Next() { + err = decodeTableRow(row, fs) + if err != nil { + return errors.Trace(err) + } + } + // NOTE: decodeTableRow decodes data from a chunk Row, that is a shallow copy. + // The result will reference memory in the chunk, so the chunk must not be reused + // here, otherwise some werid bug will happen! + req = chunk.Renew(req, 1024) } - return nil } // parseHostIPNet parses an IPv4 address and its subnet mask (e.g. `127.0.0.0/255.255.255.0`), @@ -832,261 +974,356 @@ func parseHostIPNet(s string) *net.IPNet { func (record *baseRecord) assignUserOrHost(row chunk.Row, i int, f *resolve.ResultField) { switch f.ColumnAsName.L { case "user": - record.User = row.GetString(i) + record.User = strings.Clone(row.GetString(i)) case "host": - record.Host = row.GetString(i) + record.Host = strings.Clone(row.GetString(i)) record.patChars, record.patTypes = stringutil.CompilePatternBinary(record.Host, '\\') record.hostIPNet = parseHostIPNet(record.Host) } } -func (p *immutable) decodeUserTableRow(row chunk.Row, fs []*resolve.ResultField) error { - var value UserRecord - for i, f := range fs { - switch { - case f.ColumnAsName.L == "authentication_string": - value.AuthenticationString = row.GetString(i) - case f.ColumnAsName.L == "account_locked": - if row.GetEnum(i).String() == "Y" { - value.AccountLocked = true - } - case f.ColumnAsName.L == "plugin": - if row.GetString(i) != "" { - value.AuthPlugin = row.GetString(i) - } else { - value.AuthPlugin = mysql.AuthNativePassword - } - case f.ColumnAsName.L == "token_issuer": - value.AuthTokenIssuer = row.GetString(i) - case f.ColumnAsName.L == "user_attributes": - if row.IsNull(i) { - continue +func (p *MySQLPrivilege) decodeUserTableRow(userList map[string]struct{}) func(chunk.Row, []*resolve.ResultField) error { + return func(row chunk.Row, fs []*resolve.ResultField) error { + var value UserRecord + defaultAuthPlugin := "" + if p.globalVars != nil { + val, err := p.globalVars.GetGlobalSysVar(variable.DefaultAuthPlugin) + if err == nil { + defaultAuthPlugin = val } - bj := row.GetJSON(i) - pathExpr, err := types.ParseJSONPathExpr("$.metadata.email") - if err != nil { - return err - } - if emailBJ, found := bj.Extract([]types.JSONPathExpression{pathExpr}); found { - email, err := emailBJ.Unquote() + } + if defaultAuthPlugin == "" { + defaultAuthPlugin = mysql.AuthNativePassword + } + for i, f := range fs { + switch { + case f.ColumnAsName.L == "authentication_string": + value.AuthenticationString = strings.Clone(row.GetString(i)) + case f.ColumnAsName.L == "account_locked": + if row.GetEnum(i).String() == "Y" { + value.AccountLocked = true + } + case f.ColumnAsName.L == "plugin": + if row.GetString(i) != "" { + value.AuthPlugin = strings.Clone(row.GetString(i)) + } else { + value.AuthPlugin = defaultAuthPlugin + } + case f.ColumnAsName.L == "token_issuer": + value.AuthTokenIssuer = strings.Clone(row.GetString(i)) + case f.ColumnAsName.L == "user_attributes": + if row.IsNull(i) { + continue + } + bj := row.GetJSON(i) + pathExpr, err := types.ParseJSONPathExpr("$.metadata.email") if err != nil { return err } - value.Email = email - } - pathExpr, err = types.ParseJSONPathExpr("$.resource_group") - if err != nil { - return err - } - if resourceGroup, found := bj.Extract([]types.JSONPathExpression{pathExpr}); found { - resourceGroup, err := resourceGroup.Unquote() + if emailBJ, found := bj.Extract([]types.JSONPathExpression{pathExpr}); found { + email, err := emailBJ.Unquote() + if err != nil { + return err + } + value.Email = strings.Clone(email) + } + pathExpr, err = types.ParseJSONPathExpr("$.resource_group") if err != nil { return err } - value.ResourceGroup = resourceGroup - } - passwordLocking := PasswordLocking{} - if err := passwordLocking.ParseJSON(bj); err != nil { - return err - } - value.FailedLoginAttempts = passwordLocking.FailedLoginAttempts - value.PasswordLockTimeDays = passwordLocking.PasswordLockTimeDays - value.FailedLoginCount = passwordLocking.FailedLoginCount - value.AutoLockedLastChanged = passwordLocking.AutoLockedLastChanged - value.AutoAccountLocked = passwordLocking.AutoAccountLocked - case f.ColumnAsName.L == "password_expired": - if row.GetEnum(i).String() == "Y" { - value.PasswordExpired = true - } - case f.ColumnAsName.L == "password_last_changed": - t := row.GetTime(i) - gotime, err := t.GoTime(time.Local) - if err != nil { - return err - } - value.PasswordLastChanged = gotime - case f.ColumnAsName.L == "password_lifetime": - if row.IsNull(i) { - value.PasswordLifeTime = -1 - continue - } - value.PasswordLifeTime = row.GetInt64(i) - case f.Column.GetType() == mysql.TypeEnum: - if row.GetEnum(i).String() != "Y" { - continue - } - priv, ok := mysql.Col2PrivType[f.ColumnAsName.O] - if !ok { - return errInvalidPrivilegeType.GenWithStack(f.ColumnAsName.O) + if resourceGroup, found := bj.Extract([]types.JSONPathExpression{pathExpr}); found { + resourceGroup, err := resourceGroup.Unquote() + if err != nil { + return err + } + value.ResourceGroup = strings.Clone(resourceGroup) + } + passwordLocking := PasswordLocking{} + if err := passwordLocking.ParseJSON(bj); err != nil { + return err + } + value.FailedLoginAttempts = passwordLocking.FailedLoginAttempts + value.PasswordLockTimeDays = passwordLocking.PasswordLockTimeDays + value.FailedLoginCount = passwordLocking.FailedLoginCount + value.AutoLockedLastChanged = passwordLocking.AutoLockedLastChanged + value.AutoAccountLocked = passwordLocking.AutoAccountLocked + case f.ColumnAsName.L == "password_expired": + if row.GetEnum(i).String() == "Y" { + value.PasswordExpired = true + } + case f.ColumnAsName.L == "password_last_changed": + t := row.GetTime(i) + gotime, err := t.GoTime(time.Local) + if err != nil { + return err + } + value.PasswordLastChanged = gotime + case f.ColumnAsName.L == "password_lifetime": + if row.IsNull(i) { + value.PasswordLifeTime = -1 + continue + } + value.PasswordLifeTime = row.GetInt64(i) + case f.Column.GetType() == mysql.TypeEnum: + if row.GetEnum(i).String() != "Y" { + continue + } + priv, ok := mysql.Col2PrivType[f.ColumnAsName.O] + if !ok { + return errInvalidPrivilegeType.GenWithStack(f.ColumnAsName.O) + } + value.Privileges |= priv + default: + value.assignUserOrHost(row, i, f) } - value.Privileges |= priv - default: - value.assignUserOrHost(row, i, f) } + old, ok := p.user.Get(itemUser{username: value.User}) + if !ok { + old.username = value.User + } + old.data = append(old.data, value) + p.user.ReplaceOrInsert(old) + return nil } - p.user = append(p.user, value) - return nil } -func (p *immutable) decodeGlobalPrivTableRow(row chunk.Row, fs []*resolve.ResultField) error { - var value globalPrivRecord - for i, f := range fs { - if f.ColumnAsName.L == "priv" { - privData := row.GetString(i) - if len(privData) > 0 { - var privValue GlobalPrivValue - err := json.Unmarshal(hack.Slice(privData), &privValue) - if err != nil { - logutil.BgLogger().Error("one user global priv data is broken, forbidden login until data be fixed", - zap.String("user", value.User), zap.String("host", value.Host)) - value.Broken = true - } else { - value.Priv.SSLType = privValue.SSLType - value.Priv.SSLCipher = privValue.SSLCipher - value.Priv.X509Issuer = privValue.X509Issuer - value.Priv.X509Subject = privValue.X509Subject - value.Priv.SAN = privValue.SAN - if len(value.Priv.SAN) > 0 { - value.Priv.SANs, err = util.ParseAndCheckSAN(value.Priv.SAN) - if err != nil { - value.Broken = true +func (p *MySQLPrivilege) decodeGlobalPrivTableRow(userList map[string]struct{}) func(chunk.Row, []*resolve.ResultField) error { + return func(row chunk.Row, fs []*resolve.ResultField) error { + var value globalPrivRecord + for i, f := range fs { + if f.ColumnAsName.L == "priv" { + privData := row.GetString(i) + if len(privData) > 0 { + var privValue GlobalPrivValue + err := json.Unmarshal(hack.Slice(privData), &privValue) + if err != nil { + logutil.BgLogger().Error("one user global priv data is broken, forbidden login until data be fixed", + zap.String("user", value.User), zap.String("host", value.Host)) + value.Broken = true + } else { + value.Priv.SSLType = privValue.SSLType + value.Priv.SSLCipher = strings.Clone(privValue.SSLCipher) + value.Priv.X509Issuer = strings.Clone(privValue.X509Issuer) + value.Priv.X509Subject = strings.Clone(privValue.X509Subject) + value.Priv.SAN = strings.Clone(privValue.SAN) + if len(value.Priv.SAN) > 0 { + value.Priv.SANs, err = util.ParseAndCheckSAN(value.Priv.SAN) + if err != nil { + value.Broken = true + } } } } + } else { + value.assignUserOrHost(row, i, f) + } + } + if userList != nil { + if _, ok := userList[value.User]; !ok { + return nil } - } else { - value.assignUserOrHost(row, i, f) } + + old, ok := p.globalPriv.Get(itemGlobalPriv{username: value.User}) + if !ok { + old.username = value.User + } + old.data = append(old.data, value) + p.globalPriv.ReplaceOrInsert(old) + return nil } - p.globalPriv = append(p.globalPriv, value) - return nil } -func (p *immutable) decodeGlobalGrantsTableRow(row chunk.Row, fs []*resolve.ResultField) error { - var value dynamicPrivRecord - for i, f := range fs { - switch f.ColumnAsName.L { - case "priv": - value.PrivilegeName = strings.ToUpper(row.GetString(i)) - case "with_grant_option": - value.GrantOption = row.GetEnum(i).String() == "Y" - default: - value.assignUserOrHost(row, i, f) +func (p *MySQLPrivilege) decodeGlobalGrantsTableRow(userList map[string]struct{}) func(chunk.Row, []*resolve.ResultField) error { + return func(row chunk.Row, fs []*resolve.ResultField) error { + var value dynamicPrivRecord + for i, f := range fs { + switch f.ColumnAsName.L { + case "priv": + value.PrivilegeName = strings.ToUpper(row.GetString(i)) + case "with_grant_option": + value.GrantOption = row.GetEnum(i).String() == "Y" + default: + value.assignUserOrHost(row, i, f) + } } + if userList != nil { + if _, ok := userList[value.User]; !ok { + return nil + } + } + + old, ok := p.dynamicPriv.Get(itemDynamicPriv{username: value.User}) + if !ok { + old.username = value.User + } + old.data = append(old.data, value) + p.dynamicPriv.ReplaceOrInsert(old) + return nil } - p.dynamicPriv = append(p.dynamicPriv, value) - return nil } -func (p *immutable) decodeDBTableRow(row chunk.Row, fs []*resolve.ResultField) error { - var value dbRecord - for i, f := range fs { - switch { - case f.ColumnAsName.L == "db": - value.DB = row.GetString(i) - value.dbPatChars, value.dbPatTypes = stringutil.CompilePatternBinary(strings.ToUpper(value.DB), '\\') - case f.Column.GetType() == mysql.TypeEnum: - if row.GetEnum(i).String() != "Y" { - continue +func (p *MySQLPrivilege) decodeDBTableRow(userList map[string]struct{}) func(chunk.Row, []*resolve.ResultField) error { + return func(row chunk.Row, fs []*resolve.ResultField) error { + var value dbRecord + for i, f := range fs { + switch { + case f.ColumnAsName.L == "db": + value.DB = row.GetString(i) + value.dbPatChars, value.dbPatTypes = stringutil.CompilePatternBinary(strings.ToUpper(value.DB), '\\') + case f.Column.GetType() == mysql.TypeEnum: + if row.GetEnum(i).String() != "Y" { + continue + } + priv, ok := mysql.Col2PrivType[f.ColumnAsName.O] + if !ok { + return errInvalidPrivilegeType.GenWithStack("Unknown Privilege Type!") + } + value.Privileges |= priv + default: + value.assignUserOrHost(row, i, f) } - priv, ok := mysql.Col2PrivType[f.ColumnAsName.O] - if !ok { - return errInvalidPrivilegeType.GenWithStack("Unknown Privilege Type!") + } + if userList != nil { + if _, ok := userList[value.User]; !ok { + return nil } - value.Privileges |= priv - default: - value.assignUserOrHost(row, i, f) } + + old, ok := p.db.Get(itemDB{username: value.User}) + if !ok { + old.username = value.User + } + old.data = append(old.data, value) + p.db.ReplaceOrInsert(old) + return nil } - p.db = append(p.db, value) - return nil } -func (p *immutable) decodeTablesPrivTableRow(row chunk.Row, fs []*resolve.ResultField) error { - var value tablesPrivRecord - for i, f := range fs { - switch f.ColumnAsName.L { - case "db": - value.DB = row.GetString(i) - case "table_name": - value.TableName = row.GetString(i) - case "table_priv": - value.TablePriv = decodeSetToPrivilege(row.GetSet(i)) - case "column_priv": - value.ColumnPriv = decodeSetToPrivilege(row.GetSet(i)) - default: - value.assignUserOrHost(row, i, f) - } - } - p.tablesPriv = append(p.tablesPriv, value) - return nil +func (p *MySQLPrivilege) decodeTablesPrivTableRow(userList map[string]struct{}) func(chunk.Row, []*resolve.ResultField) error { + return func(row chunk.Row, fs []*resolve.ResultField) error { + var value tablesPrivRecord + for i, f := range fs { + switch f.ColumnAsName.L { + case "db": + value.DB = row.GetString(i) + case "table_name": + value.TableName = row.GetString(i) + case "table_priv": + value.TablePriv = decodeSetToPrivilege(row.GetSet(i)) + case "column_priv": + value.ColumnPriv = decodeSetToPrivilege(row.GetSet(i)) + default: + value.assignUserOrHost(row, i, f) + } + } + if userList != nil { + if _, ok := userList[value.User]; !ok { + return nil + } + } + + old, ok := p.tablesPriv.Get(itemTablesPriv{username: value.User}) + if !ok { + old.username = value.User + } + old.data = append(old.data, value) + p.tablesPriv.ReplaceOrInsert(old) + return nil + } } -func (p *immutable) decodeRoleEdgesTable(row chunk.Row, fs []*resolve.ResultField) error { +func (p *MySQLPrivilege) decodeRoleEdgesTable(row chunk.Row, fs []*resolve.ResultField) error { var fromUser, fromHost, toHost, toUser string for i, f := range fs { switch f.ColumnAsName.L { case "from_host": - fromHost = row.GetString(i) + fromHost = strings.Clone(row.GetString(i)) case "from_user": - fromUser = row.GetString(i) + fromUser = strings.Clone(row.GetString(i)) case "to_host": - toHost = row.GetString(i) + toHost = strings.Clone(row.GetString(i)) case "to_user": - toUser = row.GetString(i) + toUser = strings.Clone(row.GetString(i)) } } - fromKey := fromUser + "@" + fromHost - toKey := toUser + "@" + toHost + fromKey := auth.RoleIdentity{Username: fromUser, Hostname: fromHost} + toKey := auth.RoleIdentity{Username: toUser, Hostname: toHost} roleGraph, ok := p.roleGraph[toKey] if !ok { - roleGraph = roleGraphEdgesTable{roleList: make(map[string]*auth.RoleIdentity)} + roleGraph = roleGraphEdgesTable{roleList: make(map[auth.RoleIdentity]*auth.RoleIdentity)} p.roleGraph[toKey] = roleGraph } roleGraph.roleList[fromKey] = &auth.RoleIdentity{Username: fromUser, Hostname: fromHost} return nil } -func (p *immutable) decodeDefaultRoleTableRow(row chunk.Row, fs []*resolve.ResultField) error { - var value defaultRoleRecord - for i, f := range fs { - switch f.ColumnAsName.L { - case "default_role_host": - value.DefaultRoleHost = row.GetString(i) - case "default_role_user": - value.DefaultRoleUser = row.GetString(i) - default: - value.assignUserOrHost(row, i, f) +func (p *MySQLPrivilege) decodeDefaultRoleTableRow(userList map[string]struct{}) func(chunk.Row, []*resolve.ResultField) error { + return func(row chunk.Row, fs []*resolve.ResultField) error { + var value defaultRoleRecord + for i, f := range fs { + switch f.ColumnAsName.L { + case "default_role_host": + value.DefaultRoleHost = row.GetString(i) + case "default_role_user": + value.DefaultRoleUser = row.GetString(i) + default: + value.assignUserOrHost(row, i, f) + } + } + if userList != nil { + if _, ok := userList[value.User]; !ok { + return nil + } + } + + old, ok := p.defaultRoles.Get(itemDefaultRole{username: value.User}) + if !ok { + old.username = value.User } + old.data = append(old.data, value) + p.defaultRoles.ReplaceOrInsert(old) + return nil } - p.defaultRoles = append(p.defaultRoles, value) - return nil } -func (p *immutable) decodeColumnsPrivTableRow(row chunk.Row, fs []*resolve.ResultField) error { - var value columnsPrivRecord - for i, f := range fs { - switch f.ColumnAsName.L { - case "db": - value.DB = row.GetString(i) - case "table_name": - value.TableName = row.GetString(i) - case "column_name": - value.ColumnName = row.GetString(i) - case "timestamp": - var err error - value.Timestamp, err = row.GetTime(i).GoTime(time.Local) - if err != nil { - return errors.Trace(err) +func (p *MySQLPrivilege) decodeColumnsPrivTableRow(userList map[string]struct{}) func(chunk.Row, []*resolve.ResultField) error { + return func(row chunk.Row, fs []*resolve.ResultField) error { + var value columnsPrivRecord + for i, f := range fs { + switch f.ColumnAsName.L { + case "db": + value.DB = row.GetString(i) + case "table_name": + value.TableName = row.GetString(i) + case "column_name": + value.ColumnName = row.GetString(i) + case "timestamp": + var err error + value.Timestamp, err = row.GetTime(i).GoTime(time.Local) + if err != nil { + return errors.Trace(err) + } + case "column_priv": + value.ColumnPriv = decodeSetToPrivilege(row.GetSet(i)) + default: + value.assignUserOrHost(row, i, f) + } + } + if userList != nil { + if _, ok := userList[value.User]; !ok { + return nil } - case "column_priv": - value.ColumnPriv = decodeSetToPrivilege(row.GetSet(i)) - default: - value.assignUserOrHost(row, i, f) } + + old, ok := p.columnsPriv.Get(itemColumnsPriv{username: value.User}) + if !ok { + old.username = value.User + } + old.data = append(old.data, value) + p.columnsPriv.ReplaceOrInsert(old) + return nil } - p.columnsPriv = append(p.columnsPriv, value) - return nil } func decodeSetToPrivilege(s types.Set) mysql.PrivilegeType { @@ -1159,9 +1396,14 @@ func patternMatch(str string, patChars, patTypes []byte) bool { // matchIdentity finds an identity to match a user + host // using the correct rules according to MySQL. -func (p *MySQLPrivilege) matchIdentity(sctx sqlexec.RestrictedSQLExecutor, user, host string, skipNameResolve bool) *UserRecord { - for i := 0; i < len(p.user); i++ { - record := &p.user[i] +func (p *MySQLPrivilege) matchIdentity(user, host string, skipNameResolve bool) *UserRecord { + item, ok := p.user.Get(itemUser{username: user}) + if !ok { + return nil + } + + for i := 0; i < len(item.data); i++ { + record := &item.data[i] if record.match(user, host) { return record } @@ -1181,8 +1423,8 @@ func (p *MySQLPrivilege) matchIdentity(sctx sqlexec.RestrictedSQLExecutor, user, return nil } for _, addr := range addrs { - for i := 0; i < len(p.user); i++ { - record := &p.user[i] + for i := 0; i < len(item.data); i++ { + record := &item.data[i] if record.match(user, addr) { return record } @@ -1192,25 +1434,14 @@ func (p *MySQLPrivilege) matchIdentity(sctx sqlexec.RestrictedSQLExecutor, user, return nil } -// matchResoureGroup finds an identity to match resource group. -func (p *MySQLPrivilege) matchResoureGroup(resourceGroupName string) *UserRecord { - for i := 0; i < len(p.user); i++ { - record := &p.user[i] - if record.ResourceGroup == resourceGroupName { - return record - } - } - return nil -} - // connectionVerification verifies the username + hostname according to exact // match from the mysql.user privilege table. call matchIdentity() first if you // do not have an exact match yet. func (p *MySQLPrivilege) connectionVerification(user, host string) *UserRecord { - records, exists := p.UserMap[user] + records, exists := p.user.Get(itemUser{username: user}) if exists { - for i := 0; i < len(records); i++ { - record := &records[i] + for i := 0; i < len(records.data); i++ { + record := &records.data[i] if record.Host == host { // exact match return record } @@ -1220,10 +1451,11 @@ func (p *MySQLPrivilege) connectionVerification(user, host string) *UserRecord { } func (p *MySQLPrivilege) matchGlobalPriv(user, host string) *globalPrivRecord { - uGlobal, exists := p.Global[user] + item, exists := p.globalPriv.Get(itemGlobalPriv{username: user}) if !exists { return nil } + uGlobal := item.data for i := 0; i < len(uGlobal); i++ { record := &uGlobal[i] if record.match(user, host) { @@ -1234,8 +1466,9 @@ func (p *MySQLPrivilege) matchGlobalPriv(user, host string) *globalPrivRecord { } func (p *MySQLPrivilege) matchUser(user, host string) *UserRecord { - records, exists := p.UserMap[user] + item, exists := p.user.Get(itemUser{username: user}) if exists { + records := item.data for i := 0; i < len(records); i++ { record := &records[i] if record.match(user, host) { @@ -1247,8 +1480,9 @@ func (p *MySQLPrivilege) matchUser(user, host string) *UserRecord { } func (p *MySQLPrivilege) matchDB(user, host, db string) *dbRecord { - records, exists := p.DBMap[user] + item, exists := p.db.Get(itemDB{username: user}) if exists { + records := item.data for i := 0; i < len(records); i++ { record := &records[i] if record.match(user, host, db) { @@ -1260,8 +1494,9 @@ func (p *MySQLPrivilege) matchDB(user, host, db string) *dbRecord { } func (p *MySQLPrivilege) matchTables(user, host, db, table string) *tablesPrivRecord { - records, exists := p.TablesPrivMap[user] + item, exists := p.tablesPriv.Get(itemTablesPriv{username: user}) if exists { + records := item.data for i := 0; i < len(records); i++ { record := &records[i] if record.match(user, host, db, table) { @@ -1273,10 +1508,13 @@ func (p *MySQLPrivilege) matchTables(user, host, db, table string) *tablesPrivRe } func (p *MySQLPrivilege) matchColumns(user, host, db, table, column string) *columnsPrivRecord { - for i := 0; i < len(p.columnsPriv); i++ { - record := &p.columnsPriv[i] - if record.match(user, host, db, table, column) { - return record + item, exists := p.columnsPriv.Get(itemColumnsPriv{username: user}) + if exists { + for i := 0; i < len(item.data); i++ { + record := &item.data[i] + if record.match(user, host, db, table, column) { + return record + } } } return nil @@ -1293,13 +1531,16 @@ func (p *MySQLPrivilege) HasExplicitlyGrantedDynamicPrivilege(activeRoles []*aut for _, r := range roleList { u := r.Username h := r.Hostname - for _, record := range p.Dynamic[u] { - if record.match(u, h) { - if withGrant && !record.GrantOption { - continue - } - if record.PrivilegeName == privName { - return true + item, exists := p.dynamicPriv.Get(itemDynamicPriv{username: u}) + if exists { + for _, record := range item.data { + if record.match(u, h) { + if withGrant && !record.GrantOption { + continue + } + if record.PrivilegeName == privName { + return true + } } } } @@ -1409,20 +1650,24 @@ func (p *MySQLPrivilege) DBIsVisible(user, host, db string) bool { } } - for _, record := range p.tablesPriv { - if record.baseRecord.match(user, host) && - strings.EqualFold(record.DB, db) { - if record.TablePriv != 0 || record.ColumnPriv != 0 { - return true + if item, exists := p.tablesPriv.Get(itemTablesPriv{username: user}); exists { + for _, record := range item.data { + if record.baseRecord.match(user, host) && + strings.EqualFold(record.DB, db) { + if record.TablePriv != 0 || record.ColumnPriv != 0 { + return true + } } } } - for _, record := range p.columnsPriv { - if record.baseRecord.match(user, host) && - strings.EqualFold(record.DB, db) { - if record.ColumnPriv != 0 { - return true + if item, exists := p.columnsPriv.Get(itemColumnsPriv{username: user}); exists { + for _, record := range item.data { + if record.baseRecord.match(user, host) && + strings.EqualFold(record.DB, db) { + if record.ColumnPriv != 0 { + return true + } } } } @@ -1441,10 +1686,12 @@ func (p *MySQLPrivilege) showGrants(ctx sessionctx.Context, user, host string, r var currentPriv mysql.PrivilegeType var userExists = false // Check whether user exists. - if userList, ok := p.UserMap[user]; ok { - for _, record := range userList { + if userList, ok := p.user.Get(itemUser{username: user}); ok { + for _, record := range userList.data { if record.fullyMatch(user, host) { userExists = true + hasGlobalGrant = true + currentPriv |= record.Privileges break } } @@ -1452,21 +1699,18 @@ func (p *MySQLPrivilege) showGrants(ctx sessionctx.Context, user, host string, r return gs } } - var g string - for _, record := range p.user { - if record.fullyMatch(user, host) { - hasGlobalGrant = true - currentPriv |= record.Privileges - } else { - for _, r := range allRoles { - if record.baseRecord.match(r.Username, r.Hostname) { + + for _, r := range allRoles { + if userList, ok := p.user.Get(itemUser{username: r.Username}); ok { + for _, record := range userList.data { + if record.fullyMatch(r.Username, r.Hostname) { hasGlobalGrant = true currentPriv |= record.Privileges } } } } - g = userPrivToString(currentPriv) + g := userPrivToString(currentPriv) if len(g) > 0 { var s string if (currentPriv & mysql.GrantPriv) > 0 { @@ -1491,17 +1735,20 @@ func (p *MySQLPrivilege) showGrants(ctx sessionctx.Context, user, host string, r // Show db scope grants. sortFromIdx = len(gs) dbPrivTable := make(map[string]mysql.PrivilegeType) - for _, record := range p.db { - if record.fullyMatch(user, host) { - dbPrivTable[record.DB] |= record.Privileges - } else { - for _, r := range allRoles { - if record.baseRecord.match(r.Username, r.Hostname) { - dbPrivTable[record.DB] |= record.Privileges + p.db.Ascend(func(itm itemDB) bool { + for _, record := range itm.data { + if record.fullyMatch(user, host) { + dbPrivTable[record.DB] |= record.Privileges + } else { + for _, r := range allRoles { + if record.baseRecord.match(r.Username, r.Hostname) { + dbPrivTable[record.DB] |= record.Privileges + } } } } - } + return true + }) sqlMode := ctx.GetSessionVars().SQLMode for dbName, priv := range dbPrivTable { @@ -1527,18 +1774,21 @@ func (p *MySQLPrivilege) showGrants(ctx sessionctx.Context, user, host string, r // Show table scope grants. sortFromIdx = len(gs) tablePrivTable := make(map[string]mysql.PrivilegeType) - for _, record := range p.tablesPriv { - recordKey := stringutil.Escape(record.DB, sqlMode) + "." + stringutil.Escape(record.TableName, sqlMode) - if user == record.User && host == record.Host { - tablePrivTable[recordKey] |= record.TablePriv - } else { - for _, r := range allRoles { - if record.baseRecord.match(r.Username, r.Hostname) { - tablePrivTable[recordKey] |= record.TablePriv + p.tablesPriv.Ascend(func(itm itemTablesPriv) bool { + for _, record := range itm.data { + recordKey := stringutil.Escape(record.DB, sqlMode) + "." + stringutil.Escape(record.TableName, sqlMode) + if user == record.User && host == record.Host { + tablePrivTable[recordKey] |= record.TablePriv + } else { + for _, r := range allRoles { + if record.baseRecord.match(r.Username, r.Hostname) { + tablePrivTable[recordKey] |= record.TablePriv + } } } } - } + return true + }) for k, priv := range tablePrivTable { g := tablePrivToString(priv) if len(g) > 0 { @@ -1562,14 +1812,16 @@ func (p *MySQLPrivilege) showGrants(ctx sessionctx.Context, user, host string, r // A map of "DB.Table" => Priv(col1, col2 ...) sortFromIdx = len(gs) columnPrivTable := make(map[string]privOnColumns) - for i := range p.columnsPriv { - record := p.columnsPriv[i] - if !collectColumnGrant(&record, user, host, columnPrivTable, sqlMode) { - for _, r := range allRoles { - collectColumnGrant(&record, r.Username, r.Hostname, columnPrivTable, sqlMode) + p.columnsPriv.Ascend(func(itm itemColumnsPriv) bool { + for _, record := range itm.data { + if !collectColumnGrant(&record, user, host, columnPrivTable, sqlMode) { + for _, r := range allRoles { + collectColumnGrant(&record, r.Username, r.Hostname, columnPrivTable, sqlMode) + } } } - } + return true + }) for k, v := range columnPrivTable { privCols := privOnColumnsToString(v) s := fmt.Sprintf(`GRANT %s ON %s TO '%s'@'%s'`, privCols, k, user, host) @@ -1578,15 +1830,13 @@ func (p *MySQLPrivilege) showGrants(ctx sessionctx.Context, user, host string, r slices.Sort(gs[sortFromIdx:]) // Show role grants. - graphKey := user + "@" + host + graphKey := auth.RoleIdentity{Username: user, Hostname: host} edgeTable, ok := p.roleGraph[graphKey] g = "" if ok { sortedRes := make([]string, 0, 10) for k := range edgeTable.roleList { - role := strings.Split(k, "@") - roleName, roleHost := role[0], role[1] - tmp := fmt.Sprintf("'%s'@'%s'", roleName, roleHost) + tmp := fmt.Sprintf("'%s'@'%s'", k.Username, k.Hostname) sortedRes = append(sortedRes, tmp) } slices.Sort(sortedRes) @@ -1604,21 +1854,25 @@ func (p *MySQLPrivilege) showGrants(ctx sessionctx.Context, user, host string, r // The convention is to merge the Dynamic privileges assigned to the user with // inherited dynamic privileges from those roles dynamicPrivsMap := make(map[string]bool) // privName, grantable - for _, record := range p.Dynamic[user] { - if record.fullyMatch(user, host) { - dynamicPrivsMap[record.PrivilegeName] = record.GrantOption + if item, exists := p.dynamicPriv.Get(itemDynamicPriv{username: user}); exists { + for _, record := range item.data { + if record.fullyMatch(user, host) { + dynamicPrivsMap[record.PrivilegeName] = record.GrantOption + } } } for _, r := range allRoles { - for _, record := range p.Dynamic[r.Username] { - if record.fullyMatch(r.Username, r.Hostname) { - // If the record already exists in the map and it's grantable - // skip doing anything, because we might inherit a non-grantable permission - // from a role, and don't want to clobber the existing privilege. - if grantable, ok := dynamicPrivsMap[record.PrivilegeName]; ok && grantable { - continue + if item, exists := p.dynamicPriv.Get(itemDynamicPriv{username: r.Username}); exists { + for _, record := range item.data { + if record.fullyMatch(r.Username, r.Hostname) { + // If the record already exists in the map and it's grantable + // skip doing anything, because we might inherit a non-grantable permission + // from a role, and don't want to clobber the existing privilege. + if grantable, ok := dynamicPrivsMap[record.PrivilegeName]; ok && grantable { + continue + } + dynamicPrivsMap[record.PrivilegeName] = record.GrantOption } - dynamicPrivsMap[record.PrivilegeName] = record.GrantOption } } } @@ -1739,18 +1993,22 @@ func (p *MySQLPrivilege) UserPrivilegesTable(activeRoles []*auth.RoleIdentity, u // This is verified against MySQL. showOtherUsers := p.RequestVerification(activeRoles, user, host, mysql.SystemDB, "", "", mysql.SelectPriv) var rows [][]types.Datum - for _, u := range p.user { - if showOtherUsers || u.match(user, host) { - rows = appendUserPrivilegesTableRow(rows, u) + p.user.Ascend(func(itm itemUser) bool { + for _, u := range itm.data { + if showOtherUsers || u.match(user, host) { + rows = appendUserPrivilegesTableRow(rows, u) + } } - } - for _, dynamicPrivs := range p.Dynamic { - for _, dynamicPriv := range dynamicPrivs { + return true + }) + p.dynamicPriv.Ascend(func(itm itemDynamicPriv) bool { + for _, dynamicPriv := range itm.data { if showOtherUsers || dynamicPriv.match(user, host) { rows = appendDynamicPrivRecord(rows, dynamicPriv) } } - } + return true + }) return rows } @@ -1794,16 +2052,18 @@ func appendUserPrivilegesTableRow(rows [][]types.Datum, user UserRecord) [][]typ func (p *MySQLPrivilege) getDefaultRoles(user, host string) []*auth.RoleIdentity { ret := make([]*auth.RoleIdentity, 0) - for _, r := range p.defaultRoles { - if r.match(user, host) { - ret = append(ret, &auth.RoleIdentity{Username: r.DefaultRoleUser, Hostname: r.DefaultRoleHost}) + if item, exists := p.defaultRoles.Get(itemDefaultRole{username: user}); exists { + for _, r := range item.data { + if r.match(user, host) { + ret = append(ret, &auth.RoleIdentity{Username: r.DefaultRoleUser, Hostname: r.DefaultRoleHost}) + } } } return ret } func (p *MySQLPrivilege) getAllRoles(user, host string) []*auth.RoleIdentity { - key := user + "@" + host + key := auth.RoleIdentity{Username: user, Hostname: host} edgeTable, ok := p.roleGraph[key] ret := make([]*auth.RoleIdentity, 0, len(edgeTable.roleList)) if ok { @@ -1814,27 +2074,62 @@ func (p *MySQLPrivilege) getAllRoles(user, host string) []*auth.RoleIdentity { return ret } +// SetGlobalVarsAccessor is only used for test. +func (p *MySQLPrivilege) SetGlobalVarsAccessor(globalVars variable.GlobalVarAccessor) { + p.globalVars = globalVars +} + // Handle wraps MySQLPrivilege providing thread safe access. type Handle struct { - sctx sqlexec.RestrictedSQLExecutor + sctx util.SessionPool priv atomic.Pointer[MySQLPrivilege] // Only load the active user's data to save memory // username => struct{} activeUsers sync.Map + fullData atomic.Bool + globalVars variable.GlobalVarAccessor } // NewHandle returns a Handle. -func NewHandle(sctx sqlexec.RestrictedSQLExecutor) *Handle { - var priv MySQLPrivilege +func NewHandle(sctx util.SessionPool, globalVars variable.GlobalVarAccessor) *Handle { + priv := newMySQLPrivilege() ret := &Handle{} ret.sctx = sctx - ret.priv.Store(&priv) + ret.globalVars = globalVars + ret.priv.Store(priv) return ret } // ensureActiveUser ensure that the specific user data is loaded in-memory. -func (h *Handle) ensureActiveUser(user string) error { - return nil +func (h *Handle) ensureActiveUser(ctx context.Context, user string) error { + if p := ctx.Value("mock"); p != nil { + visited := p.(*bool) + *visited = true + } + + if h.fullData.Load() { + // All users data are in-memory, nothing to do + return nil + } + + _, exist := h.activeUsers.Load(user) + if exist { + return nil + } + return h.updateUsers([]string{user}) +} + +func (h *Handle) merge(data *MySQLPrivilege, userList map[string]struct{}) { + for { + old := h.Get() + swapped := h.priv.CompareAndSwap(old, old.merge(data, userList)) + if swapped { + break + } + } + for user := range userList { + h.activeUsers.Store(user, struct{}{}) + } } // Get the MySQLPrivilege for read. @@ -1842,14 +2137,81 @@ func (h *Handle) Get() *MySQLPrivilege { return h.priv.Load() } -// Update loads all the privilege info from kv storage. -func (h *Handle) Update() error { - var priv MySQLPrivilege - err := priv.LoadAll(h.sctx) +// UpdateAll loads all the users' privilege info from kv storage. +func (h *Handle) UpdateAll() error { + priv := newMySQLPrivilege() + priv.globalVars = h.globalVars + res, err := h.sctx.Get() if err != nil { - return err + return errors.Trace(err) + } + defer h.sctx.Put(res) + exec := res.(sqlexec.SQLExecutor) + + err = priv.LoadAll(exec) + if err != nil { + return errors.Trace(err) } + h.priv.Store(priv) + h.fullData.Store(true) + return nil +} - h.priv.Store(&priv) +// UpdateAllActive loads all the active users' privilege info from kv storage. +func (h *Handle) UpdateAllActive() error { + h.fullData.Store(false) + userList := make([]string, 0, 20) + h.activeUsers.Range(func(key, _ any) bool { + userList = append(userList, key.(string)) + return true + }) + metrics.ActiveUser.Set(float64(len(userList))) + return h.updateUsers(userList) +} + +// Update loads the privilege info from kv storage for the list of users. +func (h *Handle) Update(userList []string) error { + h.fullData.Store(false) + if len(userList) > 100 { + logutil.BgLogger().Warn("update user list is long", zap.Int("len", len(userList))) + } + needReload := false + for _, user := range userList { + if _, ok := h.activeUsers.Load(user); ok { + needReload = true + break + } + } + if !needReload { + return nil + } + + return h.updateUsers(userList) +} + +func (h *Handle) updateUsers(userList []string) error { + res, err := h.sctx.Get() + if err != nil { + return errors.Trace(err) + } + defer h.sctx.Put(res) + exec := res.(sqlexec.SQLExecutor) + + p := newMySQLPrivilege() + p.globalVars = h.globalVars + // Load the full role edge table first. + p.roleGraph = make(map[auth.RoleIdentity]roleGraphEdgesTable) + err = loadTable(exec, sqlLoadRoleGraph, p.decodeRoleEdgesTable) + if err != nil { + return errors.Trace(err) + } + + // Including the user and also their roles + userAndRoles := findUserAndAllRoles(userList, p.roleGraph) + err = p.loadSomeUsers(exec, userAndRoles) + if err != nil { + return err + } + h.merge(p, userAndRoles) return nil } diff --git a/pkg/privilege/privileges/cache_test.go b/pkg/privilege/privileges/cache_test.go index 3527d347aac55..99613084d3533 100644 --- a/pkg/privilege/privileges/cache_test.go +++ b/pkg/privilege/privileges/cache_test.go @@ -15,6 +15,7 @@ package privileges_test import ( + "context" "fmt" "testing" "time" @@ -22,6 +23,7 @@ import ( "github.com/pingcap/tidb/pkg/parser/auth" "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/privilege/privileges" + "github.com/pingcap/tidb/pkg/sessionctx/variable" "github.com/pingcap/tidb/pkg/testkit" "github.com/pingcap/tidb/pkg/util" "github.com/stretchr/testify/require" @@ -34,9 +36,9 @@ func TestLoadUserTable(t *testing.T) { tk.MustExec("use mysql;") tk.MustExec("truncate table user;") - var p privileges.MySQLPrivilege + p := privileges.NewMySQLPrivilege() se := tk.Session() - require.NoError(t, p.LoadUserTable(se.GetRestrictedSQLExecutor())) + require.NoError(t, p.LoadUserTable(se.GetSQLExecutor())) require.Len(t, p.User(), 0) // Host | User | authentication_string | Select_priv | Insert_priv | Update_priv | Delete_priv | Create_priv | Drop_priv | Process_priv | Grant_priv | References_priv | Alter_priv | Show_db_priv | Super_priv | Execute_priv | Index_priv | Create_user_priv | Trigger_priv @@ -48,9 +50,9 @@ func TestLoadUserTable(t *testing.T) { tk.MustExec(`INSERT INTO mysql.user (Host, User, password_expired, password_last_changed, password_lifetime) VALUES ("%", "root2", "Y", "2022-10-10 12:00:00", 3)`) tk.MustExec(`INSERT INTO mysql.user (Host, User, password_expired, password_last_changed) VALUES ("%", "root3", "N", "2022-10-10 12:00:00")`) - p = privileges.MySQLPrivilege{} - require.NoError(t, p.LoadUserTable(se.GetRestrictedSQLExecutor())) - require.Len(t, p.User(), len(p.UserMap)) + p = privileges.NewMySQLPrivilege() + require.NoError(t, p.LoadUserTable(se.GetSQLExecutor())) + // require.Len(t, p.User(), len(p.UserMap)) user := p.User() require.Equal(t, "root", user[0].User) @@ -66,6 +68,15 @@ func TestLoadUserTable(t *testing.T) { require.Equal(t, false, user[6].PasswordExpired) require.Equal(t, time.Date(2022, 10, 10, 12, 0, 0, 0, time.Local), user[6].PasswordLastChanged) require.Equal(t, int64(-1), user[6].PasswordLifeTime) + + // test switching default auth plugin + for _, plugin := range []string{mysql.AuthNativePassword, mysql.AuthCachingSha2Password, mysql.AuthTiDBSM3Password} { + p = privileges.NewMySQLPrivilege() + p.SetGlobalVarsAccessor(se.GetSessionVars().GlobalVarsAccessor) + require.NoError(t, se.GetSessionVars().GlobalVarsAccessor.SetGlobalSysVar(context.Background(), variable.DefaultAuthPlugin, plugin)) + require.NoError(t, p.LoadUserTable(se.GetSQLExecutor())) + require.Equal(t, plugin, p.User()[0].AuthPlugin) + } } func TestLoadGlobalPrivTable(t *testing.T) { @@ -78,18 +89,19 @@ func TestLoadGlobalPrivTable(t *testing.T) { tk.MustExec(`INSERT INTO mysql.global_priv VALUES ("%", "tu", "{\"access\":0,\"plugin\":\"mysql_native_password\",\"ssl_type\":3, \"ssl_cipher\":\"cipher\",\"x509_subject\":\"\C=ZH1\", \"x509_issuer\":\"\C=ZH2\", \"san\":\"\IP:127.0.0.1, IP:1.1.1.1, DNS:pingcap.com, URI:spiffe://mesh.pingcap.com/ns/timesh/sa/me1\", \"password_last_changed\":1}")`) - var p privileges.MySQLPrivilege + p := privileges.NewMySQLPrivilege() se := tk.Session() - require.NoError(t, p.LoadGlobalPrivTable(se.GetRestrictedSQLExecutor())) - require.Equal(t, `%`, p.Global["tu"][0].Host) - require.Equal(t, `tu`, p.Global["tu"][0].User) - require.Equal(t, privileges.SslTypeSpecified, p.Global["tu"][0].Priv.SSLType) - require.Equal(t, "C=ZH2", p.Global["tu"][0].Priv.X509Issuer) - require.Equal(t, "C=ZH1", p.Global["tu"][0].Priv.X509Subject) - require.Equal(t, "IP:127.0.0.1, IP:1.1.1.1, DNS:pingcap.com, URI:spiffe://mesh.pingcap.com/ns/timesh/sa/me1", p.Global["tu"][0].Priv.SAN) - require.Len(t, p.Global["tu"][0].Priv.SANs[util.IP], 2) - require.Equal(t, "pingcap.com", p.Global["tu"][0].Priv.SANs[util.DNS][0]) - require.Equal(t, "spiffe://mesh.pingcap.com/ns/timesh/sa/me1", p.Global["tu"][0].Priv.SANs[util.URI][0]) + require.NoError(t, p.LoadGlobalPrivTable(se.GetSQLExecutor())) + val := p.GlobalPriv("tu")[0] + require.Equal(t, `%`, val.Host) + require.Equal(t, `tu`, val.User) + require.Equal(t, privileges.SslTypeSpecified, val.Priv.SSLType) + require.Equal(t, "C=ZH2", val.Priv.X509Issuer) + require.Equal(t, "C=ZH1", val.Priv.X509Subject) + require.Equal(t, "IP:127.0.0.1, IP:1.1.1.1, DNS:pingcap.com, URI:spiffe://mesh.pingcap.com/ns/timesh/sa/me1", val.Priv.SAN) + require.Len(t, val.Priv.SANs[util.IP], 2) + require.Equal(t, "pingcap.com", val.Priv.SANs[util.DNS][0]) + require.Equal(t, "spiffe://mesh.pingcap.com/ns/timesh/sa/me1", val.Priv.SANs[util.URI][0]) } func TestLoadDBTable(t *testing.T) { @@ -102,10 +114,10 @@ func TestLoadDBTable(t *testing.T) { tk.MustExec(`INSERT INTO mysql.db (Host, DB, User, Select_priv, Insert_priv, Update_priv, Delete_priv, Create_priv) VALUES ("%", "information_schema", "root", "Y", "Y", "Y", "Y", "Y")`) tk.MustExec(`INSERT INTO mysql.db (Host, DB, User, Drop_priv, Grant_priv, Index_priv, Alter_priv, Create_view_priv, Show_view_priv, Execute_priv) VALUES ("%", "mysql", "root1", "Y", "Y", "Y", "Y", "Y", "Y", "Y")`) - var p privileges.MySQLPrivilege + p := privileges.NewMySQLPrivilege() se := tk.Session() - require.NoError(t, p.LoadDBTable(se.GetRestrictedSQLExecutor())) - require.Len(t, p.DB(), len(p.DBMap)) + require.NoError(t, p.LoadDBTable(se.GetSQLExecutor())) + // require.Len(t, p.DB(), len(p.DBMap)) require.Equal(t, mysql.SelectPriv|mysql.InsertPriv|mysql.UpdatePriv|mysql.DeletePriv|mysql.CreatePriv, p.DB()[0].Privileges) require.Equal(t, mysql.DropPriv|mysql.GrantPriv|mysql.IndexPriv|mysql.AlterPriv|mysql.CreateViewPriv|mysql.ShowViewPriv|mysql.ExecutePriv, p.DB()[1].Privileges) @@ -120,11 +132,11 @@ func TestLoadTablesPrivTable(t *testing.T) { tk.MustExec(`INSERT INTO mysql.tables_priv VALUES ("%", "db", "user", "table", "grantor", "2017-01-04 16:33:42.235831", "Grant,Index,Alter", "Insert,Update")`) - var p privileges.MySQLPrivilege + p := privileges.NewMySQLPrivilege() se := tk.Session() - require.NoError(t, p.LoadTablesPrivTable(se.GetRestrictedSQLExecutor())) + require.NoError(t, p.LoadTablesPrivTable(se.GetSQLExecutor())) tablesPriv := p.TablesPriv() - require.Len(t, tablesPriv, len(p.TablesPrivMap)) + // require.Len(t, tablesPriv, len(p.TablesPrivMap)) require.Equal(t, `%`, tablesPriv[0].Host) require.Equal(t, "db", tablesPriv[0].DB) @@ -144,9 +156,9 @@ func TestLoadColumnsPrivTable(t *testing.T) { tk.MustExec(`INSERT INTO mysql.columns_priv VALUES ("%", "db", "user", "table", "column", "2017-01-04 16:33:42.235831", "Insert,Update")`) tk.MustExec(`INSERT INTO mysql.columns_priv VALUES ("127.0.0.1", "db", "user", "table", "column", "2017-01-04 16:33:42.235831", "Select")`) - var p privileges.MySQLPrivilege + p := privileges.NewMySQLPrivilege() se := tk.Session() - require.NoError(t, p.LoadColumnsPrivTable(se.GetRestrictedSQLExecutor())) + require.NoError(t, p.LoadColumnsPrivTable(se.GetSQLExecutor())) columnsPriv := p.ColumnsPriv() require.Equal(t, `%`, columnsPriv[0].Host) require.Equal(t, "db", columnsPriv[0].DB) @@ -166,9 +178,9 @@ func TestLoadDefaultRoleTable(t *testing.T) { tk.MustExec(`INSERT INTO mysql.default_roles VALUES ("%", "test_default_roles", "localhost", "r_1")`) tk.MustExec(`INSERT INTO mysql.default_roles VALUES ("%", "test_default_roles", "localhost", "r_2")`) - var p privileges.MySQLPrivilege + p := privileges.NewMySQLPrivilege() se := tk.Session() - require.NoError(t, p.LoadDefaultRoles(se.GetRestrictedSQLExecutor())) + require.NoError(t, p.LoadDefaultRoles(se.GetSQLExecutor())) require.Equal(t, `%`, p.DefaultRoles()[0].Host) require.Equal(t, "test_default_roles", p.DefaultRoles()[0].User) require.Equal(t, "localhost", p.DefaultRoles()[0].DefaultRoleHost) @@ -185,9 +197,9 @@ func TestPatternMatch(t *testing.T) { tk.MustExec("USE MYSQL;") tk.MustExec("TRUNCATE TABLE mysql.user") tk.MustExec(`INSERT INTO mysql.user (HOST, USER, Select_priv, Shutdown_priv) VALUES ("10.0.%", "root", "Y", "Y")`) - var p privileges.MySQLPrivilege + p := privileges.NewMySQLPrivilege() se := tk.Session() - require.NoError(t, p.LoadUserTable(se.GetRestrictedSQLExecutor())) + require.NoError(t, p.LoadUserTable(se.GetSQLExecutor())) require.True(t, p.RequestVerification(activeRoles, "root", "10.0.1", "test", "", "", mysql.SelectPriv)) require.True(t, p.RequestVerification(activeRoles, "root", "10.0.1.118", "test", "", "", mysql.SelectPriv)) require.False(t, p.RequestVerification(activeRoles, "root", "localhost", "test", "", "", mysql.SelectPriv)) @@ -198,8 +210,8 @@ func TestPatternMatch(t *testing.T) { tk.MustExec("TRUNCATE TABLE mysql.user") tk.MustExec(`INSERT INTO mysql.user (HOST, USER, Select_priv, Shutdown_priv) VALUES ("", "root", "Y", "N")`) - p = privileges.MySQLPrivilege{} - require.NoError(t, p.LoadUserTable(se.GetRestrictedSQLExecutor())) + p = privileges.NewMySQLPrivilege() + require.NoError(t, p.LoadUserTable(se.GetSQLExecutor())) require.True(t, p.RequestVerification(activeRoles, "root", "", "test", "", "", mysql.SelectPriv)) require.False(t, p.RequestVerification(activeRoles, "root", "notnull", "test", "", "", mysql.SelectPriv)) require.False(t, p.RequestVerification(activeRoles, "root", "", "test", "", "", mysql.ShutdownPriv)) @@ -208,7 +220,7 @@ func TestPatternMatch(t *testing.T) { tk.MustExec("TRUNCATE TABLE mysql.user") tk.MustExec("TRUNCATE TABLE mysql.db") tk.MustExec(`INSERT INTO mysql.db (user,host,db,select_priv) values ('genius', '%', 'te%', 'Y')`) - require.NoError(t, p.LoadDBTable(se.GetRestrictedSQLExecutor())) + require.NoError(t, p.LoadDBTable(se.GetSQLExecutor())) require.True(t, p.RequestVerification(activeRoles, "genius", "127.0.0.1", "test", "", "", mysql.SelectPriv)) } @@ -222,9 +234,9 @@ func TestHostMatch(t *testing.T) { tk.MustExec("USE MYSQL;") tk.MustExec("TRUNCATE TABLE mysql.user") tk.MustExec(`INSERT INTO mysql.user (HOST, USER, authentication_string, Select_priv, Shutdown_priv) VALUES ("172.0.0.0/255.0.0.0", "root", "", "Y", "Y")`) - var p privileges.MySQLPrivilege + p := privileges.NewMySQLPrivilege() se := tk.Session() - require.NoError(t, p.LoadUserTable(se.GetRestrictedSQLExecutor())) + require.NoError(t, p.LoadUserTable(se.GetSQLExecutor())) require.True(t, p.RequestVerification(activeRoles, "root", "172.0.0.1", "test", "", "", mysql.SelectPriv)) require.True(t, p.RequestVerification(activeRoles, "root", "172.1.1.1", "test", "", "", mysql.SelectPriv)) require.False(t, p.RequestVerification(activeRoles, "root", "localhost", "test", "", "", mysql.SelectPriv)) @@ -248,9 +260,9 @@ func TestHostMatch(t *testing.T) { for _, IPMask := range cases { sql := fmt.Sprintf(`INSERT INTO mysql.user (HOST, USER, Select_priv, Shutdown_priv) VALUES ("%s", "root", "Y", "Y")`, IPMask) tk.MustExec(sql) - p = privileges.MySQLPrivilege{} + p := privileges.NewMySQLPrivilege() se := tk.Session() - require.NoError(t, p.LoadUserTable(se.GetRestrictedSQLExecutor())) + require.NoError(t, p.LoadUserTable(se.GetSQLExecutor())) require.False(t, p.RequestVerification(activeRoles, "root", "127.0.0.1", "test", "", "", mysql.SelectPriv), fmt.Sprintf("test case: %s", IPMask)) require.False(t, p.RequestVerification(activeRoles, "root", "127.0.0.0", "test", "", "", mysql.SelectPriv), fmt.Sprintf("test case: %s", IPMask)) require.False(t, p.RequestVerification(activeRoles, "root", "localhost", "test", "", "", mysql.ShutdownPriv), fmt.Sprintf("test case: %s", IPMask)) @@ -258,8 +270,8 @@ func TestHostMatch(t *testing.T) { // Netmask notation cannot be used for IPv6 addresses. tk.MustExec(`INSERT INTO mysql.user (HOST, USER, Select_priv, Shutdown_priv) VALUES ("2001:db8::/ffff:ffff::", "root", "Y", "Y")`) - p = privileges.MySQLPrivilege{} - require.NoError(t, p.LoadUserTable(se.GetRestrictedSQLExecutor())) + p = privileges.NewMySQLPrivilege() + require.NoError(t, p.LoadUserTable(se.GetSQLExecutor())) require.False(t, p.RequestVerification(activeRoles, "root", "2001:db8::1234", "test", "", "", mysql.SelectPriv)) require.False(t, p.RequestVerification(activeRoles, "root", "2001:db8::", "test", "", "", mysql.SelectPriv)) require.False(t, p.RequestVerification(activeRoles, "root", "localhost", "test", "", "", mysql.ShutdownPriv)) @@ -274,9 +286,9 @@ func TestCaseInsensitive(t *testing.T) { tk.MustExec("CREATE TABLE TCTrain.TCTrainOrder (id int);") tk.MustExec("TRUNCATE TABLE mysql.user") tk.MustExec(`INSERT INTO mysql.db VALUES ("127.0.0.1", "TCTrain", "genius", "Y", "Y", "Y", "Y", "Y", "N", "N", "N", "N", "N", "N", "N", "N", "N", "N", "N", "N", "N", "N")`) - var p privileges.MySQLPrivilege + p := privileges.NewMySQLPrivilege() se := tk.Session() - require.NoError(t, p.LoadDBTable(se.GetRestrictedSQLExecutor())) + require.NoError(t, p.LoadDBTable(se.GetSQLExecutor())) // DB and Table names are case-insensitive in MySQL. require.True(t, p.RequestVerification(activeRoles, "genius", "127.0.0.1", "TCTrain", "TCTrainOrder", "", mysql.SelectPriv)) require.True(t, p.RequestVerification(activeRoles, "genius", "127.0.0.1", "TCTRAIN", "TCTRAINORDER", "", mysql.SelectPriv)) @@ -290,9 +302,9 @@ func TestLoadRoleGraph(t *testing.T) { tk.MustExec("use mysql;") tk.MustExec("truncate table user;") - var p privileges.MySQLPrivilege + p := privileges.NewMySQLPrivilege() se := tk.Session() - require.NoError(t, p.LoadDBTable(se.GetRestrictedSQLExecutor())) + require.NoError(t, p.LoadDBTable(se.GetSQLExecutor())) require.Len(t, p.User(), 0) tk.MustExec(`INSERT INTO mysql.role_edges (FROM_HOST, FROM_USER, TO_HOST, TO_USER) VALUES ("%", "r_1", "%", "user2")`) @@ -300,16 +312,16 @@ func TestLoadRoleGraph(t *testing.T) { tk.MustExec(`INSERT INTO mysql.role_edges (FROM_HOST, FROM_USER, TO_HOST, TO_USER) VALUES ("%", "r_3", "%", "user1")`) tk.MustExec(`INSERT INTO mysql.role_edges (FROM_HOST, FROM_USER, TO_HOST, TO_USER) VALUES ("%", "r_4", "%", "root")`) - p = privileges.MySQLPrivilege{} - require.NoError(t, p.LoadRoleGraph(se.GetRestrictedSQLExecutor())) + p = privileges.NewMySQLPrivilege() + require.NoError(t, p.LoadRoleGraph(se.GetSQLExecutor())) graph := p.RoleGraph() - require.True(t, graph["root@%"].Find("r_2", "%")) - require.True(t, graph["root@%"].Find("r_4", "%")) - require.True(t, graph["user2@%"].Find("r_1", "%")) - require.True(t, graph["user1@%"].Find("r_3", "%")) - _, ok := graph["illedal"] + require.True(t, graph[auth.RoleIdentity{Username: "root", Hostname: "%"}].Find("r_2", "%")) + require.True(t, graph[auth.RoleIdentity{Username: "root", Hostname: "%"}].Find("r_4", "%")) + require.True(t, graph[auth.RoleIdentity{Username: "user2", Hostname: "%"}].Find("r_1", "%")) + require.True(t, graph[auth.RoleIdentity{Username: "user1", Hostname: "%"}].Find("r_3", "%")) + _, ok := graph[auth.RoleIdentity{Username: "illedal"}] require.False(t, ok) - require.False(t, graph["root@%"].Find("r_1", "%")) + require.False(t, graph[auth.RoleIdentity{Username: "root", Hostname: "%"}].Find("r_1", "%")) } func TestRoleGraphBFS(t *testing.T) { @@ -323,9 +335,9 @@ func TestRoleGraphBFS(t *testing.T) { tk.MustExec(`GRANT r_1 TO r_4;`) tk.MustExec(`GRANT r_5 TO r_3, r_6;`) - var p privileges.MySQLPrivilege + p := privileges.NewMySQLPrivilege() se := tk.Session() - require.NoError(t, p.LoadRoleGraph(se.GetRestrictedSQLExecutor())) + require.NoError(t, p.LoadRoleGraph(se.GetSQLExecutor())) activeRoles := make([]*auth.RoleIdentity, 0) ret := p.FindAllRole(activeRoles) @@ -357,9 +369,9 @@ func TestFindAllUserEffectiveRoles(t *testing.T) { tk.MustExec(`GRANT r_1 to u1`) tk.MustExec(`GRANT r_2 to u1`) - var p privileges.MySQLPrivilege + p := privileges.NewMySQLPrivilege() se := tk.Session() - require.NoError(t, p.LoadAll(se.GetRestrictedSQLExecutor())) + require.NoError(t, p.LoadAll(se.GetSQLExecutor())) ret := p.FindAllUserEffectiveRoles("u1", "%", []*auth.RoleIdentity{ {Username: "r_1", Hostname: "%"}, {Username: "r_2", Hostname: "%"}, @@ -371,7 +383,7 @@ func TestFindAllUserEffectiveRoles(t *testing.T) { require.Equal(t, "r_4", ret[3].Username) tk.MustExec(`REVOKE r_2 from u1`) - require.NoError(t, p.LoadAll(se.GetRestrictedSQLExecutor())) + require.NoError(t, p.LoadAll(se.GetSQLExecutor())) ret = p.FindAllUserEffectiveRoles("u1", "%", []*auth.RoleIdentity{ {Username: "r_1", Hostname: "%"}, {Username: "r_2", Hostname: "%"}, @@ -382,7 +394,7 @@ func TestFindAllUserEffectiveRoles(t *testing.T) { } func TestSortUserTable(t *testing.T) { - var p privileges.MySQLPrivilege + p := privileges.NewMySQLPrivilege() p.SetUser([]privileges.UserRecord{ privileges.NewUserRecord(`%`, "root"), privileges.NewUserRecord(`%`, "jeffrey"), @@ -391,8 +403,8 @@ func TestSortUserTable(t *testing.T) { }) p.SortUserTable() result := []privileges.UserRecord{ - privileges.NewUserRecord("localhost", "root"), privileges.NewUserRecord("localhost", ""), + privileges.NewUserRecord("localhost", "root"), privileges.NewUserRecord(`%`, "jeffrey"), privileges.NewUserRecord(`%`, "root"), } @@ -453,60 +465,60 @@ func TestDBIsVisible(t *testing.T) { tk := testkit.NewTestKit(t, store) tk.MustExec("create database visdb") - p := privileges.MySQLPrivilege{} + p := privileges.NewMySQLPrivilege() se := tk.Session() - require.NoError(t, p.LoadAll(se.GetRestrictedSQLExecutor())) + require.NoError(t, p.LoadAll(se.GetSQLExecutor())) tk.MustExec(`INSERT INTO mysql.user (Host, User, Create_role_priv, Super_priv) VALUES ("%", "testvisdb", "Y", "Y")`) - require.NoError(t, p.LoadUserTable(se.GetRestrictedSQLExecutor())) + require.NoError(t, p.LoadUserTable(se.GetSQLExecutor())) isVisible := p.DBIsVisible("testvisdb", "%", "visdb") require.False(t, isVisible) tk.MustExec("TRUNCATE TABLE mysql.user") tk.MustExec(`INSERT INTO mysql.user (Host, User, Select_priv) VALUES ("%", "testvisdb2", "Y")`) - require.NoError(t, p.LoadUserTable(se.GetRestrictedSQLExecutor())) + require.NoError(t, p.LoadUserTable(se.GetSQLExecutor())) isVisible = p.DBIsVisible("testvisdb2", "%", "visdb") require.True(t, isVisible) tk.MustExec("TRUNCATE TABLE mysql.user") tk.MustExec(`INSERT INTO mysql.user (Host, User, Create_priv) VALUES ("%", "testvisdb3", "Y")`) - require.NoError(t, p.LoadUserTable(se.GetRestrictedSQLExecutor())) + require.NoError(t, p.LoadUserTable(se.GetSQLExecutor())) isVisible = p.DBIsVisible("testvisdb3", "%", "visdb") require.True(t, isVisible) tk.MustExec("TRUNCATE TABLE mysql.user") tk.MustExec(`INSERT INTO mysql.user (Host, User, Insert_priv) VALUES ("%", "testvisdb4", "Y")`) - require.NoError(t, p.LoadUserTable(se.GetRestrictedSQLExecutor())) + require.NoError(t, p.LoadUserTable(se.GetSQLExecutor())) isVisible = p.DBIsVisible("testvisdb4", "%", "visdb") require.True(t, isVisible) tk.MustExec("TRUNCATE TABLE mysql.user") tk.MustExec(`INSERT INTO mysql.user (Host, User, Update_priv) VALUES ("%", "testvisdb5", "Y")`) - require.NoError(t, p.LoadUserTable(se.GetRestrictedSQLExecutor())) + require.NoError(t, p.LoadUserTable(se.GetSQLExecutor())) isVisible = p.DBIsVisible("testvisdb5", "%", "visdb") require.True(t, isVisible) tk.MustExec("TRUNCATE TABLE mysql.user") tk.MustExec(`INSERT INTO mysql.user (Host, User, Create_view_priv) VALUES ("%", "testvisdb6", "Y")`) - require.NoError(t, p.LoadUserTable(se.GetRestrictedSQLExecutor())) + require.NoError(t, p.LoadUserTable(se.GetSQLExecutor())) isVisible = p.DBIsVisible("testvisdb6", "%", "visdb") require.True(t, isVisible) tk.MustExec("TRUNCATE TABLE mysql.user") tk.MustExec(`INSERT INTO mysql.user (Host, User, Trigger_priv) VALUES ("%", "testvisdb7", "Y")`) - require.NoError(t, p.LoadUserTable(se.GetRestrictedSQLExecutor())) + require.NoError(t, p.LoadUserTable(se.GetSQLExecutor())) isVisible = p.DBIsVisible("testvisdb7", "%", "visdb") require.True(t, isVisible) tk.MustExec("TRUNCATE TABLE mysql.user") tk.MustExec(`INSERT INTO mysql.user (Host, User, References_priv) VALUES ("%", "testvisdb8", "Y")`) - require.NoError(t, p.LoadUserTable(se.GetRestrictedSQLExecutor())) + require.NoError(t, p.LoadUserTable(se.GetSQLExecutor())) isVisible = p.DBIsVisible("testvisdb8", "%", "visdb") require.True(t, isVisible) tk.MustExec("TRUNCATE TABLE mysql.user") tk.MustExec(`INSERT INTO mysql.user (Host, User, Execute_priv) VALUES ("%", "testvisdb9", "Y")`) - require.NoError(t, p.LoadUserTable(se.GetRestrictedSQLExecutor())) + require.NoError(t, p.LoadUserTable(se.GetSQLExecutor())) isVisible = p.DBIsVisible("testvisdb9", "%", "visdb") require.True(t, isVisible) tk.MustExec("TRUNCATE TABLE mysql.user") diff --git a/pkg/privilege/privileges/privileges.go b/pkg/privilege/privileges/privileges.go index 5a607e56aa27c..5d52dbe64142e 100644 --- a/pkg/privilege/privileges/privileges.go +++ b/pkg/privilege/privileges/privileges.go @@ -15,6 +15,7 @@ package privileges import ( + "context" "crypto/tls" "crypto/x509" "errors" @@ -29,6 +30,7 @@ import ( "github.com/lestrrat-go/jwx/v2/jwt/openid" "github.com/pingcap/tidb/pkg/extension" "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/kv" "github.com/pingcap/tidb/pkg/parser/auth" "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/parser/terror" @@ -43,6 +45,7 @@ import ( "github.com/pingcap/tidb/pkg/util/hack" "github.com/pingcap/tidb/pkg/util/logutil" "github.com/pingcap/tidb/pkg/util/sem" + "github.com/pingcap/tidb/pkg/util/sqlexec" "go.uber.org/zap" ) @@ -103,6 +106,7 @@ func (p *UserPrivileges) RequestDynamicVerificationWithUser(privName string, gra return false } + terror.Log(p.Handle.ensureActiveUser(context.Background(), user.Username)) mysqlPriv := p.Handle.Get() roles := mysqlPriv.getDefaultRoles(user.Username, user.Hostname) return mysqlPriv.RequestDynamicVerification(roles, user.Username, user.Hostname, privName, grantable) @@ -219,7 +223,7 @@ func (p *UserPrivileges) RequestVerificationWithUser(db, table, column string, p return true } - terror.Log(p.Handle.ensureActiveUser(user.Username)) + terror.Log(p.Handle.ensureActiveUser(context.Background(), user.Username)) mysqlPriv := p.Handle.Get() roles := mysqlPriv.getDefaultRoles(user.Username, user.Hostname) return mysqlPriv.RequestVerification(roles, user.Username, user.Hostname, db, table, column, priv) @@ -315,6 +319,7 @@ func (p *UserPrivileges) isValidHash(record *UserRecord) bool { // GetEncodedPassword implements the Manager interface. func (p *UserPrivileges) GetEncodedPassword(user, host string) string { + terror.Log(p.Handle.ensureActiveUser(context.Background(), user)) mysqlPriv := p.Handle.Get() record := mysqlPriv.connectionVerification(user, host) if record == nil { @@ -334,6 +339,7 @@ func (p *UserPrivileges) GetAuthPluginForConnection(user, host string) (string, return mysql.AuthNativePassword, nil } + terror.Log(p.Handle.ensureActiveUser(context.Background(), user)) mysqlPriv := p.Handle.Get() record := mysqlPriv.connectionVerification(user, host) if record == nil { @@ -364,6 +370,8 @@ func (p *UserPrivileges) GetAuthPlugin(user, host string) (string, error) { if SkipWithGrant { return mysql.AuthNativePassword, nil } + + terror.Log(p.Handle.ensureActiveUser(context.Background(), user)) mysqlPriv := p.Handle.Get() record := mysqlPriv.connectionVerification(user, host) if record == nil { @@ -380,12 +388,12 @@ func (p *UserPrivileges) MatchIdentity(user, host string, skipNameResolve bool) if SkipWithGrant { return user, host, true } - if err := p.Handle.ensureActiveUser(user); err != nil { + if err := p.Handle.ensureActiveUser(context.Background(), user); err != nil { logutil.BgLogger().Error("ensure user data fail", zap.String("user", user)) } mysqlPriv := p.Handle.Get() - record := mysqlPriv.matchIdentity(p.Handle.sctx, user, host, skipNameResolve) + record := mysqlPriv.matchIdentity(user, host, skipNameResolve) if record != nil { return record.User, record.Host, true } @@ -393,11 +401,16 @@ func (p *UserPrivileges) MatchIdentity(user, host string, skipNameResolve bool) } // MatchUserResourceGroupName implements the Manager interface. -func (p *UserPrivileges) MatchUserResourceGroupName(resourceGroupName string) (u string, success bool) { - mysqlPriv := p.Handle.Get() - record := mysqlPriv.matchResoureGroup(resourceGroupName) - if record != nil { - return record.User, true +func (p *UserPrivileges) MatchUserResourceGroupName(exec sqlexec.RestrictedSQLExecutor, resourceGroupName string) (u string, success bool) { + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnPrivilege) + sql := "SELECT user FROM mysql.user WHERE json_extract(user_attributes, '$.resource_group') = %? LIMIT 1" + rows, _, err := exec.ExecRestrictedSQL(ctx, nil, sql, resourceGroupName) + if err != nil { + logutil.BgLogger().Error("execute sql error", zap.String("sql", sql), zap.Error(err)) + return "", false + } + if len(rows) > 0 { + return rows[0].GetString(0), true } return "", false } @@ -907,7 +920,7 @@ func (p *UserPrivileges) UserPrivilegesTable(activeRoles []*auth.RoleIdentity, u } // ShowGrants implements privilege.Manager ShowGrants interface. -func (p *UserPrivileges) ShowGrants(ctx sessionctx.Context, user *auth.UserIdentity, roles []*auth.RoleIdentity) (grants []string, err error) { +func (p *UserPrivileges) ShowGrants(ctx context.Context, sctx sessionctx.Context, user *auth.UserIdentity, roles []*auth.RoleIdentity) (grants []string, err error) { if SkipWithGrant { return nil, ErrNonexistingGrant.GenWithStackByArgs("root", "%") } @@ -917,12 +930,11 @@ func (p *UserPrivileges) ShowGrants(ctx sessionctx.Context, user *auth.UserIdent u = user.AuthUsername h = user.AuthHostname } - if err := p.Handle.ensureActiveUser(u); err != nil { + if err := p.Handle.ensureActiveUser(ctx, u); err != nil { return nil, err } mysqlPrivilege := p.Handle.Get() - - grants = mysqlPrivilege.showGrants(ctx, u, h, roles) + grants = mysqlPrivilege.showGrants(sctx, u, h, roles) if len(grants) == 0 { err = ErrNonexistingGrant.GenWithStackByArgs(u, h) } @@ -931,31 +943,29 @@ func (p *UserPrivileges) ShowGrants(ctx sessionctx.Context, user *auth.UserIdent } // ActiveRoles implements privilege.Manager ActiveRoles interface. -func (p *UserPrivileges) ActiveRoles(ctx sessionctx.Context, roleList []*auth.RoleIdentity) (bool, string) { +func (p *UserPrivileges) ActiveRoles(ctx context.Context, sctx sessionctx.Context, roleList []*auth.RoleIdentity) (bool, string) { if SkipWithGrant { return true, "" } - mysqlPrivilege := p.Handle.Get() u := p.user h := p.host for _, r := range roleList { - ok := mysqlPrivilege.FindRole(u, h, r) + ok := findRole(ctx, p.Handle, u, h, r) if !ok { logutil.BgLogger().Error("find role failed", zap.Stringer("role", r)) return false, r.String() } } - ctx.GetSessionVars().ActiveRoles = roleList + sctx.GetSessionVars().ActiveRoles = roleList return true, "" } // FindEdge implements privilege.Manager FindRelationship interface. -func (p *UserPrivileges) FindEdge(ctx sessionctx.Context, role *auth.RoleIdentity, user *auth.UserIdentity) bool { +func (p *UserPrivileges) FindEdge(ctx context.Context, sctx sessionctx.Context, role *auth.RoleIdentity, user *auth.UserIdentity) bool { if SkipWithGrant { return false } - mysqlPrivilege := p.Handle.Get() - ok := mysqlPrivilege.FindRole(user.Username, user.Hostname, role) + ok := findRole(ctx, p.Handle, user.Username, user.Hostname, role) if !ok { logutil.BgLogger().Error("find role failed", zap.Stringer("role", role)) return false @@ -968,6 +978,7 @@ func (p *UserPrivileges) GetDefaultRoles(user, host string) []*auth.RoleIdentity if SkipWithGrant { return make([]*auth.RoleIdentity, 0, 10) } + terror.Log(p.Handle.ensureActiveUser(context.Background(), user)) mysqlPrivilege := p.Handle.Get() ret := mysqlPrivilege.getDefaultRoles(user, host) return ret diff --git a/pkg/privilege/privileges/privileges_test.go b/pkg/privilege/privileges/privileges_test.go index e483690852c26..8ff62916d57cf 100644 --- a/pkg/privilege/privileges/privileges_test.go +++ b/pkg/privilege/privileges/privileges_test.go @@ -30,6 +30,7 @@ import ( "time" "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/domain" "github.com/pingcap/tidb/pkg/errno" "github.com/pingcap/tidb/pkg/kv" "github.com/pingcap/tidb/pkg/parser/auth" @@ -1904,7 +1905,7 @@ func TestCheckPasswordExpired(t *testing.T) { sessionVars := variable.NewSessionVars(nil) sessionVars.GlobalVarsAccessor = variable.NewMockGlobalAccessor4Tests() record := privileges.NewUserRecord("%", "root") - userPrivilege := privileges.NewUserPrivileges(privileges.NewHandle(nil), nil) + userPrivilege := privileges.NewUserPrivileges(privileges.NewHandle(nil, nil), nil) record.PasswordExpired = true _, err := userPrivilege.CheckPasswordExpired(sessionVars, &record) @@ -2091,7 +2092,7 @@ func TestNilHandleInConnectionVerification(t *testing.T) { func testShowGrantsSQLMode(t *testing.T, tk *testkit.TestKit, expected []string) { pc := privilege.GetPrivilegeManager(tk.Session()) - gs, err := pc.ShowGrants(tk.Session(), &auth.UserIdentity{Username: "show_sql_mode", Hostname: "localhost"}, nil) + gs, err := pc.ShowGrants(context.Background(), tk.Session(), &auth.UserIdentity{Username: "show_sql_mode", Hostname: "localhost"}, nil) require.NoError(t, err) require.Len(t, gs, 2) require.True(t, testutil.CompareUnorderedStringSlice(gs, expected), fmt.Sprintf("gs: %v, expected: %v", gs, expected)) @@ -2116,3 +2117,77 @@ func TestShowGrantsSQLMode(t *testing.T) { "GRANT SELECT ON \"test\".* TO 'show_sql_mode'@'localhost'", }) } + +func TestEnsureActiveUserCoverage(t *testing.T) { + store := createStoreAndPrepareDB(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("create user 'test'") + tk.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "localhost"}, nil, nil, nil) + + cases := []struct { + sql string + visited bool + }{ + {"drop user if exists 'test1'", false}, + {"alter user test identified by 'test1'", false}, + {"set password for test = 'test2'", false}, + {"show create user test", false}, + {"create user test1", false}, + {"grant select on test.* to test1", false}, + {"show grants", true}, + {"show grants for 'test'@'%'", true}, + } + + for ith, c := range cases { + var visited bool + ctx := context.WithValue(context.Background(), "mock", &visited) + rs, err := tk.ExecWithContext(ctx, c.sql) + require.NoError(t, err) + + comment := fmt.Sprintf("testcase %d failed", ith) + if rs != nil { + tk.ResultSetToResultWithCtx(ctx, rs, comment) + } + require.Equal(t, c.visited, visited, comment) + } +} + +func TestSQLVariableAccelerateUserCreationUpdate(t *testing.T) { + store := createStoreAndPrepareDB(t) + tk := testkit.NewTestKit(t, store) + dom := domain.GetDomain(tk.Session()) + // 1. check the default variable value + tk.MustQuery("select @@global.tidb_accelerate_user_creation_update").Check(testkit.Rows("0")) + // trigger priv reload + tk.MustExec("create user aaa") + handle := dom.PrivilegeHandle() + handle.CheckFullData(t, true) + priv := handle.Get() + require.False(t, priv.RequestVerification(nil, "bbb", "%", "test", "", "", mysql.SelectPriv)) + + // 2. change the variable and check + tk.MustExec("set @@global.tidb_accelerate_user_creation_update = on") + tk.MustQuery("select @@global.tidb_accelerate_user_creation_update").Check(testkit.Rows("1")) + require.True(t, variable.AccelerateUserCreationUpdate.Load()) + tk.MustExec("create user bbb") + handle.CheckFullData(t, false) + // trigger priv reload, but data for bbb is not really loaded + tk.MustExec("grant select on test.* to bbb") + priv = handle.Get() + // data for bbb is not loaded, because that user is not active + // So this is **counterintuitive**, but it's still the expected behavior. + require.False(t, priv.RequestVerification(nil, "bbb", "%", "test", "", "", mysql.SelectPriv)) + tk1 := testkit.NewTestKit(t, store) + // if user bbb login, everything works as expected + require.NoError(t, tk1.Session().Auth(&auth.UserIdentity{Username: "bbb", Hostname: "localhost"}, nil, nil, nil)) + priv = handle.Get() + require.True(t, priv.RequestVerification(nil, "bbb", "%", "test", "", "", mysql.SelectPriv)) + + // 3. change the variable and check again + tk.MustExec("set @@global.tidb_accelerate_user_creation_update = off") + tk.MustQuery("select @@global.tidb_accelerate_user_creation_update").Check(testkit.Rows("0")) + tk.MustExec("drop user aaa") + handle.CheckFullData(t, true) + priv = handle.Get() + require.True(t, priv.RequestVerification(nil, "bbb", "%", "test", "", "", mysql.SelectPriv)) +} diff --git a/pkg/privilege/privileges/tidb_auth_token_test.go b/pkg/privilege/privileges/tidb_auth_token_test.go index d03190bc7a15f..1612d52774015 100644 --- a/pkg/privilege/privileges/tidb_auth_token_test.go +++ b/pkg/privilege/privileges/tidb_auth_token_test.go @@ -20,6 +20,7 @@ import ( "fmt" "log" "os" + "slices" "strings" "testing" "time" @@ -29,6 +30,7 @@ import ( jwsRepo "github.com/lestrrat-go/jwx/v2/jws" jwtRepo "github.com/lestrrat-go/jwx/v2/jwt" "github.com/lestrrat-go/jwx/v2/jwt/openid" + "github.com/pingcap/tidb/pkg/parser/auth" "github.com/pingcap/tidb/pkg/util/hack" "github.com/stretchr/testify/require" ) @@ -418,30 +420,75 @@ func TestJWKSImpl(t *testing.T) { require.Error(t, err) } -func (p *immutable) User() []UserRecord { - return p.user +func (p *MySQLPrivilege) User() []UserRecord { + var ret []UserRecord + p.user.Ascend(func(itm itemUser) bool { + ret = append(ret, itm.data...) + return true + }) + slices.SortStableFunc(ret, compareUserRecord) + return ret } -func (p *immutable) SetUser(user []UserRecord) { - p.user = user +func (p *MySQLPrivilege) SetUser(user []UserRecord) { + p.user.Clear(false) + for _, u := range user { + old, exists := p.user.Get(itemUser{username: u.User}) + if !exists { + old.username = u.User + } + old.data = append(old.data, u) + p.user.ReplaceOrInsert(old) + } +} + +func (p *MySQLPrivilege) DB() []dbRecord { + var ret []dbRecord + p.db.Ascend(func(itm itemDB) bool { + ret = append(ret, itm.data...) + return true + }) + return ret } -func (p *immutable) DB() []dbRecord { - return p.db +func (p *MySQLPrivilege) TablesPriv() []tablesPrivRecord { + var ret []tablesPrivRecord + p.tablesPriv.Ascend(func(itm itemTablesPriv) bool { + ret = append(ret, itm.data...) + return true + }) + return ret } -func (p *immutable) TablesPriv() []tablesPrivRecord { - return p.tablesPriv +func (p *MySQLPrivilege) ColumnsPriv() []columnsPrivRecord { + var ret []columnsPrivRecord + p.columnsPriv.Ascend(func(itm itemColumnsPriv) bool { + ret = append(ret, itm.data...) + return true + }) + return ret } -func (p *immutable) ColumnsPriv() []columnsPrivRecord { - return p.columnsPriv +func (p *MySQLPrivilege) DefaultRoles() []defaultRoleRecord { + var ret []defaultRoleRecord + p.defaultRoles.Ascend(func(itm itemDefaultRole) bool { + ret = append(ret, itm.data...) + return true + }) + return ret } -func (p *immutable) DefaultRoles() []defaultRoleRecord { - return p.defaultRoles +func (p *MySQLPrivilege) GlobalPriv(user string) []globalPrivRecord { + ret, _ := p.globalPriv.Get(itemGlobalPriv{username: user}) + return ret.data } -func (p *immutable) RoleGraph() map[string]roleGraphEdgesTable { +func (p *MySQLPrivilege) RoleGraph() map[auth.RoleIdentity]roleGraphEdgesTable { return p.roleGraph } + +func (h *Handle) CheckFullData(t *testing.T, value bool) { + require.True(t, h.fullData.Load() == value) +} + +var NewMySQLPrivilege = newMySQLPrivilege diff --git a/pkg/server/tests/commontest/tidb_test.go b/pkg/server/tests/commontest/tidb_test.go index 3e8e28446ec64..8bc38e57af1d1 100644 --- a/pkg/server/tests/commontest/tidb_test.go +++ b/pkg/server/tests/commontest/tidb_test.go @@ -2666,7 +2666,7 @@ func TestSandBoxMode(t *testing.T) { require.NoError(t, err) _, err = Execute(context.Background(), qctx, "create user testuser;") require.NoError(t, err) - qctx.Session.GetSessionVars().User = &auth.UserIdentity{Username: "testuser", AuthUsername: "testuser", AuthHostname: "%"} + qctx.Session.Auth(&auth.UserIdentity{Username: "testuser", AuthUsername: "testuser", AuthHostname: "%"}, nil, nil, nil) alterPwdStmts := []string{ "set password = '1234';", diff --git a/pkg/session/bootstrap.go b/pkg/session/bootstrap.go index f68458ee18b3f..f26f2c211c701 100644 --- a/pkg/session/bootstrap.go +++ b/pkg/session/bootstrap.go @@ -115,14 +115,15 @@ const ( Password_expired ENUM('N','Y') NOT NULL DEFAULT 'N', Password_last_changed TIMESTAMP DEFAULT CURRENT_TIMESTAMP(), Password_lifetime SMALLINT UNSIGNED DEFAULT NULL, - PRIMARY KEY (Host, User));` + PRIMARY KEY (Host, User), + KEY i_user (User));` // CreateGlobalPrivTable is the SQL statement creates Global scope privilege table in system db. CreateGlobalPrivTable = "CREATE TABLE IF NOT EXISTS mysql.global_priv (" + "Host CHAR(255) NOT NULL DEFAULT ''," + "User CHAR(80) NOT NULL DEFAULT ''," + "Priv LONGTEXT NOT NULL DEFAULT ''," + - "PRIMARY KEY (Host, User)" + - ")" + "PRIMARY KEY (Host, User)," + + "KEY i_user (User))" // For `mysql.db`, `mysql.tables_priv` and `mysql.columns_priv` table, we have a slight different // schema definition with MySQL: columns `DB`/`Table_name`/`Column_name` are defined with case-insensitive @@ -160,7 +161,8 @@ const ( Execute_priv ENUM('N','Y') NOT NULL DEFAULT 'N', Event_priv ENUM('N','Y') NOT NULL DEFAULT 'N', Trigger_priv ENUM('N','Y') NOT NULL DEFAULT 'N', - PRIMARY KEY (Host, DB, User));` + PRIMARY KEY (Host, DB, User), + KEY i_user (User));` // CreateTablePrivTable is the SQL statement creates table scope privilege table in system db. CreateTablePrivTable = `CREATE TABLE IF NOT EXISTS mysql.tables_priv ( Host CHAR(255), @@ -171,7 +173,8 @@ const ( Timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP, Table_priv SET('Select','Insert','Update','Delete','Create','Drop','Grant','Index','Alter','Create View','Show View','Trigger','References'), Column_priv SET('Select','Insert','Update','References'), - PRIMARY KEY (Host, DB, User, Table_name));` + PRIMARY KEY (Host, DB, User, Table_name), + KEY i_user (User));` // CreateColumnPrivTable is the SQL statement creates column scope privilege table in system db. CreateColumnPrivTable = `CREATE TABLE IF NOT EXISTS mysql.columns_priv( Host CHAR(255), @@ -181,7 +184,8 @@ const ( Column_name CHAR(64) CHARSET utf8mb4 COLLATE utf8mb4_general_ci, Timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP, Column_priv SET('Select','Insert','Update','References'), - PRIMARY KEY (Host, DB, User, Table_name, Column_name));` + PRIMARY KEY (Host, DB, User, Table_name, Column_name), + KEY i_user (User));` // CreateGlobalVariablesTable is the SQL statement creates global variable table in system db. // TODO: MySQL puts GLOBAL_VARIABLES table in INFORMATION_SCHEMA db. // INFORMATION_SCHEMA is a virtual db in TiDB. So we put this table in system db. @@ -317,8 +321,8 @@ const ( USER CHAR(32) COLLATE utf8_bin NOT NULL DEFAULT '', DEFAULT_ROLE_HOST CHAR(60) COLLATE utf8_bin NOT NULL DEFAULT '%', DEFAULT_ROLE_USER CHAR(32) COLLATE utf8_bin NOT NULL DEFAULT '', - PRIMARY KEY (HOST,USER,DEFAULT_ROLE_HOST,DEFAULT_ROLE_USER) - )` + PRIMARY KEY (HOST,USER,DEFAULT_ROLE_HOST,DEFAULT_ROLE_USER), + KEY i_user (USER))` // CreateStatsTopNTable stores topn data of a cmsketch with top n. CreateStatsTopNTable = `CREATE TABLE IF NOT EXISTS mysql.stats_top_n ( @@ -380,7 +384,8 @@ const ( HOST char(255) NOT NULL DEFAULT '', PRIV char(32) NOT NULL DEFAULT '', WITH_GRANT_OPTION enum('N','Y') NOT NULL DEFAULT 'N', - PRIMARY KEY (USER,HOST,PRIV) + PRIMARY KEY (USER,HOST,PRIV), + KEY i_user (USER) );` // CreateCapturePlanBaselinesBlacklist stores the baseline capture filter rules. CreateCapturePlanBaselinesBlacklist = `CREATE TABLE IF NOT EXISTS mysql.capture_plan_baselines_blacklist ( @@ -1238,16 +1243,17 @@ const ( // add modify_params to tidb_global_task and tidb_global_task_history. version223 = 223 + // Add index on user field for some mysql tables. + version224 = 224 + // ... // [version223, version238] is the version range reserved for patches of 8.5.x // ... - - // next version should start with 239 ) // currentBootstrapVersion is defined as a variable, so we can modify its value for testing. // please make sure this is the largest version -var currentBootstrapVersion int64 = version223 +var currentBootstrapVersion int64 = version224 // DDL owner key's expired time is ManagerSessionTTL seconds, we should wait the time and give more time to have a chance to finish it. var internalSQLTimeout = owner.ManagerSessionTTL + 15 @@ -1426,6 +1432,7 @@ var ( upgradeToVer221, upgradeToVer222, upgradeToVer223, + upgradeToVer224, } ) @@ -3302,6 +3309,19 @@ func upgradeToVer223(s sessiontypes.Session, ver int64) { doReentrantDDL(s, "ALTER TABLE mysql.tidb_global_task_history ADD COLUMN modify_params json AFTER `error`;", infoschema.ErrColumnExists) } +func upgradeToVer224(s sessiontypes.Session, ver int64) { + if ver >= version224 { + return + } + doReentrantDDL(s, "ALTER TABLE mysql.user ADD INDEX i_user (user)", dbterror.ErrDupKeyName) + doReentrantDDL(s, "ALTER TABLE mysql.global_priv ADD INDEX i_user (user)", dbterror.ErrDupKeyName) + doReentrantDDL(s, "ALTER TABLE mysql.db ADD INDEX i_user (user)", dbterror.ErrDupKeyName) + doReentrantDDL(s, "ALTER TABLE mysql.tables_priv ADD INDEX i_user (user)", dbterror.ErrDupKeyName) + doReentrantDDL(s, "ALTER TABLE mysql.columns_priv ADD INDEX i_user (user)", dbterror.ErrDupKeyName) + doReentrantDDL(s, "ALTER TABLE mysql.global_grants ADD INDEX i_user (user)", dbterror.ErrDupKeyName) + doReentrantDDL(s, "ALTER TABLE mysql.default_roles ADD INDEX i_user (user)", dbterror.ErrDupKeyName) +} + // initGlobalVariableIfNotExists initialize a global variable with specific val if it does not exist. func initGlobalVariableIfNotExists(s sessiontypes.Session, name string, val any) { ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnBootstrap) diff --git a/pkg/session/bootstraptest/bootstrap_upgrade_test.go b/pkg/session/bootstraptest/bootstrap_upgrade_test.go index 3070aeee5097a..d60d341966210 100644 --- a/pkg/session/bootstraptest/bootstrap_upgrade_test.go +++ b/pkg/session/bootstraptest/bootstrap_upgrade_test.go @@ -460,7 +460,7 @@ func TestUpgradeVersionForPausedJob(t *testing.T) { // checkDDLJobExecSucc is used to make sure the DDL operation is successful. func checkDDLJobExecSucc(t *testing.T, se sessiontypes.Session, jobID int64) { - sql := fmt.Sprintf(" admin show ddl jobs where job_id=%d", jobID) + sql := fmt.Sprintf(" admin show ddl jobs 20 where job_id=%d", jobID) suc := false for i := 0; i < 20; i++ { rows, err := execute(context.Background(), se, sql) diff --git a/pkg/session/session.go b/pkg/session/session.go index a511d8c5e7ae1..2ec07249815aa 100644 --- a/pkg/session/session.go +++ b/pkg/session/session.go @@ -2798,7 +2798,7 @@ func (s *session) Auth(user *auth.UserIdentity, authentication, salt []byte, aut } if lockStatusChanged { // Notification auto unlock. - err = domain.GetDomain(s).NotifyUpdatePrivilege() + err = domain.GetDomain(s).NotifyUpdatePrivilege([]string{authUser.Username}) if err != nil { return err } @@ -2972,7 +2972,7 @@ func authFailedTracking(s *session, user string, host string) (bool, *privileges func autolockAction(s *session, passwordLocking *privileges.PasswordLocking, user, host string) error { // Don't want to update the cache frequently, and only trigger the update cache when the lock status is updated. - err := domain.GetDomain(s).NotifyUpdatePrivilege() + err := domain.GetDomain(s).NotifyUpdatePrivilege([]string{user}) if err != nil { return err } diff --git a/pkg/sessionctx/variable/sysvar.go b/pkg/sessionctx/variable/sysvar.go index 3d691f9fede86..238dde11f99d2 100644 --- a/pkg/sessionctx/variable/sysvar.go +++ b/pkg/sessionctx/variable/sysvar.go @@ -3584,6 +3584,12 @@ var defaultSysVars = []*SysVar{ return nil }, }, + {Scope: ScopeGlobal, Name: TiDBAccelerateUserCreationUpdate, Value: BoolToOnOff(DefTiDBAccelerateUserCreationUpdate), Type: TypeBool, + SetGlobal: func(_ context.Context, s *SessionVars, val string) error { + AccelerateUserCreationUpdate.Store(TiDBOptOn(val)) + return nil + }, + }, } // GlobalSystemVariableInitialValue gets the default value for a system variable including ones that are dynamically set (e.g. based on the store) diff --git a/pkg/sessionctx/variable/tidb_vars.go b/pkg/sessionctx/variable/tidb_vars.go index 3c8006acea7e0..a9f7669ef77bf 100644 --- a/pkg/sessionctx/variable/tidb_vars.go +++ b/pkg/sessionctx/variable/tidb_vars.go @@ -1257,6 +1257,9 @@ const ( // TiDBCircuitBreakerPDMetadataErrorRateThresholdRatio variable is used to set ratio of errors to trip the circuit breaker for get region calls to PD // https://github.com/tikv/rfcs/blob/master/text/0115-circuit-breaker.md TiDBCircuitBreakerPDMetadataErrorRateThresholdRatio = "tidb_cb_pd_metadata_error_rate_threshold_ratio" + + // TiDBAccelerateUserCreationUpdate decides whether tidb will load & update the whole user's data in-memory. + TiDBAccelerateUserCreationUpdate = "tidb_accelerate_user_creation_update" ) // TiDB intentional limits @@ -1631,6 +1634,7 @@ const ( DefTiDBAdvancerCheckPointLagLimit = 48 * time.Hour DefTiDBIndexLookUpPushDownPolicy = IndexLookUpPushDownPolicyHintOnly DefTiDBCircuitBreakerPDMetaErrorRateRatio = 0.0 + DefTiDBAccelerateUserCreationUpdate = false ) // Process global variables. @@ -1756,6 +1760,7 @@ var ( AdvancerCheckPointLagLimit = atomic.NewDuration(DefTiDBAdvancerCheckPointLagLimit) CircuitBreakerPDMetadataErrorRateThresholdRatio = atomic.NewFloat64(0.0) + AccelerateUserCreationUpdate = atomic.NewBool(DefTiDBAccelerateUserCreationUpdate) ) var ( diff --git a/tests/integrationtest/r/explain.result b/tests/integrationtest/r/explain.result index 6532a554d4fa0..520871ac7c81b 100644 --- a/tests/integrationtest/r/explain.result +++ b/tests/integrationtest/r/explain.result @@ -57,3 +57,28 @@ create view v as select cast(replace(substring_index(substring_index("",',',1),' desc v; Field Type Null Key Default Extra event_id varchar(32) NO NULL +explain format = 'brief' select * from mysql.user where user = 'xxx'; +id estRows task access object operator info +IndexLookUp 10.00 root +├─IndexRangeScan(Build) 10.00 cop[tikv] table:user, index:i_user(User) range:["xxx","xxx"], keep order:false, stats:pseudo +└─TableRowIDScan(Probe) 10.00 cop[tikv] table:user keep order:false, stats:pseudo +explain format = 'brief' select * from mysql.user where user = 'xxx' or user = 'yyy'; +id estRows task access object operator info +IndexLookUp 20.00 root +├─IndexRangeScan(Build) 20.00 cop[tikv] table:user, index:i_user(User) range:["xxx","xxx"], ["yyy","yyy"], keep order:false, stats:pseudo +└─TableRowIDScan(Probe) 20.00 cop[tikv] table:user keep order:false, stats:pseudo +explain format = 'brief' select * from mysql.global_priv where user = 'xxx'; +id estRows task access object operator info +IndexLookUp 10.00 root +├─IndexRangeScan(Build) 10.00 cop[tikv] table:global_priv, index:i_user(User) range:["xxx","xxx"], keep order:false, stats:pseudo +└─TableRowIDScan(Probe) 10.00 cop[tikv] table:global_priv keep order:false, stats:pseudo +explain format = 'brief' select * from mysql.global_grants where user = 'xxx' or user = 'yyy'; +id estRows task access object operator info +IndexLookUp 20.00 root +├─IndexRangeScan(Build) 20.00 cop[tikv] table:global_grants, index:i_user(USER) range:["xxx","xxx"], ["yyy","yyy"], keep order:false, stats:pseudo +└─TableRowIDScan(Probe) 20.00 cop[tikv] table:global_grants keep order:false, stats:pseudo +explain format = 'brief' select * from mysql.db where user = 'xxx'; +id estRows task access object operator info +IndexLookUp 10.00 root +├─IndexRangeScan(Build) 10.00 cop[tikv] table:db, index:i_user(User) range:["xxx","xxx"], keep order:false, stats:pseudo +└─TableRowIDScan(Probe) 10.00 cop[tikv] table:db keep order:false, stats:pseudo diff --git a/tests/integrationtest/t/explain.test b/tests/integrationtest/t/explain.test index ed679d54c199f..a1e63396e4504 100644 --- a/tests/integrationtest/t/explain.test +++ b/tests/integrationtest/t/explain.test @@ -21,3 +21,10 @@ drop table t; drop view if exists v; create view v as select cast(replace(substring_index(substring_index("",',',1),':',-1),'"','') as CHAR(32)) as event_id; desc v; + +# should use index lookup after adding user index, table scan is not expected +explain format = 'brief' select * from mysql.user where user = 'xxx'; +explain format = 'brief' select * from mysql.user where user = 'xxx' or user = 'yyy'; +explain format = 'brief' select * from mysql.global_priv where user = 'xxx'; +explain format = 'brief' select * from mysql.global_grants where user = 'xxx' or user = 'yyy'; +explain format = 'brief' select * from mysql.db where user = 'xxx'; \ No newline at end of file