diff --git a/src/actions/active-sessions.ts b/src/actions/active-sessions.ts index e03cb7731..65fc63deb 100644 --- a/src/actions/active-sessions.ts +++ b/src/actions/active-sessions.ts @@ -854,9 +854,9 @@ export async function terminateActiveSession(sessionId: string): Promise | null { } } +export class TerminatedSessionError extends Error { + constructor( + public readonly sessionId: string, + public readonly terminatedAt: string | null = null + ) { + // 注意:此错误的 message 不应作为用户可见文案;用户提示由 HTTP 层/ProxySessionGuard 统一映射。 + super("ERR_TERMINATED_SESSION"); + this.name = "TerminatedSessionError"; + } +} + +export type TerminateSessionResult = { + markerOk: boolean; + deletedKeys: number; +}; + type SessionRequestMeta = { url: string; method: string; @@ -105,6 +122,76 @@ export class SessionManager { ); // 短上下文阈值 private static readonly ENABLE_SHORT_CONTEXT_DETECTION = process.env.ENABLE_SHORT_CONTEXT_DETECTION !== "false"; // 默认启用 + // 会话终止标记 TTL(单位:秒) + // 规范环境变量:SESSION_TERMINATION_TTL_SECONDS + // 兼容旧名(计划弃用):SESSION_TERMINATION_TTL / TERMINATED_SESSION_TTL + private static readonly TERMINATED_SESSION_TTL = (() => { + const normalize = (value: string | undefined): string | undefined => { + if (typeof value !== "string") return undefined; + const trimmed = value.trim(); + return trimmed.length > 0 ? trimmed : undefined; + }; + + const rawPrimary = normalize(process.env.SESSION_TERMINATION_TTL_SECONDS); + const rawLegacyA = normalize(process.env.SESSION_TERMINATION_TTL); + const rawLegacyB = normalize(process.env.TERMINATED_SESSION_TTL); + + if (!rawPrimary && (rawLegacyA || rawLegacyB)) { + logger.warn("SessionManager: Deprecated termination TTL env var detected", { + SESSION_TERMINATION_TTL: rawLegacyA ? "set" : "unset", + TERMINATED_SESSION_TTL: rawLegacyB ? "set" : "unset", + preferred: "SESSION_TERMINATION_TTL_SECONDS", + }); + } + + const raw = rawPrimary ?? rawLegacyA ?? rawLegacyB; + const parsed = raw ? Number.parseInt(raw, 10) : Number.NaN; + if (Number.isFinite(parsed) && parsed > 0) { + return parsed; + } + return 24 * 60 * 60; // 1 天 + })(); + + private static readonly SCAN_COUNT = 100; + private static readonly TERMINATE_SCAN_COUNT = 200; + + private static getTerminationMarkerKey(sessionId: string): string { + return `session:${sessionId}:terminated`; + } + + private static async readTerminationMarker( + redis: Redis, + sessionId: string + ): Promise { + const terminatedKey = SessionManager.getTerminationMarkerKey(sessionId); + try { + const value = await redis.get(terminatedKey); + if (typeof value !== "string" || value.length === 0) { + return null; + } + return value; + } catch (error) { + logger.error("SessionManager: Failed to read termination marker", { + error, + sessionId, + }); + return null; + } + } + + /** + * 将用户可控的字符串安全地嵌入 Redis `SCAN MATCH` glob pattern 中(按字面量匹配)。 + * + * Redis glob 语法中 `* ? [] \\` 都具有特殊含义,因此需要转义以避免误匹配/误删。 + */ + private static escapeRedisMatchPatternLiteral(value: string): string { + return value + .replaceAll("\\", "\\\\") + .replaceAll("*", "\\*") + .replaceAll("?", "\\?") + .replaceAll("[", "\\[") + .replaceAll("]", "\\]"); + } /** * 获取 STORE_SESSION_MESSAGES 配置 @@ -358,6 +445,21 @@ export class SessionManager { // 1. 优先使用客户端传递的 session_id (来自 metadata.user_id 或 metadata.session_id) if (clientSessionId) { + // Fail-open:Redis 不可用时,不阻断请求(避免因 Redis 故障导致全站请求失败)。 + // 代价:Redis 故障窗口内无法强制执行 terminated marker,因此 TerminatedSessionError 不会抛出。 + if (redis && redis.status === "ready") { + const terminatedAt = await SessionManager.readTerminationMarker(redis, clientSessionId); + if (terminatedAt) { + logger.info("SessionManager: Rejected terminated client session", { + keyId, + sessionId: clientSessionId, + terminatedAt, + messagesLength, + }); + throw new TerminatedSessionError(clientSessionId, terminatedAt); + } + } + // 2. 短上下文并发检测(方案E) if ( SessionManager.ENABLE_SHORT_CONTEXT_DETECTION && @@ -423,13 +525,25 @@ export class SessionManager { const existingSessionId = await redis.get(hashKey); if (existingSessionId) { - // 找到已有 session,刷新 TTL - await SessionManager.refreshSessionTTL(existingSessionId); - logger.trace("SessionManager: Reusing session via hash", { - sessionId: existingSessionId, - hash: contentHash, - }); - return existingSessionId; + const terminatedAt = await SessionManager.readTerminationMarker(redis, existingSessionId); + if (terminatedAt) { + logger.info( + "SessionManager: Hash hit but session was terminated, creating new session", + { + existingSessionId, + terminatedAt, + hash: contentHash, + } + ); + } else { + // 找到已有 session,刷新 TTL + await SessionManager.refreshSessionTTL(existingSessionId); + logger.trace("SessionManager: Reusing session via hash", { + sessionId: existingSessionId, + hash: contentHash, + }); + return existingSessionId; + } } // 未找到:创建新 session @@ -1154,7 +1268,7 @@ export class SessionManager { "MATCH", "session:*:info", "COUNT", - 100 + SessionManager.SCAN_COUNT )) as [string, string[]]; cursor = nextCursor; @@ -1249,7 +1363,7 @@ export class SessionManager { "MATCH", "session:*:info", "COUNT", - 100 + SessionManager.SCAN_COUNT )) as [string, string[]]; cursor = nextCursor; @@ -1332,14 +1446,15 @@ export class SessionManager { } // 2. 检查新格式:使用 SCAN 搜索 session:{sessionId}:req:*:messages + const escapedSessionId = SessionManager.escapeRedisMatchPatternLiteral(sessionId); let cursor = "0"; do { const [nextCursor, keys] = (await redis.scan( cursor, "MATCH", - `session:${sessionId}:req:*:messages`, + `session:${escapedSessionId}:req:*:messages`, "COUNT", - 100 + SessionManager.SCAN_COUNT )) as [string, string[]]; cursor = nextCursor; @@ -1927,20 +2042,44 @@ export class SessionManager { /** * 终止 Session(主动打断) * - * 功能:删除 Session 在 Redis 中的所有绑定关系,强制下次请求重新选择供应商 - * 用途:管理员主动打断长时间占用同一供应商的 Session + * 功能:写入“终止标记”并清理 Redis 中所有 session:{id}:* 相关 key + * 影响:客户端后续继续携带同一 sessionId 时,将被阻断(getOrCreateSessionId 抛出 TerminatedSessionError) * * @param sessionId - Session ID - * @returns 是否成功删除 + * @returns markerOk: 是否成功写入终止标记; deletedKeys: 清理掉的 key 数量(不含 terminated 标记) */ - static async terminateSession(sessionId: string): Promise { + static async terminateSession(sessionId: string): Promise { const redis = getRedisClient(); if (!redis || redis.status !== "ready") { logger.warn("SessionManager: Redis not ready, cannot terminate session"); - return false; + return { markerOk: false, deletedKeys: 0 }; } + let markerOk = false; + let deletedKeys = 0; + try { + const terminatedKey = SessionManager.getTerminationMarkerKey(sessionId); + const terminatedAt = Date.now().toString(); + const ttlSeconds = SessionManager.TERMINATED_SESSION_TTL; + + // 0. 标记终止(优先写入,避免并发请求在清理窗口内复活) + // 说明:这里允许覆盖旧值,用于刷新 TTL(多次终止时延长阻断窗口)。 + const markerResult = await redis.set(terminatedKey, terminatedAt, "EX", ttlSeconds); + markerOk = markerResult === "OK"; + + if (!markerOk) { + logger.warn( + "SessionManager: Failed to set termination marker; cleanup will still proceed (session may be reusable)", + { + sessionId, + terminatedKey, + terminatedAt, + ttlSeconds, + } + ); + } + // 1. 先查询绑定信息(用于从 ZSET 中移除) let providerId: number | null = null; let keyId: number | null = null; @@ -1957,6 +2096,14 @@ export class SessionManager { keyId = keyIdStr ? parseInt(keyIdStr, 10) : null; userId = userIdStr ? parseInt(userIdStr, 10) : null; + if (!Number.isFinite(providerId)) { + providerId = null; + } + + if (!Number.isFinite(keyId)) { + keyId = null; + } + if (!Number.isFinite(userId)) { userId = null; } @@ -1971,48 +2118,106 @@ export class SessionManager { ); } - // 2. 删除所有 Session 相关的 key + // 2. 从 ZSET 中移除(始终尝试,即使查询失败) const pipeline = redis.pipeline(); + const zremKeys: string[] = []; - // 基础绑定信息 - pipeline.del(`session:${sessionId}:provider`); - pipeline.del(`session:${sessionId}:key`); - pipeline.del(`session:${sessionId}:info`); - pipeline.del(`session:${sessionId}:last_seen`); - pipeline.del(`session:${sessionId}:concurrent_count`); + const globalKey = getGlobalActiveSessionsKey(); + pipeline.zrem(globalKey, sessionId); + zremKeys.push(globalKey); - // 可选:messages 和 response(如果启用了存储) - pipeline.del(`session:${sessionId}:messages`); - pipeline.del(`session:${sessionId}:response`); - - // 3. 从 ZSET 中移除(始终尝试,即使查询失败) - pipeline.zrem(getGlobalActiveSessionsKey(), sessionId); + if (providerId !== null) { + const key = `provider:${providerId}:active_sessions`; + pipeline.zrem(key, sessionId); + zremKeys.push(key); + } - if (providerId) { - pipeline.zrem(`provider:${providerId}:active_sessions`, sessionId); + if (keyId !== null) { + const key = getKeyActiveSessionsKey(keyId); + pipeline.zrem(key, sessionId); + zremKeys.push(key); } - if (keyId) { - pipeline.zrem(getKeyActiveSessionsKey(keyId), sessionId); + if (userId !== null) { + const key = getUserActiveSessionsKey(userId); + pipeline.zrem(key, sessionId); + zremKeys.push(key); } - if (userId) { - pipeline.zrem(getUserActiveSessionsKey(userId), sessionId); + try { + const results = await pipeline.exec(); + if (results) { + for (let i = 0; i < results.length; i++) { + const [err] = results[i]; + if (!err) continue; + logger.warn("SessionManager: Failed to remove session from active_sessions ZSET", { + sessionId, + zsetKey: zremKeys[i], + providerId, + keyId, + userId, + error: err, + }); + } + } + } catch (zremError) { + logger.warn("SessionManager: Failed to cleanup active_sessions ZSET, continuing", { + sessionId, + providerId, + keyId, + userId, + error: zremError, + }); } - // 4. 删除 hash 映射(如果存在) - // 注意:无法直接反查 hash,只能清理已知的 session key - // hash 会在 TTL 后自动过期,不影响功能 + // 3. 删除 session:* 相关 key(包含 req:* 新格式;保留 terminated 标记) + const escapedSessionId = SessionManager.escapeRedisMatchPatternLiteral(sessionId); + const matchPattern = `session:${escapedSessionId}:*`; + + // 说明:Redis SCAN 不提供快照语义;为了减少并发窗口下的遗漏,这里最多执行两轮全量扫描清理。 + const MAX_SCAN_ROUNDS = 2; + for (let round = 0; round < MAX_SCAN_ROUNDS; round++) { + let cursor = "0"; + let deletedInRound = 0; + + do { + const scanResult = (await redis.scan( + cursor, + "MATCH", + matchPattern, + "COUNT", + SessionManager.TERMINATE_SCAN_COUNT + )) as [string, string[]]; + const nextCursor = scanResult[0]; + const keys = scanResult[1] ?? []; + cursor = nextCursor; + + if (keys.length === 0) continue; + + const deletePipeline = redis.pipeline(); + let hasDeletes = false; + for (const key of keys) { + if (key === terminatedKey) continue; + deletePipeline.del(key); + hasDeletes = true; + } + + // 如果这一页 SCAN 只返回了 terminatedKey,则无需发起空 pipeline.exec()。 + if (!hasDeletes) continue; - const results = await pipeline.exec(); + const deleteResults = await deletePipeline.exec(); + if (!deleteResults) continue; - // 5. 检查结果 - let deletedKeys = 0; - if (results) { - for (const [err, result] of results) { - if (!err && typeof result === "number" && result > 0) { - deletedKeys += result; + for (const [err, result] of deleteResults) { + if (!err && typeof result === "number" && result > 0) { + deletedInRound += result; + } } + } while (cursor !== "0"); + + deletedKeys += deletedInRound; + if (deletedInRound === 0) { + break; } } @@ -2020,16 +2225,19 @@ export class SessionManager { sessionId, providerId, keyId, + userId, deletedKeys, + terminatedAt, + markerOk, }); - return deletedKeys > 0; + return { markerOk, deletedKeys }; } catch (error) { logger.error("SessionManager: Failed to terminate session", { error, sessionId, }); - return false; + return { markerOk, deletedKeys }; } } @@ -2061,8 +2269,8 @@ export class SessionManager { const chunk = sessionIds.slice(i, i + CHUNK_SIZE); const results = await Promise.all( chunk.map(async (sessionId) => { - const success = await SessionManager.terminateSession(sessionId); - return success ? 1 : 0; + const result = await SessionManager.terminateSession(sessionId); + return result.markerOk ? 1 : 0; }) ); successCount += results.reduce((sum, value) => sum + value, 0); diff --git a/tests/unit/components/form/client-restrictions-editor.test.tsx b/tests/unit/components/form/client-restrictions-editor.test.tsx index f9df775a4..3241235b2 100644 --- a/tests/unit/components/form/client-restrictions-editor.test.tsx +++ b/tests/unit/components/form/client-restrictions-editor.test.tsx @@ -7,9 +7,13 @@ import { act } from "react"; import { createRoot } from "react-dom/client"; import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; -vi.mock("@/lib/client-restrictions/client-presets", () => ({ - CLIENT_RESTRICTION_PRESET_OPTIONS: [], -})); +vi.mock("@/lib/client-restrictions/client-presets", async (importOriginal) => { + const actual = await importOriginal(); + return { + ...actual, + CLIENT_RESTRICTION_PRESET_OPTIONS: [], + }; +}); vi.mock("@/components/ui/tag-input", () => ({ TagInput: vi.fn(() => null), @@ -40,10 +44,21 @@ function getTagInputOnChange(callIndex: number): (values: string[]) => void { return (call[0] as TagInputProps).onChange; } -describe("ClientRestrictionsEditor", () => { +describe("ClientRestrictionsEditor - custom clients", () => { const onAllowedChange = vi.fn(); const onBlockedChange = vi.fn(); + const translations = { + allowAction: "Allow", + blockAction: "Block", + customAllowedLabel: "Custom Allowed", + customAllowedPlaceholder: "", + customBlockedLabel: "Custom Blocked", + customBlockedPlaceholder: "", + customHelp: "", + presetClients: {}, + }; + beforeEach(() => { vi.clearAllMocks(); }); @@ -61,75 +76,28 @@ describe("ClientRestrictionsEditor", () => { blocked={blocked} onAllowedChange={onAllowedChange} onBlockedChange={onBlockedChange} - allowedLabel="Allowed" - blockedLabel="Blocked" - getPresetLabel={(v) => v} + translations={translations} /> ); } - describe("uniqueOrdered normalization", () => { - it("deduplicates values preserving first occurrence order", () => { - const unmount = renderEditor([], []); - act(() => getTagInputOnChange(0)(["a", "b", "a", "c"])); - expect(onAllowedChange).toHaveBeenCalledWith(["a", "b", "c"]); - unmount(); - }); - - it("trims whitespace from values", () => { - const unmount = renderEditor([], []); - act(() => getTagInputOnChange(0)([" a ", " b", "c "])); - expect(onAllowedChange).toHaveBeenCalledWith(["a", "b", "c"]); - unmount(); - }); - - it("filters out empty and whitespace-only entries", () => { - const unmount = renderEditor([], []); - act(() => getTagInputOnChange(0)(["a", "", " ", "b"])); - expect(onAllowedChange).toHaveBeenCalledWith(["a", "b"]); - unmount(); - }); + it("custom allowed: should deduplicate values preserving order", () => { + const unmount = renderEditor([], []); + + act(() => getTagInputOnChange(0)(["a", "b", "a", "c"])); + + expect(onAllowedChange).toHaveBeenCalledWith(["a", "b", "c"]); + expect(onBlockedChange).not.toHaveBeenCalled(); + unmount(); }); - describe("allow/block mutual exclusion", () => { - it("removes overlapping items from blocked when added to allowed", () => { - const unmount = renderEditor([], ["b", "c"]); - act(() => getTagInputOnChange(0)(["a", "b"])); - expect(onAllowedChange).toHaveBeenCalledWith(["a", "b"]); - expect(onBlockedChange).toHaveBeenCalledWith(["c"]); - unmount(); - }); - - it("does not call onBlockedChange when allowed has no overlap with blocked", () => { - const unmount = renderEditor([], ["c", "d"]); - act(() => getTagInputOnChange(0)(["a", "b"])); - expect(onAllowedChange).toHaveBeenCalledWith(["a", "b"]); - expect(onBlockedChange).not.toHaveBeenCalled(); - unmount(); - }); - - it("removes overlapping items from allowed when added to blocked", () => { - const unmount = renderEditor(["a", "b"], []); - act(() => getTagInputOnChange(1)(["b", "c"])); - expect(onBlockedChange).toHaveBeenCalledWith(["b", "c"]); - expect(onAllowedChange).toHaveBeenCalledWith(["a"]); - unmount(); - }); - - it("does not call onAllowedChange when blocked has no overlap with allowed", () => { - const unmount = renderEditor(["a", "b"], []); - act(() => getTagInputOnChange(1)(["c", "d"])); - expect(onBlockedChange).toHaveBeenCalledWith(["c", "d"]); - expect(onAllowedChange).not.toHaveBeenCalled(); - unmount(); - }); - - it("clears all blocked when all items are moved to allowed", () => { - const unmount = renderEditor([], ["x", "y"]); - act(() => getTagInputOnChange(0)(["x", "y", "z"])); - expect(onAllowedChange).toHaveBeenCalledWith(["x", "y", "z"]); - expect(onBlockedChange).toHaveBeenCalledWith([]); - unmount(); - }); + it("custom blocked: should deduplicate values preserving order", () => { + const unmount = renderEditor([], []); + + act(() => getTagInputOnChange(1)(["x", "x", "y"])); + + expect(onBlockedChange).toHaveBeenCalledWith(["x", "y"]); + expect(onAllowedChange).not.toHaveBeenCalled(); + unmount(); }); }); diff --git a/tests/unit/lib/session-manager-scan-pattern-escape.test.ts b/tests/unit/lib/session-manager-scan-pattern-escape.test.ts new file mode 100644 index 000000000..575bd110a --- /dev/null +++ b/tests/unit/lib/session-manager-scan-pattern-escape.test.ts @@ -0,0 +1,63 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; + +let redisClientRef: any; +const getRedisClientMock = vi.fn(); + +vi.mock("server-only", () => ({})); + +vi.mock("@/lib/logger", () => ({ + logger: { + warn: vi.fn(), + info: vi.fn(), + error: vi.fn(), + debug: vi.fn(), + trace: vi.fn(), + }, +})); + +vi.mock("@/app/v1/_lib/proxy/errors", () => ({ + sanitizeHeaders: vi.fn(() => "(empty)"), + sanitizeUrl: vi.fn((url: unknown) => String(url)), +})); + +vi.mock("@/lib/session-tracker", () => ({ + SessionTracker: { + getConcurrentCount: vi.fn(async () => 0), + }, +})); + +vi.mock("@/lib/redis", () => ({ + getRedisClient: getRedisClientMock, +})); + +describe("SessionManager.hasAnySessionMessages - scan pattern escaping", () => { + beforeEach(() => { + vi.resetAllMocks(); + vi.resetModules(); + + redisClientRef = { + status: "ready", + exists: vi.fn(async () => 0), + scan: vi.fn(async () => ["0", []]), + }; + + getRedisClientMock.mockReturnValue(redisClientRef); + }); + + it("应对 sessionId 中的 glob 特殊字符进行转义(避免误匹配/误删)", async () => { + const { SessionManager } = await import("@/lib/session-manager"); + + const sessionId = "sess_te*st?[x]"; + const ok = await SessionManager.hasAnySessionMessages(sessionId); + + expect(ok).toBe(false); + expect(redisClientRef.exists).toHaveBeenCalledWith(`session:${sessionId}:messages`); + expect(redisClientRef.scan).toHaveBeenCalledWith( + "0", + "MATCH", + "session:sess_te\\*st\\?\\[x\\]:req:*:messages", + "COUNT", + 100 + ); + }); +}); diff --git a/tests/unit/lib/session-manager-terminate-session.test.ts b/tests/unit/lib/session-manager-terminate-session.test.ts index f4de279ac..7629153e3 100644 --- a/tests/unit/lib/session-manager-terminate-session.test.ts +++ b/tests/unit/lib/session-manager-terminate-session.test.ts @@ -2,6 +2,7 @@ import { beforeEach, describe, expect, it, vi } from "vitest"; let redisClientRef: any; let pipelineRef: any; +let deletePipelineRef: any; vi.mock("server-only", () => ({})); @@ -25,21 +26,33 @@ describe("SessionManager.terminateSession", () => { vi.resetModules(); pipelineRef = { - del: vi.fn(() => pipelineRef), zrem: vi.fn(() => pipelineRef), + exec: vi.fn(async () => []), + }; + + deletePipelineRef = { + del: vi.fn(() => deletePipelineRef), exec: vi.fn(async () => [[null, 1]]), }; redisClientRef = { status: "ready", + set: vi.fn(async () => "OK"), get: vi.fn(async () => null), hget: vi.fn(async () => null), - pipeline: vi.fn(() => pipelineRef), + scan: vi.fn(async () => ["0", []]), + pipeline: vi + .fn() + // 第一次 pipeline:用于 ZSET 清理(global/key/provider/user) + .mockImplementationOnce(() => pipelineRef) + // 后续 pipeline:用于批量删除 session:{id}:* key(可能多页 SCAN) + .mockImplementation(() => deletePipelineRef), }; }); it("应同时从 global/key/user 的 active_sessions ZSET 中移除 sessionId(若可解析到 userId)", async () => { - const sessionId = "sess_test"; + const sessionId = "sess_te*st?[x]"; + const terminatedKey = `session:${sessionId}:terminated`; redisClientRef.get.mockImplementation(async (key: string) => { if (key === `session:${sessionId}:provider`) return "42"; if (key === `session:${sessionId}:key`) return "7"; @@ -49,37 +62,89 @@ describe("SessionManager.terminateSession", () => { if (key === `session:${sessionId}:info` && field === "userId") return "123"; return null; }); + redisClientRef.scan.mockResolvedValueOnce([ + "0", + [ + terminatedKey, + `session:${sessionId}:provider`, + `session:${sessionId}:req:1:messages`, + `session:${sessionId}:req:1:response`, + ], + ]); const { getGlobalActiveSessionsKey, getKeyActiveSessionsKey, getUserActiveSessionsKey } = await import("@/lib/redis/active-session-keys"); const { SessionManager } = await import("@/lib/session-manager"); - const ok = await SessionManager.terminateSession(sessionId); - expect(ok).toBe(true); + const result = await SessionManager.terminateSession(sessionId); + expect(result.markerOk).toBe(true); + expect(redisClientRef.set).toHaveBeenCalledWith(terminatedKey, expect.any(String), "EX", 86400); expect(redisClientRef.hget).toHaveBeenCalledWith(`session:${sessionId}:info`, "userId"); expect(pipelineRef.zrem).toHaveBeenCalledWith(getGlobalActiveSessionsKey(), sessionId); expect(pipelineRef.zrem).toHaveBeenCalledWith("provider:42:active_sessions", sessionId); expect(pipelineRef.zrem).toHaveBeenCalledWith(getKeyActiveSessionsKey(7), sessionId); expect(pipelineRef.zrem).toHaveBeenCalledWith(getUserActiveSessionsKey(123), sessionId); + + expect(deletePipelineRef.del).toHaveBeenCalledWith(`session:${sessionId}:provider`); + expect(deletePipelineRef.del).toHaveBeenCalledWith(`session:${sessionId}:req:1:messages`); + expect(deletePipelineRef.del).toHaveBeenCalledWith(`session:${sessionId}:req:1:response`); + expect(deletePipelineRef.del).not.toHaveBeenCalledWith(terminatedKey); + + // 安全性:SCAN MATCH pattern 必须按字面量匹配 sessionId,避免 glob 注入误删其它 key + expect(redisClientRef.scan).toHaveBeenCalledWith( + "0", + "MATCH", + "session:sess_te\\*st\\?\\[x\\]:*", + "COUNT", + 200 + ); }); it("当 userId 不可用时,不应尝试 zrem user active_sessions key", async () => { const sessionId = "sess_test"; + const terminatedKey = `session:${sessionId}:terminated`; redisClientRef.get.mockImplementation(async (key: string) => { if (key === `session:${sessionId}:provider`) return "42"; if (key === `session:${sessionId}:key`) return "7"; return null; }); redisClientRef.hget.mockResolvedValue(null); + redisClientRef.scan.mockResolvedValueOnce(["0", [terminatedKey]]); const { getUserActiveSessionsKey } = await import("@/lib/redis/active-session-keys"); const { SessionManager } = await import("@/lib/session-manager"); - const ok = await SessionManager.terminateSession(sessionId); - expect(ok).toBe(true); + const result = await SessionManager.terminateSession(sessionId); + expect(result.markerOk).toBe(true); + // SCAN 仅返回 terminatedKey 时,不会发出任何 DEL 命令,因此不应执行 delete pipeline(避免不必要的网络开销)。 + expect(deletePipelineRef.exec).not.toHaveBeenCalled(); expect(pipelineRef.zrem).not.toHaveBeenCalledWith(getUserActiveSessionsKey(123), sessionId); }); + + it("当终止标记写入失败时,markerOk 应为 false(但清理仍会执行)", async () => { + const sessionId = "sess_marker_fail"; + redisClientRef.set.mockResolvedValueOnce(null); + redisClientRef.scan.mockResolvedValueOnce(["0", [`session:${sessionId}:provider`]]); + + const { SessionManager } = await import("@/lib/session-manager"); + const result = await SessionManager.terminateSession(sessionId); + + expect(result.markerOk).toBe(false); + expect(deletePipelineRef.del).toHaveBeenCalledWith(`session:${sessionId}:provider`); + }); + + it("当清理过程抛错时,应尽量保留 markerOk=true(如果终止标记已写入)", async () => { + const sessionId = "sess_cleanup_fail"; + const terminatedKey = `session:${sessionId}:terminated`; + redisClientRef.scan.mockRejectedValueOnce(new Error("scan failed")); + + const { SessionManager } = await import("@/lib/session-manager"); + const result = await SessionManager.terminateSession(sessionId); + + expect(redisClientRef.set).toHaveBeenCalledWith(terminatedKey, expect.any(String), "EX", 86400); + expect(result.markerOk).toBe(true); + }); }); diff --git a/tests/unit/lib/session-manager-terminated-remap.test.ts b/tests/unit/lib/session-manager-terminated-remap.test.ts new file mode 100644 index 000000000..a6132e88c --- /dev/null +++ b/tests/unit/lib/session-manager-terminated-remap.test.ts @@ -0,0 +1,107 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; + +let redisClientRef: any; + +vi.mock("server-only", () => ({})); + +vi.mock("@/lib/logger", () => ({ + logger: { + warn: vi.fn(), + info: vi.fn(), + error: vi.fn(), + debug: vi.fn(), + trace: vi.fn(), + }, +})); + +vi.mock("@/app/v1/_lib/proxy/errors", () => ({ + sanitizeHeaders: vi.fn(() => "(empty)"), + sanitizeUrl: vi.fn((url: unknown) => String(url)), +})); + +vi.mock("@/lib/session-tracker", () => ({ + SessionTracker: { + getConcurrentCount: vi.fn(async () => 0), + }, +})); + +vi.mock("@/lib/redis", () => ({ + getRedisClient: () => redisClientRef, +})); + +function makePipeline() { + const pipeline = { + setex: vi.fn(() => pipeline), + expire: vi.fn(() => pipeline), + exec: vi.fn(async () => []), + }; + return pipeline; +} + +describe("SessionManager.getOrCreateSessionId - terminated blocking", () => { + beforeEach(() => { + vi.resetAllMocks(); + vi.resetModules(); + + redisClientRef = { + status: "ready", + get: vi.fn(async () => null), + pipeline: vi.fn(() => makePipeline()), + }; + }); + + it("未终止时应保持原 sessionId", async () => { + const { SessionManager } = await import("@/lib/session-manager"); + + const keyId = 1; + const oldSessionId = "sess_old"; + const messages = [{ role: "user", content: "hi" }]; + + const sessionId = await SessionManager.getOrCreateSessionId(keyId, messages, oldSessionId); + + expect(sessionId).toBe(oldSessionId); + expect(redisClientRef.get).toHaveBeenCalledWith(`session:${oldSessionId}:terminated`); + }); + + it("已终止时应拒绝复用并抛出 TerminatedSessionError", async () => { + const keyId = 1; + const oldSessionId = "sess_old"; + redisClientRef.get.mockImplementation(async (key: string) => { + if (key === `session:${oldSessionId}:terminated`) return "1"; + return null; + }); + + const { SessionManager, TerminatedSessionError } = await import("@/lib/session-manager"); + + const error = await SessionManager.getOrCreateSessionId(keyId, [], oldSessionId).catch( + (e) => e as any + ); + expect(error).toBeInstanceOf(TerminatedSessionError); + expect(error.sessionId).toBe(oldSessionId); + expect(error.terminatedAt).toBe("1"); + }); + + it("hash 命中已终止 session 时应创建新 session", async () => { + const keyId = 1; + const terminatedSessionId = "sess_terminated"; + + const { SessionManager } = await import("@/lib/session-manager"); + vi.spyOn(SessionManager, "generateSessionId").mockReturnValue("sess_new"); + + redisClientRef.get.mockImplementation(async (key: string) => { + if (key.startsWith("hash:") && key.endsWith(":session")) return terminatedSessionId; + if (key === `session:${terminatedSessionId}:terminated`) return "1"; + return null; + }); + + const messages = [ + { role: "user", content: "hi" }, + { role: "assistant", content: "ok" }, + ]; + + const sessionId = await SessionManager.getOrCreateSessionId(keyId, messages, null); + + expect(sessionId).toBe("sess_new"); + expect(redisClientRef.get).toHaveBeenCalledWith(`session:${terminatedSessionId}:terminated`); + }); +}); diff --git a/tests/unit/proxy/session-guard-terminated-session.test.ts b/tests/unit/proxy/session-guard-terminated-session.test.ts new file mode 100644 index 000000000..9360ef987 --- /dev/null +++ b/tests/unit/proxy/session-guard-terminated-session.test.ts @@ -0,0 +1,140 @@ +import { beforeEach, describe, expect, test, vi } from "vitest"; +import type { ProxySession } from "@/app/v1/_lib/proxy/session"; + +const getCachedSystemSettingsMock = vi.fn(); + +const extractClientSessionIdMock = vi.fn(); +const getOrCreateSessionIdMock = vi.fn(); +const getNextRequestSequenceMock = vi.fn(); +const storeSessionRequestBodyMock = vi.fn(async () => undefined); +const storeSessionClientRequestMetaMock = vi.fn(async () => undefined); +const storeSessionMessagesMock = vi.fn(async () => undefined); +const storeSessionInfoMock = vi.fn(async () => undefined); +const generateSessionIdMock = vi.fn(); + +const trackSessionMock = vi.fn(async () => undefined); + +class TerminatedSessionError extends Error { + sessionId: string; + terminatedAt: string | null; + + constructor(sessionId: string, terminatedAt: string | null = null) { + super("Session has been terminated"); + this.name = "TerminatedSessionError"; + this.sessionId = sessionId; + this.terminatedAt = terminatedAt; + } +} + +vi.mock("@/lib/config", () => ({ + getCachedSystemSettings: () => getCachedSystemSettingsMock(), +})); + +vi.mock("@/lib/session-manager", () => ({ + SessionManager: { + extractClientSessionId: extractClientSessionIdMock, + getOrCreateSessionId: getOrCreateSessionIdMock, + getNextRequestSequence: getNextRequestSequenceMock, + storeSessionRequestBody: storeSessionRequestBodyMock, + storeSessionClientRequestMeta: storeSessionClientRequestMetaMock, + storeSessionMessages: storeSessionMessagesMock, + storeSessionInfo: storeSessionInfoMock, + generateSessionId: generateSessionIdMock, + }, + TerminatedSessionError, +})); + +vi.mock("@/lib/session-tracker", () => ({ + SessionTracker: { + trackSession: trackSessionMock, + }, +})); + +vi.mock("@/lib/logger", () => ({ + logger: { + debug: vi.fn(), + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + fatal: vi.fn(), + trace: vi.fn(), + }, +})); + +async function loadGuard() { + const mod = await import("@/app/v1/_lib/proxy/session-guard"); + return mod.ProxySessionGuard; +} + +function createMockSession(overrides: Partial = {}): ProxySession { + const session: any = { + authState: { + success: true, + user: { id: 1, name: "u" }, + key: { id: 1, name: "k" }, + apiKey: "api-key", + }, + request: { + message: {}, + model: "claude-sonnet-4-5-20250929", + }, + headers: new Headers(), + userAgent: "claude_cli/1.0", + requestUrl: "http://localhost/v1/messages", + method: "POST", + originalFormat: "claude", + + sessionId: null, + setSessionId(id: string) { + this.sessionId = id; + }, + setRequestSequence(seq: number) { + this.requestSequence = seq; + }, + getRequestSequence() { + return this.requestSequence ?? 1; + }, + getMessages() { + return []; + }, + getMessagesLength() { + return 1; + }, + isWarmupRequest() { + return false; + }, + } satisfies Partial; + + return { ...session, ...overrides } as ProxySession; +} + +beforeEach(() => { + vi.clearAllMocks(); + getCachedSystemSettingsMock.mockResolvedValue({ + interceptAnthropicWarmupRequests: false, + enableCodexSessionIdCompletion: false, + }); + extractClientSessionIdMock.mockReturnValue("sess_terminated"); + getNextRequestSequenceMock.mockResolvedValue(1); +}); + +describe("ProxySessionGuard - terminated session", () => { + test("当 clientSessionId 已终止时应阻断请求并抛出 ProxyError(410)", async () => { + const ProxySessionGuard = await loadGuard(); + const session = createMockSession(); + + getOrCreateSessionIdMock.mockRejectedValueOnce( + new TerminatedSessionError("sess_terminated", "1") + ); + + await expect(ProxySessionGuard.ensure(session)).rejects.toMatchObject({ + name: "ProxyError", + statusCode: 410, + message: "Session 已被终止,请创建新的会话后重试", + }); + + expect(generateSessionIdMock).not.toHaveBeenCalled(); + expect(trackSessionMock).not.toHaveBeenCalled(); + expect(session.sessionId).toBeNull(); + }); +}); diff --git a/tests/unit/proxy/session-guard-warmup-intercept.test.ts b/tests/unit/proxy/session-guard-warmup-intercept.test.ts index f7443b936..f2c69a5df 100644 --- a/tests/unit/proxy/session-guard-warmup-intercept.test.ts +++ b/tests/unit/proxy/session-guard-warmup-intercept.test.ts @@ -14,6 +14,18 @@ const generateSessionIdMock = vi.fn(); const trackSessionMock = vi.fn(async () => undefined); +class TerminatedSessionError extends Error { + sessionId: string; + terminatedAt: string | null; + + constructor(sessionId: string, terminatedAt: string | null = null) { + super("Session has been terminated"); + this.name = "TerminatedSessionError"; + this.sessionId = sessionId; + this.terminatedAt = terminatedAt; + } +} + vi.mock("@/lib/config", () => ({ getCachedSystemSettings: () => getCachedSystemSettingsMock(), })); @@ -29,6 +41,7 @@ vi.mock("@/lib/session-manager", () => ({ storeSessionInfo: storeSessionInfoMock, generateSessionId: generateSessionIdMock, }, + TerminatedSessionError, })); vi.mock("@/lib/session-tracker", () => ({ diff --git a/tests/unit/repository/provider.test.ts b/tests/unit/repository/provider.test.ts index 694c29e98..fff4538ef 100644 --- a/tests/unit/repository/provider.test.ts +++ b/tests/unit/repository/provider.test.ts @@ -51,7 +51,7 @@ describe("provider repository - updateProviderPrioritiesBatch", () => { test("returns 0 and does not execute SQL when updates is empty", async () => { vi.resetModules(); - const executeMock = vi.fn(async () => ({ rowCount: 0 })); + const executeMock = vi.fn(async () => ({ count: 0 })); vi.doMock("@/drizzle/db", () => ({ db: { @@ -69,7 +69,7 @@ describe("provider repository - updateProviderPrioritiesBatch", () => { test("generates CASE batch update SQL and returns affected rows", async () => { vi.resetModules(); - const executeMock = vi.fn(async () => ({ rowCount: 2 })); + const executeMock = vi.fn(async () => ({ count: 2 })); vi.doMock("@/drizzle/db", () => ({ db: { @@ -101,7 +101,7 @@ describe("provider repository - updateProviderPrioritiesBatch", () => { test("deduplicates provider ids (last update wins)", async () => { vi.resetModules(); - const executeMock = vi.fn(async () => ({ rowCount: 1 })); + const executeMock = vi.fn(async () => ({ count: 1 })); vi.doMock("@/drizzle/db", () => ({ db: {