Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/core/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@stello-ai/core",
"version": "0.7.2",
"version": "0.8.0",
"description": "The first open-source conversation topology engine",
"license": "Apache-2.0",
"author": "Stello Contributors",
Expand Down
46 changes: 45 additions & 1 deletion packages/core/src/adapters/__tests__/session-runtime.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ describe('session-runtime adapters', () => {
const raw = await runtime.send('hello');
const parsed = sessionSendResultParser.parse(raw);

expect(session.send).toHaveBeenCalledWith('hello');
expect(session.send).toHaveBeenCalledWith('hello', undefined);
expect(runtime.meta.turnCount).toBe(3);
expect(parsed.toolCalls[0]).toEqual({
id: 't1',
Expand Down Expand Up @@ -146,6 +146,50 @@ describe('session-runtime adapters', () => {
});
});

it('adapter forwards signal to underlying SessionCompatible.send', async () => {
const session = {
meta: { id: 's1', status: 'active' as const },
send: vi.fn().mockResolvedValue({ content: 'ok', toolCalls: [] }),
messages: vi.fn().mockResolvedValue([]),
consolidate: vi.fn(),
setTools: vi.fn(),
};
const runtime = await adaptSessionToEngineRuntime(session, {});

const controller = new AbortController();
await runtime.send('hi', { signal: controller.signal });
expect(session.send).toHaveBeenCalledWith('hi', { signal: controller.signal });
});

it('adapter forwards signal to underlying SessionCompatible.stream', async () => {
const streamSource = {
result: Promise.resolve({ content: 'ok', toolCalls: [] }),
async *[Symbol.asyncIterator]() {
yield 'a';
},
};
const session = {
meta: { id: 's1', status: 'active' as const },
send: vi.fn(),
stream: vi.fn(() => streamSource),
messages: vi.fn().mockResolvedValue([]),
consolidate: vi.fn(),
setTools: vi.fn(),
};
const runtime = await adaptSessionToEngineRuntime(session, {});

const controller = new AbortController();
const stream = runtime.stream!('hi', { signal: controller.signal });
const drained: string[] = [];
for await (const chunk of stream) {
drained.push(chunk);
}
await stream.result;

expect(drained).toEqual(['a']);
expect(session.stream).toHaveBeenCalledWith('hi', { signal: controller.signal });
});

it('adapter exposes tools getter and forwards setTools to underlying Session', async () => {
const sessionTools: Array<{ name: string; description: string; inputSchema: object }> = [{ name: 'a', description: 'd', inputSchema: {} }];
const setToolsSpy = vi.fn((t) => {
Expand Down
22 changes: 16 additions & 6 deletions packages/core/src/adapters/session-runtime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,25 @@ export interface SessionCompatibleForkOptions {
compressFn?: SessionCompatibleCompressFn;
}

/** Session.send / Session.stream 的可选运行时参数(结构兼容 @stello-ai/session) */
export interface SessionCompatibleSendOptions {
/** AbortSignal — abort 时底层 LLM 调用应被取消 */
signal?: AbortSignal;
}

/** 结构兼容 @stello-ai/session 的 Session */
export interface SessionCompatible {
meta: {
id: string;
status: 'active' | 'archived';
};
send(content: string): Promise<SessionCompatibleSendResult>;
send(
content: string,
options?: SessionCompatibleSendOptions,
): Promise<SessionCompatibleSendResult>;
stream?(
content: string
content: string,
options?: SessionCompatibleSendOptions,
): AsyncIterable<string> & { result: Promise<SessionCompatibleSendResult> };
messages(): Promise<Array<{ role: string; content: string; timestamp?: string }>>;
consolidate(): Promise<void>;
Expand Down Expand Up @@ -159,8 +169,8 @@ export async function adaptSessionToEngineRuntime(
get turnCount() {
return turnCount;
},
async send(input: string): Promise<string> {
const result = await session.send(input);
async send(input: string, sendOptions?: SessionCompatibleSendOptions): Promise<string> {
const result = await session.send(input, sendOptions);
turnCount += 1;
return (options.serializeResult ?? serializeSessionSendResult)(result);
},
Expand All @@ -175,8 +185,8 @@ export async function adaptSessionToEngineRuntime(
},
...(session.stream
? {
stream(input: string) {
const source = session.stream!(input);
stream(input: string, sendOptions?: SessionCompatibleSendOptions) {
const source = session.stream!(input, sendOptions);
return {
result: (async () => {
const result = await source.result;
Expand Down
60 changes: 58 additions & 2 deletions packages/core/src/agent/__tests__/stello-agent.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ describe('StelloAgent', () => {
const result = await agent.turn('root', 'hello');

expect(agent.sessions).toBeDefined();
expect(runtimeSession.send).toHaveBeenCalledWith('hello');
expect(runtimeSession.send).toHaveBeenCalledWith('hello', { signal: undefined });
expect(result.turn.finalContent).toContain('"content":"done"');
});

Expand Down Expand Up @@ -118,6 +118,62 @@ describe('StelloAgent', () => {
expect(result.turn.finalContent).toContain('"content":"done"')
});

it('agent.stream(input, { signal }) 透传到 runtime session 并在 abort 时让 result reject', async () => {
const controller = new AbortController()
const runtimeSession = {
id: 'root',
meta: { id: 'root', turnCount: 0, status: 'active' as const },
turnCount: 0,
send: vi.fn(),
stream: vi.fn((_input: string, opts?: { signal?: AbortSignal }) => {
let rejectResult: (err: unknown) => void = () => {}
const result = new Promise<string>((_resolve, reject) => { rejectResult = reject })
result.catch(() => {})
return {
result,
async *[Symbol.asyncIterator]() {
try {
for (const chunk of ['a', 'b', 'c']) {
if (opts?.signal?.aborted) {
const err = new DOMException('aborted', 'AbortError')
rejectResult(err)
throw err
}
await new Promise((r) => setTimeout(r, 5))
yield chunk
}
} catch (err) {
rejectResult(err)
throw err
}
},
}
}),
consolidate: vi.fn(),
setTools: vi.fn(),
}

const agent = createStelloAgent(baseConfig({ runtimeSession }))
const stream = await agent.stream('root', 'hello', { signal: controller.signal })

const collected: string[] = []
const iter = (async () => {
try {
for await (const chunk of stream) {
collected.push(chunk)
if (collected.length === 1) controller.abort()
}
} catch {
// expected: iterator re-throws AbortError
}
})()

await expect(stream.result).rejects.toMatchObject({ name: 'AbortError' })
await iter

expect(runtimeSession.stream).toHaveBeenCalledWith('hello', { signal: controller.signal })
})

it('默认树形拓扑:子节点 fork 出的新节点挂在自己下面', async () => {
const childSession = {
...rootSession,
Expand Down Expand Up @@ -400,7 +456,7 @@ describe('StelloAgent', () => {

const result = await agent.turn('root', 'hello');

expect(session.send).toHaveBeenCalledWith('hello');
expect(session.send).toHaveBeenCalledWith('hello', { signal: undefined });
expect(result.turn.rawResponse).toContain('"content":"done"');
expect(result.turn.toolCallsExecuted).toBe(1);
});
Expand Down
171 changes: 171 additions & 0 deletions packages/core/src/engine/__tests__/turn-runner-abort.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import { describe, expect, it, vi } from 'vitest'
import { TurnRunner, type ToolCallParser } from '../turn-runner'

const parser: ToolCallParser = {
parse(raw) {
return JSON.parse(raw) as { content: string | null; toolCalls: Array<{ id?: string; name: string; args: Record<string, unknown> }> }
},
}

describe('TurnRunner.run AbortSignal', () => {
it('signal abort 在轮间生效,下一轮 send 不再发起', async () => {
const controller = new AbortController()
const session = {
id: 's1',
send: vi
.fn<(input: string, options?: { signal?: AbortSignal }) => Promise<string>>()
.mockResolvedValueOnce(
JSON.stringify({
content: null,
toolCalls: [{ id: '1', name: 'read', args: { path: 'a' } }],
}),
),
}
const tools = {
executeTool: vi.fn().mockImplementation(async () => {
controller.abort()
return { success: true, data: 'ok' }
}),
}
const onToolResult = vi.fn()

const runner = new TurnRunner(parser)
await expect(
runner.run(session, 'hello', tools, {
signal: controller.signal,
onToolResult,
}),
).rejects.toMatchObject({ name: 'AbortError' })

// 第二轮 session.send 不应被调用(signal 在 round 边界检查)
expect(session.send).toHaveBeenCalledTimes(1)
// tool 执行后立刻 abort,onToolResult 不应触发(避免 phantom result)
expect(onToolResult).not.toHaveBeenCalled()
})

it('已 abort 的 signal 立即拒绝,不调用 session.send', async () => {
const controller = new AbortController()
controller.abort()
const session = { id: 's1', send: vi.fn() }
const tools = { executeTool: vi.fn() }

const runner = new TurnRunner(parser)
await expect(
runner.run(session, 'hello', tools, { signal: controller.signal }),
).rejects.toMatchObject({ name: 'AbortError' })

expect(session.send).not.toHaveBeenCalled()
})

it('signal 透传到 session.send 与 tools.executeTool', async () => {
const controller = new AbortController()
const session = {
id: 's1',
send: vi
.fn<(input: string, options?: { signal?: AbortSignal }) => Promise<string>>()
.mockResolvedValueOnce(
JSON.stringify({
content: null,
toolCalls: [{ id: '1', name: 'read', args: {} }],
}),
)
.mockResolvedValueOnce(JSON.stringify({ content: 'done', toolCalls: [] })),
}
const tools = {
executeTool: vi
.fn<(name: string, args: Record<string, unknown>, id?: string, options?: { signal?: AbortSignal }) => Promise<{ success: boolean; data?: unknown }>>()
.mockResolvedValue({ success: true, data: 'x' }),
}

const runner = new TurnRunner(parser)
await runner.run(session, 'hi', tools, { signal: controller.signal })

// session.send 第一参数是 input;第二参数应携带 signal
expect(session.send).toHaveBeenCalledWith('hi', expect.objectContaining({ signal: controller.signal }))
// tools.executeTool 第四参数应携带 signal
expect(tools.executeTool).toHaveBeenCalledWith(
'read',
{},
'1',
expect.objectContaining({ signal: controller.signal }),
)
})
})

describe('TurnRunner.runStream AbortSignal', () => {
it('运行中 abort 后 result reject 为 AbortError,且后续不再 send', async () => {
const controller = new AbortController()
// 模拟 session.stream 的真实行为:iterator 抛 AbortError,result 也 reject。
function makeMockStream(chunks: string[]) {
let resolveResult: (raw: string) => void = () => {}
let rejectResult: (err: unknown) => void = () => {}
const result = new Promise<string>((resolve, reject) => {
resolveResult = resolve
rejectResult = reject
})
result.catch(() => {})
return {
result,
async *[Symbol.asyncIterator]() {
try {
for (const chunk of chunks) {
if (controller.signal.aborted) {
const err = new DOMException('aborted', 'AbortError')
rejectResult(err)
throw err
}
await new Promise((r) => setTimeout(r, 5))
yield chunk
}
resolveResult(JSON.stringify({ content: chunks.join(''), toolCalls: [] }))
} catch (err) {
rejectResult(err)
throw err
}
},
}
}
const session = {
id: 's1',
stream: vi.fn(() => makeMockStream(['a', 'b', 'c'])),
send: vi.fn(),
}
const tools = { executeTool: vi.fn() }

const runner = new TurnRunner(parser)
const stream = runner.runStream(session, 'hi', tools, { signal: controller.signal })

const collected: string[] = []
const iter = (async () => {
try {
for await (const chunk of stream) {
collected.push(chunk)
if (collected.length === 1) controller.abort()
}
} catch {
// iterator re-throws AbortError per plan; consumer-side ok
}
})()

await expect(stream.result).rejects.toMatchObject({ name: 'AbortError' })
await iter

expect(session.send).not.toHaveBeenCalled()
})

it('已 abort signal 让 runStream 立即让 result reject', async () => {
const controller = new AbortController()
controller.abort()
const session = {
id: 's1',
send: vi.fn(),
stream: vi.fn(),
}
const tools = { executeTool: vi.fn() }

const runner = new TurnRunner(parser)
const stream = runner.runStream(session, 'hi', tools, { signal: controller.signal })
await expect(stream.result).rejects.toMatchObject({ name: 'AbortError' })
expect(session.stream).not.toHaveBeenCalled()
})
})
8 changes: 4 additions & 4 deletions packages/core/src/engine/__tests__/turn-runner.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ describe('TurnRunner', () => {
const result = await runner.run(session, 'hello', tools);

expect(session.send).toHaveBeenCalledTimes(1);
expect(session.send).toHaveBeenCalledWith('hello');
expect(session.send).toHaveBeenCalledWith('hello', { signal: undefined });
expect(result.finalContent).toBe('final');
expect(result.toolRoundCount).toBe(0);
expect(result.toolCallsExecuted).toBe(0);
Expand All @@ -48,7 +48,7 @@ describe('TurnRunner', () => {
const result = await runner.run(session, 'hello', tools);

expect(session.send).toHaveBeenCalledTimes(2);
expect(tools.executeTool).toHaveBeenCalledWith('read', { path: 'core.name' }, '1');
expect(tools.executeTool).toHaveBeenCalledWith('read', { path: 'core.name' }, '1', { signal: undefined });
expect(session.send.mock.calls[1]?.[0]).toContain('"toolResults"');
expect(result.finalContent).toBe('done');
expect(result.toolRoundCount).toBe(1);
Expand Down Expand Up @@ -79,8 +79,8 @@ describe('TurnRunner', () => {
const result = await runner.run(session, 'hello', tools);

expect(tools.executeTool.mock.calls).toEqual([
['read', { path: 'core.name' }, undefined],
['list', { scope: 'ui' }, undefined],
['read', { path: 'core.name' }, undefined, { signal: undefined }],
['list', { scope: 'ui' }, undefined, { signal: undefined }],
]);
expect(result.toolCallsExecuted).toBe(2);
});
Expand Down
Loading