diff --git a/.agents/skills/grill-me/SKILL.md b/.agents/skills/grill-me/SKILL.md new file mode 100644 index 000000000..f1543a911 --- /dev/null +++ b/.agents/skills/grill-me/SKILL.md @@ -0,0 +1,8 @@ +--- +name: grill-me +description: Interview the user relentlessly about a plan or design until reaching shared understanding, resolving each branch of the decision tree. Use when user wants to stress-test a plan, get grilled on their design, or mentions "grill me". +--- + +Interview me relentlessly about every aspect of this plan until we reach a shared understanding. Walk down each branch of the design tree, resolving dependencies between decisions one-by-one. For each question, provide your recommended answer. + +If a question can be answered by exploring the codebase, explore the codebase instead. diff --git a/.agents/skills/simplify/SKILL.md b/.agents/skills/simplify/SKILL.md new file mode 100644 index 000000000..a57d263e1 --- /dev/null +++ b/.agents/skills/simplify/SKILL.md @@ -0,0 +1,54 @@ +--- +name: simplify +description: 'Review changed code for reuse, quality, and efficiency, then fix any issues found. Use when the user asks to simplify, clean up, dedupe, reduce abstraction, remove pointless types, collapse duplicate contracts, replace bespoke parsing or validation with direct schema usage, or make a recent diff easier to read and maintain.' +--- + +# Simplify: Code Review and Cleanup + +Review all changed files for reuse, quality, and efficiency. Fix any issues found. + +## Phase 1: Identify Changes + +Run `git diff` (or `git diff HEAD` if there are staged changes) to see what changed. If there are no git changes, review the most recently modified files that the user mentioned or that you edited earlier in this conversation. + +## Phase 2: Launch Three Review Agents in Parallel + +Launch all three agents concurrently in a single message. Pass each agent the full diff so it has the complete context. + +### Agent 1: Code Reuse Review + +For each change: + +1. **Search for existing utilities and helpers** that could replace newly written code. Look for similar patterns elsewhere in the codebase — common locations are utility directories, shared modules, and files adjacent to the changed ones. +2. **Flag any new function that duplicates existing functionality.** Suggest the existing function to use instead. +3. **Flag any inline logic that could use an existing utility** — hand-rolled string manipulation, manual path handling, custom environment checks, ad-hoc type guards, and similar patterns are common candidates. + +### Agent 2: Code Quality Review + +Review the same changes for hacky patterns: + +1. **Redundant state**: state that duplicates existing state, cached values that could be derived, observers/effects that could be direct calls +2. **Parameter sprawl**: adding new parameters to a function instead of generalizing or restructuring existing ones +3. **Copy-paste with slight variation**: near-duplicate code blocks that should be unified with a shared abstraction +4. **Leaky abstractions**: exposing internal details that should be encapsulated, or breaking existing abstraction boundaries +5. **Stringly-typed code**: using raw strings where constants, enums (string unions), or branded types already exist in the codebase +6. **Unnecessary JSX nesting**: wrapper Boxes/elements that add no layout value — check if inner component props (flexShrink, alignItems, etc.) already provide the needed behavior +7. **Unnecessary comments**: comments explaining WHAT the code does (well-named identifiers already do that), narrating the change, or referencing the task/caller — delete; keep only non-obvious WHY (hidden constraints, subtle invariants, workarounds) + +### Agent 3: Efficiency Review + +Review the same changes for efficiency: + +1. **Unnecessary work**: redundant computations, repeated file reads, duplicate network/API calls, N+1 patterns +2. **Missed concurrency**: independent operations run sequentially when they could run in parallel +3. **Hot-path bloat**: new blocking work added to startup or per-request/per-render hot paths +4. **Recurring no-op updates**: state/store updates inside polling loops, intervals, or event handlers that fire unconditionally — add a change-detection guard so downstream consumers aren't notified when nothing changed. Also: if a wrapper function takes an updater/reducer callback, verify it honors same-reference returns (or whatever the "no change" signal is) — otherwise callers' early-return no-ops are silently defeated +5. **Unnecessary existence checks**: pre-checking file/resource existence before operating (TOCTOU anti-pattern) — operate directly and handle the error +6. **Memory**: unbounded data structures, missing cleanup, event listener leaks +7. **Overly broad operations**: reading entire files when only a portion is needed, loading all items when filtering for one + +## Phase 3: Fix Issues + +Wait for all three agents to complete. Aggregate their findings and fix each issue directly. If a finding is a false positive or not worth addressing, note it and move on — do not argue with the finding, just skip it. + +When done, briefly summarize what was fixed (or confirm the code was already clean). diff --git a/.agents/skills/simplify/agents/openai.yaml b/.agents/skills/simplify/agents/openai.yaml new file mode 100644 index 000000000..9dd3596de --- /dev/null +++ b/.agents/skills/simplify/agents/openai.yaml @@ -0,0 +1,4 @@ +interface: + display_name: 'Simplify' + short_description: 'Remove accidental complexity from code' + default_prompt: 'Use $simplify to remove accidental complexity from this diff without changing behavior.' diff --git a/.agents/skills/tdd/SKILL.md b/.agents/skills/tdd/SKILL.md new file mode 100644 index 000000000..78f5077d6 --- /dev/null +++ b/.agents/skills/tdd/SKILL.md @@ -0,0 +1,107 @@ +--- +name: tdd +description: Test-driven development with red-green-refactor loop. Use when user wants to build features or fix bugs using TDD, mentions "red-green-refactor", wants integration tests, or asks for test-first development. +--- + +# Test-Driven Development + +## Philosophy + +**Core principle**: Tests should verify behavior through public interfaces, not implementation details. Code can change entirely; tests shouldn't. + +**Good tests** are integration-style: they exercise real code paths through public APIs. They describe _what_ the system does, not _how_ it does it. A good test reads like a specification - "user can checkout with valid cart" tells you exactly what capability exists. These tests survive refactors because they don't care about internal structure. + +**Bad tests** are coupled to implementation. They mock internal collaborators, test private methods, or verify through external means (like querying a database directly instead of using the interface). The warning sign: your test breaks when you refactor, but behavior hasn't changed. If you rename an internal function and tests fail, those tests were testing implementation, not behavior. + +See [tests.md](tests.md) for examples and [mocking.md](mocking.md) for mocking guidelines. + +## Anti-Pattern: Horizontal Slices + +**DO NOT write all tests first, then all implementation.** This is "horizontal slicing" - treating RED as "write all tests" and GREEN as "write all code." + +This produces **crap tests**: + +- Tests written in bulk test _imagined_ behavior, not _actual_ behavior +- You end up testing the _shape_ of things (data structures, function signatures) rather than user-facing behavior +- Tests become insensitive to real changes - they pass when behavior breaks, fail when behavior is fine +- You outrun your headlights, committing to test structure before understanding the implementation + +**Correct approach**: Vertical slices via tracer bullets. One test → one implementation → repeat. Each test responds to what you learned from the previous cycle. Because you just wrote the code, you know exactly what behavior matters and how to verify it. + +``` +WRONG (horizontal): + RED: test1, test2, test3, test4, test5 + GREEN: impl1, impl2, impl3, impl4, impl5 + +RIGHT (vertical): + RED→GREEN: test1→impl1 + RED→GREEN: test2→impl2 + RED→GREEN: test3→impl3 + ... +``` + +## Workflow + +### 1. Planning + +Before writing any code: + +- [ ] Confirm with user what interface changes are needed +- [ ] Confirm with user which behaviors to test (prioritize) +- [ ] Identify opportunities for [deep modules](deep-modules.md) (small interface, deep implementation) +- [ ] Design interfaces for [testability](interface-design.md) +- [ ] List the behaviors to test (not implementation steps) +- [ ] Get user approval on the plan + +Ask: "What should the public interface look like? Which behaviors are most important to test?" + +**You can't test everything.** Confirm with the user exactly which behaviors matter most. Focus testing effort on critical paths and complex logic, not every possible edge case. + +### 2. Tracer Bullet + +Write ONE test that confirms ONE thing about the system: + +``` +RED: Write test for first behavior → test fails +GREEN: Write minimal code to pass → test passes +``` + +This is your tracer bullet - proves the path works end-to-end. + +### 3. Incremental Loop + +For each remaining behavior: + +``` +RED: Write next test → fails +GREEN: Minimal code to pass → passes +``` + +Rules: + +- One test at a time +- Only enough code to pass current test +- Don't anticipate future tests +- Keep tests focused on observable behavior + +### 4. Refactor + +After all tests pass, look for [refactor candidates](refactoring.md): + +- [ ] Extract duplication +- [ ] Deepen modules (move complexity behind simple interfaces) +- [ ] Apply SOLID principles where natural +- [ ] Consider what new code reveals about existing code +- [ ] Run tests after each refactor step + +**Never refactor while RED.** Get to GREEN first. + +## Checklist Per Cycle + +``` +[ ] Test describes behavior, not implementation +[ ] Test uses public interface only +[ ] Test would survive internal refactor +[ ] Code is minimal for this test +[ ] No speculative features added +``` diff --git a/.agents/skills/tdd/deep-modules.md b/.agents/skills/tdd/deep-modules.md new file mode 100644 index 000000000..0d9720cf1 --- /dev/null +++ b/.agents/skills/tdd/deep-modules.md @@ -0,0 +1,33 @@ +# Deep Modules + +From "A Philosophy of Software Design": + +**Deep module** = small interface + lots of implementation + +``` +┌─────────────────────┐ +│ Small Interface │ ← Few methods, simple params +├─────────────────────┤ +│ │ +│ │ +│ Deep Implementation│ ← Complex logic hidden +│ │ +│ │ +└─────────────────────┘ +``` + +**Shallow module** = large interface + little implementation (avoid) + +``` +┌─────────────────────────────────┐ +│ Large Interface │ ← Many methods, complex params +├─────────────────────────────────┤ +│ Thin Implementation │ ← Just passes through +└─────────────────────────────────┘ +``` + +When designing interfaces, ask: + +- Can I reduce the number of methods? +- Can I simplify the parameters? +- Can I hide more complexity inside? diff --git a/.agents/skills/tdd/interface-design.md b/.agents/skills/tdd/interface-design.md new file mode 100644 index 000000000..c1ed64d41 --- /dev/null +++ b/.agents/skills/tdd/interface-design.md @@ -0,0 +1,31 @@ +# Interface Design for Testability + +Good interfaces make testing natural: + +1. **Accept dependencies, don't create them** + + ```typescript + // Testable + function processOrder(order, paymentGateway) {} + + // Hard to test + function processOrder(order) { + const gateway = new StripeGateway(); + } + ``` + +2. **Return results, don't produce side effects** + + ```typescript + // Testable + function calculateDiscount(cart): Discount {} + + // Hard to test + function applyDiscount(cart): void { + cart.total -= discount; + } + ``` + +3. **Small surface area** + - Fewer methods = fewer tests needed + - Fewer params = simpler test setup diff --git a/.agents/skills/tdd/mocking.md b/.agents/skills/tdd/mocking.md new file mode 100644 index 000000000..fede286dc --- /dev/null +++ b/.agents/skills/tdd/mocking.md @@ -0,0 +1,60 @@ +# When to Mock + +Mock at **system boundaries** only: + +- External APIs (payment, email, etc.) +- Databases (sometimes - prefer test DB) +- Time/randomness +- File system (sometimes) + +Don't mock: + +- Your own classes/modules +- Internal collaborators +- Anything you control + +## Designing for Mockability + +At system boundaries, design interfaces that are easy to mock: + +**1. Use dependency injection** + +Pass external dependencies in rather than creating them internally: + +```typescript +// Easy to mock +function processPayment(order, paymentClient) { + return paymentClient.charge(order.total); +} + +// Hard to mock +function processPayment(order) { + const client = new StripeClient(process.env.STRIPE_KEY); + return client.charge(order.total); +} +``` + +**2. Prefer SDK-style interfaces over generic fetchers** + +Create specific functions for each external operation instead of one generic function with conditional logic: + +```typescript +// GOOD: Each function is independently mockable +const api = { + getUser: (id) => fetch(`/users/${id}`), + getOrders: (userId) => fetch(`/users/${userId}/orders`), + createOrder: (data) => fetch('/orders', { method: 'POST', body: data }), +}; + +// BAD: Mocking requires conditional logic inside the mock +const api = { + fetch: (endpoint, options) => fetch(endpoint, options), +}; +``` + +The SDK approach means: + +- Each mock returns one specific shape +- No conditional logic in test setup +- Easier to see which endpoints a test exercises +- Type safety per endpoint diff --git a/.agents/skills/tdd/refactoring.md b/.agents/skills/tdd/refactoring.md new file mode 100644 index 000000000..8a4443924 --- /dev/null +++ b/.agents/skills/tdd/refactoring.md @@ -0,0 +1,10 @@ +# Refactor Candidates + +After TDD cycle, look for: + +- **Duplication** → Extract function/class +- **Long methods** → Break into private helpers (keep tests on public interface) +- **Shallow modules** → Combine or deepen +- **Feature envy** → Move logic to where data lives +- **Primitive obsession** → Introduce value objects +- **Existing code** the new code reveals as problematic diff --git a/.agents/skills/tdd/tests.md b/.agents/skills/tdd/tests.md new file mode 100644 index 000000000..e2849048f --- /dev/null +++ b/.agents/skills/tdd/tests.md @@ -0,0 +1,61 @@ +# Good and Bad Tests + +## Good Tests + +**Integration-style**: Test through real interfaces, not mocks of internal parts. + +```typescript +// GOOD: Tests observable behavior +test('user can checkout with valid cart', async () => { + const cart = createCart(); + cart.add(product); + const result = await checkout(cart, paymentMethod); + expect(result.status).toBe('confirmed'); +}); +``` + +Characteristics: + +- Tests behavior users/callers care about +- Uses public API only +- Survives internal refactors +- Describes WHAT, not HOW +- One logical assertion per test + +## Bad Tests + +**Implementation-detail tests**: Coupled to internal structure. + +```typescript +// BAD: Tests implementation details +test('checkout calls paymentService.process', async () => { + const mockPayment = jest.mock(paymentService); + await checkout(cart, payment); + expect(mockPayment.process).toHaveBeenCalledWith(cart.total); +}); +``` + +Red flags: + +- Mocking internal collaborators +- Testing private methods +- Asserting on call counts/order +- Test breaks when refactoring without behavior change +- Test name describes HOW not WHAT +- Verifying through external means instead of interface + +```typescript +// BAD: Bypasses interface to verify +test('createUser saves to database', async () => { + await createUser({ name: 'Alice' }); + const row = await db.query('SELECT * FROM users WHERE name = ?', ['Alice']); + expect(row).toBeDefined(); +}); + +// GOOD: Verifies through interface +test('createUser makes user retrievable', async () => { + const user = await createUser({ name: 'Alice' }); + const retrieved = await getUser(user.id); + expect(retrieved.name).toBe('Alice'); +}); +``` diff --git a/.changeset/four-beers-march.md b/.changeset/four-beers-march.md new file mode 100644 index 000000000..d847e5340 --- /dev/null +++ b/.changeset/four-beers-march.md @@ -0,0 +1,7 @@ +--- +"@dexto/core": patch +"@dexto/tui": patch +"dexto": patch +--- + +Persist core-owned interaction state with the existing storage layers and update the TUI/CLI call sites for the new async session tool preference APIs. diff --git a/packages/cli/src/cli/cloud-chat.ts b/packages/cli/src/cli/cloud-chat.ts index 9677698aa..9929d1cf4 100644 --- a/packages/cli/src/cli/cloud-chat.ts +++ b/packages/cli/src/cli/cloud-chat.ts @@ -809,15 +809,15 @@ export function createCloudAgentBackend( globalDisabledTools.splice(0, globalDisabledTools.length, ...toolNames); }) as CloudChatBackend['setGlobalDisabledTools'], - setSessionDisabledTools: ((sessionId, toolNames) => { + setSessionDisabledTools: (async (sessionId, toolNames) => { sessionDisabledTools.set(sessionId, [...toolNames]); }) as CloudChatBackend['setSessionDisabledTools'], - setSessionAutoApproveTools: ((sessionId, toolNames) => { + setSessionAutoApproveTools: (async (sessionId, toolNames) => { sessionAutoApproveTools.set(sessionId, [...toolNames]); }) as CloudChatBackend['setSessionAutoApproveTools'], - getSessionAutoApproveTools: ((sessionId) => { + getSessionAutoApproveTools: (async (sessionId) => { return sessionAutoApproveTools.get(sessionId) ?? []; }) as CloudChatBackend['getSessionAutoApproveTools'], diff --git a/packages/core/src/agent/DextoAgent.lifecycle.test.ts b/packages/core/src/agent/DextoAgent.lifecycle.test.ts index bc2bfdad0..3ceb41101 100644 --- a/packages/core/src/agent/DextoAgent.lifecycle.test.ts +++ b/packages/core/src/agent/DextoAgent.lifecycle.test.ts @@ -499,6 +499,29 @@ describe('DextoAgent Lifecycle Management', () => { expect(updateLLM).not.toHaveBeenCalled(); expect(switchLLMForSpecificSession).not.toHaveBeenCalled(); }); + + test('switchLLM should reject empty session ids instead of falling back to global state', async () => { + const agent = createTestAgent(mockValidatedConfig); + await agent.start(); + + const updateLLM = mockServices.stateManager.updateLLM as ReturnType; + const getSession = mockServices.sessionManager.getSession as ReturnType; + const switchLLMForSpecificSession = mockServices.sessionManager + .switchLLMForSpecificSession as ReturnType; + const switchLLMForAllSessions = mockServices.sessionManager + .switchLLMForAllSessions as ReturnType; + + await expect(agent.switchLLM({ model: 'gpt-5-nano' }, '')).rejects.toMatchObject({ + code: AgentErrorCode.API_VALIDATION_ERROR, + scope: ErrorScope.AGENT, + type: ErrorType.USER, + }); + + expect(updateLLM).not.toHaveBeenCalled(); + expect(getSession).not.toHaveBeenCalled(); + expect(switchLLMForSpecificSession).not.toHaveBeenCalled(); + expect(switchLLMForAllSessions).not.toHaveBeenCalled(); + }); }); describe('Session Auto-Approve Tools Cleanup (Memory Leak Fix)', () => { diff --git a/packages/core/src/agent/DextoAgent.ts b/packages/core/src/agent/DextoAgent.ts index 4a9ff57bc..b9e1d461d 100644 --- a/packages/core/src/agent/DextoAgent.ts +++ b/packages/core/src/agent/DextoAgent.ts @@ -1973,8 +1973,9 @@ export class DextoAgent { } /** - * Resets the conversation history for a specific session. - * Keeps the session alive but the conversation history is cleared. + * Resets the conversation and session-scoped interaction state for a specific session. + * Keeps the session alive, but clears persisted mid-session state such as queued follow-ups, + * approval memory, tool preferences, and session-level LLM overrides. * @param sessionId Session ID (required) */ public async resetConversation(sessionId: string): Promise { @@ -2389,10 +2390,17 @@ export class DextoAgent { } const validatedUpdates = parseResult.data; + if (sessionId !== undefined && sessionId !== '*' && sessionId.trim() === '') { + throw AgentError.apiValidationError( + 'sessionId must be a non-empty string when provided' + ); + } + // Get current config for the session - const currentLLMConfig = sessionId - ? this.stateManager.getRuntimeConfig(sessionId).llm - : this.stateManager.getRuntimeConfig().llm; + const currentLLMConfig = + sessionId !== undefined && sessionId !== '*' + ? this.stateManager.getRuntimeConfig(sessionId).llm + : this.stateManager.getRuntimeConfig().llm; // Build and validate the new configuration using Result pattern internally const result = await resolveAndValidateLLMConfig( @@ -2433,16 +2441,13 @@ export class DextoAgent { ): Promise { // Switch LLM in session(s) if (sessionScope === '*') { - // Update state manager (no validation needed - already validated) - this.stateManager.updateLLM(validatedConfig, sessionScope); await this.sessionManager.switchLLMForAllSessions(validatedConfig); - } else if (sessionScope) { + } else if (sessionScope !== undefined) { // Verify session exists before switching LLM const session = await this.sessionManager.getSession(sessionScope); if (!session) { throw SessionError.notFound(sessionScope); } - this.stateManager.updateLLM(validatedConfig, sessionScope); await this.sessionManager.switchLLMForSpecificSession(validatedConfig, sessionScope); } else { // No sessionScope provided - this is a configuration-level switch only @@ -2987,6 +2992,9 @@ export class DextoAgent { if (sessionId !== undefined && (!sessionId || typeof sessionId !== 'string')) { throw AgentError.apiValidationError('sessionId must be a non-empty string'); } + if (sessionId !== undefined) { + await this.toolManager.restoreSessionState(sessionId); + } return this.toolManager.filterToolsForSession( await this.toolManager.getAllTools(), sessionId @@ -3018,7 +3026,7 @@ export class DextoAgent { /** * Set session-level disabled tools (session override). */ - public setSessionDisabledTools(sessionId: string, toolNames: string[]): void { + public async setSessionDisabledTools(sessionId: string, toolNames: string[]): Promise { this.ensureStarted(); if (!sessionId || typeof sessionId !== 'string') { throw AgentError.apiValidationError( @@ -3031,39 +3039,40 @@ export class DextoAgent { ) { throw AgentError.apiValidationError('toolNames must be an array of non-empty strings'); } - this.toolManager.setSessionDisabledTools(sessionId, toolNames); + await this.toolManager.setSessionDisabledTools(sessionId, toolNames); } /** * Clear session-level disabled tools (session override). */ - public clearSessionDisabledTools(sessionId: string): void { + public async clearSessionDisabledTools(sessionId: string): Promise { this.ensureStarted(); if (!sessionId || typeof sessionId !== 'string') { throw AgentError.apiValidationError( 'sessionId is required and must be a non-empty string' ); } - this.toolManager.clearSessionDisabledTools(sessionId); + await this.toolManager.clearSessionDisabledTools(sessionId); } /** * Get session-level auto-approve tools. */ - public getSessionAutoApproveTools(sessionId: string): string[] { + public async getSessionAutoApproveTools(sessionId: string): Promise { this.ensureStarted(); if (!sessionId || typeof sessionId !== 'string') { throw AgentError.apiValidationError( 'sessionId is required and must be a non-empty string' ); } + await this.toolManager.restoreSessionState(sessionId); return this.toolManager.getSessionUserAutoApproveTools(sessionId) ?? []; } /** * Set session-level auto-approve tools (user selection). */ - public setSessionAutoApproveTools(sessionId: string, toolNames: string[]): void { + public async setSessionAutoApproveTools(sessionId: string, toolNames: string[]): Promise { this.ensureStarted(); if (!sessionId || typeof sessionId !== 'string') { throw AgentError.apiValidationError( @@ -3076,7 +3085,7 @@ export class DextoAgent { ) { throw AgentError.apiValidationError('toolNames must be an array of non-empty strings'); } - this.toolManager.setSessionUserAutoApproveTools(sessionId, toolNames); + await this.toolManager.setSessionUserAutoApproveTools(sessionId, toolNames); } /** diff --git a/packages/core/src/approval/manager.test.ts b/packages/core/src/approval/manager.test.ts index d41babd17..192281aa3 100644 --- a/packages/core/src/approval/manager.test.ts +++ b/packages/core/src/approval/manager.test.ts @@ -1,4 +1,4 @@ -import { describe, it, expect, beforeEach } from 'vitest'; +import { describe, it, expect, beforeEach, vi } from 'vitest'; import * as path from 'node:path'; import * as os from 'node:os'; import { mkdtempSync, mkdirSync, rmSync, symlinkSync } from 'node:fs'; @@ -8,18 +8,38 @@ import { AgentEventBus } from '../events/index.js'; import { DextoRuntimeError } from '../errors/index.js'; import { ApprovalErrorCode } from './error-codes.js'; import { createMockLogger } from '../logger/v2/test-utils.js'; +import type { Logger } from '../logger/v2/types.js'; +import type { SessionApprovalState } from './session-approval-store.js'; +import { createInMemorySessionApprovalStore } from '../test-utils/session-state-stores.js'; + +function createDeferred() { + let resolve!: (value: T | PromiseLike) => void; + let reject!: (reason?: unknown) => void; + const promise = new Promise((res, rej) => { + resolve = res; + reject = rej; + }); + return { promise, resolve, reject }; +} describe('ApprovalManager', () => { let agentEventBus: AgentEventBus; const mockLogger = createMockLogger(); + function createApprovalManager( + config: ConstructorParameters[0], + logger: Logger = mockLogger + ) { + return new ApprovalManager(config, logger, createInMemorySessionApprovalStore(logger)); + } + beforeEach(() => { agentEventBus = new AgentEventBus(); }); describe('Configuration - Separate tool and elicitation control', () => { it('should allow auto-approve for tools while elicitation is enabled', async () => { - const manager = new ApprovalManager( + const manager = createApprovalManager( { permissions: { mode: 'auto-approve', @@ -44,7 +64,7 @@ describe('ApprovalManager', () => { }); it('should reject elicitation when disabled, even if tools are auto-approved', async () => { - const manager = new ApprovalManager( + const manager = createApprovalManager( { permissions: { mode: 'auto-approve', @@ -87,7 +107,7 @@ describe('ApprovalManager', () => { }); it('should auto-deny tools while elicitation is enabled', async () => { - const manager = new ApprovalManager( + const manager = createApprovalManager( { permissions: { mode: 'auto-deny', @@ -112,7 +132,7 @@ describe('ApprovalManager', () => { }); it('should use separate timeouts for tools and elicitation', () => { - const manager = new ApprovalManager( + const manager = createApprovalManager( { permissions: { mode: 'manual', @@ -134,7 +154,7 @@ describe('ApprovalManager', () => { describe('Approval routing by type', () => { it('should route tool approvals to tool approval handler', async () => { - const manager = new ApprovalManager( + const manager = createApprovalManager( { permissions: { mode: 'auto-approve', @@ -158,7 +178,7 @@ describe('ApprovalManager', () => { }); it('should route command confirmations to tool confirmation handler', async () => { - const manager = new ApprovalManager( + const manager = createApprovalManager( { permissions: { mode: 'auto-approve', @@ -182,7 +202,7 @@ describe('ApprovalManager', () => { }); it('should route elicitation to elicitation provider when enabled', async () => { - const manager = new ApprovalManager( + const manager = createApprovalManager( { permissions: { mode: 'auto-deny', // Different mode for tools @@ -216,7 +236,7 @@ describe('ApprovalManager', () => { describe('Pending approvals tracking', () => { it('should track pending approvals across both providers', () => { - const manager = new ApprovalManager( + const manager = createApprovalManager( { permissions: { mode: 'manual', @@ -238,7 +258,7 @@ describe('ApprovalManager', () => { }); it('should cancel approvals in both providers', () => { - const manager = new ApprovalManager( + const manager = createApprovalManager( { permissions: { mode: 'manual', @@ -260,7 +280,7 @@ describe('ApprovalManager', () => { describe('Error handling', () => { it('should throw clear error when elicitation is disabled', async () => { - const manager = new ApprovalManager( + const manager = createApprovalManager( { permissions: { mode: 'auto-approve', @@ -289,7 +309,7 @@ describe('ApprovalManager', () => { }); it('should provide helpful error message about enabling elicitation', async () => { - const manager = new ApprovalManager( + const manager = createApprovalManager( { permissions: { mode: 'auto-approve', @@ -321,11 +341,43 @@ describe('ApprovalManager', () => { expect((error as Error).message).toContain('agent configuration'); } }); + + it('should treat approved elicitations without formData as an empty object', async () => { + const manager = createApprovalManager( + { + permissions: { + mode: 'auto-approve', + timeout: 120000, + }, + elicitation: { + enabled: true, + timeout: 120000, + }, + }, + mockLogger + ); + + manager.setHandler(async (request) => ({ + approvalId: request.approvalId, + status: ApprovalStatus.APPROVED, + })); + + await expect( + manager.getElicitationData({ + schema: { + type: 'object' as const, + properties: {}, + }, + prompt: 'Anything to add?', + serverName: 'Test Server', + }) + ).resolves.toEqual({}); + }); }); describe('Timeout Configuration', () => { it('should allow undefined timeout (infinite wait) for tool confirmation', () => { - const manager = new ApprovalManager( + const manager = createApprovalManager( { permissions: { mode: 'manual', @@ -344,7 +396,7 @@ describe('ApprovalManager', () => { }); it('should allow undefined timeout (infinite wait) for elicitation', () => { - const manager = new ApprovalManager( + const manager = createApprovalManager( { permissions: { mode: 'manual', @@ -363,7 +415,7 @@ describe('ApprovalManager', () => { }); it('should allow both timeouts to be undefined (infinite wait for all approvals)', () => { - const manager = new ApprovalManager( + const manager = createApprovalManager( { permissions: { mode: 'manual', @@ -383,7 +435,7 @@ describe('ApprovalManager', () => { }); it('should use per-request timeout override when provided', async () => { - const manager = new ApprovalManager( + const manager = createApprovalManager( { permissions: { mode: 'auto-approve', // Auto-approve so we can test immediately @@ -410,7 +462,7 @@ describe('ApprovalManager', () => { }); it('should not timeout when timeout is undefined in auto-approve mode', async () => { - const manager = new ApprovalManager( + const manager = createApprovalManager( { permissions: { mode: 'auto-approve', @@ -433,7 +485,7 @@ describe('ApprovalManager', () => { }); it('should not timeout when timeout is undefined in auto-deny mode', async () => { - const manager = new ApprovalManager( + const manager = createApprovalManager( { permissions: { mode: 'auto-deny', @@ -459,7 +511,7 @@ describe('ApprovalManager', () => { describe('Backward compatibility', () => { it('should work with manual mode for both tools and elicitation', () => { - const manager = new ApprovalManager( + const manager = createApprovalManager( { permissions: { mode: 'manual', @@ -486,7 +538,7 @@ describe('ApprovalManager', () => { }); it('should respect explicitly set elicitation enabled value', () => { - const manager = new ApprovalManager( + const manager = createApprovalManager( { permissions: { mode: 'manual', @@ -506,7 +558,7 @@ describe('ApprovalManager', () => { describe('Denial Reasons', () => { it('should include system_denied reason in auto-deny mode', async () => { - const manager = new ApprovalManager( + const manager = createApprovalManager( { permissions: { mode: 'auto-deny', @@ -532,7 +584,7 @@ describe('ApprovalManager', () => { }); it('should throw error with specific reason when tool is denied', async () => { - const manager = new ApprovalManager( + const manager = createApprovalManager( { permissions: { mode: 'auto-deny', @@ -564,7 +616,7 @@ describe('ApprovalManager', () => { }); it('should handle user_denied reason in error message', async () => { - const _manager = new ApprovalManager( + const _manager = createApprovalManager( { permissions: { mode: 'manual', @@ -628,7 +680,7 @@ describe('ApprovalManager', () => { const toolName = 'bash_exec'; beforeEach(() => { - manager = new ApprovalManager( + manager = createApprovalManager( { permissions: { mode: 'manual', @@ -643,15 +695,15 @@ describe('ApprovalManager', () => { }); describe('addPattern', () => { - it('should add a pattern to the approved list', () => { - manager.addPattern(toolName, 'git *'); + it('should add a pattern to the approved list', async () => { + await manager.addPattern(toolName, 'git *'); expect(manager.getToolPatterns(toolName).has('git *')).toBe(true); }); - it('should add multiple patterns', () => { - manager.addPattern(toolName, 'git *'); - manager.addPattern(toolName, 'npm *'); - manager.addPattern(toolName, 'ls *'); + it('should add multiple patterns', async () => { + await manager.addPattern(toolName, 'git *'); + await manager.addPattern(toolName, 'npm *'); + await manager.addPattern(toolName, 'ls *'); const patterns = manager.getToolPatterns(toolName); expect(patterns.size).toBe(3); @@ -660,9 +712,9 @@ describe('ApprovalManager', () => { expect(patterns.has('ls *')).toBe(true); }); - it('should not duplicate patterns', () => { - manager.addPattern(toolName, 'git *'); - manager.addPattern(toolName, 'git *'); + it('should not duplicate patterns', async () => { + await manager.addPattern(toolName, 'git *'); + await manager.addPattern(toolName, 'git *'); expect(manager.getToolPatterns(toolName).size).toBe(1); }); @@ -672,32 +724,32 @@ describe('ApprovalManager', () => { // Note: matchesPattern expects pattern keys (e.g., "git push *"), // not raw commands. ToolManager generates pattern keys from commands. - it('should match exact pattern against exact stored pattern', () => { - manager.addPattern(toolName, 'git status *'); + it('should match exact pattern against exact stored pattern', async () => { + await manager.addPattern(toolName, 'git status *'); expect(manager.matchesPattern(toolName, 'git status *')).toBe(true); expect(manager.matchesPattern(toolName, 'git push *')).toBe(false); }); - it('should cover narrower pattern with broader pattern', () => { + it('should cover narrower pattern with broader pattern', async () => { // "git *" is broader and should cover "git push *", "git status *", etc. - manager.addPattern(toolName, 'git *'); + await manager.addPattern(toolName, 'git *'); expect(manager.matchesPattern(toolName, 'git *')).toBe(true); expect(manager.matchesPattern(toolName, 'git push *')).toBe(true); expect(manager.matchesPattern(toolName, 'git status *')).toBe(true); expect(manager.matchesPattern(toolName, 'npm *')).toBe(false); }); - it('should not let narrower pattern cover broader pattern', () => { + it('should not let narrower pattern cover broader pattern', async () => { // "git push *" should NOT cover "git *" - manager.addPattern(toolName, 'git push *'); + await manager.addPattern(toolName, 'git push *'); expect(manager.matchesPattern(toolName, 'git push *')).toBe(true); expect(manager.matchesPattern(toolName, 'git *')).toBe(false); expect(manager.matchesPattern(toolName, 'git status *')).toBe(false); }); - it('should match against multiple patterns', () => { - manager.addPattern(toolName, 'git *'); - manager.addPattern(toolName, 'npm install *'); + it('should match against multiple patterns', async () => { + await manager.addPattern(toolName, 'git *'); + await manager.addPattern(toolName, 'npm install *'); expect(manager.matchesPattern(toolName, 'git status *')).toBe(true); expect(manager.matchesPattern(toolName, 'npm install *')).toBe(true); @@ -709,40 +761,113 @@ describe('ApprovalManager', () => { expect(manager.matchesPattern(toolName, 'git status *')).toBe(false); }); - it('should not cross-match unrelated commands', () => { - manager.addPattern(toolName, 'npm *'); + it('should not cross-match unrelated commands', async () => { + await manager.addPattern(toolName, 'npm *'); // "npx" starts with "np" but is not "npm " + something expect(manager.matchesPattern(toolName, 'npx *')).toBe(false); }); - it('should handle multi-level subcommands', () => { - manager.addPattern(toolName, 'docker compose *'); + it('should handle multi-level subcommands', async () => { + await manager.addPattern(toolName, 'docker compose *'); expect(manager.matchesPattern(toolName, 'docker compose *')).toBe(true); expect(manager.matchesPattern(toolName, 'docker compose up *')).toBe(true); expect(manager.matchesPattern(toolName, 'docker *')).toBe(false); }); - it('should isolate patterns by tool', () => { - manager.addPattern('tool-a', 'git *'); + it('should isolate patterns by tool', async () => { + await manager.addPattern('tool-a', 'git *'); expect(manager.matchesPattern('tool-a', 'git push *')).toBe(true); expect(manager.matchesPattern('tool-b', 'git push *')).toBe(false); }); + + it('should serialize deleteSessionState with in-flight pattern persistence', async () => { + const sessionId = 'locked-delete-session'; + const saveStarted = createDeferred(); + const releaseSave = createDeferred(); + const persistedState = new Map(); + const emptyState: SessionApprovalState = { + toolPatterns: {}, + approvedDirectories: [], + }; + const store = { + load: vi.fn().mockImplementation(async (requestedSessionId?: string) => { + return structuredClone( + persistedState.get(requestedSessionId ?? '__global__') ?? emptyState + ); + }), + save: vi + .fn() + .mockImplementation( + async ( + requestedSessionId: string | undefined, + state: SessionApprovalState + ) => { + saveStarted.resolve(); + await releaseSave.promise; + persistedState.set( + requestedSessionId ?? '__global__', + structuredClone(state) + ); + } + ), + delete: vi.fn().mockImplementation(async (requestedSessionId?: string) => { + persistedState.delete(requestedSessionId ?? '__global__'); + }), + }; + const manager = new ApprovalManager( + { + permissions: { + mode: 'auto-approve', + timeout: 120000, + }, + elicitation: { + enabled: true, + timeout: 120000, + }, + }, + mockLogger, + store as unknown as ConstructorParameters[2] + ); + + const addPatternPromise = manager.addPattern('bash_exec', 'git *', sessionId); + await saveStarted.promise; + + let deleteFinished = false; + const deletePromise = manager.deleteSessionState(sessionId).then(() => { + deleteFinished = true; + }); + + await Promise.resolve(); + expect(deleteFinished).toBe(false); + + releaseSave.resolve(); + await addPatternPromise; + await deletePromise; + + expect( + persistedState.get(sessionId) ?? { + toolPatterns: {}, + approvedDirectories: [], + } + ).toEqual(emptyState); + expect(manager.matchesPattern('bash_exec', 'git status *', sessionId)).toBe(false); + }); }); describe('clearPatterns', () => { - it('should clear patterns for a tool', () => { - manager.addPattern(toolName, 'git *'); - manager.addPattern(toolName, 'npm *'); + it('should clear patterns for a tool', async () => { + await manager.addPattern(toolName, 'git *'); + await manager.addPattern(toolName, 'npm *'); expect(manager.getToolPatterns(toolName).size).toBe(2); - manager.clearPatterns(toolName); + await manager.clearPatterns(toolName); expect(manager.getToolPatterns(toolName).size).toBe(0); }); - it('should allow adding patterns after clearing', () => { - manager.addPattern(toolName, 'git *'); - manager.clearPatterns(toolName); - manager.addPattern(toolName, 'npm *'); + it('should allow adding patterns after clearing', async () => { + await manager.addPattern(toolName, 'git *'); + await manager.clearPatterns(toolName); + await manager.addPattern(toolName, 'npm *'); expect(manager.getToolPatterns(toolName).size).toBe(1); expect(manager.getToolPatterns(toolName).has('npm *')).toBe(true); @@ -754,8 +879,8 @@ describe('ApprovalManager', () => { expect(manager.getToolPatterns(toolName).size).toBe(0); }); - it('should return a copy that reflects current patterns', () => { - manager.addPattern(toolName, 'git *'); + it('should return a copy that reflects current patterns', async () => { + await manager.addPattern(toolName, 'git *'); const patterns = manager.getToolPatterns(toolName); expect(patterns.has('git *')).toBe(true); }); @@ -766,7 +891,7 @@ describe('ApprovalManager', () => { let manager: ApprovalManager; beforeEach(() => { - manager = new ApprovalManager( + manager = createApprovalManager( { permissions: { mode: 'manual', @@ -781,26 +906,26 @@ describe('ApprovalManager', () => { }); describe('initializeWorkingDirectory', () => { - it('should add working directory as session-approved', () => { - manager.initializeWorkingDirectory('/home/user/project'); + it('should add working directory as session-approved', async () => { + await manager.initializeWorkingDirectory('/home/user/project'); expect(manager.isDirectorySessionApproved('/home/user/project/src/file.ts')).toBe( true ); }); - it('should normalize the path before adding', () => { - manager.initializeWorkingDirectory('/home/user/../user/project'); + it('should normalize the path before adding', async () => { + await manager.initializeWorkingDirectory('/home/user/../user/project'); expect(manager.isDirectorySessionApproved('/home/user/project/file.ts')).toBe(true); }); }); describe('addApprovedDirectory', () => { - it('should add directory with session type by default', () => { - manager.addApprovedDirectory('/external/project'); + it('should add directory with session type by default', async () => { + await manager.addApprovedDirectory('/external/project'); expect(manager.isDirectorySessionApproved('/external/project/file.ts')).toBe(true); }); - it('should treat symlink-approved directory as approved for its realpath', () => { + it('should treat symlink-approved directory as approved for its realpath', async () => { const baseDir = mkdtempSync(path.join(os.tmpdir(), 'dexto-approval-symlink-')); try { const actualDir = path.join(baseDir, 'actual'); @@ -810,7 +935,7 @@ describe('ApprovalManager', () => { const symlinkType = process.platform === 'win32' ? 'junction' : 'dir'; symlinkSync(actualDir, linkDir, symlinkType); - manager.addApprovedDirectory(linkDir, 'session'); + await manager.addApprovedDirectory(linkDir, 'session'); expect(manager.isDirectoryApproved(path.join(actualDir, 'file.ts'))).toBe(true); expect( @@ -821,7 +946,7 @@ describe('ApprovalManager', () => { } }); - it('should treat approved directory as approved for its realpath even if the directory did not exist yet', () => { + it('should treat approved directory as approved for its realpath even if the directory did not exist yet', async () => { const baseDir = mkdtempSync( path.join(os.tmpdir(), 'dexto-approval-symlink-missing-leaf-') ); @@ -837,7 +962,7 @@ describe('ApprovalManager', () => { const approvedDir = path.join(linkDir, 'child'); // Approve a directory that doesn't exist yet (common for write/create flows). - manager.addApprovedDirectory(approvedDir, 'session'); + await manager.addApprovedDirectory(approvedDir, 'session'); const actualChildDir = path.join(actualDir, 'child'); mkdirSync(actualChildDir); @@ -850,69 +975,69 @@ describe('ApprovalManager', () => { } }); - it('should add directory with explicit session type', () => { - manager.addApprovedDirectory('/external/project', 'session'); + it('should add directory with explicit session type', async () => { + await manager.addApprovedDirectory('/external/project', 'session'); expect(manager.isDirectorySessionApproved('/external/project/file.ts')).toBe(true); }); - it('should add directory with once type', () => { - manager.addApprovedDirectory('/external/project', 'once'); + it('should add directory with once type', async () => { + await manager.addApprovedDirectory('/external/project', 'once'); // 'once' type should NOT be session-approved (requires prompt each time) expect(manager.isDirectorySessionApproved('/external/project/file.ts')).toBe(false); // But should be generally approved for execution expect(manager.isDirectoryApproved('/external/project/file.ts')).toBe(true); }); - it('should not downgrade from session to once', () => { - manager.addApprovedDirectory('/external/project', 'session'); - manager.addApprovedDirectory('/external/project', 'once'); + it('should not downgrade from session to once', async () => { + await manager.addApprovedDirectory('/external/project', 'session'); + await manager.addApprovedDirectory('/external/project', 'once'); // Should still be session-approved expect(manager.isDirectorySessionApproved('/external/project/file.ts')).toBe(true); }); - it('should upgrade from once to session', () => { - manager.addApprovedDirectory('/external/project', 'once'); + it('should upgrade from once to session', async () => { + await manager.addApprovedDirectory('/external/project', 'once'); expect(manager.isDirectorySessionApproved('/external/project/file.ts')).toBe(false); - manager.addApprovedDirectory('/external/project', 'session'); + await manager.addApprovedDirectory('/external/project', 'session'); expect(manager.isDirectorySessionApproved('/external/project/file.ts')).toBe(true); }); - it('should normalize paths before adding', () => { - manager.addApprovedDirectory('/external/../external/project'); + it('should normalize paths before adding', async () => { + await manager.addApprovedDirectory('/external/../external/project'); expect(manager.isDirectoryApproved('/external/project/file.ts')).toBe(true); }); }); describe('isDirectorySessionApproved', () => { - it('should return true for files within session-approved directory', () => { - manager.addApprovedDirectory('/external/project', 'session'); + it('should return true for files within session-approved directory', async () => { + await manager.addApprovedDirectory('/external/project', 'session'); expect(manager.isDirectorySessionApproved('/external/project/file.ts')).toBe(true); expect( manager.isDirectorySessionApproved('/external/project/src/deep/file.ts') ).toBe(true); }); - it('should return false for files within once-approved directory', () => { - manager.addApprovedDirectory('/external/project', 'once'); + it('should return false for files within once-approved directory', async () => { + await manager.addApprovedDirectory('/external/project', 'once'); expect(manager.isDirectorySessionApproved('/external/project/file.ts')).toBe(false); }); - it('should return false for files outside approved directories', () => { - manager.addApprovedDirectory('/external/project', 'session'); + it('should return false for files outside approved directories', async () => { + await manager.addApprovedDirectory('/external/project', 'session'); expect(manager.isDirectorySessionApproved('/other/file.ts')).toBe(false); }); - it('should handle path containment correctly', () => { - manager.addApprovedDirectory('/external', 'session'); + it('should handle path containment correctly', async () => { + await manager.addApprovedDirectory('/external', 'session'); // Approving /external should cover /external/sub/file.ts expect(manager.isDirectorySessionApproved('/external/sub/file.ts')).toBe(true); // But not /external-other/file.ts (different directory) expect(manager.isDirectorySessionApproved('/external-other/file.ts')).toBe(false); }); - it('should return true when working directory is initialized', () => { - manager.initializeWorkingDirectory('/home/user/project'); + it('should return true when working directory is initialized', async () => { + await manager.initializeWorkingDirectory('/home/user/project'); expect(manager.isDirectorySessionApproved('/home/user/project/any/file.ts')).toBe( true ); @@ -920,32 +1045,32 @@ describe('ApprovalManager', () => { }); describe('isDirectoryApproved', () => { - it('should return true for files within session-approved directory', () => { - manager.addApprovedDirectory('/external/project', 'session'); + it('should return true for files within session-approved directory', async () => { + await manager.addApprovedDirectory('/external/project', 'session'); expect(manager.isDirectoryApproved('/external/project/file.ts')).toBe(true); }); - it('should return true for files within once-approved directory', () => { - manager.addApprovedDirectory('/external/project', 'once'); + it('should return true for files within once-approved directory', async () => { + await manager.addApprovedDirectory('/external/project', 'once'); expect(manager.isDirectoryApproved('/external/project/file.ts')).toBe(true); }); - it('should return false for files outside approved directories', () => { - manager.addApprovedDirectory('/external/project', 'session'); + it('should return false for files outside approved directories', async () => { + await manager.addApprovedDirectory('/external/project', 'session'); expect(manager.isDirectoryApproved('/other/file.ts')).toBe(false); }); - it('should handle multiple approved directories', () => { - manager.addApprovedDirectory('/external/project1', 'session'); - manager.addApprovedDirectory('/external/project2', 'once'); + it('should handle multiple approved directories', async () => { + await manager.addApprovedDirectory('/external/project1', 'session'); + await manager.addApprovedDirectory('/external/project2', 'once'); expect(manager.isDirectoryApproved('/external/project1/file.ts')).toBe(true); expect(manager.isDirectoryApproved('/external/project2/file.ts')).toBe(true); expect(manager.isDirectoryApproved('/external/project3/file.ts')).toBe(false); }); - it('should handle nested directory approvals', () => { - manager.addApprovedDirectory('/external', 'session'); + it('should handle nested directory approvals', async () => { + await manager.addApprovedDirectory('/external', 'session'); // Approving /external should cover all subdirectories expect(manager.isDirectoryApproved('/external/sub/deep/file.ts')).toBe(true); }); @@ -956,9 +1081,9 @@ describe('ApprovalManager', () => { expect(manager.getApprovedDirectories().size).toBe(0); }); - it('should return map with type information', () => { - manager.addApprovedDirectory('/external/project1', 'session'); - manager.addApprovedDirectory('/external/project2', 'once'); + it('should return map with type information', async () => { + await manager.addApprovedDirectory('/external/project1', 'session'); + await manager.addApprovedDirectory('/external/project2', 'once'); const dirs = manager.getApprovedDirectories(); expect(dirs.size).toBeGreaterThanOrEqual(2); @@ -968,8 +1093,8 @@ describe('ApprovalManager', () => { expect(keys.some((k) => k.includes('project2'))).toBe(true); }); - it('should include working directory after initialization', () => { - manager.initializeWorkingDirectory('/home/user/project'); + it('should include working directory after initialization', async () => { + await manager.initializeWorkingDirectory('/home/user/project'); const dirs = manager.getApprovedDirectories(); expect(dirs.size).toBeGreaterThanOrEqual(1); expect(dirs.get(path.resolve('/home/user/project'))).toBe('session'); @@ -980,22 +1105,22 @@ describe('ApprovalManager', () => { describe('Session vs Once Prompting Behavior', () => { // These tests verify the expected prompting flow - it('working directory should not require prompt (session-approved)', () => { - manager.initializeWorkingDirectory('/home/user/project'); + it('working directory should not require prompt (session-approved)', async () => { + await manager.initializeWorkingDirectory('/home/user/project'); // isDirectorySessionApproved returns true → no directory prompt needed expect(manager.isDirectorySessionApproved('/home/user/project/src/file.ts')).toBe( true ); }); - it('external dir after session approval should not require prompt', () => { - manager.addApprovedDirectory('/external', 'session'); + it('external dir after session approval should not require prompt', async () => { + await manager.addApprovedDirectory('/external', 'session'); // isDirectorySessionApproved returns true → no directory prompt needed expect(manager.isDirectorySessionApproved('/external/file.ts')).toBe(true); }); - it('external dir after once approval should require prompt each time', () => { - manager.addApprovedDirectory('/external', 'once'); + it('external dir after once approval should require prompt each time', async () => { + await manager.addApprovedDirectory('/external', 'once'); // isDirectorySessionApproved returns false → directory prompt needed expect(manager.isDirectorySessionApproved('/external/file.ts')).toBe(false); // But isDirectoryApproved returns true → execution allowed diff --git a/packages/core/src/approval/manager.ts b/packages/core/src/approval/manager.ts index ffb75e149..c16204d22 100644 --- a/packages/core/src/approval/manager.ts +++ b/packages/core/src/approval/manager.ts @@ -17,6 +17,18 @@ import { DextoLogComponent } from '../logger/v2/types.js'; import { ApprovalError } from './errors.js'; import { patternCovers } from '../tools/pattern-utils.js'; import type { PermissionsMode } from '../tools/schemas.js'; +import { + SessionApprovalStore, + type PersistedApprovedDirectory, + type SessionApprovalState, +} from './session-approval-store.js'; + +const GLOBAL_APPROVAL_SCOPE = '__global__'; + +type ApprovalScopeState = { + toolPatterns: Map>; + approvedDirectories: Map; +}; function tryRealpathSync(targetPath: string): string | null { try { @@ -97,25 +109,15 @@ export class ApprovalManager { private handler: ApprovalHandler | undefined; private config: ApprovalManagerConfig; private logger: Logger; - - /** - * Tool approval patterns, keyed by tool id. - * - * Patterns use simple glob syntax (e.g. "git *", "npm install *") and are matched - * using pattern-to-pattern covering (see {@link patternCovers}). - */ - private toolPatterns: Map> = new Map(); - - /** - * Directories approved for file access for the current session. - * Stores normalized absolute paths mapped to their approval type: - * - 'session': No directory prompt, follows tool config (working dir + user session-approved) - * - 'once': Prompts each time, but tool can execute - * Cleared when session ends. - */ - private approvedDirectories: Map = new Map(); - - constructor(config: ApprovalManagerConfig, logger: Logger) { + private readonly loadedScopes = new Set(); + private readonly scopeLocks = new Map>(); + private readonly scopes = new Map(); + + constructor( + config: ApprovalManagerConfig, + logger: Logger, + private readonly sessionApprovalStore: SessionApprovalStore + ) { this.config = config; this.logger = logger.createChild(DextoLogComponent.APPROVAL); @@ -124,22 +126,170 @@ export class ApprovalManager { ); } + private getScopeKey(sessionId?: string): string { + return sessionId ?? GLOBAL_APPROVAL_SCOPE; + } + + private getScopeLabel(sessionId?: string): string { + return sessionId ?? 'global'; + } + + private getApprovalTimeout(type: ApprovalType, timeout?: number): number | undefined { + return timeout ?? this.getDefaultTimeout(type); + } + + private getDefaultTimeout(type: ApprovalType): number | undefined { + return type === ApprovalType.ELICITATION + ? this.config.elicitation.timeout + : this.config.permissions.timeout; + } + + private createEmptyScopeState(): ApprovalScopeState { + return { + toolPatterns: new Map(), + approvedDirectories: new Map(), + }; + } + + private getOrCreateScope(scopeKey: string): ApprovalScopeState { + const existing = this.scopes.get(scopeKey); + if (existing) return existing; + const created = this.createEmptyScopeState(); + this.scopes.set(scopeKey, created); + return created; + } + + private getScope(scopeKey: string): ApprovalScopeState { + return this.scopes.get(scopeKey) ?? this.createEmptyScopeState(); + } + + private async runWithScopeLock(scopeKey: string, fn: () => Promise): Promise { + const previousLock = this.scopeLocks.get(scopeKey) ?? Promise.resolve(); + const currentResult = previousLock.catch(() => {}).then(() => fn()); + const currentLock = currentResult.then( + () => undefined, + () => undefined + ); + + this.scopeLocks.set(scopeKey, currentLock); + + try { + return await currentResult; + } finally { + if (this.scopeLocks.get(scopeKey) === currentLock) { + this.scopeLocks.delete(scopeKey); + } + } + } + + private snapshotToolPatterns(scopeKey: string): Record { + const snapshot: Record = {}; + for (const [toolName, patterns] of this.getScope(scopeKey).toolPatterns) { + snapshot[toolName] = Array.from(patterns); + } + return snapshot; + } + + private snapshotApprovedDirectories(scopeKey: string): PersistedApprovedDirectory[] { + return Array.from(this.getScope(scopeKey).approvedDirectories.entries()).map( + ([path, type]) => ({ + path, + type, + }) + ); + } + + private async persistScope(sessionId?: string): Promise { + const scopeKey = this.getScopeKey(sessionId); + const state: SessionApprovalState = { + toolPatterns: this.snapshotToolPatterns(scopeKey), + approvedDirectories: this.snapshotApprovedDirectories(scopeKey), + }; + await this.sessionApprovalStore.save(sessionId, state); + } + + private hydrateScope(sessionId: string | undefined, state: SessionApprovalState): void { + const scopeKey = this.getScopeKey(sessionId); + + const toolPatterns = new Map>(); + for (const [toolName, patterns] of Object.entries(state.toolPatterns)) { + toolPatterns.set(toolName, new Set(patterns)); + } + + const approvedDirectories = new Map(); + for (const entry of state.approvedDirectories) { + approvedDirectories.set(entry.path, entry.type); + } + + this.scopes.set(scopeKey, { + toolPatterns, + approvedDirectories, + }); + } + + async restoreSessionState(sessionId?: string): Promise { + const scopeKey = this.getScopeKey(sessionId); + if (this.loadedScopes.has(scopeKey)) { + return; + } + + await this.runWithScopeLock(scopeKey, async () => { + if (this.loadedScopes.has(scopeKey)) { + return; + } + + const state = await this.sessionApprovalStore.load(sessionId); + this.hydrateScope(sessionId, state); + this.loadedScopes.add(scopeKey); + + this.logger.debug('Restored persisted approval state', { + sessionId: this.getScopeLabel(sessionId), + toolCount: Object.keys(state.toolPatterns).length, + directoryCount: state.approvedDirectories.length, + }); + }); + } + + evictSessionState(sessionId?: string): void { + const scopeKey = this.getScopeKey(sessionId); + this.scopes.delete(scopeKey); + this.loadedScopes.delete(scopeKey); + } + + async deleteSessionState(sessionId?: string): Promise { + const scopeKey = this.getScopeKey(sessionId); + await this.runWithScopeLock(scopeKey, async () => { + this.evictSessionState(sessionId); + await this.sessionApprovalStore.delete(sessionId); + }); + } + // ==================== Pattern Methods ==================== - private getOrCreateToolPatternSet(toolName: string): Set { - const existing = this.toolPatterns.get(toolName); + private getOrCreateToolPatternSet(toolName: string, scopeKey: string): Set { + const scope = this.getOrCreateScope(scopeKey).toolPatterns; + const existing = scope.get(toolName); if (existing) return existing; const created = new Set(); - this.toolPatterns.set(toolName, created); + scope.set(toolName, created); return created; } /** * Add an approval pattern for a tool. */ - addPattern(toolName: string, pattern: string): void { - this.getOrCreateToolPatternSet(toolName).add(pattern); - this.logger.debug(`Added pattern for '${toolName}': "${pattern}"`); + async addPattern(toolName: string, pattern: string, sessionId?: string): Promise { + await this.restoreSessionState(sessionId); + const scopeKey = this.getScopeKey(sessionId); + + await this.runWithScopeLock(scopeKey, async () => { + this.getOrCreateToolPatternSet(toolName, scopeKey).add(pattern); + await this.persistScope(sessionId); + }); + + this.logger.debug( + `Added pattern for '${toolName}' in '${this.getScopeLabel(sessionId)}': "${pattern}"` + ); } /** @@ -148,8 +298,9 @@ export class ApprovalManager { * Note: This expects a pattern key (e.g. "git push *"), not raw arguments. * Tools are responsible for generating the key via `tool.approval.patternKey()`. */ - matchesPattern(toolName: string, patternKey: string): boolean { - const patterns = this.toolPatterns.get(toolName); + matchesPattern(toolName: string, patternKey: string, sessionId?: string): boolean { + const scopeKey = this.getScopeKey(sessionId); + const patterns = this.getScope(scopeKey).toolPatterns.get(toolName); if (!patterns || patterns.size === 0) return false; for (const storedPattern of patterns) { @@ -166,40 +317,51 @@ export class ApprovalManager { /** * Clear all patterns for a tool (or all tools when omitted). */ - clearPatterns(toolName?: string): void { - if (toolName) { - const patterns = this.toolPatterns.get(toolName); - if (!patterns) return; - const count = patterns.size; - patterns.clear(); - if (count > 0) { - this.logger.debug(`Cleared ${count} pattern(s) for '${toolName}'`); + async clearPatterns(toolName?: string, sessionId?: string): Promise { + await this.restoreSessionState(sessionId); + const scopeKey = this.getScopeKey(sessionId); + + await this.runWithScopeLock(scopeKey, async () => { + const scope = this.getOrCreateScope(scopeKey).toolPatterns; + if (toolName) { + const patterns = scope.get(toolName); + if (!patterns) return; + const count = patterns.size; + scope.delete(toolName); + await this.persistScope(sessionId); + if (count > 0) { + this.logger.debug( + `Cleared ${count} pattern(s) for '${toolName}' in '${this.getScopeLabel(sessionId)}'` + ); + } + return; } - return; - } - const count = Array.from(this.toolPatterns.values()).reduce( - (sum, set) => sum + set.size, - 0 - ); - this.toolPatterns.clear(); - if (count > 0) { - this.logger.debug(`Cleared ${count} total tool pattern(s)`); - } + const count = Array.from(scope.values()).reduce((sum, set) => sum + set.size, 0); + scope.clear(); + await this.persistScope(sessionId); + if (count > 0) { + this.logger.debug( + `Cleared ${count} total tool pattern(s) in '${this.getScopeLabel(sessionId)}'` + ); + } + }); } /** * Get patterns for a tool (for debugging/display). */ - getToolPatterns(toolName: string): ReadonlySet { - return this.toolPatterns.get(toolName) ?? new Set(); + getToolPatterns(toolName: string, sessionId?: string): ReadonlySet { + const scopeKey = this.getScopeKey(sessionId); + return this.getScope(scopeKey).toolPatterns.get(toolName) ?? new Set(); } /** * Get all tool patterns (for debugging/display). */ - getAllToolPatterns(): ReadonlyMap> { - return this.toolPatterns; + getAllToolPatterns(sessionId?: string): ReadonlyMap> { + const scopeKey = this.getScopeKey(sessionId); + return this.getScope(scopeKey).toolPatterns; } // ==================== Directory Access Methods ==================== @@ -211,8 +373,8 @@ export class ApprovalManager { * continue to work even when other subsystems canonicalize paths via realpath * (e.g. macOS /tmp -> /private/tmp or custom symlinked directories). */ - private getDirectoryApprovalKeys(directory: string): string[] { - const resolved = path.resolve(directory); + private getPathApprovalKeys(targetPath: string): string[] { + const resolved = path.resolve(targetPath); const real = tryRealpathSyncWithExistingParent(resolved); if (real && real !== resolved) { return [resolved, real]; @@ -220,13 +382,30 @@ export class ApprovalManager { return [resolved]; } - private getFileApprovalKeys(filePath: string): string[] { - const resolved = path.resolve(filePath); - const real = tryRealpathSyncWithExistingParent(resolved); - if (real && real !== resolved) { - return [resolved, real]; + private isPathWithinApprovedDirectory( + targetPath: string, + sessionId: string | undefined, + approvedTypes: ReadonlySet<'session' | 'once'> + ): boolean { + const scopeKey = this.getScopeKey(sessionId); + const directoryScope = this.getScope(scopeKey).approvedDirectories; + for (const normalized of this.getPathApprovalKeys(targetPath)) { + for (const [approvedDir, type] of directoryScope) { + if (!approvedTypes.has(type)) { + continue; + } + + const relative = path.relative(approvedDir, normalized); + if (!relative.startsWith('..') && !path.isAbsolute(relative)) { + this.logger.debug( + `Path "${normalized}" is within approved directory "${approvedDir}" (type: ${type})` + ); + return true; + } + } } - return [resolved]; + + return false; } /** @@ -236,8 +415,8 @@ export class ApprovalManager { * * @param workingDir The working directory path */ - initializeWorkingDirectory(workingDir: string): void { - this.addApprovedDirectory(workingDir, 'session'); + async initializeWorkingDirectory(workingDir: string, sessionId?: string): Promise { + await this.addApprovedDirectory(workingDir, 'session', sessionId); } /** @@ -257,40 +436,52 @@ export class ApprovalManager { * // Tool can access, but will prompt again next time * ``` */ - addApprovedDirectory(directory: string, type: 'session' | 'once' = 'session'): void { - const keys = this.getDirectoryApprovalKeys(directory); - - const existingTypes = keys - .map((key) => this.approvedDirectories.get(key)) - .filter((value): value is 'session' | 'once' => value !== undefined); - const hasSessionApproval = existingTypes.includes('session'); + async addApprovedDirectory( + directory: string, + type: 'session' | 'once' = 'session', + sessionId?: string + ): Promise { + await this.restoreSessionState(sessionId); + const scopeKey = this.getScopeKey(sessionId); + + await this.runWithScopeLock(scopeKey, async () => { + const keys = this.getPathApprovalKeys(directory); + const directoryScope = this.getOrCreateScope(scopeKey).approvedDirectories; + + const existingTypes = keys + .map((key) => directoryScope.get(key)) + .filter((value): value is 'session' | 'once' => value !== undefined); + const hasSessionApproval = existingTypes.includes('session'); + + // Never downgrade from session to once, even across realpath aliases + const effectiveType: 'session' | 'once' = + type === 'session' || hasSessionApproval ? 'session' : 'once'; + + for (const key of keys) { + const existing = directoryScope.get(key); + if (existing === 'session') { + continue; + } + directoryScope.set(key, effectiveType); + } - // Never downgrade from session to once, even across realpath aliases - const effectiveType: 'session' | 'once' = - type === 'session' || hasSessionApproval ? 'session' : 'once'; + await this.persistScope(sessionId); - for (const key of keys) { - const existing = this.approvedDirectories.get(key); - if (existing === 'session') { - continue; + const resolvedKey = keys[0]!; + if (effectiveType === 'session' && type === 'once' && hasSessionApproval) { + this.logger.debug( + `Directory "${resolvedKey}" already approved as 'session', not downgrading to 'once'` + ); + return; } - this.approvedDirectories.set(key, effectiveType); - } - const resolvedKey = keys[0]!; - if (effectiveType === 'session' && type === 'once' && hasSessionApproval) { + const realKey = keys.length > 1 ? keys[1] : null; this.logger.debug( - `Directory "${resolvedKey}" already approved as 'session', not downgrading to 'once'` + `Added approved directory in '${this.getScopeLabel(sessionId)}': "${resolvedKey}" (type: ${effectiveType})${ + realKey ? `, realpath: "${realKey}"` : '' + }` ); - return; - } - - const realKey = keys.length > 1 ? keys[1] : null; - this.logger.debug( - `Added approved directory: "${resolvedKey}" (type: ${effectiveType})${ - realKey ? `, realpath: "${realKey}"` : '' - }` - ); + }); } /** @@ -301,22 +492,8 @@ export class ApprovalManager { * @param filePath The file path to check (can be relative or absolute) * @returns true if the path is within a session-approved directory */ - isDirectorySessionApproved(filePath: string): boolean { - for (const normalized of this.getFileApprovalKeys(filePath)) { - for (const [approvedDir, type] of this.approvedDirectories) { - // Only check 'session' type directories for prompting decisions - if (type !== 'session') continue; - - const relative = path.relative(approvedDir, normalized); - if (!relative.startsWith('..') && !path.isAbsolute(relative)) { - this.logger.debug( - `Path "${normalized}" is within session-approved directory "${approvedDir}"` - ); - return true; - } - } - } - return false; + isDirectorySessionApproved(filePath: string, sessionId?: string): boolean { + return this.isPathWithinApprovedDirectory(filePath, sessionId, new Set(['session'])); } /** @@ -327,55 +504,133 @@ export class ApprovalManager { * @param filePath The file path to check (can be relative or absolute) * @returns true if the path is within any approved directory */ - isDirectoryApproved(filePath: string): boolean { - for (const normalized of this.getFileApprovalKeys(filePath)) { - for (const [approvedDir] of this.approvedDirectories) { - const relative = path.relative(approvedDir, normalized); - if (!relative.startsWith('..') && !path.isAbsolute(relative)) { - this.logger.debug( - `Path "${normalized}" is within approved directory "${approvedDir}"` - ); - return true; - } - } - } - return false; + isDirectoryApproved(filePath: string, sessionId?: string): boolean { + return this.isPathWithinApprovedDirectory( + filePath, + sessionId, + new Set(['session', 'once']) + ); } /** * Clear all approved directories. * Should be called when session ends. */ - clearApprovedDirectories(): void { - const count = this.approvedDirectories.size; - this.approvedDirectories.clear(); - if (count > 0) { - this.logger.debug(`Cleared ${count} approved directories`); - } + async clearApprovedDirectories(sessionId?: string): Promise { + await this.restoreSessionState(sessionId); + const scopeKey = this.getScopeKey(sessionId); + + await this.runWithScopeLock(scopeKey, async () => { + const scope = this.getOrCreateScope(scopeKey).approvedDirectories; + const count = scope.size; + scope.clear(); + await this.persistScope(sessionId); + if (count > 0) { + this.logger.debug( + `Cleared ${count} approved directories in '${this.getScopeLabel(sessionId)}'` + ); + } + }); } /** * Get the current map of approved directories with their types (for debugging/display). */ - getApprovedDirectories(): ReadonlyMap { - return this.approvedDirectories; + getApprovedDirectories(sessionId?: string): ReadonlyMap { + const scopeKey = this.getScopeKey(sessionId); + return this.getScope(scopeKey).approvedDirectories; } /** * Get just the directory paths that are approved (for debugging/display). */ - getApprovedDirectoryPaths(): string[] { - return Array.from(this.approvedDirectories.keys()); + getApprovedDirectoryPaths(sessionId?: string): string[] { + return Array.from(this.getApprovedDirectories(sessionId).keys()); } /** * Clear all session-scoped approvals (tool patterns and directories). * Convenience method for clearing all session state at once. */ - clearSessionApprovals(): void { - this.clearPatterns(); - this.clearApprovedDirectories(); - this.logger.debug('Cleared all session approvals'); + async clearSessionApprovals(sessionId?: string): Promise { + await this.restoreSessionState(sessionId); + const scopeKey = this.getScopeKey(sessionId); + + await this.runWithScopeLock(scopeKey, async () => { + const scope = this.getOrCreateScope(scopeKey); + const patternCount = Array.from(scope.toolPatterns.values()).reduce( + (sum, set) => sum + set.size, + 0 + ); + const directoryCount = scope.approvedDirectories.size; + + scope.toolPatterns.clear(); + scope.approvedDirectories.clear(); + await this.persistScope(sessionId); + + if (patternCount > 0 || directoryCount > 0) { + this.logger.debug( + `Cleared ${patternCount} tool pattern(s) and ${directoryCount} approved director${directoryCount === 1 ? 'y' : 'ies'} in '${this.getScopeLabel(sessionId)}'` + ); + } + }); + } + + private createApprovalDetails( + type: ApprovalType, + metadata: ApprovalRequestDetails['metadata'], + sessionId: string | undefined, + timeout?: number + ): ApprovalRequestDetails { + const details: ApprovalRequestDetails = { + type, + timeout: this.getApprovalTimeout(type, timeout), + metadata, + }; + + if (sessionId !== undefined) { + details.sessionId = sessionId; + } + + return details; + } + + private createResponse( + request: ApprovalRequest, + response: Omit + ): ApprovalResponse { + return { + approvalId: request.approvalId, + ...(request.sessionId !== undefined ? { sessionId: request.sessionId } : {}), + ...response, + }; + } + + private getElicitationFormData(response: ApprovalResponse): Record { + if ( + response.data && + typeof response.data === 'object' && + 'formData' in response.data && + typeof (response.data as { formData: unknown }).formData === 'object' && + (response.data as { formData: unknown }).formData !== null + ) { + return (response.data as { formData: Record }).formData; + } + + if ( + response.data === undefined || + (typeof response.data === 'object' && + response.data !== null && + !('formData' in response.data)) + ) { + return {}; + } + + throw ApprovalError.invalidResponse('Approved elicitation response is missing formData', { + approvalId: response.approvalId, + type: ApprovalType.ELICITATION, + field: 'formData', + }); } /** @@ -397,19 +652,14 @@ export class ApprovalManager { metadata: DirectoryAccessMetadata & { sessionId?: string; timeout?: number } ): Promise { const { sessionId, timeout, ...directoryMetadata } = metadata; - - const details: ApprovalRequestDetails = { - type: ApprovalType.DIRECTORY_ACCESS, - // Use provided timeout, fallback to config timeout, or undefined (no timeout) - timeout: timeout !== undefined ? timeout : this.config.permissions.timeout, - metadata: directoryMetadata, - }; - - if (sessionId !== undefined) { - details.sessionId = sessionId; - } - - return this.requestApproval(details); + return this.requestApproval( + this.createApprovalDetails( + ApprovalType.DIRECTORY_ACCESS, + directoryMetadata, + sessionId, + timeout + ) + ); } /** @@ -451,14 +701,9 @@ export class ApprovalManager { this.logger.info( `Auto-approve approval '${request.type}', approvalId: ${request.approvalId}` ); - const response: ApprovalResponse = { - approvalId: request.approvalId, + return this.createResponse(request, { status: ApprovalStatus.APPROVED, - }; - if (request.sessionId !== undefined) { - response.sessionId = request.sessionId; - } - return response; + }); } // Auto-deny mode @@ -466,16 +711,11 @@ export class ApprovalManager { this.logger.info( `Auto-deny approval '${request.type}', approvalId: ${request.approvalId}` ); - const response: ApprovalResponse = { - approvalId: request.approvalId, + return this.createResponse(request, { status: ApprovalStatus.DENIED, reason: DenialReason.SYSTEM_DENIED, message: `Approval automatically denied by system policy (auto-deny mode)`, - }; - if (request.sessionId !== undefined) { - response.sessionId = request.sessionId; - } - return response; + }); } // Manual mode - delegate to handler @@ -497,19 +737,9 @@ export class ApprovalManager { metadata: ToolApprovalMetadata & { sessionId?: string; timeout?: number } ): Promise { const { sessionId, timeout, ...toolMetadata } = metadata; - - const details: ApprovalRequestDetails = { - type: ApprovalType.TOOL_APPROVAL, - // Use provided timeout, fallback to config timeout, or undefined (no timeout) - timeout: timeout !== undefined ? timeout : this.config.permissions.timeout, - metadata: toolMetadata, - }; - - if (sessionId !== undefined) { - details.sessionId = sessionId; - } - - return this.requestApproval(details); + return this.requestApproval( + this.createApprovalDetails(ApprovalType.TOOL_APPROVAL, toolMetadata, sessionId, timeout) + ); } /** @@ -537,19 +767,14 @@ export class ApprovalManager { metadata: CommandConfirmationMetadata & { sessionId?: string; timeout?: number } ): Promise { const { sessionId, timeout, ...commandMetadata } = metadata; - - const details: ApprovalRequestDetails = { - type: ApprovalType.COMMAND_CONFIRMATION, - // Use provided timeout, fallback to config timeout, or undefined (no timeout) - timeout: timeout !== undefined ? timeout : this.config.permissions.timeout, - metadata: commandMetadata, - }; - - if (sessionId !== undefined) { - details.sessionId = sessionId; - } - - return this.requestApproval(details); + return this.requestApproval( + this.createApprovalDetails( + ApprovalType.COMMAND_CONFIRMATION, + commandMetadata, + sessionId, + timeout + ) + ); } /** @@ -563,19 +788,14 @@ export class ApprovalManager { metadata: ElicitationMetadata & { sessionId?: string; timeout?: number } ): Promise { const { sessionId, timeout, ...elicitationMetadata } = metadata; - - const details: ApprovalRequestDetails = { - type: ApprovalType.ELICITATION, - // Use provided timeout, fallback to config timeout, or undefined (no timeout) - timeout: timeout !== undefined ? timeout : this.config.elicitation.timeout, - metadata: elicitationMetadata, - }; - - if (sessionId !== undefined) { - details.sessionId = sessionId; - } - - return this.requestApproval(details); + return this.requestApproval( + this.createApprovalDetails( + ApprovalType.ELICITATION, + elicitationMetadata, + sessionId, + timeout + ) + ); } /** @@ -615,18 +835,7 @@ export class ApprovalManager { const response = await this.requestElicitation(metadata); if (response.status === ApprovalStatus.APPROVED) { - // Extract formData from response (handler always provides formData for elicitation) - if ( - response.data && - typeof response.data === 'object' && - 'formData' in response.data && - typeof (response.data as { formData: unknown }).formData === 'object' && - (response.data as { formData: unknown }).formData !== null - ) { - return (response.data as { formData: Record }).formData; - } - // Fallback to empty form if data is missing (edge case) - return {}; + return this.getElicitationFormData(response); } else if (response.status === ApprovalStatus.DENIED) { throw ApprovalError.elicitationDenied( metadata.serverName, diff --git a/packages/core/src/approval/session-approval-store.ts b/packages/core/src/approval/session-approval-store.ts new file mode 100644 index 000000000..461e679a1 --- /dev/null +++ b/packages/core/src/approval/session-approval-store.ts @@ -0,0 +1,89 @@ +import { z } from 'zod'; +import type { StorageManager } from '../storage/index.js'; +import type { Logger } from '../logger/v2/types.js'; + +const ApprovedDirectoryTypeSchema = z.enum(['session', 'once']); + +const PersistedApprovedDirectorySchema = z + .object({ + path: z.string(), + type: ApprovedDirectoryTypeSchema, + }) + .strict(); + +const SessionApprovalStateSchema = z + .object({ + toolPatterns: z.record(z.array(z.string())).default({}), + approvedDirectories: z.array(PersistedApprovedDirectorySchema).default([]), + }) + .strict(); + +export type PersistedApprovedDirectory = z.output; +export type SessionApprovalState = z.output; + +const DEFAULT_APPROVAL_STATE: SessionApprovalState = { + toolPatterns: {}, + approvedDirectories: [], +}; + +export class SessionApprovalStore { + private readonly cacheTtlSeconds: number; + + constructor( + private readonly storageManager: StorageManager, + private readonly logger: Logger, + options: { cacheTtlMs?: number } = {} + ) { + const cacheTtlMs = options.cacheTtlMs ?? 3600000; + this.cacheTtlSeconds = Math.max(1, Math.floor(cacheTtlMs / 1000)); + } + + private buildKey(sessionId?: string): string { + return sessionId ? `session-approvals:${sessionId}` : 'session-approvals:global'; + } + + async load(sessionId?: string): Promise { + const key = this.buildKey(sessionId); + const cached = await this.storageManager.getCache().get(key); + if (cached !== undefined) { + return this.parseState(cached, key); + } + + const stored = await this.storageManager.getDatabase().get(key); + if (stored === undefined) { + return structuredClone(DEFAULT_APPROVAL_STATE); + } + + const parsed = this.parseState(stored, key); + await this.storageManager.getCache().set(key, parsed, this.cacheTtlSeconds); + return parsed; + } + + async save(sessionId: string | undefined, state: SessionApprovalState): Promise { + const key = this.buildKey(sessionId); + const normalized = SessionApprovalStateSchema.parse(state); + await this.storageManager.getDatabase().set(key, normalized); + await this.storageManager.getCache().set(key, normalized, this.cacheTtlSeconds); + } + + async delete(sessionId?: string): Promise { + const key = this.buildKey(sessionId); + await Promise.all([ + this.storageManager.getDatabase().delete(key), + this.storageManager.getCache().delete(key), + ]); + } + + private parseState(value: unknown, key: string): SessionApprovalState { + const result = SessionApprovalStateSchema.safeParse(value); + if (result.success) { + return result.data; + } + + this.logger.warn('Invalid persisted approval state encountered; using defaults', { + key, + error: result.error.message, + }); + return structuredClone(DEFAULT_APPROVAL_STATE); + } +} diff --git a/packages/core/src/llm/executor/turn-executor.integration.test.ts b/packages/core/src/llm/executor/turn-executor.integration.test.ts index bfc3fdc9d..86b2af60c 100644 --- a/packages/core/src/llm/executor/turn-executor.integration.test.ts +++ b/packages/core/src/llm/executor/turn-executor.integration.test.ts @@ -19,6 +19,11 @@ import type { LanguageModel, ModelMessage } from 'ai'; import type { LLMContext } from '../types.js'; import type { ValidatedLLMConfig } from '../schemas.js'; import type { Logger } from '../../logger/v2/types.js'; +import { + createInMemoryMessageQueueStore, + createInMemorySessionApprovalStore, + createInMemorySessionToolPreferencesStore, +} from '../../test-utils/session-state-stores.js'; // Only mock the AI SDK's streamText/generateText - everything else is real vi.mock('ai', async (importOriginal) => { @@ -206,7 +211,8 @@ describe('TurnExecutor Integration Tests', () => { permissions: { mode: 'auto-approve', timeout: 120000 }, elicitation: { enabled: false, timeout: 120000 }, }, - logger + logger, + createInMemorySessionApprovalStore(logger) ); // Create real tool manager (minimal setup - no internal tools) @@ -224,12 +230,18 @@ describe('TurnExecutor Integration Tests', () => { agentEventBus, { alwaysAllow: [], alwaysDeny: [] }, [], - logger + logger, + createInMemorySessionToolPreferencesStore(logger) ); await toolManager.initialize(); // Create real message queue - messageQueue = new MessageQueueService(sessionEventBus, logger); + messageQueue = new MessageQueueService( + sessionEventBus, + logger, + sessionId, + createInMemoryMessageQueueStore() + ); // Default streamText mock - simple text response vi.mocked(streamText).mockImplementation( @@ -392,7 +404,7 @@ describe('TurnExecutor Integration Tests', () => { describe('Message Queue Injection', () => { it('should inject queued messages into context', async () => { - messageQueue.enqueue({ + await messageQueue.enqueue({ content: [{ type: 'text', text: 'User guidance: focus on performance' }], }); @@ -415,13 +427,21 @@ describe('TurnExecutor Integration Tests', () => { vi.mocked(streamText).mockImplementation(() => { callCount++; if (callCount === 1) { - messageQueue.enqueue({ + const queuedFollowUp = messageQueue.enqueue({ content: [{ type: 'text', text: 'Follow-up question' }], }); - return createMockStream({ + const firstStream = createMockStream({ text: 'First response', finishReason: 'stop', - }) as unknown as ReturnType; + }); + return { + fullStream: (async function* () { + await queuedFollowUp; + for await (const event of firstStream.fullStream) { + yield event; + } + })(), + } as unknown as ReturnType; } return createMockStream({ text: 'Second response', @@ -469,7 +489,12 @@ describe('TurnExecutor Integration Tests', () => { expect(generateText).toHaveBeenCalledTimes(1); // Second executor with same baseURL should use cache - const newMessageQueue = new MessageQueueService(sessionEventBus, logger); + const newMessageQueue = new MessageQueueService( + sessionEventBus, + logger, + 'session-2', + createInMemoryMessageQueueStore() + ); const executor2 = new TurnExecutor( createMockModel(), toolManager, @@ -668,16 +693,16 @@ describe('TurnExecutor Integration Tests', () => { describe('Cleanup and Resource Management', () => { it('should clear message queue on normal completion', async () => { - messageQueue.enqueue({ content: [{ type: 'text', text: 'Pending' }] }); + await messageQueue.enqueue({ content: [{ type: 'text', text: 'Pending' }] }); await contextManager.addUserMessage([{ type: 'text', text: 'Hello' }]); await executor.execute({ mcpManager }, true); - expect(messageQueue.dequeueAll()).toBeNull(); + await expect(messageQueue.dequeueAll()).resolves.toBeNull(); }); it('should clear message queue on error', async () => { - messageQueue.enqueue({ content: [{ type: 'text', text: 'Pending' }] }); + await messageQueue.enqueue({ content: [{ type: 'text', text: 'Pending' }] }); vi.mocked(streamText).mockImplementation(() => { throw new Error('Failed'); @@ -686,7 +711,7 @@ describe('TurnExecutor Integration Tests', () => { await contextManager.addUserMessage([{ type: 'text', text: 'Hello' }]); await expect(executor.execute({ mcpManager }, true)).rejects.toThrow(); - expect(messageQueue.dequeueAll()).toBeNull(); + await expect(messageQueue.dequeueAll()).resolves.toBeNull(); }); }); diff --git a/packages/core/src/llm/executor/turn-executor.ts b/packages/core/src/llm/executor/turn-executor.ts index c4c1920e9..96ca0a96f 100644 --- a/packages/core/src/llm/executor/turn-executor.ts +++ b/packages/core/src/llm/executor/turn-executor.ts @@ -224,7 +224,7 @@ export class TurnExecutor { } // 1. Check for queued messages (mid-loop injection) - const coalesced = this.messageQueue.dequeueAll(); + const coalesced = await this.messageQueue.dequeueAll(); if (coalesced) { await this.injectQueuedMessages(coalesced); } @@ -435,7 +435,7 @@ export class TurnExecutor { // Check queue before terminating - process queued messages if any // Note: Hard cancel clears the queue BEFORE aborting, so if messages exist // here it means soft cancel - we should continue processing them - const queuedOnTerminate = this.messageQueue.dequeueAll(); + const queuedOnTerminate = await this.messageQueue.dequeueAll(); if (queuedOnTerminate) { this.logger.debug( `Continuing: ${queuedOnTerminate.messages.length} queued message(s) to process` @@ -1022,7 +1022,13 @@ export class TurnExecutor { } // Clear any pending queued messages - this.messageQueue.clear(); + void this.messageQueue.clear().catch((error) => { + this.logger.warn( + `Failed to clear queued follow-up messages during cleanup: ${ + error instanceof Error ? error.message : String(error) + }` + ); + }); } /** diff --git a/packages/core/src/llm/services/factory.ts b/packages/core/src/llm/services/factory.ts index 2a110163d..efd066b7b 100644 --- a/packages/core/src/llm/services/factory.ts +++ b/packages/core/src/llm/services/factory.ts @@ -356,10 +356,10 @@ export function createLLMService( sessionId: string, resourceManager: import('../../resources/index.js').ResourceManager, logger: Logger, - options: CreateLLMServiceOptions = {}, + options: CreateLLMServiceOptions, languageModelFactory?: LanguageModelFactory ): VercelLLMService { - const { usageScopeId, compactionStrategy } = options; + const { usageScopeId, compactionStrategy, messageQueue } = options; const providerContext: DextoProviderContext = { sessionId, @@ -390,6 +390,7 @@ export function createLLMService( sessionId, resourceManager, logger, + messageQueue, usageScopeId, compactionStrategy ); diff --git a/packages/core/src/llm/services/types.ts b/packages/core/src/llm/services/types.ts index 035c23973..43c7d2062 100644 --- a/packages/core/src/llm/services/types.ts +++ b/packages/core/src/llm/services/types.ts @@ -3,6 +3,7 @@ import type { LanguageModel } from 'ai'; import type { CodexRateLimitSnapshot } from '../providers/codex-app-server.js'; import type { ValidatedLLMConfig } from '../schemas.js'; import type { LLMProvider } from '../types.js'; +import type { MessageQueueService } from '../../session/message-queue.js'; /** * Configuration object returned by the default session LLM service. @@ -18,6 +19,7 @@ export interface CreateLLMServiceOptions { usageScopeId?: string | undefined; compactionStrategy?: CompactionStrategy | null | undefined; cwd?: string | undefined; + messageQueue: MessageQueueService; } /** diff --git a/packages/core/src/llm/services/vercel.ts b/packages/core/src/llm/services/vercel.ts index 02f12b2e5..d99284ce4 100644 --- a/packages/core/src/llm/services/vercel.ts +++ b/packages/core/src/llm/services/vercel.ts @@ -73,6 +73,7 @@ export class VercelLLMService { sessionId: string, resourceManager: ResourceManager, logger: Logger, + messageQueue: MessageQueueService, usageScopeId?: string, compactionStrategy?: import('../../context/compaction/types.js').CompactionStrategy | null ) { @@ -86,8 +87,7 @@ export class VercelLLMService { this.usageScopeId = usageScopeId; this.compactionStrategy = compactionStrategy ?? null; - // Create session-level message queue for mid-task user messages - this.messageQueue = new MessageQueueService(this.sessionEventBus, this.logger); + this.messageQueue = messageQueue; // Create properly-typed ContextManager for Vercel const formatter = new VercelMessageFormatter(this.logger); diff --git a/packages/core/src/session/chat-session.test.ts b/packages/core/src/session/chat-session.test.ts index 0764e224a..8ef88b507 100644 --- a/packages/core/src/session/chat-session.test.ts +++ b/packages/core/src/session/chat-session.test.ts @@ -163,6 +163,11 @@ describe('ChatSession', () => { toolManager: { getAllTools: vi.fn().mockReturnValue([]), }, + messageQueueStore: { + load: vi.fn().mockResolvedValue([]), + save: vi.fn().mockResolvedValue(undefined), + delete: vi.fn().mockResolvedValue(undefined), + }, hookManager: { executeHooks: vi.fn().mockImplementation(async (_point, payload) => payload), cleanup: vi.fn(), @@ -232,6 +237,7 @@ describe('ChatSession', () => { usageScopeId: undefined, compactionStrategy: null, cwd: '/tmp/dexto-cloud', + messageQueue: expect.any(Object), }), undefined ); @@ -255,10 +261,11 @@ describe('ChatSession', () => { sessionId, mockServices.resourceManager, expect.any(Object), - { + expect.objectContaining({ usageScopeId: undefined, compactionStrategy: null, - }, + messageQueue: expect.any(Object), + }), languageModelFactory ); expect(chatSession.getLLMService()).toBe(mockLLMService); @@ -341,10 +348,11 @@ describe('ChatSession', () => { sessionId, mockServices.resourceManager, expect.any(Object), - { + expect.objectContaining({ usageScopeId: undefined, compactionStrategy: null, - }, + messageQueue: expect.any(Object), + }), undefined ); }); @@ -371,10 +379,11 @@ describe('ChatSession', () => { sessionId, mockServices.resourceManager, expect.any(Object), - { + expect.objectContaining({ usageScopeId: undefined, compactionStrategy: null, - }, + messageQueue: expect.any(Object), + }), undefined ); }); @@ -492,10 +501,11 @@ describe('ChatSession', () => { sessionId, mockServices.resourceManager, // ResourceManager parameter expect.any(Object), // Logger parameter - { + expect.objectContaining({ usageScopeId: undefined, compactionStrategy: null, - }, + messageQueue: expect.any(Object), + }), undefined ); diff --git a/packages/core/src/session/chat-session.ts b/packages/core/src/session/chat-session.ts index a6352bda5..9c8bdcb5f 100644 --- a/packages/core/src/session/chat-session.ts +++ b/packages/core/src/session/chat-session.ts @@ -24,7 +24,8 @@ import { DextoLogComponent } from '../logger/v2/types.js'; import { DextoRuntimeError, ErrorScope, ErrorType } from '../errors/index.js'; import { HookErrorCode } from '../hooks/error-codes.js'; import type { InternalMessage, ContentPart } from '../context/types.js'; -import type { UserMessageInput } from './message-queue.js'; +import { MessageQueueService, type UserMessageInput } from './message-queue.js'; +import type { MessageQueueStore } from './message-queue-store.js'; import type { ContentInput } from '../agent/types.js'; import { getUsagePricingMetadata, hasMeaningfulTokenUsage } from '../llm/usage-metadata.js'; import type { CompactionStrategy } from '../context/compaction/types.js'; @@ -103,6 +104,12 @@ export class ChatSession { */ private llmService!: VercelLLMService; + /** + * Durable queued follow-up messages for this session. + * Reused across LLM switches so mid-task follow-ups survive service recreation. + */ + private messageQueue!: MessageQueueService; + /** * Map of event forwarder functions for cleanup. * Stores the bound functions so they can be removed from the event bus. @@ -146,6 +153,7 @@ export class ChatSession { hookManager: HookManager; mcpManager: MCPManager; sessionManager: import('./session-manager.js').SessionManager; + messageQueueStore: Pick; languageModelFactory?: LanguageModelFactory; workspaceManager?: import('../workspace/manager.js').WorkspaceManager; compactionStrategy: CompactionStrategy | null; @@ -156,6 +164,12 @@ export class ChatSession { this.logger = logger.createChild(DextoLogComponent.SESSION); // Create session-specific event bus this.eventBus = new SessionEventBus(); + this.messageQueue = new MessageQueueService( + this.eventBus, + this.logger, + this.id, + this.services.messageQueueStore + ); // Set up event forwarding to agent's global bus this.setupEventForwarding(); @@ -271,6 +285,8 @@ export class ChatSession { const runtimeConfig = this.services.stateManager.getRuntimeConfig(this.id); const llmConfig = runtimeConfig.llm; + await this.messageQueue.initialize(); + // Create session-specific history provider directly with database backend // This persists across LLM switches to maintain conversation history this.historyProvider = createDatabaseHistoryProvider( @@ -293,6 +309,7 @@ export class ChatSession { usageScopeId, compactionStrategy: this.services.compactionStrategy, ...(workspace?.path !== undefined && { cwd: workspace.path }), + messageQueue: this.messageQueue, }; return createLLMService( @@ -769,8 +786,10 @@ export class ChatSession { * @param message The user message to queue * @returns Queue position and message ID */ - public queueMessage(message: UserMessageInput): { queued: true; position: number; id: string } { - return this.llmService.getMessageQueue().enqueue(message); + public async queueMessage( + message: UserMessageInput + ): Promise<{ queued: true; position: number; id: string }> { + return await this.llmService.getMessageQueue().enqueue(message); } /** @@ -786,18 +805,18 @@ export class ChatSession { * @param id Message ID to remove * @returns true if message was found and removed; false otherwise */ - public removeQueuedMessage(id: string): boolean { - return this.llmService.getMessageQueue().remove(id); + public async removeQueuedMessage(id: string): Promise { + return await this.llmService.getMessageQueue().remove(id); } /** * Clear all queued messages. * @returns Number of messages that were cleared */ - public clearMessageQueue(): number { + public async clearMessageQueue(): Promise { const queue = this.llmService.getMessageQueue(); const count = queue.pendingCount(); - queue.clear(); + await queue.clear(); return count; } diff --git a/packages/core/src/session/message-queue-store.ts b/packages/core/src/session/message-queue-store.ts new file mode 100644 index 000000000..4a9c98b84 --- /dev/null +++ b/packages/core/src/session/message-queue-store.ts @@ -0,0 +1,62 @@ +import type { Logger } from '../logger/v2/types.js'; +import type { StorageManager } from '../storage/index.js'; +import type { QueuedMessage } from './types.js'; + +export class MessageQueueStore { + private readonly cacheTtlSeconds: number; + + constructor( + private readonly storageManager: StorageManager, + private readonly logger: Logger, + options: { cacheTtlMs?: number } = {} + ) { + const cacheTtlMs = options.cacheTtlMs ?? 3600000; + this.cacheTtlSeconds = Math.max(1, Math.floor(cacheTtlMs / 1000)); + } + + private buildKey(sessionId: string): string { + return `session-message-queue:${sessionId}`; + } + + async load(sessionId: string): Promise { + const key = this.buildKey(sessionId); + const cached = await this.storageManager.getCache().get(key); + if (Array.isArray(cached)) { + return structuredClone(cached); + } + + const stored = await this.storageManager.getDatabase().get(key); + if (!Array.isArray(stored)) { + if (stored !== undefined) { + this.logger.warn('Invalid persisted message queue encountered; ignoring state', { + key, + }); + } + return []; + } + + const cloned = structuredClone(stored); + await this.storageManager.getCache().set(key, cloned, this.cacheTtlSeconds); + return cloned; + } + + async save(sessionId: string, queue: QueuedMessage[]): Promise { + const key = this.buildKey(sessionId); + if (queue.length === 0) { + await this.delete(sessionId); + return; + } + + const cloned = structuredClone(queue); + await this.storageManager.getDatabase().set(key, cloned); + await this.storageManager.getCache().set(key, cloned, this.cacheTtlSeconds); + } + + async delete(sessionId: string): Promise { + const key = this.buildKey(sessionId); + await Promise.all([ + this.storageManager.getDatabase().delete(key), + this.storageManager.getCache().delete(key), + ]); + } +} diff --git a/packages/core/src/session/message-queue.test.ts b/packages/core/src/session/message-queue.test.ts index a72f7fc04..6d51ba6e1 100644 --- a/packages/core/src/session/message-queue.test.ts +++ b/packages/core/src/session/message-queue.test.ts @@ -4,6 +4,17 @@ import type { SessionEventBus } from '../events/index.js'; import type { ContentPart } from '../context/types.js'; import { createMockLogger } from '../logger/v2/test-utils.js'; import type { Logger } from '../logger/v2/types.js'; +import { createInMemoryMessageQueueStore } from '../test-utils/session-state-stores.js'; + +function createDeferred() { + let resolve!: (value: T | PromiseLike) => void; + let reject!: (reason?: unknown) => void; + const promise = new Promise((res, rej) => { + resolve = res; + reject = rej; + }); + return { promise, resolve, reject }; +} // Create a mock SessionEventBus function createMockEventBus(): SessionEventBus { @@ -24,36 +35,41 @@ describe('MessageQueueService', () => { beforeEach(() => { eventBus = createMockEventBus(); logger = createMockLogger(); - queue = new MessageQueueService(eventBus, logger); + queue = new MessageQueueService( + eventBus, + logger, + 'session-1', + createInMemoryMessageQueueStore() + ); }); describe('enqueue()', () => { - it('should add a message to the queue and return position and id', () => { + it('should add a message to the queue and return position and id', async () => { const content: ContentPart[] = [{ type: 'text', text: 'hello' }]; - const result = queue.enqueue({ content }); + const result = await queue.enqueue({ content }); expect(result.queued).toBe(true); expect(result.position).toBe(1); expect(result.id).toMatch(/^msg_\d+_[a-z0-9]+$/); }); - it('should increment position for multiple enqueued messages', () => { + it('should increment position for multiple enqueued messages', async () => { const content: ContentPart[] = [{ type: 'text', text: 'hello' }]; - const result1 = queue.enqueue({ content }); - const result2 = queue.enqueue({ content }); - const result3 = queue.enqueue({ content }); + const result1 = await queue.enqueue({ content }); + const result2 = await queue.enqueue({ content }); + const result3 = await queue.enqueue({ content }); expect(result1.position).toBe(1); expect(result2.position).toBe(2); expect(result3.position).toBe(3); }); - it('should emit message:queued event with correct data', () => { + it('should emit message:queued event with correct data', async () => { const content: ContentPart[] = [{ type: 'text', text: 'hello' }]; - const result = queue.enqueue({ content }); + const result = await queue.enqueue({ content }); expect(eventBus.emit).toHaveBeenCalledWith('message:queued', { position: 1, @@ -61,23 +77,23 @@ describe('MessageQueueService', () => { }); }); - it('should include metadata when provided', () => { + it('should include metadata when provided', async () => { const content: ContentPart[] = [{ type: 'text', text: 'hello' }]; const metadata = { source: 'api', priority: 'high' }; - queue.enqueue({ content, metadata }); - const coalesced = queue.dequeueAll(); + await queue.enqueue({ content, metadata }); + const coalesced = await queue.dequeueAll(); expect(coalesced).not.toBeNull(); const firstMessage = coalesced?.messages[0]; expect(firstMessage?.metadata).toEqual(metadata); }); - it('should not include metadata field when not provided', () => { + it('should not include metadata field when not provided', async () => { const content: ContentPart[] = [{ type: 'text', text: 'hello' }]; - queue.enqueue({ content }); - const coalesced = queue.dequeueAll(); + await queue.enqueue({ content }); + const coalesced = await queue.dequeueAll(); expect(coalesced).not.toBeNull(); const firstMessage = coalesced?.messages[0]; @@ -86,35 +102,35 @@ describe('MessageQueueService', () => { }); describe('dequeueAll()', () => { - it('should return null when queue is empty', () => { - const result = queue.dequeueAll(); + it('should return null when queue is empty', async () => { + const result = await queue.dequeueAll(); expect(result).toBeNull(); }); - it('should return CoalescedMessage with single message', () => { + it('should return CoalescedMessage with single message', async () => { const content: ContentPart[] = [{ type: 'text', text: 'hello' }]; - queue.enqueue({ content }); + await queue.enqueue({ content }); - const result = queue.dequeueAll(); + const result = await queue.dequeueAll(); expect(result).not.toBeNull(); expect(result?.messages).toHaveLength(1); expect(result?.combinedContent).toEqual(content); }); - it('should clear the queue after dequeue', () => { - queue.enqueue({ content: [{ type: 'text', text: 'hello' }] }); - queue.dequeueAll(); + it('should clear the queue after dequeue', async () => { + await queue.enqueue({ content: [{ type: 'text', text: 'hello' }] }); + await queue.dequeueAll(); expect(queue.hasPending()).toBe(false); expect(queue.pendingCount()).toBe(0); }); - it('should emit message:dequeued event with correct data', () => { - queue.enqueue({ content: [{ type: 'text', text: 'msg1' }] }); - queue.enqueue({ content: [{ type: 'text', text: 'msg2' }] }); + it('should emit message:dequeued event with correct data', async () => { + await queue.enqueue({ content: [{ type: 'text', text: 'msg1' }] }); + await queue.enqueue({ content: [{ type: 'text', text: 'msg2' }] }); - queue.dequeueAll(); + await queue.dequeueAll(); expect(eventBus.emit).toHaveBeenCalledWith('message:dequeued', { count: 2, @@ -125,10 +141,10 @@ describe('MessageQueueService', () => { }); }); - it('should set coalesced to false for single message', () => { - queue.enqueue({ content: [{ type: 'text', text: 'solo' }] }); + it('should set coalesced to false for single message', async () => { + await queue.enqueue({ content: [{ type: 'text', text: 'solo' }] }); - queue.dequeueAll(); + await queue.dequeueAll(); expect(eventBus.emit).toHaveBeenCalledWith('message:dequeued', { count: 1, @@ -141,23 +157,23 @@ describe('MessageQueueService', () => { }); describe('coalescing', () => { - it('should return single message content as-is', () => { + it('should return single message content as-is', async () => { const content: ContentPart[] = [ { type: 'text', text: 'hello world' }, { type: 'image', image: 'base64data', mimeType: 'image/png' }, ]; - queue.enqueue({ content }); + await queue.enqueue({ content }); - const result = queue.dequeueAll(); + const result = await queue.dequeueAll(); expect(result?.combinedContent).toEqual(content); }); - it('should prefix two messages with First and Also', () => { - queue.enqueue({ content: [{ type: 'text', text: 'stop' }] }); - queue.enqueue({ content: [{ type: 'text', text: 'try another way' }] }); + it('should prefix two messages with First and Also', async () => { + await queue.enqueue({ content: [{ type: 'text', text: 'stop' }] }); + await queue.enqueue({ content: [{ type: 'text', text: 'try another way' }] }); - const result = queue.dequeueAll(); + const result = await queue.dequeueAll(); expect(result?.combinedContent).toHaveLength(3); // First + separator + Also expect(result?.combinedContent[0]).toEqual({ type: 'text', text: 'First: stop' }); @@ -168,12 +184,12 @@ describe('MessageQueueService', () => { }); }); - it('should number three or more messages', () => { - queue.enqueue({ content: [{ type: 'text', text: 'one' }] }); - queue.enqueue({ content: [{ type: 'text', text: 'two' }] }); - queue.enqueue({ content: [{ type: 'text', text: 'three' }] }); + it('should number three or more messages', async () => { + await queue.enqueue({ content: [{ type: 'text', text: 'one' }] }); + await queue.enqueue({ content: [{ type: 'text', text: 'two' }] }); + await queue.enqueue({ content: [{ type: 'text', text: 'three' }] }); - const result = queue.dequeueAll(); + const result = await queue.dequeueAll(); expect(result?.combinedContent).toHaveLength(5); // 3 messages + 2 separators expect(result?.combinedContent[0]).toEqual({ type: 'text', text: '[1]: one' }); @@ -181,18 +197,18 @@ describe('MessageQueueService', () => { expect(result?.combinedContent[4]).toEqual({ type: 'text', text: '[3]: three' }); }); - it('should preserve multimodal content (text + images)', () => { - queue.enqueue({ + it('should preserve multimodal content (text + images)', async () => { + await queue.enqueue({ content: [ { type: 'text', text: 'look at this' }, { type: 'image', image: 'base64img1', mimeType: 'image/png' }, ], }); - queue.enqueue({ + await queue.enqueue({ content: [{ type: 'image', image: 'base64img2', mimeType: 'image/jpeg' }], }); - const result = queue.dequeueAll(); + const result = await queue.dequeueAll(); // Should have: "First: look at this", image1, separator, "Also: ", image2 expect(result?.combinedContent).toHaveLength(5); @@ -213,14 +229,14 @@ describe('MessageQueueService', () => { }); }); - it('should tag user messages in mixed batches', () => { - queue.enqueue({ content: [{ type: 'text', text: 'user note' }] }); - queue.enqueue({ + it('should tag user messages in mixed batches', async () => { + await queue.enqueue({ content: [{ type: 'text', text: 'user note' }] }); + await queue.enqueue({ content: [{ type: 'text', text: 'bg payload' }], kind: 'background', }); - const result = queue.dequeueAll(); + const result = await queue.dequeueAll(); expect(result?.combinedContent[0]).toEqual({ type: 'text', @@ -229,11 +245,11 @@ describe('MessageQueueService', () => { expect(result?.combinedContent[2]).toEqual({ type: 'text', text: 'bg payload' }); }); - it('should handle empty message content with placeholder', () => { - queue.enqueue({ content: [{ type: 'text', text: 'first' }] }); - queue.enqueue({ content: [] }); + it('should handle empty message content with placeholder', async () => { + await queue.enqueue({ content: [{ type: 'text', text: 'first' }] }); + await queue.enqueue({ content: [] }); - const result = queue.dequeueAll(); + const result = await queue.dequeueAll(); expect(result?.combinedContent).toContainEqual({ type: 'text', @@ -242,14 +258,14 @@ describe('MessageQueueService', () => { }); it('should set correct firstQueuedAt and lastQueuedAt timestamps', async () => { - queue.enqueue({ content: [{ type: 'text', text: 'first' }] }); + await queue.enqueue({ content: [{ type: 'text', text: 'first' }] }); // Small delay to ensure different timestamps await new Promise((resolve) => setTimeout(resolve, 10)); - queue.enqueue({ content: [{ type: 'text', text: 'second' }] }); + await queue.enqueue({ content: [{ type: 'text', text: 'second' }] }); - const result = queue.dequeueAll(); + const result = await queue.dequeueAll(); expect(result?.firstQueuedAt).toBeLessThan(result?.lastQueuedAt ?? 0); }); @@ -261,35 +277,83 @@ describe('MessageQueueService', () => { expect(queue.pendingCount()).toBe(0); }); - it('should return true and correct count for non-empty queue', () => { - queue.enqueue({ content: [{ type: 'text', text: 'msg1' }] }); - queue.enqueue({ content: [{ type: 'text', text: 'msg2' }] }); + it('should return true and correct count for non-empty queue', async () => { + await queue.enqueue({ content: [{ type: 'text', text: 'msg1' }] }); + await queue.enqueue({ content: [{ type: 'text', text: 'msg2' }] }); expect(queue.hasPending()).toBe(true); expect(queue.pendingCount()).toBe(2); }); - it('should update after dequeue', () => { - queue.enqueue({ content: [{ type: 'text', text: 'msg1' }] }); + it('should update after dequeue', async () => { + await queue.enqueue({ content: [{ type: 'text', text: 'msg1' }] }); expect(queue.pendingCount()).toBe(1); - queue.dequeueAll(); + await queue.dequeueAll(); expect(queue.hasPending()).toBe(false); expect(queue.pendingCount()).toBe(0); }); }); + describe('initialize()', () => { + it('should serialize initialization with concurrent queued mutations', async () => { + const loadStarted = createDeferred(); + const releaseLoad = + createDeferred>(); + const savedQueues: Array< + Array<{ id: string; content: ContentPart[]; queuedAt: number }> + > = []; + const serializedQueue = new MessageQueueService(eventBus, logger, 'session-2', { + load: vi.fn().mockImplementation(async () => { + loadStarted.resolve(); + return await releaseLoad.promise; + }), + save: vi.fn().mockImplementation(async (_sessionId, nextQueue) => { + savedQueues.push(structuredClone(nextQueue)); + }), + delete: vi.fn().mockResolvedValue(undefined), + }); + + const initializePromise = serializedQueue.initialize(); + await loadStarted.promise; + + const enqueuePromise = serializedQueue.enqueue({ + content: [{ type: 'text', text: 'new follow-up' }], + }); + + releaseLoad.resolve([ + { + id: 'restored-message', + content: [{ type: 'text', text: 'restored follow-up' }], + queuedAt: 1, + }, + ]); + + await initializePromise; + const enqueued = await enqueuePromise; + + expect(serializedQueue.getAll().map((message) => message.id)).toEqual([ + 'restored-message', + enqueued.id, + ]); + expect(savedQueues.at(-1)?.map((message) => message.id)).toEqual([ + 'restored-message', + enqueued.id, + ]); + }); + }); + describe('clear()', () => { - it('should empty the queue', () => { - queue.enqueue({ content: [{ type: 'text', text: 'msg1' }] }); - queue.enqueue({ content: [{ type: 'text', text: 'msg2' }] }); + it('should empty the queue', async () => { + await queue.enqueue({ content: [{ type: 'text', text: 'msg1' }] }); + await queue.enqueue({ content: [{ type: 'text', text: 'msg2' }] }); - queue.clear(); + await queue.clear(); expect(queue.hasPending()).toBe(false); expect(queue.pendingCount()).toBe(0); - expect(queue.dequeueAll()).toBeNull(); + await expect(queue.dequeueAll()).resolves.toBeNull(); }); }); @@ -298,9 +362,9 @@ describe('MessageQueueService', () => { expect(queue.getAll()).toEqual([]); }); - it('should return shallow copy of queued messages', () => { - const result1 = queue.enqueue({ content: [{ type: 'text', text: 'msg1' }] }); - const result2 = queue.enqueue({ content: [{ type: 'text', text: 'msg2' }] }); + it('should return shallow copy of queued messages', async () => { + const result1 = await queue.enqueue({ content: [{ type: 'text', text: 'msg1' }] }); + const result2 = await queue.enqueue({ content: [{ type: 'text', text: 'msg2' }] }); const all = queue.getAll(); @@ -309,8 +373,8 @@ describe('MessageQueueService', () => { expect(all[1]?.id).toBe(result2.id); }); - it('should not allow external mutation of queue', () => { - queue.enqueue({ content: [{ type: 'text', text: 'msg1' }] }); + it('should not allow external mutation of queue', async () => { + await queue.enqueue({ content: [{ type: 'text', text: 'msg1' }] }); const all = queue.getAll(); all.push({ @@ -328,9 +392,9 @@ describe('MessageQueueService', () => { expect(queue.get('non-existent')).toBeUndefined(); }); - it('should return message by id', () => { + it('should return message by id', async () => { const content: ContentPart[] = [{ type: 'text', text: 'hello' }]; - const result = queue.enqueue({ content }); + const result = await queue.enqueue({ content }); const msg = queue.get(result.id); @@ -341,8 +405,8 @@ describe('MessageQueueService', () => { }); describe('remove()', () => { - it('should return false for non-existent id', () => { - const result = queue.remove('non-existent'); + it('should return false for non-existent id', async () => { + const result = await queue.remove('non-existent'); expect(result).toBe(false); expect(logger.debug).toHaveBeenCalledWith( @@ -350,42 +414,42 @@ describe('MessageQueueService', () => { ); }); - it('should remove message and return true', () => { - const result = queue.enqueue({ content: [{ type: 'text', text: 'to remove' }] }); + it('should remove message and return true', async () => { + const result = await queue.enqueue({ content: [{ type: 'text', text: 'to remove' }] }); - const removed = queue.remove(result.id); + const removed = await queue.remove(result.id); expect(removed).toBe(true); expect(queue.get(result.id)).toBeUndefined(); expect(queue.pendingCount()).toBe(0); }); - it('should emit message:removed event', () => { - const result = queue.enqueue({ content: [{ type: 'text', text: 'to remove' }] }); + it('should emit message:removed event', async () => { + const result = await queue.enqueue({ content: [{ type: 'text', text: 'to remove' }] }); - queue.remove(result.id); + await queue.remove(result.id); expect(eventBus.emit).toHaveBeenCalledWith('message:removed', { id: result.id, }); }); - it('should log debug message on successful removal', () => { - const result = queue.enqueue({ content: [{ type: 'text', text: 'to remove' }] }); + it('should log debug message on successful removal', async () => { + const result = await queue.enqueue({ content: [{ type: 'text', text: 'to remove' }] }); - queue.remove(result.id); + await queue.remove(result.id); expect(logger.debug).toHaveBeenCalledWith( `Message removed: ${result.id}, remaining: 0` ); }); - it('should maintain order of remaining messages', () => { - const r1 = queue.enqueue({ content: [{ type: 'text', text: 'first' }] }); - const r2 = queue.enqueue({ content: [{ type: 'text', text: 'second' }] }); - const r3 = queue.enqueue({ content: [{ type: 'text', text: 'third' }] }); + it('should maintain order of remaining messages', async () => { + const r1 = await queue.enqueue({ content: [{ type: 'text', text: 'first' }] }); + const r2 = await queue.enqueue({ content: [{ type: 'text', text: 'second' }] }); + const r3 = await queue.enqueue({ content: [{ type: 'text', text: 'third' }] }); - queue.remove(r2.id); + await queue.remove(r2.id); const all = queue.getAll(); expect(all).toHaveLength(2); diff --git a/packages/core/src/session/message-queue.ts b/packages/core/src/session/message-queue.ts index 7e333ee25..20db4453d 100644 --- a/packages/core/src/session/message-queue.ts +++ b/packages/core/src/session/message-queue.ts @@ -2,6 +2,25 @@ import type { SessionEventBus } from '../events/index.js'; import type { QueuedMessage, CoalescedMessage } from './types.js'; import type { ContentPart } from '../context/types.js'; import type { Logger } from '../logger/v2/types.js'; +import type { MessageQueueStore } from './message-queue-store.js'; + +type MessageQueueBackingStore = Pick; + +class EphemeralMessageQueueStore implements MessageQueueBackingStore { + async load(sessionId: string): Promise { + void sessionId; + return []; + } + + async save(sessionId: string, queue: QueuedMessage[]): Promise { + void sessionId; + void queue; + } + + async delete(sessionId: string): Promise { + void sessionId; + } +} /** * Generates a unique ID for queued messages. @@ -54,12 +73,65 @@ export interface UserMessageInput { */ export class MessageQueueService { private queue: QueuedMessage[] = []; + private mutationLock: Promise = Promise.resolve(); + private initialized = false; + private initializationPromise: Promise | null = null; + + static createEphemeral( + eventBus: SessionEventBus, + logger: Logger, + sessionId: string + ): MessageQueueService { + return new MessageQueueService( + eventBus, + logger, + sessionId, + new EphemeralMessageQueueStore() + ); + } constructor( private eventBus: SessionEventBus, - private logger: Logger + private logger: Logger, + private sessionId: string, + private store: MessageQueueBackingStore ) {} + async initialize(): Promise { + this.initializationPromise ??= this.runWithMutationLock(async () => { + if (this.initialized) { + return; + } + + this.queue = await this.store.load(this.sessionId); + if (this.queue.length > 0) { + this.logger.debug( + `Restored ${this.queue.length} queued message(s) for session ${this.sessionId}` + ); + } + + this.initialized = true; + }).catch((error) => { + this.initializationPromise = null; + throw error; + }); + + await this.initializationPromise; + } + + private async persistQueue(): Promise { + await this.store.save(this.sessionId, this.queue); + } + + private runWithMutationLock(fn: () => Promise): Promise { + const currentResult = this.mutationLock.catch(() => {}).then(() => fn()); + this.mutationLock = currentResult.then( + () => undefined, + () => undefined + ); + return currentResult; + } + /** * Add a message to the queue. * Called by API endpoint - returns immediately with queue position. @@ -67,29 +139,40 @@ export class MessageQueueService { * @param message The user message to queue * @returns Queue position and message ID */ - enqueue(message: UserMessageInput): { queued: true; position: number; id: string } { - const queuedMsg: QueuedMessage = { - id: generateId(), - content: message.content, - queuedAt: Date.now(), - ...(message.metadata !== undefined && { metadata: message.metadata }), - ...(message.kind !== undefined && { kind: message.kind }), - }; + async enqueue( + message: UserMessageInput + ): Promise<{ queued: true; position: number; id: string }> { + return await this.runWithMutationLock(async () => { + const queuedMsg: QueuedMessage = { + id: generateId(), + content: message.content, + queuedAt: Date.now(), + ...(message.metadata !== undefined && { metadata: message.metadata }), + ...(message.kind !== undefined && { kind: message.kind }), + }; - this.queue.push(queuedMsg); + this.queue.push(queuedMsg); - this.logger.debug(`Message queued: ${queuedMsg.id}, position: ${this.queue.length}`); + try { + await this.persistQueue(); + } catch (error) { + this.queue.pop(); + throw error; + } - this.eventBus.emit('message:queued', { - position: this.queue.length, - id: queuedMsg.id, - }); + this.logger.debug(`Message queued: ${queuedMsg.id}, position: ${this.queue.length}`); - return { - queued: true, - position: this.queue.length, - id: queuedMsg.id, - }; + this.eventBus.emit('message:queued', { + position: this.queue.length, + id: queuedMsg.id, + }); + + return { + queued: true, + position: this.queue.length, + id: queuedMsg.id, + }; + }); } /** @@ -111,27 +194,36 @@ export class MessageQueueService { * * @returns Coalesced message or null if queue is empty */ - dequeueAll(): CoalescedMessage | null { - if (this.queue.length === 0) return null; + async dequeueAll(): Promise { + return await this.runWithMutationLock(async () => { + if (this.queue.length === 0) return null; + + const messages = [...this.queue]; + this.queue = []; + + try { + await this.persistQueue(); + } catch (error) { + this.queue = messages; + throw error; + } - const messages = [...this.queue]; - this.queue = []; + const combined = this.coalesce(messages); - const combined = this.coalesce(messages); + this.logger.debug( + `Dequeued ${messages.length} message(s): ${messages.map((m) => m.id).join(', ')}` + ); - this.logger.debug( - `Dequeued ${messages.length} message(s): ${messages.map((m) => m.id).join(', ')}` - ); + this.eventBus.emit('message:dequeued', { + count: messages.length, + ids: messages.map((m) => m.id), + coalesced: messages.length > 1, + content: combined.combinedContent, + messages, + }); - this.eventBus.emit('message:dequeued', { - count: messages.length, - ids: messages.map((m) => m.id), - coalesced: messages.length > 1, - content: combined.combinedContent, - messages, + return combined; }); - - return combined; } /** @@ -259,8 +351,22 @@ export class MessageQueueService { * Clear all pending messages without processing. * Used during cleanup/abort. */ - clear(): void { - this.queue = []; + async clear(): Promise { + await this.runWithMutationLock(async () => { + if (this.queue.length === 0) { + return; + } + + const previousQueue = [...this.queue]; + this.queue = []; + + try { + await this.persistQueue(); + } catch (error) { + this.queue = previousQueue; + throw error; + } + }); } /** @@ -282,16 +388,28 @@ export class MessageQueueService { * Remove a single queued message by ID. * @returns true if message was found and removed; false otherwise */ - remove(id: string): boolean { - const index = this.queue.findIndex((m) => m.id === id); - if (index === -1) { - this.logger.debug(`Remove failed: message ${id} not found in queue`); - return false; - } + async remove(id: string): Promise { + return await this.runWithMutationLock(async () => { + const index = this.queue.findIndex((m) => m.id === id); + if (index === -1) { + this.logger.debug(`Remove failed: message ${id} not found in queue`); + return false; + } + + const [removed] = this.queue.splice(index, 1); + + try { + await this.persistQueue(); + } catch (error) { + if (removed) { + this.queue.splice(index, 0, removed); + } + throw error; + } - this.queue.splice(index, 1); - this.logger.debug(`Message removed: ${id}, remaining: ${this.queue.length}`); - this.eventBus.emit('message:removed', { id }); - return true; + this.logger.debug(`Message removed: ${id}, remaining: ${this.queue.length}`); + this.eventBus.emit('message:removed', { id }); + return true; + }); } } diff --git a/packages/core/src/session/session-manager.integration.test.ts b/packages/core/src/session/session-manager.integration.test.ts index 4589705a8..ffab42f3b 100644 --- a/packages/core/src/session/session-manager.integration.test.ts +++ b/packages/core/src/session/session-manager.integration.test.ts @@ -1,4 +1,7 @@ +import os from 'node:os'; +import path from 'node:path'; import { describe, test, expect, beforeEach, afterEach } from 'vitest'; +import { z } from 'zod'; import { DextoAgent } from '../agent/DextoAgent.js'; import type { AgentRuntimeSettings } from '../agent/runtime-config.js'; import { SystemPromptConfigSchema } from '../systemPrompt/schemas.js'; @@ -151,6 +154,21 @@ describe('Session Integration: Chat History Preservation', () => { expect(finalHistory![4]).toEqual(newMessage); }); + test('session LLM overrides stay visible after ending a session', async () => { + const sessionId = 'override-visible-after-end'; + + await agent.createSession(sessionId); + await agent.switchLLM({ model: 'gpt-5' }, sessionId); + + expect(agent.hasSessionLLMOverride(sessionId)).toBe(true); + expect(agent.getCurrentLLMConfig(sessionId).model).toBe('gpt-5'); + + await agent.endSession(sessionId); + + expect(agent.hasSessionLLMOverride(sessionId)).toBe(true); + expect(agent.getCurrentLLMConfig(sessionId).model).toBe('gpt-5'); + }); + test('full integration: explicit session deletion removes everything', async () => { const sessionId = 'deletion-test-session'; @@ -301,6 +319,326 @@ describe('Session Integration: Chat History Preservation', () => { // The core functionality (chat history preservation) is thoroughly tested above }); +describe('Session Integration: Core-owned Interaction State Persistence', () => { + let agents: DextoAgent[] = []; + + const baseSettings: AgentRuntimeSettings = { + systemPrompt: SystemPromptConfigSchema.parse('You are a helpful assistant.'), + llm: LLMConfigSchema.parse({ + provider: 'openai', + model: 'gpt-5-mini', + apiKey: 'test-key-123', + }), + agentId: 'interaction-state-test-agent', + mcpServers: ServersConfigSchema.parse({}), + sessions: SessionConfigSchema.parse({ + maxSessions: 10, + sessionTTL: 60000, + }), + permissions: PermissionsConfigSchema.parse({ + mode: 'auto-approve', + timeout: 120000, + }), + elicitation: ElicitationConfigSchema.parse({ + enabled: false, + timeout: 120000, + }), + resources: ResourcesConfigSchema.parse([]), + prompts: PromptsSchema.parse([]), + }; + + async function createAgentWithSharedStorage( + agentId: string, + storage: { + blob: ReturnType; + cache: ReturnType; + database: ReturnType; + } + ): Promise { + const loggerConfig = LoggerConfigSchema.parse({ + level: 'warn', + transports: [{ type: 'console', colorize: false }], + }); + const logger = createLogger({ config: loggerConfig, agentId }); + + const agent = new DextoAgent({ + ...baseSettings, + agentId, + logger, + storage, + tools: [ + { + id: 'allowed_tool', + description: 'Allowed tool', + inputSchema: z.object({}).strict(), + execute: async () => null, + }, + { + id: 'disabled_tool', + description: 'Disabled tool', + inputSchema: z.object({}).strict(), + execute: async () => null, + }, + ], + hooks: [], + }); + await agent.start(); + agents.push(agent); + return agent; + } + + afterEach(async () => { + for (const agent of [...agents].reverse()) { + if (agent.isStarted()) { + await agent.stop(); + } + } + agents = []; + }); + + test('restores queued messages, session overrides, approvals, and tool preferences after agent restart', async () => { + const originalOpenAiApiKey = process.env.OPENAI_API_KEY; + process.env.OPENAI_API_KEY = 'test-key-123'; + + try { + const sharedStorage = { + blob: createInMemoryBlobStore(), + cache: createInMemoryCache(), + database: createInMemoryDatabase(), + }; + const sessionId = 'persisted-interaction-session'; + const approvedDirectory = path.join(os.tmpdir(), 'dexto-persisted-approval'); + + const agent1 = await createAgentWithSharedStorage( + 'interaction-state-agent-1', + sharedStorage + ); + await agent1.createSession(sessionId); + await agent1.switchLLM({ model: 'gpt-5' }, sessionId); + await agent1.queueMessage(sessionId, { + content: [{ type: 'text', text: 'resume with plan B' }], + metadata: { source: 'integration-test' }, + }); + await agent1.setSessionAutoApproveTools(sessionId, ['allowed_tool']); + await agent1.setSessionDisabledTools(sessionId, ['disabled_tool']); + await agent1.services.approvalManager.addPattern('bash_exec', 'git *', sessionId); + await agent1.services.approvalManager.addApprovedDirectory( + approvedDirectory, + 'session', + sessionId + ); + + const persistedSession = await agent1.services.storageManager + .getDatabase() + .get(`session:${sessionId}`); + expect(persistedSession?.llmOverride).toEqual( + expect.objectContaining({ + provider: 'openai', + model: 'gpt-5', + }) + ); + + const persistedQueue = await agent1.services.storageManager + .getDatabase() + .get< + Array<{ content: Array<{ type: string; text?: string }> }> + >(`session-message-queue:${sessionId}`); + expect(persistedQueue).toHaveLength(1); + expect(persistedQueue?.[0]?.content).toEqual([ + { type: 'text', text: 'resume with plan B' }, + ]); + + expect( + await agent1.services.storageManager + .getDatabase() + .get(`session-tool-preferences:${sessionId}`) + ).toEqual({ + userAutoApproveTools: ['allowed_tool'], + disabledTools: ['disabled_tool'], + }); + + const persistedApprovals = await agent1.services.storageManager.getDatabase().get<{ + toolPatterns?: Record; + approvedDirectories?: Array<{ path: string; type: string }>; + }>(`session-approvals:${sessionId}`); + expect(persistedApprovals?.toolPatterns).toEqual({ + bash_exec: ['git *'], + }); + expect( + persistedApprovals?.approvedDirectories?.some( + (entry) => + entry.type === 'session' && + entry.path.endsWith(path.normalize('dexto-persisted-approval')) + ) + ).toBe(true); + + const agent2 = await createAgentWithSharedStorage( + 'interaction-state-agent-2', + sharedStorage + ); + + expect(agent2.services.stateManager.getLLMConfig(sessionId).model).toBe('gpt-5-mini'); + + const restoredSession = await agent2.getSession(sessionId); + expect(restoredSession).toBeDefined(); + expect(agent2.services.stateManager.getLLMConfig(sessionId).model).toBe('gpt-5'); + + const queuedMessages = await agent2.getQueuedMessages(sessionId); + expect(queuedMessages).toHaveLength(1); + expect(queuedMessages[0]?.content).toEqual([ + { type: 'text', text: 'resume with plan B' }, + ]); + + expect(await agent2.getSessionAutoApproveTools(sessionId)).toEqual(['allowed_tool']); + + const enabledTools = await agent2.getEnabledTools(sessionId); + expect(Object.keys(enabledTools)).toContain('allowed_tool'); + expect(Object.keys(enabledTools)).not.toContain('disabled_tool'); + + expect( + agent2.services.approvalManager.matchesPattern( + 'bash_exec', + 'git status *', + sessionId + ) + ).toBe(true); + expect( + agent2.services.approvalManager.isDirectorySessionApproved( + path.join(approvedDirectory, 'file.ts'), + sessionId + ) + ).toBe(true); + } finally { + if (originalOpenAiApiKey === undefined) { + delete process.env.OPENAI_API_KEY; + } else { + process.env.OPENAI_API_KEY = originalOpenAiApiKey; + } + } + }); + + test('drops persisted interaction state when startup cleanup purges an expired session', async () => { + const sharedStorage = { + blob: createInMemoryBlobStore(), + cache: createInMemoryCache(), + database: createInMemoryDatabase(), + }; + const sessionId = 'expired-persisted-interaction-session'; + const approvedDirectory = path.join(os.tmpdir(), 'dexto-expired-persisted-approval'); + + const agent1 = await createAgentWithSharedStorage('expired-state-agent-1', sharedStorage); + await agent1.createSession(sessionId); + await agent1.switchLLM({ model: 'gpt-5' }, sessionId); + await agent1.queueMessage(sessionId, { + content: [{ type: 'text', text: 'stale queued follow-up' }], + }); + await agent1.setSessionAutoApproveTools(sessionId, ['allowed_tool']); + await agent1.setSessionDisabledTools(sessionId, ['disabled_tool']); + await agent1.services.approvalManager.addPattern('bash_exec', 'git *', sessionId); + await agent1.services.approvalManager.addApprovedDirectory( + approvedDirectory, + 'session', + sessionId + ); + + const database = agent1.services.storageManager.getDatabase(); + const expiredSession = await database.get(`session:${sessionId}`); + if (!expiredSession) { + throw new Error(`Expected session '${sessionId}' to exist`); + } + + expiredSession.lastActivity = Date.now() - 120000; + await database.set(`session:${sessionId}`, expiredSession); + await agent1.stop(); + + const agent2 = await createAgentWithSharedStorage('expired-state-agent-2', sharedStorage); + + expect(await database.get(`session:${sessionId}`)).toBeUndefined(); + expect(await database.get(`session-message-queue:${sessionId}`)).toBeUndefined(); + expect(await database.get(`session-tool-preferences:${sessionId}`)).toBeUndefined(); + expect(await database.get(`session-approvals:${sessionId}`)).toBeUndefined(); + + await agent2.createSession(sessionId); + + expect(agent2.hasSessionLLMOverride(sessionId)).toBe(false); + expect(agent2.getCurrentLLMConfig(sessionId).model).toBe('gpt-5-mini'); + expect(await agent2.getQueuedMessages(sessionId)).toEqual([]); + expect(await agent2.getSessionAutoApproveTools(sessionId)).toEqual([]); + + const enabledTools = await agent2.getEnabledTools(sessionId); + expect(Object.keys(enabledTools)).toContain('allowed_tool'); + expect(Object.keys(enabledTools)).toContain('disabled_tool'); + + expect( + agent2.services.approvalManager.matchesPattern('bash_exec', 'git status *', sessionId) + ).toBe(false); + expect( + agent2.services.approvalManager.isDirectorySessionApproved( + path.join(approvedDirectory, 'file.ts'), + sessionId + ) + ).toBe(false); + }); + + test('newly created sessions do not inherit orphaned persisted interaction state', async () => { + const sharedStorage = { + blob: createInMemoryBlobStore(), + cache: createInMemoryCache(), + database: createInMemoryDatabase(), + }; + const sessionId = 'orphaned-interaction-session'; + const approvedDirectory = path.join(os.tmpdir(), 'dexto-orphaned-persisted-approval'); + + const agent1 = await createAgentWithSharedStorage('orphaned-state-agent-1', sharedStorage); + await agent1.createSession(sessionId); + await agent1.switchLLM({ model: 'gpt-5' }, sessionId); + await agent1.queueMessage(sessionId, { + content: [{ type: 'text', text: 'stale orphaned follow-up' }], + }); + await agent1.setSessionAutoApproveTools(sessionId, ['allowed_tool']); + await agent1.setSessionDisabledTools(sessionId, ['disabled_tool']); + await agent1.services.approvalManager.addPattern('bash_exec', 'git *', sessionId); + await agent1.services.approvalManager.addApprovedDirectory( + approvedDirectory, + 'session', + sessionId + ); + + await sharedStorage.database.delete(`session:${sessionId}`); + await agent1.stop(); + + const agent2 = await createAgentWithSharedStorage('orphaned-state-agent-2', sharedStorage); + await agent2.createSession(sessionId); + + expect(agent2.hasSessionLLMOverride(sessionId)).toBe(false); + expect(agent2.getCurrentLLMConfig(sessionId).model).toBe('gpt-5-mini'); + expect(await agent2.getQueuedMessages(sessionId)).toEqual([]); + expect(await agent2.getSessionAutoApproveTools(sessionId)).toEqual([]); + + const enabledTools = await agent2.getEnabledTools(sessionId); + expect(Object.keys(enabledTools)).toContain('allowed_tool'); + expect(Object.keys(enabledTools)).toContain('disabled_tool'); + + expect( + agent2.services.approvalManager.matchesPattern('bash_exec', 'git status *', sessionId) + ).toBe(false); + expect( + agent2.services.approvalManager.isDirectorySessionApproved( + path.join(approvedDirectory, 'file.ts'), + sessionId + ) + ).toBe(false); + + expect( + await sharedStorage.database.get(`session-message-queue:${sessionId}`) + ).toBeUndefined(); + expect( + await sharedStorage.database.get(`session-tool-preferences:${sessionId}`) + ).toBeUndefined(); + expect(await sharedStorage.database.get(`session-approvals:${sessionId}`)).toBeUndefined(); + }); +}); + describe('Session Integration: Multi-Model Token Tracking', () => { let agent: DextoAgent; diff --git a/packages/core/src/session/session-manager.test.ts b/packages/core/src/session/session-manager.test.ts index 7f0ac015d..fc2b5f07a 100644 --- a/packages/core/src/session/session-manager.test.ts +++ b/packages/core/src/session/session-manager.test.ts @@ -108,6 +108,8 @@ describe('SessionManager', () => { agentCard: { name: 'test-agent' }, })), updateLLM: vi.fn().mockReturnValue({ isValid: true, errors: [], warnings: [] }), + clearSessionOverride: vi.fn(), + hasSessionLLMOverride: vi.fn().mockReturnValue(false), }, systemPromptManager: { getSystemPrompt: vi.fn().mockReturnValue('System prompt'), @@ -126,11 +128,24 @@ describe('SessionManager', () => { }, toolManager: { getAllTools: vi.fn().mockReturnValue([]), + restoreSessionState: vi.fn().mockResolvedValue(undefined), + deleteSessionState: vi.fn().mockResolvedValue(undefined), + evictSessionState: vi.fn(), + }, + approvalManager: { + restoreSessionState: vi.fn().mockResolvedValue(undefined), + deleteSessionState: vi.fn().mockResolvedValue(undefined), + evictSessionState: vi.fn(), }, hookManager: { executeHooks: vi.fn().mockImplementation(async (_point, payload) => payload), cleanup: vi.fn(), }, + messageQueueStore: { + load: vi.fn().mockResolvedValue([]), + save: vi.fn().mockResolvedValue(undefined), + delete: vi.fn().mockResolvedValue(undefined), + }, }; // Parse LLM config now that mocks are set up @@ -160,6 +175,7 @@ describe('SessionManager', () => { run: vi.fn().mockResolvedValue('Mock response'), reset: vi.fn().mockResolvedValue(undefined), dispose: vi.fn(), + clearMessageQueue: vi.fn().mockResolvedValue(0), cleanup: vi.fn().mockImplementation(async () => { // Simulate the new cleanup behavior - only call dispose, not reset mockSession.dispose(); @@ -275,6 +291,17 @@ describe('SessionManager', () => { expect(mockStorageManager.database.delete).toHaveBeenCalledWith( 'session:expired-session' ); + expect(mockStorageManager.cache.delete).toHaveBeenCalledWith('session:expired-session'); + expect(mockServices.toolManager.deleteSessionState).toHaveBeenCalledWith( + 'expired-session' + ); + expect(mockServices.approvalManager.deleteSessionState).toHaveBeenCalledWith( + 'expired-session' + ); + expect(mockServices.messageQueueStore.delete).toHaveBeenCalledWith('expired-session'); + expect(mockServices.stateManager.clearSessionOverride).toHaveBeenCalledWith( + 'expired-session' + ); }); }); @@ -1091,6 +1118,23 @@ describe('SessionManager', () => { expect(session.cleanup).toHaveBeenCalled(); expect(mockStorageManager.database.delete).toHaveBeenCalledWith(`session:${sessionId}`); + expect(mockServices.stateManager.clearSessionOverride).toHaveBeenCalledWith(sessionId); + }); + + test('should evict restored interaction state when ending a session', async () => { + const sessionId = 'test-session'; + + await sessionManager.createSession(sessionId); + mockServices.toolManager.evictSessionState.mockClear(); + mockServices.approvalManager.evictSessionState.mockClear(); + mockServices.stateManager.clearSessionOverride.mockClear(); + await sessionManager.endSession(sessionId); + + expect(mockServices.toolManager.evictSessionState).toHaveBeenCalledWith(sessionId); + expect(mockServices.approvalManager.evictSessionState).toHaveBeenCalledWith(sessionId); + expect(mockServices.stateManager.clearSessionOverride).not.toHaveBeenCalledWith( + sessionId + ); }); test('should handle deleting non-existent sessions gracefully', async () => { @@ -1155,6 +1199,104 @@ describe('SessionManager', () => { expect(result.warnings).toEqual([]); }); + test('should not switch live session when persisting the override fails', async () => { + const sessionId = 'persist-failure-session'; + const newLLMConfig: ValidatedLLMConfig = { + ...mockLLMConfig, + provider: 'anthropic', + model: 'claude-4-opus-20250514', + }; + + const session = await sessionManager.createSession(sessionId); + + mockStorageManager.database.get.mockImplementation(async (key: string) => { + if (key === `session:${sessionId}`) { + return { + ...mockSessionData, + id: sessionId, + }; + } + return null; + }); + mockStorageManager.database.set.mockRejectedValue(new Error('Persist failed')); + + await expect( + sessionManager.switchLLMForSpecificSession(newLLMConfig, sessionId) + ).rejects.toThrow('Persist failed'); + + expect(mockServices.stateManager.updateLLM).not.toHaveBeenCalled(); + expect(session.switchLLM).not.toHaveBeenCalled(); + }); + + test('should roll back session overrides when switching the live session fails', async () => { + const sessionId = 'rollback-session'; + const previousLLMConfig: ValidatedLLMConfig = { + ...mockLLMConfig, + model: 'gpt-5-mini', + }; + const newLLMConfig: ValidatedLLMConfig = { + ...mockLLMConfig, + provider: 'anthropic', + model: 'claude-4-opus-20250514', + }; + + const session = await sessionManager.createSession(sessionId); + + mockServices.stateManager.getRuntimeConfig.mockImplementation( + (requestedSessionId?: string) => ({ + llm: requestedSessionId === sessionId ? previousLLMConfig : mockLLMConfig, + agentCard: { name: 'test-agent' }, + }) + ); + mockServices.stateManager.hasSessionLLMOverride.mockImplementation( + (requestedSessionId?: string) => requestedSessionId === sessionId + ); + + const persistedSessionData = { + ...mockSessionData, + id: sessionId, + llmOverride: { + provider: previousLLMConfig.provider, + model: previousLLMConfig.model, + maxIterations: previousLLMConfig.maxIterations, + maxInputTokens: previousLLMConfig.maxInputTokens, + }, + }; + + mockStorageManager.database.get.mockImplementation(async (key: string) => { + if (key === `session:${sessionId}`) { + return persistedSessionData; + } + return null; + }); + (session.switchLLM as ReturnType) + .mockRejectedValueOnce(new Error('Session switch failed')) + .mockResolvedValueOnce(undefined); + + await expect( + sessionManager.switchLLMForSpecificSession(newLLMConfig, sessionId) + ).rejects.toThrow('Session switch failed'); + + expect(mockServices.stateManager.updateLLM).toHaveBeenNthCalledWith( + 1, + newLLMConfig, + sessionId + ); + expect(mockServices.stateManager.updateLLM).toHaveBeenNthCalledWith( + 2, + previousLLMConfig, + sessionId + ); + expect(session.switchLLM).toHaveBeenNthCalledWith(1, newLLMConfig); + expect(session.switchLLM).toHaveBeenNthCalledWith(2, previousLLMConfig); + expect(persistedSessionData.llmOverride).toEqual({ + provider: previousLLMConfig.provider, + model: previousLLMConfig.model, + maxIterations: previousLLMConfig.maxIterations, + maxInputTokens: previousLLMConfig.maxInputTokens, + }); + }); + test('should handle LLM switch for non-existent session', async () => { const newLLMConfig: ValidatedLLMConfig = { ...mockLLMConfig, @@ -1375,6 +1517,41 @@ describe('SessionManager', () => { // The expired session should have been cleaned up from memory // This is tested indirectly through the cleanup process }); + + test('should reset persisted interaction state alongside conversation history', async () => { + const sessionId = 'reset-session'; + const session = await sessionManager.createSession(sessionId); + mockServices.stateManager.hasSessionLLMOverride.mockReturnValue(true); + mockStorageManager.database.get.mockImplementation(async (key: string) => { + if (key === `session:${sessionId}`) { + return { + ...mockSessionData, + id: sessionId, + llmOverride: { + provider: 'openai', + model: 'gpt-5', + maxInputTokens: 128000, + }, + }; + } + return null; + }); + + await sessionManager.resetSession(sessionId); + + expect(session.reset).toHaveBeenCalled(); + expect(session.clearMessageQueue).toHaveBeenCalled(); + expect(mockServices.toolManager.deleteSessionState).toHaveBeenCalledWith(sessionId); + expect(mockServices.approvalManager.deleteSessionState).toHaveBeenCalledWith(sessionId); + expect(mockServices.stateManager.clearSessionOverride).toHaveBeenCalledWith(sessionId); + expect(session.switchLLM).toHaveBeenCalledWith(mockLLMConfig); + expect(mockStorageManager.database.set).toHaveBeenCalledWith( + `session:${sessionId}`, + expect.not.objectContaining({ + llmOverride: expect.anything(), + }) + ); + }); }); describe('Periodic Cleanup', () => { @@ -1465,6 +1642,9 @@ describe('SessionManager', () => { lastActivity: Date.now() - 7200000, // 2 hours ago (expired) messageCount: 5, }; + mockServices.toolManager.evictSessionState.mockClear(); + mockServices.approvalManager.evictSessionState.mockClear(); + mockServices.stateManager.clearSessionOverride.mockClear(); mockStorageManager.database.get.mockResolvedValue(expiredSessionData); // Trigger cleanup - should remove from memory but preserve storage @@ -1472,6 +1652,11 @@ describe('SessionManager', () => { // Session should be removed from memory expect(sessionManager['sessions'].has(sessionId)).toBe(false); + expect(mockServices.toolManager.evictSessionState).toHaveBeenCalledWith(sessionId); + expect(mockServices.approvalManager.evictSessionState).toHaveBeenCalledWith(sessionId); + expect(mockServices.stateManager.clearSessionOverride).not.toHaveBeenCalledWith( + sessionId + ); // But session should still exist in storage (not deleted) expect(mockStorageManager.database.delete).not.toHaveBeenCalledWith(sessionKey); diff --git a/packages/core/src/session/session-manager.ts b/packages/core/src/session/session-manager.ts index 85efc299e..ff6c0fec5 100644 --- a/packages/core/src/session/session-manager.ts +++ b/packages/core/src/session/session-manager.ts @@ -9,6 +9,7 @@ import type { AgentStateManager } from '../agent/state-manager.js'; import type { ValidatedLLMConfig } from '../llm/schemas.js'; import type { StorageManager } from '../storage/index.js'; import type { HookManager } from '../hooks/manager.js'; +import type { ApprovalManager } from '../approval/manager.js'; import { SessionError } from './errors.js'; import type { TokenUsage } from '../llm/types.js'; import type { LanguageModelFactory } from '../llm/services/types.js'; @@ -18,6 +19,7 @@ import { SessionPromptContributorSchema, type SessionPromptContributor, } from '../systemPrompt/schemas.js'; +import type { MessageQueueStore } from './message-queue-store.js'; export type SessionLoggerFactory = (options: { baseLogger: Logger; agentId: string; @@ -139,11 +141,13 @@ export class SessionManager { stateManager: AgentStateManager; systemPromptManager: SystemPromptManager; toolManager: ToolManager; + approvalManager: ApprovalManager; agentEventBus: AgentEventBus; storageManager: StorageManager; resourceManager: import('../resources/index.js').ResourceManager; hookManager: HookManager; mcpManager: import('../mcp/manager.js').MCPManager; + messageQueueStore: Pick; compactionStrategy: CompactionStrategy | null; workspaceManager?: import('../workspace/manager.js').WorkspaceManager; }, @@ -221,8 +225,13 @@ export class SessionManager { // Session is still valid, but don't create ChatSession until requested this.logger.debug(`Session ${sessionId} restored from storage`); } else { - // Session expired, clean it up - await this.services.storageManager.getDatabase().delete(sessionKey); + // Session expired, purge the session record plus any persisted + // interaction state keyed off the same session ID. + await Promise.all([ + this.services.storageManager.getDatabase().delete(sessionKey), + this.services.storageManager.getCache().delete(sessionKey), + this.deleteSessionInteractionState(sessionId), + ]); this.logger.debug(`Expired session ${sessionId} cleaned up during restore`); } } @@ -486,6 +495,8 @@ export class SessionManager { const session = new ChatSession(this.getChatSessionServices(), id, sessionLogger); await session.init(); + await this.services.toolManager.restoreSessionState(id); + await this.services.approvalManager.restoreSessionState(id); this.sessions.set(id, session); this.logger.info(`Restored session from storage: ${id}`); @@ -499,6 +510,10 @@ export class SessionManager { throw SessionError.maxSessionsExceeded(activeSessionKeys.length, this.maxSessions); } + // A newly-created session claims a clean interaction-state namespace. + // If stale per-session buckets exist without metadata, they belong to an orphaned session. + await this.deleteSessionInteractionState(id); + const workspace = await this.services.workspaceManager?.getWorkspace(); // Create new session metadata first to "reserve" the session slot @@ -621,6 +636,8 @@ export class SessionManager { sessionLogger ); await session.init(); + await this.services.toolManager.restoreSessionState(sessionId); + await this.services.approvalManager.restoreSessionState(sessionId); this.sessions.set(sessionId, session); return session; @@ -649,6 +666,7 @@ export class SessionManager { // Remove from cache but preserve database storage const sessionKey = `session:${sessionId}`; await this.services.storageManager.getCache().delete(sessionKey); + this.evictSessionInteractionState(sessionId); this.logger.debug( `Ended session (removed from memory, chat history preserved): ${sessionId}` @@ -675,6 +693,7 @@ export class SessionManager { const sessionKey = `session:${sessionId}`; await this.services.storageManager.getDatabase().delete(sessionKey); await this.services.storageManager.getCache().delete(sessionKey); + await this.deleteSessionInteractionState(sessionId); const messagesKey = `messages:${sessionId}`; await this.services.storageManager.getDatabase().delete(messagesKey); @@ -683,7 +702,7 @@ export class SessionManager { } /** - * Resets the conversation history for a session while keeping the session alive. + * Resets conversation and session-scoped interaction state while keeping the session alive. * * @param sessionId The session ID to reset * @throws Error if session doesn't exist @@ -697,6 +716,16 @@ export class SessionManager { } await session.reset(); + await session.clearMessageQueue(); + await Promise.all([ + this.services.toolManager.deleteSessionState(sessionId), + this.services.approvalManager.deleteSessionState(sessionId), + ]); + + if (this.services.stateManager.hasSessionLLMOverride(sessionId)) { + this.services.stateManager.clearSessionOverride(sessionId); + await session.switchLLM(this.services.stateManager.getRuntimeConfig().llm); + } // Reset message count in metadata await this.runWithSessionDataLock(sessionId, async (sessionKey) => { @@ -709,6 +738,7 @@ export class SessionManager { sessionData.messageCount = 0; sessionData.lastActivity = Date.now(); + delete sessionData.llmOverride; await this.persistSessionData(sessionKey, sessionData); }); @@ -1125,6 +1155,7 @@ export class SessionManager { // Only dispose memory resources, don't delete chat history session.dispose(); this.sessions.delete(sessionId); + this.evictSessionInteractionState(sessionId); this.logger.debug( `Removed expired session from memory: ${sessionId} (chat history preserved)` ); @@ -1155,11 +1186,7 @@ export class SessionManager { const session = await this.getSession(sId); if (session) { try { - // Update state with validated config (validation already done by DextoAgent) - // Using exceptions here for session-specific runtime failures (corruption, disposal, etc.) - // This is different from input validation which uses Result pattern - this.services.stateManager.updateLLM(newLLMConfig, sId); - await session.switchLLM(newLLMConfig); + await this.applySessionLLMSwitch(sId, session, newLLMConfig); } catch (error) { // Session-level failure - continue processing other sessions (isolation) failedSessions.push(sId); @@ -1204,10 +1231,78 @@ export class SessionManager { throw SessionError.notFound(sessionId); } - await session.switchLLM(newLLMConfig); + await this.applySessionLLMSwitch(sessionId, session, newLLMConfig); + + this.services.agentEventBus.emit('llm:switched', { + newConfig: newLLMConfig, + historyRetained: true, + sessionIds: [sessionId], + }); + + const message = `Successfully switched to ${newLLMConfig.provider}/${newLLMConfig.model} for session ${sessionId}`; + + return { message, warnings: [] }; + } + + private async applySessionLLMSwitch( + sessionId: string, + session: ChatSession, + newLLMConfig: ValidatedLLMConfig + ): Promise { + const previousLLMConfig = this.services.stateManager.getRuntimeConfig(sessionId).llm; + const previousHadOverride = this.services.stateManager.hasSessionLLMOverride(sessionId); + const previousPersistedOverride = await this.getPersistedSessionLLMOverride(sessionId); - // Persist the LLM override to storage so it survives restarts - // SECURITY: Don't persist API keys - they should be resolved from environment variables + await this.setPersistedSessionLLMOverride( + sessionId, + this.toPersistedLLMConfig(newLLMConfig) + ); + + try { + this.services.stateManager.updateLLM(newLLMConfig, sessionId); + await session.switchLLM(newLLMConfig); + } catch (error) { + await this.setPersistedSessionLLMOverride(sessionId, previousPersistedOverride); + + if (previousHadOverride) { + this.services.stateManager.updateLLM(previousLLMConfig, sessionId); + } else { + this.services.stateManager.clearSessionOverride(sessionId); + } + + try { + await session.switchLLM(previousLLMConfig); + } catch (rollbackError) { + this.logger.error( + `Failed to roll back LLM switch for session ${sessionId}: ${ + rollbackError instanceof Error + ? rollbackError.message + : String(rollbackError) + }` + ); + } + + throw error; + } + } + + private async getPersistedSessionLLMOverride( + sessionId: string + ): Promise { + const sessionData = await this.getSessionData(sessionId); + return sessionData?.llmOverride; + } + + private toPersistedLLMConfig(newLLMConfig: ValidatedLLMConfig): PersistedLLMConfig { + // SECURITY: Don't persist API keys - they should be resolved from environment variables. + const { apiKey: _apiKey, ...configWithoutApiKey } = newLLMConfig; + return configWithoutApiKey; + } + + private async setPersistedSessionLLMOverride( + sessionId: string, + llmOverride: PersistedLLMConfig | undefined + ): Promise { await this.runWithSessionDataLock(sessionId, async (sessionKey) => { const sessionData = await this.services.storageManager .getDatabase() @@ -1216,20 +1311,27 @@ export class SessionManager { return; } - const { apiKey: _apiKey, ...configWithoutApiKey } = newLLMConfig; - sessionData.llmOverride = configWithoutApiKey; + if (llmOverride !== undefined) { + sessionData.llmOverride = llmOverride; + } else { + delete sessionData.llmOverride; + } await this.persistSessionData(sessionKey, sessionData); }); + } - this.services.agentEventBus.emit('llm:switched', { - newConfig: newLLMConfig, - historyRetained: true, - sessionIds: [sessionId], - }); - - const message = `Successfully switched to ${newLLMConfig.provider}/${newLLMConfig.model} for session ${sessionId}`; + private async deleteSessionInteractionState(sessionId: string): Promise { + this.services.stateManager.clearSessionOverride(sessionId); + await Promise.all([ + this.services.toolManager.deleteSessionState(sessionId), + this.services.approvalManager.deleteSessionState(sessionId), + this.services.messageQueueStore.delete(sessionId), + ]); + } - return { message, warnings: [] }; + private evictSessionInteractionState(sessionId: string): void { + this.services.toolManager.evictSessionState(sessionId); + this.services.approvalManager.evictSessionState(sessionId); } private async runWithSessionDataLock( diff --git a/packages/core/src/session/title-generator.test.ts b/packages/core/src/session/title-generator.test.ts index bef677793..4278764db 100644 --- a/packages/core/src/session/title-generator.test.ts +++ b/packages/core/src/session/title-generator.test.ts @@ -71,7 +71,9 @@ describe('generateSessionTitle', () => { expect.stringMatching(/^titlegen-/), mockResourceManager, logger, - {}, + expect.objectContaining({ + messageQueue: expect.any(Object), + }), languageModelFactory ); }); @@ -96,7 +98,9 @@ describe('generateSessionTitle', () => { expect.stringMatching(/^titlegen-/), mockResourceManager, logger, - {}, + expect.objectContaining({ + messageQueue: expect.any(Object), + }), undefined ); }); diff --git a/packages/core/src/session/title-generator.ts b/packages/core/src/session/title-generator.ts index 90664a40e..b6ca01b14 100644 --- a/packages/core/src/session/title-generator.ts +++ b/packages/core/src/session/title-generator.ts @@ -7,6 +7,7 @@ import type { CreateLLMServiceOptions, LanguageModelFactory } from '../llm/servi import { createLLMService } from '../llm/services/factory.js'; import { SessionEventBus } from '../events/index.js'; import { MemoryHistoryProvider } from './history/memory.js'; +import { MessageQueueService } from './message-queue.js'; export interface GenerateSessionTitleResult { title?: string; @@ -38,7 +39,9 @@ export async function generateSessionTitle( const history = new MemoryHistoryProvider(logger); const bus = new SessionEventBus(); const sessionId = `titlegen-${Math.random().toString(36).slice(2)}`; - const options: CreateLLMServiceOptions = {}; + const options: CreateLLMServiceOptions = { + messageQueue: MessageQueueService.createEphemeral(bus, logger, sessionId), + }; const tempService = createLLMService( config, toolManager, diff --git a/packages/core/src/test-utils/session-state-stores.ts b/packages/core/src/test-utils/session-state-stores.ts new file mode 100644 index 000000000..0e3f70f64 --- /dev/null +++ b/packages/core/src/test-utils/session-state-stores.ts @@ -0,0 +1,57 @@ +import type { Logger } from '../logger/v2/types.js'; +import type { StorageManager } from '../storage/index.js'; +import { SessionApprovalStore } from '../approval/session-approval-store.js'; +import type { MessageQueueStore } from '../session/message-queue-store.js'; +import type { QueuedMessage } from '../session/types.js'; +import { SessionToolPreferencesStore } from '../tools/session-tool-preferences-store.js'; +import { createInMemoryCache, createInMemoryDatabase } from './in-memory-storage.js'; + +type SessionStateStorage = Pick; + +export function createInMemorySessionStateStorage(): SessionStateStorage { + const cache = createInMemoryCache(); + const database = createInMemoryDatabase(); + + return { + getCache: () => cache, + getDatabase: () => database, + }; +} + +export function createInMemorySessionApprovalStore( + logger: Logger, + storageManager: SessionStateStorage = createInMemorySessionStateStorage() +): SessionApprovalStore { + return new SessionApprovalStore(storageManager as StorageManager, logger); +} + +export function createInMemorySessionToolPreferencesStore( + logger: Logger, + storageManager: SessionStateStorage = createInMemorySessionStateStorage() +): SessionToolPreferencesStore { + return new SessionToolPreferencesStore(storageManager as StorageManager, logger); +} + +export function createInMemoryMessageQueueStore(): Pick< + MessageQueueStore, + 'load' | 'save' | 'delete' +> { + const queues = new Map(); + + return { + async load(sessionId: string): Promise { + return structuredClone(queues.get(sessionId) ?? []); + }, + async save(sessionId: string, queue: QueuedMessage[]): Promise { + if (queue.length === 0) { + queues.delete(sessionId); + return; + } + + queues.set(sessionId, structuredClone(queue)); + }, + async delete(sessionId: string): Promise { + queues.delete(sessionId); + }, + }; +} diff --git a/packages/core/src/tools/session-tool-preferences-store.ts b/packages/core/src/tools/session-tool-preferences-store.ts new file mode 100644 index 000000000..82e4215ef --- /dev/null +++ b/packages/core/src/tools/session-tool-preferences-store.ts @@ -0,0 +1,79 @@ +import { z } from 'zod'; +import type { StorageManager } from '../storage/index.js'; +import type { Logger } from '../logger/v2/types.js'; + +const SessionToolPreferencesSchema = z + .object({ + userAutoApproveTools: z.array(z.string()).default([]), + disabledTools: z.array(z.string()).default([]), + }) + .strict(); + +export type SessionToolPreferences = z.output; + +const DEFAULT_SESSION_TOOL_PREFERENCES: SessionToolPreferences = { + userAutoApproveTools: [], + disabledTools: [], +}; + +export class SessionToolPreferencesStore { + private readonly cacheTtlSeconds: number; + + constructor( + private readonly storageManager: StorageManager, + private readonly logger: Logger, + options: { cacheTtlMs?: number } = {} + ) { + const cacheTtlMs = options.cacheTtlMs ?? 3600000; + this.cacheTtlSeconds = Math.max(1, Math.floor(cacheTtlMs / 1000)); + } + + private buildKey(sessionId: string): string { + return `session-tool-preferences:${sessionId}`; + } + + async load(sessionId: string): Promise { + const key = this.buildKey(sessionId); + const cached = await this.storageManager.getCache().get(key); + if (cached !== undefined) { + return this.parsePreferences(cached, key); + } + + const stored = await this.storageManager.getDatabase().get(key); + if (stored === undefined) { + return structuredClone(DEFAULT_SESSION_TOOL_PREFERENCES); + } + + const parsed = this.parsePreferences(stored, key); + await this.storageManager.getCache().set(key, parsed, this.cacheTtlSeconds); + return parsed; + } + + async save(sessionId: string, preferences: SessionToolPreferences): Promise { + const key = this.buildKey(sessionId); + const normalized = SessionToolPreferencesSchema.parse(preferences); + await this.storageManager.getDatabase().set(key, normalized); + await this.storageManager.getCache().set(key, normalized, this.cacheTtlSeconds); + } + + async delete(sessionId: string): Promise { + const key = this.buildKey(sessionId); + await Promise.all([ + this.storageManager.getDatabase().delete(key), + this.storageManager.getCache().delete(key), + ]); + } + + private parsePreferences(value: unknown, key: string): SessionToolPreferences { + const result = SessionToolPreferencesSchema.safeParse(value); + if (result.success) { + return result.data; + } + + this.logger.warn('Invalid persisted session tool preferences encountered; using defaults', { + key, + error: result.error.message, + }); + return structuredClone(DEFAULT_SESSION_TOOL_PREFERENCES); + } +} diff --git a/packages/core/src/tools/tool-manager.integration.test.ts b/packages/core/src/tools/tool-manager.integration.test.ts index 201c70064..732f3a17b 100644 --- a/packages/core/src/tools/tool-manager.integration.test.ts +++ b/packages/core/src/tools/tool-manager.integration.test.ts @@ -10,6 +10,46 @@ import { AgentEventBus } from '../events/index.js'; import { ApprovalManager } from '../approval/manager.js'; import type { AllowedToolsProvider } from './confirmation/allowed-tools-provider/types.js'; import { createMockLogger } from '../logger/v2/test-utils.js'; +import { + createInMemorySessionApprovalStore, + createInMemorySessionToolPreferencesStore, +} from '../test-utils/session-state-stores.js'; + +type ToolManagerFactoryArgs = + ConstructorParameters extends [ + infer McpManager, + infer ApprovalManager, + infer AllowedToolsProvider, + infer ApprovalMode, + infer AgentEventBus, + infer ToolPolicies, + infer Tools, + infer Logger, + infer _SessionToolPreferencesStore, + ] + ? [ + McpManager, + ApprovalManager, + AllowedToolsProvider, + ApprovalMode, + AgentEventBus, + ToolPolicies, + Tools, + Logger, + ] + : never; + +function createToolManager(...args: ToolManagerFactoryArgs): ToolManager { + const logger = args[7]; + return new ToolManager(...args, createInMemorySessionToolPreferencesStore(logger)); +} + +function createApprovalManager( + config: ConstructorParameters[0], + logger: ConstructorParameters[1] +): ApprovalManager { + return new ApprovalManager(config, logger, createInMemorySessionApprovalStore(logger)); +} // Mock logger vi.mock('../logger/index.js', () => ({ @@ -113,7 +153,7 @@ describe('ToolManager Integration Tests', () => { } as any; // Create ApprovalManager in auto-approve mode for integration tests - approvalManager = new ApprovalManager( + approvalManager = createApprovalManager( { permissions: { mode: 'auto-approve', @@ -166,7 +206,7 @@ describe('ToolManager Integration Tests', () => { await (mcpManager as any).updateClientCache('test-server', mockClient); // Create ToolManager with real components - const toolManager = new ToolManager( + const toolManager = createToolManager( mcpManager, approvalManager, allowedToolsProvider, @@ -191,7 +231,7 @@ describe('ToolManager Integration Tests', () => { it('should execute local tools through the complete pipeline', async () => { // Create ToolManager with local tools - const toolManager = new ToolManager( + const toolManager = createToolManager( mcpManager, approvalManager, allowedToolsProvider, @@ -245,7 +285,7 @@ describe('ToolManager Integration Tests', () => { await (mcpManager as any).updateClientCache('file-server', mockClient); // Create ToolManager with both MCP and local tools - const toolManager = new ToolManager( + const toolManager = createToolManager( mcpManager, approvalManager, allowedToolsProvider, @@ -304,7 +344,7 @@ describe('ToolManager Integration Tests', () => { describe('Confirmation Flow Integration', () => { it('should work with auto-approve mode', async () => { - const autoApproveManager = new ApprovalManager( + const autoApproveManager = createApprovalManager( { permissions: { mode: 'auto-approve', @@ -334,7 +374,7 @@ describe('ToolManager Integration Tests', () => { mcpMgr.registerClient('test-server', mockClient); await (mcpMgr as any).updateClientCache('test-server', mockClient); - const toolManager = new ToolManager( + const toolManager = createToolManager( mcpMgr, autoApproveManager, allowedToolsProvider, @@ -350,7 +390,7 @@ describe('ToolManager Integration Tests', () => { }); it('should work with auto-deny mode', async () => { - const autoDenyManager = new ApprovalManager( + const autoDenyManager = createApprovalManager( { permissions: { mode: 'auto-deny', @@ -380,7 +420,7 @@ describe('ToolManager Integration Tests', () => { mcpMgr.registerClient('test-server', mockClient); await (mcpMgr as any).updateClientCache('test-server', mockClient); - const toolManager = new ToolManager( + const toolManager = createToolManager( mcpMgr, autoDenyManager, allowedToolsProvider, @@ -414,7 +454,7 @@ describe('ToolManager Integration Tests', () => { mcpManager.registerClient('failing-server', failingClient); - const toolManager = new ToolManager( + const toolManager = createToolManager( mcpManager, approvalManager, allowedToolsProvider, @@ -453,7 +493,7 @@ describe('ToolManager Integration Tests', () => { mcpManager.registerClient('failing-server', failingClient); await (mcpManager as any).updateClientCache('failing-server', failingClient); - const toolManager = new ToolManager( + const toolManager = createToolManager( mcpManager, approvalManager, allowedToolsProvider, @@ -476,7 +516,7 @@ describe('ToolManager Integration Tests', () => { searchSessions: vi.fn().mockRejectedValue(new Error('Search service failed')), } as any; - const toolManager = new ToolManager( + const toolManager = createToolManager( mcpManager, approvalManager, allowedToolsProvider, @@ -522,7 +562,7 @@ describe('ToolManager Integration Tests', () => { expect(mockClient.getTools).toHaveBeenCalledTimes(1); vi.mocked(mockClient.getTools).mockClear(); - const toolManager = new ToolManager( + const toolManager = createToolManager( mcpManager, approvalManager, allowedToolsProvider, @@ -564,7 +604,7 @@ describe('ToolManager Integration Tests', () => { expect(mockClient.getTools).toHaveBeenCalledTimes(1); vi.mocked(mockClient.getTools).mockClear(); - const toolManager = new ToolManager( + const toolManager = createToolManager( mcpManager, approvalManager, allowedToolsProvider, @@ -610,7 +650,7 @@ describe('ToolManager Integration Tests', () => { mcpManager.registerClient('test-server', mockClient); await (mcpManager as any).updateClientCache('test-server', mockClient); - const toolManager = new ToolManager( + const toolManager = createToolManager( mcpManager, approvalManager, allowedToolsProvider, diff --git a/packages/core/src/tools/tool-manager.test.ts b/packages/core/src/tools/tool-manager.test.ts index 7f9850dfe..52ca5e18c 100644 --- a/packages/core/src/tools/tool-manager.test.ts +++ b/packages/core/src/tools/tool-manager.test.ts @@ -12,6 +12,47 @@ import type { AllowedToolsProvider } from './confirmation/allowed-tools-provider import { ApprovalStatus, ApprovalType } from '../approval/types.js'; import { createMockLogger } from '../logger/v2/test-utils.js'; import { SessionError } from '../session/errors.js'; +import { createInMemorySessionToolPreferencesStore } from '../test-utils/session-state-stores.js'; +import type { SessionToolPreferences } from './session-tool-preferences-store.js'; + +function createDeferred() { + let resolve!: (value: T | PromiseLike) => void; + let reject!: (reason?: unknown) => void; + const promise = new Promise((res, rej) => { + resolve = res; + reject = rej; + }); + return { promise, resolve, reject }; +} + +type ToolManagerFactoryArgs = + ConstructorParameters extends [ + infer McpManager, + infer ApprovalManager, + infer AllowedToolsProvider, + infer ApprovalMode, + infer AgentEventBus, + infer ToolPolicies, + infer Tools, + infer Logger, + infer _SessionToolPreferencesStore, + ] + ? [ + McpManager, + ApprovalManager, + AllowedToolsProvider, + ApprovalMode, + AgentEventBus, + ToolPolicies, + Tools, + Logger, + ] + : never; + +function createToolManager(...args: ToolManagerFactoryArgs): ToolManager { + const logger = args[7]; + return new ToolManager(...args, createInMemorySessionToolPreferencesStore(logger)); +} // Mock logger vi.mock('../logger/index.js', () => ({ @@ -77,7 +118,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { describe('Tool Source Detection Logic', () => { it('should correctly identify MCP tools', () => { - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -93,7 +134,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { }); it('should correctly identify local tools', () => { - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -122,7 +163,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { }); it('should identify unknown tools', () => { - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -139,7 +180,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { }); it('should handle edge cases with empty tool names', () => { - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -156,7 +197,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { describe('Contributor Context', () => { it('includes session context when a session id is provided', async () => { - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -175,7 +216,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { }); it('preserves session context when contributor overrides add environment data', async () => { - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -205,7 +246,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { }); it('includes session prompt contributors when session manager support is configured', async () => { - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -247,7 +288,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { }); it('gracefully ignores missing sessions when building contributor context', async () => { - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -300,7 +341,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { it('should return not found for unknown tools', async () => { mockMcpManager.getAllTools = vi.fn().mockResolvedValue({}); - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -321,7 +362,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { }); it('should reject MCP tools with prefix but no name', async () => { - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -345,7 +386,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { }); it('should return not found when local tool is not registered', async () => { - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -370,7 +411,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { it('should execute local tools provided to ToolManager', async () => { mockMcpManager.getAllTools = vi.fn().mockResolvedValue({}); - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -417,7 +458,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { execute: vi.fn().mockResolvedValue('ok'), }); - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -462,7 +503,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { execute: vi.fn().mockResolvedValue('ok'), }); - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -523,7 +564,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { it('should emit callDescription on llm:tool-call events when args.description is provided', async () => { mockMcpManager.executeTool = vi.fn().mockResolvedValue('result'); - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -576,7 +617,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { execute: vi.fn().mockResolvedValue('ok'), }); - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -669,7 +710,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { callOrder.push('addApprovedDirectory'); }); - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -846,7 +887,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { } ); - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -871,7 +912,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { it('should request approval via ApprovalManager with correct parameters', async () => { mockMcpManager.executeTool = vi.fn().mockResolvedValue('result'); - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -929,7 +970,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { status: ApprovalStatus.APPROVED, }); - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -962,7 +1003,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { it('should fall back to args.description for approval description when __meta.callDescription is missing', async () => { mockMcpManager.executeTool = vi.fn().mockResolvedValue('result'); - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -1000,7 +1041,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { const emitSpy = vi.fn(); mockAgentEventBus.emit = emitSpy as typeof mockAgentEventBus.emit; - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -1055,7 +1096,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { const emitSpy = vi.fn(); mockAgentEventBus.emit = emitSpy as typeof mockAgentEventBus.emit; - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -1092,7 +1133,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { it('should request approval without sessionId when not provided', async () => { mockMcpManager.executeTool = vi.fn().mockResolvedValue('result'); - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -1121,7 +1162,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { status: ApprovalStatus.DENIED, }); - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -1146,7 +1187,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { it('should proceed with execution when approval granted', async () => { mockMcpManager.executeTool = vi.fn().mockResolvedValue('success'); - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -1181,7 +1222,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { mockAllowedToolsProvider.isToolAllowed = vi.fn().mockResolvedValue(true); mockMcpManager.executeTool = vi.fn().mockResolvedValue('success'); - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -1209,7 +1250,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { it('should auto-approve when mode is auto-approve', async () => { mockMcpManager.executeTool = vi.fn().mockResolvedValue('success'); - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -1232,7 +1273,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { }); it('should auto-deny when mode is auto-deny', async () => { - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -1259,7 +1300,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { const getDescription = vi.fn().mockReturnValue('Dynamic description'); mockMcpManager.getAllTools = vi.fn().mockResolvedValue({}); - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -1296,7 +1337,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { }; mockMcpManager.getAllTools = vi.fn().mockResolvedValue(tools); - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -1322,7 +1363,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { .mockResolvedValueOnce('Workspace agents: explore-agent'); mockMcpManager.getAllTools = vi.fn().mockResolvedValue({}); - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -1363,7 +1404,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { .mockResolvedValueOnce(' '); mockMcpManager.getAllTools = vi.fn().mockResolvedValue({}); - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -1401,7 +1442,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { }; mockMcpManager.getAllTools = vi.fn().mockResolvedValue(tools); - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -1443,7 +1484,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { return mockAgentEventBus; }) as any; - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -1491,7 +1532,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { mockMcpManager.getAllTools = vi.fn().mockResolvedValue(mcpTools); - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -1514,7 +1555,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { it('should handle empty tool sets', async () => { mockMcpManager.getAllTools = vi.fn().mockResolvedValue({}); - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -1537,7 +1578,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { it('should handle MCP errors gracefully in statistics', async () => { mockMcpManager.getAllTools = vi.fn().mockRejectedValue(new Error('MCP failed')); - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -1562,7 +1603,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { it('should check MCP tool existence correctly', async () => { mockMcpManager.getToolClient = vi.fn().mockReturnValue({}); - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -1582,7 +1623,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { it('should return false for non-existent MCP tools', async () => { mockMcpManager.getToolClient = vi.fn().mockReturnValue(undefined); - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -1599,7 +1640,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { }); it('should return false for tools without proper prefix', async () => { - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -1621,7 +1662,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { const executionError = new Error('Tool execution failed'); mockMcpManager.executeTool = vi.fn().mockRejectedValue(executionError); - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -1641,7 +1682,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { const approvalError = new Error('Approval request failed'); mockApprovalManager.requestToolApproval = vi.fn().mockRejectedValue(approvalError); - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -1672,7 +1713,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { alwaysDeny: ['mcp--filesystem--delete_file'], }; - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -1702,7 +1743,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { alwaysDeny: ['mcp--filesystem--delete_file'], // Deny takes precedence }; - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -1730,7 +1771,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { alwaysDeny: [], }; - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -1766,7 +1807,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { alwaysDeny: [], }; - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -1803,7 +1844,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { alwaysDeny: ['mcp--filesystem--delete_file'], }; - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -1838,7 +1879,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { alwaysDeny: ['mcp--filesystem--delete_file'], }; - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -1866,7 +1907,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { alwaysDeny: ['mcp--filesystem--delete_file'], }; - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -1895,7 +1936,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { alwaysDeny: [], }; - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -1922,7 +1963,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { alwaysDeny: [], }; - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -1953,7 +1994,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { data: {}, }); - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -1992,7 +2033,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { alwaysDeny: [], }; - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -2027,7 +2068,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { alwaysDeny: [], }; - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -2065,7 +2106,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { alwaysDeny: [], }; - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -2092,7 +2133,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { alwaysDeny: [], }; - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -2127,7 +2168,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { alwaysDeny: ['mcp--delete_file'], }; - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -2151,7 +2192,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { alwaysDeny: ['mcp--delete_file'], // Simple policy }; - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -2186,7 +2227,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { alwaysDeny: [], }; - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -2228,7 +2269,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { alwaysDeny: [], }; - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -2265,7 +2306,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { alwaysDeny: ['mcp--delete_file', 'mcp--execute_script'], }; - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -2312,7 +2353,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { describe('Session Auto-Approve Tools (Skill allowed-tools)', () => { describe('Basic CRUD Operations', () => { it('should set and get session auto-approve tools', () => { - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -2333,7 +2374,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { }); it('should normalize local tool aliases when setting auto-approve tools', () => { - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -2359,7 +2400,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { }); it('should return false/undefined for non-existent sessions', () => { - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -2375,7 +2416,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { }); it('should clear session auto-approve tools', () => { - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -2398,7 +2439,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { }); it('should handle multiple sessions independently', () => { - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -2432,7 +2473,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { }); it('should overwrite existing tools when setting again', () => { - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -2454,7 +2495,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { }); it('should clear auto-approvals when setting empty array', () => { - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -2478,8 +2519,109 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { expect(toolManager.getSessionAutoApproveTools(sessionId)).toBeUndefined(); }); - it('should merge tools when adding to session auto-approve list', () => { + it('should not keep an empty user auto-approve key when restored state has no tools', async () => { + const emptyPreferencesStore = { + load: vi.fn().mockResolvedValue({ + userAutoApproveTools: [], + disabledTools: [], + } satisfies SessionToolPreferences), + save: vi.fn().mockResolvedValue(undefined), + delete: vi.fn().mockResolvedValue(undefined), + }; const toolManager = new ToolManager( + mockMcpManager, + mockApprovalManager, + mockAllowedToolsProvider, + 'manual', + mockAgentEventBus, + { alwaysAllow: [], alwaysDeny: [] }, + [], + mockLogger, + emptyPreferencesStore as unknown as ConstructorParameters[8] + ); + + await toolManager.restoreSessionState('restored-session'); + + expect(toolManager.hasSessionUserAutoApproveTools('restored-session')).toBe(false); + expect( + toolManager.getSessionUserAutoApproveTools('restored-session') + ).toBeUndefined(); + }); + + it('should serialize deleteSessionState with in-flight preference persistence', async () => { + const sessionId = 'locked-delete-session'; + const saveStarted = createDeferred(); + const releaseSave = createDeferred(); + const persistedPreferences = new Map(); + const emptyPreferences: SessionToolPreferences = { + userAutoApproveTools: [], + disabledTools: [], + }; + const controlledStore = { + load: vi.fn().mockImplementation(async (requestedSessionId: string) => { + return structuredClone( + persistedPreferences.get(requestedSessionId) ?? emptyPreferences + ); + }), + save: vi + .fn() + .mockImplementation( + async ( + requestedSessionId: string, + preferences: SessionToolPreferences + ) => { + saveStarted.resolve(); + await releaseSave.promise; + persistedPreferences.set( + requestedSessionId, + structuredClone(preferences) + ); + } + ), + delete: vi.fn().mockImplementation(async (requestedSessionId: string) => { + persistedPreferences.delete(requestedSessionId); + }), + }; + const toolManager = new ToolManager( + mockMcpManager, + mockApprovalManager, + mockAllowedToolsProvider, + 'manual', + mockAgentEventBus, + { alwaysAllow: [], alwaysDeny: [] }, + [], + mockLogger, + controlledStore as unknown as ConstructorParameters[8] + ); + + const setDisabledPromise = toolManager.setSessionDisabledTools(sessionId, [ + 'bash_exec', + ]); + await saveStarted.promise; + + let deleteFinished = false; + const deletePromise = toolManager.deleteSessionState(sessionId).then(() => { + deleteFinished = true; + }); + + await Promise.resolve(); + expect(deleteFinished).toBe(false); + + releaseSave.resolve(); + await setDisabledPromise; + await deletePromise; + + expect( + persistedPreferences.get(sessionId) ?? { + userAutoApproveTools: [], + disabledTools: [], + } + ).toEqual(emptyPreferences); + expect(toolManager.getDisabledTools(sessionId)).toEqual([]); + }); + + it('should merge tools when adding to session auto-approve list', () => { + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -2501,7 +2643,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { }); it('should normalize aliases and ignore duplicates when adding auto-approve tools', () => { - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -2541,7 +2683,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { 'success' ); - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -2580,7 +2722,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { 'success' ); - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -2610,7 +2752,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { }, }); - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -2644,7 +2786,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { 'success' ); - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, @@ -2677,7 +2819,7 @@ describe('ToolManager - Unit Tests (Pure Logic)', () => { 'success' ); - const toolManager = new ToolManager( + const toolManager = createToolManager( mockMcpManager, mockApprovalManager, mockAllowedToolsProvider, diff --git a/packages/core/src/tools/tool-manager.ts b/packages/core/src/tools/tool-manager.ts index 83c5e269b..ba81d0e94 100644 --- a/packages/core/src/tools/tool-manager.ts +++ b/packages/core/src/tools/tool-manager.ts @@ -31,6 +31,10 @@ import type { BeforeToolCallPayload, AfterToolResultPayload } from '../hooks/typ import type { WorkspaceManager } from '../workspace/manager.js'; import type { WorkspaceContext } from '../workspace/types.js'; import { InstrumentClass } from '../telemetry/decorators.js'; +import type { + SessionToolPreferences, + SessionToolPreferencesStore, +} from './session-tool-preferences-store.js'; import { extractToolCallMeta, wrapToolParametersSchema, @@ -112,6 +116,8 @@ export class ToolManager { // Session-level auto-approve tools set by users (UI) private sessionUserAutoApproveTools: Map = new Map(); private sessionDisabledTools: Map = new Map(); + private readonly restoredSessionPreferences = new Set(); + private readonly sessionPreferenceLocks = new Map>(); private globalDisabledTools: string[] = []; private cleanupHandlers: Set<() => Promise | void> = new Set(); private cleanupStarted = false; @@ -149,7 +155,8 @@ export class ToolManager { agentEventBus: AgentEventBus, toolPolicies: ToolPolicies, tools: Tool[], - logger: Logger + logger: Logger, + private readonly sessionToolPreferencesStore: SessionToolPreferencesStore ) { this.mcpManager = mcpManager; this.approvalManager = approvalManager; @@ -171,6 +178,94 @@ export class ToolManager { this.logger.debug('ToolManager initialized'); } + private async runWithSessionPreferenceLock( + sessionId: string, + fn: () => Promise + ): Promise { + const previousLock = this.sessionPreferenceLocks.get(sessionId) ?? Promise.resolve(); + const currentResult = previousLock.catch(() => {}).then(() => fn()); + const currentLock = currentResult.then( + () => undefined, + () => undefined + ); + + this.sessionPreferenceLocks.set(sessionId, currentLock); + + try { + return await currentResult; + } finally { + if (this.sessionPreferenceLocks.get(sessionId) === currentLock) { + this.sessionPreferenceLocks.delete(sessionId); + } + } + } + + private applySessionToolPreferences( + sessionId: string, + preferences: SessionToolPreferences + ): void { + if (preferences.userAutoApproveTools.length > 0) { + this.sessionUserAutoApproveTools.set(sessionId, [...preferences.userAutoApproveTools]); + } else { + this.sessionUserAutoApproveTools.delete(sessionId); + } + if (preferences.disabledTools.length > 0) { + this.sessionDisabledTools.set(sessionId, [...preferences.disabledTools]); + } else { + this.sessionDisabledTools.delete(sessionId); + } + } + + private getSessionToolPreferencesSnapshot(sessionId: string): SessionToolPreferences { + return { + userAutoApproveTools: [...(this.sessionUserAutoApproveTools.get(sessionId) ?? [])], + disabledTools: [...(this.sessionDisabledTools.get(sessionId) ?? [])], + }; + } + + async restoreSessionState(sessionId: string): Promise { + if (this.restoredSessionPreferences.has(sessionId)) { + return; + } + + await this.runWithSessionPreferenceLock(sessionId, async () => { + if (this.restoredSessionPreferences.has(sessionId)) { + return; + } + + const preferences = await this.sessionToolPreferencesStore.load(sessionId); + this.applySessionToolPreferences(sessionId, preferences); + this.restoredSessionPreferences.add(sessionId); + + this.logger.debug('Restored persisted session tool preferences', { + sessionId, + autoApproveCount: preferences.userAutoApproveTools.length, + disabledCount: preferences.disabledTools.length, + }); + }); + } + + evictSessionState(sessionId: string): void { + this.sessionAutoApproveTools.delete(sessionId); + this.sessionUserAutoApproveTools.delete(sessionId); + this.sessionDisabledTools.delete(sessionId); + this.restoredSessionPreferences.delete(sessionId); + } + + async deleteSessionState(sessionId: string): Promise { + await this.runWithSessionPreferenceLock(sessionId, async () => { + this.evictSessionState(sessionId); + await this.sessionToolPreferencesStore.delete(sessionId); + }); + } + + private async persistSessionToolPreferences(sessionId: string): Promise { + await this.sessionToolPreferencesStore.save( + sessionId, + this.getSessionToolPreferencesSnapshot(sessionId) + ); + } + /** * Initialize the ToolManager and its components */ @@ -348,15 +443,25 @@ export class ToolManager { /** * Set session-level auto-approve tools chosen by the user. */ - setSessionUserAutoApproveTools(sessionId: string, autoApproveTools: string[]): void { + async setSessionUserAutoApproveTools( + sessionId: string, + autoApproveTools: string[] + ): Promise { + await this.restoreSessionState(sessionId); if (autoApproveTools.length === 0) { - this.clearSessionUserAutoApproveTools(sessionId); + await this.clearSessionUserAutoApproveTools(sessionId); return; } + const normalized = autoApproveTools.map((pattern) => this.normalizeToolPolicyPattern(pattern) ); - this.sessionUserAutoApproveTools.set(sessionId, normalized); + + await this.runWithSessionPreferenceLock(sessionId, async () => { + this.sessionUserAutoApproveTools.set(sessionId, normalized); + await this.persistSessionToolPreferences(sessionId); + }); + this.logger.info( `Session user auto-approve tools set for '${sessionId}': ${autoApproveTools.length} tools` ); @@ -366,9 +471,16 @@ export class ToolManager { /** * Clear session-level auto-approve tools chosen by the user. */ - clearSessionUserAutoApproveTools(sessionId: string): void { - const hadAutoApprove = this.sessionUserAutoApproveTools.has(sessionId); - this.sessionUserAutoApproveTools.delete(sessionId); + async clearSessionUserAutoApproveTools(sessionId: string): Promise { + await this.restoreSessionState(sessionId); + + let hadAutoApprove = false; + await this.runWithSessionPreferenceLock(sessionId, async () => { + hadAutoApprove = this.sessionUserAutoApproveTools.has(sessionId); + this.sessionUserAutoApproveTools.delete(sessionId); + await this.persistSessionToolPreferences(sessionId); + }); + if (hadAutoApprove) { this.logger.info(`Session user auto-approve tools cleared for '${sessionId}'`); } @@ -416,12 +528,18 @@ export class ToolManager { /** * Set session-level disabled tools (overrides global list). */ - setSessionDisabledTools(sessionId: string, toolNames: string[]): void { + async setSessionDisabledTools(sessionId: string, toolNames: string[]): Promise { + await this.restoreSessionState(sessionId); if (toolNames.length === 0) { - this.clearSessionDisabledTools(sessionId); + await this.clearSessionDisabledTools(sessionId); return; } - this.sessionDisabledTools.set(sessionId, [...toolNames]); + + await this.runWithSessionPreferenceLock(sessionId, async () => { + this.sessionDisabledTools.set(sessionId, [...toolNames]); + await this.persistSessionToolPreferences(sessionId); + }); + this.logger.info('Session disabled tools updated', { sessionId, count: toolNames.length, @@ -436,9 +554,16 @@ export class ToolManager { /** * Clear session-level disabled tools. */ - clearSessionDisabledTools(sessionId: string): void { - const hadOverrides = this.sessionDisabledTools.has(sessionId); - this.sessionDisabledTools.delete(sessionId); + async clearSessionDisabledTools(sessionId: string): Promise { + await this.restoreSessionState(sessionId); + + let hadOverrides = false; + await this.runWithSessionPreferenceLock(sessionId, async () => { + hadOverrides = this.sessionDisabledTools.has(sessionId); + this.sessionDisabledTools.delete(sessionId); + await this.persistSessionToolPreferences(sessionId); + }); + if (hadOverrides) { this.logger.info('Session disabled tools cleared', { sessionId }); } @@ -1019,7 +1144,7 @@ export class ToolManager { } if (!patternKey) return false; - return this.approvalManager.matchesPattern(toolName, patternKey); + return this.approvalManager.matchesPattern(toolName, patternKey, sessionId); }, { rememberPattern: undefined } // Don't propagate pattern choice to auto-approved requests ); @@ -1055,7 +1180,10 @@ export class ToolManager { return false; } - return this.approvalManager.isDirectorySessionApproved(directoryAccess.parentDir); + return this.approvalManager.isDirectorySessionApproved( + directoryAccess.parentDir, + sessionId + ); }, { rememberDirectory: false } ); @@ -1921,7 +2049,11 @@ export class ToolManager { // This bypasses remembered-tool approvals and edit-mode auto-approvals for outside-root paths. if (directoryAccess) { if (this.approvalMode === 'auto-approve') { - this.approvalManager.addApprovedDirectory(directoryAccess.parentDir, 'once'); + await this.approvalManager.addApprovedDirectory( + directoryAccess.parentDir, + 'once', + sessionId + ); return { requireApproval: false }; } return null; @@ -1953,7 +2085,7 @@ export class ToolManager { // 6. Check tool approval patterns const patternKey = this.getToolPatternKey(toolName, args); - if (patternKey && this.approvalManager.matchesPattern(toolName, patternKey)) { + if (patternKey && this.approvalManager.matchesPattern(toolName, patternKey, sessionId)) { this.logger.info( `Tool '${toolName}' matched approved pattern key '${patternKey}' – skipping confirmation.` ); @@ -2113,7 +2245,7 @@ export class ToolManager { ); this.autoApprovePendingToolRequests(toolName, allowSessionId); } else if (rememberPattern && this.getToolApprovalPatternKeyFn(toolName)) { - this.approvalManager.addPattern(toolName, rememberPattern); + await this.approvalManager.addPattern(toolName, rememberPattern, sessionId); this.logger.info(`Pattern '${rememberPattern}' added for tool '${toolName}' approval`); this.autoApprovePendingPatternRequests(toolName, sessionId); } else if (rememberDirectory) { diff --git a/packages/core/src/utils/service-initializer.ts b/packages/core/src/utils/service-initializer.ts index d191e574e..9109a875f 100644 --- a/packages/core/src/utils/service-initializer.ts +++ b/packages/core/src/utils/service-initializer.ts @@ -24,10 +24,13 @@ import type { AgentRuntimeSettings } from '../agent/runtime-config.js'; import { AgentEventBus } from '../events/index.js'; import { ResourceManager } from '../resources/manager.js'; import { ApprovalManager } from '../approval/manager.js'; +import { SessionApprovalStore } from '../approval/session-approval-store.js'; import { MemoryManager } from '../memory/index.js'; import { HookManager } from '../hooks/manager.js'; import type { Hook } from '../hooks/types.js'; import type { CompactionStrategy } from '../context/compaction/types.js'; +import { MessageQueueStore } from '../session/message-queue-store.js'; +import { SessionToolPreferencesStore } from '../tools/session-tool-preferences-store.js'; import type { LanguageModelFactory } from '../llm/services/types.js'; /** @@ -144,6 +147,17 @@ export async function createAgentServices( logger.debug('Storage manager initialized', await storageManager.getInfo()); + const sessionCacheTtlMs = config.sessions?.sessionTTL ?? 3600000; + const sessionApprovalStore = new SessionApprovalStore(storageManager, logger, { + cacheTtlMs: sessionCacheTtlMs, + }); + const sessionToolPreferencesStore = new SessionToolPreferencesStore(storageManager, logger, { + cacheTtlMs: sessionCacheTtlMs, + }); + const messageQueueStore = new MessageQueueStore(storageManager, logger, { + cacheTtlMs: sessionCacheTtlMs, + }); + // 2.5 Initialize workspace manager (uses persistent database) const workspaceManager = new WorkspaceManager( storageManager.getDatabase(), @@ -170,7 +184,8 @@ export async function createAgentServices( }), }, }, - logger + logger, + sessionApprovalStore ); logger.debug('Approval system initialized'); @@ -250,7 +265,8 @@ export async function createAgentServices( agentEventBus, config.permissions.toolPolicies, [], - logger + logger, + sessionToolPreferencesStore ); await toolManager.setWorkspaceManager(workspaceManager); // NOTE: local tools + ToolExecutionContext are wired in DextoAgent.start() @@ -280,11 +296,13 @@ export async function createAgentServices( stateManager, systemPromptManager, toolManager, + approvalManager, agentEventBus, storageManager, // Add storage manager to session services resourceManager, // Add resource manager for blob storage hookManager, // Add hook manager for hook execution mcpManager, // Add MCP manager for ChatSession + messageQueueStore, compactionStrategy: compactionStrategy ?? null, workspaceManager, // Workspace context propagation }, diff --git a/packages/tools-filesystem/src/directory-approval.integration.test.ts b/packages/tools-filesystem/src/directory-approval.integration.test.ts index 51fb8a1c1..d8a8e3207 100644 --- a/packages/tools-filesystem/src/directory-approval.integration.test.ts +++ b/packages/tools-filesystem/src/directory-approval.integration.test.ts @@ -17,6 +17,7 @@ import * as fs from 'node:fs/promises'; import * as os from 'node:os'; import { ApprovalManager, + ApprovalStatus, DextoRuntimeError, type Logger, type ToolExecutionContext, @@ -25,8 +26,10 @@ import { FileSystemService } from './filesystem-service.js'; import { createReadFileTool } from './read-file-tool.js'; import { createWriteFileTool } from './write-file-tool.js'; import { createEditFileTool } from './edit-file-tool.js'; +import { fileSystemToolsFactory } from './tool-factory.js'; type ToolServices = NonNullable; +type SessionApprovalStore = ConstructorParameters[2]; const createMockLogger = (): Logger => { const noopAsync = async () => undefined; @@ -49,9 +52,47 @@ const createMockLogger = (): Logger => { return logger; }; -function createToolContext(logger: Logger, approval: ApprovalManager): ToolExecutionContext { +function createInMemorySessionApprovalStore(): SessionApprovalStore { + const states = new Map< + string, + { + toolPatterns: Record; + approvedDirectories: Array<{ path: string; type: 'session' | 'once' }>; + } + >(); + + return { + async load(sessionId?: string) { + return structuredClone( + states.get(sessionId ?? '__global__') ?? { + toolPatterns: {}, + approvedDirectories: [], + } + ); + }, + async save( + sessionId: string | undefined, + state: { + toolPatterns: Record; + approvedDirectories: Array<{ path: string; type: 'session' | 'once' }>; + } + ) { + states.set(sessionId ?? '__global__', structuredClone(state)); + }, + async delete(sessionId?: string) { + states.delete(sessionId ?? '__global__'); + }, + } as SessionApprovalStore; +} + +function createToolContext( + logger: Logger, + approval: ApprovalManager, + sessionId?: string +): ToolExecutionContext { return { logger, + ...(sessionId !== undefined ? { sessionId } : {}), services: { approval, search: {} as unknown as ToolServices['search'], @@ -98,7 +139,8 @@ describe('Directory Approval Integration Tests', () => { permissions: { mode: 'manual' }, elicitation: { enabled: true }, }, - mockLogger + mockLogger, + createInMemorySessionApprovalStore() ); toolContext = createToolContext(mockLogger, approvalManager); @@ -155,7 +197,7 @@ describe('Directory Approval Integration Tests', () => { }); it('should return null when external path is session-approved', async () => { - approvalManager.addApprovedDirectory('/external/project', 'session'); + await approvalManager.addApprovedDirectory('/external/project', 'session'); const tool = createReadFileTool(getFileSystemService); const overrideFn = tool.approval?.override; @@ -170,7 +212,7 @@ describe('Directory Approval Integration Tests', () => { }); it('should still return metadata when external path is once-approved (prompt again)', async () => { - approvalManager.addApprovedDirectory('/external/project', 'once'); + await approvalManager.addApprovedDirectory('/external/project', 'once'); const tool = createReadFileTool(getFileSystemService); const overrideFn = tool.approval?.override; @@ -183,6 +225,47 @@ describe('Directory Approval Integration Tests', () => { ); expect(metadata).not.toBeNull(); }); + + it('should remember directory approvals only for the granting session', async () => { + const tool = createReadFileTool(getFileSystemService); + const overrideFn = tool.approval?.override; + const onGrantedFn = tool.approval?.onGranted; + expect(overrideFn).toBeDefined(); + expect(onGrantedFn).toBeDefined(); + + const externalPath = '/external/project/file.ts'; + const sessionAContext = createToolContext(mockLogger, approvalManager, 'session-a'); + const sessionBContext = createToolContext(mockLogger, approvalManager, 'session-b'); + + const approvalRequest = await overrideFn!( + tool.inputSchema.parse({ file_path: externalPath }), + sessionAContext + ); + expect(approvalRequest).not.toBeNull(); + + await onGrantedFn!( + { + approvalId: 'approval-1', + status: ApprovalStatus.APPROVED, + data: { rememberDirectory: true }, + }, + sessionAContext, + approvalRequest! + ); + + expect( + await overrideFn!( + tool.inputSchema.parse({ file_path: externalPath }), + sessionAContext + ) + ).toBeNull(); + expect( + await overrideFn!( + tool.inputSchema.parse({ file_path: externalPath }), + sessionBContext + ) + ).not.toBeNull(); + }); }); describe('Different tool operations', () => { @@ -242,7 +325,7 @@ describe('Directory Approval Integration Tests', () => { const tool = createReadFileTool(getFileSystemService); const overrideFn = tool.approval?.override; expect(overrideFn).toBeDefined(); - approvalManager.addApprovedDirectory('/external/project', 'session'); + await approvalManager.addApprovedDirectory('/external/project', 'session'); const metadata1 = await overrideFn!( tool.inputSchema.parse({ file_path: '/external/project/file.ts' }), @@ -261,7 +344,7 @@ describe('Directory Approval Integration Tests', () => { const tool = createReadFileTool(getFileSystemService); const overrideFn = tool.approval?.override; expect(overrideFn).toBeDefined(); - approvalManager.addApprovedDirectory('/external/sub', 'session'); + await approvalManager.addApprovedDirectory('/external/sub', 'session'); const metadata1 = await overrideFn!( tool.inputSchema.parse({ file_path: '/external/sub/file.ts' }), @@ -277,6 +360,194 @@ describe('Directory Approval Integration Tests', () => { }); }); + describe('Execution approval scoping', () => { + it('should allow execution only for the session that holds the approved directory', async () => { + const tools = fileSystemToolsFactory.create({ + type: 'filesystem-tools', + allowedPaths: [tempDir], + blockedPaths: [], + blockedExtensions: [], + maxFileSize: 10 * 1024 * 1024, + workingDirectory: tempDir, + enableBackups: false, + backupRetentionDays: 7, + }); + const writeTool = tools.find((tool) => tool.id === 'write_file'); + expect(writeTool).toBeDefined(); + + const externalDir = await fs.mkdtemp(path.join(os.tmpdir(), 'dexto-fs-external-')); + try { + const externalFile = path.join(externalDir, 'approved.txt'); + + await approvalManager.addApprovedDirectory(externalDir, 'once', 'session-a'); + + await expect( + writeTool!.execute!( + writeTool!.inputSchema.parse({ + file_path: externalFile, + content: 'session-scoped write', + }), + createToolContext(mockLogger, approvalManager, 'session-a') + ) + ).resolves.toEqual( + expect.objectContaining({ + success: true, + path: path.resolve(externalFile), + }) + ); + + await expect(fs.readFile(externalFile, 'utf8')).resolves.toBe( + 'session-scoped write' + ); + + await expect( + writeTool!.execute!( + writeTool!.inputSchema.parse({ + file_path: path.join(externalDir, 'blocked.txt'), + content: 'should fail', + }), + createToolContext(mockLogger, approvalManager, 'session-b') + ) + ).rejects.toBeInstanceOf(DextoRuntimeError); + } finally { + await fs.rm(externalDir, { recursive: true, force: true }); + } + }); + }); + + describe('Factory service scoping', () => { + it('should keep working directories isolated across concurrent executions', async () => { + const workspaceA = path.join(tempDir, 'workspace-a'); + const workspaceB = path.join(tempDir, 'workspace-b'); + await fs.mkdir(workspaceA, { recursive: true }); + await fs.mkdir(workspaceB, { recursive: true }); + await fs.writeFile(path.join(workspaceA, 'same.txt'), 'from workspace A'); + await fs.writeFile(path.join(workspaceB, 'same.txt'), 'from workspace B'); + + const tools = fileSystemToolsFactory.create({ + type: 'filesystem-tools', + allowedPaths: [workspaceA, workspaceB], + blockedPaths: [], + blockedExtensions: [], + maxFileSize: 10 * 1024 * 1024, + workingDirectory: tempDir, + enableBackups: false, + backupRetentionDays: 7, + enabledTools: ['read_file'], + }); + const readTool = tools.find((tool) => tool.id === 'read_file'); + expect(readTool).toBeDefined(); + + const baseContext = createToolContext(mockLogger, approvalManager); + const contextA: ToolExecutionContext = { + ...baseContext, + sessionId: 'session-a', + workspace: { + id: 'workspace-a', + path: workspaceA, + createdAt: Date.now(), + lastActiveAt: Date.now(), + }, + }; + const contextB: ToolExecutionContext = { + ...baseContext, + sessionId: 'session-b', + workspace: { + id: 'workspace-b', + path: workspaceB, + createdAt: Date.now(), + lastActiveAt: Date.now(), + }, + }; + + const [resultA, resultB] = await Promise.all([ + readTool!.execute!( + readTool!.inputSchema.parse({ file_path: 'same.txt' }), + contextA + ), + readTool!.execute!( + readTool!.inputSchema.parse({ file_path: 'same.txt' }), + contextB + ), + ]); + + expect(resultA).toEqual( + expect.objectContaining({ + content: 'from workspace A', + }) + ); + expect(resultB).toEqual( + expect.objectContaining({ + content: 'from workspace B', + }) + ); + }); + + it('should not mutate an injected filesystem service with session-specific state', async () => { + const workspace = path.join(tempDir, 'workspace-injected'); + await fs.mkdir(workspace, { recursive: true }); + await fs.writeFile(path.join(workspace, 'same.txt'), 'from injected config'); + + const tools = fileSystemToolsFactory.create({ + type: 'filesystem-tools', + allowedPaths: [workspace], + blockedPaths: [], + blockedExtensions: [], + maxFileSize: 10 * 1024 * 1024, + workingDirectory: tempDir, + enableBackups: false, + backupRetentionDays: 7, + enabledTools: ['read_file'], + }); + const readTool = tools.find((tool) => tool.id === 'read_file'); + expect(readTool).toBeDefined(); + + const injectedService = { + getConfig: vi.fn().mockReturnValue({ + allowedPaths: [workspace], + blockedPaths: [], + blockedExtensions: [], + maxFileSize: 10 * 1024 * 1024, + workingDirectory: workspace, + enableBackups: false, + backupRetentionDays: 7, + }), + readFile: vi.fn(), + writeFile: vi.fn(), + setWorkingDirectory: vi.fn(), + setDirectoryApprovalChecker: vi.fn(), + }; + + const baseContext = createToolContext(mockLogger, approvalManager, 'session-a'); + const context: ToolExecutionContext = { + ...baseContext, + workspace: { + id: 'workspace-injected', + path: workspace, + createdAt: Date.now(), + lastActiveAt: Date.now(), + }, + services: { + ...baseContext.services, + filesystemService: injectedService, + } as ToolExecutionContext['services'], + }; + + const result = await readTool!.execute!( + readTool!.inputSchema.parse({ file_path: 'same.txt' }), + context + ); + + expect(result).toEqual( + expect.objectContaining({ + content: 'from injected config', + }) + ); + expect(injectedService.setWorkingDirectory).not.toHaveBeenCalled(); + expect(injectedService.setDirectoryApprovalChecker).not.toHaveBeenCalled(); + }); + }); + describe('Without ApprovalManager in context', () => { it('should throw for external paths', async () => { const tool = createReadFileTool(getFileSystemService); diff --git a/packages/tools-filesystem/src/directory-approval.ts b/packages/tools-filesystem/src/directory-approval.ts index 42ac6283c..b4b2b8397 100644 --- a/packages/tools-filesystem/src/directory-approval.ts +++ b/packages/tools-filesystem/src/directory-approval.ts @@ -63,7 +63,7 @@ export function createDirectoryAccessApprovalHandlers = { backupRetentionDays: config.backupRetentionDays, }; - let fileSystemService: FileSystemService | undefined; - const resolveWorkingDirectory = (context: ToolExecutionContext): string => context.workspace?.path ?? fileSystemConfig.workingDirectory ?? process.cwd(); - const applyWorkspace = (context: ToolExecutionContext, service: FileSystemService) => { - const workingDirectory = resolveWorkingDirectory(context); - service.setWorkingDirectory(workingDirectory); + const createScopedFileSystemService = ( + context: ToolExecutionContext, + baseConfig: FileSystemConfig + ): FileSystemService => { + const approvalManager = context.services?.approval; + if (!approvalManager) { + throw ToolError.configInvalid( + 'filesystem-tools requires ToolExecutionContext.services.approval' + ); + } + + const service = new FileSystemService( + { + ...baseConfig, + workingDirectory: resolveWorkingDirectory(context), + }, + context.logger + ); + service.setDirectoryApprovalChecker((filePath: string) => + approvalManager.isDirectoryApproved(filePath, context.sessionId) + ); + return service; }; - const resolveInjectedService = ( + const resolveInjectedServiceConfig = ( context: ToolExecutionContext - ): FileSystemService | null => { + ): FileSystemConfig | null => { const candidate = (context.services as unknown as { filesystemService?: unknown }) - ?.filesystemService as FileSystemService | undefined; + ?.filesystemService; if (!candidate) return null; - if (candidate instanceof FileSystemService) return candidate; - const hasMethods = - typeof (candidate as FileSystemService).readFile === 'function' && - typeof (candidate as FileSystemService).writeFile === 'function' && - typeof (candidate as FileSystemService).setWorkingDirectory === 'function' && - typeof (candidate as FileSystemService).setDirectoryApprovalChecker === 'function'; - return hasMethods ? (candidate as FileSystemService) : null; + if (candidate instanceof FileSystemService) return candidate.getConfig(); + + const getConfig = (candidate as { getConfig?: unknown }).getConfig; + if (typeof getConfig === 'function') { + return getConfig.call(candidate) as FileSystemConfig; + } + + return null; }; const getFileSystemService = async ( context: ToolExecutionContext ): Promise => { - const injectedService = resolveInjectedService(context); - if (injectedService) { - const approvalManager = context.services?.approval; - if (!approvalManager) { - throw ToolError.configInvalid( - 'filesystem-tools requires ToolExecutionContext.services.approval' - ); - } - injectedService.setDirectoryApprovalChecker((filePath: string) => - approvalManager.isDirectoryApproved(filePath) - ); - applyWorkspace(context, injectedService); - return injectedService; - } - - if (fileSystemService) { - const approvalManager = context.services?.approval; - if (!approvalManager) { - throw ToolError.configInvalid( - 'filesystem-tools requires ToolExecutionContext.services.approval' - ); - } - fileSystemService.setDirectoryApprovalChecker((filePath: string) => - approvalManager.isDirectoryApproved(filePath) - ); - applyWorkspace(context, fileSystemService); - return fileSystemService; - } - - const logger = context.logger; - - fileSystemService = new FileSystemService(fileSystemConfig, logger); - - const approvalManager = context.services?.approval; - if (!approvalManager) { - throw ToolError.configInvalid( - 'filesystem-tools requires ToolExecutionContext.services.approval' - ); - } - fileSystemService.setDirectoryApprovalChecker((filePath: string) => - approvalManager.isDirectoryApproved(filePath) + const scopedFileSystemService = createScopedFileSystemService( + context, + resolveInjectedServiceConfig(context) ?? fileSystemConfig ); - applyWorkspace(context, fileSystemService); - - fileSystemService.initialize().catch((error) => { - const message = error instanceof Error ? error.message : String(error); - logger.error(`Failed to initialize FileSystemService: ${message}`); - }); - - return fileSystemService; + await scopedFileSystemService.initialize(); + return scopedFileSystemService; }; const toolCreators: Record Tool> = { diff --git a/packages/tui/src/components/overlays/ToolBrowser.tsx b/packages/tui/src/components/overlays/ToolBrowser.tsx index 22f18ac73..0ace83c75 100644 --- a/packages/tui/src/components/overlays/ToolBrowser.tsx +++ b/packages/tui/src/components/overlays/ToolBrowser.tsx @@ -293,7 +293,7 @@ const ToolBrowser = forwardRef(function Too .map((tool) => tool.name); if (sessionId) { - agent.setSessionAutoApproveTools(sessionId, autoApprovedTools); + await agent.setSessionAutoApproveTools(sessionId, autoApprovedTools); } const disabledTools = updatedTools @@ -301,7 +301,7 @@ const ToolBrowser = forwardRef(function Too .map((tool) => tool.name); if (effectiveTarget === 'session' && sessionId) { - agent.setSessionDisabledTools(sessionId, disabledTools); + await agent.setSessionDisabledTools(sessionId, disabledTools); } else if (effectiveTarget === 'global') { try { const { updateAgentPreferences, saveAgentPreferences, agentPreferencesExist } = @@ -321,7 +321,7 @@ const ToolBrowser = forwardRef(function Too } catch (_error) { // If we can't persist, still keep session state so user sees effect if (sessionId) { - agent.setSessionDisabledTools(sessionId, disabledTools); + await agent.setSessionDisabledTools(sessionId, disabledTools); } else { setTools(previousTools); setSelectedTool( @@ -334,9 +334,11 @@ const ToolBrowser = forwardRef(function Too closeScopePrompt(); }; - const toggleAutoApprove = () => { + const toggleAutoApprove = async () => { if (!sessionId) return; + const previousTools = toolsRef.current; + const previousSelectedToolName = selectedToolRef.current?.name; const updatedTools = toolsRef.current.map((tool) => tool.name === selectedToolRef.current?.name ? { ...tool, isAutoApproved: !tool.isAutoApproved } @@ -354,7 +356,14 @@ const ToolBrowser = forwardRef(function Too .filter((tool) => tool.isAutoApproved) .map((tool) => tool.name); - agent.setSessionAutoApproveTools(sessionId, autoApprovedTools); + try { + await agent.setSessionAutoApproveTools(sessionId, autoApprovedTools); + } catch (_error) { + setTools(previousTools); + setSelectedTool( + previousTools.find((tool) => tool.name === previousSelectedToolName) ?? null + ); + } }; const closeConfigMenu = () => { @@ -430,7 +439,7 @@ const ToolBrowser = forwardRef(function Too openScopePrompt(tool); } else if (configIndexRef.current === 1) { if (sessionId && tool.isEnabled) { - toggleAutoApprove(); + void toggleAutoApprove(); } } else { closeConfigMenu();