Skip to content

Commit 4ed61b6

Browse files
committed
feat: add streaming methods for elicitation and createMessage
1 parent a6ee2cb commit 4ed61b6

File tree

2 files changed

+730
-1
lines changed

2 files changed

+730
-1
lines changed

src/experimental/tasks/server.ts

Lines changed: 198 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,21 @@ import type { Server } from '../../server/index.js';
99
import type { RequestOptions } from '../../shared/protocol.js';
1010
import type { ResponseMessage } from '../../shared/responseMessage.js';
1111
import type { AnySchema, SchemaOutput } from '../../server/zod-compat.js';
12-
import type { ServerRequest, Notification, Request, Result, GetTaskResult, ListTasksResult, CancelTaskResult } from '../../types.js';
12+
import type {
13+
ServerRequest,
14+
Notification,
15+
Request,
16+
Result,
17+
GetTaskResult,
18+
ListTasksResult,
19+
CancelTaskResult,
20+
CreateMessageRequestParams,
21+
CreateMessageResult,
22+
ElicitRequestFormParams,
23+
ElicitRequestURLParams,
24+
ElicitResult
25+
} from '../../types.js';
26+
import { CreateMessageResultSchema, ElicitResultSchema } from '../../types.js';
1327

1428
/**
1529
* Experimental task features for low-level MCP servers.
@@ -60,6 +74,189 @@ export class ExperimentalServerTasks<
6074
return (this._server as unknown as ServerWithRequestStream).requestStream(request, resultSchema, options);
6175
}
6276

77+
/**
78+
* Sends a sampling request and returns an AsyncGenerator that yields response messages.
79+
* The generator is guaranteed to end with either a 'result' or 'error' message.
80+
*
81+
* For task-augmented requests, yields 'taskCreated' and 'taskStatus' messages
82+
* before the final result.
83+
*
84+
* @example
85+
* ```typescript
86+
* const stream = server.experimental.tasks.createMessageStream({
87+
* messages: [{ role: 'user', content: { type: 'text', text: 'Hello' } }],
88+
* maxTokens: 100
89+
* }, {
90+
* onprogress: (progress) => {
91+
* // Handle streaming tokens via progress notifications
92+
* console.log('Progress:', progress.message);
93+
* }
94+
* });
95+
*
96+
* for await (const message of stream) {
97+
* switch (message.type) {
98+
* case 'taskCreated':
99+
* console.log('Task created:', message.task.taskId);
100+
* break;
101+
* case 'taskStatus':
102+
* console.log('Task status:', message.task.status);
103+
* break;
104+
* case 'result':
105+
* console.log('Final result:', message.result);
106+
* break;
107+
* case 'error':
108+
* console.error('Error:', message.error);
109+
* break;
110+
* }
111+
* }
112+
* ```
113+
*
114+
* @param params - The sampling request parameters
115+
* @param options - Optional request options (timeout, signal, task creation params, onprogress, etc.)
116+
* @returns AsyncGenerator that yields ResponseMessage objects
117+
*
118+
* @experimental
119+
*/
120+
createMessageStream(
121+
params: CreateMessageRequestParams,
122+
options?: RequestOptions
123+
): AsyncGenerator<ResponseMessage<CreateMessageResult>, void, void> {
124+
// Access client capabilities via the server
125+
type ServerWithCapabilities = {
126+
getClientCapabilities(): { sampling?: { tools?: boolean } } | undefined;
127+
};
128+
const clientCapabilities = (this._server as unknown as ServerWithCapabilities).getClientCapabilities();
129+
130+
// Capability check - only required when tools/toolChoice are provided
131+
if (params.tools || params.toolChoice) {
132+
if (!clientCapabilities?.sampling?.tools) {
133+
throw new Error('Client does not support sampling tools capability.');
134+
}
135+
}
136+
137+
// Message structure validation - always validate tool_use/tool_result pairs.
138+
// These may appear even without tools/toolChoice in the current request when
139+
// a previous sampling request returned tool_use and this is a follow-up with results.
140+
if (params.messages.length > 0) {
141+
const lastMessage = params.messages[params.messages.length - 1];
142+
const lastContent = Array.isArray(lastMessage.content) ? lastMessage.content : [lastMessage.content];
143+
const hasToolResults = lastContent.some(c => c.type === 'tool_result');
144+
145+
const previousMessage = params.messages.length > 1 ? params.messages[params.messages.length - 2] : undefined;
146+
const previousContent = previousMessage
147+
? Array.isArray(previousMessage.content)
148+
? previousMessage.content
149+
: [previousMessage.content]
150+
: [];
151+
const hasPreviousToolUse = previousContent.some(c => c.type === 'tool_use');
152+
153+
if (hasToolResults) {
154+
if (lastContent.some(c => c.type !== 'tool_result')) {
155+
throw new Error('The last message must contain only tool_result content if any is present');
156+
}
157+
if (!hasPreviousToolUse) {
158+
throw new Error('tool_result blocks are not matching any tool_use from the previous message');
159+
}
160+
}
161+
if (hasPreviousToolUse) {
162+
type ToolUseContent = { type: 'tool_use'; id: string };
163+
type ToolResultContent = { type: 'tool_result'; toolUseId: string };
164+
const toolUseIds = new Set(previousContent.filter(c => c.type === 'tool_use').map(c => (c as ToolUseContent).id));
165+
const toolResultIds = new Set(
166+
lastContent.filter(c => c.type === 'tool_result').map(c => (c as ToolResultContent).toolUseId)
167+
);
168+
if (toolUseIds.size !== toolResultIds.size || ![...toolUseIds].every(id => toolResultIds.has(id))) {
169+
throw new Error('ids of tool_result blocks and tool_use blocks from previous message do not match');
170+
}
171+
}
172+
}
173+
174+
const request = {
175+
method: 'sampling/createMessage' as const,
176+
params
177+
};
178+
return this.requestStream(request, CreateMessageResultSchema, options);
179+
}
180+
181+
/**
182+
* Sends an elicitation request and returns an AsyncGenerator that yields response messages.
183+
* The generator is guaranteed to end with either a 'result' or 'error' message.
184+
*
185+
* For task-augmented requests (especially URL-based elicitation), yields 'taskCreated'
186+
* and 'taskStatus' messages before the final result.
187+
*
188+
* @example
189+
* ```typescript
190+
* const stream = server.experimental.tasks.elicitInputStream({
191+
* mode: 'url',
192+
* message: 'Please authenticate',
193+
* elicitationId: 'auth-123',
194+
* url: 'https://example.com/auth'
195+
* }, {
196+
* task: { ttl: 300000 } // Task-augmented for long-running auth flow
197+
* });
198+
*
199+
* for await (const message of stream) {
200+
* switch (message.type) {
201+
* case 'taskCreated':
202+
* console.log('Task created:', message.task.taskId);
203+
* break;
204+
* case 'taskStatus':
205+
* console.log('Task status:', message.task.status);
206+
* break;
207+
* case 'result':
208+
* console.log('User action:', message.result.action);
209+
* break;
210+
* case 'error':
211+
* console.error('Error:', message.error);
212+
* break;
213+
* }
214+
* }
215+
* ```
216+
*
217+
* @param params - The elicitation request parameters
218+
* @param options - Optional request options (timeout, signal, task creation params, etc.)
219+
* @returns AsyncGenerator that yields ResponseMessage objects
220+
*
221+
* @experimental
222+
*/
223+
elicitInputStream(
224+
params: ElicitRequestFormParams | ElicitRequestURLParams,
225+
options?: RequestOptions
226+
): AsyncGenerator<ResponseMessage<ElicitResult>, void, void> {
227+
// Access client capabilities via the server
228+
type ServerWithCapabilities = {
229+
getClientCapabilities(): { elicitation?: { form?: boolean; url?: boolean } } | undefined;
230+
};
231+
const clientCapabilities = (this._server as unknown as ServerWithCapabilities).getClientCapabilities();
232+
233+
const mode = (params.mode ?? 'form') as 'form' | 'url';
234+
235+
// Capability check based on mode
236+
switch (mode) {
237+
case 'url':
238+
if (!clientCapabilities?.elicitation?.url) {
239+
throw new Error('Client does not support url elicitation.');
240+
}
241+
break;
242+
case 'form':
243+
if (!clientCapabilities?.elicitation?.form) {
244+
throw new Error('Client does not support form elicitation.');
245+
}
246+
break;
247+
}
248+
249+
// Normalize params to ensure mode is set
250+
const normalizedParams =
251+
mode === 'form' && params.mode !== 'form' ? { ...(params as ElicitRequestFormParams), mode: 'form' as const } : params;
252+
253+
const request = {
254+
method: 'elicitation/create' as const,
255+
params: normalizedParams
256+
};
257+
return this.requestStream(request, ElicitResultSchema, options);
258+
}
259+
63260
/**
64261
* Gets the current status of a task.
65262
*

0 commit comments

Comments
 (0)