From 19f4e3844640ba30df6d5bdbc83f9a16dc374569 Mon Sep 17 00:00:00 2001 From: "IM.codes" Date: Fri, 17 Apr 2026 18:34:32 +0800 Subject: [PATCH 001/151] fix(web): keep p2p discussion controls visible --- web/src/styles.css | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/web/src/styles.css b/web/src/styles.css index c368d6623..c24d33eb5 100644 --- a/web/src/styles.css +++ b/web/src/styles.css @@ -1513,6 +1513,7 @@ body { .discussions-nav-row { display: flex; align-items: center; + flex-wrap: wrap; gap: 8px; padding: 10px 16px; flex-shrink: 0; @@ -1524,8 +1525,10 @@ body { } .discussions-nav-controls { margin-left: auto; - display: inline-flex; + display: flex; align-items: center; + flex: 1 1 260px; + flex-wrap: wrap; justify-content: flex-end; gap: 10px; min-width: 0; @@ -1536,6 +1539,7 @@ body { display: inline-flex; align-items: center; gap: 8px; + flex: 1 1 180px; min-width: 0; color: #cbd5e1; font-size: 13px; @@ -1546,13 +1550,16 @@ body { accent-color: #38bdf8; } .discussions-follow-toggle span { - white-space: nowrap; + white-space: normal; + line-height: 1.2; } .discussions-scroll-arrows { display: flex; flex-direction: row; align-items: center; gap: 8px; + flex: 0 0 auto; + margin-left: auto; } .discussions-scroll-btn-floating { width: 34px; @@ -1592,11 +1599,10 @@ body { } .discussions-nav-row { align-items: flex-start; } .discussions-nav-controls { - flex-wrap: wrap; + flex-basis: 100%; row-gap: 8px; justify-content: flex-end; } - .discussions-follow-toggle span { white-space: normal; line-height: 1.2; } .discussions-scroll-btn-floating { width: 38px; height: 38px; } } From db1acef9afd91a07c3fc997e076d15c76321daf9 Mon Sep 17 00:00:00 2001 From: "IM.codes" Date: Fri, 17 Apr 2026 21:02:54 +0800 Subject: [PATCH 002/151] feat(memory): recall template filter + cap rule + per-session dedup MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Template-prompt filter (recall-only): excludes built-in OpenSpec / P2P / slash-command / skill-template prompts from memory recall via the shared `isTemplatePrompt` / `isTemplateOriginSummary` predicates. Locale-aware across all 7 supported UI languages (en, zh-CN, zh-TW, es, ru, ja, ko) and covers every `openspec.*_prompt` + `p2p.*_prompt` built-in template, the `P2P_BASELINE_PROMPT`, `roundPrompt()` headers, harness `` tags, and slash-command / plugin-namespaced skill invocations. Recall cap rule: `RECALL_MIN_FLOOR = 0.5`, `RECALL_DEFAULT_CAP = 3`, `RECALL_EXTEND_BAR = 0.6`, `RECALL_EXTEND_CAP = 5`. Drop below floor; take top 3; extend to 5 iff every top-3 item clears 0.6. Applied at process `prependLocalMemory`, transport `buildTransportMessageRecall`, and server `POST /memory/recall`. Per-session de-dup: daemon-side LRU of 10 past injection events keyed by `sessionKey`; prevents re-injecting the same memory across consecutive turns of the same session. Cleared on `session.clear` (both transport and process paths) and on `TransportSessionRuntime.kill()`. Server endpoint does not apply this — it has no per-session context. Hit-count credit: only for items that actually entered the prompt (survived floor + LRU + cap). Items dropped upstream no longer receive a spaced-repetition credit. Intentional scope boundaries: - Ingestion / materialization is NOT filtered — template events remain part of the project's recorded history. - Startup bootstrap (`selectStartupMemoryItems`) is NOT filtered — it is project-scoped memory load, not a query-driven recall. - CLI `imcodes memory` / WS `memory.search` / web UI browsing are NOT capped — they use client-supplied explicit limits. Tests: 158 added or updated (template patterns × 7 locales, recall cap rule, injection history LRU, server recall endpoint rewrites for new semantics, materialization coordinator reverse-pin asserting template content is still recorded). Co-Authored-By: Claude Opus 4.7 (1M context) --- server/src/routes/shared-context.ts | 31 +- server/test/memory-recall.test.ts | 194 +++++++-- shared/memory-scoring.ts | 67 +++ shared/template-prompt-patterns.ts | 264 ++++++++++++ src/agent/transport-session-runtime.ts | 70 +++- src/context/recent-injection-history.ts | 115 ++++++ src/context/startup-memory.ts | 7 + src/daemon/command-handler.ts | 133 +++++- test/context/recent-injection-history.test.ts | 89 ++++ .../materialization-coordinator.test.ts | 19 + test/shared/recall-cap-rule.test.ts | 147 +++++++ test/shared/template-prompt-patterns.test.ts | 391 ++++++++++++++++++ 12 files changed, 1450 insertions(+), 77 deletions(-) create mode 100644 shared/template-prompt-patterns.ts create mode 100644 src/context/recent-injection-history.ts create mode 100644 test/context/recent-injection-history.test.ts create mode 100644 test/shared/recall-cap-rule.test.ts create mode 100644 test/shared/template-prompt-patterns.test.ts diff --git a/server/src/routes/shared-context.ts b/server/src/routes/shared-context.ts index 8016a518d..96e48bd50 100644 --- a/server/src/routes/shared-context.ts +++ b/server/src/routes/shared-context.ts @@ -7,7 +7,8 @@ import { parseRemoteUrl } from '../../../src/repo/detector.js'; import { parseCanonicalRepositoryKey } from '../../../src/agent/repository-identity-service.js'; import { classifyTimestampFreshness } from '../../../shared/context-freshness.js'; import type { ContextMemoryRecordView, ContextMemoryStatsView } from '../../../shared/context-types.js'; -import { computeRelevanceScore, type ProjectionClass } from '../../../shared/memory-scoring.js'; +import { computeRelevanceScore, applyRecallCapRule, type ProjectionClass } from '../../../shared/memory-scoring.js'; +import { isTemplatePrompt, isTemplateOriginSummary } from '../../../shared/template-prompt-patterns.js'; import { searchSemanticMemoryView } from '../util/semantic-memory-view.js'; type EnterpriseRole = 'owner' | 'admin' | 'member'; @@ -915,6 +916,12 @@ sharedContextRoutes.post('/:id/shared-context/memory/recall', async (c) => { if (!query || typeof query !== 'string' || query.trim().length === 0) { return c.json({ error: 'query_required' }, 400); } + // Template-prompt skip: OpenSpec / slash-command / skill-template queries + // are not natural-language requests; a recall over them returns noise. + // See shared/template-prompt-patterns.ts. + if (isTemplatePrompt(query)) { + return c.json({ results: [], vectorSearch: false, skipped: 'template_prompt' }); + } const limit = typeof rawLimit === 'number' && rawLimit > 0 ? Math.min(rawLimit, 20) : 5; const candidateLimit = Math.max(limit * 4, 20); @@ -1015,13 +1022,16 @@ sharedContextRoutes.post('/:id/shared-context/memory/recall', async (c) => { ); } - // Merge, deduplicate by id, sort by composite relevance score + // Merge, deduplicate by id, sort by composite relevance score. + // Result-side template filter: legacy projections whose summary reflects + // a templated workflow origin must not leak back through recall. const seen = new Set(); const currentProjectId = projectId ?? '__unknown_current_project__'; const results: Array<{ id: string; projectId: string; class: string; summary: string; updatedAt: number; score: number; source: 'personal' | 'enterprise' }> = []; for (const row of personalRows) { if (seen.has(row.id)) continue; seen.add(row.id); + if (isTemplateOriginSummary(row.summary)) continue; results.push({ id: row.id, projectId: row.project_id, @@ -1042,6 +1052,7 @@ sharedContextRoutes.post('/:id/shared-context/memory/recall', async (c) => { for (const row of enterpriseRows) { if (seen.has(row.id)) continue; seen.add(row.id); + if (isTemplateOriginSummary(row.summary)) continue; results.push({ id: row.id, projectId: row.project_id, @@ -1061,10 +1072,20 @@ sharedContextRoutes.post('/:id/shared-context/memory/recall', async (c) => { source: 'enterprise', }); } - results.sort((a, b) => b.score - a.score); - const topResults = results.slice(0, limit); + // Cap rule: floor 0.5, top 3, extend to 5 iff all >= 0.6. + // See shared/memory-scoring.ts. The client-supplied `limit` is an upper + // bound on the extend cap — a client asking for <=3 shrinks defaultCap; + // a client asking for >=5 keeps the default extend cap. + const cappedDefault = Math.min(limit, 3); + const cappedExtend = Math.min(Math.max(limit, cappedDefault), 5); + const topResults = applyRecallCapRule(results, { + defaultCap: cappedDefault, + extendCap: cappedExtend, + }); - // Record hits for recalled projections (server-side spaced repetition) + // Record hits only for projections that actually survived the cap rule — + // items dropped by floor or session-side filtering never reached the + // user's prompt and should not receive a spaced-repetition credit. const hitIds = topResults.map((r) => r.id); if (hitIds.length > 0) { const now = Date.now(); diff --git a/server/test/memory-recall.test.ts b/server/test/memory-recall.test.ts index e78fa780f..d297004d0 100644 --- a/server/test/memory-recall.test.ts +++ b/server/test/memory-recall.test.ts @@ -172,40 +172,113 @@ describe('memory recall endpoint — I.5', () => { expect(json.error).toBe('invalid_json'); }); + it('returns empty with skipped:template_prompt when the query is a built-in template', async () => { + // Query-side filter: OpenSpec workflow prompts never hit the DB — the + // endpoint short-circuits with `skipped: 'template_prompt'`. + const { db, executeLog } = makeMockDb({ + personalRows: [ + { id: 'p1', project_id: 'proj', projection_class: 'recent_summary', summary: 'Irrelevant', updated_at: 1, score: 0.9 }, + ], + }); + const app = await buildTestApp(db); + + const res = await postRecall(app, { + query: 'Drive the implementation of openspec/changes/my-feature aggressively.', + }); + expect(res.status).toBe(200); + const json = await res.json() as { results: unknown[]; skipped?: string }; + expect(json.results).toEqual([]); + expect(json.skipped).toBe('template_prompt'); + // No query-side DB work and no hit_count update for skipped queries + const hit = executeLog.find((e) => e.sql.toLowerCase().includes('hit_count')); + expect(hit).toBeUndefined(); + }); + + it('short-circuits for localized template queries across supported languages', async () => { + const { db } = makeMockDb({ personalRows: [] }); + const app = await buildTestApp(db); + + const templates = [ + '强力推进 openspec/changes/foo 的实施。', + 'P2P 讨论已经完成。请直接落实原始请求。', + 'Проведи строгий аудит реализации.', + '厳格な実装監査を実施してください。', + '엄격한 구현 감사를 수행하세요.', + ]; + for (const q of templates) { + const res = await postRecall(app, { query: q }); + expect(res.status).toBe(200); + const json = await res.json() as { results: unknown[]; skipped?: string }; + expect(json.skipped).toBe('template_prompt'); + expect(json.results).toEqual([]); + } + }); + + it('drops template-origin rows from merged results even for a normal query', async () => { + const now = Date.now(); + const { db, executeLog } = makeMockDb({ + personalRows: [ + { id: 'ok-1', project_id: 'proj-a', projection_class: 'recent_summary', summary: '## Problem → Resolution: fixed retry', updated_at: now, score: 0.9 }, + { id: 'bad-1', project_id: 'proj-a', projection_class: 'recent_summary', summary: 'User orchestrated openspec/changes/feature-x via subagents.', updated_at: now, score: 0.85 }, + ], + enterpriseRows: [ + { id: 'bad-2', project_id: 'proj-b', projection_class: 'recent_summary', summary: 'Drive the implementation of change Y.', updated_at: now, score: 0.8, enterprise_id: 'ent-1' }, + ], + }); + const app = await buildTestApp(db); + + const res = await postRecall(app, { query: 'retry behavior', projectId: 'proj-a' }); + expect(res.status).toBe(200); + const json = await res.json() as { results: Array<{ id: string; summary: string }> }; + const ids = json.results.map((r) => r.id); + expect(ids).toContain('ok-1'); + expect(ids).not.toContain('bad-1'); + expect(ids).not.toContain('bad-2'); + // Hit-count update should reference only the surviving row + await new Promise((r) => setTimeout(r, 50)); + const hit = executeLog.find((e) => e.sql.toLowerCase().includes('hit_count = hit_count + 1')); + expect(hit).toBeDefined(); + expect(hit!.params).toContain('ok-1'); + expect(hit!.params).not.toContain('bad-1'); + expect(hit!.params).not.toContain('bad-2'); + }); + it('merges personal and enterprise results into a single response', async () => { + const now = Date.now(); const { db } = makeMockDb({ personalRows: [ - { id: 'p1', project_id: 'proj-a', projection_class: 'recent_summary', summary: 'Personal memory A', updated_at: 1000, score: 0.9 }, - { id: 'p2', project_id: 'proj-a', projection_class: 'durable_memory_candidate', summary: 'Personal memory B', updated_at: 2000, score: 0.5 }, + { id: 'p1', project_id: 'proj-a', projection_class: 'recent_summary', summary: 'Personal memory A', updated_at: now, score: 0.95 }, + { id: 'p2', project_id: 'proj-a', projection_class: 'durable_memory_candidate', summary: 'Personal memory B', updated_at: now, score: 0.85 }, ], enterpriseRows: [ - { id: 'e1', project_id: 'proj-b', projection_class: 'recent_summary', summary: 'Enterprise memory C', updated_at: 3000, score: 0.7, enterprise_id: 'ent-1' }, + { id: 'e1', project_id: 'proj-a', projection_class: 'recent_summary', summary: 'Enterprise memory C', updated_at: now, score: 0.9, enterprise_id: 'ent-1' }, ], }); const app = await buildTestApp(db); - const res = await postRecall(app, { query: 'memory test' }); + const res = await postRecall(app, { query: 'memory test', projectId: 'proj-a' }); expect(res.status).toBe(200); const json = await res.json() as { results: Array<{ id: string; source: string }> }; + // All 3 survive floor + cap (top 3, all well above 0.6 extend bar) expect(json.results).toHaveLength(3); - // Should contain both personal and enterprise const sources = json.results.map((r) => r.source); expect(sources).toContain('personal'); expect(sources).toContain('enterprise'); }); it('deduplicates results by id (personal wins over enterprise for same id)', async () => { + const now = Date.now(); const { db } = makeMockDb({ personalRows: [ - { id: 'shared-1', project_id: 'proj-a', projection_class: 'recent_summary', summary: 'Personal version', updated_at: 1000, score: 0.8 }, + { id: 'shared-1', project_id: 'proj-a', projection_class: 'recent_summary', summary: 'Personal version', updated_at: now, score: 0.85 }, ], enterpriseRows: [ - { id: 'shared-1', project_id: 'proj-a', projection_class: 'recent_summary', summary: 'Enterprise version', updated_at: 2000, score: 0.9, enterprise_id: 'ent-1' }, + { id: 'shared-1', project_id: 'proj-a', projection_class: 'recent_summary', summary: 'Enterprise version', updated_at: now, score: 0.9, enterprise_id: 'ent-1' }, ], }); const app = await buildTestApp(db); - const res = await postRecall(app, { query: 'test' }); + const res = await postRecall(app, { query: 'test', projectId: 'proj-a' }); expect(res.status).toBe(200); const json = await res.json() as { results: Array<{ id: string; source: string; summary: string }> }; expect(json.results).toHaveLength(1); @@ -215,18 +288,19 @@ describe('memory recall endpoint — I.5', () => { }); it('sorts merged results by score descending', async () => { + const now = Date.now(); const { db } = makeMockDb({ personalRows: [ - { id: 'low', project_id: 'proj-a', projection_class: 'recent_summary', summary: 'Low score', updated_at: 1000, score: 0.3 }, - { id: 'high', project_id: 'proj-a', projection_class: 'recent_summary', summary: 'High score', updated_at: 2000, score: 0.95 }, + { id: 'low', project_id: 'proj-a', projection_class: 'recent_summary', summary: 'Low score', updated_at: now, score: 0.75 }, + { id: 'high', project_id: 'proj-a', projection_class: 'recent_summary', summary: 'High score', updated_at: now, score: 0.98 }, ], enterpriseRows: [ - { id: 'mid', project_id: 'proj-b', projection_class: 'recent_summary', summary: 'Mid score', updated_at: 3000, score: 0.6, enterprise_id: 'ent-1' }, + { id: 'mid', project_id: 'proj-a', projection_class: 'recent_summary', summary: 'Mid score', updated_at: now, score: 0.85, enterprise_id: 'ent-1' }, ], }); const app = await buildTestApp(db); - const res = await postRecall(app, { query: 'test' }); + const res = await postRecall(app, { query: 'test', projectId: 'proj-a' }); const json = await res.json() as { results: Array<{ id: string; score: number }> }; expect(json.results).toHaveLength(3); expect(json.results[0].id).toBe('high'); @@ -237,27 +311,29 @@ describe('memory recall endpoint — I.5', () => { expect(json.results[1].score).toBeGreaterThanOrEqual(json.results[2].score); }); - it('limits results to the requested count', async () => { + it('shrinks the default cap when client requests fewer than 3', async () => { + const now = Date.now(); const { db } = makeMockDb({ personalRows: [ - { id: 'p1', project_id: 'proj', projection_class: 'recent_summary', summary: 'A', updated_at: 1, score: 0.9 }, - { id: 'p2', project_id: 'proj', projection_class: 'recent_summary', summary: 'B', updated_at: 2, score: 0.8 }, - { id: 'p3', project_id: 'proj', projection_class: 'recent_summary', summary: 'C', updated_at: 3, score: 0.7 }, - { id: 'p4', project_id: 'proj', projection_class: 'recent_summary', summary: 'D', updated_at: 4, score: 0.6 }, - { id: 'p5', project_id: 'proj', projection_class: 'recent_summary', summary: 'E', updated_at: 5, score: 0.5 }, + { id: 'p1', project_id: 'proj', projection_class: 'recent_summary', summary: 'A', updated_at: now, score: 0.95 }, + { id: 'p2', project_id: 'proj', projection_class: 'recent_summary', summary: 'B', updated_at: now, score: 0.9 }, + { id: 'p3', project_id: 'proj', projection_class: 'recent_summary', summary: 'C', updated_at: now, score: 0.85 }, ], }); const app = await buildTestApp(db); - const res = await postRecall(app, { query: 'test', limit: 2 }); + const res = await postRecall(app, { query: 'test', limit: 2, projectId: 'proj' }); const json = await res.json() as { results: Array<{ id: string }> }; + // Client-supplied limit 2 shrinks defaultCap+extendCap to 2. expect(json.results).toHaveLength(2); - // Top 2 by score expect(json.results[0].id).toBe('p1'); expect(json.results[1].id).toBe('p2'); }); - it('defaults to limit 5 when not specified', async () => { + it('defaults to top 3 when no limit is specified', async () => { + // Under the recall cap rule, default behavior is 3 unless every top-3 + // item is above the extend bar (0.6 composite). + const now = Date.now(); const rows: MockRow[] = []; for (let i = 0; i < 10; i++) { rows.push({ @@ -265,48 +341,75 @@ describe('memory recall endpoint — I.5', () => { project_id: 'proj', projection_class: 'recent_summary', summary: `Memory ${i}`, - updated_at: i, - score: 1 - i * 0.05, + updated_at: now, + score: 1 - i * 0.05, // 1.0, 0.95, 0.9, 0.85, 0.8, ... }); } const { db } = makeMockDb({ personalRows: rows }); const app = await buildTestApp(db); - const res = await postRecall(app, { query: 'test' }); + const res = await postRecall(app, { query: 'test', projectId: 'proj' }); const json = await res.json() as { results: Array<{ id: string }> }; + // All items are well above the extend bar → extend kicks in up to 5. expect(json.results).toHaveLength(5); + expect(json.results.map((r) => r.id)).toEqual(['p0', 'p1', 'p2', 'p3', 'p4']); }); - it('caps limit at 20 even if client requests more', async () => { - const rows: MockRow[] = []; - for (let i = 0; i < 25; i++) { - rows.push({ - id: `p${i}`, - project_id: 'proj', - projection_class: 'recent_summary', - summary: `Memory ${i}`, - updated_at: i, - score: 1 - i * 0.01, - }); - } + it('extends up to 5 only when every top-3 item is above the extend bar', async () => { + // Build a set where the top 3 include one at exactly 0.59 composite + // (below 0.6 extend bar) — extension must NOT kick in. + const now = Date.now(); + const rows: MockRow[] = [ + { id: 'strong-1', project_id: 'proj', projection_class: 'recent_summary', summary: 'A', updated_at: now, score: 0.98 }, + { id: 'strong-2', project_id: 'proj', projection_class: 'recent_summary', summary: 'B', updated_at: now, score: 0.95 }, + // similarity 0.5 + project-boost 0.2 + recency ~0.225 → ~0.625 (borderline; we pick 0.35 to stay under) + { id: 'borderline', project_id: 'proj', projection_class: 'recent_summary', summary: 'C', updated_at: now, score: 0.35 }, + { id: 'extra-1', project_id: 'proj', projection_class: 'recent_summary', summary: 'D', updated_at: now, score: 0.9 }, + { id: 'extra-2', project_id: 'proj', projection_class: 'recent_summary', summary: 'E', updated_at: now, score: 0.88 }, + ]; const { db } = makeMockDb({ personalRows: rows }); const app = await buildTestApp(db); - const res = await postRecall(app, { query: 'test', limit: 100 }); - const json = await res.json() as { results: Array<{ id: string }> }; - expect(json.results).toHaveLength(20); + const res = await postRecall(app, { query: 'test', projectId: 'proj' }); + const json = await res.json() as { results: Array<{ id: string; score: number }> }; + // Top 3 by composite: strong-1, strong-2, extra-1 (all >= 0.6) → extend, + // then extra-2 (>= 0.6) → 4th, then borderline (< 0.6) → stop. + // So we get 4 results: strong-1, strong-2, extra-1, extra-2. + const ids = json.results.map((r) => r.id); + expect(ids).not.toContain('borderline'); + expect(ids).toContain('strong-1'); + expect(ids).toContain('strong-2'); + expect(ids).toContain('extra-1'); + }); + + it('drops rows that fail the 0.5 composite floor even for a normal query', async () => { + // Ancient timestamps + no project match → composite scores collapse + // below floor regardless of raw similarity. + const { db } = makeMockDb({ + personalRows: [ + { id: 'old-1', project_id: 'unrelated', projection_class: 'recent_summary', summary: 'Old memory', updated_at: 1000, score: 0.9 }, + { id: 'old-2', project_id: 'unrelated', projection_class: 'recent_summary', summary: 'Another old memory', updated_at: 1000, score: 0.85 }, + ], + }); + const app = await buildTestApp(db); + + // No matching projectId → projectBoost = 0.1, old updated_at → recency ≈ 0 + const res = await postRecall(app, { query: 'test' }); + const json = await res.json() as { results: unknown[] }; + expect(json.results).toEqual([]); }); it('fires hit_count UPDATE for recalled projection ids', async () => { + const now = Date.now(); const { db, executeLog } = makeMockDb({ personalRows: [ - { id: 'hit-a', project_id: 'proj', projection_class: 'recent_summary', summary: 'A', updated_at: 1, score: 0.9 }, - { id: 'hit-b', project_id: 'proj', projection_class: 'recent_summary', summary: 'B', updated_at: 2, score: 0.8 }, + { id: 'hit-a', project_id: 'proj', projection_class: 'recent_summary', summary: 'A', updated_at: now, score: 0.9 }, + { id: 'hit-b', project_id: 'proj', projection_class: 'recent_summary', summary: 'B', updated_at: now, score: 0.85 }, ], }); const app = await buildTestApp(db); - const res = await postRecall(app, { query: 'test' }); + const res = await postRecall(app, { query: 'test', projectId: 'proj' }); expect(res.status).toBe(200); // The hit_count UPDATE is fire-and-forget (catch-ignored), but it should @@ -341,14 +444,15 @@ describe('memory recall endpoint — I.5', () => { }); it('returns correct shape for each result item', async () => { + const now = Date.now(); const { db } = makeMockDb({ personalRows: [ - { id: 'shape-1', project_id: 'my-proj', projection_class: 'durable_memory_candidate', summary: 'A durable memory', updated_at: 1700000000000, score: 0.75 }, + { id: 'shape-1', project_id: 'my-proj', projection_class: 'durable_memory_candidate', summary: 'A durable memory', updated_at: now, score: 0.9 }, ], }); const app = await buildTestApp(db); - const res = await postRecall(app, { query: 'test' }); + const res = await postRecall(app, { query: 'test', projectId: 'my-proj' }); const json = await res.json() as { results: Array> }; expect(json.results).toHaveLength(1); const item = json.results[0]; @@ -356,7 +460,7 @@ describe('memory recall endpoint — I.5', () => { expect(item).toHaveProperty('projectId', 'my-proj'); expect(item).toHaveProperty('class', 'durable_memory_candidate'); expect(item).toHaveProperty('summary', 'A durable memory'); - expect(item).toHaveProperty('updatedAt', 1700000000000); + expect(item).toHaveProperty('updatedAt', now); expect(typeof item.score).toBe('number'); expect(item).toHaveProperty('source', 'personal'); }); diff --git a/shared/memory-scoring.ts b/shared/memory-scoring.ts index 077ee92f1..e5cacf769 100644 --- a/shared/memory-scoring.ts +++ b/shared/memory-scoring.ts @@ -74,3 +74,70 @@ export function computeRelevanceScore(input: MemoryScoringInput): number { const project = computeProjectBoost(input); return W_SIMILARITY * input.similarity + W_RECENCY * recency + W_FREQUENCY * frequency + W_PROJECT * project; } + +// ── Recall cap rule ──────────────────────────────────────────────────────── +// +// Tuning rationale: +// - MIN_FLOOR = 0.5 → excludes matches that clear 0.4+ purely on +// project + recency without real semantic or frequency signal. +// A same-project, fresh, never-recalled item with similarity 0 still +// scores only 0.425 and will be correctly dropped. +// - DEFAULT_CAP = 3 → tight default; noise-resistant. +// - EXTEND_BAR = 0.6, EXTEND_CAP = 5 → if the top 3 are ALL strong, +// keep absorbing equally-strong items up to 5. Mediocre 4th items +// do not get promoted. + +export const RECALL_MIN_FLOOR = 0.5; +export const RECALL_DEFAULT_CAP = 3; +export const RECALL_EXTEND_BAR = 0.6; +export const RECALL_EXTEND_CAP = 5; + +export interface RecallCapOptions { + minFloor?: number; + defaultCap?: number; + extendBar?: number; + extendCap?: number; +} + +/** + * Apply the recall cap rule to a list of scored candidates. + * + * Input SHOULD already be sorted by `score` descending; if not, this + * function sorts defensively without mutating the caller's array. + * + * Rule: + * 1. Drop anything with `score < minFloor` (default 0.5). + * 2. Take the first `defaultCap` (default 3). + * 3. If those `defaultCap` are ALL at or above `extendBar` (default 0.6), + * keep absorbing subsequent items that are also at or above `extendBar`, + * up to `extendCap` items total (default 5). + */ +export function applyRecallCapRule( + scored: readonly T[], + options: RecallCapOptions = {}, +): T[] { + const minFloor = options.minFloor ?? RECALL_MIN_FLOOR; + const defaultCap = options.defaultCap ?? RECALL_DEFAULT_CAP; + const extendBar = options.extendBar ?? RECALL_EXTEND_BAR; + const extendCap = options.extendCap ?? RECALL_EXTEND_CAP; + + // Defensive sort copy — callers that already sort pay only O(n) scan. + const sorted = [...scored].sort((a, b) => b.score - a.score); + + const floored = sorted.filter((item) => item.score >= minFloor); + if (floored.length === 0) return []; + + const base = floored.slice(0, defaultCap); + if (base.length < defaultCap) return base; + + const allStrong = base.every((item) => item.score >= extendBar); + if (!allStrong) return base; + + const extended: T[] = [...base]; + for (let i = defaultCap; i < floored.length && extended.length < extendCap; i++) { + const candidate = floored[i]; + if (candidate.score < extendBar) break; + extended.push(candidate); + } + return extended; +} diff --git a/shared/template-prompt-patterns.ts b/shared/template-prompt-patterns.ts new file mode 100644 index 000000000..b9eba722a --- /dev/null +++ b/shared/template-prompt-patterns.ts @@ -0,0 +1,264 @@ +/** + * Template-prompt detection shared across daemon and server. + * + * IM.codes' shared-context memory system stages and materializes chat events + * into `recent_summary` / `durable_memory_candidate` projections that later + * feed back into `prependLocalMemory` (process agents), the transport recall + * step (Phase K), `selectStartupMemoryItems`, and the server + * `memory/recall` endpoint. + * + * That pipeline produces noise for built-in / templated prompts: + * - OpenSpec workflow invocations (`Drive the implementation of + * @openspec/changes/...`, archive/propose/apply/explore skills) + * - Slash-command / skill preambles (`/loop`, `/schedule`, `/review`, + * `claude-mem:*`, `opsx:*`, `openspec-*`, `update-config`, ...) + * - Harness-injected `` templates + * + * Memories derived from those prompts are irrelevant to later user work: + * cross-project OpenSpec references pollute recall hits for unrelated + * projects. This module is the single source of truth for detecting them + * at every ingestion and recall site. + * + * Design goals: + * - Cheap: pure string/regex, no allocation beyond trimming + * - Conservative: a pattern must be a high-signal marker, not merely a + * keyword that could appear in normal prose + * - Shared: daemon (`src/context/*`, `src/daemon/*`, `src/agent/*`) and + * server (`server/src/routes/shared-context.ts`) import the same + * predicate so query-side and result-side filtering stay consistent + */ + +/** + * Raw user prompt or staged-event `content`. + * + * True when the text is obviously a templated workflow invocation — the kind + * of prompt whose resulting assistant turn should not become recallable + * memory, and whose text should not be used as a recall query. + */ +export function isTemplatePrompt(text: string | null | undefined): boolean { + if (!text || typeof text !== 'string') return false; + const trimmed = text.trim(); + if (trimmed.length === 0) return false; + + // OpenSpec change references — any `@openspec/changes/` or bare + // `openspec/changes/` path is a strong marker. The workflow skills + // (propose/apply/archive/explore) all emit these references. + if (/(^|[\s@/`"'])openspec\/changes\/[a-z0-9][\w./-]*/i.test(trimmed)) { + return true; + } + + // Harness-injected command invocation tags (Claude Code slash commands + // render as `foo` in the transcript). + if (/[^<]+<\/command-name>/i.test(trimmed)) { + return true; + } + if (/[^<]*<\/command-message>/i.test(trimmed)) { + return true; + } + if (/[^<]*<\/command-args>/i.test(trimmed)) { + return true; + } + + // OpenSpec + P2P workflow imperative phrases emitted by built-in skill + // preambles and quick-actions. Each is a high-signal anchor per language — + // see `web/src/i18n/locales/*.json` keys `openspec.*_prompt` and + // `p2p.*_prompt`, plus `shared/p2p-modes.ts` (`P2P_BASELINE_PROMPT`, + // `roundPrompt`). These MUST stay in sync with those templates across all + // 7 locales (en, zh-CN, zh-TW, es, ru, ja, ko). + for (const marker of MULTILINGUAL_TEMPLATE_MARKERS) { + if (marker.test(trimmed)) return true; + } + + // Leading slash-command dispatch for well-known built-in skills. We only + // match the first token to avoid swallowing legitimate prose that happens + // to contain a slash path. + const firstToken = trimmed.split(/\s/, 1)[0] ?? ''; + if (SLASH_COMMAND_NAMES.has(firstToken.toLowerCase())) return true; + + // Plugin-namespaced skill invocations like `claude-mem:do`, `opsx:apply`. + if (/^(?:claude-mem|claude-hud|claude-api|opsx|openspec-[a-z-]+|update-config|less-permission-prompts|keybindings-help|simplify|statusline-setup|init|review|security-review|loop|schedule):/i.test(firstToken)) { + return true; + } + + return false; +} + +/** + * Processed projection `summary` text. + * + * True when a stored memory summary clearly originated from a templated + * prompt — e.g. summaries that mention orchestrating subagents for an + * OpenSpec change, archiving a change, or running a skill. This catches + * legacy projections written before ingestion-side filtering existed, and + * guards against any content that slipped through because the templated + * prompt leaked into the assistant's final message verbatim. + */ +export function isTemplateOriginSummary(summary: string | null | undefined): boolean { + if (!summary || typeof summary !== 'string') return false; + const trimmed = summary.trim(); + if (trimmed.length === 0) return false; + + // The OpenSpec change path is the most common and highest-signal leak. + if (/openspec\/changes\//i.test(trimmed)) return true; + + // Reuse the multilingual workflow anchors so legacy summaries written + // before ingestion-side filtering existed are also filtered at recall. + for (const marker of MULTILINGUAL_TEMPLATE_MARKERS) { + if (marker.test(trimmed)) return true; + } + + // Harness `` tag fragments sometimes survive into summary + // compression output. + if (/||/i.test(trimmed)) return true; + + return false; +} + +/** + * Multilingual anchor regexes for every built-in prompt template IM.codes + * auto-sends on behalf of the user. Each marker is a short, distinctive + * substring chosen to not collide with ordinary prose in its language. + * + * Grouped by template for auditability; when a template is added or its + * wording changes in `web/src/i18n/locales/*.json`, update the matching + * group here. Add a test case in + * `test/shared/template-prompt-patterns.test.ts` for each new language. + */ +const MULTILINGUAL_TEMPLATE_MARKERS: readonly RegExp[] = [ + // ── openspec.implement_prompt ───────────────────────────────────────── + /\bDrive the implementation of\b/i, // en + /强力推进/, // zh-CN + /強力推進/, // zh-TW + /\bImpulsa con firmeza la implementación\b/i, // es + /Жестко доведи реализацию/i, // ru + /の実装を強力に前進させてください/, // ja + /구현을 강하게 밀어붙이세요/, // ko + + // ── openspec.audit_implementation_prompt ────────────────────────────── + /\bPerform a strict implementation audit\b/i, // en + /执行严格的实现审计/, // zh-CN + /執行嚴格的實作審計/, // zh-TW + /\bRealiza una auditoría estricta de la implementación\b/i, // es + /Проведи строгий аудит реализации/i, // ru + /厳格な実装監査を実施してください/, // ja + /엄격한 구현 감사를 수행하세요/, // ko + + // ── openspec.audit_spec_prompt ──────────────────────────────────────── + /\bPerform a strict specification audit\b/i, // en + /执行严格的规范审计/, // zh-CN + /執行嚴格的規格審計/, // zh-TW + /\bRealiza una auditoría estricta de la especificación\b/i, // es + /Проведи строгий аудит спецификации/i, // ru + /厳格な仕様監査を実施してください/, // ja + /엄격한 명세 감사를 수행하세요/, // ko + + // ── openspec.propose_from_discussion_prompt ─────────────────────────── + /\bGenerate an OpenSpec change from the recent discussion\b/i, // en + /根据最近的讨论生成一个 OpenSpec 变更/, // zh-CN + /根據最近的討論生成一個 OpenSpec 變更/, // zh-TW + /\bGenera un cambio de OpenSpec a partir de la discusión\b/i, // es + /Сгенерируй изменение OpenSpec на основе недавнего обсуждения/i, // ru + /直近の議論から OpenSpec 変更を生成してください/, // ja + /최근 논의를 바탕으로 OpenSpec 변경을 생성하세요/, // ko + + // ── openspec.propose_from_description_prompt ────────────────────────── + /\bGenerate an OpenSpec change from the description\b/i, // en + /根据下面的描述生成一个 OpenSpec 变更/, // zh-CN + /根據下面的描述生成一個 OpenSpec 變更/, // zh-TW + /\bGenera un cambio de OpenSpec a partir de la descripción\b/i, // es + /Сгенерируй изменение OpenSpec на основе описания/i, // ru + /OpenSpec 変更を生成してください/, // ja + /설명을 바탕으로 OpenSpec 변경을 생성하세요/, // ko + + // ── openspec.achieve_prompt ─────────────────────────────────────────── + /\busing the full OpenSpec workflow\b/i, // en + /按完整 OpenSpec 工作流/, // zh-CN + /依照完整 OpenSpec 工作流程/, // zh-TW + /\busando el flujo completo de OpenSpec\b/i, // es + /по полному процессу OpenSpec/i, // ru + /完全な OpenSpec ワークフロー/, // ja + /전체 OpenSpec 워크플로/, // ko + + // ── p2p.post_summary_execute_prompt ─────────────────────────────────── + /\bThe P2P discussion is complete\b/i, // en + /P2P 讨论已经完成/, // zh-CN + /P2P 討論已完成/, // zh-TW + /\bLa discusión P2P ha terminado\b/i, // es + /P2P-обсуждение завершено/i, // ru + /P2P議論は完了しました/, // ja + /P2P 토론이 완료되었습니다/, // ko + + // ── p2p.final_original_request_reminder ─────────────────────────────── + /\bAfter synthesizing the discussion\b/i, // en + /在完成讨论综合后/, // zh-CN + /在完成討論綜合後/, // zh-TW + /\bNo te quedes solo en el resumen de la discusión\b/i, // es + /Не ограничивайся только сводкой обсуждения/i, // ru + /議論の要約だけで終わらせず/, // ja + /토론 요약으로 끝내지 말고/, // ko + + // ── shared/p2p-modes.ts — P2P_BASELINE_PROMPT ───────────────────────── + /\bstaff-level engineer participating in a multi-agent\b/i, + + // ── shared/p2p-modes.ts — roundPrompt() output ──────────────────────── + /\[Round \d+\/\d+\b/, // round phase header + /\bProvide your initial analysis based on the original request\b/i, + /\bReview ALL previous rounds' findings above\b/i, + + // ── Generic explicit workflow phrases (non-locale-specific fallbacks) ─ + /\bArchive(?:s|d)? (?:a |the )?completed (?:OpenSpec )?change\b/i, + /\bPropose a new (?:OpenSpec )?change\b/i, + /\bImplement tasks from an? OpenSpec change\b/i, + /\bEnter explore mode\b/i, +]; + +/** + * First-token slash command names to treat as template invocations. + * Kept as a `Set` for O(1) membership checks. + */ +const SLASH_COMMAND_NAMES: ReadonlySet = new Set([ + '/loop', + '/schedule', + '/review', + '/security-review', + '/init', + '/doctor', + '/clear', + '/compact', + '/config', + '/model', + '/help', + '/status', + '/exit', + '/plan', + '/hooks', + '/mcp', + '/agents', + '/cost', + '/memory', + '/permissions', + '/rewind', + '/resume', + '/export', + '/statusline', + '/ide', + '/pr_comments', + '/upgrade', + '/output-style', + '/compactify', + '/bashes', + '/add-dir', + '/bug', + '/feedback', + '/release-notes', + '/vim', + '/migrate-installer', + '/install-github-app', +]); + +/** + * Exposed for tests that want to extend or audit the slash-command allowlist. + */ +export function listKnownSlashCommands(): readonly string[] { + return Array.from(SLASH_COMMAND_NAMES); +} diff --git a/src/agent/transport-session-runtime.ts b/src/agent/transport-session-runtime.ts index 781db54a9..3ab93d8b2 100644 --- a/src/agent/transport-session-runtime.ts +++ b/src/agent/transport-session-runtime.ts @@ -4,7 +4,9 @@ import { RUNTIME_TYPES } from './session-runtime.js'; import type { AgentStatus } from './detect.js'; import type { AgentMessage, MessageDelta } from '../../shared/agent-message.js'; import type { TransportProvider, ProviderError, SessionConfig, SessionInfoUpdate } from './transport-provider.js'; +import type { ApprovalRequest } from './transport-provider.js'; import type { TransportEffortLevel } from '../../shared/effort-levels.js'; +import type { TransportAttachment } from '../../shared/transport-attachments.js'; import { SharedContextDispatchError, dispatchSharedContextSend, @@ -20,6 +22,13 @@ import type { import { buildMemoryContextTimelinePayload } from '../daemon/memory-context-timeline.js'; import { timelineEmitter } from '../daemon/timeline-emitter.js'; import { searchLocalMemorySemantic, type MemorySearchResultItem } from '../context/memory-search.js'; +import { isTemplatePrompt, isTemplateOriginSummary } from '../../shared/template-prompt-patterns.js'; +import { applyRecallCapRule } from '../../shared/memory-scoring.js'; +import { + filterRecentlyInjected, + recordRecentInjection, + clearRecentInjectionHistory, +} from '../context/recent-injection-history.js'; import { resolveRuntimeAuthoredContext } from '../context/shared-context-runtime.js'; import { buildTransportStartupMemory, type TransportContextBootstrap } from './runtime-context-bootstrap.js'; import { recordMemoryHits } from '../store/context-store.js'; @@ -28,6 +37,7 @@ import logger from '../util/logger.js'; export interface PendingTransportMessage { clientMessageId: string; text: string; + attachments?: TransportAttachment[]; } /** @@ -85,6 +95,7 @@ export class TransportSessionRuntime implements SessionRuntime { /** Callback fired when pending messages are drained into a new turn. */ private _onDrain?: (messages: PendingTransportMessage[], mergedMessage: string, count: number) => void; private _onSessionInfoChange?: (info: SessionInfoUpdate) => void; + private _onApprovalRequest?: (request: ApprovalRequest) => void; constructor( private readonly provider: TransportProvider, @@ -123,6 +134,12 @@ export class TransportSessionRuntime implements SessionRuntime { this._onSessionInfoChange?.(info); })] : []), ); + if (this.provider.onApprovalRequest) { + this.provider.onApprovalRequest((sid: string, req: ApprovalRequest) => { + if (sid !== this._providerSessionId) return; + this._onApprovalRequest?.(req); + }); + } } // ── Public API ────────────────────────────────────────────────────────────── @@ -134,6 +151,7 @@ export class TransportSessionRuntime implements SessionRuntime { set onDrain(cb: (messages: PendingTransportMessage[], mergedMessage: string, count: number) => void) { this._onDrain = cb; } /** Register a callback for provider session metadata updates. */ set onSessionInfoChange(cb: (info: SessionInfoUpdate) => void) { this._onSessionInfoChange = cb; } + set onApprovalRequest(cb: (request: ApprovalRequest) => void) { this._onApprovalRequest = cb; } /** Set providerSessionId directly (restore from store without initialize). */ setProviderSessionId(id: string): void { this._providerSessionId = id; } @@ -196,7 +214,7 @@ export class TransportSessionRuntime implements SessionRuntime { * * Returns 'sent' if dispatched immediately, 'queued' if enqueued. */ - send(message: string, clientMessageId?: string): 'sent' | 'queued' { + send(message: string, clientMessageId?: string, attachments?: TransportAttachment[]): 'sent' | 'queued' { if (!this._providerSessionId) { throw new Error('TransportSessionRuntime not initialized — call initialize() first'); } @@ -205,11 +223,12 @@ export class TransportSessionRuntime implements SessionRuntime { this._pendingMessages.push({ clientMessageId: clientMessageId ?? randomUUID(), text: message, + ...(attachments?.length ? { attachments } : {}), }); return 'queued'; } - this._dispatchTurn(message, clientMessageId); + this._dispatchTurn(message, clientMessageId, attachments); return 'sent'; } @@ -255,6 +274,9 @@ export class TransportSessionRuntime implements SessionRuntime { this._sending = false; this._activeTurn = null; this._pendingMessages = []; + // Per-session memory injection history is daemon-scoped to this session; + // a kill ends that scope. clear() is called on session.clear separately. + clearRecentInjectionHistory(this.sessionKey); } getHistory(): AgentMessage[] { return [...this._history]; } @@ -268,7 +290,7 @@ export class TransportSessionRuntime implements SessionRuntime { } /** Dispatch a single turn to the provider. Assumes _sending is false. */ - private _dispatchTurn(message: string, clientMessageId?: string): void { + private _dispatchTurn(message: string, clientMessageId?: string, attachments?: TransportAttachment[]): void { this._history.push({ id: randomUUID(), sessionId: this._providerSessionId!, @@ -312,6 +334,7 @@ export class TransportSessionRuntime implements SessionRuntime { userMessage: message, description: this._description, systemPrompt: this._systemPrompt, + attachments, namespace: this._contextNamespace, namespaceDiagnostics: this._contextNamespaceDiagnostics, remoteProcessedFreshness: this._contextRemoteProcessedFreshness, @@ -377,8 +400,13 @@ export class TransportSessionRuntime implements SessionRuntime { const messages = this._pendingMessages.splice(0); const merged = messages.map((entry) => entry.text).join('\n\n'); + const attachments = messages.flatMap((entry) => entry.attachments ?? []); this._onDrain?.(messages, merged, messages.length); - this._dispatchTurn(merged, messages.length === 1 ? messages[0]?.clientMessageId : undefined); + this._dispatchTurn( + merged, + messages.length === 1 ? messages[0]?.clientMessageId : undefined, + attachments.length > 0 ? attachments : undefined, + ); return true; } @@ -429,22 +457,40 @@ export class TransportSessionRuntime implements SessionRuntime { logger.debug({ sessionKey: this.sessionKey, length: trimmed.length }, 'transport message recall skipped: short message'); return null; } + if (isTemplatePrompt(trimmed)) { + logger.debug({ sessionKey: this.sessionKey }, 'transport message recall skipped: template prompt'); + return null; + } try { const query = trimmed.slice(0, 200); + // Broaden candidate pool — the cap rule trims to 3 (up to 5 if all + // results are strong). See shared/memory-scoring.ts. const result = await searchLocalMemorySemantic({ query, namespace: this._contextNamespace, currentEnterpriseId: this._contextNamespace?.enterpriseId, repo: this._contextNamespace?.projectId ?? this.resolveAuthoredContextRepository(), - limit: 5, + limit: 10, }); - const items = result.items + // 1) Template-origin legacy summaries never surface through recall. + const processed = result.items .filter((item): item is MemorySearchResultItem => item.type === 'processed') - .map(toTransportMemoryRecallItem); + .filter((item) => !isTemplateOriginSummary(item.summary)); + // 2) Per-session dedup: skip items injected in this session's last + // 10 turns. Cleared on session.clear. + const procIds = processed.map((item) => item.id); + const keepIds = new Set(filterRecentlyInjected(this.sessionKey, procIds)); + const deduped = processed.filter((item) => keepIds.has(item.id)); + // 3) Cap rule: floor 0.5, top 3, extend to 5 iff all >= 0.6. + const scored = deduped.map((item) => ({ item, score: item.relevanceScore ?? 0 })); + const finalScored = applyRecallCapRule(scored); + const items = finalScored.map((s) => toTransportMemoryRecallItem(s.item)); if (items.length === 0) { logger.debug({ sessionKey: this.sessionKey, query }, 'transport message recall skipped: no processed matches'); return null; } + // 4) Record injection into the per-session ring buffer. + recordRecentInjection(this.sessionKey, items.map((it) => it.id)); const supportClass = this.provider.capabilities.contextSupport ?? 'full-normalized-context-injection'; const injectionSurface = supportClass === 'full-normalized-context-injection' ? 'normalized-payload' @@ -509,6 +555,16 @@ export class TransportSessionRuntime implements SessionRuntime { { source: 'daemon', confidence: 'high' }, ); } + + async respondApproval(requestId: string, approved: boolean): Promise { + if (!this._providerSessionId) { + throw new Error('TransportSessionRuntime not initialized — call initialize() first'); + } + if (!this.provider.respondApproval) { + throw new Error(`Provider ${this.provider.id} does not support approval responses`); + } + await this.provider.respondApproval(this._providerSessionId, requestId, approved); + } } function toTransportMemoryRecallItem(item: MemorySearchResultItem): TransportMemoryRecallItem { diff --git a/src/context/recent-injection-history.ts b/src/context/recent-injection-history.ts new file mode 100644 index 000000000..bd3cf1030 --- /dev/null +++ b/src/context/recent-injection-history.ts @@ -0,0 +1,115 @@ +/** + * Per-session recent-injection history. + * + * Purpose: prevent the same memory items from being re-injected into prompts + * on consecutive turns of the same session. Once a memory has been included + * in a recall-injected prompt, it becomes low-value to inject again in the + * immediate follow-up turns — the model already saw it, and repeating it + * is noise. + * + * Scope: + * - Per session (keyed by `sessionKey` — e.g. `deck__`). + * - Daemon-only, in-memory. Cleared on session `clear` and on daemon + * restart (a restart is effectively a clear from the user's POV). + * - Does NOT apply to startup bootstrap (which is project-scoped memory + * load, not a query-driven recall) or to server-side recall endpoint + * (no per-session context). + * + * Semantics: + * - "Last 10 turns" = the last 10 successful injection events, where + * each event carries the set of memory IDs that were injected on + * that turn. Unit is "turn", not "memory id": 1 event with 5 ids + * consumes 1 slot, not 5. + * - A candidate is considered "already injected recently" if its id + * appears in ANY of the retained injection events for this session. + * - The history is a ring buffer: recording the 11th event evicts + * the oldest. + */ + +const HISTORY_SIZE = 10; + +/** + * One past injection turn — the set of memory IDs that entered the prompt + * on that turn. + */ +type InjectionEvent = ReadonlySet; + +/** + * Keyed by `sessionKey`. Each value is an array of up to `HISTORY_SIZE` + * injection events, most recent first. + */ +const sessionHistory: Map = new Map(); + +/** + * Drop `memoryIds` that appear in any of the last `HISTORY_SIZE` injection + * events for this session. Returns a new array; does not mutate input. + * + * When `sessionKey` is falsy (e.g. anonymous WS lookup), no dedup is + * performed and all ids pass through. + */ +export function filterRecentlyInjected( + sessionKey: string | undefined, + memoryIds: readonly string[], +): string[] { + if (!sessionKey) return [...memoryIds]; + const events = sessionHistory.get(sessionKey); + if (!events || events.length === 0) return [...memoryIds]; + const seen = new Set(); + for (const ev of events) for (const id of ev) seen.add(id); + return memoryIds.filter((id) => !seen.has(id)); +} + +/** + * Record that `memoryIds` were injected into this session's prompt on the + * current turn. Pushes a new event onto the ring buffer; evicts the oldest + * event when the buffer exceeds `HISTORY_SIZE`. + * + * Empty id lists are ignored (no event recorded) — we don't want the ring + * buffer filled with no-op turns. + */ +export function recordRecentInjection( + sessionKey: string | undefined, + memoryIds: readonly string[], +): void { + if (!sessionKey) return; + if (memoryIds.length === 0) return; + const event: InjectionEvent = new Set(memoryIds); + const existing = sessionHistory.get(sessionKey) ?? []; + // Most-recent-first ordering — unshift then trim. + existing.unshift(event); + if (existing.length > HISTORY_SIZE) existing.length = HISTORY_SIZE; + sessionHistory.set(sessionKey, existing); +} + +/** + * Clear all injection history for this session. Called from session + * `clear` / fresh-conversation paths. + */ +export function clearRecentInjectionHistory(sessionKey: string | undefined): void { + if (!sessionKey) return; + sessionHistory.delete(sessionKey); +} + +/** + * Drop all session histories. Mainly for tests. + */ +export function resetAllRecentInjectionHistories(): void { + sessionHistory.clear(); +} + +/** + * Snapshot the current history for inspection/testing. Returns a copy. + */ +export function getRecentInjectionHistory( + sessionKey: string | undefined, +): readonly (readonly string[])[] { + if (!sessionKey) return []; + const events = sessionHistory.get(sessionKey); + if (!events) return []; + return events.map((ev) => Array.from(ev)); +} + +/** + * Exposed for tests that want to assert the ring-buffer bound. + */ +export const RECENT_INJECTION_HISTORY_SIZE = HISTORY_SIZE; diff --git a/src/context/startup-memory.ts b/src/context/startup-memory.ts index 17d2e1aa0..9727b9dee 100644 --- a/src/context/startup-memory.ts +++ b/src/context/startup-memory.ts @@ -20,6 +20,13 @@ export function selectStartupMemoryItems( const recentLimit = options.recentLimit ?? STARTUP_MEMORY_RECENT_LIMIT; const totalLimit = options.totalLimit ?? STARTUP_MEMORY_TOTAL_LIMIT; + // Startup bootstrap is project-scoped memory loading, NOT a query-driven + // recall. Any memory that belongs to the project's timeline is valid + // context for session startup, including entries whose source turn was a + // templated workflow prompt — the user still worked on this project and + // the resulting summary is part of the project's history. Template-prompt + // filtering is applied only on the recall/search paths. + const durable = searchLocalMemory({ namespace, projectionClass: 'durable_memory_candidate', diff --git a/src/daemon/command-handler.ts b/src/daemon/command-handler.ts index 778673231..4729f97ff 100644 --- a/src/daemon/command-handler.ts +++ b/src/daemon/command-handler.ts @@ -45,6 +45,13 @@ import { buildWindowsCleanupScript, buildWindowsCleanupVbs, buildWindowsUpgradeB import { UPGRADE_LOCK_FILE, encodeVbsAsUtf16, encodeCmdAsUtf8Bom } from '../util/windows-launch-artifacts.js'; import { registerTempFile, removeTrackedTempFile } from '../store/temp-file-store.js'; import { sanitizeProjectName } from '../../shared/sanitize-project-name.js'; +import { isTemplatePrompt, isTemplateOriginSummary } from '../../shared/template-prompt-patterns.js'; +import { applyRecallCapRule } from '../../shared/memory-scoring.js'; +import { + filterRecentlyInjected, + recordRecentInjection, + clearRecentInjectionHistory, +} from '../context/recent-injection-history.js'; import { CODEX_MODEL_IDS, normalizeClaudeCodeModelId } from '../shared/models/options.js'; import { getClaudeSdkRuntimeConfig, normalizeClaudeSdkModelForProvider } from '../agent/sdk-runtime-config.js'; import { getCodexRuntimeConfig } from '../agent/codex-runtime-config.js'; @@ -55,6 +62,7 @@ import { DAEMON_COMMAND_TYPES } from '../../shared/daemon-command-types.js'; import { CLAUDE_SDK_EFFORT_LEVELS, CODEX_SDK_EFFORT_LEVELS, + COPILOT_SDK_EFFORT_LEVELS, DEFAULT_TRANSPORT_EFFORT, OPENCLAW_THINKING_LEVELS, QWEN_EFFORT_LEVELS, @@ -220,12 +228,21 @@ async function handleSubSessionTransportConfigUpdate(cmd: Record, serverLink: ServerLink) try { serverLink.send({ type: 'session.error', project, message }); } catch { /* ignore */ } return; } - if (agentType === 'claude-code-sdk' || agentType === 'codex-sdk') { + if (agentType === 'claude-code-sdk' || agentType === 'codex-sdk' || agentType === 'copilot-sdk' || agentType === 'cursor-headless') { logger.info({ project, agentType }, 'SDK fresh session.start removing stale main-session store record'); removeSession(`deck_${project}_brain`); } @@ -1118,6 +1141,18 @@ async function handleStart(cmd: Record, serverLink: ServerLink) label, effort, }); + } else if (agentType === 'copilot-sdk' || agentType === 'cursor-headless') { + logger.info({ project, agentType }, 'SDK fresh session.start launching new transport main session'); + await launchTransportSession({ + name: `deck_${project}_brain`, + projectName: project, + role: 'brain', + agentType: agentType as 'copilot-sdk' | 'cursor-headless', + projectDir: dir, + fresh: true, + label, + effort, + }); } else { await startProject(config); } @@ -1531,6 +1566,7 @@ async function handleSend(cmd: Record, serverLink: ServerLink): // Transport sessions — route directly to the provider runtime, bypassing tmux. const transportRuntime = getTransportRuntime(sessionName); const record = (await import('../store/session-store.js')).getSession(sessionName); + const attachments: TransportAttachment[] = []; const transportUserEventId = (clientMessageId: string) => `transport-user:${clientMessageId}`; const emitTransportUserMessage = (payloadText: string, extra?: Record, eventId?: string) => { timelineEmitter.emit( @@ -1593,6 +1629,9 @@ async function handleSend(cmd: Record, serverLink: ServerLink): await runExclusiveSessionRelaunch(sessionName, async () => { await relaunchFreshTransportConversation(record); }); + // Reset per-session memory injection history — fresh conversation + // should be allowed to re-inject previously-shown memories again. + clearRecentInjectionHistory(sessionName); await handleGetSessions(serverLink); await syncSubSessionIfNeeded(sessionName, serverLink); timelineEmitter.emit(sessionName, 'assistant.text', { @@ -1785,7 +1824,9 @@ async function handleSend(cmd: Record, serverLink: ServerLink): // send() is synchronous: dispatches immediately if idle, queues if busy. // Status changes come from transport runtime's onStatusChange callback. - const result = transportRuntime.send(text, effectiveId); + const result = attachments.length > 0 + ? transportRuntime.send(text, effectiveId, attachments) + : transportRuntime.send(text, effectiveId); if (shouldTrackSupervisionTaskRun) { if (result === 'queued') { supervisionAutomation.queueTaskIntent(sessionName, effectiveId, text, supervisionSnapshot); @@ -1796,7 +1837,10 @@ async function handleSend(cmd: Record, serverLink: ServerLink): if (result === 'sent') { emitTransportUserMessage( text, - { clientMessageId: effectiveId }, + { + clientMessageId: effectiveId, + ...(attachments.length > 0 ? { attachments } : {}), + }, transportUserEventId(effectiveId), ); } @@ -1838,6 +1882,9 @@ async function handleSend(cmd: Record, serverLink: ServerLink): await runExclusiveSessionRelaunch(sessionName, async () => { await relaunchSessionWithSettings(record, { fresh: true }); }); + // Reset per-session memory injection history — fresh conversation + // should be allowed to re-inject previously-shown memories again. + clearRecentInjectionHistory(sessionName); await handleGetSessions(serverLink); await syncSubSessionIfNeeded(sessionName, serverLink); timelineEmitter.emit(sessionName, 'assistant.text', { @@ -1859,7 +1906,6 @@ async function handleSend(cmd: Record, serverLink: ServerLink): } // Build attachment refs for any uploaded files referenced in the message - const attachments: Array<{ id: string; originalName?: string; mime?: string; size?: number; daemonPath: string }> = []; if (tokens.files.length > 0) { const record = getSession(sessionName); const projectDir = record?.projectDir ?? ''; @@ -3981,6 +4027,30 @@ async function handleServerDelete(): Promise { // ── Transport chat history replay ───────────────────────────────────────────── +async function handleTransportApprovalResponse(cmd: Record, serverLink: ServerLink): Promise { + const sessionId = typeof cmd.sessionId === 'string' ? cmd.sessionId : undefined; + const requestId = typeof cmd.requestId === 'string' ? cmd.requestId : undefined; + const approved = typeof cmd.approved === 'boolean' ? cmd.approved : undefined; + if (!sessionId || !requestId || approved === undefined) return; + const runtime = getTransportRuntime(sessionId); + if (!runtime) return; + try { + await runtime.respondApproval(requestId, approved); + try { + serverLink.send({ + type: TRANSPORT_MSG.APPROVAL_RESPONSE, + sessionId, + requestId, + approved, + }); + } catch { + // ignore — daemon link disconnected + } + } catch (err) { + logger.warn({ err, sessionId, requestId }, 'transport approval response failed'); + } +} + async function handleChatSubscribeReplay(cmd: Record, serverLink: ServerLink): Promise { const sessionId = cmd.sessionId as string | undefined; if (!sessionId) return; @@ -4249,22 +4319,45 @@ async function prependLocalMemory( hitIds?: string[]; }> { if (prompt.length < 10) return { text: prompt }; // skip greetings / confirmations + // Template-prompt skip: OpenSpec / slash-command / skill-template prompts + // are not natural-language questions; a recall over them returns noise. + // See shared/template-prompt-patterns.ts. + if (isTemplatePrompt(prompt)) return { text: prompt }; try { const { searchLocalMemorySemantic } = await import('../context/memory-search.js'); const record = getSession(sessionName); const query = prompt.slice(0, 200); - const result = await searchLocalMemorySemantic({ + // Broaden the candidate pool — the cap rule trims to 3 (or up to 5 for + // all-strong results). We need enough candidates to survive filtering. + const searchResult = await searchLocalMemorySemantic({ query, namespace: record?.projectName ? { scope: 'personal', projectId: record.projectName } : undefined, repo: record?.projectName ?? undefined, - limit: 5, + limit: 10, }); - if (result.items.length === 0) return { text: prompt }; - const hitIds = result.items.filter((item) => item.type === 'processed').map((item) => item.id); - const injectedText = buildRelatedPastWorkText(result.items); - const timelinePayload = buildMemoryContextTimelinePayload(query, result.items); + // 1) Template-origin legacy summaries never surface through recall. + const notTemplate = searchResult.items.filter( + (item) => !isTemplateOriginSummary(item.summary), + ); + // 2) Per-session dedup: drop items already injected in the last 10 turns + // of THIS session. Cleared on `session.clear`. + const ids = notTemplate.map((item) => item.id); + const keepIds = new Set(filterRecentlyInjected(sessionName, ids)); + const deduped = notTemplate.filter((item) => keepIds.has(item.id)); + // 3) Cap rule: floor 0.5, top 3, extend to 5 iff all >= 0.6. + // See shared/memory-scoring.ts. + const scored = deduped.map((item) => ({ item, score: item.relevanceScore ?? 0 })); + const finalScored = applyRecallCapRule(scored); + const finalItems = finalScored.map((s) => s.item); + if (finalItems.length === 0) return { text: prompt }; + const hitIds = finalItems.filter((item) => item.type === 'processed').map((item) => item.id); + const injectedText = buildRelatedPastWorkText(finalItems); + const timelinePayload = buildMemoryContextTimelinePayload(query, finalItems); + // 4) Record the injection into the per-session ring buffer so these + // same items do not re-inject on the next 10 turns. + recordRecentInjection(sessionName, hitIds); return { text: `${injectedText}\n\n${prompt}`, timelinePayload: timelinePayload diff --git a/test/context/recent-injection-history.test.ts b/test/context/recent-injection-history.test.ts new file mode 100644 index 000000000..4f2ad17e9 --- /dev/null +++ b/test/context/recent-injection-history.test.ts @@ -0,0 +1,89 @@ +import { beforeEach, describe, expect, it } from 'vitest'; +import { + filterRecentlyInjected, + recordRecentInjection, + clearRecentInjectionHistory, + resetAllRecentInjectionHistories, + getRecentInjectionHistory, + RECENT_INJECTION_HISTORY_SIZE, +} from '../../src/context/recent-injection-history.js'; + +describe('recent-injection-history', () => { + beforeEach(() => { + resetAllRecentInjectionHistories(); + }); + + it('passes all ids through when no history exists yet', () => { + const out = filterRecentlyInjected('deck_a_brain', ['mem-1', 'mem-2']); + expect(out).toEqual(['mem-1', 'mem-2']); + }); + + it('drops ids injected on a previous turn of the same session', () => { + recordRecentInjection('deck_a_brain', ['mem-1', 'mem-2']); + const out = filterRecentlyInjected('deck_a_brain', ['mem-1', 'mem-2', 'mem-3']); + expect(out).toEqual(['mem-3']); + }); + + it('isolates history per sessionKey — other sessions see a clean history', () => { + recordRecentInjection('deck_a_brain', ['mem-1']); + const sameSession = filterRecentlyInjected('deck_a_brain', ['mem-1', 'mem-2']); + const differentSession = filterRecentlyInjected('deck_b_brain', ['mem-1', 'mem-2']); + expect(sameSession).toEqual(['mem-2']); + expect(differentSession).toEqual(['mem-1', 'mem-2']); + }); + + it('retains up to RECENT_INJECTION_HISTORY_SIZE (10) events per session', () => { + expect(RECENT_INJECTION_HISTORY_SIZE).toBe(10); + for (let i = 0; i < 12; i++) { + recordRecentInjection('deck_a_brain', [`mem-${i}`]); + } + const hist = getRecentInjectionHistory('deck_a_brain'); + // Ring buffer keeps the 10 most recent — events 2..11. + expect(hist).toHaveLength(10); + expect(hist[0]).toEqual(['mem-11']); // most recent first + expect(hist[9]).toEqual(['mem-2']); // oldest retained + }); + + it('evicts the oldest event when the 11th is recorded', () => { + for (let i = 0; i < 10; i++) recordRecentInjection('deck_a_brain', [`mem-${i}`]); + // mem-0..mem-9 are all in the history + expect(filterRecentlyInjected('deck_a_brain', ['mem-0'])).toEqual([]); + expect(filterRecentlyInjected('deck_a_brain', ['mem-9'])).toEqual([]); + + recordRecentInjection('deck_a_brain', ['mem-new']); + // mem-0 (oldest) is evicted; mem-new replaces its slot + expect(filterRecentlyInjected('deck_a_brain', ['mem-0'])).toEqual(['mem-0']); + expect(filterRecentlyInjected('deck_a_brain', ['mem-9'])).toEqual([]); + expect(filterRecentlyInjected('deck_a_brain', ['mem-new'])).toEqual([]); + }); + + it('treats one injection event as one slot, regardless of how many ids it contains', () => { + recordRecentInjection('deck_a_brain', ['a', 'b', 'c', 'd', 'e']); // 1 event, 5 ids + recordRecentInjection('deck_a_brain', ['f']); // 1 event, 1 id + const hist = getRecentInjectionHistory('deck_a_brain'); + expect(hist).toHaveLength(2); + // All 6 ids are still dedup-protected + expect(filterRecentlyInjected('deck_a_brain', ['a', 'b', 'c', 'd', 'e', 'f', 'g'])).toEqual([ + 'g', + ]); + }); + + it('does not record empty injection events', () => { + recordRecentInjection('deck_a_brain', []); + expect(getRecentInjectionHistory('deck_a_brain')).toEqual([]); + }); + + it('clearRecentInjectionHistory wipes history for the given session only', () => { + recordRecentInjection('deck_a_brain', ['mem-1']); + recordRecentInjection('deck_b_brain', ['mem-1']); + clearRecentInjectionHistory('deck_a_brain'); + expect(filterRecentlyInjected('deck_a_brain', ['mem-1'])).toEqual(['mem-1']); + expect(filterRecentlyInjected('deck_b_brain', ['mem-1'])).toEqual([]); + }); + + it('no-ops for falsy sessionKey (passes all ids through)', () => { + recordRecentInjection(undefined, ['mem-1']); + expect(filterRecentlyInjected(undefined, ['mem-1', 'mem-2'])).toEqual(['mem-1', 'mem-2']); + expect(filterRecentlyInjected('', ['mem-1'])).toEqual(['mem-1']); + }); +}); diff --git a/test/daemon/materialization-coordinator.test.ts b/test/daemon/materialization-coordinator.test.ts index 101e1cbb6..6da34a263 100644 --- a/test/daemon/materialization-coordinator.test.ts +++ b/test/daemon/materialization-coordinator.test.ts @@ -197,6 +197,25 @@ describe('MaterializationCoordinator', () => { expect(coordinator.canMaterializeTarget(target, 10_200)).toBe(true); }); + it('records template-prompt content at ingestion (filtering is a recall-side concern, not ingestion)', async () => { + // Built-in / templated prompts (OpenSpec workflow invocations, slash + // commands, harness command tags) are still written to memory — the + // template filter applies only on the recall path, not at record time. + // See shared/template-prompt-patterns.ts and Phase L. + const coordinator = new MaterializationCoordinator({ compressor: localOnlyCompressor, + thresholds: { eventCount: 1, idleMs: 1000, scheduleMs: 10_000 }, + }); + + const openspec = coordinator.ingestEvent({ + target, + eventType: 'assistant.text', + content: 'Drove the implementation of @openspec/changes/my-feature by orchestrating subagents.', + createdAt: 100, + }); + expect(openspec.filtered).toBeUndefined(); + expect(openspec.queuedJob).toEqual(expect.objectContaining({ trigger: 'threshold' })); + }); + it('pairs final assistant.text output with the user request in structured summaries', async () => { const coordinator = new MaterializationCoordinator({ compressor: localOnlyCompressor, thresholds: { eventCount: 99, idleMs: 50, scheduleMs: 200 }, diff --git a/test/shared/recall-cap-rule.test.ts b/test/shared/recall-cap-rule.test.ts new file mode 100644 index 000000000..a4057ae8f --- /dev/null +++ b/test/shared/recall-cap-rule.test.ts @@ -0,0 +1,147 @@ +import { describe, expect, it } from 'vitest'; +import { + applyRecallCapRule, + RECALL_MIN_FLOOR, + RECALL_DEFAULT_CAP, + RECALL_EXTEND_BAR, + RECALL_EXTEND_CAP, +} from '../../shared/memory-scoring.js'; + +const mk = (id: string, score: number) => ({ id, score }); + +describe('applyRecallCapRule — defaults', () => { + it('uses the documented constants', () => { + expect(RECALL_MIN_FLOOR).toBe(0.5); + expect(RECALL_DEFAULT_CAP).toBe(3); + expect(RECALL_EXTEND_BAR).toBe(0.6); + expect(RECALL_EXTEND_CAP).toBe(5); + }); + + it('returns [] when every candidate scores below the 0.5 floor', () => { + const items = [mk('a', 0.49), mk('b', 0.3), mk('c', 0.1)]; + expect(applyRecallCapRule(items)).toEqual([]); + }); + + it('keeps items at or above the 0.5 floor, drops those below', () => { + const items = [ + mk('pass-1', 0.9), + mk('pass-2', 0.5), + mk('drop-1', 0.49), + mk('drop-2', 0.2), + ]; + const out = applyRecallCapRule(items); + expect(out.map((i) => i.id)).toEqual(['pass-1', 'pass-2']); + }); + + it('caps at 3 when not all of the top 3 are >= 0.6', () => { + const items = [mk('a', 0.9), mk('b', 0.7), mk('c', 0.55), mk('d', 0.7), mk('e', 0.65)]; + // Top 3 after sort: 0.9, 0.7, 0.7 — c at 0.55 is pushed to #4 and dropped. + // WAIT: sorting preserves input order? Let's pick a clearer scenario. + const cleaner = [mk('a', 0.9), mk('b', 0.7), mk('c', 0.55), mk('d', 0.75), mk('e', 0.65)]; + const out = applyRecallCapRule(cleaner); + // Sorted: 0.9, 0.75, 0.7, 0.65, 0.55 → top 3 are 0.9/0.75/0.7 (all >= 0.6), + // so extension kicks in — 0.65 joins, 0.55 is cut off by floor? No, 0.55 >= 0.5, + // but fails extend_bar so extension stops at 0.65. + expect(out.map((i) => i.score)).toEqual([0.9, 0.75, 0.7, 0.65]); + }); + + it('caps strictly at 3 when the 3rd-ranked item is below 0.6', () => { + const items = [mk('a', 0.9), mk('b', 0.8), mk('c', 0.55), mk('d', 0.95), mk('e', 0.92)]; + // Sorted: 0.95, 0.92, 0.9, 0.8, 0.55 — wait, that reranks, let me recompute: + // 0.95 (d), 0.92 (e), 0.9 (a), 0.8 (b), 0.55 (c) + // Top 3: 0.95, 0.92, 0.9 — all >= 0.6 → extend kicks in + // Next candidate: 0.8 (b) — >= 0.6 → include → now have 4 + // Next: 0.55 (c) — < 0.6 → stop + // Final: [d, e, a, b] + const out = applyRecallCapRule(items); + expect(out.map((i) => i.id)).toEqual(['d', 'e', 'a', 'b']); + }); + + it('returns exactly the top 3 when the top 3 are not all >= 0.6', () => { + const items = [mk('a', 0.9), mk('b', 0.7), mk('c', 0.55), mk('d', 0.7)]; + // Sorted: 0.9, 0.7, 0.7, 0.55 — top 3 = [0.9, 0.7, 0.7], but 0.55 is below 0.6? + // Actually all >= 0.6? 0.7, 0.7, 0.9 yes. So extend tries next: 0.55 < 0.6 → stop. + // Actually wait, I want a case where top 3 CONTAINS a < 0.6 item. + const real = [mk('a', 0.9), mk('b', 0.7), mk('c', 0.55), mk('d', 0.55)]; + // Sorted: 0.9, 0.7, 0.55, 0.55 — top 3 = 0.9/0.7/0.55 — NOT all >= 0.6 → no extend. + const out = applyRecallCapRule(real); + expect(out.map((i) => i.score)).toEqual([0.9, 0.7, 0.55]); + }); + + it('caps extend at 5 even when more items qualify', () => { + const items = [ + mk('a', 0.95), + mk('b', 0.92), + mk('c', 0.88), + mk('d', 0.82), + mk('e', 0.75), + mk('f', 0.72), + mk('g', 0.65), + ]; + // Top 3 all >= 0.6 → extend. But hard cap at 5. + const out = applyRecallCapRule(items); + expect(out).toHaveLength(5); + expect(out.map((i) => i.id)).toEqual(['a', 'b', 'c', 'd', 'e']); + }); + + it('stops extending when the next candidate drops below 0.6', () => { + const items = [ + mk('a', 0.95), + mk('b', 0.92), + mk('c', 0.88), + mk('d', 0.58), // just below bar + mk('e', 0.75), + ]; + // Sorted: 0.95, 0.92, 0.88, 0.75, 0.58 → top 3 all >= 0.6, extend: + // next = 0.75 (>= 0.6) → include → 4 items + // next = 0.58 (< 0.6) → stop + const out = applyRecallCapRule(items); + expect(out.map((i) => i.id)).toEqual(['a', 'b', 'c', 'e']); + }); + + it('handles fewer than 3 candidates by returning whatever survived the floor', () => { + const two = [mk('a', 0.9), mk('b', 0.7)]; + expect(applyRecallCapRule(two).map((i) => i.id)).toEqual(['a', 'b']); + + const one = [mk('a', 0.9)]; + expect(applyRecallCapRule(one).map((i) => i.id)).toEqual(['a']); + + const zero: { id: string; score: number }[] = []; + expect(applyRecallCapRule(zero)).toEqual([]); + }); + + it('does not mutate the input array', () => { + const items = [mk('c', 0.55), mk('a', 0.95), mk('b', 0.75)]; + const snapshot = items.map((i) => i.id).join(','); + applyRecallCapRule(items); + expect(items.map((i) => i.id).join(',')).toBe(snapshot); + }); + + it('accepts custom caps for call sites that need tighter/looser behavior', () => { + const items = [mk('a', 0.9), mk('b', 0.85), mk('c', 0.8), mk('d', 0.75), mk('e', 0.7)]; + // Custom: defaultCap=2, extendCap=3. Top 2 both >= 0.6, extend one more. + const out = applyRecallCapRule(items, { defaultCap: 2, extendCap: 3 }); + expect(out.map((i) => i.id)).toEqual(['a', 'b', 'c']); + }); + + it('accepts custom floor', () => { + const items = [mk('a', 0.55), mk('b', 0.52), mk('c', 0.45)]; + // Default floor 0.5 → a, b pass. Custom floor 0.6 → all drop. + expect(applyRecallCapRule(items).map((i) => i.id)).toEqual(['a', 'b']); + expect(applyRecallCapRule(items, { minFloor: 0.6 })).toEqual([]); + }); + + it('calibration example: project+recency alone cannot pass (similarity=0 pure-boost case)', () => { + // From design.md: same project, fresh, never recalled, sim=0 + // 0.4*0 + 0.25*~0.9 + 0.15*0 + 0.2*1.0 = 0.425 < 0.5 floor → dropped + const items = [mk('pure-boost', 0.425)]; + expect(applyRecallCapRule(items)).toEqual([]); + }); + + it('calibration example: same project + decent semantic match passes floor', () => { + // Same project, fresh, never recalled, sim=0.3 → ~0.545 → passes floor, below extend bar + const items = [mk('decent-sim', 0.545)]; + const out = applyRecallCapRule(items); + expect(out.map((i) => i.id)).toEqual(['decent-sim']); + }); +}); diff --git a/test/shared/template-prompt-patterns.test.ts b/test/shared/template-prompt-patterns.test.ts new file mode 100644 index 000000000..2a982f73a --- /dev/null +++ b/test/shared/template-prompt-patterns.test.ts @@ -0,0 +1,391 @@ +import { describe, expect, it } from 'vitest'; +import { + isTemplatePrompt, + isTemplateOriginSummary, + listKnownSlashCommands, +} from '../../shared/template-prompt-patterns.js'; + +describe('isTemplatePrompt', () => { + // ── OpenSpec references ────────────────────────────────────────────── + it('flags @openspec/changes/ references', () => { + expect(isTemplatePrompt('Drive @openspec/changes/my-feature to completion')).toBe(true); + }); + + it('flags bare openspec/changes/ paths', () => { + expect(isTemplatePrompt('See openspec/changes/shared-agent-context/proposal.md')).toBe(true); + }); + + it('flags openspec/changes references embedded in longer text', () => { + expect( + isTemplatePrompt(`Please drive the implementation of openspec/changes/x. +Many sub-tasks ahead.`), + ).toBe(true); + }); + + // ── Workflow imperatives ───────────────────────────────────────────── + it('flags "Drive the implementation of" workflow preamble', () => { + expect(isTemplatePrompt('Drive the implementation of my-change aggressively.')).toBe(true); + }); + + it('flags "Archive a completed change" workflow preamble', () => { + expect(isTemplatePrompt('Archive a completed change in the experimental workflow.')).toBe(true); + }); + + it('flags "Propose a new change" workflow preamble', () => { + expect(isTemplatePrompt('Propose a new change for the memory filter.')).toBe(true); + }); + + it('flags "Implement tasks from an OpenSpec change" workflow preamble', () => { + expect(isTemplatePrompt('Implement tasks from an OpenSpec change.')).toBe(true); + }); + + it('flags "Enter explore mode" workflow preamble', () => { + expect(isTemplatePrompt('Enter explore mode - think through ideas')).toBe(true); + }); + + // ── Harness command tags ───────────────────────────────────────────── + it('flags tags', () => { + expect(isTemplatePrompt('Some text with foo embedded')).toBe(true); + }); + + it('flags tags', () => { + expect(isTemplatePrompt('bar')).toBe(true); + }); + + it('flags tags', () => { + expect(isTemplatePrompt('test')).toBe(true); + }); + + // ── Slash commands ─────────────────────────────────────────────────── + it('flags /loop as a slash command', () => { + expect(isTemplatePrompt('/loop 5m /foo')).toBe(true); + }); + + it('flags /schedule as a slash command', () => { + expect(isTemplatePrompt('/schedule list')).toBe(true); + }); + + it('flags /review as a slash command', () => { + expect(isTemplatePrompt('/review')).toBe(true); + }); + + it('flags /init as a slash command', () => { + expect(isTemplatePrompt('/init')).toBe(true); + }); + + it('flags case-insensitive slash commands', () => { + expect(isTemplatePrompt('/Review extra args')).toBe(true); + }); + + // ── Multilingual built-in quick-action templates ──────────────────── + // These are sent verbatim by the web UI (see `web/src/i18n/locales/*.json` + // keys `openspec.*_prompt` and `p2p.*_prompt`). Every locale must be + // caught or the filter leaks in non-English contexts. + + describe('openspec.implement_prompt across 7 locales', () => { + it('en', () => { + expect(isTemplatePrompt('Drive the implementation of my-change aggressively.')).toBe(true); + }); + it('zh-CN', () => { + expect( + isTemplatePrompt('强力推进 openspec/changes/foo 的实施。把工作拆成明确子任务。'), + ).toBe(true); + }); + it('zh-TW', () => { + expect( + isTemplatePrompt('強力推進 openspec/changes/foo 的實作。把工作拆成明確子任務。'), + ).toBe(true); + }); + it('es', () => { + expect( + isTemplatePrompt('Impulsa con firmeza la implementación de la propuesta.'), + ).toBe(true); + }); + it('ru', () => { + expect(isTemplatePrompt('Жестко доведи реализацию изменения до конца.')).toBe(true); + }); + it('ja', () => { + expect(isTemplatePrompt('この変更の実装を強力に前進させてください。')).toBe(true); + }); + it('ko', () => { + expect(isTemplatePrompt('이 변경의 구현을 강하게 밀어붙이세요.')).toBe(true); + }); + }); + + describe('openspec.audit_implementation_prompt across 7 locales', () => { + it('en', () => { + expect(isTemplatePrompt('Perform a strict implementation audit for x.')).toBe(true); + }); + it('zh-CN', () => { + expect(isTemplatePrompt('对 x 执行严格的实现审计,逐项对照。')).toBe(true); + }); + it('zh-TW', () => { + expect(isTemplatePrompt('對 x 執行嚴格的實作審計,逐項對照。')).toBe(true); + }); + it('es', () => { + expect(isTemplatePrompt('Realiza una auditoría estricta de la implementación.')).toBe(true); + }); + it('ru', () => { + expect(isTemplatePrompt('Проведи строгий аудит реализации.')).toBe(true); + }); + it('ja', () => { + expect(isTemplatePrompt('厳格な実装監査を実施してください。')).toBe(true); + }); + it('ko', () => { + expect(isTemplatePrompt('엄격한 구현 감사를 수행하세요.')).toBe(true); + }); + }); + + describe('openspec.audit_spec_prompt across 7 locales', () => { + it('en', () => { + expect(isTemplatePrompt('Perform a strict specification audit for y.')).toBe(true); + }); + it('zh-CN', () => { + expect(isTemplatePrompt('对 y 执行严格的规范审计。')).toBe(true); + }); + it('zh-TW', () => { + expect(isTemplatePrompt('對 y 執行嚴格的規格審計。')).toBe(true); + }); + it('es', () => { + expect(isTemplatePrompt('Realiza una auditoría estricta de la especificación.')).toBe(true); + }); + it('ru', () => { + expect(isTemplatePrompt('Проведи строгий аудит спецификации.')).toBe(true); + }); + it('ja', () => { + expect(isTemplatePrompt('厳格な仕様監査を実施してください。')).toBe(true); + }); + it('ko', () => { + expect(isTemplatePrompt('엄격한 명세 감사를 수행하세요.')).toBe(true); + }); + }); + + describe('openspec.propose_from_discussion_prompt across 7 locales', () => { + it('en', () => { + expect(isTemplatePrompt('Generate an OpenSpec change from the recent discussion.')).toBe( + true, + ); + }); + it('zh-CN', () => { + expect(isTemplatePrompt('根据最近的讨论生成一个 OpenSpec 变更。')).toBe(true); + }); + it('zh-TW', () => { + expect(isTemplatePrompt('根據最近的討論生成一個 OpenSpec 變更。')).toBe(true); + }); + it('es', () => { + expect(isTemplatePrompt('Genera un cambio de OpenSpec a partir de la discusión reciente.')).toBe( + true, + ); + }); + it('ru', () => { + expect( + isTemplatePrompt('Сгенерируй изменение OpenSpec на основе недавнего обсуждения.'), + ).toBe(true); + }); + it('ja', () => { + expect(isTemplatePrompt('直近の議論から OpenSpec 変更を生成してください。')).toBe(true); + }); + it('ko', () => { + expect(isTemplatePrompt('최근 논의를 바탕으로 OpenSpec 변경을 생성하세요.')).toBe(true); + }); + }); + + describe('openspec.achieve_prompt across 7 locales', () => { + it('en', () => { + expect( + isTemplatePrompt('Take my-change to done using the full OpenSpec workflow.'), + ).toBe(true); + }); + it('zh-CN', () => { + expect(isTemplatePrompt('按完整 OpenSpec 工作流把变更推到完成。')).toBe(true); + }); + it('zh-TW', () => { + expect(isTemplatePrompt('依照完整 OpenSpec 工作流程把變更推到完成。')).toBe(true); + }); + it('es', () => { + expect(isTemplatePrompt('Lleva el cambio hasta completarlo usando el flujo completo de OpenSpec.')).toBe( + true, + ); + }); + it('ru', () => { + expect(isTemplatePrompt('Доведи изменение до состояния done по полному процессу OpenSpec.')).toBe( + true, + ); + }); + it('ja', () => { + expect(isTemplatePrompt('完全な OpenSpec ワークフローで変更を done まで持っていってください。')).toBe(true); + }); + it('ko', () => { + expect(isTemplatePrompt('전체 OpenSpec 워크플로로 변경을 완료 상태까지 밀어붙이세요.')).toBe(true); + }); + }); + + describe('p2p.post_summary_execute_prompt across 7 locales', () => { + it('en', () => { + expect(isTemplatePrompt('The P2P discussion is complete. Use the discussion file.')).toBe( + true, + ); + }); + it('zh-CN', () => { + expect(isTemplatePrompt('P2P 讨论已经完成。请把讨论文件作为上下文。')).toBe(true); + }); + it('zh-TW', () => { + expect(isTemplatePrompt('P2P 討論已完成。請把討論檔案作為上下文。')).toBe(true); + }); + it('es', () => { + expect(isTemplatePrompt('La discusión P2P ha terminado.')).toBe(true); + }); + it('ru', () => { + expect(isTemplatePrompt('P2P-обсуждение завершено.')).toBe(true); + }); + it('ja', () => { + expect(isTemplatePrompt('P2P議論は完了しました。')).toBe(true); + }); + it('ko', () => { + expect(isTemplatePrompt('P2P 토론이 완료되었습니다.')).toBe(true); + }); + }); + + describe('p2p.final_original_request_reminder across 7 locales', () => { + it('en', () => { + expect( + isTemplatePrompt( + "After synthesizing the discussion, directly address the user's original request.", + ), + ).toBe(true); + }); + it('zh-CN', () => { + expect(isTemplatePrompt('在完成讨论综合后,务必直接落实。')).toBe(true); + }); + it('zh-TW', () => { + expect(isTemplatePrompt('在完成討論綜合後,務必直接落實。')).toBe(true); + }); + it('es', () => { + expect(isTemplatePrompt('No te quedes solo en el resumen de la discusión.')).toBe(true); + }); + it('ru', () => { + expect(isTemplatePrompt('Не ограничивайся только сводкой обсуждения.')).toBe(true); + }); + it('ja', () => { + expect(isTemplatePrompt('議論の要約だけで終わらせず、実行してください。')).toBe(true); + }); + it('ko', () => { + expect(isTemplatePrompt('토론 요약으로 끝내지 말고 실행하세요.')).toBe(true); + }); + }); + + describe('P2P baseline prompt + round headers', () => { + it('flags the shared P2P baseline prompt', () => { + expect( + isTemplatePrompt( + 'You are a staff-level engineer participating in a multi-agent technical discussion.', + ), + ).toBe(true); + }); + it('flags [Round N/M — Phase — Initial Analysis] headers', () => { + expect( + isTemplatePrompt( + '[Round 1/3 — Audit Phase — Initial Analysis]\nProvide your initial analysis based on the original request.', + ), + ).toBe(true); + }); + it('flags [Round N/M — Deepening] round headers', () => { + expect(isTemplatePrompt("[Round 2/3 — Deepening]\nReview ALL previous rounds' findings above.")).toBe( + true, + ); + }); + }); + + // ── Plugin-namespaced skills ──────────────────────────────────────── + it('flags claude-mem:do', () => { + expect(isTemplatePrompt('claude-mem:do run the plan')).toBe(true); + }); + + it('flags opsx:apply', () => { + expect(isTemplatePrompt('opsx:apply the change')).toBe(true); + }); + + it('flags openspec-archive-change', () => { + expect(isTemplatePrompt('openspec-archive-change:run')).toBe(true); + }); + + // ── Negative cases ─────────────────────────────────────────────────── + it('accepts normal natural-language questions', () => { + expect(isTemplatePrompt('How do I fix the download bug?')).toBe(false); + }); + + it('accepts Chinese natural-language questions', () => { + expect(isTemplatePrompt('帮我修一下下载的 bug 好不好')).toBe(false); + }); + + it('accepts prose that mentions "change" without the workflow phrase', () => { + expect(isTemplatePrompt('I want to change the color of this button.')).toBe(false); + }); + + it('accepts prose that mentions "implement" without the workflow phrase', () => { + expect(isTemplatePrompt('Please implement the sorting algorithm we discussed.')).toBe(false); + }); + + it('accepts prose with /path/like/slashes that are not slash commands', () => { + expect(isTemplatePrompt('look at /src/agent/detect.ts for the answer')).toBe(false); + }); + + it('accepts empty / null / undefined without throwing', () => { + expect(isTemplatePrompt('')).toBe(false); + expect(isTemplatePrompt(null)).toBe(false); + expect(isTemplatePrompt(undefined)).toBe(false); + expect(isTemplatePrompt(' \n \t ')).toBe(false); + }); + + it('accepts prose that references a repo path containing "changes"', () => { + expect(isTemplatePrompt('look at changes/not-openspec/foo.ts')).toBe(false); + }); +}); + +describe('isTemplateOriginSummary', () => { + it('flags summaries that reference openspec/changes/', () => { + expect( + isTemplateOriginSummary('User orchestrated openspec/changes/feature-x via subagents.'), + ).toBe(true); + }); + + it('flags summaries with "Drive the implementation of"', () => { + expect(isTemplateOriginSummary('## Summary\n- Drive the implementation of change X')).toBe( + true, + ); + }); + + it('flags summaries with "Archived a completed change"', () => { + expect(isTemplateOriginSummary('Archived the completed change.')).toBe(true); + }); + + it('flags summaries with residual fragments', () => { + expect(isTemplateOriginSummary('Resolved loop request.')).toBe( + true, + ); + }); + + it('accepts normal problem→solution summaries', () => { + expect( + isTemplateOriginSummary( + '## codedeck\n- User problem: download cancel dropped connection.\n- Resolution: added AbortController pass-through.', + ), + ).toBe(false); + }); + + it('accepts empty / null / undefined without throwing', () => { + expect(isTemplateOriginSummary('')).toBe(false); + expect(isTemplateOriginSummary(null)).toBe(false); + expect(isTemplateOriginSummary(undefined)).toBe(false); + }); +}); + +describe('listKnownSlashCommands', () => { + it('exposes a non-empty list for auditing', () => { + const list = listKnownSlashCommands(); + expect(Array.isArray(list)).toBe(true); + expect(list.length).toBeGreaterThan(0); + expect(list).toContain('/loop'); + expect(list).toContain('/schedule'); + }); +}); From c0f818f5fc282369538576575eca40942b86abb8 Mon Sep 17 00:00:00 2001 From: "IM.codes" Date: Fri, 17 Apr 2026 21:43:27 +0800 Subject: [PATCH 003/151] Harden cursor and copilot transport providers --- package-lock.json | 137 +++ package.json | 1 + server/test/bridge.test.ts | 31 + shared/agent-types.ts | 11 +- shared/context-types.ts | 3 +- shared/effort-levels.ts | 1 + shared/transport-attachments.ts | 8 + shared/transport-events.ts | 72 +- src/agent/detect.ts | 16 +- src/agent/provider-registry.ts | 8 + src/agent/providers/_template.ts | 3 +- src/agent/providers/claude-code-sdk.ts | 3 +- src/agent/providers/codex-sdk.ts | 3 +- src/agent/providers/copilot-sdk.ts | 950 +++++++++++++++ src/agent/providers/cursor-headless-stream.ts | 329 +++++ src/agent/providers/cursor-headless.ts | 761 ++++++++++++ src/agent/providers/openclaw.ts | 3 +- src/agent/providers/qwen.ts | 3 +- src/agent/session-manager.ts | 32 +- src/agent/transport-paths.ts | 18 + src/agent/transport-provider.ts | 5 +- src/agent/transport-runtime-assembly.ts | 3 +- src/daemon/lifecycle.ts | 2 +- src/daemon/transport-relay.ts | 20 +- src/store/session-store.ts | 2 + test/agent/provider-registry.test.ts | 81 +- test/agent/providers/copilot-sdk-harness.ts | 210 ++++ test/agent/providers/copilot-sdk.test.ts | 384 ++++++ .../providers/cursor-headless-stream.test.ts | 135 +++ test/agent/providers/cursor-headless.test.ts | 207 ++++ test/cursor-headless-fixture.ts | 102 ++ .../command-handler-transport-queue.test.ts | 33 + test/daemon/copilot-sdk-runtime.test.ts | 83 ++ .../cursor-copilot-transport-restore.test.ts | 379 ++++++ test/daemon/transport-relay.test.ts | 32 +- test/daemon/transport-session-runtime.test.ts | 38 +- test/daemon/transport-types.test.ts | 213 ++-- test/e2e/copilot-sdk-live.test.ts | 192 +++ test/e2e/cursor-headless-live.test.ts | 104 ++ test/e2e/cursor-headless-transport.test.ts | 200 ++++ test/shared/transport-types-contract.test.ts | 148 ++- web/src/components/NewSessionDialog.tsx | 1060 +++++++++++++---- web/src/components/QuickInputPanel.tsx | 2 + web/src/components/SessionControls.tsx | 120 +- web/src/components/SessionSettingsDialog.tsx | 4 + web/src/components/StartSubSessionDialog.tsx | 24 +- web/src/i18n/locales/en.json | 10 +- web/src/i18n/locales/es.json | 10 +- web/src/i18n/locales/ja.json | 10 +- web/src/i18n/locales/ko.json | 10 +- web/src/i18n/locales/ru.json | 10 +- web/src/i18n/locales/zh-CN.json | 10 +- web/src/i18n/locales/zh-TW.json | 10 +- web/src/pages/AddProject.tsx | 2 +- web/src/pages/ProjectSettings.tsx | 2 +- web/src/ws-client.ts | 39 + web/test/components/QuickInputPanel.test.tsx | 58 + web/test/components/SessionControls.test.tsx | 55 +- web/test/ws-client.test.ts | 44 + 59 files changed, 5992 insertions(+), 454 deletions(-) create mode 100644 shared/transport-attachments.ts create mode 100644 src/agent/providers/copilot-sdk.ts create mode 100644 src/agent/providers/cursor-headless-stream.ts create mode 100644 src/agent/providers/cursor-headless.ts create mode 100644 test/agent/providers/copilot-sdk-harness.ts create mode 100644 test/agent/providers/copilot-sdk.test.ts create mode 100644 test/agent/providers/cursor-headless-stream.test.ts create mode 100644 test/agent/providers/cursor-headless.test.ts create mode 100644 test/cursor-headless-fixture.ts create mode 100644 test/daemon/copilot-sdk-runtime.test.ts create mode 100644 test/daemon/cursor-copilot-transport-restore.test.ts create mode 100644 test/e2e/copilot-sdk-live.test.ts create mode 100644 test/e2e/cursor-headless-live.test.ts create mode 100644 test/e2e/cursor-headless-transport.test.ts diff --git a/package-lock.json b/package-lock.json index 53eb14254..76d092146 100644 --- a/package-lock.json +++ b/package-lock.json @@ -10,6 +10,7 @@ "license": "MIT", "dependencies": { "@anthropic-ai/claude-agent-sdk": "^0.2.92", + "@github/copilot-sdk": "^0.2.2", "@huggingface/transformers": "^4.1.0", "@openai/codex-sdk": "^0.118.0", "commander": "^12.1.0", @@ -981,6 +982,133 @@ } } }, + "node_modules/@github/copilot": { + "version": "1.0.31", + "resolved": "https://registry.npmjs.org/@github/copilot/-/copilot-1.0.31.tgz", + "integrity": "sha512-AfoVW9pHsKQGtLCpPcvQ8TOwBVF8meo5srle/8cqRSsx882CpIQx5C4uNs6zwrCtqMTo8M8D6zlDIbXkLudrXw==", + "license": "SEE LICENSE IN LICENSE.md", + "bin": { + "copilot": "npm-loader.js" + }, + "optionalDependencies": { + "@github/copilot-darwin-arm64": "1.0.31", + "@github/copilot-darwin-x64": "1.0.31", + "@github/copilot-linux-arm64": "1.0.31", + "@github/copilot-linux-x64": "1.0.31", + "@github/copilot-win32-arm64": "1.0.31", + "@github/copilot-win32-x64": "1.0.31" + } + }, + "node_modules/@github/copilot-darwin-arm64": { + "version": "1.0.31", + "resolved": "https://registry.npmjs.org/@github/copilot-darwin-arm64/-/copilot-darwin-arm64-1.0.31.tgz", + "integrity": "sha512-DnAbe87U55/egBu/SFdMniQfhnYjfP3ZXXhrba3DZMXQI+91iRAGfPFKAsSlekl0zfNFw8toOkiafr9Hu2lHvA==", + "cpu": [ + "arm64" + ], + "license": "SEE LICENSE IN LICENSE.md", + "optional": true, + "os": [ + "darwin" + ], + "bin": { + "copilot-darwin-arm64": "copilot" + } + }, + "node_modules/@github/copilot-darwin-x64": { + "version": "1.0.31", + "resolved": "https://registry.npmjs.org/@github/copilot-darwin-x64/-/copilot-darwin-x64-1.0.31.tgz", + "integrity": "sha512-mFmuYT3N1JE3zRIwCAPaXGDstL8Npa62Jey3vT4Lo003NfzQrBzvZ4ObAVMTmFQ6pRZzj39rTTKp1vLYGg+K0w==", + "cpu": [ + "x64" + ], + "license": "SEE LICENSE IN LICENSE.md", + "optional": true, + "os": [ + "darwin" + ], + "bin": { + "copilot-darwin-x64": "copilot" + } + }, + "node_modules/@github/copilot-linux-arm64": { + "version": "1.0.31", + "resolved": "https://registry.npmjs.org/@github/copilot-linux-arm64/-/copilot-linux-arm64-1.0.31.tgz", + "integrity": "sha512-R5V7EIqn92f9YMe3zbQkW++Mw8WErDy6hA8Rr95bSJGiTVyWdj5kqPWSAPH6MLjFbC1T5cJQm/1we+QP3XO3Cw==", + "cpu": [ + "arm64" + ], + "license": "SEE LICENSE IN LICENSE.md", + "optional": true, + "os": [ + "linux" + ], + "bin": { + "copilot-linux-arm64": "copilot" + } + }, + "node_modules/@github/copilot-linux-x64": { + "version": "1.0.31", + "resolved": "https://registry.npmjs.org/@github/copilot-linux-x64/-/copilot-linux-x64-1.0.31.tgz", + "integrity": "sha512-LmcCGmYP9QLim/YMu5e1UlVeqCt/cuMI0fIqkdHs68h+0FGreSnHpn7nA9RbjAbQuPq9HFWeFjG5UpbAHM71Xg==", + "cpu": [ + "x64" + ], + "license": "SEE LICENSE IN LICENSE.md", + "optional": true, + "os": [ + "linux" + ], + "bin": { + "copilot-linux-x64": "copilot" + } + }, + "node_modules/@github/copilot-sdk": { + "version": "0.2.2", + "resolved": "https://registry.npmjs.org/@github/copilot-sdk/-/copilot-sdk-0.2.2.tgz", + "integrity": "sha512-VZCqS08YlUM90bUKJ7VLeIxgTTEHtfXBo84T1IUMNvXRREX2csjPH6Z+CPw3S2468RcCLvzBXcc9LtJJTLIWFw==", + "license": "MIT", + "dependencies": { + "@github/copilot": "^1.0.21", + "vscode-jsonrpc": "^8.2.1", + "zod": "^4.3.6" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@github/copilot-win32-arm64": { + "version": "1.0.31", + "resolved": "https://registry.npmjs.org/@github/copilot-win32-arm64/-/copilot-win32-arm64-1.0.31.tgz", + "integrity": "sha512-OlMPsQYFbl1hzrE0t703BwB9k8lQauQ4ETiiKpXSV4FxUb3DAU9PqWcy1pZoBjmLCni9h1ASQQKmPQ9ERJPm3g==", + "cpu": [ + "arm64" + ], + "license": "SEE LICENSE IN LICENSE.md", + "optional": true, + "os": [ + "win32" + ], + "bin": { + "copilot-win32-arm64": "copilot.exe" + } + }, + "node_modules/@github/copilot-win32-x64": { + "version": "1.0.31", + "resolved": "https://registry.npmjs.org/@github/copilot-win32-x64/-/copilot-win32-x64-1.0.31.tgz", + "integrity": "sha512-nK8uRdlKH6TNk1cjBqEPTvzWQxwnDPgNN3M5bB7TBXL6EsaFdUJePz4tqutUPoPbSKQqo+DtmJGT3/+A30ZcXg==", + "cpu": [ + "x64" + ], + "license": "SEE LICENSE IN LICENSE.md", + "optional": true, + "os": [ + "win32" + ], + "bin": { + "copilot-win32-x64": "copilot.exe" + } + }, "node_modules/@hono/node-server": { "version": "1.19.12", "resolved": "https://registry.npmjs.org/@hono/node-server/-/node-server-1.19.12.tgz", @@ -7538,6 +7666,15 @@ } } }, + "node_modules/vscode-jsonrpc": { + "version": "8.2.1", + "resolved": "https://registry.npmjs.org/vscode-jsonrpc/-/vscode-jsonrpc-8.2.1.tgz", + "integrity": "sha512-kdjOSJ2lLIn7r1rtrMbbNCHjyMPfRnowdKjBQ+mGq6NAW5QY2bEZC/khaC5OR8svbbjvLEaIXkOq45e2X9BIbQ==", + "license": "MIT", + "engines": { + "node": ">=14.0.0" + } + }, "node_modules/w3c-xmlserializer": { "version": "5.0.0", "resolved": "https://registry.npmjs.org/w3c-xmlserializer/-/w3c-xmlserializer-5.0.0.tgz", diff --git a/package.json b/package.json index f39a4862b..3aba38d8b 100644 --- a/package.json +++ b/package.json @@ -39,6 +39,7 @@ }, "dependencies": { "@anthropic-ai/claude-agent-sdk": "^0.2.92", + "@github/copilot-sdk": "^0.2.2", "@huggingface/transformers": "^4.1.0", "@openai/codex-sdk": "^0.118.0", "commander": "^12.1.0", diff --git a/server/test/bridge.test.ts b/server/test/bridge.test.ts index 046329496..8f1e89211 100644 --- a/server/test/bridge.test.ts +++ b/server/test/bridge.test.ts @@ -2027,6 +2027,37 @@ describe('WsBridge', () => { expect(msg.description).toBe('Write to file /etc/passwd'); }); + it('relays chat.approval_response only to subscribed browsers', async () => { + const bridge = WsBridge.get(serverId); + const daemonWs = new MockWs(); + bridge.handleDaemonConnection(daemonWs as never, makeDb('valid-hash'), {} as never); + daemonWs.emit('message', JSON.stringify({ type: 'auth', serverId, token: 't' })); + await flushAsync(); + + const subscribedBrowser = new MockWs(); + const unsubscribedBrowser = new MockWs(); + bridge.handleBrowserConnection(subscribedBrowser as never, 'user-sub', makeDb('valid-hash')); + bridge.handleBrowserConnection(unsubscribedBrowser as never, 'user-unsub', makeDb('valid-hash')); + subscribedBrowser.emit('message', JSON.stringify({ type: 'chat.subscribe', sessionId: 'ts-approval-response' })); + await flushAsync(); + subscribedBrowser.sent.length = 0; + unsubscribedBrowser.sent.length = 0; + + daemonWs.emit('message', JSON.stringify({ + type: 'chat.approval_response', + sessionId: 'ts-approval-response', + requestId: 'req-2', + approved: true, + })); + await flushAsync(); + + expect(subscribedBrowser.sentStrings.some((raw) => { + const msg = JSON.parse(raw); + return msg.type === 'chat.approval_response' && msg.requestId === 'req-2' && msg.approved === true; + })).toBe(true); + expect(unsubscribedBrowser.sentStrings.some((raw) => JSON.parse(raw).type === 'chat.approval_response')).toBe(false); + }); + it('isolates transport subscriptions between browsers', async () => { const bridge = WsBridge.get(serverId); const daemonWs = new MockWs(); diff --git a/shared/agent-types.ts b/shared/agent-types.ts index 5967d8a5d..30096e1ae 100644 --- a/shared/agent-types.ts +++ b/shared/agent-types.ts @@ -3,6 +3,8 @@ export const SESSION_AGENT_TYPES = [ 'claude-code', 'codex-sdk', 'codex', + 'copilot-sdk', + 'cursor-headless', 'opencode', 'gemini', 'qwen', @@ -15,7 +17,14 @@ export type SessionAgentType = typeof SESSION_AGENT_TYPES[number]; export const CLAUDE_CODE_FAMILY = ['claude-code-sdk', 'claude-code'] as const; export const CODEX_FAMILY = ['codex-sdk', 'codex'] as const; -export const TRANSPORT_SESSION_AGENT_TYPES = ['claude-code-sdk', 'codex-sdk', 'qwen', 'openclaw'] as const; +export const TRANSPORT_SESSION_AGENT_TYPES = [ + 'claude-code-sdk', + 'codex-sdk', + 'copilot-sdk', + 'cursor-headless', + 'qwen', + 'openclaw', +] as const; export const PROCESS_SESSION_AGENT_TYPES = ['claude-code', 'codex', 'opencode', 'gemini', 'shell', 'script'] as const; export function isSessionAgentType(value: string): value is SessionAgentType { diff --git a/shared/context-types.ts b/shared/context-types.ts index 471e19db8..5c008a49c 100644 --- a/shared/context-types.ts +++ b/shared/context-types.ts @@ -129,7 +129,7 @@ export interface ProviderContextPayload { assembledMessage: string; systemText?: string; messagePreamble?: string; - attachments?: unknown[]; + attachments?: TransportAttachment[]; startupMemory?: TransportMemoryRecallArtifact; memoryRecall?: TransportMemoryRecallArtifact; context: CompiledAgentContextArtifact; @@ -287,3 +287,4 @@ export interface ProcessedContextReplicationBody { namespace: ContextNamespace; projections: ProcessedContextProjection[]; } +import type { TransportAttachment } from './transport-attachments.js'; diff --git a/shared/effort-levels.ts b/shared/effort-levels.ts index 5dbf12c88..3650f56ff 100644 --- a/shared/effort-levels.ts +++ b/shared/effort-levels.ts @@ -6,6 +6,7 @@ export const DEFAULT_TRANSPORT_EFFORT: TransportEffortLevel = 'high'; export const CLAUDE_SDK_EFFORT_LEVELS = ['low', 'medium', 'high', 'max'] as const satisfies readonly TransportEffortLevel[]; export const CODEX_SDK_EFFORT_LEVELS = ['minimal', 'low', 'medium', 'high'] as const satisfies readonly TransportEffortLevel[]; +export const COPILOT_SDK_EFFORT_LEVELS = ['low', 'medium', 'high', 'max'] as const satisfies readonly TransportEffortLevel[]; export const QWEN_EFFORT_LEVELS = ['off', 'low', 'medium', 'high'] as const satisfies readonly TransportEffortLevel[]; export const OPENCLAW_THINKING_LEVELS = ['off', 'minimal', 'low', 'medium', 'high', 'adaptive'] as const satisfies readonly TransportEffortLevel[]; diff --git a/shared/transport-attachments.ts b/shared/transport-attachments.ts new file mode 100644 index 000000000..ee3c24680 --- /dev/null +++ b/shared/transport-attachments.ts @@ -0,0 +1,8 @@ +export interface TransportAttachment { + id: string; + daemonPath: string; + originalName?: string; + mime?: string; + size?: number; + type?: 'file' | 'image'; +} diff --git a/shared/transport-events.ts b/shared/transport-events.ts index c7e8df7a7..02c61222e 100644 --- a/shared/transport-events.ts +++ b/shared/transport-events.ts @@ -9,7 +9,7 @@ * that uniquely identifies the message kind. */ -import type { ToolCallEvent } from './agent-message.js'; +import type { ToolCallEvent } from "./agent-message.js"; // ── Agent status ────────────────────────────────────────────────────────────── @@ -28,22 +28,30 @@ import type { ToolCallEvent } from './agent-message.js'; * - `unknown` — status cannot be determined */ export type TransportAgentStatus = - | 'idle' - | 'streaming' - | 'thinking' - | 'tool_running' - | 'permission' - | 'error' - | 'unknown'; + | "idle" + | "streaming" + | "thinking" + | "tool_running" + | "permission" + | "error" + | "unknown"; /** All valid TransportAgentStatus values for runtime validation. */ export const TRANSPORT_AGENT_STATUSES = new Set([ - 'idle', 'streaming', 'thinking', 'tool_running', 'permission', 'error', 'unknown', + "idle", + "streaming", + "thinking", + "tool_running", + "permission", + "error", + "unknown", ]); /** Statuses that indicate the agent is actively doing work. */ export const TRANSPORT_ACTIVE_STATUSES = new Set([ - 'streaming', 'thinking', 'tool_running', + "streaming", + "thinking", + "tool_running", ]); // ── Event type constant object ──────────────────────────────────────────────── @@ -57,21 +65,22 @@ export const TRANSPORT_ACTIVE_STATUSES = new Set([ */ export const TRANSPORT_EVENT = { /** Incremental token/tool delta from the agent. */ - CHAT_DELTA: 'chat.delta', + CHAT_DELTA: "chat.delta", /** A message has finished streaming (no more deltas). */ - CHAT_COMPLETE: 'chat.complete', + CHAT_COMPLETE: "chat.complete", /** A non-recoverable error occurred for a message. */ - CHAT_ERROR: 'chat.error', + CHAT_ERROR: "chat.error", /** Agent status changed (idle / streaming / tool_running / …). */ - CHAT_STATUS: 'chat.status', + CHAT_STATUS: "chat.status", /** A tool call started or completed. */ - CHAT_TOOL: 'chat.tool', + CHAT_TOOL: "chat.tool", /** Agent is requesting user approval before proceeding. */ - CHAT_APPROVAL: 'chat.approval', + CHAT_APPROVAL: "chat.approval", } as const; /** Union of all TRANSPORT_EVENT values (for exhaustive type checks). */ -export type TransportEventType = (typeof TRANSPORT_EVENT)[keyof typeof TRANSPORT_EVENT]; +export type TransportEventType = + (typeof TRANSPORT_EVENT)[keyof typeof TRANSPORT_EVENT]; // ── Browser relay message name constant object ──────────────────────────────── @@ -84,19 +93,24 @@ export type TransportEventType = (typeof TRANSPORT_EVENT)[keyof typeof TRANSPORT */ export const TRANSPORT_MSG = { /** Browser → Bridge: subscribe to transport events for a session. */ - CHAT_SUBSCRIBE: 'chat.subscribe', + CHAT_SUBSCRIBE: "chat.subscribe", /** Browser → Bridge: stop receiving transport events for a session. */ - CHAT_UNSUBSCRIBE: 'chat.unsubscribe', + CHAT_UNSUBSCRIBE: "chat.unsubscribe", + /** Bridge → Browser: agent is requesting approval before continuing. */ + CHAT_APPROVAL: "chat.approval", + /** Browser → Daemon: answer a pending transport approval request. */ + APPROVAL_RESPONSE: "chat.approval_response", /** Bridge → Browser: broadcast of agent/provider availability status. */ - PROVIDER_STATUS: 'provider.status', + PROVIDER_STATUS: "provider.status", /** Browser → Daemon: request list of remote sessions from a provider. */ - LIST_SESSIONS: 'provider.list_sessions', + LIST_SESSIONS: "provider.list_sessions", /** Daemon → Browser: response with remote sessions list. */ - SESSIONS_RESPONSE:'provider.sessions_response', + SESSIONS_RESPONSE: "provider.sessions_response", } as const; /** Union of all TRANSPORT_MSG values. */ -export type TransportMsgType = (typeof TRANSPORT_MSG)[keyof typeof TRANSPORT_MSG]; +export type TransportMsgType = + (typeof TRANSPORT_MSG)[keyof typeof TRANSPORT_MSG]; /** All relay message types that should be forwarded from bridge to browser. */ export const TRANSPORT_RELAY_TYPES = new Set([ @@ -106,6 +120,7 @@ export const TRANSPORT_RELAY_TYPES = new Set([ TRANSPORT_EVENT.CHAT_STATUS, TRANSPORT_EVENT.CHAT_TOOL, TRANSPORT_EVENT.CHAT_APPROVAL, + TRANSPORT_MSG.APPROVAL_RESPONSE, TRANSPORT_MSG.PROVIDER_STATUS, ]); @@ -124,7 +139,7 @@ export type TransportEvent = /** The incremental text fragment. */ delta: string; /** Whether this delta is a plain text fragment or tool-use input fragment. */ - deltaType?: 'text' | 'tool_use'; + deltaType?: "text" | "tool_use"; } | { /** The message has finished — no more deltas will follow. */ @@ -162,4 +177,13 @@ export type TransportEvent = requestId: string; /** Human-readable description of what the agent is asking permission to do. */ description: string; + /** Tool name that triggered the approval request, if available. */ + tool?: string; + } + | { + /** Browser-originated approval response broadcast back to transport subscribers. */ + type: typeof TRANSPORT_MSG.APPROVAL_RESPONSE; + sessionId: string; + requestId: string; + approved: boolean; }; diff --git a/src/agent/detect.ts b/src/agent/detect.ts index 55dc324c5..34a72541a 100644 --- a/src/agent/detect.ts +++ b/src/agent/detect.ts @@ -5,6 +5,12 @@ * Status: 'idle' | 'streaming' | 'thinking' | 'tool_running' | 'permission' | 'unknown' */ +import { + PROCESS_SESSION_AGENT_TYPES, + TRANSPORT_SESSION_AGENT_TYPES, + type SessionAgentType, +} from '../../shared/agent-types.js'; + export type AgentStatus = | 'idle' | 'streaming' @@ -15,19 +21,19 @@ export type AgentStatus = | 'unknown'; /** Process-backed agents — controlled via tmux sessions */ -export type ProcessAgent = 'claude-code' | 'codex' | 'opencode' | 'shell' | 'script' | 'gemini'; +export type ProcessAgent = typeof PROCESS_SESSION_AGENT_TYPES[number]; /** Transport-backed agents — controlled via network protocols */ -export type TransportAgent = 'openclaw' | 'qwen' | 'claude-code-sdk' | 'codex-sdk'; +export type TransportAgent = typeof TRANSPORT_SESSION_AGENT_TYPES[number]; /** All agent types */ -export type AgentType = ProcessAgent | TransportAgent; +export type AgentType = SessionAgentType; /** Set of all transport agent type strings */ -export const TRANSPORT_AGENTS = new Set(['openclaw', 'qwen', 'claude-code-sdk', 'codex-sdk']); +export const TRANSPORT_AGENTS = new Set(TRANSPORT_SESSION_AGENT_TYPES); /** Set of all process agent type strings */ -export const PROCESS_AGENTS = new Set(['claude-code', 'codex', 'opencode', 'shell', 'script', 'gemini']); +export const PROCESS_AGENTS = new Set(PROCESS_SESSION_AGENT_TYPES); /** Check if an agent type is transport-backed */ export function isTransportAgent(agentType: string): agentType is TransportAgent { diff --git a/src/agent/provider-registry.ts b/src/agent/provider-registry.ts index f90632532..cce8f0ee8 100644 --- a/src/agent/provider-registry.ts +++ b/src/agent/provider-registry.ts @@ -96,6 +96,14 @@ async function createProvider(id: string): Promise { const { CodexSdkProvider } = await import('./providers/codex-sdk.js'); return new CodexSdkProvider(); } + case 'cursor-headless': { + const { CursorHeadlessProvider } = await import('./providers/cursor-headless.js'); + return new CursorHeadlessProvider(); + } + case 'copilot-sdk': { + const { CopilotSdkProvider } = await import('./providers/copilot-sdk.js'); + return new CopilotSdkProvider(); + } default: throw new Error(`Unknown provider: ${id}`); } diff --git a/src/agent/providers/_template.ts b/src/agent/providers/_template.ts index e8ae8bf0c..68030669a 100644 --- a/src/agent/providers/_template.ts +++ b/src/agent/providers/_template.ts @@ -33,6 +33,7 @@ import { PROVIDER_ERROR_CODES, } from '../transport-provider.js'; import type { AgentMessage, MessageDelta } from '../../../shared/agent-message.js'; +import type { TransportAttachment } from '../../../shared/transport-attachments.js'; import logger from '../../util/logger.js'; // TODO: Replace 'your-provider' with the unique stable id for your provider. @@ -133,7 +134,7 @@ export class YourProvider implements TransportProvider { * @param message - Plain string or ProviderContextPayload. * @param attachments - Only present when capabilities.attachments is true. */ - async send(sessionId: string, _message: string, _attachments?: unknown[]): Promise { + async send(sessionId: string, _message: string, _attachments?: TransportAttachment[]): Promise { if (!this.config) { throw this.makeError(PROVIDER_ERROR_CODES.CONNECTION_LOST, 'Not connected', false); } diff --git a/src/agent/providers/claude-code-sdk.ts b/src/agent/providers/claude-code-sdk.ts index 8800a5573..e3b537068 100644 --- a/src/agent/providers/claude-code-sdk.ts +++ b/src/agent/providers/claude-code-sdk.ts @@ -21,6 +21,7 @@ import { } from '../transport-provider.js'; import type { AgentMessage, MessageDelta } from '../../../shared/agent-message.js'; import type { ProviderContextPayload } from '../../../shared/context-types.js'; +import type { TransportAttachment } from '../../../shared/transport-attachments.js'; import logger from '../../util/logger.js'; import { CLAUDE_SDK_EFFORT_LEVELS, type TransportEffortLevel } from '../../../shared/effort-levels.js'; import { normalizeTransportCwd, resolveExecutableForSpawn } from '../transport-paths.js'; @@ -232,7 +233,7 @@ export class ClaudeCodeSdkProvider implements TransportProvider { this.emitSessionInfo(sessionId, { effort }); } - async send(sessionId: string, payloadOrMessage: string | ProviderContextPayload, _attachments?: unknown[], extraSystemPrompt?: string): Promise { + async send(sessionId: string, payloadOrMessage: string | ProviderContextPayload, _attachments?: TransportAttachment[], extraSystemPrompt?: string): Promise { if (!this.config) { throw this.makeError(PROVIDER_ERROR_CODES.CONNECTION_LOST, 'Claude Code SDK provider not connected', false); } diff --git a/src/agent/providers/codex-sdk.ts b/src/agent/providers/codex-sdk.ts index a5f34b6f9..00c75531f 100644 --- a/src/agent/providers/codex-sdk.ts +++ b/src/agent/providers/codex-sdk.ts @@ -20,6 +20,7 @@ import { } from '../transport-provider.js'; import type { AgentMessage, MessageDelta } from '../../../shared/agent-message.js'; import type { ProviderContextPayload } from '../../../shared/context-types.js'; +import type { TransportAttachment } from '../../../shared/transport-attachments.js'; import logger from '../../util/logger.js'; import { CODEX_SDK_EFFORT_LEVELS, type TransportEffortLevel } from '../../../shared/effort-levels.js'; import { normalizeTransportCwd, resolveExecutableForSpawn } from '../transport-paths.js'; @@ -304,7 +305,7 @@ export class CodexSdkProvider implements TransportProvider { this.emitSessionInfo(sessionId, { effort }); } - async send(sessionId: string, payloadOrMessage: string | ProviderContextPayload, attachments?: unknown[], extraSystemPrompt?: string): Promise { + async send(sessionId: string, payloadOrMessage: string | ProviderContextPayload, attachments?: TransportAttachment[], extraSystemPrompt?: string): Promise { if (!this.config || !this.child) { throw this.makeError(PROVIDER_ERROR_CODES.CONNECTION_LOST, 'Codex app-server not connected', false); } diff --git a/src/agent/providers/copilot-sdk.ts b/src/agent/providers/copilot-sdk.ts new file mode 100644 index 000000000..ecafcfaad --- /dev/null +++ b/src/agent/providers/copilot-sdk.ts @@ -0,0 +1,950 @@ +import { randomUUID } from 'node:crypto'; +import type { + TransportProvider, + ProviderCapabilities, + ProviderConfig, + ProviderError, + SessionConfig, + SessionInfoUpdate, + ProviderStatusUpdate, + ToolCallEvent, + ApprovalRequest, + RemoteSessionInfo, +} from '../transport-provider.js'; +import { + CONNECTION_MODES, + normalizeProviderPayload, + SESSION_OWNERSHIP, + PROVIDER_ERROR_CODES, +} from '../transport-provider.js'; +import type { AgentMessage, MessageDelta } from '../../../shared/agent-message.js'; +import type { ProviderContextPayload } from '../../../shared/context-types.js'; +import type { TransportAttachment } from '../../../shared/transport-attachments.js'; +import logger from '../../util/logger.js'; +import { resolveBinaryWithWindowsFallbacks } from '../transport-paths.js'; +import { type TransportEffortLevel } from '../../../shared/effort-levels.js'; + +const COPILOT_BIN = 'copilot'; +const MIN_PROTOCOL_VERSION = 3; +const COMPATIBLE_CLI_RANGE = '^1.0.31'; +const DEFAULT_APPROVAL_TIMEOUT_MS = 30_000; + +export interface CopilotSdkRuntimeHooks { + loadSdk(): Promise; +} + +export const copilotSdkRuntimeHooks: CopilotSdkRuntimeHooks = { + loadSdk: async () => import('@github/copilot-sdk'), +}; + +type CopilotSessionLike = { + sessionId: string; + send(options: Record): Promise; + abort(): Promise; + setModel(model: string, options?: Record): Promise; + on(handler: (event: Record) => void): () => void; + disconnect?(): Promise; +}; + +type CopilotClientLike = { + start(): Promise; + stop(): Promise; + getStatus(): Promise<{ version: string; protocolVersion: number }>; + getAuthStatus(): Promise<{ isAuthenticated: boolean; statusMessage?: string }>; + createSession(config: Record): Promise; + resumeSession(sessionId: string, config: Record): Promise; + listSessions(filter?: Record): Promise>; + deleteSession(sessionId: string): Promise; + listModels(): Promise>; +}; + +interface PendingApproval { + routeId: string; + requestId: string; + generation: number; + timer: ReturnType | null; + resolve: (result: Record) => void; +} + +interface CopilotSessionState { + routeId: string; + sessionId: string; + session: CopilotSessionLike; + cwd: string; + model?: string; + effort?: TransportEffortLevel; + currentMessageId: string | null; + currentText: string; + completionEmittedForCurrentTurn: boolean; + currentOutputTokens?: number; + currentInteractionId?: string; + busy: boolean; + backgroundTainted: boolean; + cancelRequested: boolean; + cancelErrorEmitted: boolean; + rotationInProgress: boolean; + generation: number; + lastStatusSignature: string | null; + pendingApprovals: Map; + unsubscribes: Array<() => void>; +} + +function isNonEmptyString(value: unknown): value is string { + return typeof value === 'string' && value.trim().length > 0; +} + +function mapEffortToCopilot(effort: TransportEffortLevel | undefined): 'low' | 'medium' | 'high' | 'xhigh' | undefined { + switch (effort) { + case 'low': return 'low'; + case 'medium': return 'medium'; + case 'high': return 'high'; + case 'max': return 'xhigh'; + default: return undefined; + } +} + +function isCompatibleCopilotCliVersion(version: string | undefined): boolean { + if (!isNonEmptyString(version)) return false; + const match = version.trim().match(/^(\d+)\.(\d+)\.(\d+)$/); + if (!match) return false; + const major = Number(match[1]); + const minor = Number(match[2]); + const patch = Number(match[3]); + if (major !== 1) return false; + return minor > 0 || patch >= 31; +} + +function stringifyUnknown(value: unknown): string | undefined { + if (value == null) return undefined; + if (typeof value === 'string') return value; + try { + return JSON.stringify(value); + } catch { + return String(value); + } +} + +function toAttachmentPayload(attachments: TransportAttachment[] | undefined): Array> | undefined { + if (!attachments?.length) return undefined; + return attachments.map((attachment) => ({ + type: 'file', + path: attachment.daemonPath, + ...(attachment.originalName ? { displayName: attachment.originalName } : {}), + })); +} + +function buildApprovalDescription(request: Record): string { + const kind = isNonEmptyString(request.kind) ? request.kind : 'tool'; + switch (kind) { + case 'shell': { + const command = isNonEmptyString(request.fullCommandText) + ? request.fullCommandText + : isNonEmptyString(request.command) + ? request.command + : stringifyUnknown(request); + return command ? `Allow shell command: ${command}` : 'Allow shell command'; + } + case 'write': { + const filePath = isNonEmptyString(request.filePath) ? request.filePath : undefined; + return filePath ? `Allow file write: ${filePath}` : 'Allow file write'; + } + case 'url': { + const url = isNonEmptyString(request.url) ? request.url : undefined; + return url ? `Allow URL access: ${url}` : 'Allow URL access'; + } + case 'mcp': { + const serverName = isNonEmptyString(request.serverName) ? request.serverName : 'mcp'; + const toolName = isNonEmptyString(request.toolName) ? request.toolName : 'tool'; + return `Allow MCP tool ${serverName}:${toolName}`; + } + case 'custom-tool': { + const toolName = isNonEmptyString(request.toolName) ? request.toolName : 'custom-tool'; + return `Allow custom tool ${toolName}`; + } + case 'read': { + const filePath = isNonEmptyString(request.filePath) ? request.filePath : undefined; + return filePath ? `Allow file read: ${filePath}` : 'Allow file read'; + } + default: + return `Allow ${kind} permission request`; + } +} + +function toolFromEvent(event: Record): ToolCallEvent | null { + if (event.type === 'tool.execution_start') { + return { + id: String(event.data?.toolCallId ?? randomUUID()), + name: String(event.data?.toolName ?? 'tool'), + status: 'running', + ...(event.data?.arguments !== undefined ? { input: event.data.arguments } : {}), + detail: { + kind: 'tool.execution_start', + summary: String(event.data?.toolName ?? 'tool'), + input: event.data?.arguments, + meta: { + ...(event.data?.mcpServerName ? { mcpServerName: event.data.mcpServerName } : {}), + ...(event.data?.mcpToolName ? { mcpToolName: event.data.mcpToolName } : {}), + }, + raw: event, + }, + }; + } + if (event.type === 'tool.execution_complete') { + return { + id: String(event.data?.toolCallId ?? randomUUID()), + name: String(event.data?.toolName ?? 'tool'), + status: event.data?.success === false ? 'error' : 'complete', + ...(event.data?.result ? { output: stringifyUnknown(event.data.result.detailedContent ?? event.data.result.content ?? event.data.result.contents) } : {}), + detail: { + kind: 'tool.execution_complete', + summary: String(event.data?.toolName ?? 'tool'), + output: event.data?.result?.detailedContent ?? event.data?.result?.content ?? event.data?.result?.contents, + meta: { + success: event.data?.success, + model: event.data?.model, + interactionId: event.data?.interactionId, + isUserRequested: event.data?.isUserRequested, + }, + raw: event, + }, + }; + } + return null; +} + +export class CopilotSdkProvider implements TransportProvider { + readonly id = 'copilot-sdk'; + readonly connectionMode = CONNECTION_MODES.LOCAL_SDK; + readonly sessionOwnership = SESSION_OWNERSHIP.SHARED; + readonly capabilities: ProviderCapabilities = { + streaming: true, + toolCalling: true, + approval: true, + sessionRestore: true, + multiTurn: true, + attachments: true, + reasoningEffort: true, + supportedEffortLevels: ['low', 'medium', 'high', 'max'], + contextSupport: 'degraded-message-side-context-mapping', + }; + + private config: ProviderConfig | null = null; + private approvalTimeoutMs = DEFAULT_APPROVAL_TIMEOUT_MS; + private sdk: typeof import('@github/copilot-sdk') | null = null; + private client: CopilotClientLike | null = null; + private sessions = new Map(); + private poisonedSessionIds = new Set(); + private deltaCallbacks: Array<(sessionId: string, delta: MessageDelta) => void> = []; + private completeCallbacks: Array<(sessionId: string, message: AgentMessage) => void> = []; + private errorCallbacks: Array<(sessionId: string, error: ProviderError) => void> = []; + private toolCallCallbacks: Array<(sessionId: string, tool: ToolCallEvent) => void> = []; + private sessionInfoCallbacks: Array<(sessionId: string, info: SessionInfoUpdate) => void> = []; + private statusCallbacks: Array<(sessionId: string, status: ProviderStatusUpdate) => void> = []; + private approvalCallbacks: Array<(sessionId: string, req: ApprovalRequest) => void> = []; + + async connect(config: ProviderConfig): Promise { + const sdk = await copilotSdkRuntimeHooks.loadSdk(); + const resolvedBinary = this.resolveBinaryPath(config); + const client = new sdk.CopilotClient({ + ...(resolvedBinary ? { cliPath: resolvedBinary } : {}), + autoStart: false, + }) as unknown as CopilotClientLike; + try { + await client.start(); + const status = await client.getStatus(); + if (!isCompatibleCopilotCliVersion(status.version)) { + throw this.makeError( + PROVIDER_ERROR_CODES.CONFIG_ERROR, + `Copilot CLI ${status.version ?? 'unknown'} is outside supported range ${COMPATIBLE_CLI_RANGE}`, + false, + status, + ); + } + if (typeof status.protocolVersion !== 'number' || status.protocolVersion < MIN_PROTOCOL_VERSION) { + throw this.makeError( + PROVIDER_ERROR_CODES.CONFIG_ERROR, + `Copilot SDK protocol ${status.protocolVersion ?? 'unknown'} is below required ${MIN_PROTOCOL_VERSION} (tested with CLI ${COMPATIBLE_CLI_RANGE})`, + false, + status, + ); + } + const auth = await client.getAuthStatus(); + if (!auth.isAuthenticated) { + throw this.makeError( + PROVIDER_ERROR_CODES.AUTH_FAILED, + auth.statusMessage || 'Copilot is not authenticated', + false, + auth, + ); + } + try { + await client.listModels(); + } catch (error) { + logger.warn({ provider: this.id, error }, 'Copilot listModels probe failed — continuing with connect'); + } + this.sdk = sdk; + this.client = client; + this.config = config; + this.approvalTimeoutMs = this.resolveApprovalTimeoutMs(config); + logger.info({ provider: this.id, binary: resolvedBinary ?? 'default' }, 'Copilot SDK provider connected'); + } catch (error) { + try { await client.stop(); } catch {} + if (this.isProviderError(error)) throw error; + throw this.normalizeConnectError(error); + } + } + + async disconnect(): Promise { + for (const state of this.sessions.values()) { + state.unsubscribes.forEach((fn) => fn()); + try { await state.session.disconnect?.(); } catch {} + for (const pending of state.pendingApprovals.values()) { + if (pending.timer) clearTimeout(pending.timer); + pending.resolve({ kind: 'denied-no-approval-rule-and-could-not-request-from-user' }); + } + state.pendingApprovals.clear(); + } + this.sessions.clear(); + this.poisonedSessionIds.clear(); + if (this.client) { + try { await this.client.stop(); } catch {} + } + this.client = null; + this.sdk = null; + this.config = null; + } + + async createSession(config: SessionConfig): Promise { + this.assertConnected(); + const routeId = config.bindExistingKey ?? config.sessionKey; + const existing = this.sessions.get(routeId); + if (existing && !config.fresh) { + if (isNonEmptyString(config.agentId)) existing.model = config.agentId; + if (isNonEmptyString(config.resumeId) && config.resumeId !== existing.sessionId) { + await this.replaceSession(existing, config.resumeId); + } + this.emitSessionInfo(routeId, { + resumeId: existing.sessionId, + ...(existing.model ? { model: existing.model } : {}), + ...(existing.effort ? { effort: existing.effort } : {}), + }); + return routeId; + } + if (existing && config.fresh) { + await this.endSession(routeId); + } + + const model = isNonEmptyString(config.agentId) ? config.agentId : this.resolveDefaultModel(); + const effort = config.effort; + const session = config.skipCreate && isNonEmptyString(config.resumeId) + ? await this.resumeSdkSession(config.resumeId, config, model, effort) + : await this.createSdkSession(config, model, effort); + const state: CopilotSessionState = { + routeId, + sessionId: session.sessionId, + session, + cwd: isNonEmptyString(config.cwd) ? config.cwd : process.cwd(), + model, + effort, + currentMessageId: null, + currentText: '', + completionEmittedForCurrentTurn: false, + currentOutputTokens: undefined, + currentInteractionId: undefined, + busy: false, + backgroundTainted: false, + cancelRequested: false, + cancelErrorEmitted: false, + rotationInProgress: false, + generation: 0, + lastStatusSignature: null, + pendingApprovals: new Map(), + unsubscribes: [], + }; + this.sessions.set(routeId, state); + this.attachSession(state); + this.emitSessionInfo(routeId, { + resumeId: session.sessionId, + ...(model ? { model } : {}), + ...(effort ? { effort } : {}), + }); + return routeId; + } + + async endSession(sessionId: string): Promise { + const state = this.getSessionState(sessionId); + if (!state) return; + state.unsubscribes.forEach((fn) => fn()); + state.unsubscribes = []; + for (const pending of state.pendingApprovals.values()) { + if (pending.timer) clearTimeout(pending.timer); + pending.resolve({ kind: 'denied-no-approval-rule-and-could-not-request-from-user' }); + } + state.pendingApprovals.clear(); + try { await state.session.disconnect?.(); } catch {} + this.sessions.delete(state.routeId); + } + + onDelta(cb: (sessionId: string, delta: MessageDelta) => void): () => void { + this.deltaCallbacks.push(cb); + return () => { + const idx = this.deltaCallbacks.indexOf(cb); + if (idx >= 0) this.deltaCallbacks.splice(idx, 1); + }; + } + + onComplete(cb: (sessionId: string, message: AgentMessage) => void): () => void { + this.completeCallbacks.push(cb); + return () => { + const idx = this.completeCallbacks.indexOf(cb); + if (idx >= 0) this.completeCallbacks.splice(idx, 1); + }; + } + + onError(cb: (sessionId: string, error: ProviderError) => void): () => void { + this.errorCallbacks.push(cb); + return () => { + const idx = this.errorCallbacks.indexOf(cb); + if (idx >= 0) this.errorCallbacks.splice(idx, 1); + }; + } + + onToolCall(cb: (sessionId: string, tool: ToolCallEvent) => void): void { + this.toolCallCallbacks.push(cb); + } + + onSessionInfo(cb: (sessionId: string, info: SessionInfoUpdate) => void): () => void { + this.sessionInfoCallbacks.push(cb); + return () => { + const idx = this.sessionInfoCallbacks.indexOf(cb); + if (idx >= 0) this.sessionInfoCallbacks.splice(idx, 1); + }; + } + + onStatus(cb: (sessionId: string, status: ProviderStatusUpdate) => void): () => void { + this.statusCallbacks.push(cb); + return () => { + const idx = this.statusCallbacks.indexOf(cb); + if (idx >= 0) this.statusCallbacks.splice(idx, 1); + }; + } + + onApprovalRequest(cb: (sessionId: string, req: ApprovalRequest) => void): void { + this.approvalCallbacks.push(cb); + } + + async respondApproval(sessionId: string, requestId: string, approved: boolean): Promise { + const state = this.getSessionState(sessionId); + if (!state) { + throw this.makeError(PROVIDER_ERROR_CODES.SESSION_NOT_FOUND, `Unknown Copilot session: ${sessionId}`, false); + } + const pending = state.pendingApprovals.get(requestId); + if (!pending) { + throw this.makeError(PROVIDER_ERROR_CODES.PROVIDER_ERROR, `Unknown approval request: ${requestId}`, true); + } + state.pendingApprovals.delete(requestId); + if (pending.timer) clearTimeout(pending.timer); + pending.resolve(approved + ? { kind: 'approved' } + : { kind: 'denied-interactively-by-user' }); + this.emitStatus(state.routeId, { status: null, label: null }); + } + + setSessionAgentId(sessionId: string, agentId: string): void { + const state = this.getSessionState(sessionId); + if (!state) return; + state.model = agentId; + this.emitSessionInfo(state.routeId, { resumeId: state.sessionId, model: agentId }); + void state.session.setModel(agentId, { + ...(mapEffortToCopilot(state.effort) ? { reasoningEffort: mapEffortToCopilot(state.effort) } : {}), + }).catch((error) => { + logger.warn({ err: error, provider: this.id, sessionId: state.routeId }, 'Failed to update Copilot session model'); + }); + } + + setSessionEffort(sessionId: string, effort: TransportEffortLevel): void { + const state = this.getSessionState(sessionId); + if (!state) return; + state.effort = effort; + this.emitSessionInfo(state.routeId, { resumeId: state.sessionId, effort }); + if (!state.model) return; + void state.session.setModel(state.model, { + ...(mapEffortToCopilot(effort) ? { reasoningEffort: mapEffortToCopilot(effort) } : {}), + }).catch((error) => { + logger.warn({ err: error, provider: this.id, sessionId: state.routeId }, 'Failed to update Copilot session effort'); + }); + } + + async send(sessionId: string, payloadOrMessage: string | ProviderContextPayload, attachments?: TransportAttachment[], extraSystemPrompt?: string): Promise { + const state = this.getSessionState(sessionId); + if (!state) { + throw this.makeError(PROVIDER_ERROR_CODES.SESSION_NOT_FOUND, `Unknown Copilot session: ${sessionId}`, false); + } + if (state.busy) { + throw this.makeError(PROVIDER_ERROR_CODES.PROVIDER_ERROR, 'Copilot session is already busy', true); + } + const payload = normalizeProviderPayload(payloadOrMessage, attachments, extraSystemPrompt); + const prompt = [payload.systemText?.trim(), payload.assembledMessage?.trim()].filter(Boolean).join('\n\n'); + const sdkAttachments = toAttachmentPayload(payload.attachments); + state.currentMessageId = null; + state.currentText = ''; + state.completionEmittedForCurrentTurn = false; + state.currentOutputTokens = undefined; + state.currentInteractionId = undefined; + state.backgroundTainted = false; + state.cancelRequested = false; + state.cancelErrorEmitted = false; + state.rotationInProgress = false; + state.busy = true; + try { + if (state.model) { + await state.session.setModel(state.model, { + ...(mapEffortToCopilot(state.effort) ? { reasoningEffort: mapEffortToCopilot(state.effort) } : {}), + }); + } + await state.session.send({ + prompt, + ...(sdkAttachments ? { attachments: sdkAttachments } : {}), + mode: 'immediate', + }); + } catch (error) { + state.busy = false; + throw error; + } + } + + async cancel(sessionId: string): Promise { + const state = this.getSessionState(sessionId); + if (!state) return; + state.cancelRequested = true; + try { + await state.session.abort(); + } finally { + state.busy = false; + if (!state.cancelErrorEmitted) { + state.cancelErrorEmitted = true; + this.emitError(state.routeId, this.makeError(PROVIDER_ERROR_CODES.CANCELLED, 'Copilot turn cancelled', true)); + } + } + if (!state.backgroundTainted) return; + await this.rotatePoisonedSession(state); + } + + async restoreSession(sessionId: string): Promise { + if (this.poisonedSessionIds.has(sessionId)) return false; + if (this.getSessionState(sessionId)) return true; + const sessions = await this.listSessions(); + return sessions.some((session) => session.key === sessionId); + } + + async listSessions(): Promise { + const client = this.assertConnected(); + const sessions = await client.listSessions(); + return sessions + .filter((session) => !this.poisonedSessionIds.has(session.sessionId)) + .map((session) => ({ + key: session.sessionId, + ...(session.summary ? { displayName: session.summary } : {}), + ...(session.modifiedTime ? { updatedAt: new Date(session.modifiedTime).getTime() } : {}), + })); + } + + private async createSdkSession(config: SessionConfig, model?: string, effort?: TransportEffortLevel): Promise { + const client = this.assertConnected(); + return client.createSession(this.buildSessionConfig(config, model, effort)); + } + + private async resumeSdkSession(sessionId: string, config: SessionConfig, model?: string, effort?: TransportEffortLevel): Promise { + const client = this.assertConnected(); + return client.resumeSession(sessionId, this.buildSessionConfig(config, model, effort)); + } + + private buildSessionConfig(config: SessionConfig, model?: string, effort?: TransportEffortLevel): Record { + return { + workingDirectory: config.cwd, + ...(model ? { model } : {}), + ...(mapEffortToCopilot(effort) ? { reasoningEffort: mapEffortToCopilot(effort) } : {}), + onPermissionRequest: (request: Record) => this.handlePermissionRequest(config.bindExistingKey ?? config.sessionKey, request), + }; + } + + private attachSession(state: CopilotSessionState): void { + state.unsubscribes.forEach((fn) => fn()); + state.unsubscribes = []; + const generation = ++state.generation; + const unsubscribe = state.session.on((event: Record) => { + if (!this.isCurrentGeneration(state, generation)) return; + this.handleSessionEvent(state, generation, event); + }); + state.unsubscribes.push(unsubscribe); + } + + private handleSessionEvent(state: CopilotSessionState, generation: number, event: Record): void { + if (!this.isCurrentGeneration(state, generation)) return; + const routeId = state.routeId; + if (state.cancelRequested && this.shouldIgnoreCancelledEvent(event.type)) { + return; + } + switch (event.type) { + case 'assistant.message_delta': { + const chunk = String(event.data?.deltaContent ?? ''); + if (!chunk) return; + state.currentMessageId = String(event.data?.messageId ?? state.currentMessageId ?? randomUUID()); + state.currentText += chunk; + const delta: MessageDelta = { + messageId: state.currentMessageId, + type: 'text', + delta: state.currentText, + role: 'assistant', + }; + for (const cb of this.deltaCallbacks) cb(routeId, delta); + return; + } + case 'assistant.message': { + state.currentMessageId = String(event.data?.messageId ?? state.currentMessageId ?? randomUUID()); + const toolRequests = Array.isArray(event.data?.toolRequests) ? event.data.toolRequests : []; + const content = String(event.data?.content ?? state.currentText ?? ''); + if (content && (!state.currentText || content.length >= state.currentText.length || content.startsWith(state.currentText))) { + state.currentText = content; + } + if (!state.currentText && toolRequests.length === 0) { + state.currentText = content; + } + if (typeof event.data?.outputTokens === 'number') { + state.currentOutputTokens = event.data.outputTokens; + } + if (isNonEmptyString(event.data?.interactionId)) { + state.currentInteractionId = event.data.interactionId; + } + return; + } + case 'assistant.usage': { + if (typeof event.data?.outputTokens === 'number') { + state.currentOutputTokens = event.data.outputTokens; + } + if (isNonEmptyString(event.data?.interactionId)) { + state.currentInteractionId = event.data.interactionId; + } + return; + } + case 'tool.execution_start': { + const tool = toolFromEvent(event); + if (tool) { + const args = event.data?.arguments; + const toolName = String(event.data?.toolName ?? '').toLowerCase(); + if ((toolName === 'bash' || toolName === 'shell' || toolName === 'terminal') && this.looksBackgroundTainted(args)) { + this.markBackgroundTainted(state); + } + for (const cb of this.toolCallCallbacks) cb(routeId, tool); + } + return; + } + case 'tool.execution_complete': { + const tool = toolFromEvent(event); + if (tool) { + for (const cb of this.toolCallCallbacks) cb(routeId, tool); + } + return; + } + case 'session.background_tasks_changed': { + this.markBackgroundTainted(state); + return; + } + case 'system.notification': { + const kindType = String(event.data?.kind?.type ?? ''); + if (kindType === 'shell_detached_completed') { + this.markBackgroundTainted(state); + } + return; + } + case 'session.idle': { + state.busy = false; + if (state.cancelRequested && !state.cancelErrorEmitted) { + state.cancelErrorEmitted = true; + this.emitError(routeId, this.makeError(PROVIDER_ERROR_CODES.CANCELLED, 'Copilot turn cancelled', true)); + return; + } + if (!state.completionEmittedForCurrentTurn && state.currentMessageId && state.currentText) { + state.completionEmittedForCurrentTurn = true; + const message: AgentMessage = { + id: state.currentMessageId, + sessionId: routeId, + kind: 'text', + role: 'assistant', + content: state.currentText, + timestamp: Date.now(), + status: 'complete', + metadata: { + ...(state.model ? { model: state.model } : {}), + ...(typeof state.currentOutputTokens === 'number' + ? { usage: { output_tokens: state.currentOutputTokens } } + : {}), + ...(state.currentInteractionId ? { interactionId: state.currentInteractionId } : {}), + resumeId: state.sessionId, + }, + }; + for (const cb of this.completeCallbacks) cb(routeId, message); + } + return; + } + case 'session.error': { + state.busy = false; + const error = this.makeError( + PROVIDER_ERROR_CODES.PROVIDER_ERROR, + String(event.data?.message ?? 'Copilot session error'), + false, + event, + ); + for (const cb of this.errorCallbacks) cb(routeId, error); + return; + } + default: + return; + } + } + + private async handlePermissionRequest( + routeId: string, + request: Record, + generationOverride?: number, + ): Promise> { + const state = this.getSessionState(routeId); + if (!state) { + return { kind: 'denied-no-approval-rule-and-could-not-request-from-user' }; + } + if (state.cancelRequested) { + return { kind: 'denied-interactively-by-user', feedback: 'Session is cancelling' }; + } + const requestId = randomUUID(); + const generation = generationOverride ?? state.generation; + if (!this.approvalCallbacks.length) { + return { kind: 'denied-no-approval-rule-and-could-not-request-from-user' }; + } + this.emitStatus(routeId, { status: 'permission', label: 'Waiting for approval' }); + return await new Promise>((resolve) => { + const timer = setTimeout(() => { + const pending = state.pendingApprovals.get(requestId); + if (!pending || pending.generation !== generation) return; + state.pendingApprovals.delete(requestId); + pending.resolve({ kind: 'denied-no-approval-rule-and-could-not-request-from-user' }); + this.emitStatus(routeId, { status: null, label: null }); + }, this.approvalTimeoutMs); + state.pendingApprovals.set(requestId, { routeId, requestId, generation, timer, resolve }); + const approvalRequest: ApprovalRequest = { + id: requestId, + description: buildApprovalDescription(request), + ...(isNonEmptyString(request.kind) ? { tool: request.kind } : {}), + }; + for (const cb of this.approvalCallbacks) cb(routeId, approvalRequest); + }); + } + + private async rotatePoisonedSession(state: CopilotSessionState): Promise { + if (state.rotationInProgress || this.poisonedSessionIds.has(state.sessionId)) return; + state.rotationInProgress = true; + const oldSessionId = state.sessionId; + const oldSession = state.session; + this.poisonedSessionIds.add(oldSessionId); + state.unsubscribes.forEach((fn) => fn()); + state.unsubscribes = []; + for (const pending of state.pendingApprovals.values()) { + if (pending.timer) clearTimeout(pending.timer); + pending.resolve({ kind: 'denied-no-approval-rule-and-could-not-request-from-user' }); + } + state.pendingApprovals.clear(); + try { + const freshSession = await this.createSdkSession({ + sessionKey: state.routeId, + cwd: state.cwd, + agentId: state.model, + effort: state.effort, + }, state.model, state.effort); + state.session = freshSession; + state.sessionId = freshSession.sessionId; + state.currentMessageId = null; + state.currentText = ''; + state.completionEmittedForCurrentTurn = false; + state.currentOutputTokens = undefined; + state.currentInteractionId = undefined; + state.busy = false; + state.backgroundTainted = false; + state.cancelRequested = false; + state.cancelErrorEmitted = false; + this.attachSession(state); + this.emitSessionInfo(state.routeId, { + resumeId: state.sessionId, + ...(state.model ? { model: state.model } : {}), + ...(state.effort ? { effort: state.effort } : {}), + }); + } finally { + state.rotationInProgress = false; + } + try { + await oldSession.disconnect?.(); + } catch {} + try { + await this.assertConnected().deleteSession(oldSessionId); + } catch (error) { + this.emitStatus(state.routeId, { + status: 'warning', + label: 'Previous Copilot session could not be deleted', + }); + logger.warn({ err: error, provider: this.id, sessionId: oldSessionId }, 'Failed to delete poisoned Copilot session'); + } + } + + private async replaceSession(state: CopilotSessionState, resumeId: string): Promise { + const oldSessionId = state.sessionId; + const oldSession = state.session; + const resumed = await this.resumeSdkSession(resumeId, { + sessionKey: state.routeId, + cwd: state.cwd, + agentId: state.model, + effort: state.effort, + resumeId, + skipCreate: true, + }, state.model, state.effort); + for (const pending of state.pendingApprovals.values()) { + if (pending.timer) clearTimeout(pending.timer); + pending.resolve({ kind: 'denied-no-approval-rule-and-could-not-request-from-user' }); + } + state.pendingApprovals.clear(); + state.unsubscribes.forEach((fn) => fn()); + state.unsubscribes = []; + state.session = resumed; + state.sessionId = resumed.sessionId; + state.currentMessageId = null; + state.currentText = ''; + state.completionEmittedForCurrentTurn = false; + state.currentOutputTokens = undefined; + state.currentInteractionId = undefined; + state.busy = false; + state.backgroundTainted = false; + state.cancelRequested = false; + state.cancelErrorEmitted = false; + state.rotationInProgress = false; + this.attachSession(state); + try { + await oldSession.disconnect?.(); + } catch {} + if (oldSessionId !== state.sessionId) { + this.poisonedSessionIds.add(oldSessionId); + try { + await this.assertConnected().deleteSession(oldSessionId); + } catch (error) { + this.emitStatus(state.routeId, { + status: 'warning', + label: 'Previous Copilot session could not be deleted', + }); + logger.warn({ err: error, provider: this.id, sessionId: oldSessionId }, 'Failed to delete replaced Copilot session'); + } + } + } + + private getSessionState(sessionId: string): CopilotSessionState | undefined { + const direct = this.sessions.get(sessionId); + if (direct) return direct; + for (const state of this.sessions.values()) { + if (state.sessionId === sessionId) return state; + } + return undefined; + } + + private isCurrentGeneration(state: CopilotSessionState, generation: number): boolean { + return state.generation === generation && !this.poisonedSessionIds.has(state.sessionId); + } + + private emitSessionInfo(sessionId: string, info: SessionInfoUpdate): void { + for (const cb of this.sessionInfoCallbacks) cb(sessionId, info); + } + + private emitStatus(sessionId: string, status: ProviderStatusUpdate): void { + const signature = JSON.stringify(status); + const state = this.sessions.get(sessionId); + if (state && state.lastStatusSignature === signature) return; + if (state) state.lastStatusSignature = signature; + for (const cb of this.statusCallbacks) cb(sessionId, status); + } + + private emitError(sessionId: string, error: ProviderError): void { + for (const cb of this.errorCallbacks) cb(sessionId, error); + } + + private resolveBinaryPath(config: ProviderConfig): string | undefined { + if (isNonEmptyString(config.binaryPath)) return config.binaryPath.trim(); + return resolveBinaryWithWindowsFallbacks(COPILOT_BIN, []); + } + + private resolveDefaultModel(): string | undefined { + return this.config && isNonEmptyString(this.config.agentId) ? this.config.agentId : undefined; + } + + private resolveApprovalTimeoutMs(config: ProviderConfig): number { + const candidate = config.approvalTimeoutMs; + if (typeof candidate === 'number' && Number.isFinite(candidate) && candidate > 0) { + return Math.floor(candidate); + } + return DEFAULT_APPROVAL_TIMEOUT_MS; + } + + private looksBackgroundTainted(args: unknown): boolean { + if (!args || typeof args !== 'object') return false; + const record = args as Record; + const command = isNonEmptyString(record.command) ? record.command.toLowerCase() : ''; + return record.mode === 'async' + || record.background === true + || record.detached === true + || record.runInBackground === true + || record.isBackground === true + || /(^|\s)nohup(\s|$)/.test(command) + || /(^|\s)disown(\s|$)/.test(command) + || /(^|\s)start\s+\/b(\s|$)/.test(command) + || /(^|\s)start-process(\s|$)/.test(command) + || /(^|[^&])&(\s|$)/.test(command); + } + + private shouldIgnoreCancelledEvent(type: string): boolean { + return type !== 'session.idle' + && type !== 'session.background_tasks_changed' + && type !== 'system.notification' + && type !== 'tool.execution_start'; + } + + private markBackgroundTainted(state: CopilotSessionState): void { + state.backgroundTainted = true; + if (state.cancelRequested && !state.rotationInProgress && !this.poisonedSessionIds.has(state.sessionId)) { + void this.rotatePoisonedSession(state).catch((error) => { + logger.error({ err: error, provider: this.id, sessionId: state.routeId }, 'Failed to rotate poisoned Copilot session'); + this.emitError(state.routeId, this.makeError( + PROVIDER_ERROR_CODES.PROVIDER_ERROR, + 'Failed to rotate poisoned Copilot session after cancel', + false, + error, + )); + }); + } + } + + private assertConnected(): CopilotClientLike { + if (!this.client) { + throw this.makeError(PROVIDER_ERROR_CODES.CONNECTION_LOST, 'Copilot SDK provider not connected', false); + } + return this.client; + } + + private normalizeConnectError(error: unknown): ProviderError { + const message = error instanceof Error ? error.message : String(error); + if (/not authenticated|login|log in|sign in/i.test(message)) { + return this.makeError(PROVIDER_ERROR_CODES.AUTH_FAILED, message, false, error); + } + return this.makeError(PROVIDER_ERROR_CODES.CONFIG_ERROR, message, false, error); + } + + private isProviderError(error: unknown): error is ProviderError { + return !!error && typeof error === 'object' && 'code' in error && 'message' in error && 'recoverable' in error; + } + + private makeError(code: string, message: string, recoverable: boolean, details?: unknown): ProviderError { + return { code, message, recoverable, ...(details !== undefined ? { details } : {}) }; + } +} diff --git a/src/agent/providers/cursor-headless-stream.ts b/src/agent/providers/cursor-headless-stream.ts new file mode 100644 index 000000000..cab0bd244 --- /dev/null +++ b/src/agent/providers/cursor-headless-stream.ts @@ -0,0 +1,329 @@ +import type { ToolCallEvent } from '../../../shared/agent-message.js'; + +type CursorRecord = Record; + +export interface CursorSessionInitEvent { + kind: 'session.init'; + raw: CursorRecord; + sessionId?: string; + model?: string; + permissionMode?: string; +} + +export interface CursorAssistantDeltaEvent { + kind: 'assistant.delta'; + raw: CursorRecord; + sessionId?: string; + messageId?: string; + text: string; +} + +export interface CursorAssistantFinalEvent { + kind: 'assistant.final'; + raw: CursorRecord; + sessionId?: string; + messageId?: string; + text: string; +} + +export interface CursorToolStartedEvent { + kind: 'tool.started'; + raw: CursorRecord; + sessionId?: string; + id: string; + name: string; + input?: unknown; +} + +export interface CursorToolCompletedEvent { + kind: 'tool.completed'; + raw: CursorRecord; + sessionId?: string; + id: string; + name: string; + input?: unknown; + output?: unknown; +} + +export interface CursorResultSuccessEvent { + kind: 'result.success'; + raw: CursorRecord; + sessionId?: string; + model?: string; + text?: string; + usage?: Record; +} + +export interface CursorResultErrorEvent { + kind: 'result.error'; + raw: CursorRecord; + sessionId?: string; + message: string; +} + +export interface CursorUnknownEvent { + kind: 'unknown'; + raw: unknown; +} + +export type CursorParsedEvent = + | CursorSessionInitEvent + | CursorAssistantDeltaEvent + | CursorAssistantFinalEvent + | CursorToolStartedEvent + | CursorToolCompletedEvent + | CursorResultSuccessEvent + | CursorResultErrorEvent + | CursorUnknownEvent; + +function isRecord(value: unknown): value is CursorRecord { + return !!value && typeof value === 'object' && !Array.isArray(value); +} + +function pickString(record: CursorRecord, ...keys: string[]): string | undefined { + for (const key of keys) { + const value = record[key]; + if (typeof value === 'string' && value.trim()) return value.trim(); + } + return undefined; +} + +function pickRecord(value: unknown): CursorRecord | undefined { + return isRecord(value) ? value : undefined; +} + +function extractTextFromContent(content: unknown): string | undefined { + if (typeof content === 'string' && content.trim()) return content; + if (!Array.isArray(content)) return undefined; + const parts = content + .map((block) => { + if (!isRecord(block)) return ''; + if (block.type === 'text' && typeof block.text === 'string') return block.text; + if (typeof block.text === 'string') return block.text; + return ''; + }) + .filter(Boolean); + return parts.length > 0 ? parts.join('') : undefined; +} + +function extractToolPayload(record: CursorRecord): { id?: string; name?: string; input?: unknown; output?: unknown } { + const id = pickString(record, 'id', 'tool_call_id', 'toolCallId', 'toolId'); + const name = pickString(record, 'name', 'tool', 'tool_name', 'toolName'); + const input = record.input ?? record.arguments ?? record.params ?? record.payload; + const output = record.output ?? record.result ?? record.stdout ?? record.aggregated_output ?? record.aggregatedOutput; + return { id, name, input, output }; +} + +function extractMessageId(record: CursorRecord): string | undefined { + return pickString(record, 'message_id', 'messageId', 'id'); +} + +function extractSessionId(record: CursorRecord, fallback?: string): string | undefined { + return pickString(record, 'session_id', 'sessionId') ?? fallback; +} + +function extractModel(record: CursorRecord): string | undefined { + return pickString(record, 'model', 'agent'); +} + +function extractPermissionMode(record: CursorRecord): string | undefined { + return pickString(record, 'permissionMode', 'permission_mode'); +} + +function isSuccessResult(record: CursorRecord): boolean { + if (record.is_error === true) return false; + if (typeof record.status === 'string' && /success|completed|done|ok/i.test(record.status)) return true; + if (typeof record.subtype === 'string' && /success/i.test(record.subtype)) return true; + return typeof record.type === 'string' && /result(\.success)?$/i.test(record.type); +} + +function isErrorResult(record: CursorRecord): boolean { + if (record.is_error === true) return true; + if (typeof record.status === 'string' && /error|failed|cancel/i.test(record.status)) return true; + if (typeof record.subtype === 'string' && /error|failed/i.test(record.subtype)) return true; + return typeof record.type === 'string' && /result\.(error|failed)$/i.test(record.type); +} + +function parseCursorRecord(record: unknown, fallbackSessionId?: string): CursorParsedEvent | null { + if (!isRecord(record)) return null; + const sessionId = extractSessionId(record, fallbackSessionId); + const model = extractModel(record); + const permissionMode = extractPermissionMode(record); + const streamEvent = pickRecord(record.event); + + const type = typeof record.type === 'string' ? record.type : ''; + const subtype = typeof record.subtype === 'string' ? record.subtype : ''; + + if (type === 'system.init' || (type === 'system' && subtype === 'init')) { + return { + kind: 'session.init', + raw: record, + sessionId, + model, + permissionMode, + }; + } + + if (type === 'assistant') { + const message = pickRecord(record.message); + const text = extractTextFromContent(message?.content ?? record.text ?? record.content); + if (!text) return null; + return { + kind: 'assistant.final', + raw: record, + sessionId, + messageId: extractMessageId(message ?? record), + text, + }; + } + + if (type === 'user') { + return null; + } + + if ( + type === 'tool_call.started' + || type === 'tool.started' + || (type === 'tool_call' && subtype === 'started') + ) { + const tool = extractToolPayload(record); + if (!tool.id || !tool.name) return null; + return { + kind: 'tool.started', + raw: record, + sessionId, + id: tool.id, + name: tool.name, + ...(tool.input !== undefined ? { input: tool.input } : {}), + }; + } + + if ( + type === 'tool_call.completed' + || type === 'tool.completed' + || (type === 'tool_call' && subtype === 'completed') + ) { + const tool = extractToolPayload(record); + if (!tool.id || !tool.name) return null; + return { + kind: 'tool.completed', + raw: record, + sessionId, + id: tool.id, + name: tool.name, + ...(tool.input !== undefined ? { input: tool.input } : {}), + ...(tool.output !== undefined ? { output: tool.output } : {}), + }; + } + + if (type === 'assistant.delta') { + const text = extractTextFromContent(record.delta ?? record.text ?? record.content); + if (!text) return null; + return { + kind: 'assistant.delta', + raw: record, + sessionId, + messageId: extractMessageId(record), + text, + }; + } + + if (type === 'assistant.final') { + const message = pickRecord(record.message); + const text = extractTextFromContent(record.text ?? record.content ?? message?.content); + if (!text) return null; + return { + kind: 'assistant.final', + raw: record, + sessionId, + messageId: extractMessageId(record) ?? extractMessageId(message ?? {}), + text, + }; + } + + if (type === 'result.success' || (type === 'result' && isSuccessResult(record))) { + const resultText = + extractTextFromContent(record.result) + ?? extractTextFromContent(record.text) + ?? extractTextFromContent(pickRecord(record.message)?.content) + ?? (typeof record.result === 'string' ? record.result : undefined); + const usage = pickRecord(record.usage) ?? pickRecord(pickRecord(record.message)?.usage); + return { + kind: 'result.success', + raw: record, + sessionId, + model, + ...(resultText ? { text: resultText } : {}), + ...(usage ? { usage } : {}), + }; + } + + if (type === 'result.error' || (type === 'result' && isErrorResult(record))) { + const message = + pickString(record, 'message', 'error') + ?? (pickRecord(record.error)?.message as string | undefined) + ?? 'Cursor execution failed'; + return { + kind: 'result.error', + raw: record, + sessionId, + message, + }; + } + + if ( + type === 'stream_event' + && streamEvent + ) { + const event = streamEvent; + if ( + event + && typeof event.type === 'string' + && event.type === 'content_block_delta' + ) { + const delta = pickRecord(event.delta); + if (delta?.type === 'text_delta' && typeof delta.text === 'string') { + return { + kind: 'assistant.delta', + raw: record, + sessionId, + text: delta.text, + }; + } + } + + if ( + event + && typeof event.type === 'string' + && event.type === 'content_block_start' + ) { + const contentBlock = pickRecord(event.content_block); + if (contentBlock?.type === 'tool_use') { + const tool = extractToolPayload(contentBlock); + if (!tool.id || !tool.name) return null; + return { + kind: 'tool.started', + raw: record, + sessionId, + id: tool.id, + name: tool.name, + ...(tool.input !== undefined ? { input: tool.input } : {}), + }; + } + } + } + + return null; +} + +export function parseCursorStreamLine(line: string): CursorParsedEvent | null { + const trimmed = line.trim(); + if (!trimmed) return null; + let parsed: unknown; + try { + parsed = JSON.parse(trimmed) as unknown; + } catch { + return null; + } + return parseCursorRecord(parsed); +} diff --git a/src/agent/providers/cursor-headless.ts b/src/agent/providers/cursor-headless.ts new file mode 100644 index 000000000..b0cca8a52 --- /dev/null +++ b/src/agent/providers/cursor-headless.ts @@ -0,0 +1,761 @@ +import { randomUUID } from 'node:crypto'; +import path from 'node:path'; +import type { ChildProcess } from 'node:child_process'; +import readline from 'node:readline'; +import type { + TransportProvider, + ProviderCapabilities, + ProviderConfig, + ProviderError, + SessionConfig, + SessionInfoUpdate, + ProviderStatusUpdate, + ToolCallEvent, +} from '../transport-provider.js'; +import { + CONNECTION_MODES, + normalizeProviderPayload, + PROVIDER_ERROR_CODES, + SESSION_OWNERSHIP, +} from '../transport-provider.js'; +import type { AgentMessage, MessageDelta } from '../../../shared/agent-message.js'; +import type { ProviderContextPayload } from '../../../shared/context-types.js'; +import type { TransportAttachment } from '../../../shared/transport-attachments.js'; +import logger from '../../util/logger.js'; +import { + normalizeTransportCwd, + resolveBinaryWithWindowsFallbacks, + resolveExecutableForSpawn, + terminateChildProcess, +} from '../transport-paths.js'; +import { + parseCursorStreamLine, + type CursorParsedEvent, +} from './cursor-headless-stream.js'; + +const CURSOR_BIN = 'cursor-agent'; +const CONNECT_PROBE_TIMEOUT_MS = 15_000; +const CANCEL_ESCALATION_MS = 2_000; +const MIN_CURSOR_VERSION = { major: 1, minor: 0, patch: 0 }; + +export interface CursorHeadlessRuntimeHooks { + loadChildProcess(): Promise; +} + +export const cursorHeadlessRuntimeHooks: CursorHeadlessRuntimeHooks = { + loadChildProcess: async () => import('node:child_process'), +}; + +interface CursorSessionState { + routeId: string; + resumeId: string; + cwd: string; + model?: string; + child: ChildProcess | null; + currentMessageId: string | null; + currentText: string; + pendingFinalText?: string; + pendingFinalMetadata?: Record; + cancelled: boolean; + completed: boolean; + emittedToolSignatures: Map; + lastStatusSignature: string | null; +} + +function isTruthyString(value: unknown): value is string { + return typeof value === 'string' && value.trim().length > 0; +} + +function extractString(record: Record, ...keys: string[]): string | undefined { + for (const key of keys) { + const value = record[key]; + if (isTruthyString(value)) return value.trim(); + } + return undefined; +} + +function stringifyUnknown(value: unknown): string | undefined { + if (value == null) return undefined; + if (typeof value === 'string') return value; + try { + return JSON.stringify(value); + } catch { + return String(value); + } +} + +function toProcessEnv(value: unknown): NodeJS.ProcessEnv { + if (!value || typeof value !== 'object') return {}; + return value as NodeJS.ProcessEnv; +} + +function extractResultText(event: CursorParsedEvent): string | undefined { + if (event.kind !== 'result.success') return undefined; + return event.text; +} + +export class CursorHeadlessProvider implements TransportProvider { + readonly id = 'cursor-headless'; + readonly connectionMode = CONNECTION_MODES.LOCAL_SDK; + readonly sessionOwnership = SESSION_OWNERSHIP.SHARED; + readonly capabilities: ProviderCapabilities = { + streaming: true, + toolCalling: true, + approval: false, + sessionRestore: true, + multiTurn: true, + attachments: false, + reasoningEffort: false, + contextSupport: 'degraded-message-side-context-mapping', + }; + + private config: ProviderConfig | null = null; + private sessions = new Map(); + private deltaCallbacks: Array<(sessionId: string, delta: MessageDelta) => void> = []; + private completeCallbacks: Array<(sessionId: string, message: AgentMessage) => void> = []; + private errorCallbacks: Array<(sessionId: string, error: ProviderError) => void> = []; + private toolCallCallbacks: Array<(sessionId: string, tool: ToolCallEvent) => void> = []; + private sessionInfoCallbacks: Array<(sessionId: string, info: SessionInfoUpdate) => void> = []; + private statusCallbacks: Array<(sessionId: string, status: ProviderStatusUpdate) => void> = []; + + async connect(config: ProviderConfig): Promise { + const resolved = resolveExecutableForSpawn(this.resolveBinaryPath(config)); + let versionOutput = ''; + try { + const versionProbe = await this.runExecFile(resolved.executable, [...resolved.prependArgs, '--version'], { + windowsHide: true, + timeout: CONNECT_PROBE_TIMEOUT_MS, + }); + versionOutput = `${versionProbe.stdout}\n${versionProbe.stderr}`.trim(); + } catch (err) { + throw this.normalizeConnectError(err, 'Cursor binary not found or not executable'); + } + const parsedVersion = this.parseCursorVersion(versionOutput); + if (!parsedVersion) { + throw this.makeError( + PROVIDER_ERROR_CODES.CONFIG_ERROR, + `Unable to parse Cursor version from probe output: ${versionOutput || 'empty output'}`, + false, + { output: versionOutput || undefined }, + ); + } + if (!this.isSupportedCursorVersion(parsedVersion)) { + throw this.makeError( + PROVIDER_ERROR_CODES.CONFIG_ERROR, + `Cursor ${parsedVersion.raw} is below required ${MIN_CURSOR_VERSION.major}.${MIN_CURSOR_VERSION.minor}.${MIN_CURSOR_VERSION.patch}`, + false, + { + actualVersion: parsedVersion.raw, + minimumVersion: `${MIN_CURSOR_VERSION.major}.${MIN_CURSOR_VERSION.minor}.${MIN_CURSOR_VERSION.patch}`, + }, + ); + } + try { + const { stdout, stderr } = await this.runExecFile(resolved.executable, [...resolved.prependArgs, 'status'], { + windowsHide: true, + timeout: CONNECT_PROBE_TIMEOUT_MS, + }); + const statusText = `${stdout}\n${stderr}`.trim(); + if (/not\s+logged\s+in|sign\s*in|log\s+in|logged\s+out|unauth/i.test(statusText)) { + throw this.makeError(PROVIDER_ERROR_CODES.AUTH_FAILED, `Cursor is not authenticated: ${statusText || 'status probe reported unauthenticated'}`, false, statusText); + } + if (!/logged\s+in|authenticated|signed\s+in|status:\s*ok/i.test(statusText)) { + throw this.makeError( + PROVIDER_ERROR_CODES.CONFIG_ERROR, + `Unable to determine Cursor authentication state from status probe: ${statusText || 'empty output'}`, + false, + statusText || undefined, + ); + } + } catch (err) { + if (this.isAuthProbeFailure(err)) throw this.normalizeAuthError(err); + throw this.normalizeConnectError(err, 'Cursor status probe failed'); + } + this.config = config; + logger.info({ provider: this.id, resolved: resolved.executable }, 'Cursor headless provider connected'); + } + + async disconnect(): Promise { + for (const state of this.sessions.values()) { + if (state.child && !state.child.killed) { + terminateChildProcess(state.child, CANCEL_ESCALATION_MS); + } + } + this.sessions.clear(); + this.config = null; + } + + async createSession(config: SessionConfig): Promise { + const routeId = config.bindExistingKey ?? config.sessionKey; + const existingEntry = this.findSessionByRouteId(routeId); + if (existingEntry && !config.fresh) { + const [sessionId, state] = existingEntry; + if (isTruthyString(config.agentId)) state.model = config.agentId; + this.emitSessionInfo(sessionId, { + resumeId: state.resumeId, + ...(state.model ? { model: state.model } : {}), + }); + return sessionId; + } + + if (existingEntry && config.fresh) { + await this.endSession(existingEntry[0]).catch(() => {}); + } + + const cwd = normalizeTransportCwd(config.cwd) ?? normalizeTransportCwd(process.cwd())!; + const model = isTruthyString(config.agentId) ? config.agentId : this.resolveDefaultModel(); + const resumeId = + isTruthyString(config.resumeId) + ? config.resumeId + : isTruthyString(config.bindExistingKey) + ? config.bindExistingKey + : config.skipCreate + ? routeId + : await this.createRemoteChat(config, model); + + const state: CursorSessionState = { + routeId, + resumeId, + cwd, + model, + child: null, + currentMessageId: null, + currentText: '', + pendingFinalText: undefined, + pendingFinalMetadata: undefined, + cancelled: false, + completed: false, + emittedToolSignatures: new Map(), + lastStatusSignature: null, + }; + this.sessions.set(routeId, state); + this.emitSessionInfo(routeId, { + resumeId, + ...(model ? { model } : {}), + }); + return routeId; + } + + async endSession(sessionId: string): Promise { + const [resolvedId, state] = this.findSessionByAnyId(sessionId) ?? []; + if (!state) return; + if (state.child && !state.child.killed) { + terminateChildProcess(state.child, CANCEL_ESCALATION_MS); + } + this.sessions.delete(resolvedId ?? sessionId); + } + + onDelta(cb: (sessionId: string, delta: MessageDelta) => void): () => void { + this.deltaCallbacks.push(cb); + return () => { + const idx = this.deltaCallbacks.indexOf(cb); + if (idx >= 0) this.deltaCallbacks.splice(idx, 1); + }; + } + + onComplete(cb: (sessionId: string, message: AgentMessage) => void): () => void { + this.completeCallbacks.push(cb); + return () => { + const idx = this.completeCallbacks.indexOf(cb); + if (idx >= 0) this.completeCallbacks.splice(idx, 1); + }; + } + + onError(cb: (sessionId: string, error: ProviderError) => void): () => void { + this.errorCallbacks.push(cb); + return () => { + const idx = this.errorCallbacks.indexOf(cb); + if (idx >= 0) this.errorCallbacks.splice(idx, 1); + }; + } + + onToolCall(cb: (sessionId: string, tool: ToolCallEvent) => void): void { + this.toolCallCallbacks.push(cb); + } + + onSessionInfo(cb: (sessionId: string, info: SessionInfoUpdate) => void): () => void { + this.sessionInfoCallbacks.push(cb); + return () => { + const idx = this.sessionInfoCallbacks.indexOf(cb); + if (idx >= 0) this.sessionInfoCallbacks.splice(idx, 1); + }; + } + + onStatus(cb: (sessionId: string, status: ProviderStatusUpdate) => void): () => void { + this.statusCallbacks.push(cb); + return () => { + const idx = this.statusCallbacks.indexOf(cb); + if (idx >= 0) this.statusCallbacks.splice(idx, 1); + }; + } + + setSessionAgentId(sessionId: string, agentId: string): void { + const state = this.getSessionState(sessionId); + if (!state) return; + state.model = agentId; + this.emitSessionInfo(this.findSessionIdForState(state) ?? sessionId, { + resumeId: state.resumeId, + model: agentId, + }); + } + + async send(sessionId: string, payloadOrMessage: string | ProviderContextPayload, attachments?: TransportAttachment[], extraSystemPrompt?: string): Promise { + if (!this.config) { + throw this.makeError(PROVIDER_ERROR_CODES.CONNECTION_LOST, 'Cursor headless provider not connected', false); + } + const state = this.getSessionState(sessionId); + if (!state) { + throw this.makeError(PROVIDER_ERROR_CODES.SESSION_NOT_FOUND, `Unknown Cursor session: ${sessionId}`, false); + } + if (state.child && !state.child.killed) { + throw this.makeError(PROVIDER_ERROR_CODES.PROVIDER_ERROR, 'Cursor session is already busy', true); + } + + state.cancelled = false; + state.completed = false; + state.currentMessageId = null; + state.currentText = ''; + state.pendingFinalText = undefined; + state.pendingFinalMetadata = undefined; + state.emittedToolSignatures.clear(); + state.lastStatusSignature = null; + + const payload = normalizeProviderPayload(payloadOrMessage, attachments, extraSystemPrompt); + const prompt = this.composePrompt(payload); + const resolved = resolveExecutableForSpawn(this.resolveBinaryPath(this.config)); + const resumeId = await this.ensureResumeId(state, resolved); + const args = [ + ...resolved.prependArgs, + '-p', + ...(this.getTrustFlag() ? ['--trust'] : []), + ...(this.getForceFlag() ? ['--force'] : []), + '--output-format', + 'stream-json', + '--stream-partial-output', + '--resume', + resumeId, + ...(state.model ? ['--model', state.model] : []), + prompt, + ]; + const { spawn } = await cursorHeadlessRuntimeHooks.loadChildProcess(); + const child = spawn(resolved.executable, args, { + cwd: state.cwd, + env: { + ...process.env, + ...toProcessEnv(this.config.env), + }, + stdio: ['ignore', 'pipe', 'pipe'], + shell: false, + windowsHide: true, + }); + state.child = child; + + let completed = false; + let sawError = false; + let stderrBuf = ''; + + const sessionKey = this.findSessionIdForState(state) ?? sessionId; + const emitError = (error: ProviderError): void => { + if (sawError || completed) return; + sawError = true; + for (const cb of this.errorCallbacks) cb(sessionKey, error); + }; + const emitDelta = (text: string): void => { + const messageId = state.currentMessageId ??= randomUUID(); + state.currentText = text; + const delta: MessageDelta = { + messageId, + type: 'text', + delta: text, + role: 'assistant', + }; + for (const cb of this.deltaCallbacks) cb(sessionKey, delta); + }; + const emitTool = (tool: ToolCallEvent): void => { + const signature = JSON.stringify({ + status: tool.status, + name: tool.name, + input: tool.input ?? null, + output: tool.output ?? null, + }); + if (state.emittedToolSignatures.get(tool.id) === signature) return; + state.emittedToolSignatures.set(tool.id, signature); + for (const cb of this.toolCallCallbacks) cb(sessionKey, tool); + }; + const emitSessionInfoUpdate = (info: SessionInfoUpdate): void => { + this.emitSessionInfo(sessionKey, info); + }; + + const rl = readline.createInterface({ input: child.stdout! }); + rl.on('line', (line) => { + const event = parseCursorStreamLine(line); + if (!event) return; + + if (event.kind === 'session.init') { + if (event.sessionId) { + state.resumeId = event.sessionId; + } + if (event.model) { + state.model = event.model; + } + emitSessionInfoUpdate({ + resumeId: state.resumeId, + ...(state.model ? { model: state.model } : {}), + }); + return; + } + + if (event.kind === 'assistant.delta') { + const chunk = event.text; + if (chunk) { + const nextText = chunk.startsWith(state.currentText) + ? chunk + : state.currentText + chunk; + if (nextText !== state.currentText) { + emitDelta(nextText); + } + } + if (event.messageId) { + state.currentMessageId = event.messageId; + } + return; + } + + if (event.kind === 'assistant.final') { + if (event.messageId) { + state.currentMessageId = event.messageId; + } + state.pendingFinalText = event.text; + return; + } + + if (event.kind === 'tool.started') { + emitTool({ + id: event.id, + name: event.name, + status: 'running', + ...(event.input !== undefined ? { input: event.input } : {}), + detail: { + kind: 'tool_call.started', + summary: event.name, + input: event.input, + raw: event.raw, + }, + }); + return; + } + + if (event.kind === 'tool.completed') { + emitTool({ + id: event.id, + name: event.name, + status: 'complete', + ...(event.input !== undefined ? { input: event.input } : {}), + ...(event.output !== undefined ? { output: stringifyUnknown(event.output) } : {}), + detail: { + kind: 'tool_call.completed', + summary: event.name, + input: event.input, + output: event.output, + raw: event.raw, + }, + }); + return; + } + + if (event.kind === 'result.success') { + const finalText = extractResultText(event) ?? state.pendingFinalText ?? state.currentText; + completed = true; + state.completed = true; + state.child = null; + state.currentMessageId ??= randomUUID(); + const message: AgentMessage = { + id: state.currentMessageId, + sessionId: sessionKey, + kind: 'text', + role: 'assistant', + content: finalText ?? '', + timestamp: Date.now(), + status: 'complete', + metadata: { + ...(event.model ? { model: event.model } : {}), + ...(event.usage ? { usage: event.usage } : {}), + ...(state.resumeId ? { resumeId: state.resumeId } : {}), + }, + }; + for (const cb of this.completeCallbacks) cb(sessionKey, message); + return; + } + + if (event.kind === 'result.error') { + state.completed = true; + completed = false; + state.child = null; + emitError(this.makeError(PROVIDER_ERROR_CODES.PROVIDER_ERROR, event.message, false, event.raw)); + } + }); + + child.stderr?.on('data', (chunk: Buffer | string) => { + stderrBuf += chunk.toString(); + logger.debug({ provider: this.id, stderr: chunk.toString().trim() }, 'Cursor headless stderr'); + }); + + child.once('close', (code, signal) => { + rl.close(); + state.child = null; + if (completed || sawError) return; + if (state.cancelled) { + emitError(this.makeError(PROVIDER_ERROR_CODES.CANCELLED, 'Cursor turn cancelled', true, { code, signal })); + return; + } + const text = state.pendingFinalText ?? state.currentText; + if (typeof code === 'number' && code === 0 && text) { + completed = true; + state.completed = true; + const finalMessage: AgentMessage = { + id: state.currentMessageId ?? randomUUID(), + sessionId: sessionKey, + kind: 'text', + role: 'assistant', + content: text, + timestamp: Date.now(), + status: 'complete', + metadata: { + ...(state.resumeId ? { resumeId: state.resumeId } : {}), + ...(state.model ? { model: state.model } : {}), + }, + }; + for (const cb of this.completeCallbacks) cb(sessionKey, finalMessage); + return; + } + emitError(this.makeError( + signal || code === 0 ? PROVIDER_ERROR_CODES.PROVIDER_ERROR : PROVIDER_ERROR_CODES.PROVIDER_ERROR, + stderrBuf.trim() || `Cursor exited with code ${code ?? 'null'}${signal ? ` (${signal})` : ''}`, + false, + { code, signal, stderr: stderrBuf.trim() || undefined }, + )); + }); + + await new Promise((resolve, reject) => { + child.once('spawn', () => resolve()); + child.once('error', (err) => reject(this.normalizeConnectError(err, 'Cursor child process failed to start'))); + }); + child.on('error', (err) => { + const message = err instanceof Error ? err.message : String(err); + emitError(this.makeError(PROVIDER_ERROR_CODES.PROVIDER_ERROR, message, false, err)); + }); + } + + async restoreSession(sessionId: string): Promise { + return !!this.getSessionState(sessionId); + } + + async cancel(sessionId: string): Promise { + const state = this.getSessionState(sessionId); + if (!state?.child || state.child.killed) return; + state.cancelled = true; + terminateChildProcess(state.child, CANCEL_ESCALATION_MS); + } + + private resolveBinaryPath(config: ProviderConfig | null): string { + const explicit = isTruthyString(config?.binaryPath) ? config.binaryPath.trim() : undefined; + if (explicit) return explicit; + if (process.platform === 'win32') { + const localAppData = process.env.LOCALAPPDATA; + const windowsCandidates = localAppData + ? [ + path.join(localAppData, 'cursor-agent', 'cursor-agent.exe'), + path.join(localAppData, 'cursor-agent', 'agent.exe'), + ] + : []; + return resolveBinaryWithWindowsFallbacks(CURSOR_BIN, windowsCandidates); + } + return CURSOR_BIN; + } + + private resolveDefaultModel(): string | undefined { + return isTruthyString(this.config?.agentId) ? this.config!.agentId : undefined; + } + + private parseCursorVersion(output: string): { major: number; minor: number; patch: number; raw: string } | null { + const match = output.match(/(\d+)\.(\d+)\.(\d+)/); + if (!match) return null; + return { + major: Number(match[1]), + minor: Number(match[2]), + patch: Number(match[3]), + raw: `${match[1]}.${match[2]}.${match[3]}`, + }; + } + + private isSupportedCursorVersion(version: { major: number; minor: number; patch: number }): boolean { + if (version.major !== MIN_CURSOR_VERSION.major) return version.major > MIN_CURSOR_VERSION.major; + if (version.minor !== MIN_CURSOR_VERSION.minor) return version.minor > MIN_CURSOR_VERSION.minor; + return version.patch >= MIN_CURSOR_VERSION.patch; + } + + private getTrustFlag(): boolean { + return this.config?.trust !== false; + } + + private getForceFlag(): boolean { + return this.config?.force !== false; + } + + private composePrompt(payload: ProviderContextPayload): string { + const parts = [payload.systemText?.trim(), payload.assembledMessage?.trim()].filter((part): part is string => !!part && part.length > 0); + return parts.join('\n\n'); + } + + private async createRemoteChat(config: SessionConfig, model?: string): Promise { + const resolved = resolveExecutableForSpawn(this.resolveBinaryPath(this.config)); + const { stdout, stderr } = await this.runExecFile(resolved.executable, [...resolved.prependArgs, 'create-chat'], { + windowsHide: true, + timeout: CONNECT_PROBE_TIMEOUT_MS, + env: { + ...process.env, + ...toProcessEnv(this.config?.env), + }, + cwd: normalizeTransportCwd(config.cwd) ?? normalizeTransportCwd(process.cwd())!, + }); + const chatId = this.extractChatId(stdout, stderr); + if (!chatId) { + throw this.makeError(PROVIDER_ERROR_CODES.PROVIDER_ERROR, 'Cursor create-chat did not return a chat id', false, { stdout, stderr, model }); + } + return chatId; + } + + private extractChatId(stdout: string, stderr: string): string | undefined { + const candidates = [stdout, stderr]; + for (const chunk of candidates) { + if (!chunk) continue; + const trimmed = chunk.trim(); + if (!trimmed) continue; + try { + const parsed = JSON.parse(trimmed) as unknown; + if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) { + const record = parsed as Record; + const sessionId = extractString(record, 'session_id', 'sessionId', 'chat_id', 'chatId', 'id'); + if (sessionId) return sessionId; + if (record.result && typeof record.result === 'object' && !Array.isArray(record.result)) { + const result = record.result as Record; + const nested = extractString(result, 'session_id', 'sessionId', 'chat_id', 'chatId', 'id'); + if (nested) return nested; + } + } + } catch { + // fall back to plain text parsing + } + const match = trimmed.match(/[A-Za-z0-9][A-Za-z0-9._:-]{6,}/); + if (match) return match[0]; + } + return undefined; + } + + private findSessionByRouteId(routeId: string): [string, CursorSessionState] | undefined { + for (const entry of this.sessions.entries()) { + if (entry[1].routeId === routeId) return entry; + } + return undefined; + } + + private findSessionByAnyId(sessionId: string): [string, CursorSessionState] | undefined { + const direct = this.sessions.get(sessionId); + if (direct) return [sessionId, direct]; + const byResumeId = [...this.sessions.entries()].find((entry) => entry[1].resumeId === sessionId); + if (byResumeId) return byResumeId; + return this.findSessionByRouteId(sessionId); + } + + private getSessionState(sessionId: string): CursorSessionState | undefined { + return this.findSessionByAnyId(sessionId)?.[1]; + } + + private findSessionIdForState(state: CursorSessionState): string | undefined { + for (const [sessionId, candidate] of this.sessions.entries()) { + if (candidate === state) return sessionId; + } + return undefined; + } + + private async ensureResumeId(state: CursorSessionState, resolved: { executable: string; prependArgs: string[] }): Promise { + if (isTruthyString(state.resumeId)) return state.resumeId; + const { stdout, stderr } = await this.runExecFile(resolved.executable, [...resolved.prependArgs, 'create-chat'], { + windowsHide: true, + timeout: CONNECT_PROBE_TIMEOUT_MS, + env: { + ...process.env, + ...toProcessEnv(this.config?.env), + }, + cwd: state.cwd, + }); + const chatId = this.extractChatId(stdout, stderr); + if (!chatId) { + throw this.makeError(PROVIDER_ERROR_CODES.PROVIDER_ERROR, 'Cursor create-chat did not return a chat id', false, { stdout, stderr }); + } + state.resumeId = chatId; + this.emitSessionInfo(this.findSessionIdForState(state) ?? state.routeId, { + resumeId: chatId, + ...(state.model ? { model: state.model } : {}), + }); + return chatId; + } + + private emitSessionInfo(sessionId: string, info: SessionInfoUpdate): void { + for (const cb of this.sessionInfoCallbacks) cb(sessionId, info); + } + + private async runExecFile( + executable: string, + args: string[], + options: { + windowsHide?: boolean; + timeout?: number; + env?: NodeJS.ProcessEnv; + cwd?: string; + }, + ): Promise<{ stdout: string; stderr: string }> { + const { execFile } = await cursorHeadlessRuntimeHooks.loadChildProcess(); + return await new Promise<{ stdout: string; stderr: string }>((resolve, reject) => { + execFile(executable, args, options, (error, stdout, stderr) => { + if (error) { + reject(error); + return; + } + resolve({ + stdout: typeof stdout === 'string' ? stdout : String(stdout ?? ''), + stderr: typeof stderr === 'string' ? stderr : String(stderr ?? ''), + }); + }); + }); + } + + private normalizeConnectError(err: unknown, fallbackMessage: string): ProviderError { + const message = err instanceof Error ? err.message : String(err); + if (/ENOENT|not found|spawn .*cursor-agent/i.test(message)) { + return this.makeError(PROVIDER_ERROR_CODES.PROVIDER_NOT_FOUND, `Cursor binary not found: ${message}`, false, err); + } + if (/not\s+logged\s+in|sign\s*in|log\s+in|unauth/i.test(message)) { + return this.makeError(PROVIDER_ERROR_CODES.AUTH_FAILED, `Cursor authentication failed: ${message}`, false, err); + } + return this.makeError(PROVIDER_ERROR_CODES.CONFIG_ERROR, `${fallbackMessage}: ${message}`, false, err); + } + + private normalizeAuthError(err: unknown): ProviderError { + const message = err instanceof Error ? err.message : String(err); + return this.makeError(PROVIDER_ERROR_CODES.AUTH_FAILED, `Cursor authentication failed: ${message}`, false, err); + } + + private isAuthProbeFailure(err: unknown): boolean { + if (err && typeof err === 'object' && 'code' in err) { + const code = (err as { code?: unknown }).code; + if (code === PROVIDER_ERROR_CODES.AUTH_FAILED) return true; + } + const message = err instanceof Error ? err.message : String(err); + return /not\s+logged\s+in|sign\s*in|log\s+in|logged\s+out|unauth/i.test(message); + } + + private makeError(code: string, message: string, recoverable: boolean, details?: unknown): ProviderError { + return { code, message, recoverable, ...(details !== undefined ? { details } : {}) }; + } +} diff --git a/src/agent/providers/openclaw.ts b/src/agent/providers/openclaw.ts index ddd79e1f6..9a39ab88c 100644 --- a/src/agent/providers/openclaw.ts +++ b/src/agent/providers/openclaw.ts @@ -27,6 +27,7 @@ import { } from '../transport-provider.js'; import type { AgentMessage, MessageDelta, ToolCallEvent } from '../../../shared/agent-message.js'; import type { ProviderContextPayload } from '../../../shared/context-types.js'; +import type { TransportAttachment } from '../../../shared/transport-attachments.js'; import logger from '../../util/logger.js'; import { normalizeOpenClawDisplayName } from '../openclaw-display.js'; import { OPENCLAW_THINKING_LEVELS, type TransportEffortLevel } from '../../../shared/effort-levels.js'; @@ -162,7 +163,7 @@ export class OpenClawProvider implements TransportProvider { logger.info({ provider: this.id }, 'Disconnected from OpenClaw gateway'); } - async send(sessionId: string, payloadOrMessage: string | ProviderContextPayload, _attachments?: unknown[], extraSystemPrompt?: string): Promise { + async send(sessionId: string, payloadOrMessage: string | ProviderContextPayload, _attachments?: TransportAttachment[], extraSystemPrompt?: string): Promise { const payload = normalizeProviderPayload(payloadOrMessage, _attachments, extraSystemPrompt); const ocKey = unsanitizeKey(sessionId); const thinking = this.sessionThinking.get(sessionId) ?? 'off'; diff --git a/src/agent/providers/qwen.ts b/src/agent/providers/qwen.ts index 325d12e7c..e7697fa64 100644 --- a/src/agent/providers/qwen.ts +++ b/src/agent/providers/qwen.ts @@ -23,6 +23,7 @@ import { } from '../transport-provider.js'; import type { AgentMessage, MessageDelta } from '../../../shared/agent-message.js'; import type { ProviderContextPayload } from '../../../shared/context-types.js'; +import type { TransportAttachment } from '../../../shared/transport-attachments.js'; import { DEFAULT_TRANSPORT_EFFORT, QWEN_EFFORT_LEVELS, type TransportEffortLevel } from '../../../shared/effort-levels.js'; import logger from '../../util/logger.js'; import { inferContextWindow } from '../../util/model-context.js'; @@ -311,7 +312,7 @@ export class QwenProvider implements TransportProvider { async send( sessionId: string, payloadOrMessage: string | ProviderContextPayload, - _attachments?: unknown[], + _attachments?: TransportAttachment[], extraSystemPrompt?: string, allowResumeFallback = true, ): Promise { diff --git a/src/agent/session-manager.ts b/src/agent/session-manager.ts index 62d919534..9d7a8ac60 100644 --- a/src/agent/session-manager.ts +++ b/src/agent/session-manager.ts @@ -762,6 +762,8 @@ export interface LaunchOpts { geminiSessionId?: string; /** OpenCode session ID for `opencode -s `. */ opencodeSessionId?: string; + /** Provider-side durable resume identifier for shared local-sdk providers. */ + providerResumeId?: string; /** Qwen model ID for `qwen --model `. */ qwenModel?: string; /** Unified requested transport model for launch/restore. */ @@ -857,6 +859,8 @@ export async function relaunchSessionWithSettings( // codexSessionId and therefore use a fresh local route key on relaunch. && targetAgentType !== 'claude-code-sdk' && targetAgentType !== 'codex-sdk' + && targetAgentType !== 'copilot-sdk' + && targetAgentType !== 'cursor-headless' && typeof record.providerSessionId === 'string' && record.providerSessionId.length > 0; @@ -996,6 +1000,10 @@ function wireTransportSessionInfo(runtime: TransportSessionRuntime, sessionName: next.codexSessionId = info.resumeId; changed = true; } + if ((agentType === 'cursor-headless' || agentType === 'copilot-sdk') && next.providerResumeId !== info.resumeId) { + next.providerResumeId = info.resumeId; + changed = true; + } if (agentType === 'qwen' && next.providerSessionId !== info.resumeId) { if (next.providerSessionId) unregisterProviderRoute(next.providerSessionId); next.providerSessionId = info.resumeId; @@ -1119,13 +1127,18 @@ export async function restoreTransportSessions(providerId: string): Promise | undefined; let systemPrompt: string | undefined; let transportSettings: string | Record | undefined; @@ -1155,8 +1168,8 @@ export async function restoreTransportSessions(providerId: string): Promise { let transportSystemPrompt: string | undefined; let transportSettings: string | Record | undefined; const storedRequestedModel = !opts.fresh ? existing?.requestedModel : undefined; + const storedProviderResumeId = !opts.fresh ? existing?.providerResumeId : undefined; let requestedTransportModel = opts.requestedModel ?? storedRequestedModel ?? (agentType === 'qwen' ? (opts.qwenModel ?? existing?.qwenModel) : undefined); // Preserve existing transportConfig (including supervision) when opts doesn't override. // Only fall through to `undefined` if nothing is set — never force `{}`, which would @@ -1333,6 +1347,13 @@ export async function launchTransportSession(opts: LaunchOpts): Promise { effectiveSkipCreate = true; } sdkDisplay = await getCodexRuntimeConfig().catch(() => ({})); + } else if (agentType === 'cursor-headless' || agentType === 'copilot-sdk') { + effectiveSessionKey = randomUUID(); + effectiveBindExistingKey = undefined; + transportResumeId = opts.providerResumeId ?? storedProviderResumeId; + if (transportResumeId) { + effectiveSkipCreate = true; + } } // Create session on provider @@ -1378,6 +1399,9 @@ export async function launchTransportSession(opts: LaunchOpts): Promise { runtimeType: RUNTIME_TYPES.TRANSPORT, providerId: provider.id, providerSessionId: runtime.providerSessionId ?? undefined, + ...((agentType === 'copilot-sdk' || agentType === 'cursor-headless') && transportResumeId + ? { providerResumeId: transportResumeId } + : {}), ...(agentType === 'claude-code-sdk' && transportResumeId ? { ccSessionId: transportResumeId } : {}), ...(agentType === 'codex-sdk' && transportResumeId ? { codexSessionId: transportResumeId } : {}), contextNamespace: contextBootstrap.namespace, diff --git a/src/agent/transport-paths.ts b/src/agent/transport-paths.ts index 1f8ce9b28..383bc4673 100644 --- a/src/agent/transport-paths.ts +++ b/src/agent/transport-paths.ts @@ -1,5 +1,6 @@ import path from 'node:path'; import { existsSync, readFileSync } from 'node:fs'; +import type { ChildProcess } from 'node:child_process'; export function normalizeTransportCwd(cwd?: string): string | undefined { if (typeof cwd !== 'string' || !cwd.trim()) return undefined; @@ -49,6 +50,14 @@ export function resolveBinaryOnWindows(name: string): string { return name; } +export function resolveBinaryWithWindowsFallbacks(name: string, windowsCandidates: string[] = []): string { + if (process.platform !== 'win32') return name; + for (const candidate of windowsCandidates) { + if (candidate && existsSync(candidate)) return candidate; + } + return resolveBinaryOnWindows(name); +} + /** Result of resolving a binary that may be an npm .cmd shim. * When the resolved path is a real .exe, just `{ executable }`. * When it's a Windows .cmd shim, returns the underlying node script so @@ -94,6 +103,15 @@ export function resolveExecutableForSpawn(name: string): ResolvedExecutable { return { executable: resolved, prependArgs: [] }; } +export function terminateChildProcess(child: ChildProcess, escalationMs = 1_500): void { + if (child.killed) return; + child.kill('SIGTERM'); + const timer = setTimeout(() => { + if (!child.killed) child.kill('SIGKILL'); + }, escalationMs); + child.once('close', () => clearTimeout(timer)); +} + /** Parse an npm-generated `.cmd` shim and return the absolute path of the * node script it invokes. Returns null if the shim format isn't recognized. */ export function parseNpmCmdShim(cmdPath: string): string | null { diff --git a/src/agent/transport-provider.ts b/src/agent/transport-provider.ts index e8b13e7eb..9b5f58956 100644 --- a/src/agent/transport-provider.ts +++ b/src/agent/transport-provider.ts @@ -14,6 +14,7 @@ import type { AgentMessage, MessageDelta, ToolCallEvent } from '../../shared/age import type { TransportEffortLevel } from '../../shared/effort-levels.js'; import type { SessionContextBootstrapState } from '../../shared/session-context-bootstrap.js'; import type { ProviderQuotaMeta } from '../../shared/provider-quota.js'; +import type { TransportAttachment } from '../../shared/transport-attachments.js'; import type { ProviderContextPayload, ProviderSupportClass, @@ -269,7 +270,7 @@ export interface TransportProvider { * @param message - The user's text message. * @param attachments - Optional file/image attachments (only when capabilities.attachments is true). */ - send(sessionId: string, payload: string | ProviderContextPayload, attachments?: unknown[], extraSystemPrompt?: string): Promise; + send(sessionId: string, payload: string | ProviderContextPayload, attachments?: TransportAttachment[], extraSystemPrompt?: string): Promise; /** * Best-effort cancellation of the current in-flight turn for a session. @@ -373,7 +374,7 @@ export interface TransportProvider { export function normalizeProviderPayload( payload: string | ProviderContextPayload, - attachments?: unknown[], + attachments?: TransportAttachment[], extraSystemPrompt?: string, ): ProviderContextPayload { if (typeof payload !== 'string') { diff --git a/src/agent/transport-runtime-assembly.ts b/src/agent/transport-runtime-assembly.ts index 0439d8b8f..5cdaf2767 100644 --- a/src/agent/transport-runtime-assembly.ts +++ b/src/agent/transport-runtime-assembly.ts @@ -1,4 +1,5 @@ import type { TransportProvider } from './transport-provider.js'; +import type { TransportAttachment } from '../../shared/transport-attachments.js'; import { selectRuntimeAuthoredContext } from './authored-context.js'; import { evaluateContextAuthority } from './context-authority.js'; import { buildContextDiagnostics } from './context-diagnostics.js'; @@ -20,7 +21,7 @@ export interface TransportRuntimeAssemblyInput { description?: string; systemPrompt?: string; messagePreamble?: string; - attachments?: unknown[]; + attachments?: TransportAttachment[]; namespace?: ContextNamespace; namespaceDiagnostics?: string[]; remoteProcessedFreshness?: 'fresh' | 'stale' | 'missing'; diff --git a/src/daemon/lifecycle.ts b/src/daemon/lifecycle.ts index 0440c7b51..6ee6751bf 100644 --- a/src/daemon/lifecycle.ts +++ b/src/daemon/lifecycle.ts @@ -733,7 +733,7 @@ async function autoReconnectProviders(): Promise { const { connectProvider, ensureProviderConnected } = await import('../agent/provider-registry.js'); const { restoreTransportSessions } = await import('../agent/session-manager.js'); - for (const providerId of ['qwen', 'claude-code-sdk', 'codex-sdk'] as const) { + for (const providerId of ['qwen', 'claude-code-sdk', 'codex-sdk', 'cursor-headless', 'copilot-sdk'] as const) { if (!listSessions().some((s) => s.runtimeType === 'transport' && s.providerId === providerId)) continue; try { await ensureProviderConnected(providerId, {}); diff --git a/src/daemon/transport-relay.ts b/src/daemon/transport-relay.ts index 2552232bf..f36af4eed 100644 --- a/src/daemon/transport-relay.ts +++ b/src/daemon/transport-relay.ts @@ -7,7 +7,7 @@ */ import type { TransportProvider, ProviderError, ProviderStatusUpdate } from '../agent/transport-provider.js'; import type { MessageDelta, AgentMessage, ToolCallEvent } from '../../shared/agent-message.js'; -import { TRANSPORT_MSG } from '../../shared/transport-events.js'; +import { TRANSPORT_EVENT, TRANSPORT_MSG } from '../../shared/transport-events.js'; import { resolveSessionName } from '../agent/session-manager.js'; import { timelineEmitter } from './timeline-emitter.js'; import { appendTransportEvent } from './transport-history.js'; @@ -414,6 +414,24 @@ export function wireProviderToRelay(provider: TransportProvider): void { ...(status.label !== undefined ? { label: status.label } : {}), }, { source: 'daemon', confidence: 'high' }); }); + + provider.onApprovalRequest?.((providerSid: string, request) => { + const sessionName = resolveSessionName(providerSid); + if (!sessionName) { + logger.debug({ providerSid }, 'transport-relay: unresolved route for approval — dropped'); + return; + } + + const payload = { + type: TRANSPORT_EVENT.CHAT_APPROVAL, + sessionId: sessionName, + requestId: request.id, + description: request.description, + ...(request.tool ? { tool: request.tool } : {}), + } as const; + sendToServer?.(payload); + void appendTransportEvent(sessionName, payload); + }); } /** Emit user.message through timeline when user sends to a transport session. */ diff --git a/src/store/session-store.ts b/src/store/session-store.ts index 751591b32..2d091aa03 100644 --- a/src/store/session-store.ts +++ b/src/store/session-store.ts @@ -78,6 +78,8 @@ export interface SessionRecord extends SessionContextBootstrapState { providerId?: string; /** Provider-side session ID/key. For OpenClaw this is the OC session key. */ providerSessionId?: string; + /** Provider-side durable resume/session identifier for shared local-sdk providers. */ + providerResumeId?: string; /** Session description — used for persona/system prompt injection. */ description?: string; /** CC env preset name — persisted so respawn can re-inject the same env vars. */ diff --git a/test/agent/provider-registry.test.ts b/test/agent/provider-registry.test.ts index 6a9a21dca..ba72f9396 100644 --- a/test/agent/provider-registry.test.ts +++ b/test/agent/provider-registry.test.ts @@ -2,7 +2,7 @@ import { describe, it, expect, vi, beforeEach } from 'vitest'; // ── Hoisted mocks ───────────────────────────────────────────────────────────── -const { mockConnect, mockDisconnect, MockOpenClawProvider, MockQwenProvider, MockClaudeCodeSdkProvider, MockCodexSdkProvider } = vi.hoisted(() => { +const { mockConnect, mockDisconnect, MockOpenClawProvider, MockQwenProvider, MockClaudeCodeSdkProvider, MockCodexSdkProvider, MockCursorHeadlessProvider, MockCopilotSdkProvider } = vi.hoisted(() => { const mockConnect = vi.fn().mockResolvedValue(undefined); const mockDisconnect = vi.fn().mockResolvedValue(undefined); const MockOpenClawProvider = vi.fn().mockImplementation(() => ({ @@ -89,7 +89,50 @@ const { mockConnect, mockDisconnect, MockOpenClawProvider, MockQwenProvider, Moc createSession: vi.fn().mockResolvedValue('session-1'), endSession: vi.fn().mockResolvedValue(undefined), })); - return { mockConnect, mockDisconnect, MockOpenClawProvider, MockQwenProvider, MockClaudeCodeSdkProvider, MockCodexSdkProvider }; + const MockCursorHeadlessProvider = vi.fn().mockImplementation(() => ({ + id: 'cursor-headless', + connectionMode: 'local-sdk', + sessionOwnership: 'shared', + capabilities: { + streaming: true, + toolCalling: true, + approval: false, + sessionRestore: true, + multiTurn: true, + attachments: false, + }, + connect: mockConnect, + disconnect: mockDisconnect, + send: vi.fn().mockResolvedValue(undefined), + onDelta: vi.fn(), + onComplete: vi.fn(), + onError: vi.fn(), + createSession: vi.fn().mockResolvedValue('route-1'), + endSession: vi.fn().mockResolvedValue(undefined), + })); + const MockCopilotSdkProvider = vi.fn().mockImplementation(() => ({ + id: 'copilot-sdk', + connectionMode: 'local-sdk', + sessionOwnership: 'shared', + capabilities: { + streaming: true, + toolCalling: true, + approval: true, + sessionRestore: true, + multiTurn: true, + attachments: true, + reasoningEffort: true, + }, + connect: mockConnect, + disconnect: mockDisconnect, + send: vi.fn().mockResolvedValue(undefined), + onDelta: vi.fn(), + onComplete: vi.fn(), + onError: vi.fn(), + createSession: vi.fn().mockResolvedValue('route-2'), + endSession: vi.fn().mockResolvedValue(undefined), + })); + return { mockConnect, mockDisconnect, MockOpenClawProvider, MockQwenProvider, MockClaudeCodeSdkProvider, MockCodexSdkProvider, MockCursorHeadlessProvider, MockCopilotSdkProvider }; }); vi.mock('../../src/agent/providers/openclaw.js', () => ({ @@ -108,6 +151,14 @@ vi.mock('../../src/agent/providers/codex-sdk.js', () => ({ CodexSdkProvider: MockCodexSdkProvider, })); +vi.mock('../../src/agent/providers/cursor-headless.js', () => ({ + CursorHeadlessProvider: MockCursorHeadlessProvider, +})); + +vi.mock('../../src/agent/providers/copilot-sdk.js', () => ({ + CopilotSdkProvider: MockCopilotSdkProvider, +})); + vi.mock('../../src/util/logger.js', () => ({ default: { info: vi.fn(), @@ -172,6 +223,20 @@ describe('getProvider', () => { expect(provider!.id).toBe('codex-sdk'); }); + it('returns cursor-headless after connectProvider()', async () => { + await connectProvider('cursor-headless', CONFIG); + const provider = getProvider('cursor-headless'); + expect(provider).toBeDefined(); + expect(provider!.id).toBe('cursor-headless'); + }); + + it('returns copilot-sdk after connectProvider()', async () => { + await connectProvider('copilot-sdk', CONFIG); + const provider = getProvider('copilot-sdk'); + expect(provider).toBeDefined(); + expect(provider!.id).toBe('copilot-sdk'); + }); + it('returns undefined for an unknown id', () => { expect(getProvider('minimax')).toBeUndefined(); }); @@ -203,6 +268,18 @@ describe('connectProvider', () => { expect(mockConnect).toHaveBeenCalledWith(CONFIG); }); + it('instantiates CursorHeadlessProvider and calls connect()', async () => { + await connectProvider('cursor-headless', CONFIG); + expect(MockCursorHeadlessProvider).toHaveBeenCalledOnce(); + expect(mockConnect).toHaveBeenCalledWith(CONFIG); + }); + + it('instantiates CopilotSdkProvider and calls connect()', async () => { + await connectProvider('copilot-sdk', CONFIG); + expect(MockCopilotSdkProvider).toHaveBeenCalledOnce(); + expect(mockConnect).toHaveBeenCalledWith(CONFIG); + }); + it('throws for an unknown provider id', async () => { await expect(connectProvider('unknown-provider', CONFIG)).rejects.toThrow( 'Unknown provider: unknown-provider', diff --git a/test/agent/providers/copilot-sdk-harness.ts b/test/agent/providers/copilot-sdk-harness.ts new file mode 100644 index 000000000..64a069725 --- /dev/null +++ b/test/agent/providers/copilot-sdk-harness.ts @@ -0,0 +1,210 @@ +import { EventEmitter } from 'node:events'; + +type SessionConfig = Record & { + onPermissionRequest?: (request: Record, invocation: { sessionId: string }) => Promise | unknown; +}; + +export interface CopilotHarnessState { + clientCalls: { + start: number; + stop: number; + getStatus: number; + getAuthStatus: number; + listModels: number; + deleteSession: string[]; + }; + status: { + version: string; + protocolVersion: number; + }; + auth: { + isAuthenticated: boolean; + statusMessage?: string; + }; + models: Array<{ id: string; displayName?: string }>; + startError: Error | null; + statusError: Error | null; + authError: Error | null; + modelsError: Error | null; + deleteSessionError: Error | null; + keepDeletedSessions: boolean; +} + +export interface CopilotSpawnedSession { + sessionId: string; + config: SessionConfig; + sendCalls: Array>; + setModelCalls: Array<{ model: string; options?: Record }>; + abortCalls: number; + disconnectCalls: number; + active: boolean; + emitter: EventEmitter; + emit(event: Record): void; + requestPermission(request: Record): Promise; +} + +export function createCopilotSdkHarness() { + const state: CopilotHarnessState = { + clientCalls: { + start: 0, + stop: 0, + getStatus: 0, + getAuthStatus: 0, + listModels: 0, + deleteSession: [], + }, + status: { version: '1.0.31', protocolVersion: 3 }, + auth: { isAuthenticated: true, statusMessage: 'Logged in' }, + models: [{ id: 'gpt-5.4' }, { id: 'gpt-5.4-mini' }], + startError: null, + statusError: null, + authError: null, + modelsError: null, + deleteSessionError: null, + keepDeletedSessions: true, + }; + + const sessions: CopilotSpawnedSession[] = []; + const clients: FakeCopilotClient[] = []; + + class FakeCopilotSession { + readonly sessionId: string; + readonly config: SessionConfig; + readonly emitter = new EventEmitter(); + sendCalls: Array> = []; + setModelCalls: Array<{ model: string; options?: Record }> = []; + abortCalls = 0; + disconnectCalls = 0; + active = true; + + constructor(sessionId: string, config: SessionConfig) { + this.sessionId = sessionId; + this.config = config; + } + + async send(options: Record): Promise { + this.sendCalls.push(options); + } + + async abort(): Promise { + this.abortCalls += 1; + this.emitter.emit('aborted'); + } + + async setModel(model: string, options?: Record): Promise { + this.setModelCalls.push({ model, options }); + } + + async disconnect(): Promise { + this.disconnectCalls += 1; + this.active = false; + } + + requestPermission(request: Record): Promise { + const handler = this.config.onPermissionRequest; + if (!handler) { + return Promise.resolve({ kind: 'denied-no-approval-rule-and-could-not-request-from-user' }); + } + return Promise.resolve(handler(request, { sessionId: this.sessionId })); + } + + emit(event: Record): void { + this.emitter.emit('event', event); + } + + on(handler: (event: Record) => void): () => void { + const wrapped = (event: Record) => handler(event); + this.emitter.addListener('event', wrapped); + return () => { + this.emitter.removeListener('event', wrapped); + }; + } + } + + class FakeCopilotClient { + private sessionCounter = 0; + readonly createdSessions: CopilotSpawnedSession[] = sessions; + + async start(): Promise { + state.clientCalls.start += 1; + if (state.startError) throw state.startError; + } + + async stop(): Promise { + state.clientCalls.stop += 1; + } + + async getStatus(): Promise<{ version: string; protocolVersion: number }> { + state.clientCalls.getStatus += 1; + if (state.statusError) throw state.statusError; + return { ...state.status }; + } + + async getAuthStatus(): Promise<{ isAuthenticated: boolean; statusMessage?: string }> { + state.clientCalls.getAuthStatus += 1; + if (state.authError) throw state.authError; + return { ...state.auth }; + } + + async listModels(): Promise> { + state.clientCalls.listModels += 1; + if (state.modelsError) throw state.modelsError; + return state.models.map((model) => ({ ...model })); + } + + async createSession(config: SessionConfig): Promise { + const sessionId = `copilot-session-${++this.sessionCounter}`; + const session = new FakeCopilotSession(sessionId, config); + sessions.push(session); + clients.push(this); + return session; + } + + async resumeSession(sessionId: string, config: SessionConfig): Promise { + const existing = sessions.find((session) => session.sessionId === sessionId); + if (existing) { + existing.config.onPermissionRequest = config.onPermissionRequest ?? existing.config.onPermissionRequest; + return existing as unknown as FakeCopilotSession; + } + const session = new FakeCopilotSession(sessionId, config); + sessions.push(session); + clients.push(this); + return session; + } + + async listSessions(): Promise> { + return sessions.map((session) => ({ + sessionId: session.sessionId, + summary: session.sessionId, + modifiedTime: new Date(1_700_000_000_000 + sessions.indexOf(session)), + })); + } + + async deleteSession(sessionId: string): Promise { + state.clientCalls.deleteSession.push(sessionId); + if (state.deleteSessionError) throw state.deleteSessionError; + if (!state.keepDeletedSessions) { + const idx = sessions.findIndex((session) => session.sessionId === sessionId); + if (idx >= 0) sessions.splice(idx, 1); + } + } + } + + const sdkModule = { CopilotClient: FakeCopilotClient }; + + return { + state, + sessions, + clients, + sdkModule, + lastSession(): CopilotSpawnedSession { + const session = sessions.at(-1); + if (!session) throw new Error('No Copilot session recorded'); + return session; + }, + reset(): void { + sessions.length = 0; + clients.length = 0; + }, + }; +} diff --git a/test/agent/providers/copilot-sdk.test.ts b/test/agent/providers/copilot-sdk.test.ts new file mode 100644 index 000000000..3fc4adf27 --- /dev/null +++ b/test/agent/providers/copilot-sdk.test.ts @@ -0,0 +1,384 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; +import { + CopilotSdkProvider, + copilotSdkRuntimeHooks, +} from '../../../src/agent/providers/copilot-sdk.js'; +import type { TransportAttachment } from '../../../shared/transport-attachments.js'; + +vi.mock('../../../src/util/logger.js', () => ({ + default: { + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + debug: vi.fn(), + }, +})); + +type FakeSessionConfig = Record & { + onPermissionRequest?: (request: Record) => Promise> | Record; +}; + +function createCopilotHarness(options?: { + version?: string; + protocolVersion?: number; + authenticated?: boolean; +}) { + const sessions = new Map(); + const createdConfigs: FakeSessionConfig[] = []; + const resumedConfigs: Array<{ sessionId: string; config: FakeSessionConfig }> = []; + const deletedSessions: string[] = []; + let nextSessionId = 1; + + class FakeSession { + readonly handlers = new Set<(event: Record) => void>(); + readonly send = vi.fn(async () => {}); + readonly abort = vi.fn(async () => {}); + readonly setModel = vi.fn(async () => {}); + readonly disconnect = vi.fn(async () => {}); + constructor(readonly sessionId: string) {} + on(handler: (event: Record) => void): () => void { + this.handlers.add(handler); + return () => this.handlers.delete(handler); + } + emit(event: Record): void { + for (const handler of this.handlers) handler(event); + } + } + + class FakeClient { + start = vi.fn(async () => {}); + stop = vi.fn(async () => {}); + getStatus = vi.fn(async () => ({ + version: options?.version ?? '1.0.31', + protocolVersion: options?.protocolVersion ?? 3, + })); + getAuthStatus = vi.fn(async () => ({ + isAuthenticated: options?.authenticated ?? true, + statusMessage: options?.authenticated === false ? 'sign in required' : 'authenticated', + })); + listModels = vi.fn(async () => [{ id: 'gpt-5.4' }]); + createSession = vi.fn(async (config: FakeSessionConfig) => { + createdConfigs.push(config); + const session = new FakeSession(`session-${nextSessionId++}`); + sessions.set(session.sessionId, session); + return session; + }); + resumeSession = vi.fn(async (sessionId: string, config: FakeSessionConfig) => { + resumedConfigs.push({ sessionId, config }); + const session = sessions.get(sessionId) ?? new FakeSession(sessionId); + sessions.set(session.sessionId, session); + return session; + }); + listSessions = vi.fn(async () => [...sessions.values()].map((session) => ({ + sessionId: session.sessionId, + summary: `summary:${session.sessionId}`, + modifiedTime: new Date('2026-01-01T00:00:00Z'), + }))); + deleteSession = vi.fn(async (sessionId: string) => { + deletedSessions.push(sessionId); + sessions.delete(sessionId); + }); + } + + return { + FakeClient, + sessions, + createdConfigs, + resumedConfigs, + deletedSessions, + }; +} + +describe('CopilotSdkProvider', () => { + const originalLoadSdk = copilotSdkRuntimeHooks.loadSdk; + + beforeEach(() => { + vi.useFakeTimers(); + }); + + afterEach(async () => { + copilotSdkRuntimeHooks.loadSdk = originalLoadSdk; + vi.useRealTimers(); + }); + + it('bridges SDK permission requests into approval callbacks and resolves responses', async () => { + const harness = createCopilotHarness(); + const provider = new CopilotSdkProvider(); + copilotSdkRuntimeHooks.loadSdk = async () => ({ + CopilotClient: harness.FakeClient, + }) as typeof import('@github/copilot-sdk'); + + const approvals: Array> = []; + provider.onApprovalRequest((_, req) => approvals.push(req as Record)); + await provider.connect({ binaryPath: 'copilot', approvalTimeoutMs: 250 }); + const routeId = await provider.createSession({ sessionKey: 'route-1', cwd: '/tmp/project', agentId: 'gpt-5.4' }); + + const client = (provider as unknown as { client?: InstanceType }).client; + const permissionHandler = harness.createdConfigs[0]?.onPermissionRequest as ((request: Record) => Promise>) | undefined; + expect(permissionHandler).toBeTypeOf('function'); + + const pending = permissionHandler?.({ kind: 'shell', fullCommandText: 'printf hello' }); + await vi.advanceTimersByTimeAsync(0); + expect(approvals).toEqual([ + expect.objectContaining({ + description: 'Allow shell command: printf hello', + tool: 'shell', + }), + ]); + + const approvalRequestId = String(approvals[0]?.id ?? ''); + await provider.respondApproval(routeId, approvalRequestId, true); + await expect(pending).resolves.toEqual({ kind: 'approved' }); + expect(client?.getStatus).toHaveBeenCalled(); + }); + + it('denies permission requests immediately when no approval callbacks are registered', async () => { + const harness = createCopilotHarness(); + const provider = new CopilotSdkProvider(); + copilotSdkRuntimeHooks.loadSdk = async () => ({ + CopilotClient: harness.FakeClient, + }) as typeof import('@github/copilot-sdk'); + + await provider.connect({ binaryPath: 'copilot' }); + await provider.createSession({ sessionKey: 'route-2', cwd: '/tmp/project' }); + + const denied = await (provider as unknown as { + handlePermissionRequest(routeId: string, request: Record): Promise>; + }).handlePermissionRequest('route-2', { kind: 'shell', command: 'rm -rf /' }); + + expect(denied).toEqual({ kind: 'denied-no-approval-rule-and-could-not-request-from-user' }); + }); + + it('fails safe when approval callbacks never answer by timing out and denying the request', async () => { + const harness = createCopilotHarness(); + const provider = new CopilotSdkProvider(); + copilotSdkRuntimeHooks.loadSdk = async () => ({ + CopilotClient: harness.FakeClient, + }) as typeof import('@github/copilot-sdk'); + + provider.onApprovalRequest(() => {}); + await provider.connect({ binaryPath: 'copilot', approvalTimeoutMs: 50 }); + await provider.createSession({ sessionKey: 'route-3', cwd: '/tmp/project' }); + + const pending = (provider as unknown as { + handlePermissionRequest(routeId: string, request: Record): Promise>; + }).handlePermissionRequest('route-3', { kind: 'shell', command: 'sleep 1' }); + await vi.advanceTimersByTimeAsync(49); + await Promise.resolve(); + await vi.advanceTimersByTimeAsync(1); + + await expect(pending).resolves.toEqual({ kind: 'denied-no-approval-rule-and-could-not-request-from-user' }); + }); + + it('rotates poisoned sessions after background-tainted abort and suppresses stale callbacks', async () => { + const harness = createCopilotHarness(); + const provider = new CopilotSdkProvider(); + copilotSdkRuntimeHooks.loadSdk = async () => ({ + CopilotClient: harness.FakeClient, + }) as typeof import('@github/copilot-sdk'); + + await provider.connect({ binaryPath: 'copilot' }); + const routeId = await provider.createSession({ sessionKey: 'route-4', cwd: '/tmp/project', agentId: 'gpt-5.4' }); + + const completeEvents: Array> = []; + const sessionInfos: Array> = []; + provider.onComplete((_, message) => completeEvents.push(message as Record)); + provider.onSessionInfo((_, info) => sessionInfos.push(info as Record)); + + const session = harness.sessions.get('session-1'); + expect(session).toBeTruthy(); + session?.emit({ type: 'session.background_tasks_changed', data: { backgroundTasks: [{ state: 'running' }] } }); + + await provider.cancel(routeId); + await vi.runAllTimersAsync(); + + expect(harness.deletedSessions).toContain('session-1'); + expect(harness.createdConfigs).toHaveLength(2); + expect(sessionInfos.some((info) => info.resumeId === 'session-2')).toBe(true); + + session?.emit({ + type: 'assistant.message', + data: { messageId: 'old-msg', content: 'stale content' }, + }); + expect(completeEvents).toHaveLength(0); + + await expect(provider.restoreSession('session-1')).resolves.toBe(false); + await expect(provider.restoreSession('session-2')).resolves.toBe(true); + const sessions = await provider.listSessions(); + expect(sessions.some((item) => item.key === 'session-1')).toBe(false); + expect(sessions.some((item) => item.key === 'session-2')).toBe(true); + }); + + it('waits for idle before completing a tool-driven turn with an initially empty assistant message', async () => { + const harness = createCopilotHarness(); + const provider = new CopilotSdkProvider(); + copilotSdkRuntimeHooks.loadSdk = async () => ({ + CopilotClient: harness.FakeClient, + }) as typeof import('@github/copilot-sdk'); + + await provider.connect({ binaryPath: 'copilot' }); + const routeId = await provider.createSession({ sessionKey: 'route-5', cwd: '/tmp/project', agentId: 'gpt-5.4' }); + + const completions: string[] = []; + provider.onComplete((sid, message) => { + if (sid === routeId) completions.push(String(message.content ?? '')); + }); + + await provider.send(routeId, 'Read the attachment and answer'); + + const session = Array.from(harness.sessions.values())[0]; + expect(session).toBeTruthy(); + session.emit({ + type: 'assistant.message', + data: { + messageId: 'msg-1', + content: '', + toolRequests: [{ toolCallId: 'tool-1', name: 'view' }], + }, + }); + expect(completions).toEqual([]); + + session.emit({ + type: 'assistant.message', + data: { + messageId: 'msg-2', + content: 'COPILOT_ATTACHMENT_OK', + toolRequests: [], + }, + }); + expect(completions).toEqual([]); + + session.emit({ type: 'session.idle', data: {} }); + expect(completions).toEqual(['COPILOT_ATTACHMENT_OK']); + }); + + it('uses normalized payload attachments instead of the raw legacy attachments argument', async () => { + const harness = createCopilotHarness(); + const provider = new CopilotSdkProvider(); + copilotSdkRuntimeHooks.loadSdk = async () => ({ + CopilotClient: harness.FakeClient, + }) as typeof import('@github/copilot-sdk'); + + await provider.connect({ binaryPath: 'copilot' }); + const routeId = await provider.createSession({ sessionKey: 'route-attachments', cwd: '/tmp/project' }); + const normalizedAttachment: TransportAttachment = { + daemonPath: '/tmp/project/attached.txt', + originalName: 'attached.txt', + }; + const rawAttachment: TransportAttachment = { + daemonPath: '/tmp/project/legacy.txt', + originalName: 'legacy.txt', + }; + + await provider.send(routeId, { + userMessage: 'Read the attachment', + assembledMessage: 'Read the attachment', + systemText: undefined, + messagePreamble: undefined, + attachments: [normalizedAttachment], + context: { + systemText: undefined, + messagePreamble: undefined, + requiredAuthoredContext: [], + advisoryAuthoredContext: [], + appliedDocumentVersionIds: [], + diagnostics: [], + }, + authority: { + namespace: undefined, + authoritySource: 'none', + freshness: 'missing', + fallbackAllowed: true, + retryScheduled: false, + diagnostics: [], + }, + supportClass: 'degraded-message-side-context-mapping', + diagnostics: [], + }, [rawAttachment]); + + const sendPayload = harness.sessions.get('session-1')?.send.mock.calls[0]?.[0] as Record; + expect(sendPayload.attachments).toEqual([ + { type: 'file', path: '/tmp/project/attached.txt', displayName: 'attached.txt' }, + ]); + }); + + it('rotates even when background taint arrives after cancel', async () => { + const harness = createCopilotHarness(); + const provider = new CopilotSdkProvider(); + copilotSdkRuntimeHooks.loadSdk = async () => ({ + CopilotClient: harness.FakeClient, + }) as typeof import('@github/copilot-sdk'); + + await provider.connect({ binaryPath: 'copilot' }); + const routeId = await provider.createSession({ sessionKey: 'route-late-taint', cwd: '/tmp/project', agentId: 'gpt-5.4' }); + + const infos: Array> = []; + provider.onSessionInfo((_, info) => infos.push(info as Record)); + + const session = harness.sessions.get('session-1'); + expect(session).toBeTruthy(); + session!.abort.mockImplementation(async () => { + queueMicrotask(() => { + session!.emit({ type: 'session.background_tasks_changed', data: { backgroundTasks: [{ state: 'running' }] } }); + }); + }); + + await provider.cancel(routeId); + await vi.runAllTimersAsync(); + + expect(harness.deletedSessions).toContain('session-1'); + expect(infos.some((info) => info.resumeId === 'session-2')).toBe(true); + }); + + it('retains output token and interaction metadata when completing on idle', async () => { + const harness = createCopilotHarness(); + const provider = new CopilotSdkProvider(); + copilotSdkRuntimeHooks.loadSdk = async () => ({ + CopilotClient: harness.FakeClient, + }) as typeof import('@github/copilot-sdk'); + + await provider.connect({ binaryPath: 'copilot' }); + const routeId = await provider.createSession({ sessionKey: 'route-metadata', cwd: '/tmp/project', agentId: 'gpt-5.4' }); + + const completions: Array> = []; + provider.onComplete((sid, message) => { + if (sid === routeId) completions.push(message as Record); + }); + + await provider.send(routeId, 'reply'); + const session = harness.sessions.get('session-1')!; + session.emit({ type: 'assistant.message_delta', data: { messageId: 'msg-meta', deltaContent: 'Hello there' } }); + session.emit({ type: 'assistant.message', data: { messageId: 'msg-meta', content: 'Hi', interactionId: 'ix-1' } }); + session.emit({ type: 'assistant.usage', data: { outputTokens: 42, interactionId: 'ix-1' } }); + session.emit({ type: 'session.idle', data: {} }); + + expect(completions).toHaveLength(1); + expect(completions[0].content).toBe('Hello there'); + expect(completions[0].metadata).toMatchObject({ + interactionId: 'ix-1', + usage: { output_tokens: 42 }, + resumeId: 'session-1', + model: 'gpt-5.4', + }); + }); + + it('rejects incompatible versions and unauthenticated clients at connect time', async () => { + const incompatibleHarness = createCopilotHarness({ version: '0.9.0' }); + const incompatibleProvider = new CopilotSdkProvider(); + copilotSdkRuntimeHooks.loadSdk = async () => ({ + CopilotClient: incompatibleHarness.FakeClient, + }) as typeof import('@github/copilot-sdk'); + await expect(incompatibleProvider.connect({ binaryPath: 'copilot' })).rejects.toMatchObject({ + code: 'CONFIG_ERROR', + }); + + const authHarness = createCopilotHarness({ authenticated: false }); + const authProvider = new CopilotSdkProvider(); + copilotSdkRuntimeHooks.loadSdk = async () => ({ + CopilotClient: authHarness.FakeClient, + }) as typeof import('@github/copilot-sdk'); + await expect(authProvider.connect({ binaryPath: 'copilot' })).rejects.toMatchObject({ + code: 'AUTH_FAILED', + }); + }); +}); diff --git a/test/agent/providers/cursor-headless-stream.test.ts b/test/agent/providers/cursor-headless-stream.test.ts new file mode 100644 index 000000000..a53a9b2e6 --- /dev/null +++ b/test/agent/providers/cursor-headless-stream.test.ts @@ -0,0 +1,135 @@ +import { describe, expect, it } from 'vitest'; +import { parseCursorStreamLine } from '../../../src/agent/providers/cursor-headless-stream.js'; + +describe('parseCursorStreamLine', () => { + it('normalizes system init, streamed deltas, tool events, and completion records', () => { + expect(parseCursorStreamLine(JSON.stringify({ + type: 'system.init', + session_id: 'cursor-chat-1', + model: 'GPT-5.2', + permissionMode: 'default', + }))).toEqual({ + kind: 'session.init', + sessionId: 'cursor-chat-1', + model: 'GPT-5.2', + permissionMode: 'default', + raw: { + type: 'system.init', + session_id: 'cursor-chat-1', + model: 'GPT-5.2', + permissionMode: 'default', + }, + }); + + expect(parseCursorStreamLine(JSON.stringify({ + type: 'stream_event', + session_id: 'cursor-chat-1', + event: { + type: 'content_block_delta', + delta: { + type: 'text_delta', + text: 'Hel', + }, + }, + }))).toEqual({ + kind: 'assistant.delta', + sessionId: 'cursor-chat-1', + text: 'Hel', + raw: { + type: 'stream_event', + session_id: 'cursor-chat-1', + event: { + type: 'content_block_delta', + delta: { + type: 'text_delta', + text: 'Hel', + }, + }, + }, + }); + + expect(parseCursorStreamLine(JSON.stringify({ + type: 'tool_call.started', + id: 'tool-1', + name: 'shell', + input: { command: 'printf hello' }, + }))).toEqual({ + kind: 'tool.started', + sessionId: undefined, + id: 'tool-1', + name: 'shell', + input: { command: 'printf hello' }, + raw: { + type: 'tool_call.started', + id: 'tool-1', + name: 'shell', + input: { command: 'printf hello' }, + }, + }); + + expect(parseCursorStreamLine(JSON.stringify({ + type: 'tool_call.completed', + id: 'tool-1', + name: 'shell', + output: 'hello', + }))).toEqual({ + kind: 'tool.completed', + sessionId: undefined, + id: 'tool-1', + name: 'shell', + output: 'hello', + raw: { + type: 'tool_call.completed', + id: 'tool-1', + name: 'shell', + output: 'hello', + }, + }); + + expect(parseCursorStreamLine(JSON.stringify({ + type: 'assistant', + message: { + id: 'msg-1', + content: [{ type: 'text', text: 'Hello' }], + }, + }))).toEqual({ + kind: 'assistant.final', + sessionId: undefined, + messageId: 'msg-1', + text: 'Hello', + raw: { + type: 'assistant', + message: { + id: 'msg-1', + content: [{ type: 'text', text: 'Hello' }], + }, + }, + }); + + expect(parseCursorStreamLine(JSON.stringify({ + type: 'result.success', + session_id: 'cursor-chat-1', + result: 'Hello', + usage: { input_tokens: 3, output_tokens: 2 }, + }))).toEqual({ + kind: 'result.success', + sessionId: 'cursor-chat-1', + model: undefined, + text: 'Hello', + usage: { input_tokens: 3, output_tokens: 2 }, + raw: { + type: 'result.success', + session_id: 'cursor-chat-1', + result: 'Hello', + usage: { input_tokens: 3, output_tokens: 2 }, + }, + }); + }); + + it('ignores invalid or irrelevant records', () => { + expect(parseCursorStreamLine('')).toBeNull(); + expect(parseCursorStreamLine('not-json')).toBeNull(); + expect(parseCursorStreamLine(JSON.stringify({ type: 'user', message: { content: [] } }))).toBeNull(); + }); +}); + diff --git a/test/agent/providers/cursor-headless.test.ts b/test/agent/providers/cursor-headless.test.ts new file mode 100644 index 000000000..bd1b9e3dc --- /dev/null +++ b/test/agent/providers/cursor-headless.test.ts @@ -0,0 +1,207 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; +import { + CursorHeadlessProvider, + cursorHeadlessRuntimeHooks, +} from '../../../src/agent/providers/cursor-headless.js'; +import { createCursorHeadlessHarness } from '../../cursor-headless-fixture.js'; +import type { ProviderContextPayload } from '../../../shared/context-types.js'; + +vi.mock('../../../src/util/logger.js', () => ({ + default: { + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + debug: vi.fn(), + }, +})); + +describe('CursorHeadlessProvider', () => { + const originalLoadChildProcess = cursorHeadlessRuntimeHooks.loadChildProcess; + let harness = createCursorHeadlessHarness(); + + beforeEach(() => { + harness = createCursorHeadlessHarness(); + cursorHeadlessRuntimeHooks.loadChildProcess = async () => ({ + execFile: harness.execFile, + spawn: harness.spawn, + } as typeof import('node:child_process')); + }); + + afterEach(() => { + cursorHeadlessRuntimeHooks.loadChildProcess = originalLoadChildProcess; + }); + + it('connects by probing version and authentication status', async () => { + const provider = new CursorHeadlessProvider(); + await provider.connect({ binaryPath: 'cursor-agent' }); + + expect(harness.execFile.mock.calls.some((call) => Array.isArray(call[1]) && (call[1] as string[]).includes('--version'))).toBe(true); + expect(harness.execFile.mock.calls.some((call) => Array.isArray(call[1]) && (call[1] as string[]).includes('status'))).toBe(true); + }); + + it('rejects when the status probe reports a logged-out account', async () => { + harness.state.statusOutput = 'Not logged in\n'; + const provider = new CursorHeadlessProvider(); + await expect(provider.connect({ binaryPath: 'cursor-agent' })).rejects.toMatchObject({ code: 'AUTH_FAILED' }); + }); + + it('rejects unsupported versions and ambiguous auth probe output', async () => { + harness.state.versionOutput = 'Cursor Agent 0.9.9\n'; + const oldVersionProvider = new CursorHeadlessProvider(); + await expect(oldVersionProvider.connect({ binaryPath: 'cursor-agent' })).rejects.toMatchObject({ code: 'CONFIG_ERROR' }); + + harness = createCursorHeadlessHarness({ + versionOutput: 'Cursor Agent 1.0.0\n', + statusOutput: 'status probe returned something unexpected\n', + }); + cursorHeadlessRuntimeHooks.loadChildProcess = async () => ({ + execFile: harness.execFile, + spawn: harness.spawn, + } as typeof import('node:child_process')); + + const ambiguousAuthProvider = new CursorHeadlessProvider(); + await expect(ambiguousAuthProvider.connect({ binaryPath: 'cursor-agent' })).rejects.toMatchObject({ code: 'CONFIG_ERROR' }); + }); + + it('maps version probe failures to provider-not-found and status failures to config errors', async () => { + harness.state.versionError = new Error('cursor-agent not found'); + const missingBinaryProvider = new CursorHeadlessProvider(); + await expect(missingBinaryProvider.connect({ binaryPath: 'cursor-agent' })).rejects.toMatchObject({ + code: 'PROVIDER_NOT_FOUND', + }); + + harness.state.versionError = null; + harness.state.statusError = new Error('status probe failed unexpectedly'); + const statusFailureProvider = new CursorHeadlessProvider(); + await expect(statusFailureProvider.connect({ binaryPath: 'cursor-agent' })).rejects.toMatchObject({ + code: 'CONFIG_ERROR', + }); + }); + + it('creates a route id, emits durable session info, and restores by either route or resume id', async () => { + harness.state.createChatOutput = 'cursor-chat-9\n'; + const provider = new CursorHeadlessProvider(); + await provider.connect({ binaryPath: 'cursor-agent' }); + + const sessionInfo: Array> = []; + provider.onSessionInfo((_, info) => sessionInfo.push(info as Record)); + + const routeId = await provider.createSession({ + sessionKey: 'route-1', + cwd: '/tmp/project', + agentId: 'gpt-5.2', + }); + + expect(routeId).toBe('route-1'); + expect(sessionInfo).toContainEqual({ resumeId: 'cursor-chat-9', model: 'gpt-5.2' }); + expect(provider.capabilities).toMatchObject({ + streaming: true, + toolCalling: true, + approval: false, + sessionRestore: true, + multiTurn: true, + attachments: false, + }); + expect(provider.connectionMode).toBe('local-sdk'); + expect((provider as { listSessions?: unknown }).listSessions).toBeUndefined(); + await expect(provider.restoreSession(routeId)).resolves.toBe(true); + await expect(provider.restoreSession('cursor-chat-9')).resolves.toBe(true); + await expect(provider.restoreSession('missing-session')).resolves.toBe(false); + }); + + it('streams cumulative deltas, tool events, and completion from stream-json output', async () => { + harness.state.createChatOutput = 'cursor-chat-2\n'; + const provider = new CursorHeadlessProvider(); + await provider.connect({ binaryPath: 'cursor-agent' }); + const sessionId = await provider.createSession({ + sessionKey: 'route-2', + cwd: '/tmp/project', + agentId: 'gpt-5.2', + }); + + const deltas: string[] = []; + const completed: string[] = []; + const tools: Array<{ name: string; status: string; output?: string }> = []; + const infos: Array> = []; + provider.onDelta((_sid, delta) => deltas.push(delta.delta)); + provider.onComplete((_sid, msg) => completed.push(String(msg.content))); + provider.onToolCall((_sid, tool) => tools.push({ name: tool.name, status: tool.status, output: tool.output })); + provider.onSessionInfo((_, info) => infos.push(info as Record)); + + await provider.send(sessionId, { + userMessage: 'ship it', + assembledMessage: 'Relevant context\n\nship it', + systemText: 'Normalized system text', + messagePreamble: 'Relevant context', + attachments: [], + context: { + systemText: 'Normalized system text', + messagePreamble: 'Relevant context', + requiredAuthoredContext: [], + advisoryAuthoredContext: [], + appliedDocumentVersionIds: [], + diagnostics: [], + }, + authority: { + namespace: { scope: 'personal', projectId: 'route-2' }, + authoritySource: 'none', + freshness: 'missing', + fallbackAllowed: true, + retryScheduled: false, + diagnostics: [], + }, + supportClass: 'degraded-message-side-context-mapping', + diagnostics: [], + } satisfies ProviderContextPayload); + + const spawned = harness.lastSpawn(); + expect(spawned.file).toBe('cursor-agent'); + expect(spawned.args).toContain('-p'); + expect(spawned.args).toContain('--trust'); + expect(spawned.args).toContain('--force'); + expect(spawned.args).toContain('--output-format'); + expect(spawned.args).toContain('stream-json'); + expect(spawned.args).toContain('--stream-partial-output'); + expect(spawned.args).toContain('--resume'); + expect(spawned.args).toContain('cursor-chat-2'); + expect(spawned.args).toContain('--model'); + expect(spawned.args).toContain('gpt-5.2'); + expect(spawned.args.at(-1)).toBe('Normalized system text\n\nRelevant context\n\nship it'); + + spawned.child.stdout.write(`${JSON.stringify({ type: 'system.init', session_id: 'cursor-chat-2', model: 'gpt-5.2', permissionMode: 'default' })}\n`); + spawned.child.stdout.write(`${JSON.stringify({ type: 'stream_event', session_id: 'cursor-chat-2', event: { type: 'content_block_delta', delta: { type: 'text_delta', text: 'Hel' } } })}\n`); + spawned.child.stdout.write(`${JSON.stringify({ type: 'stream_event', session_id: 'cursor-chat-2', event: { type: 'content_block_delta', delta: { type: 'text_delta', text: 'lo' } } })}\n`); + spawned.child.stdout.write(`${JSON.stringify({ type: 'tool_call.started', session_id: 'cursor-chat-2', id: 'tool-1', name: 'shell', input: { command: 'printf hello' } })}\n`); + spawned.child.stdout.write(`${JSON.stringify({ type: 'tool_call.completed', session_id: 'cursor-chat-2', id: 'tool-1', name: 'shell', output: 'hello' })}\n`); + spawned.child.stdout.write(`${JSON.stringify({ type: 'assistant', session_id: 'cursor-chat-2', message: { id: 'msg-1', content: [{ type: 'text', text: 'Hello' }] } })}\n`); + spawned.child.stdout.write(`${JSON.stringify({ type: 'result.success', session_id: 'cursor-chat-2', result: 'Hello', usage: { input_tokens: 3, output_tokens: 2 } })}\n`); + spawned.child.emit('close', 0, null); + await harness.flush(); + + expect(deltas).toEqual(['Hel', 'Hello']); + expect(completed).toEqual(['Hello']); + expect(tools).toEqual([ + { name: 'shell', status: 'running', output: undefined }, + { name: 'shell', status: 'complete', output: 'hello' }, + ]); + expect(infos).toContainEqual({ resumeId: 'cursor-chat-2', model: 'gpt-5.2' }); + }); + + it('cancels the active child process and emits a recoverable cancelled error', async () => { + const provider = new CursorHeadlessProvider(); + await provider.connect({ binaryPath: 'cursor-agent' }); + const sessionId = await provider.createSession({ sessionKey: 'route-cancel', cwd: '/tmp/project' }); + + const errors: Array> = []; + provider.onError((_sid, error) => errors.push(error as Record)); + + const sendPromise = provider.send(sessionId, 'reply with nothing'); + await harness.flush(); + await provider.cancel(sessionId); + await sendPromise; + await harness.flush(); + + expect(harness.lastSpawn().child.killed).toBe(true); + expect(errors.some((error) => error.code === 'CANCELLED')).toBe(true); + }); +}); diff --git a/test/cursor-headless-fixture.ts b/test/cursor-headless-fixture.ts new file mode 100644 index 000000000..fab4f3181 --- /dev/null +++ b/test/cursor-headless-fixture.ts @@ -0,0 +1,102 @@ +import { EventEmitter } from 'node:events'; +import { PassThrough } from 'node:stream'; +import { vi } from 'vitest'; + +export interface CursorHarnessOptions { + versionOutput?: string; + statusOutput?: string; + createChatOutput?: string; + versionError?: Error | null; + statusError?: Error | null; + createChatError?: Error | null; +} + +export interface CursorSpawnRecord { + file: string; + args: string[]; + cwd?: string; + env?: NodeJS.ProcessEnv; + child: EventEmitter & { + stdout: PassThrough; + stderr: PassThrough; + stdin: PassThrough; + kill: ReturnType; + killed: boolean; + }; +} + +export function createCursorHeadlessHarness(options: CursorHarnessOptions = {}) { + const state = { + versionOutput: options.versionOutput ?? 'Cursor Agent 1.0.0\n', + statusOutput: options.statusOutput ?? 'Logged in\n', + createChatOutput: options.createChatOutput ?? 'cursor-chat-1\n', + versionError: options.versionError ?? null, + statusError: options.statusError ?? null, + createChatError: options.createChatError ?? null, + }; + + const spawned: CursorSpawnRecord[] = []; + + const execFile = vi.fn((file: string, args: string[], optsOrCb?: unknown, maybeCb?: unknown) => { + const cb = typeof optsOrCb === 'function' + ? optsOrCb as (err: Error | null, stdout: string, stderr: string) => void + : maybeCb as ((err: Error | null, stdout: string, stderr: string) => void) | undefined; + if (args.includes('--version')) { + if (state.versionError) cb?.(state.versionError, '', ''); + else cb?.(null, state.versionOutput, ''); + return {} as never; + } + if (args[0] === 'status') { + if (state.statusError) { + cb?.(state.statusError, '', ''); + } else { + cb?.(null, state.statusOutput, ''); + } + return {} as never; + } + if (args[0] === 'create-chat') { + if (state.createChatError) { + cb?.(state.createChatError, '', ''); + } else { + cb?.(null, state.createChatOutput, ''); + } + return {} as never; + } + cb?.(null, '', ''); + return {} as never; + }); + + const spawn = vi.fn((file: string, args: string[], opts: { cwd?: string; env?: NodeJS.ProcessEnv }) => { + const stdout = new PassThrough(); + const stderr = new PassThrough(); + const stdin = new PassThrough(); + const child = new EventEmitter() as CursorSpawnRecord['child']; + child.stdout = stdout; + child.stderr = stderr; + child.stdin = stdin; + child.killed = false; + child.kill = vi.fn((signal?: string) => { + child.killed = true; + queueMicrotask(() => child.emit('close', 0, signal ?? 'SIGTERM')); + return true; + }); + spawned.push({ file, args, cwd: opts.cwd, env: opts.env, child }); + queueMicrotask(() => child.emit('spawn')); + return child as never; + }); + + return { + state, + spawned, + execFile, + spawn, + lastSpawn(): CursorSpawnRecord { + const entry = spawned.at(-1); + if (!entry) throw new Error('No Cursor spawn recorded'); + return entry; + }, + async flush(): Promise { + await new Promise((resolve) => setTimeout(resolve, 0)); + }, + }; +} diff --git a/test/daemon/command-handler-transport-queue.test.ts b/test/daemon/command-handler-transport-queue.test.ts index b437a6b55..5779d966f 100644 --- a/test/daemon/command-handler-transport-queue.test.ts +++ b/test/daemon/command-handler-transport-queue.test.ts @@ -1,5 +1,6 @@ import { beforeEach, describe, expect, it, vi } from 'vitest'; import { DAEMON_COMMAND_TYPES } from '../../shared/daemon-command-types.js'; +import { TRANSPORT_MSG } from '../../shared/transport-events.js'; const { getSessionMock, @@ -888,4 +889,36 @@ describe('handleWebCommand transport queue behavior', () => { expect(resizeSessionMock).not.toHaveBeenCalled(); }); + + it('forwards transport approval responses to the live runtime and rebroadcasts them', async () => { + const respondApproval = vi.fn().mockResolvedValue(undefined); + getSessionMock.mockReturnValue({ + name: 'deck_transport_brain', + projectName: 'transport', + role: 'brain', + agentType: 'copilot-sdk', + runtimeType: 'transport', + state: 'running', + }); + getTransportRuntimeMock.mockReturnValue({ + providerSessionId: 'provider-route-1', + respondApproval, + }); + + await handleWebCommand({ + type: TRANSPORT_MSG.APPROVAL_RESPONSE, + sessionId: 'deck_transport_brain', + requestId: 'approval-1', + approved: true, + }, serverLink as any); + await flushAsync(); + + expect(respondApproval).toHaveBeenCalledWith('approval-1', true); + expect(serverLink.send).toHaveBeenCalledWith(expect.objectContaining({ + type: TRANSPORT_MSG.APPROVAL_RESPONSE, + sessionId: 'deck_transport_brain', + requestId: 'approval-1', + approved: true, + })); + }); }); diff --git a/test/daemon/copilot-sdk-runtime.test.ts b/test/daemon/copilot-sdk-runtime.test.ts new file mode 100644 index 000000000..0d959512b --- /dev/null +++ b/test/daemon/copilot-sdk-runtime.test.ts @@ -0,0 +1,83 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; +import { + CopilotSdkProvider, + copilotSdkRuntimeHooks, +} from "../../src/agent/providers/copilot-sdk.js"; +import { TransportSessionRuntime } from "../../src/agent/transport-session-runtime.js"; +import { createCopilotSdkHarness } from "../agent/providers/copilot-sdk-harness.js"; + +vi.mock("../../src/util/logger.js", () => ({ + default: { info: vi.fn(), warn: vi.fn(), error: vi.fn(), debug: vi.fn() }, +})); + +describe("CopilotSdkProvider + TransportSessionRuntime", () => { + const originalLoadSdk = copilotSdkRuntimeHooks.loadSdk; + let harness = createCopilotSdkHarness(); + + beforeEach(() => { + harness = createCopilotSdkHarness(); + copilotSdkRuntimeHooks.loadSdk = async () => harness.sdkModule as never; + }); + + afterEach(() => { + copilotSdkRuntimeHooks.loadSdk = originalLoadSdk; + }); + + it("does not let stale poisoned-session callbacks resolve a later runtime turn", async () => { + const provider = new CopilotSdkProvider(); + await provider.connect({ binaryPath: "copilot" }); + + const runtime = new TransportSessionRuntime( + provider, + "deck_copilot_runtime_brain", + ); + const statuses: string[] = []; + runtime.onStatusChange = (status) => { + statuses.push(status); + }; + await runtime.initialize({ + sessionKey: "deck_copilot_runtime_brain", + cwd: "/tmp/project", + }); + + runtime.send("first turn"); + const oldSession = harness.lastSession(); + oldSession.emit({ + type: "tool.execution_start", + data: { + toolCallId: "tool-1", + toolName: "shell", + arguments: { mode: "async", command: "sleep 30" }, + }, + }); + + await runtime.cancel(); + const rotatedSession = harness.lastSession(); + expect(rotatedSession.sessionId).toBe("copilot-session-2"); + + runtime.send("second turn"); + oldSession.emit({ + type: "assistant.message_delta", + data: { messageId: "stale-msg", deltaContent: "STALE" }, + }); + oldSession.emit({ + type: "assistant.message", + data: { messageId: "stale-msg", content: "STALE" }, + }); + rotatedSession.emit({ + type: "assistant.message", + data: { messageId: "fresh-msg", content: "FRESH" }, + }); + rotatedSession.emit({ type: "session.idle", data: {} }); + + await new Promise((resolve) => setTimeout(resolve, 0)); + + const history = runtime.getHistory(); + expect(history.at(-1)?.content).toBe("FRESH"); + expect(history.some((entry) => String(entry.content) === "STALE")).toBe( + false, + ); + expect(runtime.getStatus()).toBe("idle"); + expect(statuses.includes("error")).toBe(false); + }); +}); diff --git a/test/daemon/cursor-copilot-transport-restore.test.ts b/test/daemon/cursor-copilot-transport-restore.test.ts new file mode 100644 index 000000000..77019cc55 --- /dev/null +++ b/test/daemon/cursor-copilot-transport-restore.test.ts @@ -0,0 +1,379 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; +import { EventEmitter } from "node:events"; +import { PassThrough, Writable } from "node:stream"; + +const mocks = vi.hoisted(() => { + const store = new Map>(); + const cursorSpawns: Array<{ + file: string; + args: string[]; + child: EventEmitter & { + stdout: PassThrough; + stderr: PassThrough; + stdin: Writable; + killed: boolean; + kill: ReturnType; + }; + }> = []; + const copilotRuns: Array<{ + sessionId: string; + prompt: string; + attachments?: Array>; + }> = []; + return { store, cursorSpawns, copilotRuns }; +}); + +vi.mock("node:child_process", async (importOriginal) => { + const actual = await importOriginal(); + const execFile = vi.fn( + (file: string, args: string[], optsOrCb?: unknown, maybeCb?: unknown) => { + const cb = (typeof optsOrCb === "function" ? optsOrCb : maybeCb) as + | ((err: Error | null, stdout: string, stderr: string) => void) + | undefined; + if (args.includes("--version")) { + cb?.(null, "Cursor Agent 1.0.0\n", ""); + return {} as never; + } + if (args[0] === "status") { + cb?.(null, "Logged in\n", ""); + return {} as never; + } + if (args[0] === "create-chat") { + cb?.(null, "cursor-chat-restored\n", ""); + return {} as never; + } + cb?.(null, "ok\n", ""); + return {} as never; + }, + ); + const spawn = vi.fn((file: string, args: string[]) => { + const stdout = new PassThrough(); + const stderr = new PassThrough(); + const stdin = new Writable({ + write(_chunk, _enc, cb) { + cb(); + }, + }); + const child = new EventEmitter() as EventEmitter & { + stdout: PassThrough; + stderr: PassThrough; + stdin: Writable; + killed: boolean; + kill: ReturnType; + }; + child.stdout = stdout; + child.stderr = stderr; + child.stdin = stdin; + child.killed = false; + child.kill = vi.fn((signal?: string) => { + child.killed = true; + queueMicrotask(() => child.emit("close", 0, signal ?? "SIGTERM")); + return true; + }); + mocks.cursorSpawns.push({ file, args, child }); + queueMicrotask(() => child.emit("spawn")); + return child as never; + }); + return { ...actual, execFile, spawn }; +}); + +vi.mock("@github/copilot-sdk", () => { + class FakeSession { + sessionId: string; + handlers = new Set<(event: Record) => void>(); + constructor(sessionId: string) { + this.sessionId = sessionId; + } + async send(options: Record): Promise { + mocks.copilotRuns.push({ + sessionId: this.sessionId, + prompt: String(options.prompt ?? ""), + attachments: options.attachments as + | Array> + | undefined, + }); + for (const handler of this.handlers) { + handler({ + type: "assistant.message", + data: { messageId: "msg-1", content: "ACK" }, + }); + handler({ type: "session.idle", data: {} }); + } + } + async abort(): Promise {} + async setModel( + _model: string, + _options?: Record, + ): Promise {} + on(handler: (event: Record) => void): () => void { + this.handlers.add(handler); + return () => { + this.handlers.delete(handler); + }; + } + async disconnect(): Promise {} + } + class CopilotClient { + async start(): Promise {} + async stop(): Promise {} + async getStatus(): Promise<{ version: string; protocolVersion: number }> { + return { version: "1.0.31", protocolVersion: 3 }; + } + async getAuthStatus(): Promise<{ + isAuthenticated: boolean; + statusMessage?: string; + }> { + return { isAuthenticated: true, statusMessage: "Logged in" }; + } + async listModels(): Promise> { + return [{ id: "gpt-5.4" }]; + } + async createSession(): Promise { + return new FakeSession("copilot-created"); + } + async resumeSession(sessionId: string): Promise { + return new FakeSession(sessionId); + } + async listSessions(): Promise< + Array<{ sessionId: string; summary?: string }> + > { + return [{ sessionId: "copilot-session-restore", summary: "restored" }]; + } + async deleteSession(_sessionId: string): Promise {} + } + return { CopilotClient }; +}); + +vi.mock("../../src/store/session-store.js", () => ({ + listSessions: vi.fn(() => [...mocks.store.values()]), + getSession: vi.fn((name: string) => mocks.store.get(name) ?? null), + upsertSession: vi.fn((record: Record) => { + if (record.name) mocks.store.set(record.name, record); + }), + removeSession: vi.fn((name: string) => { + mocks.store.delete(name); + }), + updateSessionState: vi.fn((name: string, state: string) => { + const existing = mocks.store.get(name); + if (existing) mocks.store.set(name, { ...existing, state }); + }), +})); + +vi.mock("../../src/daemon/transport-relay.js", () => ({ + wireProviderToRelay: vi.fn(), + broadcastProviderStatus: vi.fn(), +})); +vi.mock("../../src/util/logger.js", () => ({ + default: { info: vi.fn(), warn: vi.fn(), error: vi.fn(), debug: vi.fn() }, +})); +vi.mock("../../src/daemon/timeline-emitter.js", () => ({ + timelineEmitter: { + emit: vi.fn(), + on: vi.fn(() => () => {}), + epoch: 0, + replay: vi.fn(() => ({ events: [], truncated: false })), + }, +})); +vi.mock("../../src/agent/tmux.js", () => ({ + listSessions: vi.fn().mockResolvedValue([]), + newSession: vi.fn().mockResolvedValue(undefined), + killSession: vi.fn().mockResolvedValue(undefined), + sessionExists: vi.fn(), + isPaneAlive: vi.fn(), + respawnPane: vi.fn(), + sendKeys: vi.fn(), + sendKey: vi.fn(), + capturePane: vi.fn(), + showBuffer: vi.fn(), + getPaneId: vi.fn().mockResolvedValue(undefined), + getPaneCwd: vi.fn().mockResolvedValue("/tmp"), + getPaneStartCommand: vi.fn().mockResolvedValue(""), + cleanupOrphanFifos: vi.fn(), + BACKEND: "tmux", +})); +vi.mock("../../src/daemon/jsonl-watcher.js", () => ({ + startWatching: vi.fn().mockResolvedValue(undefined), + startWatchingFile: vi.fn().mockResolvedValue(undefined), + stopWatching: vi.fn(), + isWatching: vi.fn(() => false), + findJsonlPathBySessionId: vi.fn(() => "/tmp/mock.jsonl"), +})); +vi.mock("../../src/daemon/codex-watcher.js", () => ({ + startWatching: vi.fn().mockResolvedValue(undefined), + startWatchingSpecificFile: vi.fn().mockResolvedValue(undefined), + startWatchingById: vi.fn().mockResolvedValue(undefined), + stopWatching: vi.fn(), + isWatching: vi.fn(() => false), + findRolloutPathByUuid: vi.fn(async () => null), +})); +vi.mock("../../src/daemon/gemini-watcher.js", () => ({ + startWatching: vi.fn().mockResolvedValue(undefined), + startWatchingLatest: vi.fn().mockResolvedValue(undefined), + stopWatching: vi.fn(), + isWatching: vi.fn(() => false), +})); +vi.mock("../../src/daemon/opencode-watcher.js", () => ({ + startWatching: vi.fn().mockResolvedValue(undefined), + stopWatching: vi.fn(), + isWatching: vi.fn(() => false), +})); +vi.mock("../../src/agent/structured-session-bootstrap.js", () => ({ + resolveStructuredSessionBootstrap: vi.fn(async (x) => x), +})); +vi.mock("../../src/agent/qwen-runtime-config.js", () => ({ + getQwenRuntimeConfig: vi.fn(async () => null), +})); +vi.mock("../../src/agent/sdk-runtime-config.js", () => ({ + getClaudeSdkRuntimeConfig: vi.fn(async () => ({})), +})); +vi.mock("../../src/agent/codex-runtime-config.js", () => ({ + getCodexRuntimeConfig: vi.fn(async () => ({})), +})); +vi.mock("../../src/agent/provider-display.js", () => ({ + getQwenDisplayMetadata: vi.fn(() => ({})), +})); +vi.mock("../../src/agent/provider-quota.js", () => ({ + getQwenOAuthQuotaUsageLabel: vi.fn(() => ""), +})); +vi.mock("../../src/agent/agent-version.js", () => ({ + getAgentVersion: vi.fn(async () => "test"), +})); +vi.mock("../../src/agent/signal.js", () => ({ + setupCCStopHook: vi.fn(async () => {}), +})); +vi.mock("../../src/agent/notify-setup.js", () => ({ + setupCodexNotify: vi.fn(async () => {}), + setupOpenCodePlugin: vi.fn(async () => {}), +})); +vi.mock("../../src/repo/cache.js", () => ({ + repoCache: { invalidate: vi.fn() }, +})); +vi.mock("../../src/agent/brain-dispatcher.js", () => ({ + BrainDispatcher: vi + .fn() + .mockImplementation(() => ({ start: vi.fn(), stop: vi.fn() })), +})); + +import { + connectProvider, + disconnectAll, +} from "../../src/agent/provider-registry.js"; +import { + getTransportRuntime, + restoreTransportSessions, +} from "../../src/agent/session-manager.js"; + +const flush = async () => { + for (let i = 0; i < 4; i++) + await new Promise((resolve) => setTimeout(resolve, 0)); +}; + +describe("cursor/copilot transport restore", () => { + beforeEach(() => { + mocks.store.clear(); + mocks.cursorSpawns.length = 0; + mocks.copilotRuns.length = 0; + }); + + afterEach(async () => { + await disconnectAll(); + }); + + it("restores cursor-headless sessions with persisted provider resume ids", async () => { + mocks.store.set("deck_cursor_restore_brain", { + name: "deck_cursor_restore_brain", + projectName: "cursorrestore", + role: "brain", + agentType: "cursor-headless", + projectDir: "/tmp/cursor-restore", + state: "idle", + restarts: 0, + restartTimestamps: [], + createdAt: Date.now(), + updatedAt: Date.now(), + runtimeType: "transport", + providerId: "cursor-headless", + providerSessionId: "route-cursor-restore", + providerResumeId: "cursor-chat-restore", + requestedModel: "gpt-5.2", + activeModel: "gpt-5.2", + }); + + await connectProvider("cursor-headless", {}); + await restoreTransportSessions("cursor-headless"); + + const runtime = getTransportRuntime("deck_cursor_restore_brain"); + expect(runtime?.providerSessionId).toBe("route-cursor-restore"); + + runtime!.send("Verify cursor restore"); + await flush(); + const spawned = mocks.cursorSpawns.at(-1); + expect(spawned?.args).toContain("--resume"); + expect(spawned?.args).toContain("cursor-chat-restore"); + }); + + it("restores copilot-sdk sessions with persisted provider resume ids and sends on resumed continuity", async () => { + mocks.store.set("deck_copilot_restore_brain", { + name: "deck_copilot_restore_brain", + projectName: "copilotrestore", + role: "brain", + agentType: "copilot-sdk", + projectDir: "/tmp/copilot-restore", + state: "idle", + restarts: 0, + restartTimestamps: [], + createdAt: Date.now(), + updatedAt: Date.now(), + runtimeType: "transport", + providerId: "copilot-sdk", + providerSessionId: "route-copilot-restore", + providerResumeId: "copilot-session-restore", + requestedModel: "gpt-5.4", + activeModel: "gpt-5.4", + effort: "high", + }); + + await connectProvider("copilot-sdk", {}); + await restoreTransportSessions("copilot-sdk"); + + const runtime = getTransportRuntime("deck_copilot_restore_brain"); + expect(runtime?.providerSessionId).toBe("route-copilot-restore"); + + runtime!.send("Verify copilot restore"); + await flush(); + + expect(mocks.copilotRuns).toContainEqual( + expect.objectContaining({ + sessionId: "copilot-session-restore", + prompt: "Verify copilot restore", + }), + ); + }); + + it("skips unavailable provider restores without throwing and leaves the persisted session inspectable", async () => { + mocks.store.set("deck_missing_provider_brain", { + name: "deck_missing_provider_brain", + projectName: "missingprovider", + role: "brain", + agentType: "copilot-sdk", + projectDir: "/tmp/missing-provider", + state: "idle", + restarts: 0, + restartTimestamps: [], + createdAt: Date.now(), + updatedAt: Date.now(), + runtimeType: "transport", + providerId: "copilot-sdk", + providerSessionId: "route-missing-provider", + providerResumeId: "copilot-session-missing", + }); + + await expect( + restoreTransportSessions("copilot-sdk"), + ).resolves.toBeUndefined(); + expect(getTransportRuntime("deck_missing_provider_brain")).toBeUndefined(); + expect( + mocks.store.get("deck_missing_provider_brain")?.providerResumeId, + ).toBe("copilot-session-missing"); + }); +}); diff --git a/test/daemon/transport-relay.test.ts b/test/daemon/transport-relay.test.ts index 6eaf883e9..91f248787 100644 --- a/test/daemon/transport-relay.test.ts +++ b/test/daemon/transport-relay.test.ts @@ -40,7 +40,7 @@ import { appendTransportEvent } from '../../src/daemon/transport-history.js'; import type { TransportProvider } from '../../src/agent/transport-provider.js'; import type { AgentMessage, MessageDelta, ToolCallEvent } from '../../shared/agent-message.js'; -import { TRANSPORT_MSG } from '../../shared/transport-events.js'; +import { TRANSPORT_EVENT, TRANSPORT_MSG } from '../../shared/transport-events.js'; // ── Mock provider factory ──────────────────────────────────────────────────── @@ -49,6 +49,7 @@ type CompleteCb = (sessionId: string, message: AgentMessage) => void; type ErrorCb = (sessionId: string, error: { code: string; message: string; recoverable: boolean }) => void; type ToolCb = (sessionId: string, tool: ToolCallEvent) => void; type StatusCb = (sessionId: string, status: { status: string | null; label?: string | null }) => void; +type ApprovalCb = (sessionId: string, request: { id: string; description: string; tool?: string }) => void; function makeMockProvider() { let deltaCb: DeltaCb | undefined; @@ -56,6 +57,7 @@ function makeMockProvider() { let errorCb: ErrorCb | undefined; let toolCb: ToolCb | undefined; let statusCb: StatusCb | undefined; + let approvalCb: ApprovalCb | undefined; return { provider: { @@ -64,12 +66,14 @@ function makeMockProvider() { onError: (cb: ErrorCb) => { errorCb = cb; return () => { errorCb = undefined; }; }, onToolCall: (cb: ToolCb) => { toolCb = cb; }, onStatus: (cb: StatusCb) => { statusCb = cb; return () => { statusCb = undefined; }; }, + onApprovalRequest: (cb: ApprovalCb) => { approvalCb = cb; }, } as unknown as TransportProvider, fireDelta: (sid: string, delta: MessageDelta) => deltaCb?.(sid, delta), fireComplete: (sid: string, msg: AgentMessage) => completeCb?.(sid, msg), fireError: (sid: string, err: { code: string; message: string; recoverable: boolean }) => errorCb?.(sid, err), fireTool: (sid: string, tool: ToolCallEvent) => toolCb?.(sid, tool), fireStatus: (sid: string, status: { status: string | null; label?: string | null }) => statusCb?.(sid, status), + fireApproval: (sid: string, request: { id: string; description: string; tool?: string }) => approvalCb?.(sid, request), }; } @@ -800,6 +804,32 @@ describe('transport-relay (timeline-emitter based)', () => { ); }); }); + + describe('onApprovalRequest', () => { + it('broadcasts approval requests to transport subscribers and caches them', async () => { + const { provider, fireApproval } = makeMockProvider(); + wireProviderToRelay(provider); + + fireApproval('sess-approval', { + id: 'approval-1', + description: 'Allow file write', + tool: 'shell', + }); + await Promise.resolve(); + + expect(send).toHaveBeenCalledWith(expect.objectContaining({ + type: TRANSPORT_EVENT.CHAT_APPROVAL, + sessionId: 'sess-approval', + requestId: 'approval-1', + description: 'Allow file write', + tool: 'shell', + })); + expect(appendMock).toHaveBeenCalledWith('sess-approval', expect.objectContaining({ + type: TRANSPORT_EVENT.CHAT_APPROVAL, + requestId: 'approval-1', + })); + }); + }); }); // ── useTimeline same-ID replacement (logic extracted for unit testing) ─────── diff --git a/test/daemon/transport-session-runtime.test.ts b/test/daemon/transport-session-runtime.test.ts index 6ff99f180..d4f913d19 100644 --- a/test/daemon/transport-session-runtime.test.ts +++ b/test/daemon/transport-session-runtime.test.ts @@ -26,6 +26,7 @@ function makeMockProvider() { let deltaCb: ((sid: string, d: MessageDelta) => void) | null = null; let completeCb: ((sid: string, m: AgentMessage) => void) | null = null; let errorCb: ((sid: string, e: ProviderError) => void) | null = null; + let approvalCb: ((sid: string, req: { id: string; description: string; tool?: string }) => void) | null = null; const fireDelta = (sid: string) => deltaCb?.(sid, { messageId: 'msg', type: 'text', delta: 'x', role: 'assistant' }); @@ -33,6 +34,8 @@ function makeMockProvider() { completeCb?.(sid, { id: 'msg-1', sessionId: sid, kind: 'text', role: 'assistant', content: 'done', timestamp: Date.now(), status: 'complete' }); const fireError = (sid: string, err?: ProviderError) => errorCb?.(sid, err ?? { code: 'PROVIDER_ERROR', message: 'err', recoverable: false }); + const fireApproval = (sid: string, req: { id: string; description: string; tool?: string }) => + approvalCb?.(sid, req); return { provider: { @@ -43,8 +46,10 @@ function makeMockProvider() { onDelta: (cb: (sid: string, d: MessageDelta) => void) => { deltaCb = cb; return () => { deltaCb = null; }; }, onComplete: (cb: (sid: string, m: AgentMessage) => void) => { completeCb = cb; return () => { completeCb = null; }; }, onError: (cb: (sid: string, e: ProviderError) => void) => { errorCb = cb; return () => { errorCb = null; }; }, + onApprovalRequest: (cb: (sid: string, req: { id: string; description: string; tool?: string }) => void) => { approvalCb = cb; }, + respondApproval: vi.fn().mockResolvedValue(undefined), } as unknown as TransportProvider, - fireDelta, fireComplete, fireError, + fireDelta, fireComplete, fireError, fireApproval, }; } @@ -267,6 +272,34 @@ describe('TransportSessionRuntime', () => { }); }); + it('forwards approval requests through runtime callbacks', async () => { + const approvalMock = makeMockProvider(); + const runtimeWithApproval = new TransportSessionRuntime(approvalMock.provider, 'deck_test_brain'); + const approvalEvents: Array> = []; + runtimeWithApproval.onApprovalRequest = (request) => approvalEvents.push(request as Record); + await runtimeWithApproval.initialize(defaultConfig); + + approvalMock.fireApproval('sess-1', { + id: 'approval-1', + description: 'Allow file write', + tool: 'shell', + }); + + expect(approvalEvents).toEqual([ + { id: 'approval-1', description: 'Allow file write', tool: 'shell' }, + ]); + }); + + it('forwards approval responses to the provider', async () => { + const approvalMock = makeMockProvider(); + const runtimeWithApproval = new TransportSessionRuntime(approvalMock.provider, 'deck_test_brain'); + await runtimeWithApproval.initialize(defaultConfig); + + await runtimeWithApproval.respondApproval('approval-2', true); + + expect((approvalMock.provider as any).respondApproval).toHaveBeenCalledWith('sess-1', 'approval-2', true); + }); + it('refreshes shared-context bootstrap on each dispatch turn instead of freezing launch-time namespace state', async () => { const localMock = makeMockProvider(); const r = new TransportSessionRuntime(localMock.provider, 'x'); @@ -395,7 +428,8 @@ describe('TransportSessionRuntime', () => { query: expect.stringContaining('Please recall recent transport memory'), namespace: { scope: 'personal', projectId: 'repo-1' }, repo: 'repo-1', - limit: 5, + currentEnterpriseId: undefined, + limit: 10, })); expect(localMock.provider.send).toHaveBeenCalledWith('sess-1', expect.objectContaining({ memoryRecall: expect.objectContaining({ diff --git a/test/daemon/transport-types.test.ts b/test/daemon/transport-types.test.ts index 3a9857e99..fb97f4cea 100644 --- a/test/daemon/transport-types.test.ts +++ b/test/daemon/transport-types.test.ts @@ -4,7 +4,7 @@ * Verifies that all constant objects and runtime validation sets from the * shared transport modules contain the expected values. */ -import { describe, it, expect } from 'vitest'; +import { describe, it, expect } from "vitest"; import { AGENT_MESSAGE_KINDS, @@ -12,7 +12,7 @@ import { AGENT_MESSAGE_STATUSES, MESSAGE_DELTA_TYPES, AGENT_MESSAGE_TERMINAL_STATUSES, -} from '../../shared/agent-message.js'; +} from "../../shared/agent-message.js"; import { TRANSPORT_EVENT, @@ -20,216 +20,235 @@ import { TRANSPORT_AGENT_STATUSES, TRANSPORT_ACTIVE_STATUSES, TRANSPORT_RELAY_TYPES, -} from '../../shared/transport-events.js'; +} from "../../shared/transport-events.js"; import { CONNECTION_MODES, SESSION_OWNERSHIP, PROVIDER_ERROR_CODES, -} from '../../src/agent/transport-provider.js'; +} from "../../src/agent/transport-provider.js"; -import { RUNTIME_TYPES } from '../../src/agent/session-runtime.js'; +import { RUNTIME_TYPES } from "../../src/agent/session-runtime.js"; import { isTransportAgent, isProcessAgent, TRANSPORT_AGENTS, PROCESS_AGENTS, -} from '../../src/agent/detect.js'; +} from "../../src/agent/detect.js"; // ── shared/agent-message.ts ────────────────────────────────────────────────── -describe('shared/agent-message', () => { - it('AGENT_MESSAGE_KINDS contains all 5 kinds', () => { - const expected = ['text', 'tool_use', 'tool_result', 'system', 'approval']; +describe("shared/agent-message", () => { + it("AGENT_MESSAGE_KINDS contains all 5 kinds", () => { + const expected = ["text", "tool_use", "tool_result", "system", "approval"]; expect(AGENT_MESSAGE_KINDS.size).toBe(5); for (const kind of expected) { expect(AGENT_MESSAGE_KINDS.has(kind as any)).toBe(true); } }); - it('AGENT_MESSAGE_ROLES contains user, assistant, system', () => { - const expected = ['user', 'assistant', 'system']; + it("AGENT_MESSAGE_ROLES contains user, assistant, system", () => { + const expected = ["user", "assistant", "system"]; expect(AGENT_MESSAGE_ROLES.size).toBe(3); for (const role of expected) { expect(AGENT_MESSAGE_ROLES.has(role as any)).toBe(true); } }); - it('AGENT_MESSAGE_STATUSES contains streaming, complete, error', () => { - const expected = ['streaming', 'complete', 'error']; + it("AGENT_MESSAGE_STATUSES contains streaming, complete, error", () => { + const expected = ["streaming", "complete", "error"]; expect(AGENT_MESSAGE_STATUSES.size).toBe(3); for (const status of expected) { expect(AGENT_MESSAGE_STATUSES.has(status as any)).toBe(true); } }); - it('MESSAGE_DELTA_TYPES contains text, tool_use, tool_result', () => { - const expected = ['text', 'tool_use', 'tool_result']; + it("MESSAGE_DELTA_TYPES contains text, tool_use, tool_result", () => { + const expected = ["text", "tool_use", "tool_result"]; expect(MESSAGE_DELTA_TYPES.size).toBe(3); for (const type of expected) { expect(MESSAGE_DELTA_TYPES.has(type as any)).toBe(true); } }); - it('AGENT_MESSAGE_TERMINAL_STATUSES contains complete and error but NOT streaming', () => { - expect(AGENT_MESSAGE_TERMINAL_STATUSES.has('complete')).toBe(true); - expect(AGENT_MESSAGE_TERMINAL_STATUSES.has('error')).toBe(true); - expect(AGENT_MESSAGE_TERMINAL_STATUSES.has('streaming')).toBe(false); + it("AGENT_MESSAGE_TERMINAL_STATUSES contains complete and error but NOT streaming", () => { + expect(AGENT_MESSAGE_TERMINAL_STATUSES.has("complete")).toBe(true); + expect(AGENT_MESSAGE_TERMINAL_STATUSES.has("error")).toBe(true); + expect(AGENT_MESSAGE_TERMINAL_STATUSES.has("streaming")).toBe(false); expect(AGENT_MESSAGE_TERMINAL_STATUSES.size).toBe(2); }); }); // ── shared/transport-events.ts ─────────────────────────────────────────────── -describe('shared/transport-events', () => { - it('TRANSPORT_EVENT has correct values for all 6 event types', () => { - expect(TRANSPORT_EVENT.CHAT_DELTA).toBe('chat.delta'); - expect(TRANSPORT_EVENT.CHAT_COMPLETE).toBe('chat.complete'); - expect(TRANSPORT_EVENT.CHAT_ERROR).toBe('chat.error'); - expect(TRANSPORT_EVENT.CHAT_STATUS).toBe('chat.status'); - expect(TRANSPORT_EVENT.CHAT_TOOL).toBe('chat.tool'); - expect(TRANSPORT_EVENT.CHAT_APPROVAL).toBe('chat.approval'); +describe("shared/transport-events", () => { + it("TRANSPORT_EVENT has correct values for all 6 event types", () => { + expect(TRANSPORT_EVENT.CHAT_DELTA).toBe("chat.delta"); + expect(TRANSPORT_EVENT.CHAT_COMPLETE).toBe("chat.complete"); + expect(TRANSPORT_EVENT.CHAT_ERROR).toBe("chat.error"); + expect(TRANSPORT_EVENT.CHAT_STATUS).toBe("chat.status"); + expect(TRANSPORT_EVENT.CHAT_TOOL).toBe("chat.tool"); + expect(TRANSPORT_EVENT.CHAT_APPROVAL).toBe("chat.approval"); expect(Object.keys(TRANSPORT_EVENT)).toHaveLength(6); }); - it('TRANSPORT_MSG has correct values for all 5 message types', () => { - expect(TRANSPORT_MSG.CHAT_SUBSCRIBE).toBe('chat.subscribe'); - expect(TRANSPORT_MSG.CHAT_UNSUBSCRIBE).toBe('chat.unsubscribe'); - expect(TRANSPORT_MSG.PROVIDER_STATUS).toBe('provider.status'); - expect(TRANSPORT_MSG.LIST_SESSIONS).toBe('provider.list_sessions'); - expect(TRANSPORT_MSG.SESSIONS_RESPONSE).toBe('provider.sessions_response'); - expect(Object.keys(TRANSPORT_MSG)).toHaveLength(5); + it("TRANSPORT_MSG has correct values for all 6 message types", () => { + expect(TRANSPORT_MSG.CHAT_SUBSCRIBE).toBe("chat.subscribe"); + expect(TRANSPORT_MSG.CHAT_UNSUBSCRIBE).toBe("chat.unsubscribe"); + expect(TRANSPORT_MSG.PROVIDER_STATUS).toBe("provider.status"); + expect(TRANSPORT_MSG.LIST_SESSIONS).toBe("provider.list_sessions"); + expect(TRANSPORT_MSG.SESSIONS_RESPONSE).toBe("provider.sessions_response"); + expect(TRANSPORT_MSG.APPROVAL_RESPONSE).toBe("chat.approval_response"); + expect(Object.keys(TRANSPORT_MSG)).toHaveLength(7); }); - it('TRANSPORT_AGENT_STATUSES contains all 7 statuses', () => { - const expected = ['idle', 'streaming', 'thinking', 'tool_running', 'permission', 'error', 'unknown']; + it("TRANSPORT_AGENT_STATUSES contains all 7 statuses", () => { + const expected = [ + "idle", + "streaming", + "thinking", + "tool_running", + "permission", + "error", + "unknown", + ]; expect(TRANSPORT_AGENT_STATUSES.size).toBe(7); for (const status of expected) { expect(TRANSPORT_AGENT_STATUSES.has(status as any)).toBe(true); } }); - it('TRANSPORT_ACTIVE_STATUSES contains streaming, thinking, tool_running and NOT idle/permission/unknown', () => { - expect(TRANSPORT_ACTIVE_STATUSES.has('streaming')).toBe(true); - expect(TRANSPORT_ACTIVE_STATUSES.has('thinking')).toBe(true); - expect(TRANSPORT_ACTIVE_STATUSES.has('tool_running')).toBe(true); - expect(TRANSPORT_ACTIVE_STATUSES.has('idle')).toBe(false); - expect(TRANSPORT_ACTIVE_STATUSES.has('permission')).toBe(false); - expect(TRANSPORT_ACTIVE_STATUSES.has('unknown')).toBe(false); + it("TRANSPORT_ACTIVE_STATUSES contains streaming, thinking, tool_running and NOT idle/permission/unknown", () => { + expect(TRANSPORT_ACTIVE_STATUSES.has("streaming")).toBe(true); + expect(TRANSPORT_ACTIVE_STATUSES.has("thinking")).toBe(true); + expect(TRANSPORT_ACTIVE_STATUSES.has("tool_running")).toBe(true); + expect(TRANSPORT_ACTIVE_STATUSES.has("idle")).toBe(false); + expect(TRANSPORT_ACTIVE_STATUSES.has("permission")).toBe(false); + expect(TRANSPORT_ACTIVE_STATUSES.has("unknown")).toBe(false); expect(TRANSPORT_ACTIVE_STATUSES.size).toBe(3); }); - it('TRANSPORT_RELAY_TYPES contains all event types plus PROVIDER_STATUS', () => { + it("TRANSPORT_RELAY_TYPES contains all event types plus PROVIDER_STATUS", () => { // All 6 TRANSPORT_EVENT values - for (const key of Object.keys(TRANSPORT_EVENT) as (keyof typeof TRANSPORT_EVENT)[]) { + for (const key of Object.keys( + TRANSPORT_EVENT, + ) as (keyof typeof TRANSPORT_EVENT)[]) { expect(TRANSPORT_RELAY_TYPES.has(TRANSPORT_EVENT[key])).toBe(true); } + expect(TRANSPORT_RELAY_TYPES.has(TRANSPORT_MSG.APPROVAL_RESPONSE)).toBe(true); // Plus PROVIDER_STATUS from TRANSPORT_MSG expect(TRANSPORT_RELAY_TYPES.has(TRANSPORT_MSG.PROVIDER_STATUS)).toBe(true); - // Total: 6 events + 1 provider.status = 7 - expect(TRANSPORT_RELAY_TYPES.size).toBe(7); + // Total: 6 events + approval_response + provider.status = 8 + expect(TRANSPORT_RELAY_TYPES.size).toBe(8); }); }); // ── src/agent/transport-provider.ts ────────────────────────────────────────── -describe('src/agent/transport-provider', () => { - it('CONNECTION_MODES has persistent, per-request, local-sdk', () => { - expect(CONNECTION_MODES.PERSISTENT).toBe('persistent'); - expect(CONNECTION_MODES.PER_REQUEST).toBe('per-request'); - expect(CONNECTION_MODES.LOCAL_SDK).toBe('local-sdk'); +describe("src/agent/transport-provider", () => { + it("CONNECTION_MODES has persistent, per-request, local-sdk", () => { + expect(CONNECTION_MODES.PERSISTENT).toBe("persistent"); + expect(CONNECTION_MODES.PER_REQUEST).toBe("per-request"); + expect(CONNECTION_MODES.LOCAL_SDK).toBe("local-sdk"); expect(Object.keys(CONNECTION_MODES)).toHaveLength(3); }); - it('SESSION_OWNERSHIP has provider, local, shared', () => { - expect(SESSION_OWNERSHIP.PROVIDER).toBe('provider'); - expect(SESSION_OWNERSHIP.LOCAL).toBe('local'); - expect(SESSION_OWNERSHIP.SHARED).toBe('shared'); + it("SESSION_OWNERSHIP has provider, local, shared", () => { + expect(SESSION_OWNERSHIP.PROVIDER).toBe("provider"); + expect(SESSION_OWNERSHIP.LOCAL).toBe("local"); + expect(SESSION_OWNERSHIP.SHARED).toBe("shared"); expect(Object.keys(SESSION_OWNERSHIP)).toHaveLength(3); }); - it('PROVIDER_ERROR_CODES has all 9 codes', () => { - expect(PROVIDER_ERROR_CODES.AUTH_FAILED).toBe('AUTH_FAILED'); - expect(PROVIDER_ERROR_CODES.CONFIG_ERROR).toBe('CONFIG_ERROR'); - expect(PROVIDER_ERROR_CODES.CONNECTION_LOST).toBe('CONNECTION_LOST'); - expect(PROVIDER_ERROR_CODES.SESSION_NOT_FOUND).toBe('SESSION_NOT_FOUND'); - expect(PROVIDER_ERROR_CODES.RATE_LIMITED).toBe('RATE_LIMITED'); - expect(PROVIDER_ERROR_CODES.PROVIDER_ERROR).toBe('PROVIDER_ERROR'); - expect(PROVIDER_ERROR_CODES.CANCELLED).toBe('CANCELLED'); - expect(PROVIDER_ERROR_CODES.PARSE_ERROR).toBe('PARSE_ERROR'); - expect(PROVIDER_ERROR_CODES.PROVIDER_NOT_FOUND).toBe('PROVIDER_NOT_FOUND'); + it("PROVIDER_ERROR_CODES has all 9 codes", () => { + expect(PROVIDER_ERROR_CODES.AUTH_FAILED).toBe("AUTH_FAILED"); + expect(PROVIDER_ERROR_CODES.CONFIG_ERROR).toBe("CONFIG_ERROR"); + expect(PROVIDER_ERROR_CODES.CONNECTION_LOST).toBe("CONNECTION_LOST"); + expect(PROVIDER_ERROR_CODES.SESSION_NOT_FOUND).toBe("SESSION_NOT_FOUND"); + expect(PROVIDER_ERROR_CODES.RATE_LIMITED).toBe("RATE_LIMITED"); + expect(PROVIDER_ERROR_CODES.PROVIDER_ERROR).toBe("PROVIDER_ERROR"); + expect(PROVIDER_ERROR_CODES.CANCELLED).toBe("CANCELLED"); + expect(PROVIDER_ERROR_CODES.PARSE_ERROR).toBe("PARSE_ERROR"); + expect(PROVIDER_ERROR_CODES.PROVIDER_NOT_FOUND).toBe("PROVIDER_NOT_FOUND"); expect(Object.keys(PROVIDER_ERROR_CODES)).toHaveLength(9); }); }); // ── src/agent/session-runtime.ts ───────────────────────────────────────────── -describe('src/agent/session-runtime', () => { - it('RUNTIME_TYPES has process and transport', () => { - expect(RUNTIME_TYPES.PROCESS).toBe('process'); - expect(RUNTIME_TYPES.TRANSPORT).toBe('transport'); +describe("src/agent/session-runtime", () => { + it("RUNTIME_TYPES has process and transport", () => { + expect(RUNTIME_TYPES.PROCESS).toBe("process"); + expect(RUNTIME_TYPES.TRANSPORT).toBe("transport"); expect(Object.keys(RUNTIME_TYPES)).toHaveLength(2); }); }); // ── src/agent/detect.ts ────────────────────────────────────────────────────── -describe('src/agent/detect — transport/process classification', () => { - it('isTransportAgent returns true for openclaw', () => { - expect(isTransportAgent('openclaw')).toBe(true); +describe("src/agent/detect — transport/process classification", () => { + it("isTransportAgent returns true for openclaw", () => { + expect(isTransportAgent("openclaw")).toBe(true); }); - it('isTransportAgent returns true for qwen', () => { - expect(isTransportAgent('qwen')).toBe(true); + it("isTransportAgent returns true for qwen", () => { + expect(isTransportAgent("qwen")).toBe(true); }); - it('isTransportAgent returns false for claude-code', () => { - expect(isTransportAgent('claude-code')).toBe(false); + it("isTransportAgent returns false for claude-code", () => { + expect(isTransportAgent("claude-code")).toBe(false); }); - it('isTransportAgent returns true for claude-code-sdk', () => { - expect(isTransportAgent('claude-code-sdk')).toBe(true); + it("isTransportAgent returns true for claude-code-sdk", () => { + expect(isTransportAgent("claude-code-sdk")).toBe(true); }); - it('isTransportAgent returns true for codex-sdk', () => { - expect(isTransportAgent('codex-sdk')).toBe(true); + it("isTransportAgent returns true for codex-sdk", () => { + expect(isTransportAgent("codex-sdk")).toBe(true); }); - it('isProcessAgent returns true for claude-code', () => { - expect(isProcessAgent('claude-code')).toBe(true); + it("isProcessAgent returns true for claude-code", () => { + expect(isProcessAgent("claude-code")).toBe(true); }); - it('isProcessAgent returns false for openclaw', () => { - expect(isProcessAgent('openclaw')).toBe(false); + it("isProcessAgent returns false for openclaw", () => { + expect(isProcessAgent("openclaw")).toBe(false); }); - it('isProcessAgent returns false for qwen', () => { - expect(isProcessAgent('qwen')).toBe(false); + it("isProcessAgent returns false for qwen", () => { + expect(isProcessAgent("qwen")).toBe(false); }); - it('TRANSPORT_AGENTS contains openclaw', () => { - expect(TRANSPORT_AGENTS.has('openclaw')).toBe(true); + it("TRANSPORT_AGENTS contains openclaw", () => { + expect(TRANSPORT_AGENTS.has("openclaw")).toBe(true); }); - it('TRANSPORT_AGENTS contains qwen', () => { - expect(TRANSPORT_AGENTS.has('qwen')).toBe(true); + it("TRANSPORT_AGENTS contains qwen", () => { + expect(TRANSPORT_AGENTS.has("qwen")).toBe(true); }); - it('TRANSPORT_AGENTS contains claude-code-sdk and codex-sdk', () => { - expect(TRANSPORT_AGENTS.has('claude-code-sdk')).toBe(true); - expect(TRANSPORT_AGENTS.has('codex-sdk')).toBe(true); + it("TRANSPORT_AGENTS contains claude-code-sdk and codex-sdk", () => { + expect(TRANSPORT_AGENTS.has("claude-code-sdk")).toBe(true); + expect(TRANSPORT_AGENTS.has("codex-sdk")).toBe(true); }); - it('PROCESS_AGENTS contains all process agent types', () => { - const expected = ['claude-code', 'codex', 'opencode', 'shell', 'script', 'gemini']; + it("PROCESS_AGENTS contains all process agent types", () => { + const expected = [ + "claude-code", + "codex", + "opencode", + "shell", + "script", + "gemini", + ]; expect(PROCESS_AGENTS.size).toBe(6); for (const agent of expected) { expect(PROCESS_AGENTS.has(agent as any)).toBe(true); } }); - it('TRANSPORT_AGENTS and PROCESS_AGENTS are disjoint', () => { + it("TRANSPORT_AGENTS and PROCESS_AGENTS are disjoint", () => { for (const agent of TRANSPORT_AGENTS) { expect(PROCESS_AGENTS.has(agent as any)).toBe(false); } diff --git a/test/e2e/copilot-sdk-live.test.ts b/test/e2e/copilot-sdk-live.test.ts new file mode 100644 index 000000000..4a56ab1fa --- /dev/null +++ b/test/e2e/copilot-sdk-live.test.ts @@ -0,0 +1,192 @@ +import { afterEach, beforeEach, describe, expect, it } from "vitest"; +import { mkdtemp, writeFile } from "node:fs/promises"; +import { tmpdir } from "node:os"; +import { join } from "node:path"; +import { CopilotSdkProvider } from "../../src/agent/providers/copilot-sdk.js"; +import type { + ApprovalRequest, + ProviderError, + SessionInfoUpdate, +} from "../../src/agent/transport-provider.js"; + +const RUN = process.env.RUN_COPILOT_LIVE === "1"; +const TIMEOUT_MS = 90_000; + +function waitForCompletion( + provider: CopilotSdkProvider, + sessionId: string, +): Promise { + return new Promise((resolve, reject) => { + const offComplete = provider.onComplete((sid, message) => { + if (sid !== sessionId) return; + offComplete(); + offError(); + resolve(String(message.content ?? "")); + }); + const offError = provider.onError((sid, error) => { + if (sid !== sessionId) return; + offComplete(); + offError(); + reject(Object.assign(new Error(error.message), { code: error.code })); + }); + }); +} + +function waitForInfo( + provider: CopilotSdkProvider, + sessionId: string, + predicate: (info: SessionInfoUpdate) => boolean, +): Promise { + return new Promise((resolve, reject) => { + const off = provider.onSessionInfo((sid, info) => { + if (sid !== sessionId || !predicate(info)) return; + off(); + resolve(info); + }); + setTimeout(() => { + off(); + reject(new Error("Timed out waiting for Copilot session info update")); + }, 20_000); + }); +} + +function waitForCancel( + provider: CopilotSdkProvider, + sessionId: string, +): Promise { + return new Promise((resolve, reject) => { + const offError = provider.onError((sid, error) => { + if (sid !== sessionId || error.code !== "CANCELLED") return; + offError(); + resolve(error); + }); + setTimeout(() => { + offError(); + reject(new Error("Timed out waiting for Copilot cancellation")); + }, 20_000); + }); +} + +function waitForToolStart( + provider: CopilotSdkProvider, + sessionId: string, + predicate: (toolName: string, input: unknown) => boolean, +): Promise { + return new Promise((resolve, reject) => { + let settled = false; + provider.onToolCall((sid, tool) => { + if (settled) return; + if (sid !== sessionId || tool.status !== "running") return; + if (!predicate(String(tool.name ?? ""), tool.input)) return; + settled = true; + resolve(); + }); + setTimeout(() => { + if (settled) return; + settled = true; + reject(new Error("Timed out waiting for Copilot tool start")); + }, 30_000); + }); +} + +describe.skipIf(!RUN)("copilot-sdk live transport", () => { + let provider: CopilotSdkProvider; + let sessionId: string; + let latestResumeId = ""; + let tempDir = ""; + + beforeEach(async () => { + provider = new CopilotSdkProvider(); + provider.onApprovalRequest((sid, req: ApprovalRequest) => { + void provider.respondApproval(sid, req.id, true); + }); + provider.onSessionInfo((sid, info) => { + if (sid === sessionId && info.resumeId) latestResumeId = info.resumeId; + }); + await provider.connect({ + binaryPath: process.env.COPILOT_BIN_PATH, + approvalTimeoutMs: 20_000, + }); + sessionId = await provider.createSession({ + sessionKey: `copilot-live-${Date.now()}`, + cwd: process.cwd(), + agentId: process.env.COPILOT_LIVE_MODEL || "gpt-5.4", + effort: "high", + }); + tempDir = await mkdtemp(join(tmpdir(), "copilot-live-")); + }, TIMEOUT_MS); + + afterEach(async () => { + await provider.disconnect(); + }); + + it( + "supports attachments and multi-turn resume", + async () => { + const attachmentPath = join(tempDir, "transport-live.txt"); + await writeFile(attachmentPath, "COPILOT_ATTACHMENT_OK\n", "utf8"); + + const first = waitForCompletion(provider, sessionId); + await provider.send( + sessionId, + "Read the attached file and reply with exactly COPILOT_ATTACHMENT_OK and nothing else.", + [ + { + id: "att-1", + daemonPath: attachmentPath, + originalName: "transport-live.txt", + type: "file", + }, + ], + ); + await expect(first).resolves.toContain("COPILOT_ATTACHMENT_OK"); + + const second = waitForCompletion(provider, sessionId); + await provider.send( + sessionId, + "Without explanation, reply exactly COPILOT_LIVE_RESUME_OK if the previous final answer in this conversation was COPILOT_ATTACHMENT_OK, otherwise reply COPILOT_LIVE_RESUME_NO.", + ); + await expect(second).resolves.toContain("COPILOT_LIVE_RESUME_OK"); + }, + TIMEOUT_MS, + ); + + it( + "rotates away from background-tainted aborts before the next turn", + async () => { + const originalResume = latestResumeId; + const toolStarted = waitForToolStart( + provider, + sessionId, + (toolName, input) => + toolName.toLowerCase() === "bash" + && typeof input === "object" + && input !== null + && String((input as Record).command ?? "").includes("COPILOT_BG_STARTED"), + ); + await provider.send( + sessionId, + 'Use shell immediately to run: nohup sh -c "sleep 30" >/tmp/copilot-bg.log 2>&1 & echo COPILOT_BG_STARTED. After starting the background process, do not wait for it; just say COPILOT_BG_STARTED.', + ); + await toolStarted; + const cancelled = waitForCancel(provider, sessionId); + const rotatedInfo = waitForInfo( + provider, + sessionId, + (info) => !!info.resumeId && info.resumeId !== originalResume, + ); + await provider.cancel(sessionId); + await expect(cancelled).resolves.toMatchObject({ code: "CANCELLED" }); + const info = await rotatedInfo; + expect(info.resumeId).not.toBe(originalResume); + + const followup = waitForCompletion(provider, sessionId); + await provider.send( + sessionId, + "Reply with exactly COPILOT_POST_ABORT_OK and nothing else.", + ); + await expect(followup).resolves.toContain("COPILOT_POST_ABORT_OK"); + }, + TIMEOUT_MS, + ); +}); diff --git a/test/e2e/cursor-headless-live.test.ts b/test/e2e/cursor-headless-live.test.ts new file mode 100644 index 000000000..c2a9e417c --- /dev/null +++ b/test/e2e/cursor-headless-live.test.ts @@ -0,0 +1,104 @@ +import { afterEach, beforeEach, describe, expect, it } from "vitest"; +import { CursorHeadlessProvider } from "../../src/agent/providers/cursor-headless.js"; +import type { + ProviderError, + SessionInfoUpdate, +} from "../../src/agent/transport-provider.js"; + +const RUN = process.env.RUN_CURSOR_LIVE === "1"; +const TIMEOUT_MS = 60_000; + +function waitForCompletion( + provider: CursorHeadlessProvider, + sessionId: string, +): Promise { + return new Promise((resolve, reject) => { + const offComplete = provider.onComplete((sid, message) => { + if (sid !== sessionId) return; + offComplete(); + offError(); + resolve(String(message.content ?? "")); + }); + const offError = provider.onError((sid, error) => { + if (sid !== sessionId) return; + offComplete(); + offError(); + reject(Object.assign(new Error(error.message), { code: error.code })); + }); + }); +} + +function waitForCancel( + provider: CursorHeadlessProvider, + sessionId: string, +): Promise { + return new Promise((resolve, reject) => { + const offError = provider.onError((sid, error) => { + if (sid !== sessionId || error.code !== "CANCELLED") return; + offError(); + resolve(error); + }); + setTimeout(() => { + offError(); + reject(new Error("Timed out waiting for Cursor cancellation")); + }, 10_000); + }); +} + +describe.skipIf(!RUN)("cursor-headless live transport", () => { + let provider: CursorHeadlessProvider; + let sessionId: string; + + beforeEach(async () => { + provider = new CursorHeadlessProvider(); + await provider.connect({ + binaryPath: process.env.CURSOR_BIN_PATH, + force: true, + trust: true, + }); + sessionId = await provider.createSession({ + sessionKey: `cursor-live-${Date.now()}`, + cwd: process.cwd(), + agentId: process.env.CURSOR_LIVE_MODEL || "gpt-5.2", + }); + }, TIMEOUT_MS); + + afterEach(async () => { + await provider.disconnect(); + }); + + it( + "supports multi-turn resume and explicit tool-mediated answers", + async () => { + const first = waitForCompletion(provider, sessionId); + await provider.send( + sessionId, + "Use shell if needed, then reply with exactly CURSOR_LIVE_OK and nothing else.", + ); + await expect(first).resolves.toContain("CURSOR_LIVE_OK"); + + const second = waitForCompletion(provider, sessionId); + await provider.send( + sessionId, + "Without explanation, reply exactly CURSOR_LIVE_RESUME_OK if your previous final answer in this conversation was CURSOR_LIVE_OK, otherwise reply CURSOR_LIVE_RESUME_NO.", + ); + await expect(second).resolves.toContain("CURSOR_LIVE_RESUME_OK"); + }, + TIMEOUT_MS, + ); + + it( + "supports deterministic process-kill cancellation", + async () => { + await provider.send( + sessionId, + "Run a long task and do not finish quickly.", + ); + const cancelled = waitForCancel(provider, sessionId); + await new Promise((resolve) => setTimeout(resolve, 1500)); + await provider.cancel(sessionId); + await expect(cancelled).resolves.toMatchObject({ code: "CANCELLED" }); + }, + TIMEOUT_MS, + ); +}); diff --git a/test/e2e/cursor-headless-transport.test.ts b/test/e2e/cursor-headless-transport.test.ts new file mode 100644 index 000000000..eda700879 --- /dev/null +++ b/test/e2e/cursor-headless-transport.test.ts @@ -0,0 +1,200 @@ +import { EventEmitter } from 'node:events'; +import { PassThrough } from 'node:stream'; +import { afterEach, describe, expect, it, vi, beforeEach } from 'vitest'; + +const cursorHarness = vi.hoisted(() => { + const state = { + versionOutput: 'Cursor Agent 1.0.0\n', + statusOutput: 'Logged in\n', + createChatOutput: 'cursor-e2e-chat-1\n', + statusError: null as Error | null, + createChatError: null as Error | null, + }; + const spawned: Array<{ + file: string; + args: string[]; + cwd?: string; + env?: NodeJS.ProcessEnv; + child: EventEmitter & { + stdout: PassThrough; + stderr: PassThrough; + stdin: PassThrough; + kill: ReturnType; + killed: boolean; + }; + }> = []; + const execFile = vi.fn((file: string, args: string[], optsOrCb?: unknown, maybeCb?: unknown) => { + const cb = typeof optsOrCb === 'function' + ? optsOrCb as (err: Error | null, stdout: string, stderr: string) => void + : maybeCb as ((err: Error | null, stdout: string, stderr: string) => void) | undefined; + if (args.includes('--version')) { + cb?.(null, state.versionOutput, ''); + return {} as never; + } + if (args[0] === 'status') { + if (state.statusError) cb?.(state.statusError, '', ''); + else cb?.(null, state.statusOutput, ''); + return {} as never; + } + if (args[0] === 'create-chat') { + if (state.createChatError) cb?.(state.createChatError, '', ''); + else cb?.(null, state.createChatOutput, ''); + return {} as never; + } + cb?.(null, '', ''); + return {} as never; + }); + const spawn = vi.fn((file: string, args: string[], opts: { cwd?: string; env?: NodeJS.ProcessEnv }) => { + const stdout = new PassThrough(); + const stderr = new PassThrough(); + const stdin = new PassThrough(); + const child = new EventEmitter() as EventEmitter & { + stdout: PassThrough; + stderr: PassThrough; + stdin: PassThrough; + kill: ReturnType; + killed: boolean; + }; + child.stdout = stdout; + child.stderr = stderr; + child.stdin = stdin; + child.killed = false; + child.kill = vi.fn((signal?: string) => { + child.killed = true; + queueMicrotask(() => child.emit('close', 0, signal ?? 'SIGTERM')); + return true; + }); + spawned.push({ file, args, cwd: opts.cwd, env: opts.env, child }); + queueMicrotask(() => child.emit('spawn')); + return child as never; + }); + return { + state, + spawned, + execFile, + spawn, + lastSpawn(): (typeof spawned)[number] { + const entry = spawned.at(-1); + if (!entry) throw new Error('No Cursor spawn recorded'); + return entry; + }, + async flush(): Promise { + await new Promise((resolve) => setTimeout(resolve, 0)); + }, + }; +}); + +vi.mock('../../src/util/logger.js', () => ({ + default: { + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + debug: vi.fn(), + }, +})); + +import { + CursorHeadlessProvider, + cursorHeadlessRuntimeHooks, +} from '../../src/agent/providers/cursor-headless.js'; +import type { ProviderContextPayload } from '../../shared/context-types.js'; + +describe('Cursor headless transport (e2e)', () => { + const originalLoadChildProcess = cursorHeadlessRuntimeHooks.loadChildProcess; + + beforeEach(() => { + cursorHeadlessRuntimeHooks.loadChildProcess = async () => ({ + execFile: cursorHarness.execFile, + spawn: cursorHarness.spawn, + } as typeof import('node:child_process')); + cursorHarness.spawn.mockClear(); + cursorHarness.execFile.mockClear(); + cursorHarness.spawned.length = 0; + cursorHarness.state.versionOutput = 'Cursor Agent 1.0.0\n'; + cursorHarness.state.statusOutput = 'Logged in\n'; + cursorHarness.state.createChatOutput = 'cursor-e2e-chat-1\n'; + cursorHarness.state.statusError = null; + cursorHarness.state.createChatError = null; + }); + + afterEach(() => { + cursorHeadlessRuntimeHooks.loadChildProcess = originalLoadChildProcess; + }); + + it('creates a session, streams a turn, cancels cleanly, and preserves restoreability for the known session id', async () => { + const provider = new CursorHeadlessProvider(); + await provider.connect({ binaryPath: 'cursor-agent' }); + + const sessionId = await provider.createSession({ + sessionKey: 'cursor-e2e-route', + cwd: '/tmp/project', + agentId: 'gpt-5.2', + }); + + const deltas: string[] = []; + const completed: string[] = []; + const errors: Array> = []; + const tools: Array<{ status: string }> = []; + provider.onDelta((_sid, delta) => deltas.push(delta.delta)); + provider.onComplete((_sid, msg) => completed.push(String(msg.content))); + provider.onError((_sid, error) => errors.push(error as Record)); + provider.onToolCall((_sid, tool) => tools.push({ status: tool.status })); + + await provider.send(sessionId, { + userMessage: 'run the probe', + assembledMessage: 'Context block\n\nrun the probe', + systemText: 'Probe the repo and then respond with PROBE_OK', + messagePreamble: 'Context block', + attachments: [], + context: { + systemText: 'Probe the repo and then respond with PROBE_OK', + messagePreamble: 'Context block', + requiredAuthoredContext: [], + advisoryAuthoredContext: [], + appliedDocumentVersionIds: [], + diagnostics: [], + }, + authority: { + namespace: { scope: 'personal', projectId: 'cursor-e2e-route' }, + authoritySource: 'none', + freshness: 'missing', + fallbackAllowed: true, + retryScheduled: false, + diagnostics: [], + }, + supportClass: 'degraded-message-side-context-mapping', + diagnostics: [], + } satisfies ProviderContextPayload); + + const firstSpawn = cursorHarness.lastSpawn(); + expect(firstSpawn.args).toContain('--resume'); + expect(firstSpawn.args).toContain('cursor-e2e-chat-1'); + expect(firstSpawn.args.at(-1)).toContain('run the probe'); + expect(sessionId).toBe('cursor-e2e-route'); + + firstSpawn.child.stdout.write(`${JSON.stringify({ type: 'system.init', session_id: 'cursor-e2e-chat-1', model: 'gpt-5.2', permissionMode: 'default' })}\n`); + firstSpawn.child.stdout.write(`${JSON.stringify({ type: 'stream_event', session_id: 'cursor-e2e-chat-1', event: { type: 'content_block_delta', delta: { type: 'text_delta', text: 'PRO' } } })}\n`); + firstSpawn.child.stdout.write(`${JSON.stringify({ type: 'stream_event', session_id: 'cursor-e2e-chat-1', event: { type: 'content_block_delta', delta: { type: 'text_delta', text: 'PROBE_' } } })}\n`); + firstSpawn.child.stdout.write(`${JSON.stringify({ type: 'tool_call.started', session_id: 'cursor-e2e-chat-1', id: 'tool-e2e-1', name: 'shell', input: { command: 'echo PROBE_OK' } })}\n`); + firstSpawn.child.stdout.write(`${JSON.stringify({ type: 'tool_call.completed', session_id: 'cursor-e2e-chat-1', id: 'tool-e2e-1', name: 'shell', output: 'PROBE_OK' })}\n`); + firstSpawn.child.stdout.write(`${JSON.stringify({ type: 'assistant', session_id: 'cursor-e2e-chat-1', message: { id: 'msg-e2e-1', content: [{ type: 'text', text: 'PROBE_OK' }] } })}\n`); + firstSpawn.child.stdout.write(`${JSON.stringify({ type: 'result.success', session_id: 'cursor-e2e-chat-1', result: 'PROBE_OK', usage: { input_tokens: 9, output_tokens: 4 } })}\n`); + firstSpawn.child.emit('close', 0, null); + await cursorHarness.flush(); + + expect(deltas).toEqual(['PRO', 'PROBE_']); + expect(completed).toEqual(['PROBE_OK']); + expect(tools).toEqual([{ status: 'running' }, { status: 'complete' }]); + expect(errors).toEqual([]); + await expect(provider.restoreSession(sessionId)).resolves.toBe(true); + + const cancelTurn = provider.send(sessionId, 'stop this turn'); + await cursorHarness.flush(); + await provider.cancel(sessionId); + await cancelTurn; + await cursorHarness.flush(); + + expect(cursorHarness.lastSpawn().child.killed).toBe(true); + expect(errors.some((error) => error.code === 'CANCELLED')).toBe(true); + }); +}); diff --git a/test/shared/transport-types-contract.test.ts b/test/shared/transport-types-contract.test.ts index 25cb84d34..a7a72e1c3 100644 --- a/test/shared/transport-types-contract.test.ts +++ b/test/shared/transport-types-contract.test.ts @@ -1,127 +1,141 @@ -import { describe, it, expect, expectTypeOf } from 'vitest'; +import { describe, it, expect, expectTypeOf } from "vitest"; import { AGENT_MESSAGE_KINDS, AGENT_MESSAGE_STATUSES, AGENT_MESSAGE_TERMINAL_STATUSES, type AgentMessageKind, type AgentMessageStatus, -} from '../../shared/agent-message.js'; +} from "../../shared/agent-message.js"; import { TRANSPORT_EVENT, TRANSPORT_MSG, TRANSPORT_RELAY_TYPES, -} from '../../shared/transport-events.js'; +} from "../../shared/transport-events.js"; // ── TRANSPORT_EVENT ──────────────────────────────────────────────────────────── -describe('TRANSPORT_EVENT constant', () => { - it('has all expected keys', () => { +describe("TRANSPORT_EVENT constant", () => { + it("has all expected keys", () => { const expectedKeys = [ - 'CHAT_DELTA', - 'CHAT_COMPLETE', - 'CHAT_ERROR', - 'CHAT_STATUS', - 'CHAT_TOOL', - 'CHAT_APPROVAL', + "CHAT_DELTA", + "CHAT_COMPLETE", + "CHAT_ERROR", + "CHAT_STATUS", + "CHAT_TOOL", + "CHAT_APPROVAL", ]; for (const key of expectedKeys) { expect(TRANSPORT_EVENT).toHaveProperty(key); } }); - it('has exactly the expected number of keys', () => { + it("has exactly the expected number of keys", () => { expect(Object.keys(TRANSPORT_EVENT)).toHaveLength(6); }); - it('has no duplicate values', () => { + it("has no duplicate values", () => { const values = Object.values(TRANSPORT_EVENT); const unique = new Set(values); expect(unique.size).toBe(values.length); }); - it('values are correctly mapped', () => { - expect(TRANSPORT_EVENT.CHAT_DELTA).toBe('chat.delta'); - expect(TRANSPORT_EVENT.CHAT_COMPLETE).toBe('chat.complete'); - expect(TRANSPORT_EVENT.CHAT_ERROR).toBe('chat.error'); - expect(TRANSPORT_EVENT.CHAT_STATUS).toBe('chat.status'); - expect(TRANSPORT_EVENT.CHAT_TOOL).toBe('chat.tool'); - expect(TRANSPORT_EVENT.CHAT_APPROVAL).toBe('chat.approval'); + it("values are correctly mapped", () => { + expect(TRANSPORT_EVENT.CHAT_DELTA).toBe("chat.delta"); + expect(TRANSPORT_EVENT.CHAT_COMPLETE).toBe("chat.complete"); + expect(TRANSPORT_EVENT.CHAT_ERROR).toBe("chat.error"); + expect(TRANSPORT_EVENT.CHAT_STATUS).toBe("chat.status"); + expect(TRANSPORT_EVENT.CHAT_TOOL).toBe("chat.tool"); + expect(TRANSPORT_EVENT.CHAT_APPROVAL).toBe("chat.approval"); }); }); // ── TRANSPORT_MSG ────────────────────────────────────────────────────────────── -describe('TRANSPORT_MSG constant', () => { - it('has all expected keys', () => { - const expectedKeys = ['CHAT_SUBSCRIBE', 'CHAT_UNSUBSCRIBE', 'PROVIDER_STATUS', 'LIST_SESSIONS', 'SESSIONS_RESPONSE']; +describe("TRANSPORT_MSG constant", () => { + it("has all expected keys", () => { + const expectedKeys = [ + "CHAT_SUBSCRIBE", + "CHAT_UNSUBSCRIBE", + "APPROVAL_RESPONSE", + "PROVIDER_STATUS", + "LIST_SESSIONS", + "SESSIONS_RESPONSE", + ]; for (const key of expectedKeys) { expect(TRANSPORT_MSG).toHaveProperty(key); } }); - it('has exactly the expected number of keys', () => { - expect(Object.keys(TRANSPORT_MSG)).toHaveLength(5); + it("has exactly the expected number of keys", () => { + expect(Object.keys(TRANSPORT_MSG)).toHaveLength(7); }); - it('has no duplicate values', () => { + it("has no duplicate values", () => { const values = Object.values(TRANSPORT_MSG); const unique = new Set(values); expect(unique.size).toBe(values.length); }); - it('values are correctly mapped', () => { - expect(TRANSPORT_MSG.CHAT_SUBSCRIBE).toBe('chat.subscribe'); - expect(TRANSPORT_MSG.CHAT_UNSUBSCRIBE).toBe('chat.unsubscribe'); - expect(TRANSPORT_MSG.PROVIDER_STATUS).toBe('provider.status'); - expect(TRANSPORT_MSG.LIST_SESSIONS).toBe('provider.list_sessions'); - expect(TRANSPORT_MSG.SESSIONS_RESPONSE).toBe('provider.sessions_response'); + it("values are correctly mapped", () => { + expect(TRANSPORT_MSG.CHAT_SUBSCRIBE).toBe("chat.subscribe"); + expect(TRANSPORT_MSG.CHAT_UNSUBSCRIBE).toBe("chat.unsubscribe"); + expect(TRANSPORT_MSG.APPROVAL_RESPONSE).toBe("chat.approval_response"); + expect(TRANSPORT_MSG.PROVIDER_STATUS).toBe("provider.status"); + expect(TRANSPORT_MSG.LIST_SESSIONS).toBe("provider.list_sessions"); + expect(TRANSPORT_MSG.SESSIONS_RESPONSE).toBe("provider.sessions_response"); }); }); // ── TRANSPORT_RELAY_TYPES ────────────────────────────────────────────────────── -describe('TRANSPORT_RELAY_TYPES set', () => { - it('contains all TRANSPORT_EVENT values', () => { +describe("TRANSPORT_RELAY_TYPES set", () => { + it("contains all TRANSPORT_EVENT values", () => { for (const value of Object.values(TRANSPORT_EVENT)) { expect(TRANSPORT_RELAY_TYPES.has(value)).toBe(true); } }); - it('contains PROVIDER_STATUS from TRANSPORT_MSG', () => { + it("contains PROVIDER_STATUS from TRANSPORT_MSG", () => { expect(TRANSPORT_RELAY_TYPES.has(TRANSPORT_MSG.PROVIDER_STATUS)).toBe(true); }); - it('does not contain CHAT_SUBSCRIBE or CHAT_UNSUBSCRIBE (browser-only control msgs)', () => { + it("contains APPROVAL_RESPONSE from TRANSPORT_MSG", () => { + expect(TRANSPORT_RELAY_TYPES.has(TRANSPORT_MSG.APPROVAL_RESPONSE)).toBe(true); + }); + + it("does not contain CHAT_SUBSCRIBE or CHAT_UNSUBSCRIBE (browser-only control msgs)", () => { expect(TRANSPORT_RELAY_TYPES.has(TRANSPORT_MSG.CHAT_SUBSCRIBE)).toBe(false); - expect(TRANSPORT_RELAY_TYPES.has(TRANSPORT_MSG.CHAT_UNSUBSCRIBE)).toBe(false); + expect(TRANSPORT_RELAY_TYPES.has(TRANSPORT_MSG.CHAT_UNSUBSCRIBE)).toBe( + false, + ); }); - it('contains exactly 7 entries (6 events + PROVIDER_STATUS)', () => { - expect(TRANSPORT_RELAY_TYPES.size).toBe(7); + it("contains exactly 8 entries (6 events + approval response + PROVIDER_STATUS)", () => { + expect(TRANSPORT_RELAY_TYPES.size).toBe(8); }); }); // ── AGENT_MESSAGE_KINDS ──────────────────────────────────────────────────────── -describe('AGENT_MESSAGE_KINDS set', () => { - it('contains all expected kinds', () => { +describe("AGENT_MESSAGE_KINDS set", () => { + it("contains all expected kinds", () => { const expectedKinds: AgentMessageKind[] = [ - 'text', - 'tool_use', - 'tool_result', - 'system', - 'approval', + "text", + "tool_use", + "tool_result", + "system", + "approval", ]; for (const kind of expectedKinds) { expect(AGENT_MESSAGE_KINDS.has(kind)).toBe(true); } }); - it('has exactly 5 entries', () => { + it("has exactly 5 entries", () => { expect(AGENT_MESSAGE_KINDS.size).toBe(5); }); - it('has no duplicates (Set invariant holds)', () => { + it("has no duplicates (Set invariant holds)", () => { // A Set by definition cannot contain duplicates; verify via array round-trip const arr = Array.from(AGENT_MESSAGE_KINDS); expect(new Set(arr).size).toBe(arr.length); @@ -130,19 +144,23 @@ describe('AGENT_MESSAGE_KINDS set', () => { // ── AGENT_MESSAGE_STATUSES ───────────────────────────────────────────────────── -describe('AGENT_MESSAGE_STATUSES set', () => { - it('contains all expected statuses', () => { - const expectedStatuses: AgentMessageStatus[] = ['streaming', 'complete', 'error']; +describe("AGENT_MESSAGE_STATUSES set", () => { + it("contains all expected statuses", () => { + const expectedStatuses: AgentMessageStatus[] = [ + "streaming", + "complete", + "error", + ]; for (const status of expectedStatuses) { expect(AGENT_MESSAGE_STATUSES.has(status)).toBe(true); } }); - it('has exactly 3 entries', () => { + it("has exactly 3 entries", () => { expect(AGENT_MESSAGE_STATUSES.size).toBe(3); }); - it('has no duplicate values (Set invariant)', () => { + it("has no duplicate values (Set invariant)", () => { const arr = Array.from(AGENT_MESSAGE_STATUSES); expect(new Set(arr).size).toBe(arr.length); }); @@ -150,26 +168,28 @@ describe('AGENT_MESSAGE_STATUSES set', () => { // ── AGENT_MESSAGE_TERMINAL_STATUSES ─────────────────────────────────────────── -describe('AGENT_MESSAGE_TERMINAL_STATUSES set', () => { - it('contains complete and error', () => { - expect(AGENT_MESSAGE_TERMINAL_STATUSES.has('complete')).toBe(true); - expect(AGENT_MESSAGE_TERMINAL_STATUSES.has('error')).toBe(true); +describe("AGENT_MESSAGE_TERMINAL_STATUSES set", () => { + it("contains complete and error", () => { + expect(AGENT_MESSAGE_TERMINAL_STATUSES.has("complete")).toBe(true); + expect(AGENT_MESSAGE_TERMINAL_STATUSES.has("error")).toBe(true); }); - it('does not contain streaming', () => { - expect(AGENT_MESSAGE_TERMINAL_STATUSES.has('streaming')).toBe(false); + it("does not contain streaming", () => { + expect(AGENT_MESSAGE_TERMINAL_STATUSES.has("streaming")).toBe(false); }); - it('is a strict subset of AGENT_MESSAGE_STATUSES', () => { + it("is a strict subset of AGENT_MESSAGE_STATUSES", () => { for (const status of AGENT_MESSAGE_TERMINAL_STATUSES) { expect(AGENT_MESSAGE_STATUSES.has(status)).toBe(true); } - expect(AGENT_MESSAGE_TERMINAL_STATUSES.size).toBeLessThan(AGENT_MESSAGE_STATUSES.size); + expect(AGENT_MESSAGE_TERMINAL_STATUSES.size).toBeLessThan( + AGENT_MESSAGE_STATUSES.size, + ); }); - it('type-level: AgentMessageStatus is assignable to the terminal status union', () => { + it("type-level: AgentMessageStatus is assignable to the terminal status union", () => { // 'complete' and 'error' are valid AgentMessageStatus values - expectTypeOf<'complete'>().toMatchTypeOf(); - expectTypeOf<'error'>().toMatchTypeOf(); + expectTypeOf<"complete">().toMatchTypeOf(); + expectTypeOf<"error">().toMatchTypeOf(); }); }); diff --git a/web/src/components/NewSessionDialog.tsx b/web/src/components/NewSessionDialog.tsx index 3449af7bd..425523980 100644 --- a/web/src/components/NewSessionDialog.tsx +++ b/web/src/components/NewSessionDialog.tsx @@ -1,12 +1,19 @@ -import { useState, useEffect } from 'preact/hooks'; -import { useTranslation } from 'react-i18next'; -import type { WsClient } from '../ws-client.js'; -import { FileBrowser } from './file-browser-lazy.js'; -import { getUserPref, saveUserPref } from '../api.js'; -import { sanitizeProjectName } from '@shared/sanitize-project-name.js'; -import { CLAUDE_SDK_EFFORT_LEVELS, CODEX_SDK_EFFORT_LEVELS, OPENCLAW_THINKING_LEVELS, QWEN_EFFORT_LEVELS, type TransportEffortLevel } from '@shared/effort-levels.js'; +import { useState, useEffect } from "preact/hooks"; +import { useTranslation } from "react-i18next"; +import type { WsClient } from "../ws-client.js"; +import { FileBrowser } from "./file-browser-lazy.js"; +import { getUserPref, saveUserPref } from "../api.js"; +import { sanitizeProjectName } from "@shared/sanitize-project-name.js"; +import { + CLAUDE_SDK_EFFORT_LEVELS, + CODEX_SDK_EFFORT_LEVELS, + COPILOT_SDK_EFFORT_LEVELS, + OPENCLAW_THINKING_LEVELS, + QWEN_EFFORT_LEVELS, + type TransportEffortLevel, +} from "@shared/effort-levels.js"; -const DEFAULT_SHELL_KEY = 'default_shell'; +const DEFAULT_SHELL_KEY = "default_shell"; interface Props { ws: WsClient | null; @@ -15,62 +22,96 @@ interface Props { isProviderConnected: (id: string) => boolean; } -type AgentType = 'claude-code' | 'claude-code-sdk' | 'codex' | 'codex-sdk' | 'opencode' | 'gemini' | 'openclaw' | 'qwen'; -type OpenClawMode = 'new' | 'bind'; +type AgentType = + | "claude-code" + | "claude-code-sdk" + | "codex" + | "codex-sdk" + | "copilot-sdk" + | "cursor-headless" + | "opencode" + | "gemini" + | "openclaw" + | "qwen"; +type OpenClawMode = "new" | "bind"; interface RemoteSession { id: string; label: string; } -export function NewSessionDialog({ ws, onClose, onSessionStarted, isProviderConnected: _isProviderConnected }: Props) { +export function NewSessionDialog({ + ws, + onClose, + onSessionStarted, + isProviderConnected: _isProviderConnected, +}: Props) { const { t } = useTranslation(); - const [project, setProject] = useState(''); - const [dir, setDir] = useState('~/'); - const [agentType, setAgentType] = useState('claude-code-sdk'); - const [error, setError] = useState(''); + const [project, setProject] = useState(""); + const [dir, setDir] = useState("~/"); + const [agentType, setAgentType] = useState("claude-code-sdk"); + const [error, setError] = useState(""); const [starting, setStarting] = useState(false); const [showDirBrowser, setShowDirBrowser] = useState(false); - const [thinking, setThinking] = useState('high'); + const [thinking, setThinking] = useState("high"); const [shells, setShells] = useState([]); - const [shellBin, setShellBin] = useState(''); + const [shellBin, setShellBin] = useState(""); // CC env presets - const [ccPresets, setCcPresets] = useState; contextWindow?: number; initMessage?: string }>>([]); - const [ccPreset, setCcPreset] = useState(''); - const [ccInitPrompt, setCcInitPrompt] = useState(''); + const [ccPresets, setCcPresets] = useState< + Array<{ + name: string; + env: Record; + contextWindow?: number; + initMessage?: string; + }> + >([]); + const [ccPreset, setCcPreset] = useState(""); + const [ccInitPrompt, setCcInitPrompt] = useState(""); const [showPresetEditor, setShowPresetEditor] = useState(false); // New preset form - const [newPresetName, setNewPresetName] = useState(''); - const [newPresetBaseUrl, setNewPresetBaseUrl] = useState(''); - const [newPresetToken, setNewPresetToken] = useState(''); - const [newPresetModel, setNewPresetModel] = useState(''); - const [newPresetCtx, setNewPresetCtx] = useState('1000000'); - const [newPresetCustomEnv, setNewPresetCustomEnv] = useState>([]); - const DEFAULT_INIT_MSG = 'For web searches, use: curl -s "https://html.duckduckgo.com/html/?q=QUERY" | head -200. Replace QUERY with URL-encoded search terms.'; + const [newPresetName, setNewPresetName] = useState(""); + const [newPresetBaseUrl, setNewPresetBaseUrl] = useState(""); + const [newPresetToken, setNewPresetToken] = useState(""); + const [newPresetModel, setNewPresetModel] = useState(""); + const [newPresetCtx, setNewPresetCtx] = useState("1000000"); + const [newPresetCustomEnv, setNewPresetCustomEnv] = useState< + Array<{ key: string; value: string }> + >([]); + const DEFAULT_INIT_MSG = + 'For web searches, use: curl -s "https://html.duckduckgo.com/html/?q=QUERY" | head -200. Replace QUERY with URL-encoded search terms.'; const [newPresetInit, setNewPresetInit] = useState(DEFAULT_INIT_MSG); - const fmtCtx = (v: string) => { const n = parseInt(v, 10); if (!n) return ''; if (n >= 1000000) return `${(n/1000000).toFixed(n%1000000===0?0:1)}M`; if (n >= 1000) return `${(n/1000).toFixed(0)}K`; return String(n); }; + const fmtCtx = (v: string) => { + const n = parseInt(v, 10); + if (!n) return ""; + if (n >= 1000000) + return `${(n / 1000000).toFixed(n % 1000000 === 0 ? 0 : 1)}M`; + if (n >= 1000) return `${(n / 1000).toFixed(0)}K`; + return String(n); + }; // OpenClaw-specific state - const [ocMode, setOcMode] = useState('new'); - const [ocSessionKey, setOcSessionKey] = useState(''); - const [ocDescription, setOcDescription] = useState(''); + const [ocMode, setOcMode] = useState("new"); + const [ocSessionKey, setOcSessionKey] = useState(""); + const [ocDescription, setOcDescription] = useState(""); const [ocRemoteSessions, setOcRemoteSessions] = useState([]); const [ocLoadingSessions, setOcLoadingSessions] = useState(false); - const [ocSelectedSession, setOcSelectedSession] = useState(''); + const [ocSelectedSession, setOcSelectedSession] = useState(""); // Load saved shell preference — will be validated against daemon's detected list later const [savedShellPref, setSavedShellPref] = useState(null); useEffect(() => { - void getUserPref(DEFAULT_SHELL_KEY).then((saved) => { - if (typeof saved === 'string' && saved) setSavedShellPref(saved); - }).catch(() => {}); + void getUserPref(DEFAULT_SHELL_KEY) + .then((saved) => { + if (typeof saved === "string" && saved) setSavedShellPref(saved); + }) + .catch(() => {}); }, []); useEffect(() => { if (!ws) return; const unsub = ws.onMessage((msg) => { - if (msg.type === 'subsession.shells') { + if (msg.type === "subsession.shells") { const list = msg.shells as string[]; setShells(list); // Use saved preference only if daemon actually has that shell; otherwise pick first detected @@ -78,38 +119,42 @@ export function NewSessionDialog({ ws, onClose, onSessionStarted, isProviderConn if (preferred && list.includes(preferred)) { setShellBin(preferred); } else { - setShellBin(list[0] ?? ''); + setShellBin(list[0] ?? ""); } } // Listen for CC presets response - if (msg.type === 'cc.presets.list_response') { + if (msg.type === "cc.presets.list_response") { setCcPresets((msg as any).presets ?? []); } // Listen for openclaw remote session list response const raw = msg as unknown as Record; - if (raw['type'] === 'openclaw.sessions_response') { - const sessions = raw['sessions'] as RemoteSession[] | undefined; + if (raw["type"] === "openclaw.sessions_response") { + const sessions = raw["sessions"] as RemoteSession[] | undefined; setOcRemoteSessions(sessions ?? []); setOcLoadingSessions(false); } }); ws.subSessionDetectShells?.(); - try { ws.send({ type: 'cc.presets.list' }); } catch { /* ws may not support send in test */ } + try { + ws.send({ type: "cc.presets.list" }); + } catch { + /* ws may not support send in test */ + } return unsub; - // eslint-disable-next-line react-hooks/exhaustive-deps + // eslint-disable-next-line react-hooks/exhaustive-deps }, [ws]); // Fetch remote sessions when bind mode is selected useEffect(() => { - if (agentType !== 'openclaw' || ocMode !== 'bind' || !ws) return; + if (agentType !== "openclaw" || ocMode !== "bind" || !ws) return; setOcLoadingSessions(true); setOcRemoteSessions([]); - ws.send({ type: 'openclaw.list_sessions' }); + ws.send({ type: "openclaw.list_sessions" }); }, [agentType, ocMode, ws]); // Auto-generate a session key when switching to openclaw new mode useEffect(() => { - if (agentType === 'openclaw' && ocMode === 'new' && !ocSessionKey) { + if (agentType === "openclaw" && ocMode === "new" && !ocSessionKey) { setOcSessionKey(`oc-${Math.random().toString(36).slice(2, 10)}`); } }, [agentType, ocMode, ocSessionKey]); @@ -120,22 +165,25 @@ export function NewSessionDialog({ ws, onClose, onSessionStarted, isProviderConn useEffect(() => { if (!ws || !starting) return; const unsub = ws.onMessage((msg) => { - if (msg.type === 'session.event') { - const name = msg.session ?? ''; + if (msg.type === "session.event") { + const name = msg.session ?? ""; const slug = sanitizeProjectName(project); - if (msg.event === 'started' && name.startsWith(`deck_${slug}_`)) { + if (msg.event === "started" && name.startsWith(`deck_${slug}_`)) { unsub(); onSessionStarted(name); onClose(); - } else if (msg.event === 'error' && name.startsWith(`deck_${slug}_`)) { + } else if (msg.event === "error" && name.startsWith(`deck_${slug}_`)) { unsub(); setError(`Session failed to start: ${msg.state}`); setStarting(false); } } - if (msg.type === 'session.error') { + if (msg.type === "session.error") { unsub(); - setError((msg as unknown as { message: string }).message || 'Failed to start session'); + setError( + (msg as unknown as { message: string }).message || + "Failed to start session", + ); setStarting(false); } }); @@ -143,86 +191,146 @@ export function NewSessionDialog({ ws, onClose, onSessionStarted, isProviderConn // Timeout after 15s const timeout = setTimeout(() => { unsub(); - setError(t('new_session.timeout')); + setError(t("new_session.timeout")); setStarting(false); }, 15_000); - return () => { unsub(); clearTimeout(timeout); }; + return () => { + unsub(); + clearTimeout(timeout); + }; }, [starting, ws, project]); const handleStart = () => { - if (!project.trim()) { setError(t('new_session.project_required')); return; } - if (!dir.trim()) { setError(t('new_session.dir_required')); return; } - if (!ws) { setError(t('new_session.not_connected')); return; } - if (!ws.connected) { setError(t('new_session.daemon_offline')); return; } + if (!project.trim()) { + setError(t("new_session.project_required")); + return; + } + if (!dir.trim()) { + setError(t("new_session.dir_required")); + return; + } + if (!ws) { + setError(t("new_session.not_connected")); + return; + } + if (!ws.connected) { + setError(t("new_session.daemon_offline")); + return; + } - setError(''); + setError(""); setStarting(true); - if (shellBin) void saveUserPref(DEFAULT_SHELL_KEY, shellBin).catch(() => {}); + if (shellBin) + void saveUserPref(DEFAULT_SHELL_KEY, shellBin).catch(() => {}); - if (agentType === 'openclaw') { + if (agentType === "openclaw") { const extra = - ocMode === 'bind' - ? { ocMode: 'bind', ocSessionId: ocSelectedSession } - : { ocMode: 'new', ocSessionKey: ocSessionKey.trim(), ocDescription: ocDescription.trim() }; - ws.sendSessionCommand('start', { project: project.trim(), dir: dir.trim(), agentType, ...extra, thinking }); + ocMode === "bind" + ? { ocMode: "bind", ocSessionId: ocSelectedSession } + : { + ocMode: "new", + ocSessionKey: ocSessionKey.trim(), + ocDescription: ocDescription.trim(), + }; + ws.sendSessionCommand("start", { + project: project.trim(), + dir: dir.trim(), + agentType, + ...extra, + thinking, + }); } else { const extra: Record = {}; - if (ccPreset && (agentType === 'claude-code' || agentType === 'qwen')) extra.ccPreset = ccPreset; - if (ccInitPrompt.trim() && agentType === 'claude-code') extra.ccInitPrompt = ccInitPrompt.trim(); - ws.sendSessionCommand('start', { - project: project.trim(), dir: dir.trim(), agentType, + if (ccPreset && (agentType === "claude-code" || agentType === "qwen")) + extra.ccPreset = ccPreset; + if (ccInitPrompt.trim() && agentType === "claude-code") + extra.ccInitPrompt = ccInitPrompt.trim(); + ws.sendSessionCommand("start", { + project: project.trim(), + dir: dir.trim(), + agentType, ...extra, - ...((agentType === 'claude-code-sdk' || agentType === 'codex-sdk' || agentType === 'qwen') ? { thinking } : {}), + ...(agentType === "claude-code-sdk" || + agentType === "codex-sdk" || + agentType === "copilot-sdk" || + agentType === "qwen" + ? { thinking } + : {}), }); } }; - const agentFlavor = ( - agentType === 'claude-code' - || agentType === 'codex' - ) ? 'cli' : ( - agentType === 'claude-code-sdk' - || agentType === 'codex-sdk' - ) ? 'sdk' : null; - const thinkingLevels = agentType === 'claude-code-sdk' - ? CLAUDE_SDK_EFFORT_LEVELS - : agentType === 'codex-sdk' - ? CODEX_SDK_EFFORT_LEVELS - : agentType === 'qwen' - ? QWEN_EFFORT_LEVELS - : agentType === 'openclaw' - ? OPENCLAW_THINKING_LEVELS - : []; - const supportsCcPreset = agentType === 'claude-code' || agentType === 'qwen'; + const agentFlavor = + agentType === "claude-code" || agentType === "codex" + ? "cli" + : agentType === "claude-code-sdk" || agentType === "codex-sdk" + ? "sdk" + : null; + const thinkingLevels = + agentType === "claude-code-sdk" + ? CLAUDE_SDK_EFFORT_LEVELS + : agentType === "codex-sdk" + ? CODEX_SDK_EFFORT_LEVELS + : agentType === "copilot-sdk" + ? COPILOT_SDK_EFFORT_LEVELS + : agentType === "qwen" + ? QWEN_EFFORT_LEVELS + : agentType === "openclaw" + ? OPENCLAW_THINKING_LEVELS + : []; + const supportsCcPreset = agentType === "claude-code" || agentType === "qwen"; useEffect(() => { - setThinking('high'); + setThinking("high"); }, [agentType]); const handleKey = (e: KeyboardEvent) => { - if (e.key === 'Escape' && !starting) onClose(); - if (e.key === 'Enter' && !starting) handleStart(); + if (e.key === "Escape" && !starting) onClose(); + if (e.key === "Enter" && !starting) handleStart(); }; return (
{ if (e.target === e.currentTarget && !starting) onClose(); }} + style={{ + position: "fixed", + inset: 0, + background: "#00000080", + display: "flex", + alignItems: "center", + justifyContent: "center", + zIndex: 9999, + }} + onClick={(e) => { + if (e.target === e.currentTarget && !starting) onClose(); + }} onKeyDown={handleKey} role="dialog" > -
-

{t('new_session.title')}

+
+

+ {t("new_session.title")} +

- + { setProject((e.target as HTMLInputElement).value); setError(''); }} + onInput={(e) => { + setProject((e.target as HTMLInputElement).value); + setError(""); + }} autoFocus autoComplete="off" autoCorrect="off" @@ -234,7 +342,7 @@ export function NewSessionDialog({ ws, onClose, onSessionStarted, isProviderConn
- +
{ws && ( - + )}
@@ -260,47 +376,95 @@ export function NewSessionDialog({ ws, onClose, onSessionStarted, isProviderConn ws={ws} mode="dir-only" layout="modal" - initialPath={dir || '~'} - onConfirm={(paths) => { setDir(paths[0] ?? ''); setShowDirBrowser(false); }} + initialPath={dir || "~"} + onConfirm={(paths) => { + setDir(paths[0] ?? ""); + setShowDirBrowser(false); + }} onClose={() => setShowDirBrowser(false)} /> )}
- + {agentFlavor && ( -
- {agentFlavor === 'cli' ? t('new_session.agent_flavor_cli') : t('new_session.agent_flavor_sdk')} +
+ {agentFlavor === "cli" + ? t("new_session.agent_flavor_cli") + : t("new_session.agent_flavor_sdk")}
)}
{thinkingLevels.length > 0 && (
- +
@@ -310,134 +474,471 @@ export function NewSessionDialog({ ws, onClose, onSessionStarted, isProviderConn {supportsCcPreset && ( <>
-
{/* Inline preset editor */} {showPresetEditor && ( -
-
Add / Edit Preset
-
Stored locally on daemon (~/.imcodes/cc-presets.json)
+
+
+ Add / Edit Preset +
+
+ Stored locally on daemon (~/.imcodes/cc-presets.json) +
{[ - { label: 'Preset Name', envKey: '', ph: 'e.g. MiniMax', val: newPresetName, set: setNewPresetName }, - { label: 'API Base URL', envKey: 'ANTHROPIC_BASE_URL', ph: 'https://api.minimax.io/anthropic', val: newPresetBaseUrl, set: setNewPresetBaseUrl }, - { label: 'API Key', envKey: 'ANTHROPIC_AUTH_TOKEN', ph: 'your-api-key', val: newPresetToken, set: setNewPresetToken, type: 'password' as const }, - { label: 'Model', envKey: 'ANTHROPIC_MODEL', ph: 'e.g. MiniMax-M2.7', val: newPresetModel, set: setNewPresetModel }, + { + label: "Preset Name", + envKey: "", + ph: "e.g. MiniMax", + val: newPresetName, + set: setNewPresetName, + }, + { + label: "API Base URL", + envKey: "ANTHROPIC_BASE_URL", + ph: "https://api.minimax.io/anthropic", + val: newPresetBaseUrl, + set: setNewPresetBaseUrl, + }, + { + label: "API Key", + envKey: "ANTHROPIC_AUTH_TOKEN", + ph: "your-api-key", + val: newPresetToken, + set: setNewPresetToken, + type: "password" as const, + }, + { + label: "Model", + envKey: "ANTHROPIC_MODEL", + ph: "e.g. MiniMax-M2.7", + val: newPresetModel, + set: setNewPresetModel, + }, ].map(({ label, envKey, ph, val, set, type }) => (
-
{label}{envKey && {envKey}}
- + {label} + {envKey && ( + + {envKey} + + )} +
+ set((e.target as HTMLInputElement).value)} - style={{ width: '100%', background: '#1e293b', border: '1px solid #334155', color: '#e2e8f0', padding: '5px 8px', borderRadius: 4, fontSize: 12, boxSizing: 'border-box' }} + style={{ + width: "100%", + background: "#1e293b", + border: "1px solid #334155", + color: "#e2e8f0", + padding: "5px 8px", + borderRadius: 4, + fontSize: 12, + boxSizing: "border-box", + }} />
))}
-
Context Window{newPresetCtx && {fmtCtx(newPresetCtx)}}
- setNewPresetCtx((e.target as HTMLInputElement).value)} - style={{ width: '100%', background: '#1e293b', border: '1px solid #334155', color: '#e2e8f0', padding: '5px 8px', borderRadius: 4, fontSize: 12, boxSizing: 'border-box' }} +
+ Context Window + {newPresetCtx && ( + + {fmtCtx(newPresetCtx)} + + )} +
+ + setNewPresetCtx((e.target as HTMLInputElement).value) + } + style={{ + width: "100%", + background: "#1e293b", + border: "1px solid #334155", + color: "#e2e8f0", + padding: "5px 8px", + borderRadius: 4, + fontSize: 12, + boxSizing: "border-box", + }} />
{/* Custom env vars */}
-
- Custom ENV Vars - +
+ + Custom ENV Vars + +
{newPresetCustomEnv.map((item, i) => ( -
- { const u = [...newPresetCustomEnv]; u[i] = { ...u[i], key: (e.target as HTMLInputElement).value }; setNewPresetCustomEnv(u); }} - style={{ flex: 1, background: '#1e293b', border: '1px solid #334155', color: '#e2e8f0', padding: '4px 6px', borderRadius: 4, fontSize: 11, fontFamily: 'monospace', boxSizing: 'border-box' }} +
+ { + const u = [...newPresetCustomEnv]; + u[i] = { + ...u[i], + key: (e.target as HTMLInputElement).value, + }; + setNewPresetCustomEnv(u); + }} + style={{ + flex: 1, + background: "#1e293b", + border: "1px solid #334155", + color: "#e2e8f0", + padding: "4px 6px", + borderRadius: 4, + fontSize: 11, + fontFamily: "monospace", + boxSizing: "border-box", + }} /> - { const u = [...newPresetCustomEnv]; u[i] = { ...u[i], value: (e.target as HTMLInputElement).value }; setNewPresetCustomEnv(u); }} - style={{ flex: 2, background: '#1e293b', border: '1px solid #334155', color: '#e2e8f0', padding: '4px 6px', borderRadius: 4, fontSize: 11, boxSizing: 'border-box' }} + { + const u = [...newPresetCustomEnv]; + u[i] = { + ...u[i], + value: (e.target as HTMLInputElement).value, + }; + setNewPresetCustomEnv(u); + }} + style={{ + flex: 2, + background: "#1e293b", + border: "1px solid #334155", + color: "#e2e8f0", + padding: "4px 6px", + borderRadius: 4, + fontSize: 11, + boxSizing: "border-box", + }} /> - +
))}
-
Init Message (sent after session starts)
-