From 0c384a37cb79f59a1aa0116d607f191abc52bb59 Mon Sep 17 00:00:00 2001 From: uchouT Date: Tue, 28 Apr 2026 21:15:58 +0800 Subject: [PATCH 1/3] feat(session): AbortSignal support in LLMCompleteOptions and Session.send/stream MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Threads an optional AbortSignal from Session.send / Session.stream (and MainSession variants) through to LLMAdapter.complete / LLMAdapter.stream. Built-in OpenAI and Anthropic adapters forward the signal to their SDK request options. Aborting cancels the in-flight LLM call with a standard DOMException('aborted', 'AbortError') and skips L3 persistence entirely (both user and assistant records are dropped — atomic with non-stream behavior, no partial messages polluting context). Bumps @stello-ai/session 0.6.0 → 0.7.0 (additive; existing callers unaffected). Closes part A of #60. --- packages/session/package.json | 2 +- packages/session/src/__tests__/abort.test.ts | 157 ++++++++++++++++++ .../src/__tests__/openai-compatible.test.ts | 47 ++++-- packages/session/src/adapters/anthropic.ts | 46 ++--- .../session/src/adapters/openai-compatible.ts | 26 +-- packages/session/src/create-main-session.ts | 17 +- packages/session/src/create-session.ts | 21 ++- packages/session/src/index.ts | 1 + packages/session/src/types/llm.ts | 5 + .../session/src/types/main-session-api.ts | 6 +- packages/session/src/types/session-api.ts | 15 +- 11 files changed, 281 insertions(+), 62 deletions(-) create mode 100644 packages/session/src/__tests__/abort.test.ts diff --git a/packages/session/package.json b/packages/session/package.json index 3550194..86cf75b 100644 --- a/packages/session/package.json +++ b/packages/session/package.json @@ -1,6 +1,6 @@ { "name": "@stello-ai/session", - "version": "0.6.0", + "version": "0.7.0", "description": "Session layer for Stello — conversation topology engine", "license": "Apache-2.0", "author": "Stello Contributors", diff --git a/packages/session/src/__tests__/abort.test.ts b/packages/session/src/__tests__/abort.test.ts new file mode 100644 index 0000000..bf03a53 --- /dev/null +++ b/packages/session/src/__tests__/abort.test.ts @@ -0,0 +1,157 @@ +import { describe, it, expect } from 'vitest' +import { makeSession } from './helpers.js' +import type { LLMAdapter, LLMChunk, LLMCompleteOptions, LLMResult, Message } from '../types/llm.js' + +/** 让 fetch-style adapter 监听 signal 的最小 LLMAdapter */ +function createSignalAwareLLM(behavior: { + /** complete() 等多久 resolve(毫秒),默认 50ms */ + delayMs?: number + result?: LLMResult + chunks?: LLMChunk[] + /** chunk 之间的间隔(毫秒),默认 20ms */ + streamGapMs?: number +} = {}): LLMAdapter & { calls: { signal?: AbortSignal }[] } { + const calls: { signal?: AbortSignal }[] = [] + const result = behavior.result ?? { content: 'ok' } + const chunks = behavior.chunks ?? [{ delta: 'partial' }] + const delayMs = behavior.delayMs ?? 50 + const streamGapMs = behavior.streamGapMs ?? 20 + + return { + calls, + maxContextTokens: 1_000_000, + async complete(_messages: Message[], options?: LLMCompleteOptions): Promise { + calls.push({ signal: options?.signal }) + await new Promise((resolve, reject) => { + if (options?.signal?.aborted) { + reject(new DOMException('aborted', 'AbortError')) + return + } + const timer = setTimeout(() => { + options?.signal?.removeEventListener('abort', onAbort) + resolve() + }, delayMs) + const onAbort = () => { + clearTimeout(timer) + reject(new DOMException('aborted', 'AbortError')) + } + options?.signal?.addEventListener('abort', onAbort, { once: true }) + }) + return result + }, + async *stream(_messages: Message[], options?: LLMCompleteOptions): AsyncIterable { + calls.push({ signal: options?.signal }) + for (const chunk of chunks) { + if (options?.signal?.aborted) { + throw new DOMException('aborted', 'AbortError') + } + await new Promise((resolve, reject) => { + const timer = setTimeout(() => { + options?.signal?.removeEventListener('abort', onAbort) + resolve() + }, streamGapMs) + const onAbort = () => { + clearTimeout(timer) + reject(new DOMException('aborted', 'AbortError')) + } + options?.signal?.addEventListener('abort', onAbort, { once: true }) + }) + yield chunk + } + }, + } +} + +describe('Session.send() AbortSignal', () => { + it('signal abort 触发后 send() reject 为 AbortError,且不写入 L3', async () => { + const llm = createSignalAwareLLM({ delayMs: 100 }) + const { session } = await makeSession({ llm }) + + const controller = new AbortController() + const promise = session.send('hello', { signal: controller.signal }) + setTimeout(() => controller.abort(), 10) + + await expect(promise).rejects.toMatchObject({ name: 'AbortError' }) + + const messages = await session.messages() + expect(messages).toEqual([]) + }) + + it('已 abort 的 signal 立即抛出,不调用 LLM', async () => { + const llm = createSignalAwareLLM() + const { session } = await makeSession({ llm }) + const controller = new AbortController() + controller.abort() + + await expect(session.send('hello', { signal: controller.signal })).rejects.toMatchObject({ + name: 'AbortError', + }) + expect(llm.calls).toHaveLength(0) + }) + + it('signal 被透传到 LLMAdapter.complete', async () => { + const llm = createSignalAwareLLM() + const { session } = await makeSession({ llm }) + + const controller = new AbortController() + await session.send('hello', { signal: controller.signal }) + + expect(llm.calls[0]!.signal).toBe(controller.signal) + }) +}) + +describe('Session.stream() AbortSignal', () => { + it('流式中段 abort 后迭代器停止,result reject 为 AbortError,L3 不写', async () => { + const llm = createSignalAwareLLM({ + chunks: [{ delta: 'a' }, { delta: 'b' }, { delta: 'c' }], + streamGapMs: 30, + }) + const { session } = await makeSession({ llm }) + + const controller = new AbortController() + const stream = session.stream('hello', { signal: controller.signal }) + + const collected: string[] = [] + const iteratorPromise = (async () => { + for await (const chunk of stream) { + collected.push(chunk) + if (collected.length === 1) { + controller.abort() + } + } + })() + + await expect(stream.result).rejects.toMatchObject({ name: 'AbortError' }) + // 等待 iterator 完成(abort 后会停止) + await iteratorPromise.catch(() => {}) + + const messages = await session.messages() + expect(messages).toEqual([]) + }) + + it('stream() 已 abort 的 signal 立即让 result reject', async () => { + const llm = createSignalAwareLLM() + const { session } = await makeSession({ llm }) + const controller = new AbortController() + controller.abort() + + const stream = session.stream('hello', { signal: controller.signal }) + await expect(stream.result).rejects.toMatchObject({ name: 'AbortError' }) + }) + + it('signal 被透传到 LLMAdapter.stream', async () => { + const llm = createSignalAwareLLM({ + chunks: [{ delta: 'x' }], + }) + const { session } = await makeSession({ llm }) + + const controller = new AbortController() + const stream = session.stream('hello', { signal: controller.signal }) + for await (const _ of stream) { + // drain + } + await stream.result + + expect(llm.calls[0]!.signal).toBe(controller.signal) + }) +}) diff --git a/packages/session/src/__tests__/openai-compatible.test.ts b/packages/session/src/__tests__/openai-compatible.test.ts index 718db30..fd87b30 100644 --- a/packages/session/src/__tests__/openai-compatible.test.ts +++ b/packages/session/src/__tests__/openai-compatible.test.ts @@ -40,14 +40,17 @@ describe('createOpenAICompatibleAdapter', () => { await adapter.complete(messages) expect(createCompletion).toHaveBeenCalledTimes(1) - expect(createCompletion).toHaveBeenCalledWith(expect.objectContaining({ - messages: [ - { role: 'system', content: 'system prompt\n\nsynthesis' }, - { role: 'user', content: 'hello' }, - ], - max_tokens: 4096, - stream: false, - })) + expect(createCompletion).toHaveBeenCalledWith( + expect.objectContaining({ + messages: [ + { role: 'system', content: 'system prompt\n\nsynthesis' }, + { role: 'user', content: 'hello' }, + ], + max_tokens: 4096, + stream: false, + }), + undefined, + ) }) it('显式传入 maxTokens 时优先使用调用方配置', async () => { @@ -60,9 +63,29 @@ describe('createOpenAICompatibleAdapter', () => { await adapter.complete([{ role: 'user', content: 'hello' }], { maxTokens: 2048 }) - expect(createCompletion).toHaveBeenCalledWith(expect.objectContaining({ - max_tokens: 2048, - stream: false, - })) + expect(createCompletion).toHaveBeenCalledWith( + expect.objectContaining({ + max_tokens: 2048, + stream: false, + }), + undefined, + ) + }) + + it('signal 透传到 SDK request options', async () => { + const adapter = createOpenAICompatibleAdapter({ + apiKey: 'test-key', + baseURL: 'https://api.example.com/v1', + model: 'test-model', + maxContextTokens: 128_000, + }) + + const controller = new AbortController() + await adapter.complete([{ role: 'user', content: 'hello' }], { signal: controller.signal }) + + expect(createCompletion).toHaveBeenCalledWith( + expect.objectContaining({ stream: false }), + { signal: controller.signal }, + ) }) }) diff --git a/packages/session/src/adapters/anthropic.ts b/packages/session/src/adapters/anthropic.ts index 05925e7..78af07f 100644 --- a/packages/session/src/adapters/anthropic.ts +++ b/packages/session/src/adapters/anthropic.ts @@ -135,16 +135,19 @@ export function createAnthropicAdapter(options: AnthropicAdapterOptions): LLMAda ? systemMessages.map((m) => m.content).join('\n\n') : undefined - const response = await client.messages.create({ - model: options.model, - max_tokens: completeOptions?.maxTokens ?? 4096, - ...(completeOptions?.temperature !== undefined && { temperature: completeOptions.temperature }), - ...(system && { system }), - ...(completeOptions?.tools && completeOptions.tools.length > 0 - ? { tools: toAnthropicTools(completeOptions.tools) } - : {}), - messages: toAnthropicMessages(nonSystemMessages), - }) + const response = await client.messages.create( + { + model: options.model, + max_tokens: completeOptions?.maxTokens ?? 4096, + ...(completeOptions?.temperature !== undefined && { temperature: completeOptions.temperature }), + ...(system && { system }), + ...(completeOptions?.tools && completeOptions.tools.length > 0 + ? { tools: toAnthropicTools(completeOptions.tools) } + : {}), + messages: toAnthropicMessages(nonSystemMessages), + }, + completeOptions?.signal ? { signal: completeOptions.signal } : undefined, + ) const toolCalls = extractToolCalls(response.content) @@ -166,16 +169,19 @@ export function createAnthropicAdapter(options: AnthropicAdapterOptions): LLMAda ? systemMessages.map((m) => m.content).join('\n\n') : undefined - const stream = client.messages.stream({ - model: options.model, - max_tokens: completeOptions?.maxTokens ?? 4096, - ...(completeOptions?.temperature !== undefined && { temperature: completeOptions.temperature }), - ...(system && { system }), - ...(completeOptions?.tools && completeOptions.tools.length > 0 - ? { tools: toAnthropicTools(completeOptions.tools) } - : {}), - messages: toAnthropicMessages(nonSystemMessages), - }) + const stream = client.messages.stream( + { + model: options.model, + max_tokens: completeOptions?.maxTokens ?? 4096, + ...(completeOptions?.temperature !== undefined && { temperature: completeOptions.temperature }), + ...(system && { system }), + ...(completeOptions?.tools && completeOptions.tools.length > 0 + ? { tools: toAnthropicTools(completeOptions.tools) } + : {}), + messages: toAnthropicMessages(nonSystemMessages), + }, + completeOptions?.signal ? { signal: completeOptions.signal } : undefined, + ) for await (const event of stream) { if (event.type === 'content_block_delta') { diff --git a/packages/session/src/adapters/openai-compatible.ts b/packages/session/src/adapters/openai-compatible.ts index a4deab2..5a51f73 100644 --- a/packages/session/src/adapters/openai-compatible.ts +++ b/packages/session/src/adapters/openai-compatible.ts @@ -83,11 +83,14 @@ export function createOpenAICompatibleAdapter(options: OpenAICompatibleOptions): return { maxContextTokens: options.maxContextTokens, async complete(messages: Message[], completeOptions?: LLMCompleteOptions): Promise { - const response = await client.chat.completions.create({ - ...buildParams(messages, completeOptions), - ...(options.extraBody ?? {}), - stream: false, - } as Parameters[0]) as ChatCompletion + const response = await client.chat.completions.create( + { + ...buildParams(messages, completeOptions), + ...(options.extraBody ?? {}), + stream: false, + } as Parameters[0], + completeOptions?.signal ? { signal: completeOptions.signal } : undefined, + ) as ChatCompletion const choice = response.choices[0] @@ -110,11 +113,14 @@ export function createOpenAICompatibleAdapter(options: OpenAICompatibleOptions): } }, async *stream(messages: Message[], completeOptions?: LLMCompleteOptions) { - const stream = await client.chat.completions.create({ - ...buildParams(messages, completeOptions), - ...(options.extraBody ?? {}), - stream: true, - } as Parameters[0]) as Stream + const stream = await client.chat.completions.create( + { + ...buildParams(messages, completeOptions), + ...(options.extraBody ?? {}), + stream: true, + } as Parameters[0], + completeOptions?.signal ? { signal: completeOptions.signal } : undefined, + ) as Stream for await (const chunk of stream) { const delta = chunk.choices[0]?.delta?.content ?? '' diff --git a/packages/session/src/create-main-session.ts b/packages/session/src/create-main-session.ts index 91dedfc..10c058c 100644 --- a/packages/session/src/create-main-session.ts +++ b/packages/session/src/create-main-session.ts @@ -1,6 +1,6 @@ import { randomUUID } from 'node:crypto' import type { MainSession } from './types/main-session-api.js' -import type { Session, MessageQueryOptions } from './types/session-api.js' +import type { Session, MessageQueryOptions, SessionSendOptions } from './types/session-api.js' import { SessionArchivedError } from './types/session-api.js' import type { SessionMeta, SessionMetaUpdate, ForkOptions } from './types/session.js' import type { Message } from './types/llm.js' @@ -138,13 +138,14 @@ function buildMainSession( return currentMeta }, - async send(content: string): Promise { + async send(content: string, sendOptions?: SessionSendOptions): Promise { if (currentMeta.status === 'archived') { throw new SessionArchivedError(currentMeta.id) } if (!options.llm) { throw new Error('LLMAdapter is required for send()') } + sendOptions?.signal?.throwIfAborted() // 组装上下文(自动压缩) const assembled = await assembleMainSessionContext( @@ -172,8 +173,8 @@ function buildMainSession( recordsToPersist = promptMessages.slice(replayContext.length) } - // 调 LLM - const result = await options.llm.complete(promptMessages, { tools }) + // 调 LLM — abort 时直接向上传播;下方 L3 写入分支整体跳过 + const result = await options.llm.complete(promptMessages, { tools, signal: sendOptions?.signal }) // 更新 promptTokens 基线 if (result.usage?.promptTokens) { @@ -197,7 +198,7 @@ function buildMainSession( } }, - stream(content: string): StreamResult { + stream(content: string, sendOptions?: SessionSendOptions): StreamResult { if (currentMeta.status === 'archived') { throw new SessionArchivedError(currentMeta.id) } @@ -206,6 +207,8 @@ function buildMainSession( } return createStreamResult(async (push) => { + sendOptions?.signal?.throwIfAborted() + // 组装上下文(自动压缩) const assembled = await assembleMainSessionContext( currentMeta.id, storage, content, @@ -240,7 +243,7 @@ function buildMainSession( if (options.llm.stream) { let accumulated = '' const toolCallsByIndex = new Map() - for await (const chunk of options.llm.stream(promptMessages, { tools })) { + for await (const chunk of options.llm.stream(promptMessages, { tools, signal: sendOptions?.signal })) { accumulated += chunk.delta push(chunk.delta) for (const delta of chunk.toolCallDeltas ?? []) { @@ -258,7 +261,7 @@ function buildMainSession( })) result = { content: accumulated, toolCalls } } else { - result = await options.llm.complete(promptMessages, { tools }) + result = await options.llm.complete(promptMessages, { tools, signal: sendOptions?.signal }) if (result.content) { push(result.content) } diff --git a/packages/session/src/create-session.ts b/packages/session/src/create-session.ts index 30706ce..4611b10 100644 --- a/packages/session/src/create-session.ts +++ b/packages/session/src/create-session.ts @@ -1,5 +1,5 @@ import { randomUUID } from 'node:crypto' -import type { Session, MessageQueryOptions } from './types/session-api.js' +import type { Session, MessageQueryOptions, SessionSendOptions } from './types/session-api.js' import { SessionArchivedError } from './types/session-api.js' import type { SessionMeta, SessionMetaUpdate, ForkOptions } from './types/session.js' import type { Message } from './types/llm.js' @@ -184,13 +184,15 @@ function buildSession( return currentMeta }, - async send(content: string): Promise { + async send(content: string, sendOptions?: SessionSendOptions): Promise { if (currentMeta.status === 'archived') { throw new SessionArchivedError(currentMeta.id) } if (!options.llm) { throw new Error('LLMAdapter is required for send()') } + // pre-flight:已 abort 的 signal 立即抛出,不发起任何 LLM 请求 + sendOptions?.signal?.throwIfAborted() // 组装上下文(自动压缩) const assembled = await assembleSessionContext( @@ -227,8 +229,8 @@ function buildSession( } } - // 调 LLM - const result = await options.llm.complete(promptMessages, { tools }) + // 调 LLM — adapter 抛 AbortError 时直接向上传播,下方 L3 写入分支整体跳过 + const result = await options.llm.complete(promptMessages, { tools, signal: sendOptions?.signal }) // 更新 promptTokens 基线 if (result.usage?.promptTokens) { @@ -252,7 +254,7 @@ function buildSession( } }, - stream(content: string): StreamResult { + stream(content: string, sendOptions?: SessionSendOptions): StreamResult { if (currentMeta.status === 'archived') { throw new SessionArchivedError(currentMeta.id) } @@ -261,6 +263,9 @@ function buildSession( } return createStreamResult(async (push) => { + // pre-flight:已 abort 的 signal 立即让 result reject,processor 不进入下游 + sendOptions?.signal?.throwIfAborted() + // 组装上下文(自动压缩) const assembled = await assembleSessionContext( currentMeta.id, storage, content, @@ -304,7 +309,9 @@ function buildSession( if (options.llm.stream) { let accumulated = '' const toolCallsByIndex = new Map() - for await (const chunk of options.llm.stream(promptMessages, { tools })) { + // adapter 在 abort 时抛 AbortError,这里直接向上传播给 result promise; + // 下方 L3 写入分支不会执行(policy: drop entirely),与非流式 send() 对称。 + for await (const chunk of options.llm.stream(promptMessages, { tools, signal: sendOptions?.signal })) { accumulated += chunk.delta push(chunk.delta) for (const delta of chunk.toolCallDeltas ?? []) { @@ -322,7 +329,7 @@ function buildSession( })) result = { content: accumulated, toolCalls } } else { - result = await options.llm.complete(promptMessages, { tools }) + result = await options.llm.complete(promptMessages, { tools, signal: sendOptions?.signal }) if (result.content) { push(result.content) } diff --git a/packages/session/src/index.ts b/packages/session/src/index.ts index d8040d4..8d36204 100644 --- a/packages/session/src/index.ts +++ b/packages/session/src/index.ts @@ -7,6 +7,7 @@ export type { export type { Session, MessageQueryOptions, + SessionSendOptions, } from './types/session-api.js' export { SessionArchivedError, diff --git a/packages/session/src/types/llm.ts b/packages/session/src/types/llm.ts index e1130d2..a2d9f87 100644 --- a/packages/session/src/types/llm.ts +++ b/packages/session/src/types/llm.ts @@ -25,6 +25,11 @@ export interface LLMCompleteOptions { temperature?: number /** 可用工具列表的 schema(JSON Schema 格式) */ tools?: Array<{ name: string; description: string; inputSchema: Record }> + /** + * AbortSignal — adapter 应在 abort 时中断 LLM 调用并以 AbortError reject。 + * 不支持取消的 adapter 可忽略此字段(best-effort 语义)。 + */ + signal?: AbortSignal } /** LLM 完成后的返回结果 */ diff --git a/packages/session/src/types/main-session-api.ts b/packages/session/src/types/main-session-api.ts index 8497ed9..2c7ef38 100644 --- a/packages/session/src/types/main-session-api.ts +++ b/packages/session/src/types/main-session-api.ts @@ -1,7 +1,7 @@ import type { SessionMeta, SessionMetaUpdate, ForkOptions } from './session.js' import type { Message, LLMAdapter, LLMCompleteOptions } from './llm.js' import type { SendResult, StreamResult, IntegrateResult } from './functions.js' -import type { MessageQueryOptions, Session } from './session-api.js' +import type { MessageQueryOptions, Session, SessionSendOptions } from './session-api.js' /** * MainSession — 全局意识层对话单元 @@ -16,10 +16,10 @@ export interface MainSession { readonly meta: Readonly /** 发送消息:组装上下文(system prompt + synthesis + L3 + msg)→ 调 LLM → 存 L3 */ - send(content: string): Promise + send(content: string, options?: SessionSendOptions): Promise /** 流式发送:同 send 但逐 chunk 输出 */ - stream(content: string): StreamResult + stream(content: string, options?: SessionSendOptions): StreamResult /** 读取 L3 对话记录 */ messages(options?: MessageQueryOptions): Promise diff --git a/packages/session/src/types/session-api.ts b/packages/session/src/types/session-api.ts index 44edf90..0b5c5e9 100644 --- a/packages/session/src/types/session-api.ts +++ b/packages/session/src/types/session-api.ts @@ -10,6 +10,17 @@ export interface MessageQueryOptions { role?: Message['role'] } +/** + * Session.send / Session.stream 的运行时选项 + * + * 通过 signal 取消正在进行的 LLM 调用:abort 后 send() reject 为 AbortError, + * stream() 的 result 同样 reject。被取消的调用不写入 L3(user msg 也不持久化)。 + */ +export interface SessionSendOptions { + /** AbortSignal — abort 后中断 LLM 调用并 reject 为 AbortError */ + signal?: AbortSignal +} + /** Session 错误:操作归档中的 Session */ export class SessionArchivedError extends Error { constructor(sessionId: string) { @@ -35,10 +46,10 @@ export interface Session { readonly meta: Readonly /** 发送一条消息:组装上下文 → 调 LLM → 存 L3(用户消息 + LLM 响应)→ 返回结果 */ - send(content: string): Promise + send(content: string, options?: SessionSendOptions): Promise /** 流式发送:同 send() 但逐 chunk 输出,流结束后自动存 L3 */ - stream(content: string): StreamResult + stream(content: string, options?: SessionSendOptions): StreamResult /** 读取 L3 对话记录 */ messages(options?: MessageQueryOptions): Promise From 2cae3aca1b51428acc17e6cfede35b6d1756468f Mon Sep 17 00:00:00 2001 From: uchouT Date: Tue, 28 Apr 2026 21:16:14 +0800 Subject: [PATCH 2/3] feat(core): AbortSignal support in TurnRunner and engine session/tool runtime MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds signal threading across the entire turn lifecycle: - TurnRunnerOptions.signal — tool loop checks throwIfAborted() at every round boundary, before each tool call, and after each tool result is collected (suppresses phantom onToolResult after cancel). - TurnRunnerSession.send/stream and EngineRuntimeSession.send/stream accept { signal? }; runStream short-circuits when the signal is already aborted, otherwise forwards to session.stream and re-throws AbortError from the iterator instead of silently closing. - TurnRunnerToolExecutor.executeTool gains an options object carrying the signal; ToolExecutionContext.signal lets individual tools opt-in to honoring cancellation for long-running work (tools that ignore it still work — the runner aborts at the next round boundary). - SessionCompatible and adaptSessionToEngineRuntime forward the signal to the underlying @stello-ai/session contract. - Engine.turn / Engine.stream / Orchestrator / StelloAgent passthrough preserves signal via TurnRunnerOptions spread; no field stripping. Bumps @stello-ai/core 0.7.2 → 0.8.0 (additive). Closes part B of #60. --- packages/core/package.json | 2 +- .../__tests__/session-runtime.test.ts | 44 ++++- packages/core/src/adapters/session-runtime.ts | 22 ++- .../src/agent/__tests__/stello-agent.test.ts | 60 +++++- .../__tests__/turn-runner-abort.test.ts | 171 ++++++++++++++++++ .../src/engine/__tests__/turn-runner.test.ts | 8 +- packages/core/src/engine/stello-engine.ts | 15 +- packages/core/src/engine/turn-runner.ts | 68 ++++++- .../__tests__/default-engine-factory.test.ts | 2 +- packages/core/src/types/tool.ts | 6 + 10 files changed, 374 insertions(+), 24 deletions(-) create mode 100644 packages/core/src/engine/__tests__/turn-runner-abort.test.ts diff --git a/packages/core/package.json b/packages/core/package.json index 2586456..92be153 100644 --- a/packages/core/package.json +++ b/packages/core/package.json @@ -1,6 +1,6 @@ { "name": "@stello-ai/core", - "version": "0.7.2", + "version": "0.8.0", "description": "The first open-source conversation topology engine", "license": "Apache-2.0", "author": "Stello Contributors", diff --git a/packages/core/src/adapters/__tests__/session-runtime.test.ts b/packages/core/src/adapters/__tests__/session-runtime.test.ts index 97bc9d4..d209da5 100644 --- a/packages/core/src/adapters/__tests__/session-runtime.test.ts +++ b/packages/core/src/adapters/__tests__/session-runtime.test.ts @@ -49,7 +49,7 @@ describe('session-runtime adapters', () => { const raw = await runtime.send('hello'); const parsed = sessionSendResultParser.parse(raw); - expect(session.send).toHaveBeenCalledWith('hello'); + expect(session.send).toHaveBeenCalledWith('hello', undefined); expect(runtime.meta.turnCount).toBe(3); expect(parsed.toolCalls[0]).toEqual({ id: 't1', @@ -146,6 +146,48 @@ describe('session-runtime adapters', () => { }); }); + it('adapter forwards signal to underlying SessionCompatible.send', async () => { + const session = { + meta: { id: 's1', status: 'active' as const }, + send: vi.fn().mockResolvedValue({ content: 'ok', toolCalls: [] }), + messages: vi.fn().mockResolvedValue([]), + consolidate: vi.fn(), + setTools: vi.fn(), + }; + const runtime = await adaptSessionToEngineRuntime(session, {}); + + const controller = new AbortController(); + await runtime.send('hi', { signal: controller.signal }); + expect(session.send).toHaveBeenCalledWith('hi', { signal: controller.signal }); + }); + + it('adapter forwards signal to underlying SessionCompatible.stream', async () => { + const streamSource = { + result: Promise.resolve({ content: 'ok', toolCalls: [] }), + async *[Symbol.asyncIterator]() { + yield 'a'; + }, + }; + const session = { + meta: { id: 's1', status: 'active' as const }, + send: vi.fn(), + stream: vi.fn(() => streamSource), + messages: vi.fn().mockResolvedValue([]), + consolidate: vi.fn(), + setTools: vi.fn(), + }; + const runtime = await adaptSessionToEngineRuntime(session, {}); + + const controller = new AbortController(); + const stream = runtime.stream!('hi', { signal: controller.signal }); + for await (const _ of stream) { + // drain + } + await stream.result; + + expect(session.stream).toHaveBeenCalledWith('hi', { signal: controller.signal }); + }); + it('adapter exposes tools getter and forwards setTools to underlying Session', async () => { const sessionTools: Array<{ name: string; description: string; inputSchema: object }> = [{ name: 'a', description: 'd', inputSchema: {} }]; const setToolsSpy = vi.fn((t) => { diff --git a/packages/core/src/adapters/session-runtime.ts b/packages/core/src/adapters/session-runtime.ts index a6cb877..33a34db 100644 --- a/packages/core/src/adapters/session-runtime.ts +++ b/packages/core/src/adapters/session-runtime.ts @@ -59,15 +59,25 @@ export interface SessionCompatibleForkOptions { compressFn?: SessionCompatibleCompressFn; } +/** Session.send / Session.stream 的可选运行时参数(结构兼容 @stello-ai/session) */ +export interface SessionCompatibleSendOptions { + /** AbortSignal — abort 时底层 LLM 调用应被取消 */ + signal?: AbortSignal; +} + /** 结构兼容 @stello-ai/session 的 Session */ export interface SessionCompatible { meta: { id: string; status: 'active' | 'archived'; }; - send(content: string): Promise; + send( + content: string, + options?: SessionCompatibleSendOptions, + ): Promise; stream?( - content: string + content: string, + options?: SessionCompatibleSendOptions, ): AsyncIterable & { result: Promise }; messages(): Promise>; consolidate(): Promise; @@ -159,8 +169,8 @@ export async function adaptSessionToEngineRuntime( get turnCount() { return turnCount; }, - async send(input: string): Promise { - const result = await session.send(input); + async send(input: string, sendOptions?: SessionCompatibleSendOptions): Promise { + const result = await session.send(input, sendOptions); turnCount += 1; return (options.serializeResult ?? serializeSessionSendResult)(result); }, @@ -175,8 +185,8 @@ export async function adaptSessionToEngineRuntime( }, ...(session.stream ? { - stream(input: string) { - const source = session.stream!(input); + stream(input: string, sendOptions?: SessionCompatibleSendOptions) { + const source = session.stream!(input, sendOptions); return { result: (async () => { const result = await source.result; diff --git a/packages/core/src/agent/__tests__/stello-agent.test.ts b/packages/core/src/agent/__tests__/stello-agent.test.ts index b9315c8..0474a99 100644 --- a/packages/core/src/agent/__tests__/stello-agent.test.ts +++ b/packages/core/src/agent/__tests__/stello-agent.test.ts @@ -84,7 +84,7 @@ describe('StelloAgent', () => { const result = await agent.turn('root', 'hello'); expect(agent.sessions).toBeDefined(); - expect(runtimeSession.send).toHaveBeenCalledWith('hello'); + expect(runtimeSession.send).toHaveBeenCalledWith('hello', { signal: undefined }); expect(result.turn.finalContent).toContain('"content":"done"'); }); @@ -118,6 +118,62 @@ describe('StelloAgent', () => { expect(result.turn.finalContent).toContain('"content":"done"') }); + it('agent.stream(input, { signal }) 透传到 runtime session 并在 abort 时让 result reject', async () => { + const controller = new AbortController() + const runtimeSession = { + id: 'root', + meta: { id: 'root', turnCount: 0, status: 'active' as const }, + turnCount: 0, + send: vi.fn(), + stream: vi.fn((_input: string, opts?: { signal?: AbortSignal }) => { + let rejectResult: (err: unknown) => void = () => {} + const result = new Promise((_resolve, reject) => { rejectResult = reject }) + result.catch(() => {}) + return { + result, + async *[Symbol.asyncIterator]() { + try { + for (const chunk of ['a', 'b', 'c']) { + if (opts?.signal?.aborted) { + const err = new DOMException('aborted', 'AbortError') + rejectResult(err) + throw err + } + await new Promise((r) => setTimeout(r, 5)) + yield chunk + } + } catch (err) { + rejectResult(err) + throw err + } + }, + } + }), + consolidate: vi.fn(), + setTools: vi.fn(), + } + + const agent = createStelloAgent(baseConfig({ runtimeSession })) + const stream = await agent.stream('root', 'hello', { signal: controller.signal }) + + const collected: string[] = [] + const iter = (async () => { + try { + for await (const chunk of stream) { + collected.push(chunk) + if (collected.length === 1) controller.abort() + } + } catch { + // expected: iterator re-throws AbortError + } + })() + + await expect(stream.result).rejects.toMatchObject({ name: 'AbortError' }) + await iter + + expect(runtimeSession.stream).toHaveBeenCalledWith('hello', { signal: controller.signal }) + }) + it('默认树形拓扑:子节点 fork 出的新节点挂在自己下面', async () => { const childSession = { ...rootSession, @@ -400,7 +456,7 @@ describe('StelloAgent', () => { const result = await agent.turn('root', 'hello'); - expect(session.send).toHaveBeenCalledWith('hello'); + expect(session.send).toHaveBeenCalledWith('hello', { signal: undefined }); expect(result.turn.rawResponse).toContain('"content":"done"'); expect(result.turn.toolCallsExecuted).toBe(1); }); diff --git a/packages/core/src/engine/__tests__/turn-runner-abort.test.ts b/packages/core/src/engine/__tests__/turn-runner-abort.test.ts new file mode 100644 index 0000000..7d0b7fc --- /dev/null +++ b/packages/core/src/engine/__tests__/turn-runner-abort.test.ts @@ -0,0 +1,171 @@ +import { describe, expect, it, vi } from 'vitest' +import { TurnRunner, type ToolCallParser } from '../turn-runner' + +const parser: ToolCallParser = { + parse(raw) { + return JSON.parse(raw) as { content: string | null; toolCalls: Array<{ id?: string; name: string; args: Record }> } + }, +} + +describe('TurnRunner.run AbortSignal', () => { + it('signal abort 在轮间生效,下一轮 send 不再发起', async () => { + const controller = new AbortController() + const session = { + id: 's1', + send: vi + .fn<(input: string, options?: { signal?: AbortSignal }) => Promise>() + .mockResolvedValueOnce( + JSON.stringify({ + content: null, + toolCalls: [{ id: '1', name: 'read', args: { path: 'a' } }], + }), + ), + } + const tools = { + executeTool: vi.fn().mockImplementation(async () => { + controller.abort() + return { success: true, data: 'ok' } + }), + } + const onToolResult = vi.fn() + + const runner = new TurnRunner(parser) + await expect( + runner.run(session, 'hello', tools, { + signal: controller.signal, + onToolResult, + }), + ).rejects.toMatchObject({ name: 'AbortError' }) + + // 第二轮 session.send 不应被调用(signal 在 round 边界检查) + expect(session.send).toHaveBeenCalledTimes(1) + // tool 执行后立刻 abort,onToolResult 不应触发(避免 phantom result) + expect(onToolResult).not.toHaveBeenCalled() + }) + + it('已 abort 的 signal 立即拒绝,不调用 session.send', async () => { + const controller = new AbortController() + controller.abort() + const session = { id: 's1', send: vi.fn() } + const tools = { executeTool: vi.fn() } + + const runner = new TurnRunner(parser) + await expect( + runner.run(session, 'hello', tools, { signal: controller.signal }), + ).rejects.toMatchObject({ name: 'AbortError' }) + + expect(session.send).not.toHaveBeenCalled() + }) + + it('signal 透传到 session.send 与 tools.executeTool', async () => { + const controller = new AbortController() + const session = { + id: 's1', + send: vi + .fn<(input: string, options?: { signal?: AbortSignal }) => Promise>() + .mockResolvedValueOnce( + JSON.stringify({ + content: null, + toolCalls: [{ id: '1', name: 'read', args: {} }], + }), + ) + .mockResolvedValueOnce(JSON.stringify({ content: 'done', toolCalls: [] })), + } + const tools = { + executeTool: vi + .fn<(name: string, args: Record, id?: string, options?: { signal?: AbortSignal }) => Promise<{ success: boolean; data?: unknown }>>() + .mockResolvedValue({ success: true, data: 'x' }), + } + + const runner = new TurnRunner(parser) + await runner.run(session, 'hi', tools, { signal: controller.signal }) + + // session.send 第一参数是 input;第二参数应携带 signal + expect(session.send).toHaveBeenCalledWith('hi', expect.objectContaining({ signal: controller.signal })) + // tools.executeTool 第四参数应携带 signal + expect(tools.executeTool).toHaveBeenCalledWith( + 'read', + {}, + '1', + expect.objectContaining({ signal: controller.signal }), + ) + }) +}) + +describe('TurnRunner.runStream AbortSignal', () => { + it('运行中 abort 后 result reject 为 AbortError,且后续不再 send', async () => { + const controller = new AbortController() + // 模拟 session.stream 的真实行为:iterator 抛 AbortError,result 也 reject。 + function makeMockStream(chunks: string[]) { + let resolveResult: (raw: string) => void = () => {} + let rejectResult: (err: unknown) => void = () => {} + const result = new Promise((resolve, reject) => { + resolveResult = resolve + rejectResult = reject + }) + result.catch(() => {}) + return { + result, + async *[Symbol.asyncIterator]() { + try { + for (const chunk of chunks) { + if (controller.signal.aborted) { + const err = new DOMException('aborted', 'AbortError') + rejectResult(err) + throw err + } + await new Promise((r) => setTimeout(r, 5)) + yield chunk + } + resolveResult(JSON.stringify({ content: chunks.join(''), toolCalls: [] })) + } catch (err) { + rejectResult(err) + throw err + } + }, + } + } + const session = { + id: 's1', + stream: vi.fn(() => makeMockStream(['a', 'b', 'c'])), + send: vi.fn(), + } + const tools = { executeTool: vi.fn() } + + const runner = new TurnRunner(parser) + const stream = runner.runStream(session, 'hi', tools, { signal: controller.signal }) + + const collected: string[] = [] + const iter = (async () => { + try { + for await (const chunk of stream) { + collected.push(chunk) + if (collected.length === 1) controller.abort() + } + } catch { + // iterator re-throws AbortError per plan; consumer-side ok + } + })() + + await expect(stream.result).rejects.toMatchObject({ name: 'AbortError' }) + await iter + + expect(session.send).not.toHaveBeenCalled() + }) + + it('已 abort signal 让 runStream 立即让 result reject', async () => { + const controller = new AbortController() + controller.abort() + const session = { + id: 's1', + send: vi.fn(), + stream: vi.fn(), + } + const tools = { executeTool: vi.fn() } + + const runner = new TurnRunner(parser) + const stream = runner.runStream(session, 'hi', tools, { signal: controller.signal }) + await expect(stream.result).rejects.toMatchObject({ name: 'AbortError' }) + expect(session.stream).not.toHaveBeenCalled() + }) +}) diff --git a/packages/core/src/engine/__tests__/turn-runner.test.ts b/packages/core/src/engine/__tests__/turn-runner.test.ts index f7e6e1b..f699de1 100644 --- a/packages/core/src/engine/__tests__/turn-runner.test.ts +++ b/packages/core/src/engine/__tests__/turn-runner.test.ts @@ -21,7 +21,7 @@ describe('TurnRunner', () => { const result = await runner.run(session, 'hello', tools); expect(session.send).toHaveBeenCalledTimes(1); - expect(session.send).toHaveBeenCalledWith('hello'); + expect(session.send).toHaveBeenCalledWith('hello', { signal: undefined }); expect(result.finalContent).toBe('final'); expect(result.toolRoundCount).toBe(0); expect(result.toolCallsExecuted).toBe(0); @@ -48,7 +48,7 @@ describe('TurnRunner', () => { const result = await runner.run(session, 'hello', tools); expect(session.send).toHaveBeenCalledTimes(2); - expect(tools.executeTool).toHaveBeenCalledWith('read', { path: 'core.name' }, '1'); + expect(tools.executeTool).toHaveBeenCalledWith('read', { path: 'core.name' }, '1', { signal: undefined }); expect(session.send.mock.calls[1]?.[0]).toContain('"toolResults"'); expect(result.finalContent).toBe('done'); expect(result.toolRoundCount).toBe(1); @@ -79,8 +79,8 @@ describe('TurnRunner', () => { const result = await runner.run(session, 'hello', tools); expect(tools.executeTool.mock.calls).toEqual([ - ['read', { path: 'core.name' }, undefined], - ['list', { scope: 'ui' }, undefined], + ['read', { path: 'core.name' }, undefined, { signal: undefined }], + ['list', { scope: 'ui' }, undefined, { signal: undefined }], ]); expect(result.toolCallsExecuted).toBe(2); }); diff --git a/packages/core/src/engine/stello-engine.ts b/packages/core/src/engine/stello-engine.ts index 69887f5..17286ec 100644 --- a/packages/core/src/engine/stello-engine.ts +++ b/packages/core/src/engine/stello-engine.ts @@ -36,6 +36,12 @@ import { type TurnRunnerStreamResult, } from './turn-runner'; +/** Engine 调用 session.send/stream 时的运行时选项 */ +export interface EngineRuntimeSessionCallOptions { + /** AbortSignal — 透传给底层 session.send/stream 与 LLM 调用 */ + signal?: AbortSignal; +} + /** 供 Engine 使用的运行时 Session 契约 */ export interface EngineRuntimeSession { /** 供日志和 hooks 使用的稳定标识 */ @@ -49,9 +55,12 @@ export interface EngineRuntimeSession { /** 当前已完成轮次 */ turnCount: number; /** 运行一次单条对话 */ - send(input: string): Promise; + send(input: string, options?: EngineRuntimeSessionCallOptions): Promise; /** 可选:流式运行一次单条对话 */ - stream?(input: string): AsyncIterable & { result: Promise }; + stream?( + input: string, + options?: EngineRuntimeSessionCallOptions, + ): AsyncIterable & { result: Promise }; /** fork 子 session,返回子 session 的 runtime */ fork?(options: SessionCompatibleForkOptions): Promise; /** 由 Session 自己完成 L2/L3 -> memory 的整理 */ @@ -435,12 +444,14 @@ export class StelloEngineImpl implements StelloEngine { name: string, args: Record, toolCallId?: string, + options?: { signal?: AbortSignal }, ): Promise { const ctx: ToolExecutionContext = { agent: this.agent, sessionId: this.session.id, toolCallId, toolName: name, + ...(options?.signal !== undefined && { signal: options.signal }), }; return this.tools.executeTool(name, args, ctx); } diff --git a/packages/core/src/engine/turn-runner.ts b/packages/core/src/engine/turn-runner.ts index 4bc2dbb..09c80e7 100644 --- a/packages/core/src/engine/turn-runner.ts +++ b/packages/core/src/engine/turn-runner.ts @@ -18,14 +18,29 @@ export interface ParsedTurnResponse { toolCalls: ToolCall[]; } +/** Session 调用的运行时选项 */ +export interface TurnRunnerSessionCallOptions { + /** AbortSignal — 透传给 session.send/stream,进而透传给 LLM 调用 */ + signal?: AbortSignal; +} + /** 单个 Session 的最小运行时契约 */ export interface TurnRunnerSession { /** Session 标识 */ id: string; /** 执行一次单条对话 */ - send(input: string): Promise; + send(input: string, options?: TurnRunnerSessionCallOptions): Promise; /** 可选:流式执行一次单条对话 */ - stream?(input: string): AsyncIterable & { result: Promise }; + stream?( + input: string, + options?: TurnRunnerSessionCallOptions, + ): AsyncIterable & { result: Promise }; +} + +/** Tool 调用的运行时选项 */ +export interface TurnRunnerToolCallOptions { + /** AbortSignal — tool 可读取以中断长任务(HTTP、subprocess 等) */ + signal?: AbortSignal; } /** Tool 执行器的最小契约 */ @@ -35,6 +50,7 @@ export interface TurnRunnerToolExecutor { name: string, args: Record, toolCallId?: string, + options?: TurnRunnerToolCallOptions, ): Promise; } @@ -52,6 +68,12 @@ export interface TurnRunnerOptions { onToolCall?: (toolCall: ToolCall) => Promise | void; /** 工具调用后的观察回调 */ onToolResult?: (result: ToolCallResult) => Promise | void; + /** + * AbortSignal — abort 后下一轮边界(含 send / tool 执行前后)抛 AbortError, + * 同时透传给 session.send/stream 与 tools.executeTool。 + * Tools 不消费 ctx.signal 时,runner 会等本轮 tool 自然返回,再在边界处抛。 + */ + signal?: AbortSignal; } /** 单个工具调用的执行结果 */ @@ -113,7 +135,8 @@ export class TurnRunner { let lastRawResponse = ''; while (true) { - lastRawResponse = await session.send(currentInput); + options.signal?.throwIfAborted(); + lastRawResponse = await session.send(currentInput, { signal: options.signal }); const parsed = this.parser.parse(lastRawResponse); if (parsed.toolCalls.length === 0) { @@ -131,9 +154,17 @@ export class TurnRunner { const toolResults = []; for (const toolCall of parsed.toolCalls) { + options.signal?.throwIfAborted(); await options.onToolCall?.(toolCall); - const result = await tools.executeTool(toolCall.name, toolCall.args, toolCall.id); + const result = await tools.executeTool( + toolCall.name, + toolCall.args, + toolCall.id, + { signal: options.signal }, + ); toolCallsExecuted += 1; + // tool 结果收集后立刻检查 signal — 已 abort 时不下发 phantom onToolResult。 + options.signal?.throwIfAborted(); const toolResult: ToolCallResult = { toolCallId: toolCall.id ?? null, toolName: toolCall.name, @@ -165,6 +196,19 @@ export class TurnRunner { tools: TurnRunnerToolExecutor, options: TurnRunnerOptions = {}, ): TurnRunnerStreamResult { + // pre-flight:已 abort 时直接返回 reject 的 result + 立刻抛错的 iterator + if (options.signal?.aborted) { + const aborted = Promise.reject(new DOMException('aborted', 'AbortError')) + // 安抚 unhandledRejection:消费方通过 `result` 或 iterator 任一感知即可。 + aborted.catch(() => {}) + return { + result: aborted as Promise, + async *[Symbol.asyncIterator]() { + throw new DOMException('aborted', 'AbortError') + }, + } + } + if (!session.stream) { const result = this.run(session, input, tools, options) return { @@ -178,12 +222,13 @@ export class TurnRunner { } } - const source = session.stream(input) + const source = session.stream(input, { signal: options.signal }) const result = this.finishFromStreamResult(session, source.result, tools, options) return { result, async *[Symbol.asyncIterator]() { + // 重新抛出 AbortError(而不是静默关闭),让调用方明确感知取消语义。 for await (const chunk of source) { yield chunk } @@ -201,6 +246,7 @@ export class TurnRunner { let toolRoundCount = 0 let toolCallsExecuted = 0 let lastRawResponse = await rawResult + options.signal?.throwIfAborted() let parsed = this.parser.parse(lastRawResponse) while (parsed.toolCalls.length > 0) { @@ -210,9 +256,16 @@ export class TurnRunner { const toolResults = [] for (const toolCall of parsed.toolCalls) { + options.signal?.throwIfAborted() await options.onToolCall?.(toolCall) - const result = await tools.executeTool(toolCall.name, toolCall.args, toolCall.id) + const result = await tools.executeTool( + toolCall.name, + toolCall.args, + toolCall.id, + { signal: options.signal }, + ) toolCallsExecuted += 1 + options.signal?.throwIfAborted() const toolResult: ToolCallResult = { toolCallId: toolCall.id ?? null, toolName: toolCall.name, @@ -226,7 +279,8 @@ export class TurnRunner { } toolRoundCount += 1 - lastRawResponse = await session.send(JSON.stringify({ toolResults })) + options.signal?.throwIfAborted() + lastRawResponse = await session.send(JSON.stringify({ toolResults }), { signal: options.signal }) parsed = this.parser.parse(lastRawResponse) } diff --git a/packages/core/src/orchestrator/__tests__/default-engine-factory.test.ts b/packages/core/src/orchestrator/__tests__/default-engine-factory.test.ts index eeb80c9..60308a7 100644 --- a/packages/core/src/orchestrator/__tests__/default-engine-factory.test.ts +++ b/packages/core/src/orchestrator/__tests__/default-engine-factory.test.ts @@ -56,7 +56,7 @@ describe('DefaultEngineFactory', () => { const result = await engine.turn('hello'); expect(engine.sessionId).toBe('s1'); - expect(runtimeSession.send).toHaveBeenCalledWith('hello'); + expect(runtimeSession.send).toHaveBeenCalledWith('hello', { signal: undefined }); expect(result.turn.rawResponse).toContain('"content":"done"'); }); diff --git a/packages/core/src/types/tool.ts b/packages/core/src/types/tool.ts index 00b11a1..3142bdb 100644 --- a/packages/core/src/types/tool.ts +++ b/packages/core/src/types/tool.ts @@ -10,4 +10,10 @@ export interface ToolExecutionContext { toolCallId?: string /** This tool's own name (debug, generic wrappers) */ toolName: string + /** + * AbortSignal of the in-flight turn. Tools may opt-in to honor it for + * long-running work (HTTP, subprocess, etc.). Tools that ignore the signal + * still work — the runner aborts at the next round boundary. + */ + signal?: AbortSignal } From 922fe47c2c5afcfc8af986719d3505613e6f8cc2 Mon Sep 17 00:00:00 2001 From: uchouT Date: Wed, 29 Apr 2026 00:21:20 +0800 Subject: [PATCH 3/3] test: avoid unused-var lint error in for-await drain loops --- .../core/src/adapters/__tests__/session-runtime.test.ts | 6 ++++-- packages/session/src/__tests__/abort.test.ts | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/packages/core/src/adapters/__tests__/session-runtime.test.ts b/packages/core/src/adapters/__tests__/session-runtime.test.ts index d209da5..3fe6918 100644 --- a/packages/core/src/adapters/__tests__/session-runtime.test.ts +++ b/packages/core/src/adapters/__tests__/session-runtime.test.ts @@ -180,11 +180,13 @@ describe('session-runtime adapters', () => { const controller = new AbortController(); const stream = runtime.stream!('hi', { signal: controller.signal }); - for await (const _ of stream) { - // drain + const drained: string[] = []; + for await (const chunk of stream) { + drained.push(chunk); } await stream.result; + expect(drained).toEqual(['a']); expect(session.stream).toHaveBeenCalledWith('hi', { signal: controller.signal }); }); diff --git a/packages/session/src/__tests__/abort.test.ts b/packages/session/src/__tests__/abort.test.ts index bf03a53..9b056d5 100644 --- a/packages/session/src/__tests__/abort.test.ts +++ b/packages/session/src/__tests__/abort.test.ts @@ -147,11 +147,13 @@ describe('Session.stream() AbortSignal', () => { const controller = new AbortController() const stream = session.stream('hello', { signal: controller.signal }) - for await (const _ of stream) { - // drain + const drained: string[] = [] + for await (const chunk of stream) { + drained.push(chunk) } await stream.result + expect(drained.length).toBeGreaterThan(0) expect(llm.calls[0]!.signal).toBe(controller.signal) }) })