From c5df916939319384f306e2731e74ecca110c4018 Mon Sep 17 00:00:00 2001 From: luankz <17682333171@163.com> Date: Tue, 26 Aug 2025 13:11:01 +0800 Subject: [PATCH 1/3] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=20RenewAccessToke?= =?UTF-8?q?n=20=E6=9C=AA=E6=9B=B4=E6=96=B0=20redis=20ttl?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- session/global.go | 2 +- session/memory.go | 4 ++ session/middleware_builder.go | 90 +++++++++++++++++++++++++++++++---- session/redis/provider.go | 9 +++- session/redis/session.go | 4 ++ session/types.go | 2 + 6 files changed, 100 insertions(+), 11 deletions(-) diff --git a/session/global.go b/session/global.go index 55caef9..2fc8069 100644 --- a/session/global.go +++ b/session/global.go @@ -49,7 +49,7 @@ func DefaultProvider() Provider { } func CheckLoginMiddleware() gin.HandlerFunc { - return (&MiddlewareBuilder{sp: defaultProvider, Threshold: time.Minute * 30}).Build() + return NewMiddlewareBuilder(defaultProvider, WithThreshold(time.Minute*30)).Build() } func RenewAccessToken(ctx *gctx.Context) error { diff --git a/session/memory.go b/session/memory.go index c12541f..e17cdc3 100644 --- a/session/memory.go +++ b/session/memory.go @@ -67,3 +67,7 @@ func (m *MemorySession) Get(ctx context.Context, key string) ekit.AnyValue { func (m *MemorySession) Claims() Claims { return m.claims } + +func (m *MemorySession) Expire(ctx context.Context) error { + return nil +} diff --git a/session/middleware_builder.go b/session/middleware_builder.go index ae79b5b..900af8a 100644 --- a/session/middleware_builder.go +++ b/session/middleware_builder.go @@ -17,37 +17,109 @@ package session import ( "log/slog" "net/http" + "sync" "time" "github.com/ecodeclub/ginx/gctx" "github.com/gin-gonic/gin" ) +// MiddlewareOption 定义 middleware 的配置选项 +type MiddlewareOption func(*MiddlewareBuilder) + +// WithThreshold 设置续期阈值 +func WithThreshold(threshold time.Duration) MiddlewareOption { + return func(b *MiddlewareBuilder) { + b.threshold = threshold + } +} + +// WithConcurrencyControl 启用并发控制 +func WithConcurrencyControl() MiddlewareOption { + return func(b *MiddlewareBuilder) { + b.enableConcurrencyControl = true + } +} + +// WithLogger 设置自定义日志器 +func WithLogger(logger *slog.Logger) MiddlewareOption { + return func(b *MiddlewareBuilder) { + b.logger = logger + } +} + // MiddlewareBuilder 登录校验 type MiddlewareBuilder struct { sp Provider // 当 token 的有效时间少于这个值的时候,就会刷新一下 token - Threshold time.Duration + threshold time.Duration + logger *slog.Logger + + // 用于并发控制的锁 + enableConcurrencyControl bool + renewalLocks sync.Map +} + +func NewMiddlewareBuilder(sp Provider, opts ...MiddlewareOption) *MiddlewareBuilder { + builder := &MiddlewareBuilder{ + sp: sp, + threshold: time.Minute * 30, + enableConcurrencyControl: false, + logger: slog.Default(), + } + + // 应用选项 + for _, opt := range opts { + opt(builder) + } + + return builder } func (b *MiddlewareBuilder) Build() gin.HandlerFunc { - threshold := b.Threshold.Milliseconds() + threshold := b.threshold.Milliseconds() return func(ctx *gin.Context) { - ctxx := &gctx.Context{Context: ctx} - sess, err := b.sp.Get(ctxx) + gCtx := &gctx.Context{Context: ctx} + sess, err := b.sp.Get(gCtx) if err != nil { - slog.Debug("未授权", slog.Any("err", err)) + b.logger.Debug("未授权", slog.Any("err", err)) ctx.AbortWithStatus(http.StatusUnauthorized) return } expiration := sess.Claims().Expiration if expiration-time.Now().UnixMilli() < threshold { - // 刷新一个token - err = b.sp.RenewAccessToken(ctxx) - if err != nil { - slog.Warn("刷新 token 失败", slog.String("err", err.Error())) + // 如果需要并发控制,使用锁机制 + if b.enableConcurrencyControl { + b.renewWithConcurrencyControl(gCtx, sess.Claims().SSID) + } else { + // 直接续期 + err = b.sp.RenewAccessToken(gCtx) + if err != nil { + b.logger.Warn("刷新 token 失败", slog.String("err", err.Error())) + } } } ctx.Set(CtxSessionKey, sess) } } + +// renewWithConcurrencyControl 使用并发控制进行续期 +func (b *MiddlewareBuilder) renewWithConcurrencyControl(ctx *gctx.Context, ssid string) { + // 获取或创建该 SSID 的锁 + lockInterface, _ := b.renewalLocks.LoadOrStore(ssid, &sync.Mutex{}) + lock := lockInterface.(*sync.Mutex) + + // 尝试获取锁 + if lock.TryLock() { + defer lock.Unlock() + + // 执行续期 + err := b.sp.RenewAccessToken(ctx) + if err != nil { + b.logger.Warn("刷新 token 失败", slog.String("err", err.Error())) + } + } else { + // 如果获取不到锁,说明其他请求正在处理续期,跳过本次续期 + b.logger.Debug("跳过续期,其他请求正在处理", slog.String("ssid", ssid)) + } +} diff --git a/session/redis/provider.go b/session/redis/provider.go index fa80b24..2e77cc0 100644 --- a/session/redis/provider.go +++ b/session/redis/provider.go @@ -67,10 +67,17 @@ func (rsp *SessionProvider) RenewAccessToken(ctx *ginx.Context) error { if err != nil { return err } + + // 更新 claims 中的过期时间为当前时间加上配置的过期时间 claims := jwtClaims.Data + claims.Expiration = time.Now().Add(rsp.expiration).UnixMilli() + + // 生成新的 access token accessToken, err := rsp.m.GenerateAccessToken(claims) rsp.TokenCarrier.Inject(ctx, accessToken) - return err + + // redis 进行续期 + return newRedisSession(claims.SSID, rsp.expiration, rsp.client, claims).Expire(ctx) } // NewSession 的时候,要先把这个 data 写入到对应的 token 里面 diff --git a/session/redis/session.go b/session/redis/session.go index 327269e..22bd68f 100644 --- a/session/redis/session.go +++ b/session/redis/session.go @@ -70,6 +70,10 @@ func (sess *Session) Claims() session.Claims { return sess.claims } +func (sess *Session) Expire(ctx context.Context) error { + return sess.client.Expire(ctx, sess.key, sess.expiration).Err() +} + func newRedisSession( ssid string, expiration time.Duration, diff --git a/session/types.go b/session/types.go index f734683..46c2720 100644 --- a/session/types.go +++ b/session/types.go @@ -34,6 +34,8 @@ type Session interface { Destroy(ctx context.Context) error // Claims 编码进去了 JWT 里面的数据 Claims() Claims + // Expire session 续期 + Expire(ctx context.Context) error } // Provider 定义了 Session 的整个管理机制。 From f57400bf258190e2773debf86a7731b937e2000a Mon Sep 17 00:00:00 2001 From: luankz <17682333171@163.com> Date: Wed, 3 Sep 2025 21:00:43 +0800 Subject: [PATCH 2/3] =?UTF-8?q?chore:=20=E5=88=A0=E9=99=A4=E5=A4=9A?= =?UTF-8?q?=E4=BD=99=E4=B8=94=E5=B9=B6=E4=B8=8D=E4=BC=98=E9=9B=85=E7=9A=84?= =?UTF-8?q?=E8=AE=BE=E8=AE=A1=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- session/global.go | 2 +- session/memory.go | 4 -- session/middleware_builder.go | 90 ++++------------------------------- session/redis/provider.go | 4 +- session/redis/session.go | 6 ++- session/types.go | 2 - 6 files changed, 17 insertions(+), 91 deletions(-) diff --git a/session/global.go b/session/global.go index 2fc8069..55caef9 100644 --- a/session/global.go +++ b/session/global.go @@ -49,7 +49,7 @@ func DefaultProvider() Provider { } func CheckLoginMiddleware() gin.HandlerFunc { - return NewMiddlewareBuilder(defaultProvider, WithThreshold(time.Minute*30)).Build() + return (&MiddlewareBuilder{sp: defaultProvider, Threshold: time.Minute * 30}).Build() } func RenewAccessToken(ctx *gctx.Context) error { diff --git a/session/memory.go b/session/memory.go index e17cdc3..c12541f 100644 --- a/session/memory.go +++ b/session/memory.go @@ -67,7 +67,3 @@ func (m *MemorySession) Get(ctx context.Context, key string) ekit.AnyValue { func (m *MemorySession) Claims() Claims { return m.claims } - -func (m *MemorySession) Expire(ctx context.Context) error { - return nil -} diff --git a/session/middleware_builder.go b/session/middleware_builder.go index 900af8a..ae79b5b 100644 --- a/session/middleware_builder.go +++ b/session/middleware_builder.go @@ -17,109 +17,37 @@ package session import ( "log/slog" "net/http" - "sync" "time" "github.com/ecodeclub/ginx/gctx" "github.com/gin-gonic/gin" ) -// MiddlewareOption 定义 middleware 的配置选项 -type MiddlewareOption func(*MiddlewareBuilder) - -// WithThreshold 设置续期阈值 -func WithThreshold(threshold time.Duration) MiddlewareOption { - return func(b *MiddlewareBuilder) { - b.threshold = threshold - } -} - -// WithConcurrencyControl 启用并发控制 -func WithConcurrencyControl() MiddlewareOption { - return func(b *MiddlewareBuilder) { - b.enableConcurrencyControl = true - } -} - -// WithLogger 设置自定义日志器 -func WithLogger(logger *slog.Logger) MiddlewareOption { - return func(b *MiddlewareBuilder) { - b.logger = logger - } -} - // MiddlewareBuilder 登录校验 type MiddlewareBuilder struct { sp Provider // 当 token 的有效时间少于这个值的时候,就会刷新一下 token - threshold time.Duration - logger *slog.Logger - - // 用于并发控制的锁 - enableConcurrencyControl bool - renewalLocks sync.Map -} - -func NewMiddlewareBuilder(sp Provider, opts ...MiddlewareOption) *MiddlewareBuilder { - builder := &MiddlewareBuilder{ - sp: sp, - threshold: time.Minute * 30, - enableConcurrencyControl: false, - logger: slog.Default(), - } - - // 应用选项 - for _, opt := range opts { - opt(builder) - } - - return builder + Threshold time.Duration } func (b *MiddlewareBuilder) Build() gin.HandlerFunc { - threshold := b.threshold.Milliseconds() + threshold := b.Threshold.Milliseconds() return func(ctx *gin.Context) { - gCtx := &gctx.Context{Context: ctx} - sess, err := b.sp.Get(gCtx) + ctxx := &gctx.Context{Context: ctx} + sess, err := b.sp.Get(ctxx) if err != nil { - b.logger.Debug("未授权", slog.Any("err", err)) + slog.Debug("未授权", slog.Any("err", err)) ctx.AbortWithStatus(http.StatusUnauthorized) return } expiration := sess.Claims().Expiration if expiration-time.Now().UnixMilli() < threshold { - // 如果需要并发控制,使用锁机制 - if b.enableConcurrencyControl { - b.renewWithConcurrencyControl(gCtx, sess.Claims().SSID) - } else { - // 直接续期 - err = b.sp.RenewAccessToken(gCtx) - if err != nil { - b.logger.Warn("刷新 token 失败", slog.String("err", err.Error())) - } + // 刷新一个token + err = b.sp.RenewAccessToken(ctxx) + if err != nil { + slog.Warn("刷新 token 失败", slog.String("err", err.Error())) } } ctx.Set(CtxSessionKey, sess) } } - -// renewWithConcurrencyControl 使用并发控制进行续期 -func (b *MiddlewareBuilder) renewWithConcurrencyControl(ctx *gctx.Context, ssid string) { - // 获取或创建该 SSID 的锁 - lockInterface, _ := b.renewalLocks.LoadOrStore(ssid, &sync.Mutex{}) - lock := lockInterface.(*sync.Mutex) - - // 尝试获取锁 - if lock.TryLock() { - defer lock.Unlock() - - // 执行续期 - err := b.sp.RenewAccessToken(ctx) - if err != nil { - b.logger.Warn("刷新 token 失败", slog.String("err", err.Error())) - } - } else { - // 如果获取不到锁,说明其他请求正在处理续期,跳过本次续期 - b.logger.Debug("跳过续期,其他请求正在处理", slog.String("ssid", ssid)) - } -} diff --git a/session/redis/provider.go b/session/redis/provider.go index 2e77cc0..d18f8c9 100644 --- a/session/redis/provider.go +++ b/session/redis/provider.go @@ -76,8 +76,8 @@ func (rsp *SessionProvider) RenewAccessToken(ctx *ginx.Context) error { accessToken, err := rsp.m.GenerateAccessToken(claims) rsp.TokenCarrier.Inject(ctx, accessToken) - // redis 进行续期 - return newRedisSession(claims.SSID, rsp.expiration, rsp.client, claims).Expire(ctx) + // Redis 续期 Session + return rsp.client.Expire(ctx, sessionKey(claims.SSID), rsp.expiration).Err() } // NewSession 的时候,要先把这个 data 写入到对应的 token 里面 diff --git a/session/redis/session.go b/session/redis/session.go index 22bd68f..1f8a513 100644 --- a/session/redis/session.go +++ b/session/redis/session.go @@ -80,8 +80,12 @@ func newRedisSession( client redis.Cmdable, cl session.Claims) *Session { return &Session{ client: client, - key: "session:" + ssid, + key: sessionKey(ssid), expiration: expiration, claims: cl, } } + +func sessionKey(ssid string) string { + return "session:" + ssid +} diff --git a/session/types.go b/session/types.go index 46c2720..f734683 100644 --- a/session/types.go +++ b/session/types.go @@ -34,8 +34,6 @@ type Session interface { Destroy(ctx context.Context) error // Claims 编码进去了 JWT 里面的数据 Claims() Claims - // Expire session 续期 - Expire(ctx context.Context) error } // Provider 定义了 Session 的整个管理机制。 From 1f057e31bda398bc0090e1ff5779c78a38f9dc3f Mon Sep 17 00:00:00 2001 From: luankz <17682333171@163.com> Date: Wed, 3 Sep 2025 21:01:30 +0800 Subject: [PATCH 3/3] =?UTF-8?q?chore:=20=E5=88=A0=E9=99=A4=E5=A4=9A?= =?UTF-8?q?=E4=BD=99=E4=B8=94=E5=B9=B6=E4=B8=8D=E4=BC=98=E9=9B=85=E7=9A=84?= =?UTF-8?q?=E8=AE=BE=E8=AE=A1=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- session/redis/session.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/session/redis/session.go b/session/redis/session.go index 1f8a513..01964a9 100644 --- a/session/redis/session.go +++ b/session/redis/session.go @@ -70,10 +70,6 @@ func (sess *Session) Claims() session.Claims { return sess.claims } -func (sess *Session) Expire(ctx context.Context) error { - return sess.client.Expire(ctx, sess.key, sess.expiration).Err() -} - func newRedisSession( ssid string, expiration time.Duration,