diff --git a/src/examples/client/simpleStreamableHttp.ts b/src/examples/client/simpleStreamableHttp.ts index 21ab4f556..08f205592 100644 --- a/src/examples/client/simpleStreamableHttp.ts +++ b/src/examples/client/simpleStreamableHttp.ts @@ -24,6 +24,7 @@ import { } from '../../types.js'; import { getDisplayName } from '../../shared/metadataUtils.js'; import { Ajv } from 'ajv'; +import { InMemoryTaskStore } from '../../experimental/tasks/stores/in-memory.js'; // Create readline interface for user input const readline = createInterface({ @@ -65,6 +66,7 @@ function printHelp(): void { console.log(' greet [name] - Call the greet tool'); console.log(' multi-greet [name] - Call the multi-greet tool with notifications'); console.log(' collect-info [type] - Test form elicitation with collect-user-info tool (contact/preferences/feedback)'); + console.log(' collect-info-task [type] - Test bidirectional task support (server+client tasks) with elicitation'); console.log(' start-notifications [interval] [count] - Start periodic notifications'); console.log(' run-notifications-tool-with-resumability [interval] [count] - Run notification tool with resumability'); console.log(' list-prompts - List available prompts'); @@ -131,6 +133,10 @@ function commandLoop(): void { await callCollectInfoTool(args[1] || 'contact'); break; + case 'collect-info-task': + await callCollectInfoWithTask(args[1] || 'contact'); + break; + case 'start-notifications': { const interval = args[1] ? parseInt(args[1], 10) : 2000; const count = args[2] ? parseInt(args[2], 10) : 10; @@ -232,7 +238,10 @@ async function connect(url?: string): Promise { console.log(`Connecting to ${serverUrl}...`); try { - // Create a new client with form elicitation capability + // Create task store for client-side task support + const clientTaskStore = new InMemoryTaskStore(); + + // Create a new client with form elicitation capability and task support client = new Client( { name: 'example-client', @@ -242,25 +251,46 @@ async function connect(url?: string): Promise { capabilities: { elicitation: { form: {} + }, + tasks: { + requests: { + elicitation: { + create: {} + } + } } - } + }, + taskStore: clientTaskStore } ); client.onerror = error => { console.error('\x1b[31mClient error:', error, '\x1b[0m'); }; - // Set up elicitation request handler with proper validation - client.setRequestHandler(ElicitRequestSchema, async request => { + // Set up elicitation request handler with proper validation and task support + client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { if (request.params.mode !== 'form') { throw new McpError(ErrorCode.InvalidParams, `Unsupported elicitation mode: ${request.params.mode}`); } console.log('\nšŸ”” Elicitation (form) Request Received:'); console.log(`Message: ${request.params.message}`); console.log(`Related Task: ${request.params._meta?.[RELATED_TASK_META_KEY]?.taskId}`); + console.log(`Task Creation Requested: ${request.params.task ? 'yes' : 'no'}`); console.log('Requested Schema:'); console.log(JSON.stringify(request.params.requestedSchema, null, 2)); + // Helper to return result, optionally creating a task if requested + const returnResult = async (result: { action: 'accept' | 'decline' | 'cancel'; content?: Record }) => { + if (request.params.task && extra.taskStore) { + // Create a task and store the result + const task = await extra.taskStore.createTask({ ttl: extra.taskRequestedTtl }); + await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + console.log(`šŸ“‹ Created client-side task: ${task.taskId}`); + return { task }; + } + return result; + }; + const schema = request.params.requestedSchema; const properties = schema.properties; const required = schema.required || []; @@ -381,7 +411,7 @@ async function connect(url?: string): Promise { } if (inputCancelled) { - return { action: 'cancel' }; + return returnResult({ action: 'cancel' }); } // If we didn't complete all fields due to an error, try again @@ -394,7 +424,7 @@ async function connect(url?: string): Promise { continue; } else { console.log('Maximum attempts reached. Declining request.'); - return { action: 'decline' }; + return returnResult({ action: 'decline' }); } } @@ -412,7 +442,7 @@ async function connect(url?: string): Promise { continue; } else { console.log('Maximum attempts reached. Declining request.'); - return { action: 'decline' }; + return returnResult({ action: 'decline' }); } } @@ -427,24 +457,24 @@ async function connect(url?: string): Promise { }); if (confirmAnswer === 'yes' || confirmAnswer === 'y') { - return { + return returnResult({ action: 'accept', content - }; + }); } else if (confirmAnswer === 'cancel' || confirmAnswer === 'c') { - return { action: 'cancel' }; + return returnResult({ action: 'cancel' }); } else if (confirmAnswer === 'no' || confirmAnswer === 'n') { if (attempts < maxAttempts) { console.log('Please re-enter the information...'); continue; } else { - return { action: 'decline' }; + return returnResult({ action: 'decline' }); } } } console.log('Maximum attempts reached. Declining request.'); - return { action: 'decline' }; + return returnResult({ action: 'decline' }); }); transport = new StreamableHTTPClientTransport(new URL(serverUrl), { @@ -641,6 +671,12 @@ async function callCollectInfoTool(infoType: string): Promise { await callTool('collect-user-info', { infoType }); } +async function callCollectInfoWithTask(infoType: string): Promise { + console.log(`\nšŸ”„ Testing bidirectional task support with collect-user-info-task tool (${infoType})...`); + console.log('This will create a task on the server, which will elicit input and create a task on the client.\n'); + await callToolTask('collect-user-info-task', { infoType }); +} + async function startNotifications(interval: number, count: number): Promise { console.log(`Starting notification stream: interval=${interval}ms, count=${count || 'unlimited'}`); await callTool('start-notification-stream', { interval, count }); diff --git a/src/examples/server/simpleStreamableHttp.ts b/src/examples/server/simpleStreamableHttp.ts index 3500ac066..750b70d84 100644 --- a/src/examples/server/simpleStreamableHttp.ts +++ b/src/examples/server/simpleStreamableHttp.ts @@ -8,6 +8,7 @@ import { requireBearerAuth } from '../../server/auth/middleware/bearerAuth.js'; import { createMcpExpressApp } from '../../server/index.js'; import { CallToolResult, + ElicitResult, ElicitResultSchema, GetPromptResult, isInitializeRequest, @@ -500,6 +501,114 @@ const getServer = () => { } ); + // Register a tool that demonstrates bidirectional task support: + // Server creates a task, then elicits input from client using elicitInputStream + // Using the experimental tasks API - WARNING: may change without notice + server.experimental.tasks.registerToolTask( + 'collect-user-info-task', + { + title: 'Collect Info with Task', + description: 'Collects user info via elicitation with task support using elicitInputStream', + inputSchema: { + infoType: z.enum(['contact', 'preferences']).describe('Type of information to collect').default('contact') + } + }, + { + async createTask({ infoType }, { taskStore, taskRequestedTtl }) { + // Create the server-side task + const task = await taskStore.createTask({ + ttl: taskRequestedTtl + }); + + // Perform async work that makes a nested elicitation request using elicitInputStream + (async () => { + try { + const message = infoType === 'contact' ? 'Please provide your contact information' : 'Please set your preferences'; + + // Define schemas with proper typing for PrimitiveSchemaDefinition + const contactSchema: { + type: 'object'; + properties: Record; + required: string[]; + } = { + type: 'object', + properties: { + name: { type: 'string', title: 'Full Name', description: 'Your full name' }, + email: { type: 'string', title: 'Email', description: 'Your email address' } + }, + required: ['name', 'email'] + }; + + const preferencesSchema: { + type: 'object'; + properties: Record; + required: string[]; + } = { + type: 'object', + properties: { + theme: { type: 'string', title: 'Theme', enum: ['light', 'dark', 'auto'] }, + notifications: { type: 'boolean', title: 'Enable Notifications', default: true } + }, + required: ['theme'] + }; + + const requestedSchema = infoType === 'contact' ? contactSchema : preferencesSchema; + + // Use elicitInputStream to elicit input from client + // This demonstrates the streaming elicitation API + // Access via server.server to get the underlying Server instance + const stream = server.server.experimental.tasks.elicitInputStream({ + mode: 'form', + message, + requestedSchema + }); + + let elicitResult: ElicitResult | undefined; + for await (const msg of stream) { + if (msg.type === 'result') { + elicitResult = msg.result as ElicitResult; + } else if (msg.type === 'error') { + throw msg.error; + } + } + + if (!elicitResult) { + throw new Error('No result received from elicitation'); + } + + let resultText: string; + if (elicitResult.action === 'accept') { + resultText = `Collected ${infoType} info: ${JSON.stringify(elicitResult.content, null, 2)}`; + } else if (elicitResult.action === 'decline') { + resultText = `User declined to provide ${infoType} information`; + } else { + resultText = 'User cancelled the request'; + } + + await taskStore.storeTaskResult(task.taskId, 'completed', { + content: [{ type: 'text', text: resultText }] + }); + } catch (error) { + console.error('Error in collect-user-info-task:', error); + await taskStore.storeTaskResult(task.taskId, 'failed', { + content: [{ type: 'text', text: `Error: ${error}` }], + isError: true + }); + } + })(); + + return { task }; + }, + async getTask(_args, { taskId, taskStore }) { + return await taskStore.getTask(taskId); + }, + async getTaskResult(_args, { taskId, taskStore }) { + const result = await taskStore.getTaskResult(taskId); + return result as CallToolResult; + } + } + ); + return server; }; diff --git a/src/experimental/tasks/server.ts b/src/experimental/tasks/server.ts index a4150a8d7..b062b02b6 100644 --- a/src/experimental/tasks/server.ts +++ b/src/experimental/tasks/server.ts @@ -9,7 +9,21 @@ import type { Server } from '../../server/index.js'; import type { RequestOptions } from '../../shared/protocol.js'; import type { ResponseMessage } from '../../shared/responseMessage.js'; import type { AnySchema, SchemaOutput } from '../../server/zod-compat.js'; -import type { ServerRequest, Notification, Request, Result, GetTaskResult, ListTasksResult, CancelTaskResult } from '../../types.js'; +import type { + ServerRequest, + Notification, + Request, + Result, + GetTaskResult, + ListTasksResult, + CancelTaskResult, + CreateMessageRequestParams, + CreateMessageResult, + ElicitRequestFormParams, + ElicitRequestURLParams, + ElicitResult +} from '../../types.js'; +import { CreateMessageResultSchema, ElicitResultSchema } from '../../types.js'; /** * Experimental task features for low-level MCP servers. @@ -60,6 +74,189 @@ export class ExperimentalServerTasks< return (this._server as unknown as ServerWithRequestStream).requestStream(request, resultSchema, options); } + /** + * Sends a sampling request and returns an AsyncGenerator that yields response messages. + * The generator is guaranteed to end with either a 'result' or 'error' message. + * + * For task-augmented requests, yields 'taskCreated' and 'taskStatus' messages + * before the final result. + * + * @example + * ```typescript + * const stream = server.experimental.tasks.createMessageStream({ + * messages: [{ role: 'user', content: { type: 'text', text: 'Hello' } }], + * maxTokens: 100 + * }, { + * onprogress: (progress) => { + * // Handle streaming tokens via progress notifications + * console.log('Progress:', progress.message); + * } + * }); + * + * for await (const message of stream) { + * switch (message.type) { + * case 'taskCreated': + * console.log('Task created:', message.task.taskId); + * break; + * case 'taskStatus': + * console.log('Task status:', message.task.status); + * break; + * case 'result': + * console.log('Final result:', message.result); + * break; + * case 'error': + * console.error('Error:', message.error); + * break; + * } + * } + * ``` + * + * @param params - The sampling request parameters + * @param options - Optional request options (timeout, signal, task creation params, onprogress, etc.) + * @returns AsyncGenerator that yields ResponseMessage objects + * + * @experimental + */ + createMessageStream( + params: CreateMessageRequestParams, + options?: RequestOptions + ): AsyncGenerator, void, void> { + // Access client capabilities via the server + type ServerWithCapabilities = { + getClientCapabilities(): { sampling?: { tools?: boolean } } | undefined; + }; + const clientCapabilities = (this._server as unknown as ServerWithCapabilities).getClientCapabilities(); + + // Capability check - only required when tools/toolChoice are provided + if (params.tools || params.toolChoice) { + if (!clientCapabilities?.sampling?.tools) { + throw new Error('Client does not support sampling tools capability.'); + } + } + + // Message structure validation - always validate tool_use/tool_result pairs. + // These may appear even without tools/toolChoice in the current request when + // a previous sampling request returned tool_use and this is a follow-up with results. + if (params.messages.length > 0) { + const lastMessage = params.messages[params.messages.length - 1]; + const lastContent = Array.isArray(lastMessage.content) ? lastMessage.content : [lastMessage.content]; + const hasToolResults = lastContent.some(c => c.type === 'tool_result'); + + const previousMessage = params.messages.length > 1 ? params.messages[params.messages.length - 2] : undefined; + const previousContent = previousMessage + ? Array.isArray(previousMessage.content) + ? previousMessage.content + : [previousMessage.content] + : []; + const hasPreviousToolUse = previousContent.some(c => c.type === 'tool_use'); + + if (hasToolResults) { + if (lastContent.some(c => c.type !== 'tool_result')) { + throw new Error('The last message must contain only tool_result content if any is present'); + } + if (!hasPreviousToolUse) { + throw new Error('tool_result blocks are not matching any tool_use from the previous message'); + } + } + if (hasPreviousToolUse) { + type ToolUseContent = { type: 'tool_use'; id: string }; + type ToolResultContent = { type: 'tool_result'; toolUseId: string }; + const toolUseIds = new Set(previousContent.filter(c => c.type === 'tool_use').map(c => (c as ToolUseContent).id)); + const toolResultIds = new Set( + lastContent.filter(c => c.type === 'tool_result').map(c => (c as ToolResultContent).toolUseId) + ); + if (toolUseIds.size !== toolResultIds.size || ![...toolUseIds].every(id => toolResultIds.has(id))) { + throw new Error('ids of tool_result blocks and tool_use blocks from previous message do not match'); + } + } + } + + const request = { + method: 'sampling/createMessage' as const, + params + }; + return this.requestStream(request, CreateMessageResultSchema, options); + } + + /** + * Sends an elicitation request and returns an AsyncGenerator that yields response messages. + * The generator is guaranteed to end with either a 'result' or 'error' message. + * + * For task-augmented requests (especially URL-based elicitation), yields 'taskCreated' + * and 'taskStatus' messages before the final result. + * + * @example + * ```typescript + * const stream = server.experimental.tasks.elicitInputStream({ + * mode: 'url', + * message: 'Please authenticate', + * elicitationId: 'auth-123', + * url: 'https://example.com/auth' + * }, { + * task: { ttl: 300000 } // Task-augmented for long-running auth flow + * }); + * + * for await (const message of stream) { + * switch (message.type) { + * case 'taskCreated': + * console.log('Task created:', message.task.taskId); + * break; + * case 'taskStatus': + * console.log('Task status:', message.task.status); + * break; + * case 'result': + * console.log('User action:', message.result.action); + * break; + * case 'error': + * console.error('Error:', message.error); + * break; + * } + * } + * ``` + * + * @param params - The elicitation request parameters + * @param options - Optional request options (timeout, signal, task creation params, etc.) + * @returns AsyncGenerator that yields ResponseMessage objects + * + * @experimental + */ + elicitInputStream( + params: ElicitRequestFormParams | ElicitRequestURLParams, + options?: RequestOptions + ): AsyncGenerator, void, void> { + // Access client capabilities via the server + type ServerWithCapabilities = { + getClientCapabilities(): { elicitation?: { form?: boolean; url?: boolean } } | undefined; + }; + const clientCapabilities = (this._server as unknown as ServerWithCapabilities).getClientCapabilities(); + + const mode = (params.mode ?? 'form') as 'form' | 'url'; + + // Capability check based on mode + switch (mode) { + case 'url': + if (!clientCapabilities?.elicitation?.url) { + throw new Error('Client does not support url elicitation.'); + } + break; + case 'form': + if (!clientCapabilities?.elicitation?.form) { + throw new Error('Client does not support form elicitation.'); + } + break; + } + + // Normalize params to ensure mode is set + const normalizedParams = + mode === 'form' && params.mode !== 'form' ? { ...(params as ElicitRequestFormParams), mode: 'form' as const } : params; + + const request = { + method: 'elicitation/create' as const, + params: normalizedParams + }; + return this.requestStream(request, ElicitResultSchema, options); + } + /** * Gets the current status of a task. * diff --git a/src/server/index.test.ts b/src/server/index.test.ts index c01e638d0..e03c73c89 100644 --- a/src/server/index.test.ts +++ b/src/server/index.test.ts @@ -3,6 +3,8 @@ import supertest from 'supertest'; import { Client } from '../client/index.js'; import { InMemoryTransport } from '../inMemory.js'; import type { Transport } from '../shared/transport.js'; +import { toArrayAsync, type ResponseMessage } from '../shared/responseMessage.js'; +import type { CreateMessageResult, ElicitResult, Task } from '../types.js'; import { createMcpExpressApp } from './index.js'; import { CreateMessageRequestSchema, @@ -1926,6 +1928,236 @@ describe('createMessage validation', () => { }); }); +describe('createMessageStream', () => { + test('should throw when tools are provided without sampling.tools capability', async () => { + const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} }); + const client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: {} } }); + + client.setRequestHandler(CreateMessageRequestSchema, async () => ({ + role: 'assistant', + content: { type: 'text', text: 'Response' }, + model: 'test-model' + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + expect(() => { + server.experimental.tasks.createMessageStream({ + messages: [{ role: 'user', content: { type: 'text', text: 'Hello' } }], + maxTokens: 100, + tools: [{ name: 'test_tool', inputSchema: { type: 'object' } }] + }); + }).toThrow('Client does not support sampling tools capability'); + }); + + test('should throw when tool_result has no matching tool_use in previous message', async () => { + const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} }); + const client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: {} } }); + + client.setRequestHandler(CreateMessageRequestSchema, async () => ({ + role: 'assistant', + content: { type: 'text', text: 'Response' }, + model: 'test-model' + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + expect(() => { + server.experimental.tasks.createMessageStream({ + messages: [ + { role: 'user', content: { type: 'text', text: 'Hello' } }, + { + role: 'user', + content: [{ type: 'tool_result', toolUseId: 'test-id', content: [{ type: 'text', text: 'result' }] }] + } + ], + maxTokens: 100 + }); + }).toThrow('tool_result blocks are not matching any tool_use from the previous message'); + }); + + describe('terminal message guarantees', () => { + test('should yield exactly one terminal message for successful request', async () => { + const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} }); + const client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: {} } }); + + client.setRequestHandler(CreateMessageRequestSchema, async () => ({ + role: 'assistant', + content: { type: 'text', text: 'Response' }, + model: 'test-model' + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + const stream = server.experimental.tasks.createMessageStream({ + messages: [{ role: 'user', content: { type: 'text', text: 'Hello' } }], + maxTokens: 100 + }); + + const allMessages = await toArrayAsync(stream); + + expect(allMessages.length).toBe(1); + expect(allMessages[0].type).toBe('result'); + + const taskMessages = allMessages.filter(m => m.type === 'taskCreated' || m.type === 'taskStatus'); + expect(taskMessages.length).toBe(0); + }); + + test('should yield error as terminal message when client returns error', async () => { + const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} }); + const client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: {} } }); + + client.setRequestHandler(CreateMessageRequestSchema, async () => { + throw new Error('Simulated client error'); + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + const stream = server.experimental.tasks.createMessageStream({ + messages: [{ role: 'user', content: { type: 'text', text: 'Hello' } }], + maxTokens: 100 + }); + + const allMessages = await toArrayAsync(stream); + + expect(allMessages.length).toBe(1); + expect(allMessages[0].type).toBe('error'); + }); + + test('should yield exactly one terminal message with result', async () => { + const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} }); + const client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: {} } }); + + client.setRequestHandler(CreateMessageRequestSchema, () => ({ + model: 'test-model', + role: 'assistant' as const, + content: { type: 'text' as const, text: 'Response' } + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + const stream = server.experimental.tasks.createMessageStream({ + messages: [{ role: 'user', content: { type: 'text', text: 'Message' } }], + maxTokens: 100 + }); + + const messages = await toArrayAsync(stream); + const terminalMessages = messages.filter(m => m.type === 'result' || m.type === 'error'); + + expect(terminalMessages.length).toBe(1); + + const lastMessage = messages[messages.length - 1]; + expect(lastMessage.type === 'result' || lastMessage.type === 'error').toBe(true); + + if (lastMessage.type === 'result') { + expect((lastMessage.result as CreateMessageResult).content).toBeDefined(); + } + }); + }); + + describe('non-task request minimality', () => { + test('should yield only result message for non-task request', async () => { + const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} }); + const client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: {} } }); + + client.setRequestHandler(CreateMessageRequestSchema, () => ({ + model: 'test-model', + role: 'assistant' as const, + content: { type: 'text' as const, text: 'Response' } + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + const stream = server.experimental.tasks.createMessageStream({ + messages: [{ role: 'user', content: { type: 'text', text: 'Message' } }], + maxTokens: 100 + }); + + const messages = await toArrayAsync(stream); + + const taskMessages = messages.filter(m => m.type === 'taskCreated' || m.type === 'taskStatus'); + expect(taskMessages.length).toBe(0); + + const resultMessages = messages.filter(m => m.type === 'result'); + expect(resultMessages.length).toBe(1); + + expect(messages.length).toBe(1); + }); + }); + + describe('task-augmented request handling', () => { + test('should yield taskCreated and result for task-augmented request', async () => { + const clientTaskStore = new InMemoryTaskStore(); + const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} }); + const client = new Client( + { name: 'test client', version: '1.0' }, + { + capabilities: { + sampling: {}, + tasks: { + requests: { + sampling: { createMessage: {} } + } + } + }, + taskStore: clientTaskStore + } + ); + + client.setRequestHandler(CreateMessageRequestSchema, async (request, extra) => { + const result = { + model: 'test-model', + role: 'assistant' as const, + content: { type: 'text' as const, text: 'Task response' } + }; + + if (request.params.task && extra.taskStore) { + const task = await extra.taskStore.createTask({ ttl: extra.taskRequestedTtl }); + await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + return { task }; + } + return result; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + const stream = server.experimental.tasks.createMessageStream( + { + messages: [{ role: 'user', content: { type: 'text', text: 'Task-augmented message' } }], + maxTokens: 100 + }, + { task: { ttl: 60000 } } + ); + + const messages = await toArrayAsync(stream); + + // Should have taskCreated and result + expect(messages.length).toBeGreaterThanOrEqual(2); + + // First message should be taskCreated + expect(messages[0].type).toBe('taskCreated'); + const taskCreated = messages[0] as { type: 'taskCreated'; task: Task }; + expect(taskCreated.task.taskId).toBeDefined(); + + // Last message should be result + const lastMessage = messages[messages.length - 1]; + expect(lastMessage.type).toBe('result'); + if (lastMessage.type === 'result') { + expect((lastMessage.result as CreateMessageResult).model).toBe('test-model'); + } + + clientTaskStore.cleanup(); + }); + }); +}); + describe('createMessage backwards compatibility', () => { test('createMessage without tools returns single content (backwards compat)', async () => { const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} }); @@ -3280,3 +3512,303 @@ test('should respect client task capabilities', async () => { clientTaskStore.cleanup(); }); + +describe('elicitInputStream', () => { + let server: Server; + let client: Client; + let clientTransport: ReturnType[0]; + let serverTransport: ReturnType[1]; + + beforeEach(async () => { + server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} }); + + client = new Client( + { name: 'test client', version: '1.0' }, + { + capabilities: { + elicitation: { + form: {}, + url: {} + } + } + } + ); + + [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + }); + + afterEach(async () => { + await server.close().catch(() => {}); + await client.close().catch(() => {}); + }); + + test('should throw when client does not support form elicitation', async () => { + // Create client without form elicitation capability + const noFormClient = new Client( + { name: 'test client', version: '1.0' }, + { + capabilities: { + elicitation: { + url: {} + } + } + } + ); + + const [noFormClientTransport, noFormServerTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([noFormClient.connect(noFormClientTransport), server.connect(noFormServerTransport)]); + + expect(() => { + server.experimental.tasks.elicitInputStream({ + mode: 'form', + message: 'Enter data', + requestedSchema: { type: 'object', properties: {} } + }); + }).toThrow('Client does not support form elicitation.'); + + await noFormClient.close().catch(() => {}); + }); + + test('should throw when client does not support url elicitation', async () => { + // Create client without url elicitation capability + const noUrlClient = new Client( + { name: 'test client', version: '1.0' }, + { + capabilities: { + elicitation: { + form: {} + } + } + } + ); + + const [noUrlClientTransport, noUrlServerTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([noUrlClient.connect(noUrlClientTransport), server.connect(noUrlServerTransport)]); + + expect(() => { + server.experimental.tasks.elicitInputStream({ + mode: 'url', + message: 'Open URL', + elicitationId: 'test-123', + url: 'https://example.com/auth' + }); + }).toThrow('Client does not support url elicitation.'); + + await noUrlClient.close().catch(() => {}); + }); + + test('should default to form mode when mode is not specified', async () => { + const requestStreamSpy = vi.spyOn(server.experimental.tasks, 'requestStream'); + + client.setRequestHandler(ElicitRequestSchema, () => ({ + action: 'accept', + content: { value: 'test' } + })); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Call without explicit mode + const params = { + message: 'Enter value', + requestedSchema: { + type: 'object' as const, + properties: { value: { type: 'string' as const } } + } + }; + + const stream = server.experimental.tasks.elicitInputStream( + params as Parameters[0] + ); + await toArrayAsync(stream); + + // Verify mode was normalized to 'form' + expect(requestStreamSpy).toHaveBeenCalledWith( + expect.objectContaining({ + method: 'elicitation/create', + params: expect.objectContaining({ mode: 'form' }) + }), + ElicitResultSchema, + undefined + ); + }); + + test('should yield error as terminal message when client returns error', async () => { + client.setRequestHandler(ElicitRequestSchema, () => { + throw new Error('Simulated client error'); + }); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + const stream = server.experimental.tasks.elicitInputStream({ + mode: 'form', + message: 'Enter data', + requestedSchema: { + type: 'object', + properties: { value: { type: 'string' } } + } + }); + + const allMessages = await toArrayAsync(stream); + + expect(allMessages.length).toBe(1); + expect(allMessages[0].type).toBe('error'); + }); + + // For any streaming elicitation request, the AsyncGenerator yields exactly one terminal + // message (either 'result' or 'error') as its final message. + describe('terminal message guarantees', () => { + test.each([ + { action: 'accept' as const, content: { data: 'test-value' } }, + { action: 'decline' as const, content: undefined }, + { action: 'cancel' as const, content: undefined } + ])('should yield exactly one terminal message for action: $action', async ({ action, content }) => { + client.setRequestHandler(ElicitRequestSchema, () => ({ + action, + content + })); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + const stream = server.experimental.tasks.elicitInputStream({ + mode: 'form', + message: 'Test message', + requestedSchema: { + type: 'object', + properties: { data: { type: 'string' } } + } + }); + + const messages = await toArrayAsync(stream); + + // Count terminal messages (result or error) + const terminalMessages = messages.filter(m => m.type === 'result' || m.type === 'error'); + + expect(terminalMessages.length).toBe(1); + + // Verify terminal message is the last message + const lastMessage = messages[messages.length - 1]; + expect(lastMessage.type === 'result' || lastMessage.type === 'error').toBe(true); + + // Verify result content matches expected action + if (lastMessage.type === 'result') { + expect((lastMessage.result as ElicitResult).action).toBe(action); + } + }); + }); + + // For any non-task elicitation request, the generator yields exactly one 'result' message + // (or 'error' if the request fails), with no 'taskCreated' or 'taskStatus' messages. + describe('non-task request minimality', () => { + test.each([ + { action: 'accept' as const, content: { value: 'test' } }, + { action: 'decline' as const, content: undefined }, + { action: 'cancel' as const, content: undefined } + ])('should yield only result message for non-task request with action: $action', async ({ action, content }) => { + client.setRequestHandler(ElicitRequestSchema, () => ({ + action, + content + })); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Non-task request (no task option) + const stream = server.experimental.tasks.elicitInputStream({ + mode: 'form', + message: 'Non-task request', + requestedSchema: { + type: 'object', + properties: { value: { type: 'string' } } + } + }); + + const messages = await toArrayAsync(stream); + + // Verify no taskCreated or taskStatus messages + const taskMessages = messages.filter(m => m.type === 'taskCreated' || m.type === 'taskStatus'); + expect(taskMessages.length).toBe(0); + + // Verify exactly one result message + const resultMessages = messages.filter(m => m.type === 'result'); + expect(resultMessages.length).toBe(1); + + // Verify total message count is 1 + expect(messages.length).toBe(1); + }); + }); + + // For any task-augmented elicitation request, the generator should yield at least one + // 'taskCreated' message followed by 'taskStatus' messages before yielding the final + // result or error. + describe('task-augmented request handling', () => { + test('should yield taskCreated and result for task-augmented request', async () => { + const clientTaskStore = new InMemoryTaskStore(); + const taskClient = new Client( + { name: 'test client', version: '1.0' }, + { + capabilities: { + elicitation: { form: {} }, + tasks: { + requests: { + elicitation: { create: {} } + } + } + }, + taskStore: clientTaskStore + } + ); + + taskClient.setRequestHandler(ElicitRequestSchema, async (request, extra) => { + const result = { + action: 'accept' as const, + content: { username: 'task-user' } + }; + + if (request.params.task && extra.taskStore) { + const task = await extra.taskStore.createTask({ ttl: extra.taskRequestedTtl }); + await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + return { task }; + } + return result; + }); + + const [taskClientTransport, taskServerTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([taskClient.connect(taskClientTransport), server.connect(taskServerTransport)]); + + const stream = server.experimental.tasks.elicitInputStream( + { + mode: 'form', + message: 'Task-augmented request', + requestedSchema: { + type: 'object', + properties: { username: { type: 'string' } }, + required: ['username'] + } + }, + { task: { ttl: 60000 } } + ); + + const messages = await toArrayAsync(stream); + + // Should have taskCreated and result + expect(messages.length).toBeGreaterThanOrEqual(2); + + // First message should be taskCreated + expect(messages[0].type).toBe('taskCreated'); + const taskCreated = messages[0] as { type: 'taskCreated'; task: Task }; + expect(taskCreated.task.taskId).toBeDefined(); + + // Last message should be result + const lastMessage = messages[messages.length - 1]; + expect(lastMessage.type).toBe('result'); + if (lastMessage.type === 'result') { + expect((lastMessage.result as ElicitResult).action).toBe('accept'); + expect((lastMessage.result as ElicitResult).content).toEqual({ username: 'task-user' }); + } + + clientTaskStore.cleanup(); + await taskClient.close().catch(() => {}); + }); + }); +});