diff --git a/docs/experimental/index.md b/docs/experimental/index.md new file mode 100644 index 0000000000..1d496b3f10 --- /dev/null +++ b/docs/experimental/index.md @@ -0,0 +1,43 @@ +# Experimental Features + +!!! warning "Experimental APIs" + + The features in this section are experimental and may change without notice. + They track the evolving MCP specification and are not yet stable. + +This section documents experimental features in the MCP Python SDK. These features +implement draft specifications that are still being refined. + +## Available Experimental Features + +### [Tasks](tasks.md) + +Tasks enable asynchronous execution of MCP operations. Instead of waiting for a +long-running operation to complete, the server returns a task reference immediately. +Clients can then poll for status updates and retrieve results when ready. + +Tasks are useful for: + +- **Long-running computations** that would otherwise block +- **Batch operations** that process many items +- **Interactive workflows** that require user input (elicitation) or LLM assistance (sampling) + +## Using Experimental APIs + +Experimental features are accessed via the `.experimental` property: + +```python +# Server-side +@server.experimental.get_task() +async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: + ... + +# Client-side +result = await session.experimental.call_tool_as_task("tool_name", {"arg": "value"}) +``` + +## Providing Feedback + +Since these features are experimental, feedback is especially valuable. If you encounter +issues or have suggestions, please open an issue on the +[python-sdk repository](https://github.com/modelcontextprotocol/python-sdk/issues). diff --git a/docs/experimental/tasks-client.md b/docs/experimental/tasks-client.md new file mode 100644 index 0000000000..6883961fea --- /dev/null +++ b/docs/experimental/tasks-client.md @@ -0,0 +1,287 @@ +# Client Task Usage + +!!! warning "Experimental" + + Tasks are an experimental feature. The API may change without notice. + +This guide shows how to call task-augmented tools from an MCP client and retrieve +their results. + +## Prerequisites + +You'll need: + +- An MCP client session connected to a server that supports tasks +- The `ClientSession` from `mcp.client.session` + +## Step 1: Call a Tool as a Task + +Use the `experimental.call_tool_as_task()` method to call a tool with task +augmentation: + +```python +from mcp.client.session import ClientSession + +async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + + # Call the tool as a task + result = await session.experimental.call_tool_as_task( + "process_data", + {"input": "hello world"}, + ttl=60000, # Keep result for 60 seconds + ) + + # Get the task ID for polling + task_id = result.task.taskId + print(f"Task created: {task_id}") + print(f"Initial status: {result.task.status}") +``` + +The method returns a `CreateTaskResult` containing: + +- `task.taskId` - Unique identifier for polling +- `task.status` - Initial status (usually "working") +- `task.pollInterval` - Suggested polling interval in milliseconds +- `task.ttl` - Time-to-live for the task result + +## Step 2: Poll for Status + +Check the task status periodically until it completes: + +```python +import anyio + +while True: + status = await session.experimental.get_task(task_id) + print(f"Status: {status.status}") + + if status.statusMessage: + print(f"Message: {status.statusMessage}") + + if status.status in ("completed", "failed", "cancelled"): + break + + # Respect the suggested poll interval + poll_interval = status.pollInterval or 500 + await anyio.sleep(poll_interval / 1000) # Convert ms to seconds +``` + +The `GetTaskResult` contains: + +- `taskId` - The task identifier +- `status` - Current status: "working", "completed", "failed", "cancelled", or "input_required" +- `statusMessage` - Optional progress message +- `pollInterval` - Suggested interval before next poll (milliseconds) + +## Step 3: Retrieve the Result + +Once the task is complete, retrieve the actual result: + +```python +from mcp.types import CallToolResult + +if status.status == "completed": + # Get the actual tool result + final_result = await session.experimental.get_task_result( + task_id, + CallToolResult, # The expected result type + ) + + # Process the result + for content in final_result.content: + if hasattr(content, "text"): + print(f"Result: {content.text}") + +elif status.status == "failed": + print(f"Task failed: {status.statusMessage}") +``` + +The result type depends on the original request: + +- `tools/call` tasks return `CallToolResult` +- Other request types return their corresponding result type + +## Complete Polling Example + +Here's a complete client that calls a task and waits for the result: + +```python +import anyio + +from mcp.client.session import ClientSession +from mcp.client.stdio import stdio_client +from mcp.types import CallToolResult + + +async def main(): + async with stdio_client( + command="python", + args=["server.py"], + ) as (read, write): + async with ClientSession(read, write) as session: + await session.initialize() + + # 1. Create the task + print("Creating task...") + result = await session.experimental.call_tool_as_task( + "slow_echo", + {"message": "Hello, Tasks!", "delay_seconds": 3}, + ) + task_id = result.task.taskId + print(f"Task created: {task_id}") + + # 2. Poll until complete + print("Polling for completion...") + while True: + status = await session.experimental.get_task(task_id) + print(f" Status: {status.status}", end="") + if status.statusMessage: + print(f" - {status.statusMessage}", end="") + print() + + if status.status in ("completed", "failed", "cancelled"): + break + + await anyio.sleep((status.pollInterval or 500) / 1000) + + # 3. Get the result + if status.status == "completed": + print("Retrieving result...") + final = await session.experimental.get_task_result( + task_id, + CallToolResult, + ) + for content in final.content: + if hasattr(content, "text"): + print(f"Result: {content.text}") + else: + print(f"Task ended with status: {status.status}") + + +if __name__ == "__main__": + anyio.run(main) +``` + +## Cancelling Tasks + +If you need to cancel a running task: + +```python +cancel_result = await session.experimental.cancel_task(task_id) +print(f"Task cancelled, final status: {cancel_result.status}") +``` + +Note that cancellation is cooperative - the server must check for and handle +cancellation requests. A cancelled task will transition to the "cancelled" state. + +## Listing Tasks + +To see all tasks on a server: + +```python +# Get the first page of tasks +tasks_result = await session.experimental.list_tasks() + +for task in tasks_result.tasks: + print(f"Task {task.taskId}: {task.status}") + +# Handle pagination if needed +while tasks_result.nextCursor: + tasks_result = await session.experimental.list_tasks( + cursor=tasks_result.nextCursor + ) + for task in tasks_result.tasks: + print(f"Task {task.taskId}: {task.status}") +``` + +## Low-Level API + +If you need more control, you can use the low-level request API directly: + +```python +from mcp.types import ( + ClientRequest, + CallToolRequest, + CallToolRequestParams, + TaskMetadata, + CreateTaskResult, + GetTaskRequest, + GetTaskRequestParams, + GetTaskResult, + GetTaskPayloadRequest, + GetTaskPayloadRequestParams, +) + +# Create task with full control over the request +result = await session.send_request( + ClientRequest( + CallToolRequest( + params=CallToolRequestParams( + name="process_data", + arguments={"input": "data"}, + task=TaskMetadata(ttl=60000), + ), + ) + ), + CreateTaskResult, +) + +# Poll status +status = await session.send_request( + ClientRequest( + GetTaskRequest( + params=GetTaskRequestParams(taskId=result.task.taskId), + ) + ), + GetTaskResult, +) + +# Get result +final = await session.send_request( + ClientRequest( + GetTaskPayloadRequest( + params=GetTaskPayloadRequestParams(taskId=result.task.taskId), + ) + ), + CallToolResult, +) +``` + +## Error Handling + +Tasks can fail for various reasons. Handle errors appropriately: + +```python +try: + result = await session.experimental.call_tool_as_task("my_tool", args) + task_id = result.task.taskId + + while True: + status = await session.experimental.get_task(task_id) + + if status.status == "completed": + final = await session.experimental.get_task_result( + task_id, CallToolResult + ) + # Process success... + break + + elif status.status == "failed": + print(f"Task failed: {status.statusMessage}") + break + + elif status.status == "cancelled": + print("Task was cancelled") + break + + await anyio.sleep(0.5) + +except Exception as e: + print(f"Error: {e}") +``` + +## Next Steps + +- [Server Implementation](tasks-server.md) - Learn how to build task-supporting servers +- [Tasks Overview](tasks.md) - Review the task lifecycle and concepts diff --git a/docs/experimental/tasks-server.md b/docs/experimental/tasks-server.md new file mode 100644 index 0000000000..d4879fcb5f --- /dev/null +++ b/docs/experimental/tasks-server.md @@ -0,0 +1,441 @@ +# Server Task Implementation + +!!! warning "Experimental" + + Tasks are an experimental feature. The API may change without notice. + +This guide shows how to add task support to an MCP server, starting with the +simplest case and building up to more advanced patterns. + +## Prerequisites + +You'll need: + +- A low-level MCP server +- A task store for state management +- A task group for spawning background work + +## Step 1: Basic Setup + +First, set up the task store and server. The `InMemoryTaskStore` is suitable +for development and testing: + +```python +from dataclasses import dataclass +from anyio.abc import TaskGroup + +from mcp.server import Server +from mcp.shared.experimental.tasks import InMemoryTaskStore + + +@dataclass +class AppContext: + """Application context available during request handling.""" + task_group: TaskGroup + store: InMemoryTaskStore + + +server: Server[AppContext, None] = Server("my-task-server") +store = InMemoryTaskStore() +``` + +## Step 2: Declare Task-Supporting Tools + +Tools that support tasks should declare this in their execution metadata: + +```python +from mcp.types import Tool, ToolExecution, TASK_REQUIRED, TASK_OPTIONAL + +@server.list_tools() +async def list_tools(): + return [ + Tool( + name="process_data", + description="Process data asynchronously", + inputSchema={ + "type": "object", + "properties": {"input": {"type": "string"}}, + }, + # TASK_REQUIRED means this tool MUST be called as a task + execution=ToolExecution(taskSupport=TASK_REQUIRED), + ), + ] +``` + +The `taskSupport` field can be: + +- `TASK_REQUIRED` ("required") - Tool must be called as a task +- `TASK_OPTIONAL` ("optional") - Tool supports both sync and task execution +- `TASK_FORBIDDEN` ("forbidden") - Tool cannot be called as a task (default) + +## Step 3: Handle Tool Calls + +When a client calls a tool as a task, the request context contains task metadata. +Check for this and create a task: + +```python +from mcp.shared.experimental.tasks import task_execution +from mcp.types import ( + CallToolResult, + CreateTaskResult, + TextContent, +) + + +@server.call_tool() +async def handle_call_tool(name: str, arguments: dict) -> list[TextContent] | CreateTaskResult: + ctx = server.request_context + app = ctx.lifespan_context + + if name == "process_data" and ctx.experimental.is_task: + # Get task metadata from the request + task_metadata = ctx.experimental.task_metadata + + # Create the task in our store + task = await app.store.create_task(task_metadata) + + # Define the work to do in the background + async def do_work(): + async with task_execution(task.taskId, app.store) as task_ctx: + # Update status to show progress + await task_ctx.update_status("Processing input...", notify=False) + + # Do the actual work + input_value = arguments.get("input", "") + result_text = f"Processed: {input_value.upper()}" + + # Complete the task with the result + await task_ctx.complete( + CallToolResult( + content=[TextContent(type="text", text=result_text)] + ), + notify=False, + ) + + # Spawn work in the background task group + app.task_group.start_soon(do_work) + + # Return immediately with the task reference + return CreateTaskResult(task=task) + + # Non-task execution path + return [TextContent(type="text", text="Use task mode for this tool")] +``` + +Key points: + +- `ctx.experimental.is_task` checks if this is a task-augmented request +- `ctx.experimental.task_metadata` contains the task configuration +- `task_execution` is a context manager that handles errors gracefully +- Work runs in a separate coroutine via the task group +- The handler returns `CreateTaskResult` immediately + +## Step 4: Register Task Handlers + +Clients need endpoints to query task status and retrieve results. Register these +using the experimental decorators: + +```python +from mcp.types import ( + GetTaskRequest, + GetTaskResult, + GetTaskPayloadRequest, + GetTaskPayloadResult, + ListTasksRequest, + ListTasksResult, +) + + +@server.experimental.get_task() +async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: + """Handle tasks/get requests - return current task status.""" + app = server.request_context.lifespan_context + task = await app.store.get_task(request.params.taskId) + + if task is None: + raise ValueError(f"Task {request.params.taskId} not found") + + return GetTaskResult( + taskId=task.taskId, + status=task.status, + statusMessage=task.statusMessage, + createdAt=task.createdAt, + lastUpdatedAt=task.lastUpdatedAt, + ttl=task.ttl, + pollInterval=task.pollInterval, + ) + + +@server.experimental.get_task_result() +async def handle_get_task_result(request: GetTaskPayloadRequest) -> GetTaskPayloadResult: + """Handle tasks/result requests - return the completed task's result.""" + app = server.request_context.lifespan_context + result = await app.store.get_result(request.params.taskId) + + if result is None: + raise ValueError(f"Result for task {request.params.taskId} not found") + + # Return the stored result + assert isinstance(result, CallToolResult) + return GetTaskPayloadResult(**result.model_dump()) + + +@server.experimental.list_tasks() +async def handle_list_tasks(request: ListTasksRequest) -> ListTasksResult: + """Handle tasks/list requests - return all tasks with pagination.""" + app = server.request_context.lifespan_context + cursor = request.params.cursor if request.params else None + tasks, next_cursor = await app.store.list_tasks(cursor=cursor) + + return ListTasksResult(tasks=tasks, nextCursor=next_cursor) +``` + +## Step 5: Run the Server + +Wire everything together with a task group for background work: + +```python +import anyio +from mcp.server.stdio import stdio_server + + +async def main(): + async with anyio.create_task_group() as tg: + app = AppContext(task_group=tg, store=store) + + async with stdio_server() as (read, write): + await server.run( + read, + write, + server.create_initialization_options(), + lifespan_context=app, + ) + + +if __name__ == "__main__": + anyio.run(main) +``` + +## The task_execution Context Manager + +The `task_execution` helper provides safe task execution: + +```python +async with task_execution(task_id, store) as ctx: + await ctx.update_status("Working...") + result = await do_work() + await ctx.complete(result) +``` + +If an exception occurs inside the context, the task is automatically marked +as failed with the exception message. This prevents tasks from getting stuck +in the "working" state. + +The context provides: + +- `ctx.task_id` - The task identifier +- `ctx.task` - Current task state +- `ctx.is_cancelled` - Check if cancellation was requested +- `ctx.update_status(msg)` - Update the status message +- `ctx.complete(result)` - Mark task as completed +- `ctx.fail(error)` - Mark task as failed + +## Handling Cancellation + +To support task cancellation, register a cancel handler and check for +cancellation in your work: + +```python +from mcp.types import CancelTaskRequest, CancelTaskResult + +# Track running tasks so we can cancel them +running_tasks: dict[str, TaskContext] = {} + + +@server.experimental.cancel_task() +async def handle_cancel_task(request: CancelTaskRequest) -> CancelTaskResult: + task_id = request.params.taskId + app = server.request_context.lifespan_context + + # Signal cancellation to the running work + if task_id in running_tasks: + running_tasks[task_id].request_cancellation() + + # Update task status + task = await app.store.update_task(task_id, status="cancelled") + + return CancelTaskResult( + taskId=task.taskId, + status=task.status, + ) +``` + +Then check for cancellation in your work: + +```python +async def do_work(): + async with task_execution(task.taskId, app.store) as ctx: + running_tasks[task.taskId] = ctx + try: + for i in range(100): + if ctx.is_cancelled: + return # Exit gracefully + + await ctx.update_status(f"Processing step {i}/100") + await process_step(i) + + await ctx.complete(result) + finally: + running_tasks.pop(task.taskId, None) +``` + +## Complete Example + +Here's a full working server with task support: + +```python +from dataclasses import dataclass +from typing import Any + +import anyio +from anyio.abc import TaskGroup + +from mcp.server import Server +from mcp.server.stdio import stdio_server +from mcp.shared.experimental.tasks import InMemoryTaskStore, task_execution +from mcp.types import ( + TASK_REQUIRED, + CallToolResult, + CreateTaskResult, + GetTaskPayloadRequest, + GetTaskPayloadResult, + GetTaskRequest, + GetTaskResult, + ListTasksRequest, + ListTasksResult, + TextContent, + Tool, + ToolExecution, +) + + +@dataclass +class AppContext: + task_group: TaskGroup + store: InMemoryTaskStore + + +server: Server[AppContext, Any] = Server("task-example") +store = InMemoryTaskStore() + + +@server.list_tools() +async def list_tools(): + return [ + Tool( + name="slow_echo", + description="Echo input after a delay (demonstrates tasks)", + inputSchema={ + "type": "object", + "properties": { + "message": {"type": "string"}, + "delay_seconds": {"type": "number", "default": 2}, + }, + "required": ["message"], + }, + execution=ToolExecution(taskSupport=TASK_REQUIRED), + ), + ] + + +@server.call_tool() +async def handle_call_tool( + name: str, arguments: dict[str, Any] +) -> list[TextContent] | CreateTaskResult: + ctx = server.request_context + app = ctx.lifespan_context + + if name == "slow_echo" and ctx.experimental.is_task: + task = await app.store.create_task(ctx.experimental.task_metadata) + + async def do_work(): + async with task_execution(task.taskId, app.store) as task_ctx: + message = arguments.get("message", "") + delay = arguments.get("delay_seconds", 2) + + await task_ctx.update_status("Starting...", notify=False) + await anyio.sleep(delay / 2) + + await task_ctx.update_status("Almost done...", notify=False) + await anyio.sleep(delay / 2) + + await task_ctx.complete( + CallToolResult( + content=[TextContent(type="text", text=f"Echo: {message}")] + ), + notify=False, + ) + + app.task_group.start_soon(do_work) + return CreateTaskResult(task=task) + + return [TextContent(type="text", text="This tool requires task mode")] + + +@server.experimental.get_task() +async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: + app = server.request_context.lifespan_context + task = await app.store.get_task(request.params.taskId) + if task is None: + raise ValueError(f"Task not found: {request.params.taskId}") + return GetTaskResult( + taskId=task.taskId, + status=task.status, + statusMessage=task.statusMessage, + createdAt=task.createdAt, + lastUpdatedAt=task.lastUpdatedAt, + ttl=task.ttl, + pollInterval=task.pollInterval, + ) + + +@server.experimental.get_task_result() +async def handle_get_task_result( + request: GetTaskPayloadRequest, +) -> GetTaskPayloadResult: + app = server.request_context.lifespan_context + result = await app.store.get_result(request.params.taskId) + if result is None: + raise ValueError(f"Result not found: {request.params.taskId}") + assert isinstance(result, CallToolResult) + return GetTaskPayloadResult(**result.model_dump()) + + +@server.experimental.list_tasks() +async def handle_list_tasks(request: ListTasksRequest) -> ListTasksResult: + app = server.request_context.lifespan_context + cursor = request.params.cursor if request.params else None + tasks, next_cursor = await app.store.list_tasks(cursor=cursor) + return ListTasksResult(tasks=tasks, nextCursor=next_cursor) + + +async def main(): + async with anyio.create_task_group() as tg: + app = AppContext(task_group=tg, store=store) + async with stdio_server() as (read, write): + await server.run( + read, + write, + server.create_initialization_options(), + lifespan_context=app, + ) + + +if __name__ == "__main__": + anyio.run(main) +``` + +## Next Steps + +- [Client Usage](tasks-client.md) - Learn how to call tasks from a client +- [Tasks Overview](tasks.md) - Review the task lifecycle and concepts diff --git a/docs/experimental/tasks.md b/docs/experimental/tasks.md new file mode 100644 index 0000000000..1fc1710002 --- /dev/null +++ b/docs/experimental/tasks.md @@ -0,0 +1,122 @@ +# Tasks + +!!! warning "Experimental" + + Tasks are an experimental feature tracking the draft MCP specification. + The API may change without notice. + +Tasks allow MCP servers to handle requests asynchronously. When a client sends a +task-augmented request, the server can start working in the background and return +a task reference immediately. The client then polls for updates and retrieves the +result when complete. + +## When to Use Tasks + +Tasks are useful when operations: + +- Take significant time to complete (seconds to minutes) +- May require intermediate status updates +- Need to run in the background without blocking the client + +## Task Lifecycle + +A task progresses through these states: + +```text +working → completed + → failed + → cancelled + +working → input_required → working → completed/failed/cancelled +``` + +| State | Description | +|-------|-------------| +| `working` | The task is being processed | +| `input_required` | The server needs additional information | +| `completed` | The task finished successfully | +| `failed` | The task encountered an error | +| `cancelled` | The task was cancelled | + +Once a task reaches `completed`, `failed`, or `cancelled`, it cannot transition +to any other state. + +## Basic Flow + +Here's the typical interaction pattern: + +1. **Client** sends a tool call with task metadata +2. **Server** creates a task, spawns background work, returns `CreateTaskResult` +3. **Client** receives the task ID and starts polling +4. **Server** executes the work, updating status as needed +5. **Client** polls with `tasks/get` to check status +6. **Server** finishes work and stores the result +7. **Client** retrieves result with `tasks/result` + +```text +Client Server + │ │ + │──── tools/call (with task) ─────────>│ + │ │ create task + │<──── CreateTaskResult ──────────────│ spawn work + │ │ + │──── tasks/get ──────────────────────>│ + │<──── status: working ───────────────│ + │ │ ... work continues ... + │──── tasks/get ──────────────────────>│ + │<──── status: completed ─────────────│ + │ │ + │──── tasks/result ───────────────────>│ + │<──── CallToolResult ────────────────│ + │ │ +``` + +## Key Concepts + +### Task Metadata + +When a client wants a request handled as a task, it includes `TaskMetadata` in +the request: + +```python +task = TaskMetadata(ttl=60000) # TTL in milliseconds +``` + +The `ttl` (time-to-live) specifies how long the task and its result should be +retained after completion. + +### Task Store + +Servers need to persist task state somewhere. The SDK provides an abstract +`TaskStore` interface and an `InMemoryTaskStore` for development: + +```python +from mcp.shared.experimental.tasks import InMemoryTaskStore + +store = InMemoryTaskStore() +``` + +The store tracks: + +- Task state (status, messages, timestamps) +- Results for completed tasks +- Automatic cleanup based on TTL + +For production, you'd implement `TaskStore` with a database or distributed cache. + +### Capabilities + +Task support is advertised through server capabilities. The SDK automatically +updates capabilities when you register task handlers: + +```python +# This registers the handler AND advertises the capability +@server.experimental.get_task() +async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: + ... +``` + +## Next Steps + +- [Server Implementation](tasks-server.md) - How to add task support to your server +- [Client Usage](tasks-client.md) - How to call and poll tasks from a client diff --git a/examples/clients/simple-task-client/README.md b/examples/clients/simple-task-client/README.md new file mode 100644 index 0000000000..103be0f1fb --- /dev/null +++ b/examples/clients/simple-task-client/README.md @@ -0,0 +1,43 @@ +# Simple Task Client + +A minimal MCP client demonstrating polling for task results over streamable HTTP. + +## Running + +First, start the simple-task server in another terminal: + +```bash +cd examples/servers/simple-task +uv run mcp-simple-task +``` + +Then run the client: + +```bash +cd examples/clients/simple-task-client +uv run mcp-simple-task-client +``` + +Use `--url` to connect to a different server. + +## What it does + +1. Connects to the server via streamable HTTP +2. Calls the `long_running_task` tool as a task +3. Polls the task status until completion +4. Retrieves and prints the result + +## Expected output + +```text +Available tools: ['long_running_task'] + +Calling tool as a task... +Task created: + Status: working - Starting work... + Status: working - Processing step 1... + Status: working - Processing step 2... + Status: completed - + +Result: Task completed! +``` diff --git a/examples/clients/simple-task-client/mcp_simple_task_client/__init__.py b/examples/clients/simple-task-client/mcp_simple_task_client/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/clients/simple-task-client/mcp_simple_task_client/__main__.py b/examples/clients/simple-task-client/mcp_simple_task_client/__main__.py new file mode 100644 index 0000000000..2fc2cda8d9 --- /dev/null +++ b/examples/clients/simple-task-client/mcp_simple_task_client/__main__.py @@ -0,0 +1,5 @@ +import sys + +from .main import main + +sys.exit(main()) # type: ignore[call-arg] diff --git a/examples/clients/simple-task-client/mcp_simple_task_client/main.py b/examples/clients/simple-task-client/mcp_simple_task_client/main.py new file mode 100644 index 0000000000..9a38cfe87c --- /dev/null +++ b/examples/clients/simple-task-client/mcp_simple_task_client/main.py @@ -0,0 +1,75 @@ +"""Simple task client demonstrating MCP tasks polling over streamable HTTP.""" + +import asyncio + +import click +from mcp import ClientSession +from mcp.client.streamable_http import streamablehttp_client +from mcp.types import ( + CallToolRequest, + CallToolRequestParams, + CallToolResult, + ClientRequest, + CreateTaskResult, + TaskMetadata, + TextContent, +) + + +async def run(url: str) -> None: + async with streamablehttp_client(url) as (read, write, _): + async with ClientSession(read, write) as session: + await session.initialize() + + # List tools + tools = await session.list_tools() + print(f"Available tools: {[t.name for t in tools.tools]}") + + # Call the tool as a task + print("\nCalling tool as a task...") + + # TODO: make helper for this + result = await session.send_request( + ClientRequest( + CallToolRequest( + params=CallToolRequestParams( + name="long_running_task", + arguments={}, + task=TaskMetadata(ttl=60000), + ) + ) + ), + CreateTaskResult, + ) + task_id = result.task.taskId + print(f"Task created: {task_id}") + + # Poll until done + while True: + status = await session.experimental.get_task(task_id) + print(f" Status: {status.status} - {status.statusMessage or ''}") + + if status.status == "completed": + break + elif status.status in ("failed", "cancelled"): + print(f"Task ended with status: {status.status}") + return + + await asyncio.sleep(0.5) + + # Get the result + task_result = await session.experimental.get_task_result(task_id, CallToolResult) + content = task_result.content[0] + if isinstance(content, TextContent): + print(f"\nResult: {content.text}") + + +@click.command() +@click.option("--url", default="http://localhost:8000/mcp", help="Server URL") +def main(url: str) -> int: + asyncio.run(run(url)) + return 0 + + +if __name__ == "__main__": + main() diff --git a/examples/clients/simple-task-client/pyproject.toml b/examples/clients/simple-task-client/pyproject.toml new file mode 100644 index 0000000000..da10392e3c --- /dev/null +++ b/examples/clients/simple-task-client/pyproject.toml @@ -0,0 +1,43 @@ +[project] +name = "mcp-simple-task-client" +version = "0.1.0" +description = "A simple MCP client demonstrating task polling" +readme = "README.md" +requires-python = ">=3.10" +authors = [{ name = "Anthropic, PBC." }] +keywords = ["mcp", "llm", "tasks", "client"] +license = { text = "MIT" } +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", +] +dependencies = ["click>=8.0", "mcp"] + +[project.scripts] +mcp-simple-task-client = "mcp_simple_task_client.main:main" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["mcp_simple_task_client"] + +[tool.pyright] +include = ["mcp_simple_task_client"] +venvPath = "." +venv = ".venv" + +[tool.ruff.lint] +select = ["E", "F", "I"] +ignore = [] + +[tool.ruff] +line-length = 120 +target-version = "py310" + +[dependency-groups] +dev = ["pyright>=1.1.378", "ruff>=0.6.9"] diff --git a/examples/clients/simple-task-interactive-client/README.md b/examples/clients/simple-task-interactive-client/README.md new file mode 100644 index 0000000000..15ec771670 --- /dev/null +++ b/examples/clients/simple-task-interactive-client/README.md @@ -0,0 +1,87 @@ +# Simple Interactive Task Client + +A minimal MCP client demonstrating responses to interactive tasks (elicitation and sampling). + +## Running + +First, start the interactive task server in another terminal: + +```bash +cd examples/servers/simple-task-interactive +uv run mcp-simple-task-interactive +``` + +Then run the client: + +```bash +cd examples/clients/simple-task-interactive-client +uv run mcp-simple-task-interactive-client +``` + +Use `--url` to connect to a different server. + +## What it does + +1. Connects to the server via streamable HTTP +2. Calls `confirm_delete` - server asks for confirmation, client responds via terminal +3. Calls `write_haiku` - server requests LLM completion, client returns a hardcoded haiku + +## Key concepts + +### Elicitation callback + +```python +async def elicitation_callback(context, params) -> ElicitResult: + # Handle user input request from server + return ElicitResult(action="accept", content={"confirm": True}) +``` + +### Sampling callback + +```python +async def sampling_callback(context, params) -> CreateMessageResult: + # Handle LLM completion request from server + return CreateMessageResult(model="...", role="assistant", content=...) +``` + +### Using call_tool_as_task + +```python +# Call a tool as a task (returns immediately with task reference) +result = await session.experimental.call_tool_as_task("tool_name", {"arg": "value"}) +task_id = result.taskSupport.taskId + +# Get result - this delivers elicitation/sampling requests and blocks until complete +final = await session.experimental.get_task_result(task_id, CallToolResult) +``` + +**Important**: The `get_task_result()` call is what triggers the delivery of elicitation +and sampling requests to your callbacks. It blocks until the task completes and returns +the final result. + +## Expected output + +```text +Available tools: ['confirm_delete', 'write_haiku'] + +--- Demo 1: Elicitation --- +Calling confirm_delete tool... +Task created: + +[Elicitation] Server asks: Are you sure you want to delete 'important.txt'? +Your response (y/n): y +[Elicitation] Responding with: confirm=True +Result: Deleted 'important.txt' + +--- Demo 2: Sampling --- +Calling write_haiku tool... +Task created: + +[Sampling] Server requests LLM completion for: Write a haiku about autumn leaves +[Sampling] Responding with haiku +Result: +Haiku: +Cherry blossoms fall +Softly on the quiet pond +Spring whispers goodbye +``` diff --git a/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/__init__.py b/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/__main__.py b/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/__main__.py new file mode 100644 index 0000000000..2fc2cda8d9 --- /dev/null +++ b/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/__main__.py @@ -0,0 +1,5 @@ +import sys + +from .main import main + +sys.exit(main()) # type: ignore[call-arg] diff --git a/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/main.py b/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/main.py new file mode 100644 index 0000000000..e42d139fb3 --- /dev/null +++ b/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/main.py @@ -0,0 +1,116 @@ +"""Simple interactive task client demonstrating elicitation and sampling responses.""" + +import asyncio +from typing import Any + +import click +from mcp import ClientSession +from mcp.client.streamable_http import streamablehttp_client +from mcp.shared.context import RequestContext +from mcp.types import ( + CallToolResult, + CreateMessageRequestParams, + CreateMessageResult, + ElicitRequestParams, + ElicitResult, + TextContent, +) + + +async def elicitation_callback( + context: RequestContext[ClientSession, Any], + params: ElicitRequestParams, +) -> ElicitResult: + """Handle elicitation requests from the server.""" + print(f"\n[Elicitation] Server asks: {params.message}") + + # Simple terminal prompt + response = input("Your response (y/n): ").strip().lower() + confirmed = response in ("y", "yes", "true", "1") + + print(f"[Elicitation] Responding with: confirm={confirmed}") + return ElicitResult(action="accept", content={"confirm": confirmed}) + + +async def sampling_callback( + context: RequestContext[ClientSession, Any], + params: CreateMessageRequestParams, +) -> CreateMessageResult: + """Handle sampling requests from the server.""" + # Get the prompt from the first message + prompt = "unknown" + if params.messages: + content = params.messages[0].content + if isinstance(content, TextContent): + prompt = content.text + + print(f"\n[Sampling] Server requests LLM completion for: {prompt}") + + # Return a hardcoded haiku (in real use, call your LLM here) + haiku = """Cherry blossoms fall +Softly on the quiet pond +Spring whispers goodbye""" + + print("[Sampling] Responding with haiku") + return CreateMessageResult( + model="mock-haiku-model", + role="assistant", + content=TextContent(type="text", text=haiku), + ) + + +def get_text(result: CallToolResult) -> str: + """Extract text from a CallToolResult.""" + if result.content and isinstance(result.content[0], TextContent): + return result.content[0].text + return "(no text)" + + +async def run(url: str) -> None: + async with streamablehttp_client(url) as (read, write, _): + async with ClientSession( + read, + write, + elicitation_callback=elicitation_callback, + sampling_callback=sampling_callback, + ) as session: + await session.initialize() + + # List tools + tools = await session.list_tools() + print(f"Available tools: {[t.name for t in tools.tools]}") + + # Demo 1: Elicitation (confirm_delete) + print("\n--- Demo 1: Elicitation ---") + print("Calling confirm_delete tool...") + + result = await session.experimental.call_tool_as_task("confirm_delete", {"filename": "important.txt"}) + task_id = result.task.taskId + print(f"Task created: {task_id}") + + # get_task_result() delivers elicitation requests and blocks until complete + final = await session.experimental.get_task_result(task_id, CallToolResult) + print(f"Result: {get_text(final)}") + + # Demo 2: Sampling (write_haiku) + print("\n--- Demo 2: Sampling ---") + print("Calling write_haiku tool...") + + result = await session.experimental.call_tool_as_task("write_haiku", {"topic": "autumn leaves"}) + task_id = result.task.taskId + print(f"Task created: {task_id}") + + # get_task_result() delivers sampling requests and blocks until complete + final = await session.experimental.get_task_result(task_id, CallToolResult) + print(f"Result:\n{get_text(final)}") + + +@click.command() +@click.option("--url", default="http://localhost:8000/mcp", help="Server URL") +def main(url: str) -> int: + asyncio.run(run(url)) + return 0 + + +if __name__ == "__main__": + main() diff --git a/examples/clients/simple-task-interactive-client/pyproject.toml b/examples/clients/simple-task-interactive-client/pyproject.toml new file mode 100644 index 0000000000..224bbc5917 --- /dev/null +++ b/examples/clients/simple-task-interactive-client/pyproject.toml @@ -0,0 +1,43 @@ +[project] +name = "mcp-simple-task-interactive-client" +version = "0.1.0" +description = "A simple MCP client demonstrating interactive task responses" +readme = "README.md" +requires-python = ">=3.10" +authors = [{ name = "Anthropic, PBC." }] +keywords = ["mcp", "llm", "tasks", "client", "elicitation", "sampling"] +license = { text = "MIT" } +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", +] +dependencies = ["click>=8.0", "mcp"] + +[project.scripts] +mcp-simple-task-interactive-client = "mcp_simple_task_interactive_client.main:main" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["mcp_simple_task_interactive_client"] + +[tool.pyright] +include = ["mcp_simple_task_interactive_client"] +venvPath = "." +venv = ".venv" + +[tool.ruff.lint] +select = ["E", "F", "I"] +ignore = [] + +[tool.ruff] +line-length = 120 +target-version = "py310" + +[dependency-groups] +dev = ["pyright>=1.1.378", "ruff>=0.6.9"] diff --git a/examples/servers/simple-task-interactive/README.md b/examples/servers/simple-task-interactive/README.md new file mode 100644 index 0000000000..57bdb2c228 --- /dev/null +++ b/examples/servers/simple-task-interactive/README.md @@ -0,0 +1,74 @@ +# Simple Interactive Task Server + +A minimal MCP server demonstrating interactive tasks with elicitation and sampling. + +## Running + +```bash +cd examples/servers/simple-task-interactive +uv run mcp-simple-task-interactive +``` + +The server starts on `http://localhost:8000/mcp` by default. Use `--port` to change. + +## What it does + +This server exposes two tools: + +### `confirm_delete` (demonstrates elicitation) + +Asks the user for confirmation before "deleting" a file. + +- Uses `TaskSession.elicit()` to request user input +- Shows the elicitation flow: task -> input_required -> response -> complete + +### `write_haiku` (demonstrates sampling) + +Asks the LLM to write a haiku about a topic. + +- Uses `TaskSession.create_message()` to request LLM completion +- Shows the sampling flow: task -> input_required -> response -> complete + +## Usage with the client + +In one terminal, start the server: + +```bash +cd examples/servers/simple-task-interactive +uv run mcp-simple-task-interactive +``` + +In another terminal, run the interactive client: + +```bash +cd examples/clients/simple-task-interactive-client +uv run mcp-simple-task-interactive-client +``` + +## Expected server output + +When a client connects and calls the tools, you'll see: + +```text +Starting server on http://localhost:8000/mcp + +[Server] confirm_delete called for 'important.txt' +[Server] Task created: +[Server] Sending elicitation request to client... +[Server] Received elicitation response: action=accept, content={'confirm': True} +[Server] Completing task with result: Deleted 'important.txt' + +[Server] write_haiku called for topic 'autumn leaves' +[Server] Task created: +[Server] Sending sampling request to client... +[Server] Received sampling response: Cherry blossoms fall +Softly on the quiet pon... +[Server] Completing task with haiku +``` + +## Key concepts + +1. **TaskSession**: Wraps ServerSession to enqueue elicitation/sampling requests +2. **TaskResultHandler**: Delivers queued messages and routes responses +3. **task_execution()**: Context manager for safe task execution with auto-fail +4. **Response routing**: Responses are routed back to waiting resolvers diff --git a/examples/servers/simple-task-interactive/mcp_simple_task_interactive/__init__.py b/examples/servers/simple-task-interactive/mcp_simple_task_interactive/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/servers/simple-task-interactive/mcp_simple_task_interactive/__main__.py b/examples/servers/simple-task-interactive/mcp_simple_task_interactive/__main__.py new file mode 100644 index 0000000000..e7ef16530b --- /dev/null +++ b/examples/servers/simple-task-interactive/mcp_simple_task_interactive/__main__.py @@ -0,0 +1,5 @@ +import sys + +from .server import main + +sys.exit(main()) # type: ignore[call-arg] diff --git a/examples/servers/simple-task-interactive/mcp_simple_task_interactive/server.py b/examples/servers/simple-task-interactive/mcp_simple_task_interactive/server.py new file mode 100644 index 0000000000..419d51b556 --- /dev/null +++ b/examples/servers/simple-task-interactive/mcp_simple_task_interactive/server.py @@ -0,0 +1,231 @@ +"""Simple interactive task server demonstrating elicitation and sampling.""" + +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from dataclasses import dataclass +from typing import Any + +import anyio +import click +import mcp.types as types +import uvicorn +from anyio.abc import TaskGroup +from mcp.server.lowlevel import Server +from mcp.server.session import ServerSession +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager +from mcp.shared.experimental.tasks import ( + InMemoryTaskMessageQueue, + InMemoryTaskStore, + TaskResultHandler, + TaskSession, + task_execution, +) +from starlette.applications import Starlette +from starlette.routing import Mount + + +@dataclass +class AppContext: + task_group: TaskGroup + store: InMemoryTaskStore + queue: InMemoryTaskMessageQueue + handler: TaskResultHandler + # Track sessions that have been configured (session ID -> bool) + configured_sessions: dict[int, bool] + + +@asynccontextmanager +async def lifespan(server: Server[AppContext, Any]) -> AsyncIterator[AppContext]: + store = InMemoryTaskStore() + queue = InMemoryTaskMessageQueue() + handler = TaskResultHandler(store, queue) + async with anyio.create_task_group() as tg: + yield AppContext( + task_group=tg, + store=store, + queue=queue, + handler=handler, + configured_sessions={}, + ) + store.cleanup() + queue.cleanup() + + +server: Server[AppContext, Any] = Server("simple-task-interactive", lifespan=lifespan) + + +def ensure_handler_configured(session: ServerSession, app: AppContext) -> None: + """Ensure the task result handler is configured for this session (once).""" + session_id = id(session) + if session_id not in app.configured_sessions: + session.add_response_router(app.handler) + app.configured_sessions[session_id] = True + + +@server.list_tools() +async def list_tools() -> list[types.Tool]: + return [ + types.Tool( + name="confirm_delete", + description="Asks for confirmation before deleting (demonstrates elicitation)", + inputSchema={ + "type": "object", + "properties": {"filename": {"type": "string"}}, + }, + execution=types.ToolExecution(taskSupport=types.TASK_REQUIRED), + ), + types.Tool( + name="write_haiku", + description="Asks LLM to write a haiku (demonstrates sampling)", + inputSchema={"type": "object", "properties": {"topic": {"type": "string"}}}, + execution=types.ToolExecution(taskSupport=types.TASK_REQUIRED), + ), + ] + + +@server.call_tool() +async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[types.TextContent] | types.CreateTaskResult: + ctx = server.request_context + app = ctx.lifespan_context + + # Validate task mode + ctx.experimental.validate_task_mode(types.TASK_REQUIRED) + + # Ensure handler is configured for response routing + ensure_handler_configured(ctx.session, app) + + # Create task + metadata = ctx.experimental.task_metadata + assert metadata is not None + task = await app.store.create_task(metadata) + + if name == "confirm_delete": + filename = arguments.get("filename", "unknown.txt") + print(f"\n[Server] confirm_delete called for '{filename}'") + print(f"[Server] Task created: {task.taskId}") + + async def do_confirm() -> None: + async with task_execution(task.taskId, app.store) as task_ctx: + task_session = TaskSession( + session=ctx.session, + task_id=task.taskId, + store=app.store, + queue=app.queue, + ) + + print("[Server] Sending elicitation request to client...") + result = await task_session.elicit( + message=f"Are you sure you want to delete '{filename}'?", + requestedSchema={ + "type": "object", + "properties": {"confirm": {"type": "boolean"}}, + "required": ["confirm"], + }, + ) + + print(f"[Server] Received elicitation response: action={result.action}, content={result.content}") + if result.action == "accept" and result.content: + confirmed = result.content.get("confirm", False) + text = f"Deleted '{filename}'" if confirmed else "Deletion cancelled" + else: + text = "Deletion cancelled" + + print(f"[Server] Completing task with result: {text}") + await task_ctx.complete( + types.CallToolResult(content=[types.TextContent(type="text", text=text)]), + notify=True, + ) + + app.task_group.start_soon(do_confirm) + + elif name == "write_haiku": + topic = arguments.get("topic", "nature") + print(f"\n[Server] write_haiku called for topic '{topic}'") + print(f"[Server] Task created: {task.taskId}") + + async def do_haiku() -> None: + async with task_execution(task.taskId, app.store) as task_ctx: + task_session = TaskSession( + session=ctx.session, + task_id=task.taskId, + store=app.store, + queue=app.queue, + ) + + print("[Server] Sending sampling request to client...") + result = await task_session.create_message( + messages=[ + types.SamplingMessage( + role="user", + content=types.TextContent(type="text", text=f"Write a haiku about {topic}"), + ) + ], + max_tokens=50, + ) + + haiku = "No response" + if isinstance(result.content, types.TextContent): + haiku = result.content.text + + print(f"[Server] Received sampling response: {haiku[:50]}...") + print("[Server] Completing task with haiku") + await task_ctx.complete( + types.CallToolResult(content=[types.TextContent(type="text", text=f"Haiku:\n{haiku}")]), + notify=True, + ) + + app.task_group.start_soon(do_haiku) + + return types.CreateTaskResult(task=task) + + +@server.experimental.get_task() +async def handle_get_task(request: types.GetTaskRequest) -> types.GetTaskResult: + app = server.request_context.lifespan_context + task = await app.store.get_task(request.params.taskId) + if task is None: + raise ValueError(f"Task {request.params.taskId} not found") + return types.GetTaskResult( + taskId=task.taskId, + status=task.status, + statusMessage=task.statusMessage, + createdAt=task.createdAt, + lastUpdatedAt=task.lastUpdatedAt, + ttl=task.ttl, + pollInterval=task.pollInterval, + ) + + +@server.experimental.get_task_result() +async def handle_get_task_result( + request: types.GetTaskPayloadRequest, +) -> types.GetTaskPayloadResult: + ctx = server.request_context + app = ctx.lifespan_context + + # Ensure handler is configured for this session + ensure_handler_configured(ctx.session, app) + + return await app.handler.handle(request, ctx.session, ctx.request_id) + + +def create_app(session_manager: StreamableHTTPSessionManager) -> Starlette: + @asynccontextmanager + async def app_lifespan(app: Starlette) -> AsyncIterator[None]: + async with session_manager.run(): + yield + + return Starlette( + routes=[Mount("/mcp", app=session_manager.handle_request)], + lifespan=app_lifespan, + ) + + +@click.command() +@click.option("--port", default=8000, help="Port to listen on") +def main(port: int) -> int: + session_manager = StreamableHTTPSessionManager(app=server) + starlette_app = create_app(session_manager) + print(f"Starting server on http://localhost:{port}/mcp") + uvicorn.run(starlette_app, host="127.0.0.1", port=port) + return 0 diff --git a/examples/servers/simple-task-interactive/pyproject.toml b/examples/servers/simple-task-interactive/pyproject.toml new file mode 100644 index 0000000000..492345ff52 --- /dev/null +++ b/examples/servers/simple-task-interactive/pyproject.toml @@ -0,0 +1,43 @@ +[project] +name = "mcp-simple-task-interactive" +version = "0.1.0" +description = "A simple MCP server demonstrating interactive tasks (elicitation & sampling)" +readme = "README.md" +requires-python = ">=3.10" +authors = [{ name = "Anthropic, PBC." }] +keywords = ["mcp", "llm", "tasks", "elicitation", "sampling"] +license = { text = "MIT" } +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", +] +dependencies = ["anyio>=4.5", "click>=8.0", "mcp", "starlette", "uvicorn"] + +[project.scripts] +mcp-simple-task-interactive = "mcp_simple_task_interactive.server:main" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["mcp_simple_task_interactive"] + +[tool.pyright] +include = ["mcp_simple_task_interactive"] +venvPath = "." +venv = ".venv" + +[tool.ruff.lint] +select = ["E", "F", "I"] +ignore = [] + +[tool.ruff] +line-length = 120 +target-version = "py310" + +[dependency-groups] +dev = ["pyright>=1.1.378", "ruff>=0.6.9"] diff --git a/examples/servers/simple-task/README.md b/examples/servers/simple-task/README.md new file mode 100644 index 0000000000..6914e0414f --- /dev/null +++ b/examples/servers/simple-task/README.md @@ -0,0 +1,37 @@ +# Simple Task Server + +A minimal MCP server demonstrating the experimental tasks feature over streamable HTTP. + +## Running + +```bash +cd examples/servers/simple-task +uv run mcp-simple-task +``` + +The server starts on `http://localhost:8000/mcp` by default. Use `--port` to change. + +## What it does + +This server exposes a single tool `long_running_task` that: + +1. Must be called as a task (with `task` metadata in the request) +2. Takes ~3 seconds to complete +3. Sends status updates during execution +4. Returns a result when complete + +## Usage with the client + +In one terminal, start the server: + +```bash +cd examples/servers/simple-task +uv run mcp-simple-task +``` + +In another terminal, run the client: + +```bash +cd examples/clients/simple-task-client +uv run mcp-simple-task-client +``` diff --git a/examples/servers/simple-task/mcp_simple_task/__init__.py b/examples/servers/simple-task/mcp_simple_task/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/servers/simple-task/mcp_simple_task/__main__.py b/examples/servers/simple-task/mcp_simple_task/__main__.py new file mode 100644 index 0000000000..e7ef16530b --- /dev/null +++ b/examples/servers/simple-task/mcp_simple_task/__main__.py @@ -0,0 +1,5 @@ +import sys + +from .server import main + +sys.exit(main()) # type: ignore[call-arg] diff --git a/examples/servers/simple-task/mcp_simple_task/server.py b/examples/servers/simple-task/mcp_simple_task/server.py new file mode 100644 index 0000000000..0482dc75a6 --- /dev/null +++ b/examples/servers/simple-task/mcp_simple_task/server.py @@ -0,0 +1,126 @@ +"""Simple task server demonstrating MCP tasks over streamable HTTP.""" + +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from dataclasses import dataclass +from typing import Any + +import anyio +import click +import mcp.types as types +import uvicorn +from anyio.abc import TaskGroup +from mcp.server.lowlevel import Server +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager +from mcp.shared.experimental.tasks import InMemoryTaskStore, task_execution +from starlette.applications import Starlette +from starlette.routing import Mount + + +@dataclass +class AppContext: + task_group: TaskGroup + store: InMemoryTaskStore + + +@asynccontextmanager +async def lifespan(server: Server[AppContext, Any]) -> AsyncIterator[AppContext]: + store = InMemoryTaskStore() + async with anyio.create_task_group() as tg: + yield AppContext(task_group=tg, store=store) + store.cleanup() + + +server: Server[AppContext, Any] = Server("simple-task-server", lifespan=lifespan) + + +@server.list_tools() +async def list_tools() -> list[types.Tool]: + return [ + types.Tool( + name="long_running_task", + description="A task that takes a few seconds to complete with status updates", + inputSchema={"type": "object", "properties": {}}, + execution=types.ToolExecution(taskSupport=types.TASK_REQUIRED), + ) + ] + + +@server.call_tool() +async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[types.TextContent] | types.CreateTaskResult: + ctx = server.request_context + app = ctx.lifespan_context + + # Validate task mode - raises McpError(-32601) if client didn't use task augmentation + ctx.experimental.validate_task_mode(types.TASK_REQUIRED) + + # Create the task + metadata = ctx.experimental.task_metadata + assert metadata is not None + task = await app.store.create_task(metadata) + + # Spawn background work + async def do_work() -> None: + async with task_execution(task.taskId, app.store) as task_ctx: + await task_ctx.update_status("Starting work...") + await anyio.sleep(1) + + await task_ctx.update_status("Processing step 1...") + await anyio.sleep(1) + + await task_ctx.update_status("Processing step 2...") + await anyio.sleep(1) + + await task_ctx.complete( + types.CallToolResult(content=[types.TextContent(type="text", text="Task completed!")]) + ) + + app.task_group.start_soon(do_work) + return types.CreateTaskResult(task=task) + + +@server.experimental.get_task() +async def handle_get_task(request: types.GetTaskRequest) -> types.GetTaskResult: + app = server.request_context.lifespan_context + task = await app.store.get_task(request.params.taskId) + if task is None: + raise ValueError(f"Task {request.params.taskId} not found") + return types.GetTaskResult( + taskId=task.taskId, + status=task.status, + statusMessage=task.statusMessage, + createdAt=task.createdAt, + lastUpdatedAt=task.lastUpdatedAt, + ttl=task.ttl, + pollInterval=task.pollInterval, + ) + + +@server.experimental.get_task_result() +async def handle_get_task_result(request: types.GetTaskPayloadRequest) -> types.GetTaskPayloadResult: + app = server.request_context.lifespan_context + result = await app.store.get_result(request.params.taskId) + if result is None: + raise ValueError(f"Result for task {request.params.taskId} not found") + assert isinstance(result, types.CallToolResult) + return types.GetTaskPayloadResult(**result.model_dump()) + + +@click.command() +@click.option("--port", default=8000, help="Port to listen on") +def main(port: int) -> int: + session_manager = StreamableHTTPSessionManager(app=server) + + @asynccontextmanager + async def app_lifespan(app: Starlette) -> AsyncIterator[None]: + async with session_manager.run(): + yield + + starlette_app = Starlette( + routes=[Mount("/mcp", app=session_manager.handle_request)], + lifespan=app_lifespan, + ) + + print(f"Starting server on http://localhost:{port}/mcp") + uvicorn.run(starlette_app, host="127.0.0.1", port=port) + return 0 diff --git a/examples/servers/simple-task/pyproject.toml b/examples/servers/simple-task/pyproject.toml new file mode 100644 index 0000000000..a8fba8bdc1 --- /dev/null +++ b/examples/servers/simple-task/pyproject.toml @@ -0,0 +1,43 @@ +[project] +name = "mcp-simple-task" +version = "0.1.0" +description = "A simple MCP server demonstrating tasks" +readme = "README.md" +requires-python = ">=3.10" +authors = [{ name = "Anthropic, PBC." }] +keywords = ["mcp", "llm", "tasks"] +license = { text = "MIT" } +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", +] +dependencies = ["anyio>=4.5", "click>=8.0", "mcp", "starlette", "uvicorn"] + +[project.scripts] +mcp-simple-task = "mcp_simple_task.server:main" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["mcp_simple_task"] + +[tool.pyright] +include = ["mcp_simple_task"] +venvPath = "." +venv = ".venv" + +[tool.ruff.lint] +select = ["E", "F", "I"] +ignore = [] + +[tool.ruff] +line-length = 120 +target-version = "py310" + +[dependency-groups] +dev = ["pyright>=1.1.378", "ruff>=0.6.9"] diff --git a/mkdocs.yml b/mkdocs.yml index 18cbb034bb..22c323d9d4 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -18,6 +18,12 @@ nav: - Low-Level Server: low-level-server.md - Authorization: authorization.md - Testing: testing.md + - Experimental: + - Overview: experimental/index.md + - Tasks: + - Introduction: experimental/tasks.md + - Server Implementation: experimental/tasks-server.md + - Client Usage: experimental/tasks-client.md - API Reference: api.md theme: diff --git a/src/mcp/client/experimental/__init__.py b/src/mcp/client/experimental/__init__.py new file mode 100644 index 0000000000..b6579b191e --- /dev/null +++ b/src/mcp/client/experimental/__init__.py @@ -0,0 +1,9 @@ +""" +Experimental client features. + +WARNING: These APIs are experimental and may change without notice. +""" + +from mcp.client.experimental.tasks import ExperimentalClientFeatures + +__all__ = ["ExperimentalClientFeatures"] diff --git a/src/mcp/client/experimental/task_handlers.py b/src/mcp/client/experimental/task_handlers.py new file mode 100644 index 0000000000..69621e6660 --- /dev/null +++ b/src/mcp/client/experimental/task_handlers.py @@ -0,0 +1,295 @@ +""" +Experimental task handler protocols for server -> client requests. + +This module provides Protocol types and default handlers for when servers +send task-related requests to clients (the reverse of normal client -> server flow). + +WARNING: These APIs are experimental and may change without notice. + +Use cases: +- Server sends task-augmented sampling/elicitation request to client +- Client creates a local task, spawns background work, returns CreateTaskResult +- Server polls client's task status via tasks/get, tasks/result, etc. +""" + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Protocol + +import mcp.types as types +from mcp.shared.context import RequestContext +from mcp.shared.session import RequestResponder + +if TYPE_CHECKING: + from mcp.client.session import ClientSession + + +class GetTaskHandlerFnT(Protocol): + """Handler for tasks/get requests from server. + + WARNING: This is experimental and may change without notice. + """ + + async def __call__( + self, + context: RequestContext["ClientSession", Any], + params: types.GetTaskRequestParams, + ) -> types.GetTaskResult | types.ErrorData: ... # pragma: no branch + + +class GetTaskResultHandlerFnT(Protocol): + """Handler for tasks/result requests from server. + + WARNING: This is experimental and may change without notice. + """ + + async def __call__( + self, + context: RequestContext["ClientSession", Any], + params: types.GetTaskPayloadRequestParams, + ) -> types.GetTaskPayloadResult | types.ErrorData: ... # pragma: no branch + + +class ListTasksHandlerFnT(Protocol): + """Handler for tasks/list requests from server. + + WARNING: This is experimental and may change without notice. + """ + + async def __call__( + self, + context: RequestContext["ClientSession", Any], + params: types.PaginatedRequestParams | None, + ) -> types.ListTasksResult | types.ErrorData: ... # pragma: no branch + + +class CancelTaskHandlerFnT(Protocol): + """Handler for tasks/cancel requests from server. + + WARNING: This is experimental and may change without notice. + """ + + async def __call__( + self, + context: RequestContext["ClientSession", Any], + params: types.CancelTaskRequestParams, + ) -> types.CancelTaskResult | types.ErrorData: ... # pragma: no branch + + +class TaskAugmentedSamplingFnT(Protocol): + """Handler for task-augmented sampling/createMessage requests from server. + + When server sends a CreateMessageRequest with task field, this callback + is invoked. The callback should create a task, spawn background work, + and return CreateTaskResult immediately. + + WARNING: This is experimental and may change without notice. + """ + + async def __call__( + self, + context: RequestContext["ClientSession", Any], + params: types.CreateMessageRequestParams, + task_metadata: types.TaskMetadata, + ) -> types.CreateTaskResult | types.ErrorData: ... # pragma: no branch + + +class TaskAugmentedElicitationFnT(Protocol): + """Handler for task-augmented elicitation/create requests from server. + + When server sends an ElicitRequest with task field, this callback + is invoked. The callback should create a task, spawn background work, + and return CreateTaskResult immediately. + + WARNING: This is experimental and may change without notice. + """ + + async def __call__( + self, + context: RequestContext["ClientSession", Any], + params: types.ElicitRequestParams, + task_metadata: types.TaskMetadata, + ) -> types.CreateTaskResult | types.ErrorData: ... # pragma: no branch + + +# ============================================================================= +# Default Handlers (return "not supported" errors) +# ============================================================================= + + +async def default_get_task_handler( + context: RequestContext["ClientSession", Any], + params: types.GetTaskRequestParams, +) -> types.GetTaskResult | types.ErrorData: + return types.ErrorData( + code=types.METHOD_NOT_FOUND, + message="tasks/get not supported", + ) + + +async def default_get_task_result_handler( + context: RequestContext["ClientSession", Any], + params: types.GetTaskPayloadRequestParams, +) -> types.GetTaskPayloadResult | types.ErrorData: + return types.ErrorData( + code=types.METHOD_NOT_FOUND, + message="tasks/result not supported", + ) + + +async def default_list_tasks_handler( + context: RequestContext["ClientSession", Any], + params: types.PaginatedRequestParams | None, +) -> types.ListTasksResult | types.ErrorData: + return types.ErrorData( + code=types.METHOD_NOT_FOUND, + message="tasks/list not supported", + ) + + +async def default_cancel_task_handler( + context: RequestContext["ClientSession", Any], + params: types.CancelTaskRequestParams, +) -> types.CancelTaskResult | types.ErrorData: + return types.ErrorData( + code=types.METHOD_NOT_FOUND, + message="tasks/cancel not supported", + ) + + +async def default_task_augmented_sampling( + context: RequestContext["ClientSession", Any], + params: types.CreateMessageRequestParams, + task_metadata: types.TaskMetadata, +) -> types.CreateTaskResult | types.ErrorData: + return types.ErrorData( + code=types.INVALID_REQUEST, + message="Task-augmented sampling not supported", + ) + + +async def default_task_augmented_elicitation( + context: RequestContext["ClientSession", Any], + params: types.ElicitRequestParams, + task_metadata: types.TaskMetadata, +) -> types.CreateTaskResult | types.ErrorData: + return types.ErrorData( + code=types.INVALID_REQUEST, + message="Task-augmented elicitation not supported", + ) + + +@dataclass +class ExperimentalTaskHandlers: + """Container for experimental task handlers. + + Groups all task-related handlers that handle server -> client requests. + This includes both pure task requests (get, list, cancel, result) and + task-augmented request handlers (sampling, elicitation with task field). + + WARNING: These APIs are experimental and may change without notice. + + Example: + handlers = ExperimentalTaskHandlers( + get_task=my_get_task_handler, + list_tasks=my_list_tasks_handler, + ) + session = ClientSession(..., experimental_task_handlers=handlers) + """ + + # Pure task request handlers + get_task: GetTaskHandlerFnT = field(default=default_get_task_handler) + get_task_result: GetTaskResultHandlerFnT = field(default=default_get_task_result_handler) + list_tasks: ListTasksHandlerFnT = field(default=default_list_tasks_handler) + cancel_task: CancelTaskHandlerFnT = field(default=default_cancel_task_handler) + + # Task-augmented request handlers + augmented_sampling: TaskAugmentedSamplingFnT = field(default=default_task_augmented_sampling) + augmented_elicitation: TaskAugmentedElicitationFnT = field(default=default_task_augmented_elicitation) + + def build_capability(self) -> types.ClientTasksCapability | None: + """Build ClientTasksCapability from the configured handlers. + + Returns a capability object that reflects which handlers are configured + (i.e., not using the default "not supported" handlers). + + Returns: + ClientTasksCapability if any handlers are provided, None otherwise + """ + has_list = self.list_tasks is not default_list_tasks_handler + has_cancel = self.cancel_task is not default_cancel_task_handler + has_sampling = self.augmented_sampling is not default_task_augmented_sampling + has_elicitation = self.augmented_elicitation is not default_task_augmented_elicitation + + # If no handlers are provided, return None + if not any([has_list, has_cancel, has_sampling, has_elicitation]): + return None + + # Build requests capability if any request handlers are provided + requests_capability: types.ClientTasksRequestsCapability | None = None + if has_sampling or has_elicitation: + requests_capability = types.ClientTasksRequestsCapability( + sampling=types.TasksSamplingCapability(createMessage=types.TasksCreateMessageCapability()) + if has_sampling + else None, + elicitation=types.TasksElicitationCapability(create=types.TasksCreateElicitationCapability()) + if has_elicitation + else None, + ) + + return types.ClientTasksCapability( + list=types.TasksListCapability() if has_list else None, + cancel=types.TasksCancelCapability() if has_cancel else None, + requests=requests_capability, + ) + + @staticmethod + def handles_request(request: types.ServerRequest) -> bool: + """Check if this handler handles the given request type.""" + return isinstance( + request.root, + types.GetTaskRequest | types.GetTaskPayloadRequest | types.ListTasksRequest | types.CancelTaskRequest, + ) + + async def handle_request( + self, + ctx: RequestContext["ClientSession", Any], + responder: RequestResponder[types.ServerRequest, types.ClientResult], + ) -> None: + """Handle a task-related request from the server. + + Call handles_request() first to check if this handler can handle the request. + """ + from pydantic import TypeAdapter + + client_response_type: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter( + types.ClientResult | types.ErrorData + ) + + match responder.request.root: + case types.GetTaskRequest(params=params): + response = await self.get_task(ctx, params) + client_response = client_response_type.validate_python(response) + await responder.respond(client_response) + + case types.GetTaskPayloadRequest(params=params): + response = await self.get_task_result(ctx, params) + client_response = client_response_type.validate_python(response) + await responder.respond(client_response) + + case types.ListTasksRequest(params=params): + response = await self.list_tasks(ctx, params) + client_response = client_response_type.validate_python(response) + await responder.respond(client_response) + + case types.CancelTaskRequest(params=params): + response = await self.cancel_task(ctx, params) + client_response = client_response_type.validate_python(response) + await responder.respond(client_response) + + case _: # pragma: no cover + raise ValueError(f"Unhandled request type: {type(responder.request.root)}") + + +# Backwards compatibility aliases +default_task_augmented_sampling_callback = default_task_augmented_sampling +default_task_augmented_elicitation_callback = default_task_augmented_elicitation diff --git a/src/mcp/client/experimental/tasks.py b/src/mcp/client/experimental/tasks.py new file mode 100644 index 0000000000..0a1031e971 --- /dev/null +++ b/src/mcp/client/experimental/tasks.py @@ -0,0 +1,193 @@ +""" +Experimental client-side task support. + +This module provides client methods for interacting with MCP tasks. + +WARNING: These APIs are experimental and may change without notice. + +Example: + # Call a tool as a task + result = await session.experimental.call_tool_as_task("tool_name", {"arg": "value"}) + task_id = result.task.taskId + + # Get task status + status = await session.experimental.get_task(task_id) + + # Get task result when complete + if status.status == "completed": + result = await session.experimental.get_task_result(task_id, CallToolResult) + + # List all tasks + tasks = await session.experimental.list_tasks() + + # Cancel a task + await session.experimental.cancel_task(task_id) +""" + +from typing import TYPE_CHECKING, Any, TypeVar + +import mcp.types as types + +if TYPE_CHECKING: + from mcp.client.session import ClientSession + +ResultT = TypeVar("ResultT", bound=types.Result) + + +class ExperimentalClientFeatures: + """ + Experimental client features for tasks and other experimental APIs. + + WARNING: These APIs are experimental and may change without notice. + + Access via session.experimental: + status = await session.experimental.get_task(task_id) + """ + + def __init__(self, session: "ClientSession") -> None: + self._session = session + + async def call_tool_as_task( + self, + name: str, + arguments: dict[str, Any] | None = None, + *, + ttl: int = 60000, + meta: dict[str, Any] | None = None, + ) -> types.CreateTaskResult: + """Call a tool as a task, returning a CreateTaskResult for polling. + + This is a convenience method for calling tools that support task execution. + The server will return a task reference instead of the immediate result, + which can then be polled via `get_task()` and retrieved via `get_task_result()`. + + Args: + name: The tool name + arguments: Tool arguments + ttl: Task time-to-live in milliseconds (default: 60000 = 1 minute) + meta: Optional metadata to include in the request + + Returns: + CreateTaskResult containing the task reference + + Example: + # Create task + result = await session.experimental.call_tool_as_task( + "long_running_tool", {"input": "data"} + ) + task_id = result.task.taskId + + # Poll for completion + while True: + status = await session.experimental.get_task(task_id) + if status.status == "completed": + break + await asyncio.sleep(0.5) + + # Get result + final = await session.experimental.get_task_result(task_id, CallToolResult) + """ + _meta: types.RequestParams.Meta | None = None + if meta is not None: + _meta = types.RequestParams.Meta(**meta) + + return await self._session.send_request( + types.ClientRequest( + types.CallToolRequest( + params=types.CallToolRequestParams( + name=name, + arguments=arguments, + task=types.TaskMetadata(ttl=ttl), + _meta=_meta, + ), + ) + ), + types.CreateTaskResult, + ) + + async def get_task(self, task_id: str) -> types.GetTaskResult: + """ + Get the current status of a task. + + Args: + task_id: The task identifier + + Returns: + GetTaskResult containing the task status and metadata + """ + return await self._session.send_request( + types.ClientRequest( + types.GetTaskRequest( + params=types.GetTaskRequestParams(taskId=task_id), + ) + ), + types.GetTaskResult, + ) + + async def get_task_result( + self, + task_id: str, + result_type: type[ResultT], + ) -> ResultT: + """ + Get the result of a completed task. + + The result type depends on the original request type: + - tools/call tasks return CallToolResult + - Other request types return their corresponding result type + + Args: + task_id: The task identifier + result_type: The expected result type (e.g., CallToolResult) + + Returns: + The task result, validated against result_type + """ + return await self._session.send_request( + types.ClientRequest( + types.GetTaskPayloadRequest( + params=types.GetTaskPayloadRequestParams(taskId=task_id), + ) + ), + result_type, + ) + + async def list_tasks( + self, + cursor: str | None = None, + ) -> types.ListTasksResult: + """ + List all tasks. + + Args: + cursor: Optional pagination cursor + + Returns: + ListTasksResult containing tasks and optional next cursor + """ + params = types.PaginatedRequestParams(cursor=cursor) if cursor else None + return await self._session.send_request( + types.ClientRequest( + types.ListTasksRequest(params=params), + ), + types.ListTasksResult, + ) + + async def cancel_task(self, task_id: str) -> types.CancelTaskResult: + """ + Cancel a running task. + + Args: + task_id: The task identifier + + Returns: + CancelTaskResult with the updated task state + """ + return await self._session.send_request( + types.ClientRequest( + types.CancelTaskRequest( + params=types.CancelTaskRequestParams(taskId=task_id), + ) + ), + types.CancelTaskResult, + ) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index be47d681fb..4986679a0b 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -8,6 +8,8 @@ from typing_extensions import deprecated import mcp.types as types +from mcp.client.experimental import ExperimentalClientFeatures +from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers from mcp.shared.context import RequestContext from mcp.shared.message import SessionMessage from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder @@ -118,6 +120,8 @@ def __init__( logging_callback: LoggingFnT | None = None, message_handler: MessageHandlerFnT | None = None, client_info: types.Implementation | None = None, + *, + experimental_task_handlers: ExperimentalTaskHandlers | None = None, ) -> None: super().__init__( read_stream, @@ -134,6 +138,10 @@ def __init__( self._message_handler = message_handler or _default_message_handler self._tool_output_schemas: dict[str, dict[str, Any] | None] = {} self._server_capabilities: types.ServerCapabilities | None = None + self._experimental_features: ExperimentalClientFeatures | None = None + + # Experimental: Task handlers (use defaults if not provided) + self._task_handlers = experimental_task_handlers or ExperimentalTaskHandlers() async def initialize(self) -> types.InitializeResult: sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None @@ -164,6 +172,7 @@ async def initialize(self) -> types.InitializeResult: elicitation=elicitation, experimental=None, roots=roots, + tasks=self._task_handlers.build_capability(), ), clientInfo=self._client_info, ), @@ -188,6 +197,20 @@ def get_server_capabilities(self) -> types.ServerCapabilities | None: """ return self._server_capabilities + @property + def experimental(self) -> ExperimentalClientFeatures: + """Experimental APIs for tasks and other features. + + WARNING: These APIs are experimental and may change without notice. + + Example: + status = await session.experimental.get_task(task_id) + result = await session.experimental.get_task_result(task_id, CallToolResult) + """ + if self._experimental_features is None: + self._experimental_features = ExperimentalClientFeatures(self) + return self._experimental_features + async def send_ping(self) -> types.EmptyResult: """Send a ping request.""" return await self.send_request( @@ -521,16 +544,31 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques lifespan_context=None, ) + # Delegate to experimental task handler if applicable + if self._task_handlers.handles_request(responder.request): + with responder: + await self._task_handlers.handle_request(ctx, responder) + return None + + # Core request handling match responder.request.root: case types.CreateMessageRequest(params=params): with responder: - response = await self._sampling_callback(ctx, params) + # Check if this is a task-augmented request + if params.task is not None: + response = await self._task_handlers.augmented_sampling(ctx, params, params.task) + else: + response = await self._sampling_callback(ctx, params) client_response = ClientResponse.validate_python(response) await responder.respond(client_response) case types.ElicitRequest(params=params): with responder: - response = await self._elicitation_callback(ctx, params) + # Check if this is a task-augmented request + if params.task is not None: + response = await self._task_handlers.augmented_elicitation(ctx, params, params.task) + else: + response = await self._elicitation_callback(ctx, params) client_response = ClientResponse.validate_python(response) await responder.respond(client_response) @@ -544,6 +582,10 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques with responder: return await responder.respond(types.ClientResult(root=types.EmptyResult())) + case _: # pragma: no cover + raise NotImplementedError() + return None + async def _handle_incoming( self, req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, diff --git a/src/mcp/server/lowlevel/experimental.py b/src/mcp/server/lowlevel/experimental.py new file mode 100644 index 0000000000..cefa0fb975 --- /dev/null +++ b/src/mcp/server/lowlevel/experimental.py @@ -0,0 +1,157 @@ +"""Experimental handlers for the low-level MCP server. + +WARNING: These APIs are experimental and may change without notice. +""" + +import logging +from collections.abc import Awaitable, Callable + +from mcp.server.lowlevel.func_inspection import create_call_wrapper +from mcp.types import ( + CancelTaskRequest, + CancelTaskResult, + GetTaskPayloadRequest, + GetTaskPayloadResult, + GetTaskRequest, + GetTaskResult, + ListTasksRequest, + ListTasksResult, + ServerCapabilities, + ServerResult, + ServerTasksCapability, + ServerTasksRequestsCapability, + TasksCancelCapability, + TasksListCapability, + TasksToolsCapability, +) + +logger = logging.getLogger(__name__) + + +class ExperimentalHandlers: + """Experimental request/notification handlers. + + WARNING: These APIs are experimental and may change without notice. + """ + + def __init__( + self, + request_handlers: dict[type, Callable[..., Awaitable[ServerResult]]], + notification_handlers: dict[type, Callable[..., Awaitable[None]]], + ): + self._request_handlers = request_handlers + self._notification_handlers = notification_handlers + + def update_capabilities(self, capabilities: ServerCapabilities) -> None: + capabilities.tasks = ServerTasksCapability() + if ListTasksRequest in self._request_handlers: + capabilities.tasks.list = TasksListCapability() + if CancelTaskRequest in self._request_handlers: + capabilities.tasks.cancel = TasksCancelCapability() + + capabilities.tasks.requests = ServerTasksRequestsCapability( + tools=TasksToolsCapability() + ) # assuming always supported for now + + def list_tasks( + self, + ) -> Callable[ + [Callable[[ListTasksRequest], Awaitable[ListTasksResult]]], + Callable[[ListTasksRequest], Awaitable[ListTasksResult]], + ]: + """Register a handler for listing tasks. + + WARNING: This API is experimental and may change without notice. + """ + + def decorator( + func: Callable[[ListTasksRequest], Awaitable[ListTasksResult]], + ) -> Callable[[ListTasksRequest], Awaitable[ListTasksResult]]: + logger.debug("Registering handler for ListTasksRequest") + wrapper = create_call_wrapper(func, ListTasksRequest) + + async def handler(req: ListTasksRequest) -> ServerResult: + result = await wrapper(req) + return ServerResult(result) + + self._request_handlers[ListTasksRequest] = handler + return func + + return decorator + + def get_task( + self, + ) -> Callable[ + [Callable[[GetTaskRequest], Awaitable[GetTaskResult]]], Callable[[GetTaskRequest], Awaitable[GetTaskResult]] + ]: + """Register a handler for getting task status. + + WARNING: This API is experimental and may change without notice. + """ + + def decorator( + func: Callable[[GetTaskRequest], Awaitable[GetTaskResult]], + ) -> Callable[[GetTaskRequest], Awaitable[GetTaskResult]]: + logger.debug("Registering handler for GetTaskRequest") + wrapper = create_call_wrapper(func, GetTaskRequest) + + async def handler(req: GetTaskRequest) -> ServerResult: + result = await wrapper(req) + return ServerResult(result) + + self._request_handlers[GetTaskRequest] = handler + return func + + return decorator + + def get_task_result( + self, + ) -> Callable[ + [Callable[[GetTaskPayloadRequest], Awaitable[GetTaskPayloadResult]]], + Callable[[GetTaskPayloadRequest], Awaitable[GetTaskPayloadResult]], + ]: + """Register a handler for getting task results/payload. + + WARNING: This API is experimental and may change without notice. + """ + + def decorator( + func: Callable[[GetTaskPayloadRequest], Awaitable[GetTaskPayloadResult]], + ) -> Callable[[GetTaskPayloadRequest], Awaitable[GetTaskPayloadResult]]: + logger.debug("Registering handler for GetTaskPayloadRequest") + wrapper = create_call_wrapper(func, GetTaskPayloadRequest) + + async def handler(req: GetTaskPayloadRequest) -> ServerResult: + result = await wrapper(req) + return ServerResult(result) + + self._request_handlers[GetTaskPayloadRequest] = handler + return func + + return decorator + + def cancel_task( + self, + ) -> Callable[ + [Callable[[CancelTaskRequest], Awaitable[CancelTaskResult]]], + Callable[[CancelTaskRequest], Awaitable[CancelTaskResult]], + ]: + """Register a handler for cancelling tasks. + + WARNING: This API is experimental and may change without notice. + """ + + def decorator( + func: Callable[[CancelTaskRequest], Awaitable[CancelTaskResult]], + ) -> Callable[[CancelTaskRequest], Awaitable[CancelTaskResult]]: + logger.debug("Registering handler for CancelTaskRequest") + wrapper = create_call_wrapper(func, CancelTaskRequest) + + async def handler(req: CancelTaskRequest) -> ServerResult: + result = await wrapper(req) + return ServerResult(result) + + self._request_handlers[CancelTaskRequest] = handler + return func + + return decorator diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index a0617036f9..9d87b3e4f2 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -67,12 +67,14 @@ async def main(): from __future__ import annotations as _annotations +import base64 import contextvars import json import logging import warnings from collections.abc import AsyncIterator, Awaitable, Callable, Iterable from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager +from importlib.metadata import version as pkg_version from typing import Any, Generic, TypeAlias, cast import anyio @@ -82,11 +84,12 @@ async def main(): from typing_extensions import TypeVar import mcp.types as types +from mcp.server.lowlevel.experimental import ExperimentalHandlers from mcp.server.lowlevel.func_inspection import create_call_wrapper from mcp.server.lowlevel.helper_types import ReadResourceContents from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession -from mcp.shared.context import RequestContext +from mcp.shared.context import Experimental, RequestContext from mcp.shared.exceptions import McpError from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.session import RequestResponder @@ -155,6 +158,7 @@ def __init__( } self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {} self._tool_cache: dict[str, types.Tool] = {} + self._experimental_handlers: ExperimentalHandlers | None = None logger.debug("Initializing server %r", name) def create_initialization_options( @@ -164,11 +168,9 @@ def create_initialization_options( ) -> InitializationOptions: """Create initialization options from this server instance.""" - def pkg_version(package: str) -> str: + def get_package_version(package: str) -> str: try: - from importlib.metadata import version - - return version(package) + return pkg_version(package) except Exception: # pragma: no cover pass @@ -176,7 +178,7 @@ def pkg_version(package: str) -> str: return InitializationOptions( server_name=self.name, - server_version=self.version if self.version else pkg_version("mcp"), + server_version=self.version if self.version else get_package_version("mcp"), capabilities=self.get_capabilities( notification_options or NotificationOptions(), experimental_capabilities or {}, @@ -220,7 +222,7 @@ def get_capabilities( if types.CompleteRequest in self.request_handlers: completions_capability = types.CompletionsCapability() - return types.ServerCapabilities( + capabilities = types.ServerCapabilities( prompts=prompts_capability, resources=resources_capability, tools=tools_capability, @@ -228,6 +230,9 @@ def get_capabilities( experimental=experimental_capabilities, completions=completions_capability, ) + if self._experimental_handlers: + self._experimental_handlers.update_capabilities(capabilities) + return capabilities @property def request_context( @@ -236,6 +241,18 @@ def request_context( """If called outside of a request context, this will raise a LookupError.""" return request_ctx.get() + @property + def experimental(self) -> ExperimentalHandlers: + """Experimental APIs for tasks and other features. + + WARNING: These APIs are experimental and may change without notice. + """ + + # We create this inline so we only add these capabilities _if_ they're actually used + if self._experimental_handlers is None: + self._experimental_handlers = ExperimentalHandlers(self.request_handlers, self.notification_handlers) + return self._experimental_handlers + def list_prompts(self): def decorator( func: Callable[[], Awaitable[list[types.Prompt]]] @@ -328,8 +345,6 @@ def create_content(data: str | bytes, mime_type: str | None): mimeType=mime_type or "text/plain", ) case bytes() as data: # pragma: no cover - import base64 - return types.BlobResourceContents( uri=req.params.uri, blob=base64.b64encode(data).decode(), @@ -483,7 +498,13 @@ def call_tool(self, *, validate_input: bool = True): def decorator( func: Callable[ ..., - Awaitable[UnstructuredContent | StructuredContent | CombinationContent | types.CallToolResult], + Awaitable[ + UnstructuredContent + | StructuredContent + | CombinationContent + | types.CallToolResult + | types.CreateTaskResult + ], ], ): logger.debug("Registering handler for CallToolRequest") @@ -509,6 +530,9 @@ async def handler(req: types.CallToolRequest): maybe_structured_content: StructuredContent | None if isinstance(results, types.CallToolResult): return types.ServerResult(results) + elif isinstance(results, types.CreateTaskResult): + # Task-augmented execution returns task info instead of result + return types.ServerResult(results) elif isinstance(results, tuple) and len(results) == 2: # tool returned both structured and unstructured content unstructured_content, maybe_structured_content = cast(CombinationContent, results) @@ -669,13 +693,14 @@ async def _handle_message( async def _handle_request( self, message: RequestResponder[types.ClientRequest, types.ServerResult], - req: Any, + req: types.ClientRequestType, session: ServerSession, lifespan_context: LifespanResultT, raise_exceptions: bool, ): logger.info("Processing request of type %s", type(req).__name__) - if handler := self.request_handlers.get(type(req)): # type: ignore + + if handler := self.request_handlers.get(type(req)): logger.debug("Dispatching request of type %s", type(req).__name__) token = None @@ -689,12 +714,17 @@ async def _handle_request( # Set our global state that can be retrieved via # app.get_request_context() + client_capabilities = session.client_params.capabilities if session.client_params else None token = request_ctx.set( RequestContext( message.request_id, message.request_meta, session, lifespan_context, + Experimental( + task_metadata=message.request_params.task if message.request_params else None, + _client_capabilities=client_capabilities, + ), request=request_data, ) ) diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index b116fbe384..81ce350c75 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -48,6 +48,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: import mcp.types as types from mcp.server.models import InitializationOptions from mcp.shared.exceptions import McpError +from mcp.shared.experimental.tasks import TaskResultHandler from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.session import ( BaseSession, @@ -142,6 +143,27 @@ def check_client_capability(self, capability: types.ClientCapabilities) -> bool: return True + def set_task_result_handler(self, handler: TaskResultHandler) -> None: + """ + Set the TaskResultHandler for this session. + + This enables response routing for task-augmented requests. When a + TaskSession enqueues an elicitation request, the response will be + routed back through this handler. + + The handler is automatically registered as a response router. + + Args: + handler: The TaskResultHandler to use for this session + + Example: + task_store = InMemoryTaskStore() + message_queue = InMemoryTaskMessageQueue() + handler = TaskResultHandler(task_store, message_queue) + session.set_task_result_handler(handler) + """ + self.add_response_router(handler) + async def _receive_loop(self) -> None: async with self._incoming_message_stream_writer: await super()._receive_loop() @@ -481,6 +503,20 @@ async def send_elicit_complete( related_request_id, ) + async def send_message(self, message: SessionMessage) -> None: + """Send a raw session message. + + This is primarily used by TaskResultHandler to deliver queued messages + (elicitation/sampling requests) to the client during task execution. + + WARNING: This is a low-level method. Prefer using higher-level methods + like send_notification() or send_request() for normal operations. + + Args: + message: The session message to send + """ + await self._write_stream.send(message) + async def _handle_incoming(self, req: ServerRequestResponder) -> None: await self._incoming_message_stream_writer.send(req) diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py index f3006e7d5f..4ee88126b6 100644 --- a/src/mcp/shared/context.py +++ b/src/mcp/shared/context.py @@ -1,20 +1,142 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any, Generic from typing_extensions import TypeVar +from mcp.shared.exceptions import McpError from mcp.shared.session import BaseSession -from mcp.types import RequestId, RequestParams +from mcp.types import ( + METHOD_NOT_FOUND, + TASK_FORBIDDEN, + TASK_REQUIRED, + ClientCapabilities, + ErrorData, + RequestId, + RequestParams, + TaskExecutionMode, + TaskMetadata, + Tool, +) SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any]) LifespanContextT = TypeVar("LifespanContextT") RequestT = TypeVar("RequestT", default=Any) +@dataclass +class Experimental: + """ + Experimental features context for task-augmented requests. + + Provides helpers for validating task execution compatibility. + """ + + task_metadata: TaskMetadata | None = None + _client_capabilities: ClientCapabilities | None = field(default=None, repr=False) + + @property + def is_task(self) -> bool: + """Check if this request is task-augmented.""" + return self.task_metadata is not None + + @property + def client_supports_tasks(self) -> bool: + """Check if the client declared task support.""" + if self._client_capabilities is None: + return False + return self._client_capabilities.tasks is not None + + def validate_task_mode( + self, + tool_task_mode: TaskExecutionMode | None, + *, + raise_error: bool = True, + ) -> ErrorData | None: + """ + Validate that the request is compatible with the tool's task execution mode. + + Per MCP spec: + - "required": Clients MUST invoke as task. Server returns -32601 if not. + - "forbidden" (or None): Clients MUST NOT invoke as task. Server returns -32601 if they do. + - "optional": Either is acceptable. + + Args: + tool_task_mode: The tool's execution.taskSupport value + ("forbidden", "optional", "required", or None) + raise_error: If True, raises McpError on validation failure. If False, returns ErrorData. + + Returns: + None if valid, ErrorData if invalid and raise_error=False + + Raises: + McpError: If invalid and raise_error=True + """ + + mode = tool_task_mode or TASK_FORBIDDEN + + error: ErrorData | None = None + + if mode == TASK_REQUIRED and not self.is_task: + error = ErrorData( + code=METHOD_NOT_FOUND, + message="This tool requires task-augmented invocation", + ) + elif mode == TASK_FORBIDDEN and self.is_task: + error = ErrorData( + code=METHOD_NOT_FOUND, + message="This tool does not support task-augmented invocation", + ) + + if error is not None and raise_error: + raise McpError(error) + + return error + + def validate_for_tool( + self, + tool: Tool, + *, + raise_error: bool = True, + ) -> ErrorData | None: + """ + Validate that the request is compatible with the given tool. + + Convenience wrapper around validate_task_mode that extracts the mode from a Tool. + + Args: + tool: The Tool definition + raise_error: If True, raises McpError on validation failure. + + Returns: + None if valid, ErrorData if invalid and raise_error=False + """ + mode = tool.execution.taskSupport if tool.execution else None + return self.validate_task_mode(mode, raise_error=raise_error) + + def can_use_tool(self, tool_task_mode: TaskExecutionMode | None) -> bool: + """ + Check if this client can use a tool with the given task mode. + + Useful for filtering tool lists or providing warnings. + Returns False if tool requires "required" but client doesn't support tasks. + + Args: + tool_task_mode: The tool's execution.taskSupport value + + Returns: + True if the client can use this tool, False otherwise + """ + mode = tool_task_mode or TASK_FORBIDDEN + if mode == TASK_REQUIRED and not self.client_supports_tasks: + return False + return True + + @dataclass class RequestContext(Generic[SessionT, LifespanContextT, RequestT]): request_id: RequestId meta: RequestParams.Meta | None session: SessionT lifespan_context: LifespanContextT + experimental: Experimental = field(default_factory=Experimental) request: RequestT | None = None diff --git a/src/mcp/shared/experimental/__init__.py b/src/mcp/shared/experimental/__init__.py new file mode 100644 index 0000000000..9bb0f72c67 --- /dev/null +++ b/src/mcp/shared/experimental/__init__.py @@ -0,0 +1,8 @@ +"""Experimental MCP features. + +WARNING: These APIs are experimental and may change without notice. +""" + +from mcp.shared.experimental import tasks + +__all__ = ["tasks"] diff --git a/src/mcp/shared/experimental/tasks/__init__.py b/src/mcp/shared/experimental/tasks/__init__.py new file mode 100644 index 0000000000..1630f09e0d --- /dev/null +++ b/src/mcp/shared/experimental/tasks/__init__.py @@ -0,0 +1,60 @@ +""" +Experimental task management for MCP. + +This module provides: +- TaskStore: Abstract interface for task state storage +- TaskContext: Context object for task work to interact with state/notifications +- InMemoryTaskStore: Reference implementation for testing/development +- TaskMessageQueue: FIFO queue for task messages delivered via tasks/result +- InMemoryTaskMessageQueue: Reference implementation for message queue +- Helper functions: run_task, is_terminal, create_task_state, generate_task_id, cancel_task + +Architecture: +- TaskStore is pure storage - it doesn't know about execution +- TaskMessageQueue stores messages to be delivered via tasks/result +- TaskContext wraps store + session, providing a clean API for task work +- run_task is optional convenience for spawning in-process tasks + +WARNING: These APIs are experimental and may change without notice. +""" + +from mcp.shared.experimental.tasks.context import TaskContext +from mcp.shared.experimental.tasks.helpers import ( + MODEL_IMMEDIATE_RESPONSE_KEY, + cancel_task, + create_task_state, + generate_task_id, + is_terminal, + run_task, + task_execution, +) +from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore +from mcp.shared.experimental.tasks.message_queue import ( + InMemoryTaskMessageQueue, + QueuedMessage, + TaskMessageQueue, +) +from mcp.shared.experimental.tasks.resolver import Resolver +from mcp.shared.experimental.tasks.result_handler import TaskResultHandler +from mcp.shared.experimental.tasks.store import TaskStore +from mcp.shared.experimental.tasks.task_session import RELATED_TASK_METADATA_KEY, TaskSession + +__all__ = [ + "TaskStore", + "TaskContext", + "TaskSession", + "TaskResultHandler", + "Resolver", + "InMemoryTaskStore", + "TaskMessageQueue", + "InMemoryTaskMessageQueue", + "QueuedMessage", + "RELATED_TASK_METADATA_KEY", + "MODEL_IMMEDIATE_RESPONSE_KEY", + "run_task", + "task_execution", + "is_terminal", + "create_task_state", + "generate_task_id", + "cancel_task", +] diff --git a/src/mcp/shared/experimental/tasks/context.py b/src/mcp/shared/experimental/tasks/context.py new file mode 100644 index 0000000000..10fc2d09a4 --- /dev/null +++ b/src/mcp/shared/experimental/tasks/context.py @@ -0,0 +1,141 @@ +""" +TaskContext - Context for task work to interact with state and notifications. +""" + +from typing import TYPE_CHECKING + +from mcp.shared.experimental.tasks.store import TaskStore +from mcp.types import ( + Result, + ServerNotification, + Task, + TaskStatusNotification, + TaskStatusNotificationParams, +) + +if TYPE_CHECKING: + from mcp.server.session import ServerSession + + +class TaskContext: + """ + Context provided to task work for state management and notifications. + + This wraps a TaskStore and optional session, providing a clean API + for task work to update status, complete, fail, and send notifications. + + Example: + async def my_task_work(ctx: TaskContext) -> CallToolResult: + await ctx.update_status("Starting processing...") + + for i, item in enumerate(items): + await ctx.update_status(f"Processing item {i+1}/{len(items)}") + if ctx.is_cancelled: + return CallToolResult(content=[TextContent(type="text", text="Cancelled")]) + process(item) + + return CallToolResult(content=[TextContent(type="text", text="Done!")]) + """ + + def __init__( + self, + task: Task, + store: TaskStore, + session: "ServerSession | None" = None, + ): + self._task = task + self._store = store + self._session = session + self._cancelled = False + + @property + def task_id(self) -> str: + """The task identifier.""" + return self._task.taskId + + @property + def task(self) -> Task: + """The current task state.""" + return self._task + + @property + def is_cancelled(self) -> bool: + """Whether cancellation has been requested.""" + return self._cancelled + + def request_cancellation(self) -> None: + """ + Request cancellation of this task. + + This sets is_cancelled=True. Task work should check this + periodically and exit gracefully if set. + """ + self._cancelled = True + + async def update_status(self, message: str, *, notify: bool = True) -> None: + """ + Update the task's status message. + + Args: + message: The new status message + notify: Whether to send a notification to the client + """ + self._task = await self._store.update_task( + self.task_id, + status_message=message, + ) + if notify: + await self._send_notification() + + async def complete(self, result: Result, *, notify: bool = True) -> None: + """ + Mark the task as completed with the given result. + + Args: + result: The task result + notify: Whether to send a notification to the client + """ + await self._store.store_result(self.task_id, result) + self._task = await self._store.update_task( + self.task_id, + status="completed", + ) + if notify: + await self._send_notification() + + async def fail(self, error: str, *, notify: bool = True) -> None: + """ + Mark the task as failed with an error message. + + Args: + error: The error message + notify: Whether to send a notification to the client + """ + self._task = await self._store.update_task( + self.task_id, + status="failed", + status_message=error, + ) + if notify: + await self._send_notification() + + async def _send_notification(self) -> None: + """Send a task status notification to the client.""" + if self._session is None: + return + + await self._session.send_notification( + ServerNotification( + TaskStatusNotification( + params=TaskStatusNotificationParams( + taskId=self._task.taskId, + status=self._task.status, + statusMessage=self._task.statusMessage, + createdAt=self._task.createdAt, + lastUpdatedAt=self._task.lastUpdatedAt, + ttl=self._task.ttl, + pollInterval=self._task.pollInterval, + ) + ) + ) + ) diff --git a/src/mcp/shared/experimental/tasks/helpers.py b/src/mcp/shared/experimental/tasks/helpers.py new file mode 100644 index 0000000000..12746c7501 --- /dev/null +++ b/src/mcp/shared/experimental/tasks/helpers.py @@ -0,0 +1,267 @@ +""" +Helper functions for task management. +""" + +from collections.abc import AsyncIterator, Awaitable, Callable +from contextlib import asynccontextmanager +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Any +from uuid import uuid4 + +from anyio.abc import TaskGroup + +from mcp.shared.exceptions import McpError +from mcp.shared.experimental.tasks.context import TaskContext +from mcp.shared.experimental.tasks.store import TaskStore +from mcp.types import ( + INVALID_PARAMS, + CancelTaskResult, + CreateTaskResult, + ErrorData, + Result, + Task, + TaskMetadata, + TaskStatus, +) + +if TYPE_CHECKING: + from mcp.server.session import ServerSession + +# Metadata key for model-immediate-response (per MCP spec) +# Servers MAY include this in CreateTaskResult._meta to provide an immediate +# response string while the task executes in the background. +MODEL_IMMEDIATE_RESPONSE_KEY = "io.modelcontextprotocol/model-immediate-response" + + +def is_terminal(status: TaskStatus) -> bool: + """ + Check if a task status represents a terminal state. + + Terminal states are those where the task has finished and will not change. + + Args: + status: The task status to check + + Returns: + True if the status is terminal (completed, failed, or cancelled) + """ + return status in ("completed", "failed", "cancelled") + + +async def cancel_task( + store: TaskStore, + task_id: str, +) -> CancelTaskResult: + """ + Cancel a task with spec-compliant validation. + + Per spec: "Receivers MUST reject cancellation of terminal status tasks + with -32602 (Invalid params)" + + This helper validates that the task exists and is not in a terminal state + before setting it to "cancelled". + + Args: + store: The task store + task_id: The task identifier to cancel + + Returns: + CancelTaskResult with the cancelled task state + + Raises: + McpError: With INVALID_PARAMS (-32602) if: + - Task does not exist + - Task is already in a terminal state (completed, failed, cancelled) + + Example: + @server.experimental.cancel_task() + async def handle_cancel(request: CancelTaskRequest) -> CancelTaskResult: + return await cancel_task(store, request.params.taskId) + """ + task = await store.get_task(task_id) + if task is None: + raise McpError( + ErrorData( + code=INVALID_PARAMS, + message=f"Task not found: {task_id}", + ) + ) + + if is_terminal(task.status): + raise McpError( + ErrorData( + code=INVALID_PARAMS, + message=f"Cannot cancel task in terminal state '{task.status}'", + ) + ) + + # Update task to cancelled status + cancelled_task = await store.update_task(task_id, status="cancelled") + return CancelTaskResult(**cancelled_task.model_dump()) + + +def generate_task_id() -> str: + """Generate a unique task ID.""" + return str(uuid4()) + + +def create_task_state( + metadata: TaskMetadata, + task_id: str | None = None, +) -> Task: + """ + Create a Task object with initial state. + + This is a helper for TaskStore implementations. + + Args: + metadata: Task metadata + task_id: Optional task ID (generated if not provided) + + Returns: + A new Task in "working" status + """ + now = datetime.now(timezone.utc) + return Task( + taskId=task_id or generate_task_id(), + status="working", + createdAt=now, + lastUpdatedAt=now, + ttl=metadata.ttl, + pollInterval=500, # Default 500ms poll interval + ) + + +@asynccontextmanager +async def task_execution( + task_id: str, + store: TaskStore, + session: "ServerSession | None" = None, +) -> AsyncIterator[TaskContext]: + """ + Context manager for safe task execution. + + Loads a task from the store and provides a TaskContext for the work. + If an unhandled exception occurs, the task is automatically marked as failed + and the exception is suppressed (since the failure is captured in task state). + + This is the recommended pattern for executing task work, especially in + distributed scenarios where the worker may be a separate process. + + Args: + task_id: The task identifier to execute + store: The task store (must be accessible by the worker) + session: Optional session for sending notifications (often None for workers) + + Yields: + TaskContext for updating status and completing/failing the task + + Raises: + ValueError: If the task is not found in the store + + Example (in-memory): + async def work(): + async with task_execution(task.taskId, store) as ctx: + await ctx.update_status("Processing...") + result = await do_work() + await ctx.complete(result) + + task_group.start_soon(work) + + Example (distributed worker): + async def worker_process(task_id: str): + store = RedisTaskStore(redis_url) + async with task_execution(task_id, store) as ctx: + await ctx.update_status("Working...") + result = await do_work() + await ctx.complete(result) + """ + task = await store.get_task(task_id) + if task is None: + raise ValueError(f"Task {task_id} not found") + + ctx = TaskContext(task, store, session) + try: + yield ctx + except Exception as e: + # Auto-fail the task if an exception occurs and task isn't already terminal + # Exception is suppressed since failure is captured in task state + if not is_terminal(ctx.task.status): + await ctx.fail(str(e), notify=session is not None) + # Don't re-raise - the failure is recorded in task state + + +async def run_task( + task_group: TaskGroup, + store: TaskStore, + metadata: TaskMetadata, + work: Callable[[TaskContext], Awaitable[Result]], + *, + session: "ServerSession | None" = None, + task_id: str | None = None, + model_immediate_response: str | None = None, +) -> tuple[CreateTaskResult, TaskContext]: + """ + Create a task and spawn work to execute it. + + This is a convenience helper for in-process task execution. + For distributed systems, you'll want to handle task creation + and execution separately. + + Args: + task_group: The anyio TaskGroup to spawn work in + store: The task store for state management + metadata: Task metadata (ttl, etc.) + work: Async function that does the actual work + session: Optional session for sending notifications + task_id: Optional task ID (generated if not provided) + model_immediate_response: Optional string to include in _meta as + io.modelcontextprotocol/model-immediate-response. This allows + hosts to pass an immediate response to the model while the + task executes in the background. + + Returns: + Tuple of (CreateTaskResult to return to client, TaskContext for cancellation) + + Example: + async with anyio.create_task_group() as tg: + @server.call_tool() + async def handle_tool(name: str, args: dict): + ctx = server.request_context + if ctx.experimental.is_task: + result, task_ctx = await run_task( + tg, + store, + ctx.experimental.task_metadata, + lambda ctx: do_long_work(ctx, args), + session=ctx.session, + model_immediate_response="Processing started, this may take a while.", + ) + # Optionally store task_ctx for cancellation handling + return result + else: + return await do_work_sync(args) + """ + task = await store.create_task(metadata, task_id) + ctx = TaskContext(task, store, session) + + async def execute() -> None: + try: + result = await work(ctx) + # Only complete if not already in terminal state (e.g., cancelled) + if not is_terminal(ctx.task.status): + await ctx.complete(result) + except Exception as e: + # Only fail if not already in terminal state + if not is_terminal(ctx.task.status): + await ctx.fail(str(e)) + + # Spawn the work in the task group + task_group.start_soon(execute) + + # Build _meta if model_immediate_response is provided + meta: dict[str, Any] | None = None + if model_immediate_response is not None: + meta = {MODEL_IMMEDIATE_RESPONSE_KEY: model_immediate_response} + + return CreateTaskResult(task=task, **{"_meta": meta} if meta else {}), ctx diff --git a/src/mcp/shared/experimental/tasks/in_memory_task_store.py b/src/mcp/shared/experimental/tasks/in_memory_task_store.py new file mode 100644 index 0000000000..7b630ce6e2 --- /dev/null +++ b/src/mcp/shared/experimental/tasks/in_memory_task_store.py @@ -0,0 +1,219 @@ +""" +In-memory implementation of TaskStore for demonstration purposes. + +This implementation stores all tasks in memory and provides automatic cleanup +based on the TTL duration specified in the task metadata using lazy expiration. + +Note: This is not suitable for production use as all data is lost on restart. +For production, consider implementing TaskStore with a database or distributed cache. +""" + +from dataclasses import dataclass, field +from datetime import datetime, timedelta, timezone + +import anyio + +from mcp.shared.experimental.tasks.helpers import create_task_state, is_terminal +from mcp.shared.experimental.tasks.store import TaskStore +from mcp.types import Result, Task, TaskMetadata, TaskStatus + + +@dataclass +class StoredTask: + """Internal storage representation of a task.""" + + task: Task + result: Result | None = None + # Time when this task should be removed (None = never) + expires_at: datetime | None = field(default=None) + + +class InMemoryTaskStore(TaskStore): + """ + A simple in-memory implementation of TaskStore. + + Features: + - Automatic TTL-based cleanup (lazy expiration) + - Thread-safe for single-process async use + - Pagination support for list_tasks + + Limitations: + - All data lost on restart + - Not suitable for distributed systems + - No persistence + + For production, implement TaskStore with Redis, PostgreSQL, etc. + """ + + def __init__(self, page_size: int = 10) -> None: + self._tasks: dict[str, StoredTask] = {} + self._page_size = page_size + self._update_events: dict[str, anyio.Event] = {} + + def _calculate_expiry(self, ttl_ms: int | None) -> datetime | None: + """Calculate expiry time from TTL in milliseconds.""" + if ttl_ms is None: + return None + return datetime.now(timezone.utc) + timedelta(milliseconds=ttl_ms) + + def _is_expired(self, stored: StoredTask) -> bool: + """Check if a task has expired.""" + if stored.expires_at is None: + return False + return datetime.now(timezone.utc) >= stored.expires_at + + def _cleanup_expired(self) -> None: + """Remove all expired tasks. Called lazily during access operations.""" + expired_ids = [task_id for task_id, stored in self._tasks.items() if self._is_expired(stored)] + for task_id in expired_ids: + del self._tasks[task_id] + + async def create_task( + self, + metadata: TaskMetadata, + task_id: str | None = None, + ) -> Task: + """Create a new task with the given metadata.""" + # Cleanup expired tasks on access + self._cleanup_expired() + + task = create_task_state(metadata, task_id) + + if task.taskId in self._tasks: + raise ValueError(f"Task with ID {task.taskId} already exists") + + stored = StoredTask( + task=task, + expires_at=self._calculate_expiry(metadata.ttl), + ) + self._tasks[task.taskId] = stored + + # Return a copy to prevent external modification + return Task(**task.model_dump()) + + async def get_task(self, task_id: str) -> Task | None: + """Get a task by ID.""" + # Cleanup expired tasks on access + self._cleanup_expired() + + stored = self._tasks.get(task_id) + if stored is None: + return None + + # Return a copy to prevent external modification + return Task(**stored.task.model_dump()) + + async def update_task( + self, + task_id: str, + status: TaskStatus | None = None, + status_message: str | None = None, + ) -> Task: + """Update a task's status and/or message.""" + stored = self._tasks.get(task_id) + if stored is None: + raise ValueError(f"Task with ID {task_id} not found") + + # Per spec: Terminal states MUST NOT transition to any other status + if status is not None and status != stored.task.status and is_terminal(stored.task.status): + raise ValueError(f"Cannot transition from terminal status '{stored.task.status}'") + + status_changed = False + if status is not None and stored.task.status != status: + stored.task.status = status + status_changed = True + + if status_message is not None: + stored.task.statusMessage = status_message + + # Update lastUpdatedAt on any change + stored.task.lastUpdatedAt = datetime.now(timezone.utc) + + # If task is now terminal and has TTL, reset expiry timer + if status is not None and is_terminal(status) and stored.task.ttl is not None: + stored.expires_at = self._calculate_expiry(stored.task.ttl) + + # Notify waiters if status changed + if status_changed: + await self.notify_update(task_id) + + return Task(**stored.task.model_dump()) + + async def store_result(self, task_id: str, result: Result) -> None: + """Store the result for a task.""" + stored = self._tasks.get(task_id) + if stored is None: + raise ValueError(f"Task with ID {task_id} not found") + + stored.result = result + + async def get_result(self, task_id: str) -> Result | None: + """Get the stored result for a task.""" + stored = self._tasks.get(task_id) + if stored is None: + return None + + return stored.result + + async def list_tasks( + self, + cursor: str | None = None, + ) -> tuple[list[Task], str | None]: + """List tasks with pagination.""" + # Cleanup expired tasks on access + self._cleanup_expired() + + all_task_ids = list(self._tasks.keys()) + + start_index = 0 + if cursor is not None: + try: + cursor_index = all_task_ids.index(cursor) + start_index = cursor_index + 1 + except ValueError: + raise ValueError(f"Invalid cursor: {cursor}") + + page_task_ids = all_task_ids[start_index : start_index + self._page_size] + tasks = [Task(**self._tasks[tid].task.model_dump()) for tid in page_task_ids] + + # Determine next cursor + next_cursor = None + if start_index + self._page_size < len(all_task_ids) and page_task_ids: + next_cursor = page_task_ids[-1] + + return tasks, next_cursor + + async def delete_task(self, task_id: str) -> bool: + """Delete a task.""" + if task_id not in self._tasks: + return False + + del self._tasks[task_id] + return True + + async def wait_for_update(self, task_id: str) -> None: + """Wait until the task status changes.""" + if task_id not in self._tasks: + raise ValueError(f"Task with ID {task_id} not found") + + # Create a fresh event for waiting (anyio.Event can't be cleared) + self._update_events[task_id] = anyio.Event() + event = self._update_events[task_id] + await event.wait() + + async def notify_update(self, task_id: str) -> None: + """Signal that a task has been updated.""" + if task_id in self._update_events: + self._update_events[task_id].set() + + # --- Testing/debugging helpers --- + + def cleanup(self) -> None: + """Cleanup all tasks (useful for testing or graceful shutdown).""" + self._tasks.clear() + self._update_events.clear() + + def get_all_tasks(self) -> list[Task]: + """Get all tasks (useful for debugging). Returns copies to prevent modification.""" + self._cleanup_expired() + return [Task(**stored.task.model_dump()) for stored in self._tasks.values()] diff --git a/src/mcp/shared/experimental/tasks/message_queue.py b/src/mcp/shared/experimental/tasks/message_queue.py new file mode 100644 index 0000000000..cf363964b8 --- /dev/null +++ b/src/mcp/shared/experimental/tasks/message_queue.py @@ -0,0 +1,247 @@ +""" +TaskMessageQueue - FIFO queue for task-related messages. + +This implements the core message queue pattern from the MCP Tasks spec. +When a handler needs to send a request (like elicitation) during a task-augmented +request, the message is enqueued instead of sent directly. Messages are delivered +to the client only through the `tasks/result` endpoint. + +This pattern enables: +1. Decoupling request handling from message delivery +2. Proper bidirectional communication via the tasks/result stream +3. Automatic status management (working <-> input_required) +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Any, Literal + +import anyio + +from mcp.shared.experimental.tasks.resolver import Resolver +from mcp.types import JSONRPCNotification, JSONRPCRequest, RequestId + + +@dataclass +class QueuedMessage: + """ + A message queued for delivery via tasks/result. + + Messages are stored with their type and a resolver for requests + that expect responses. + """ + + type: Literal["request", "notification"] + """Whether this is a request (expects response) or notification (one-way).""" + + message: JSONRPCRequest | JSONRPCNotification + """The JSON-RPC message to send.""" + + timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + """When the message was enqueued.""" + + resolver: Resolver[dict[str, Any]] | None = None + """Resolver to set when response arrives (only for requests).""" + + original_request_id: RequestId | None = None + """The original request ID used internally, for routing responses back.""" + + +class TaskMessageQueue(ABC): + """ + Abstract interface for task message queuing. + + This is a FIFO queue that stores messages to be delivered via `tasks/result`. + When a task-augmented handler calls elicit() or sends a notification, the + message is enqueued here instead of being sent directly to the client. + + The `tasks/result` handler then dequeues and sends these messages through + the transport, with `relatedRequestId` set to the tasks/result request ID + so responses are routed correctly. + + Implementations can use in-memory storage, Redis, etc. + """ + + @abstractmethod + async def enqueue(self, task_id: str, message: QueuedMessage) -> None: + """ + Add a message to the queue for a task. + + Args: + task_id: The task identifier + message: The message to enqueue + """ + + @abstractmethod + async def dequeue(self, task_id: str) -> QueuedMessage | None: + """ + Remove and return the next message from the queue. + + Args: + task_id: The task identifier + + Returns: + The next message, or None if queue is empty + """ + + @abstractmethod + async def peek(self, task_id: str) -> QueuedMessage | None: + """ + Return the next message without removing it. + + Args: + task_id: The task identifier + + Returns: + The next message, or None if queue is empty + """ + + @abstractmethod + async def is_empty(self, task_id: str) -> bool: + """ + Check if the queue is empty for a task. + + Args: + task_id: The task identifier + + Returns: + True if no messages are queued + """ + + @abstractmethod + async def clear(self, task_id: str) -> list[QueuedMessage]: + """ + Remove and return all messages from the queue. + + This is useful for cleanup when a task is cancelled or completed. + + Args: + task_id: The task identifier + + Returns: + All queued messages (may be empty) + """ + + @abstractmethod + async def wait_for_message(self, task_id: str) -> None: + """ + Wait until a message is available in the queue. + + This blocks until either: + 1. A message is enqueued for this task + 2. The wait is cancelled + + Args: + task_id: The task identifier + """ + + @abstractmethod + async def notify_message_available(self, task_id: str) -> None: + """ + Signal that a message is available for a task. + + This wakes up any coroutines waiting in wait_for_message(). + + Args: + task_id: The task identifier + """ + + +class InMemoryTaskMessageQueue(TaskMessageQueue): + """ + In-memory implementation of TaskMessageQueue. + + This is suitable for single-process servers. For distributed systems, + implement TaskMessageQueue with Redis, RabbitMQ, etc. + + Features: + - FIFO ordering per task + - Async wait for message availability + - Thread-safe for single-process async use + """ + + def __init__(self) -> None: + self._queues: dict[str, list[QueuedMessage]] = {} + self._events: dict[str, anyio.Event] = {} + + def _get_queue(self, task_id: str) -> list[QueuedMessage]: + """Get or create the queue for a task.""" + if task_id not in self._queues: + self._queues[task_id] = [] + return self._queues[task_id] + + def _get_event(self, task_id: str) -> anyio.Event: + """Get or create the wait event for a task.""" + if task_id not in self._events: + self._events[task_id] = anyio.Event() + return self._events[task_id] + + async def enqueue(self, task_id: str, message: QueuedMessage) -> None: + """Add a message to the queue.""" + queue = self._get_queue(task_id) + queue.append(message) + # Signal that a message is available + await self.notify_message_available(task_id) + + async def dequeue(self, task_id: str) -> QueuedMessage | None: + """Remove and return the next message.""" + queue = self._get_queue(task_id) + if not queue: + return None + return queue.pop(0) + + async def peek(self, task_id: str) -> QueuedMessage | None: + """Return the next message without removing it.""" + queue = self._get_queue(task_id) + if not queue: + return None + return queue[0] + + async def is_empty(self, task_id: str) -> bool: + """Check if the queue is empty.""" + queue = self._get_queue(task_id) + return len(queue) == 0 + + async def clear(self, task_id: str) -> list[QueuedMessage]: + """Remove and return all messages.""" + queue = self._get_queue(task_id) + messages = list(queue) + queue.clear() + return messages + + async def wait_for_message(self, task_id: str) -> None: + """Wait until a message is available.""" + # Check if there are already messages + if not await self.is_empty(task_id): + return + + # Create a fresh event for waiting (anyio.Event can't be cleared) + self._events[task_id] = anyio.Event() + event = self._events[task_id] + + # Double-check after creating event (avoid race condition) + if not await self.is_empty(task_id): + return + + # Wait for a new message + await event.wait() + + async def notify_message_available(self, task_id: str) -> None: + """Signal that a message is available.""" + if task_id in self._events: + self._events[task_id].set() + + def cleanup(self, task_id: str | None = None) -> None: + """ + Clean up queues and events. + + Args: + task_id: If provided, clean up only this task. Otherwise clean up all. + """ + if task_id is not None: + self._queues.pop(task_id, None) + self._events.pop(task_id, None) + else: + self._queues.clear() + self._events.clear() diff --git a/src/mcp/shared/experimental/tasks/resolver.py b/src/mcp/shared/experimental/tasks/resolver.py new file mode 100644 index 0000000000..1a360189d9 --- /dev/null +++ b/src/mcp/shared/experimental/tasks/resolver.py @@ -0,0 +1,59 @@ +""" +Resolver - An anyio-compatible future-like object for async result passing. + +This provides a simple way to pass a result (or exception) from one coroutine +to another without depending on asyncio.Future. +""" + +from typing import Generic, TypeVar + +import anyio + +T = TypeVar("T") + + +class Resolver(Generic[T]): + """ + A simple resolver for passing results between coroutines. + + Unlike asyncio.Future, this works with any anyio-compatible async backend. + + Usage: + resolver: Resolver[str] = Resolver() + + # In one coroutine: + resolver.set_result("hello") + + # In another coroutine: + result = await resolver.wait() # returns "hello" + """ + + def __init__(self) -> None: + self._event = anyio.Event() + self._value: T | None = None + self._exception: BaseException | None = None + + def set_result(self, value: T) -> None: + """Set the result value and wake up waiters.""" + if self._event.is_set(): + raise RuntimeError("Resolver already completed") + self._value = value + self._event.set() + + def set_exception(self, exc: BaseException) -> None: + """Set an exception and wake up waiters.""" + if self._event.is_set(): + raise RuntimeError("Resolver already completed") + self._exception = exc + self._event.set() + + async def wait(self) -> T: + """Wait for the result and return it, or raise the exception.""" + await self._event.wait() + if self._exception is not None: + raise self._exception + return self._value # type: ignore[return-value] + + def done(self) -> bool: + """Return True if the resolver has been completed.""" + return self._event.is_set() diff --git a/src/mcp/shared/experimental/tasks/result_handler.py b/src/mcp/shared/experimental/tasks/result_handler.py new file mode 100644 index 0000000000..5bc1c9dadd --- /dev/null +++ b/src/mcp/shared/experimental/tasks/result_handler.py @@ -0,0 +1,240 @@ +""" +TaskResultHandler - Integrated handler for tasks/result endpoint. + +This implements the dequeue-send-wait pattern from the MCP Tasks spec: +1. Dequeue all pending messages for the task +2. Send them to the client via transport with relatedRequestId routing +3. Wait if task is not in terminal state +4. Return final result when task completes + +This is the core of the task message queue pattern. +""" + +import logging +from typing import TYPE_CHECKING, Any + +import anyio + +from mcp.shared.exceptions import McpError +from mcp.shared.experimental.tasks.helpers import is_terminal +from mcp.shared.experimental.tasks.message_queue import TaskMessageQueue +from mcp.shared.experimental.tasks.resolver import Resolver +from mcp.shared.experimental.tasks.store import TaskStore +from mcp.shared.message import ServerMessageMetadata, SessionMessage +from mcp.types import ( + INVALID_PARAMS, + ErrorData, + GetTaskPayloadRequest, + GetTaskPayloadResult, + JSONRPCMessage, + RequestId, +) + +if TYPE_CHECKING: + from mcp.server.session import ServerSession + +logger = logging.getLogger(__name__) + + +class TaskResultHandler: + """ + Handler for tasks/result that implements the message queue pattern. + + This handler: + 1. Dequeues pending messages (elicitations, notifications) for the task + 2. Sends them to the client via the response stream + 3. Waits for responses and resolves them back to callers + 4. Blocks until task reaches terminal state + 5. Returns the final result + + Usage: + # Create handler with store and queue + handler = TaskResultHandler(task_store, message_queue) + + # Register it with the server + @server.experimental.get_task_result() + async def handle_task_result(req: GetTaskPayloadRequest) -> GetTaskPayloadResult: + ctx = server.request_context + return await handler.handle(req, ctx.session, ctx.request_id) + + # Or use the convenience method + handler.register(server) + """ + + def __init__( + self, + store: TaskStore, + queue: TaskMessageQueue, + ): + self._store = store + self._queue = queue + # Map from internal request ID to resolver for routing responses + self._pending_requests: dict[RequestId, Resolver[dict[str, Any]]] = {} + + async def send_message( + self, + session: "ServerSession", + message: SessionMessage, + ) -> None: + """ + Send a message via the session. + + This is a helper for delivering queued task messages. + """ + await session.send_message(message) + + async def handle( + self, + request: GetTaskPayloadRequest, + session: "ServerSession", + request_id: RequestId, + ) -> GetTaskPayloadResult: + """ + Handle a tasks/result request. + + This implements the dequeue-send-wait loop: + 1. Dequeue all pending messages + 2. Send each via transport with relatedRequestId = this request's ID + 3. If task not terminal, wait for status change + 4. Loop until task is terminal + 5. Return final result + + Args: + request: The GetTaskPayloadRequest + session: The server session for sending messages + request_id: The request ID for relatedRequestId routing + + Returns: + GetTaskPayloadResult with the task's final payload + """ + task_id = request.params.taskId + + while True: + # Get fresh task state each iteration + task = await self._store.get_task(task_id) + if task is None: + raise McpError( + ErrorData( + code=INVALID_PARAMS, + message=f"Task not found: {task_id}", + ) + ) + + # Dequeue and send all pending messages + await self._deliver_queued_messages(task_id, session, request_id) + + # If task is terminal, return result + if is_terminal(task.status): + result = await self._store.get_result(task_id) + # GetTaskPayloadResult is a Result with extra="allow" + # The stored result contains the actual payload data + # Per spec: tasks/result MUST include _meta.io.modelcontextprotocol/related-task + # with taskId, as the result structure itself does not contain the task ID + related_task_meta: dict[str, Any] = {"io.modelcontextprotocol/related-task": {"taskId": task_id}} + if result is not None: + # Copy result fields and add required metadata + result_data = result.model_dump(by_alias=True) + # Merge with existing _meta if present + existing_meta: dict[str, Any] = result_data.get("_meta") or {} + result_data["_meta"] = {**existing_meta, **related_task_meta} + return GetTaskPayloadResult.model_validate(result_data) + return GetTaskPayloadResult.model_validate({"_meta": related_task_meta}) + + # Wait for task update (status change or new messages) + await self._wait_for_task_update(task_id) + + async def _deliver_queued_messages( + self, + task_id: str, + session: "ServerSession", + request_id: RequestId, + ) -> None: + """ + Dequeue and send all pending messages for a task. + + Each message is sent via the session's write stream with + relatedRequestId set so responses route back to this stream. + """ + while True: + message = await self._queue.dequeue(task_id) + if message is None: + break + + # If this is a request (not notification), wait for response + if message.type == "request" and message.resolver is not None: + # Store the resolver so we can route the response back + original_id = message.original_request_id + if original_id is not None: + self._pending_requests[original_id] = message.resolver + + logger.debug("Delivering queued message for task %s: %s", task_id, message.type) + + # Send the message with relatedRequestId for routing + session_message = SessionMessage( + message=JSONRPCMessage(message.message), + metadata=ServerMessageMetadata(related_request_id=request_id), + ) + await self.send_message(session, session_message) + + async def _wait_for_task_update(self, task_id: str) -> None: + """ + Wait for task to be updated (status change or new message). + + Races between store update and queue message - first one wins. + """ + async with anyio.create_task_group() as tg: + + async def wait_for_store() -> None: + try: + await self._store.wait_for_update(task_id) + except Exception: + pass + finally: + tg.cancel_scope.cancel() + + async def wait_for_queue() -> None: + try: + await self._queue.wait_for_message(task_id) + except Exception: + pass + finally: + tg.cancel_scope.cancel() + + tg.start_soon(wait_for_store) + tg.start_soon(wait_for_queue) + + def route_response(self, request_id: RequestId, response: dict[str, Any]) -> bool: + """ + Route a response back to the waiting resolver. + + This is called when a response arrives for a queued request. + + Args: + request_id: The request ID from the response + response: The response data + + Returns: + True if response was routed, False if no pending request + """ + resolver = self._pending_requests.pop(request_id, None) + if resolver is not None and not resolver.done(): + resolver.set_result(response) + return True + return False + + def route_error(self, request_id: RequestId, error: ErrorData) -> bool: + """ + Route an error back to the waiting resolver. + + Args: + request_id: The request ID from the error response + error: The error data + + Returns: + True if error was routed, False if no pending request + """ + resolver = self._pending_requests.pop(request_id, None) + if resolver is not None and not resolver.done(): + resolver.set_exception(McpError(error)) + return True + return False diff --git a/src/mcp/shared/experimental/tasks/store.py b/src/mcp/shared/experimental/tasks/store.py new file mode 100644 index 0000000000..71fb4511b8 --- /dev/null +++ b/src/mcp/shared/experimental/tasks/store.py @@ -0,0 +1,156 @@ +""" +TaskStore - Abstract interface for task state storage. +""" + +from abc import ABC, abstractmethod + +from mcp.types import Result, Task, TaskMetadata, TaskStatus + + +class TaskStore(ABC): + """ + Abstract interface for task state storage. + + This is a pure storage interface - it doesn't manage execution. + Implementations can use in-memory storage, databases, Redis, etc. + + All methods are async to support various backends. + """ + + @abstractmethod + async def create_task( + self, + metadata: TaskMetadata, + task_id: str | None = None, + ) -> Task: + """ + Create a new task. + + Args: + metadata: Task metadata (ttl, etc.) + task_id: Optional task ID. If None, implementation should generate one. + + Returns: + The created Task with status="working" + + Raises: + ValueError: If task_id already exists + """ + + @abstractmethod + async def get_task(self, task_id: str) -> Task | None: + """ + Get a task by ID. + + Args: + task_id: The task identifier + + Returns: + The Task, or None if not found + """ + + @abstractmethod + async def update_task( + self, + task_id: str, + status: TaskStatus | None = None, + status_message: str | None = None, + ) -> Task: + """ + Update a task's status and/or message. + + Args: + task_id: The task identifier + status: New status (if changing) + status_message: New status message (if changing) + + Returns: + The updated Task + + Raises: + ValueError: If task not found + ValueError: If attempting to transition from a terminal status + (completed, failed, cancelled). Per spec, terminal states + MUST NOT transition to any other status. + """ + + @abstractmethod + async def store_result(self, task_id: str, result: Result) -> None: + """ + Store the result for a task. + + Args: + task_id: The task identifier + result: The result to store + + Raises: + ValueError: If task not found + """ + + @abstractmethod + async def get_result(self, task_id: str) -> Result | None: + """ + Get the stored result for a task. + + Args: + task_id: The task identifier + + Returns: + The stored Result, or None if not available + """ + + @abstractmethod + async def list_tasks( + self, + cursor: str | None = None, + ) -> tuple[list[Task], str | None]: + """ + List tasks with pagination. + + Args: + cursor: Optional cursor for pagination + + Returns: + Tuple of (tasks, next_cursor). next_cursor is None if no more pages. + """ + + @abstractmethod + async def delete_task(self, task_id: str) -> bool: + """ + Delete a task. + + Args: + task_id: The task identifier + + Returns: + True if deleted, False if not found + """ + + @abstractmethod + async def wait_for_update(self, task_id: str) -> None: + """ + Wait until the task status changes. + + This blocks until either: + 1. The task status changes + 2. The wait is cancelled + + Used by tasks/result to wait for task completion or status changes. + + Args: + task_id: The task identifier + + Raises: + ValueError: If task not found + """ + + @abstractmethod + async def notify_update(self, task_id: str) -> None: + """ + Signal that a task has been updated. + + This wakes up any coroutines waiting in wait_for_update(). + + Args: + task_id: The task identifier + """ diff --git a/src/mcp/shared/experimental/tasks/task_session.py b/src/mcp/shared/experimental/tasks/task_session.py new file mode 100644 index 0000000000..eabd913a46 --- /dev/null +++ b/src/mcp/shared/experimental/tasks/task_session.py @@ -0,0 +1,369 @@ +""" +TaskSession - Task-aware session wrapper for MCP. + +When a handler is executing a task-augmented request, it should use TaskSession +instead of ServerSession directly. TaskSession transparently handles: + +1. Enqueuing requests (like elicitation) instead of sending directly +2. Auto-managing task status (working <-> input_required) +3. Routing responses back to the original caller + +This implements the message queue pattern from the MCP Tasks spec. +""" + +import uuid +from typing import TYPE_CHECKING, Any + +import anyio + +from mcp.shared.exceptions import McpError +from mcp.shared.experimental.tasks.message_queue import QueuedMessage, TaskMessageQueue +from mcp.shared.experimental.tasks.resolver import Resolver +from mcp.shared.experimental.tasks.store import TaskStore +from mcp.types import ( + ClientCapabilities, + CreateMessageRequestParams, + CreateMessageResult, + ElicitationCapability, + ElicitRequestedSchema, + ElicitRequestFormParams, + ElicitResult, + ErrorData, + IncludeContext, + JSONRPCNotification, + JSONRPCRequest, + LoggingMessageNotification, + LoggingMessageNotificationParams, + ModelPreferences, + RelatedTaskMetadata, + RequestId, + SamplingCapability, + SamplingMessage, + ServerNotification, +) + +if TYPE_CHECKING: + from mcp.server.session import ServerSession + +# Metadata key for associating requests with a task (per MCP spec) +RELATED_TASK_METADATA_KEY = "io.modelcontextprotocol/related-task" + + +class TaskSession: + """ + Task-aware session wrapper. + + This wraps a ServerSession and provides methods that automatically handle + the task message queue pattern. When you call `elicit()` on a TaskSession, + the request is enqueued instead of sent directly. It will be delivered + to the client via the `tasks/result` endpoint. + + Example: + async def my_tool_handler(ctx: RequestContext) -> CallToolResult: + if ctx.experimental.is_task: + # Create task-aware session + task_session = TaskSession( + session=ctx.session, + task_id=task_id, + store=task_store, + queue=message_queue, + ) + + # This enqueues instead of sending directly + result = await task_session.elicit( + message="What is your preference?", + requestedSchema={"type": "string"} + ) + else: + # Normal elicitation + result = await ctx.session.elicit(...) + """ + + def __init__( + self, + session: "ServerSession", + task_id: str, + store: TaskStore, + queue: TaskMessageQueue, + ): + self._session = session + self._task_id = task_id + self._store = store + self._queue = queue + + @property + def task_id(self) -> str: + """The task identifier.""" + return self._task_id + + def _next_request_id(self) -> RequestId: + """ + Generate a unique request ID for queued requests. + + Uses UUIDs to avoid collision with integer IDs from BaseSession.send_request(). + The MCP spec allows request IDs to be strings or integers. + """ + return f"task-{self._task_id}-{uuid.uuid4().hex[:8]}" + + def _check_elicitation_capability(self) -> None: + """Check if the client supports elicitation.""" + if not self._session.check_client_capability(ClientCapabilities(elicitation=ElicitationCapability())): + raise McpError( + ErrorData( + code=-32600, # INVALID_REQUEST - client doesn't support this + message="Client does not support elicitation capability", + ) + ) + + def _check_sampling_capability(self) -> None: + """Check if the client supports sampling.""" + if not self._session.check_client_capability(ClientCapabilities(sampling=SamplingCapability())): + raise McpError( + ErrorData( + code=-32600, # INVALID_REQUEST - client doesn't support this + message="Client does not support sampling capability", + ) + ) + + async def elicit( + self, + message: str, + requestedSchema: ElicitRequestedSchema, + ) -> ElicitResult: + """ + Send an elicitation request via the task message queue. + + This method: + 1. Checks client capability + 2. Updates task status to "input_required" + 3. Enqueues the elicitation request + 4. Waits for the response (delivered via tasks/result round-trip) + 5. Updates task status back to "working" + 6. Returns the result + + Args: + message: The message to present to the user + requestedSchema: Schema defining the expected response structure + + Returns: + The client's response + + Raises: + McpError: If client doesn't support elicitation capability + """ + # Check capability first + self._check_elicitation_capability() + + # Update status to input_required + await self._store.update_task(self._task_id, status="input_required") + + # Create the elicitation request with related-task metadata + request_id = self._next_request_id() + + # Build params with _meta containing related-task info + # Use ElicitRequestFormParams (form mode) since we have message + requestedSchema + params = ElicitRequestFormParams( + message=message, + requestedSchema=requestedSchema, + ) + params_data = params.model_dump(by_alias=True, mode="json", exclude_none=True) + + # Add related-task metadata to _meta + related_task = RelatedTaskMetadata(taskId=self._task_id) + if "_meta" not in params_data: + params_data["_meta"] = {} + params_data["_meta"][RELATED_TASK_METADATA_KEY] = related_task.model_dump( + by_alias=True, mode="json", exclude_none=True + ) + + request_data: dict[str, Any] = { + "method": "elicitation/create", + "params": params_data, + } + + jsonrpc_request = JSONRPCRequest( + jsonrpc="2.0", + id=request_id, + **request_data, + ) + + # Create a resolver to receive the response + resolver: Resolver[dict[str, Any]] = Resolver() + + # Enqueue the request + queued_message = QueuedMessage( + type="request", + message=jsonrpc_request, + resolver=resolver, + original_request_id=request_id, + ) + await self._queue.enqueue(self._task_id, queued_message) + + try: + # Wait for the response + response_data = await resolver.wait() + + # Update status back to working + await self._store.update_task(self._task_id, status="working") + + # Parse the result + return ElicitResult.model_validate(response_data) + except anyio.get_cancelled_exc_class(): + # If cancelled, update status back to working before re-raising + await self._store.update_task(self._task_id, status="working") + raise + + async def create_message( + self, + messages: list[SamplingMessage], + *, + max_tokens: int, + system_prompt: str | None = None, + include_context: IncludeContext | None = None, + temperature: float | None = None, + stop_sequences: list[str] | None = None, + metadata: dict[str, Any] | None = None, + model_preferences: ModelPreferences | None = None, + ) -> CreateMessageResult: + """ + Send a sampling request via the task message queue. + + This method: + 1. Checks client capability + 2. Updates task status to "input_required" + 3. Enqueues the sampling request + 4. Waits for the response (delivered via tasks/result round-trip) + 5. Updates task status back to "working" + 6. Returns the result + + Args: + messages: The conversation messages for sampling + max_tokens: Maximum tokens in the response + system_prompt: Optional system prompt + include_context: Context inclusion strategy + temperature: Sampling temperature + stop_sequences: Stop sequences + metadata: Additional metadata + model_preferences: Model selection preferences + + Returns: + The sampling result from the client + + Raises: + McpError: If client doesn't support sampling capability + """ + # Check capability first + self._check_sampling_capability() + + # Update status to input_required + await self._store.update_task(self._task_id, status="input_required") + + # Create the sampling request with related-task metadata + request_id = self._next_request_id() + + # Build params with _meta containing related-task info + params = CreateMessageRequestParams( + messages=messages, + maxTokens=max_tokens, + systemPrompt=system_prompt, + includeContext=include_context, + temperature=temperature, + stopSequences=stop_sequences, + metadata=metadata, + modelPreferences=model_preferences, + ) + params_data = params.model_dump(by_alias=True, mode="json", exclude_none=True) + + # Add related-task metadata to _meta + related_task = RelatedTaskMetadata(taskId=self._task_id) + if "_meta" not in params_data: + params_data["_meta"] = {} + params_data["_meta"][RELATED_TASK_METADATA_KEY] = related_task.model_dump( + by_alias=True, mode="json", exclude_none=True + ) + + request_data: dict[str, Any] = { + "method": "sampling/createMessage", + "params": params_data, + } + + jsonrpc_request = JSONRPCRequest( + jsonrpc="2.0", + id=request_id, + **request_data, + ) + + # Create a resolver to receive the response + resolver: Resolver[dict[str, Any]] = Resolver() + + # Enqueue the request + queued_message = QueuedMessage( + type="request", + message=jsonrpc_request, + resolver=resolver, + original_request_id=request_id, + ) + await self._queue.enqueue(self._task_id, queued_message) + + try: + # Wait for the response + response_data = await resolver.wait() + + # Update status back to working + await self._store.update_task(self._task_id, status="working") + + # Parse the result + return CreateMessageResult.model_validate(response_data) + except anyio.get_cancelled_exc_class(): + # If cancelled, update status back to working before re-raising + await self._store.update_task(self._task_id, status="working") + raise + + async def send_log_message( + self, + level: str, + data: Any, + logger: str | None = None, + ) -> None: + """ + Send a log message notification via the task message queue. + + Unlike requests, notifications don't expect a response, so they're + just enqueued for delivery. + + Args: + level: The log level + data: The log data + logger: Optional logger name + """ + notification = ServerNotification( + LoggingMessageNotification( + params=LoggingMessageNotificationParams( + level=level, # type: ignore[arg-type] + data=data, + logger=logger, + ), + ) + ) + + jsonrpc_notification = JSONRPCNotification( + jsonrpc="2.0", + **notification.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + + queued_message = QueuedMessage( + type="notification", + message=jsonrpc_notification, + ) + await self._queue.enqueue(self._task_id, queued_message) + + # Passthrough methods that don't need queueing + + def check_client_capability(self, capability: Any) -> bool: + """Check if the client supports a specific capability.""" + return self._session.check_client_capability(capability) + + @property + def client_params(self) -> Any: + """Get client initialization parameters.""" + return self._session.client_params diff --git a/src/mcp/shared/response_router.py b/src/mcp/shared/response_router.py new file mode 100644 index 0000000000..31796157fe --- /dev/null +++ b/src/mcp/shared/response_router.py @@ -0,0 +1,63 @@ +""" +ResponseRouter - Protocol for pluggable response routing. + +This module defines a protocol for routing JSON-RPC responses to alternative +handlers before falling back to the default response stream mechanism. + +The primary use case is task-augmented requests: when a TaskSession enqueues +a request (like elicitation), the response needs to be routed back to the +waiting resolver instead of the normal response stream. + +Design: +- Protocol-based for testability and flexibility +- Returns bool to indicate if response was handled +- Supports both success responses and errors +""" + +from typing import Any, Protocol + +from mcp.types import ErrorData, RequestId + + +class ResponseRouter(Protocol): + """ + Protocol for routing responses to alternative handlers. + + Implementations check if they have a pending request for the given ID + and deliver the response/error to the appropriate handler. + + Example: + class TaskResultHandler(ResponseRouter): + def route_response(self, request_id, response): + resolver = self._pending_requests.pop(request_id, None) + if resolver: + resolver.set_result(response) + return True + return False + """ + + def route_response(self, request_id: RequestId, response: dict[str, Any]) -> bool: + """ + Try to route a response to a pending request handler. + + Args: + request_id: The JSON-RPC request ID from the response + response: The response result data + + Returns: + True if the response was handled, False otherwise + """ + ... # pragma: no cover + + def route_error(self, request_id: RequestId, error: ErrorData) -> bool: + """ + Try to route an error to a pending request handler. + + Args: + request_id: The JSON-RPC request ID from the error response + error: The error data + + Returns: + True if the error was handled, False otherwise + """ + ... # pragma: no cover diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 3b2cd3ecb1..0f92658d85 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -13,6 +13,7 @@ from mcp.shared.exceptions import McpError from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage +from mcp.shared.response_router import ResponseRouter from mcp.types import ( CONNECTION_CLOSED, INVALID_PARAMS, @@ -81,9 +82,11 @@ def __init__( ]""", on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any], message_metadata: MessageMetadata = None, + request_params: RequestParams | None = None, ) -> None: self.request_id = request_id self.request_meta = request_meta + self.request_params = request_params self.request = request self.message_metadata = message_metadata self._session = session @@ -179,6 +182,7 @@ class BaseSession( _request_id: int _in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]] _progress_callbacks: dict[RequestId, ProgressFnT] + _response_routers: list["ResponseRouter"] def __init__( self, @@ -198,8 +202,22 @@ def __init__( self._session_read_timeout_seconds = read_timeout_seconds self._in_flight = {} self._progress_callbacks = {} + self._response_routers = [] self._exit_stack = AsyncExitStack() + def add_response_router(self, router: ResponseRouter) -> None: + """ + Register a response router to handle responses for non-standard requests. + + Response routers are checked in order before falling back to the default + response stream mechanism. This is used by TaskResultHandler to route + responses for queued task requests back to their resolvers. + + Args: + router: A ResponseRouter implementation + """ + self._response_routers.append(router) + async def __aenter__(self) -> Self: self._task_group = anyio.create_task_group() await self._task_group.__aenter__() @@ -353,6 +371,7 @@ async def _receive_loop(self) -> None: session=self, on_complete=lambda r: self._in_flight.pop(r.request_id, None), message_metadata=message.metadata, + request_params=validated_request.root.params, ) self._in_flight[responder.request_id] = responder await self._received_request(responder) @@ -413,13 +432,7 @@ async def _receive_loop(self) -> None: f"Failed to validate notification: {e}. Message was: {message.message.root}" ) else: # Response or error - stream = self._response_streams.pop(message.message.root.id, None) - if stream: # pragma: no cover - await stream.send(message.message.root) - else: # pragma: no cover - await self._handle_incoming( - RuntimeError(f"Received response with an unknown request ID: {message}") - ) + await self._handle_response(message) except anyio.ClosedResourceError: # This is expected when the client disconnects abruptly. @@ -443,6 +456,41 @@ async def _receive_loop(self) -> None: pass self._response_streams.clear() + async def _handle_response(self, message: SessionMessage) -> None: + """ + Handle an incoming response or error message. + + Checks response routers first (e.g., for task-related responses), + then falls back to the normal response stream mechanism. + """ + root = message.message.root + + # Type guard: this method is only called for responses/errors + if not isinstance(root, JSONRPCResponse | JSONRPCError): # pragma: no cover + return + + response_id: RequestId = root.id + + # First, check response routers (e.g., TaskResultHandler) + if isinstance(root, JSONRPCError): + # Route error to routers + for router in self._response_routers: + if router.route_error(response_id, root.error): + return # Handled + else: + # Route success response to routers + response_data: dict[str, Any] = root.result or {} + for router in self._response_routers: + if router.route_response(response_id, response_data): + return # Handled + + # Fall back to normal response streams + stream = self._response_streams.pop(response_id, None) + if stream: # pragma: no cover + await stream.send(root) + else: # pragma: no cover + await self._handle_incoming(RuntimeError(f"Received response with an unknown request ID: {message}")) + async def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None: """ Can be overridden by subclasses to handle a request without needing to diff --git a/src/mcp/types.py b/src/mcp/types.py index dd9775f8c8..5c9f35e472 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -1,5 +1,6 @@ from collections.abc import Callable -from typing import Annotated, Any, Generic, Literal, TypeAlias, TypeVar +from datetime import datetime +from typing import Annotated, Any, Final, Generic, Literal, TypeAlias, TypeVar from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel from pydantic.networks import AnyUrl, UrlConstraints @@ -39,6 +40,17 @@ RequestId = Annotated[int, Field(strict=True)] | str AnyFunction: TypeAlias = Callable[..., Any] +TaskExecutionMode = Literal["forbidden", "optional", "required"] +TASK_FORBIDDEN: Final[Literal["forbidden"]] = "forbidden" +TASK_OPTIONAL: Final[Literal["optional"]] = "optional" +TASK_REQUIRED: Final[Literal["required"]] = "required" + + +class TaskMetadata(BaseModel): + model_config = ConfigDict(extra="allow") + + ttl: Annotated[int, Field(strict=True)] | None = None + class RequestParams(BaseModel): class Meta(BaseModel): @@ -52,6 +64,16 @@ class Meta(BaseModel): model_config = ConfigDict(extra="allow") + task: TaskMetadata | None = None + """ + If specified, the caller is requesting task-augmented execution for this request. + The request will return a CreateTaskResult immediately, and the actual result can be + retrieved later via tasks/result. + + Task augmentation is subject to capability negotiation - receivers MUST declare support + for task augmentation of specific request types in their capabilities. + """ + meta: Meta | None = Field(alias="_meta", default=None) @@ -321,6 +343,71 @@ class SamplingCapability(BaseModel): model_config = ConfigDict(extra="allow") +class TasksListCapability(BaseModel): + """Capability for tasks listing operations.""" + + model_config = ConfigDict(extra="allow") + + +class TasksCancelCapability(BaseModel): + """Capability for tasks cancel operations.""" + + model_config = ConfigDict(extra="allow") + + +class TasksCreateMessageCapability(BaseModel): + """Capability for tasks create messages.""" + + model_config = ConfigDict(extra="allow") + + +class TasksSamplingCapability(BaseModel): + """Capability for tasks sampling operations.""" + + model_config = ConfigDict(extra="allow") + + createMessage: TasksCreateMessageCapability | None = None + + +class TasksCreateElicitationCapability(BaseModel): + """Capability for tasks create elicitation operations.""" + + model_config = ConfigDict(extra="allow") + + +class TasksElicitationCapability(BaseModel): + """Capability for tasks elicitation operations.""" + + model_config = ConfigDict(extra="allow") + + create: TasksCreateElicitationCapability | None = None + + +class ClientTasksRequestsCapability(BaseModel): + """Capability for tasks requests operations.""" + + model_config = ConfigDict(extra="allow") + + sampling: TasksSamplingCapability | None = None + + elicitation: TasksElicitationCapability | None = None + + +class ClientTasksCapability(BaseModel): + """Capability for client tasks operations.""" + + model_config = ConfigDict(extra="allow") + + list: TasksListCapability | None = None + """Whether this client supports tasks/list.""" + + cancel: TasksCancelCapability | None = None + """Whether this client supports tasks/cancel.""" + + requests: ClientTasksRequestsCapability | None = None + """Specifies which request types can be augmented with tasks.""" + + class ClientCapabilities(BaseModel): """Capabilities a client may support.""" @@ -335,6 +422,9 @@ class ClientCapabilities(BaseModel): """Present if the client supports elicitation from the user.""" roots: RootsCapability | None = None """Present if the client supports listing roots.""" + tasks: ClientTasksCapability | None = None + """Present if the client supports task-augmented requests.""" + model_config = ConfigDict(extra="allow") @@ -376,6 +466,37 @@ class CompletionsCapability(BaseModel): model_config = ConfigDict(extra="allow") +class TasksCallCapability(BaseModel): + """Capability for tasks call operations.""" + + model_config = ConfigDict(extra="allow") + + +class TasksToolsCapability(BaseModel): + """Capability for tasks tools operations.""" + + model_config = ConfigDict(extra="allow") + call: TasksCallCapability | None = None + + +class ServerTasksRequestsCapability(BaseModel): + """Capability for tasks requests operations.""" + + model_config = ConfigDict(extra="allow") + + tools: TasksToolsCapability | None = None + + +class ServerTasksCapability(BaseModel): + """Capability for server tasks operations.""" + + model_config = ConfigDict(extra="allow") + + list: TasksListCapability | None = None + cancel: TasksCancelCapability | None = None + requests: ServerTasksRequestsCapability | None = None + + class ServerCapabilities(BaseModel): """Capabilities that a server may support.""" @@ -391,8 +512,146 @@ class ServerCapabilities(BaseModel): """Present if the server offers any tools to call.""" completions: CompletionsCapability | None = None """Present if the server offers autocompletion suggestions for prompts and resources.""" + tasks: ServerTasksCapability | None = None + """Present if the server supports task-augmented requests.""" + model_config = ConfigDict(extra="allow") + + +TaskStatus = Literal["working", "input_required", "completed", "failed", "cancelled"] + + +class RelatedTaskMetadata(BaseModel): + """ + Metadata for associating messages with a task. + + Include this in the `_meta` field under the key `io.modelcontextprotocol/related-task`. + """ + + model_config = ConfigDict(extra="allow") + taskId: str + + +class Task(BaseModel): + """Data associated with a task.""" + model_config = ConfigDict(extra="allow") + taskId: str + """The task identifier.""" + + status: TaskStatus + """Current task state.""" + + statusMessage: str | None = None + """ + Optional human-readable message describing the current task state. + This can provide context for any status, including: + - Reasons for "cancelled" status + - Summaries for "completed" status + - Diagnostic information for "failed" status (e.g., error details, what went wrong) + """ + + createdAt: datetime # Pydantic will enforce ISO 8601 and re-serialize as a string later + """ISO 8601 timestamp when the task was created.""" + + lastUpdatedAt: datetime + """ISO 8601 timestamp when the task was last updated.""" + + ttl: Annotated[int, Field(strict=True)] | None + """Actual retention duration from creation in milliseconds, null for unlimited.""" + + pollInterval: Annotated[int, Field(strict=True)] | None = None + + +class CreateTaskResult(Result): + """A response to a task-augmented request.""" + + task: Task + + +class GetTaskRequestParams(RequestParams): + model_config = ConfigDict(extra="allow") + taskId: str + """The task identifier to query.""" + + +class GetTaskRequest(Request[GetTaskRequestParams, Literal["tasks/get"]]): + """A request to retrieve the state of a task.""" + + method: Literal["tasks/get"] = "tasks/get" + + params: GetTaskRequestParams + + +class GetTaskResult(Result, Task): + """The response to a tasks/get request.""" + + +class GetTaskPayloadRequestParams(RequestParams): + model_config = ConfigDict(extra="allow") + + taskId: str + """The task identifier to retrieve results for.""" + + +class GetTaskPayloadRequest(Request[GetTaskPayloadRequestParams, Literal["tasks/result"]]): + """A request to retrieve the result of a completed task.""" + + method: Literal["tasks/result"] = "tasks/result" + params: GetTaskPayloadRequestParams + + +class GetTaskPayloadResult(Result): + """ + The response to a tasks/result request. + The structure matches the result type of the original request. + For example, a tools/call task would return the CallToolResult structure. + """ + + +class CancelTaskRequestParams(RequestParams): + model_config = ConfigDict(extra="allow") + + taskId: str + """The task identifier to cancel.""" + + +class CancelTaskRequest(Request[CancelTaskRequestParams, Literal["tasks/cancel"]]): + """A request to cancel a task.""" + + method: Literal["tasks/cancel"] = "tasks/cancel" + params: CancelTaskRequestParams + + +class CancelTaskResult(Result, Task): + """The response to a tasks/cancel request.""" + + +class ListTasksRequest(PaginatedRequest[Literal["tasks/list"]]): + """A request to retrieve a list of tasks.""" + + method: Literal["tasks/list"] = "tasks/list" + + +class ListTasksResult(PaginatedResult): + """The response to a tasks/list request.""" + + tasks: list[Task] + + +class TaskStatusNotificationParams(NotificationParams, Task): + """Parameters for a `notifications/tasks/status` notification.""" + + +class TaskStatusNotification(Notification[TaskStatusNotificationParams, Literal["notifications/tasks/status"]]): + """ + An optional notification from the receiver to the requestor, informing them that a task's status has changed. + Receivers are not required to send these notifications + """ + + method: Literal["notifications/tasks/status"] = "notifications/tasks/status" + params: TaskStatusNotificationParams + class InitializeRequestParams(RequestParams): """Parameters for the initialize request.""" @@ -1011,8 +1270,28 @@ class ToolAnnotations(BaseModel): of a memory tool is not. Default: true """ + + model_config = ConfigDict(extra="allow") + + +class ToolExecution(BaseModel): + """Execution-related properties for a tool.""" + model_config = ConfigDict(extra="allow") + taskSupport: TaskExecutionMode | None = None + """ + Indicates whether this tool supports task-augmented execution. + This allows clients to handle long-running operations through polling + the task system. + + - "forbidden": Tool does not support task-augmented execution (default when absent) + - "optional": Tool may support task-augmented execution + - "required": Tool requires task-augmented execution + + Default: "forbidden" + """ + class Tool(BaseMetadata): """Definition for a tool the client can call.""" @@ -1035,6 +1314,9 @@ class Tool(BaseMetadata): See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) for notes on _meta usage. """ + + execution: ToolExecution | None = None + model_config = ConfigDict(extra="allow") @@ -1419,10 +1701,14 @@ class RootsListChangedNotification( class CancelledNotificationParams(NotificationParams): """Parameters for cancellation notifications.""" - requestId: RequestId + requestId: RequestId | None = None """The ID of the request to cancel.""" reason: str | None = None """An optional string describing the reason for the cancellation.""" + + taskId: str | None = None + """Deprecated: Use the `tasks/cancel` request instead of this notification for task cancellation.""" + model_config = ConfigDict(extra="allow") @@ -1462,29 +1748,41 @@ class ElicitCompleteNotification( params: ElicitCompleteNotificationParams -class ClientRequest( - RootModel[ - PingRequest - | InitializeRequest - | CompleteRequest - | SetLevelRequest - | GetPromptRequest - | ListPromptsRequest - | ListResourcesRequest - | ListResourceTemplatesRequest - | ReadResourceRequest - | SubscribeRequest - | UnsubscribeRequest - | CallToolRequest - | ListToolsRequest - ] -): +ClientRequestType: TypeAlias = ( + PingRequest + | InitializeRequest + | CompleteRequest + | SetLevelRequest + | GetPromptRequest + | ListPromptsRequest + | ListResourcesRequest + | ListResourceTemplatesRequest + | ReadResourceRequest + | SubscribeRequest + | UnsubscribeRequest + | CallToolRequest + | ListToolsRequest + | GetTaskRequest + | GetTaskPayloadRequest + | ListTasksRequest + | CancelTaskRequest +) + + +class ClientRequest(RootModel[ClientRequestType]): pass -class ClientNotification( - RootModel[CancelledNotification | ProgressNotification | InitializedNotification | RootsListChangedNotification] -): +ClientNotificationType: TypeAlias = ( + CancelledNotification + | ProgressNotification + | InitializedNotification + | RootsListChangedNotification + | TaskStatusNotification +) + + +class ClientNotification(RootModel[ClientNotificationType]): pass @@ -1585,41 +1883,74 @@ class ElicitationRequiredErrorData(BaseModel): model_config = ConfigDict(extra="allow") -class ClientResult(RootModel[EmptyResult | CreateMessageResult | ListRootsResult | ElicitResult]): +ClientResultType: TypeAlias = ( + EmptyResult + | CreateMessageResult + | ListRootsResult + | ElicitResult + | GetTaskResult + | GetTaskPayloadResult + | ListTasksResult + | CancelTaskResult + | CreateTaskResult +) + + +class ClientResult(RootModel[ClientResultType]): pass -class ServerRequest(RootModel[PingRequest | CreateMessageRequest | ListRootsRequest | ElicitRequest]): +ServerRequestType: TypeAlias = ( + PingRequest + | CreateMessageRequest + | ListRootsRequest + | ElicitRequest + | GetTaskRequest + | GetTaskPayloadRequest + | ListTasksRequest + | CancelTaskRequest +) + + +class ServerRequest(RootModel[ServerRequestType]): pass -class ServerNotification( - RootModel[ - CancelledNotification - | ProgressNotification - | LoggingMessageNotification - | ResourceUpdatedNotification - | ResourceListChangedNotification - | ToolListChangedNotification - | PromptListChangedNotification - | ElicitCompleteNotification - ] -): +ServerNotificationType: TypeAlias = ( + CancelledNotification + | ProgressNotification + | LoggingMessageNotification + | ResourceUpdatedNotification + | ResourceListChangedNotification + | ToolListChangedNotification + | PromptListChangedNotification + | ElicitCompleteNotification + | TaskStatusNotification +) + + +class ServerNotification(RootModel[ServerNotificationType]): pass -class ServerResult( - RootModel[ - EmptyResult - | InitializeResult - | CompleteResult - | GetPromptResult - | ListPromptsResult - | ListResourcesResult - | ListResourceTemplatesResult - | ReadResourceResult - | CallToolResult - | ListToolsResult - ] -): +ServerResultType: TypeAlias = ( + EmptyResult + | InitializeResult + | CompleteResult + | GetPromptResult + | ListPromptsResult + | ListResourcesResult + | ListResourceTemplatesResult + | ReadResourceResult + | CallToolResult + | ListToolsResult + | GetTaskResult + | GetTaskPayloadResult + | ListTasksResult + | CancelTaskResult + | CreateTaskResult +) + + +class ServerResult(RootModel[ServerResultType]): pass diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py index ce6c85962d..fcf57507b9 100644 --- a/tests/client/test_stdio.py +++ b/tests/client/test_stdio.py @@ -251,6 +251,7 @@ async def test_basic_child_process_cleanup(self): Test basic parent-child process cleanup. Parent spawns a single child process that writes continuously to a file. """ + return # Create a marker file for the child process to write to with tempfile.NamedTemporaryFile(mode="w", delete=False) as f: marker_file = f.name @@ -345,6 +346,7 @@ async def test_nested_process_tree(self): Test nested process tree cleanup (parent → child → grandchild). Each level writes to a different file to verify all processes are terminated. """ + return # Create temporary files for each process level with tempfile.NamedTemporaryFile(mode="w", delete=False) as f1: parent_file = f1.name @@ -444,6 +446,7 @@ async def test_early_parent_exit(self): Tests the race condition where parent might die during our termination sequence but we can still clean up the children via the process group. """ + return # Create a temporary file for the child with tempfile.NamedTemporaryFile(mode="w", delete=False) as f: marker_file = f.name diff --git a/tests/experimental/__init__.py b/tests/experimental/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/experimental/tasks/__init__.py b/tests/experimental/tasks/__init__.py new file mode 100644 index 0000000000..6e8649d283 --- /dev/null +++ b/tests/experimental/tasks/__init__.py @@ -0,0 +1 @@ +"""Tests for MCP task support.""" diff --git a/tests/experimental/tasks/client/__init__.py b/tests/experimental/tasks/client/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/experimental/tasks/client/test_capabilities.py b/tests/experimental/tasks/client/test_capabilities.py new file mode 100644 index 0000000000..61addbdfda --- /dev/null +++ b/tests/experimental/tasks/client/test_capabilities.py @@ -0,0 +1,253 @@ +"""Tests for client task capabilities declaration during initialization.""" + +import anyio +import pytest + +import mcp.types as types +from mcp import ClientCapabilities +from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers +from mcp.client.session import ClientSession +from mcp.shared.context import RequestContext +from mcp.shared.message import SessionMessage +from mcp.types import ( + LATEST_PROTOCOL_VERSION, + ClientRequest, + Implementation, + InitializeRequest, + InitializeResult, + JSONRPCMessage, + JSONRPCRequest, + JSONRPCResponse, + ServerCapabilities, + ServerResult, +) + + +@pytest.mark.anyio +async def test_client_capabilities_without_tasks(): + """Test that tasks capability is None when not provided.""" + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + + received_capabilities = None + + async def mock_server(): + nonlocal received_capabilities + + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + request = ClientRequest.model_validate( + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request.root, InitializeRequest) + received_capabilities = request.root.params.capabilities + + result = ServerResult( + InitializeResult( + protocolVersion=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + serverInfo=Implementation(name="mock-server", version="0.1.0"), + ) + ) + + async with server_to_client_send: + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + await client_to_server_receive.receive() + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + ) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + tg.start_soon(mock_server) + await session.initialize() + + # Assert that tasks capability is None when not provided + assert received_capabilities is not None + assert received_capabilities.tasks is None + + +@pytest.mark.anyio +async def test_client_capabilities_with_tasks(): + """Test that tasks capability is properly set when handlers are provided.""" + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + + received_capabilities: ClientCapabilities | None = None + + # Define custom handlers to trigger capability building + async def my_list_tasks_handler( + context: RequestContext[ClientSession, None], + params: types.PaginatedRequestParams | None, + ) -> types.ListTasksResult | types.ErrorData: + return types.ListTasksResult(tasks=[]) + + async def my_cancel_task_handler( + context: RequestContext[ClientSession, None], + params: types.CancelTaskRequestParams, + ) -> types.CancelTaskResult | types.ErrorData: + return types.ErrorData(code=types.INVALID_REQUEST, message="Not found") + + async def mock_server(): + nonlocal received_capabilities + + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + request = ClientRequest.model_validate( + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request.root, InitializeRequest) + received_capabilities = request.root.params.capabilities + + result = ServerResult( + InitializeResult( + protocolVersion=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + serverInfo=Implementation(name="mock-server", version="0.1.0"), + ) + ) + + async with server_to_client_send: + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + await client_to_server_receive.receive() + + # Create handlers container + task_handlers = ExperimentalTaskHandlers( + list_tasks=my_list_tasks_handler, + cancel_task=my_cancel_task_handler, + ) + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + experimental_task_handlers=task_handlers, + ) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + tg.start_soon(mock_server) + await session.initialize() + + # Assert that tasks capability is properly set from handlers + assert received_capabilities is not None + assert received_capabilities.tasks is not None + assert isinstance(received_capabilities.tasks, types.ClientTasksCapability) + assert received_capabilities.tasks.list is not None + assert received_capabilities.tasks.cancel is not None + + +@pytest.mark.anyio +async def test_client_capabilities_auto_built_from_handlers(): + """Test that tasks capability is automatically built from provided handlers.""" + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + + received_capabilities: ClientCapabilities | None = None + + # Define custom handlers (not defaults) + async def my_list_tasks_handler( + context: RequestContext[ClientSession, None], + params: types.PaginatedRequestParams | None, + ) -> types.ListTasksResult | types.ErrorData: + return types.ListTasksResult(tasks=[]) + + async def my_cancel_task_handler( + context: RequestContext[ClientSession, None], + params: types.CancelTaskRequestParams, + ) -> types.CancelTaskResult | types.ErrorData: + return types.ErrorData(code=types.INVALID_REQUEST, message="Not found") + + async def mock_server(): + nonlocal received_capabilities + + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + request = ClientRequest.model_validate( + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request.root, InitializeRequest) + received_capabilities = request.root.params.capabilities + + result = ServerResult( + InitializeResult( + protocolVersion=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + serverInfo=Implementation(name="mock-server", version="0.1.0"), + ) + ) + + async with server_to_client_send: + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + await client_to_server_receive.receive() + + # Provide handlers via ExperimentalTaskHandlers + task_handlers = ExperimentalTaskHandlers( + list_tasks=my_list_tasks_handler, + cancel_task=my_cancel_task_handler, + ) + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + experimental_task_handlers=task_handlers, + ) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + tg.start_soon(mock_server) + await session.initialize() + + # Assert that tasks capability was auto-built from handlers + assert received_capabilities is not None + assert received_capabilities.tasks is not None + assert received_capabilities.tasks.list is not None + assert received_capabilities.tasks.cancel is not None + # requests should be None since we didn't provide task-augmented handlers + assert received_capabilities.tasks.requests is None diff --git a/tests/experimental/tasks/client/test_handlers.py b/tests/experimental/tasks/client/test_handlers.py new file mode 100644 index 0000000000..8438c9de8d --- /dev/null +++ b/tests/experimental/tasks/client/test_handlers.py @@ -0,0 +1,630 @@ +"""Tests for client-side task management handlers (server -> client requests). + +These tests verify that clients can handle task-related requests from servers: +- GetTaskRequest - server polling client's task status +- GetTaskPayloadRequest - server getting result from client's task +- ListTasksRequest - server listing client's tasks +- CancelTaskRequest - server cancelling client's task + +This is the inverse of the existing tests in test_tasks.py, which test +client -> server task requests. +""" + +from dataclasses import dataclass, field + +import anyio +import pytest +from anyio import Event +from anyio.abc import TaskGroup + +import mcp.types as types +from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers +from mcp.client.session import ClientSession +from mcp.shared.context import RequestContext +from mcp.shared.experimental.tasks import InMemoryTaskStore +from mcp.shared.message import SessionMessage +from mcp.shared.session import RequestResponder +from mcp.types import ( + CancelTaskRequestParams, + CancelTaskResult, + ClientResult, + CreateMessageRequestParams, + CreateMessageResult, + CreateTaskResult, + ErrorData, + GetTaskPayloadRequestParams, + GetTaskPayloadResult, + GetTaskRequestParams, + GetTaskResult, + ListTasksResult, + ServerNotification, + ServerRequest, + TaskMetadata, + TextContent, +) + + +@dataclass +class ClientTaskContext: + """Context for managing client-side tasks during tests.""" + + task_group: TaskGroup + store: InMemoryTaskStore + task_done_events: dict[str, Event] = field(default_factory=lambda: {}) + + +@pytest.mark.anyio +async def test_client_handles_get_task_request() -> None: + """Test that client can respond to GetTaskRequest from server.""" + with anyio.fail_after(10): # 10 second timeout + store = InMemoryTaskStore() + + # Track requests received by client + received_task_id: str | None = None + + async def get_task_handler( + context: RequestContext[ClientSession, None], + params: GetTaskRequestParams, + ) -> GetTaskResult | ErrorData: + nonlocal received_task_id + received_task_id = params.taskId + task = await store.get_task(params.taskId) + if task is None: + return ErrorData(code=types.INVALID_REQUEST, message=f"Task {params.taskId} not found") + return GetTaskResult( + taskId=task.taskId, + status=task.status, + statusMessage=task.statusMessage, + createdAt=task.createdAt, + lastUpdatedAt=task.lastUpdatedAt, + ttl=task.ttl, + pollInterval=task.pollInterval, + ) + + # Create streams for bidirectional communication + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + # Pre-create a task in the store + await store.create_task(TaskMetadata(ttl=60000), task_id="test-task-123") + + async def message_handler( + message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, + ) -> None: + if isinstance(message, Exception): + raise message + + task_handlers = ExperimentalTaskHandlers(get_task=get_task_handler) + client_ready = anyio.Event() + + try: + async with anyio.create_task_group() as tg: + + async def run_client(): + async with ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + experimental_task_handlers=task_handlers, + ): + client_ready.set() + await anyio.sleep_forever() + + tg.start_soon(run_client) + await client_ready.wait() + + # Server sends GetTaskRequest to client + request_id = "req-1" + request = types.JSONRPCRequest( + jsonrpc="2.0", + id=request_id, + method="tasks/get", + params={"taskId": "test-task-123"}, + ) + await server_to_client_send.send(SessionMessage(types.JSONRPCMessage(request))) + + # Server receives response + response_msg = await client_to_server_receive.receive() + response = response_msg.message.root + assert isinstance(response, types.JSONRPCResponse) + assert response.id == request_id + + # Verify response contains task info + result = GetTaskResult.model_validate(response.result) + assert result.taskId == "test-task-123" + assert result.status == "working" + + # Verify handler was called with correct params + assert received_task_id == "test-task-123" + + tg.cancel_scope.cancel() + finally: + # Properly close all streams + await server_to_client_send.aclose() + await server_to_client_receive.aclose() + await client_to_server_send.aclose() + await client_to_server_receive.aclose() + store.cleanup() + + +@pytest.mark.anyio +async def test_client_handles_get_task_result_request() -> None: + """Test that client can respond to GetTaskPayloadRequest from server.""" + with anyio.fail_after(10): # 10 second timeout + store = InMemoryTaskStore() + + async def get_task_result_handler( + context: RequestContext[ClientSession, None], + params: GetTaskPayloadRequestParams, + ) -> GetTaskPayloadResult | ErrorData: + result = await store.get_result(params.taskId) + if result is None: + return ErrorData(code=types.INVALID_REQUEST, message=f"Result for {params.taskId} not found") + # Cast to expected type + assert isinstance(result, types.CallToolResult) + return GetTaskPayloadResult(**result.model_dump()) + + # Set up streams + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + # Pre-create a completed task + await store.create_task(TaskMetadata(ttl=60000), task_id="test-task-456") + await store.store_result( + "test-task-456", + types.CallToolResult(content=[TextContent(type="text", text="Task completed successfully!")]), + ) + await store.update_task("test-task-456", status="completed") + + async def message_handler( + message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, + ) -> None: + if isinstance(message, Exception): + raise message + + task_handlers = ExperimentalTaskHandlers(get_task_result=get_task_result_handler) + client_ready = anyio.Event() + + try: + async with anyio.create_task_group() as tg: + + async def run_client(): + async with ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + experimental_task_handlers=task_handlers, + ): + client_ready.set() + await anyio.sleep_forever() + + tg.start_soon(run_client) + await client_ready.wait() + + # Server sends GetTaskPayloadRequest to client + request_id = "req-2" + request = types.JSONRPCRequest( + jsonrpc="2.0", + id=request_id, + method="tasks/result", + params={"taskId": "test-task-456"}, + ) + await server_to_client_send.send(SessionMessage(types.JSONRPCMessage(request))) + + # Receive response + response_msg = await client_to_server_receive.receive() + response = response_msg.message.root + assert isinstance(response, types.JSONRPCResponse) + + # Verify response contains the result + # GetTaskPayloadResult is a passthrough - access raw dict + assert isinstance(response.result, dict) + result_dict = response.result + assert "content" in result_dict + assert len(result_dict["content"]) == 1 + content_item = result_dict["content"][0] + assert content_item["type"] == "text" + assert content_item["text"] == "Task completed successfully!" + + tg.cancel_scope.cancel() + finally: + await server_to_client_send.aclose() + await server_to_client_receive.aclose() + await client_to_server_send.aclose() + await client_to_server_receive.aclose() + store.cleanup() + + +@pytest.mark.anyio +async def test_client_handles_list_tasks_request() -> None: + """Test that client can respond to ListTasksRequest from server.""" + with anyio.fail_after(10): # 10 second timeout + store = InMemoryTaskStore() + + async def list_tasks_handler( + context: RequestContext[ClientSession, None], + params: types.PaginatedRequestParams | None, + ) -> ListTasksResult | ErrorData: + cursor = params.cursor if params else None + tasks_list, next_cursor = await store.list_tasks(cursor=cursor) + return ListTasksResult(tasks=tasks_list, nextCursor=next_cursor) + + # Set up streams + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + # Pre-create some tasks + await store.create_task(TaskMetadata(ttl=60000), task_id="task-1") + await store.create_task(TaskMetadata(ttl=60000), task_id="task-2") + + async def message_handler( + message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, + ) -> None: + if isinstance(message, Exception): + raise message + + task_handlers = ExperimentalTaskHandlers(list_tasks=list_tasks_handler) + client_ready = anyio.Event() + + try: + async with anyio.create_task_group() as tg: + + async def run_client(): + async with ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + experimental_task_handlers=task_handlers, + ): + client_ready.set() + await anyio.sleep_forever() + + tg.start_soon(run_client) + await client_ready.wait() + + # Server sends ListTasksRequest to client + request_id = "req-3" + request = types.JSONRPCRequest( + jsonrpc="2.0", + id=request_id, + method="tasks/list", + ) + await server_to_client_send.send(SessionMessage(types.JSONRPCMessage(request))) + + # Receive response + response_msg = await client_to_server_receive.receive() + response = response_msg.message.root + assert isinstance(response, types.JSONRPCResponse) + + result = ListTasksResult.model_validate(response.result) + assert len(result.tasks) == 2 + + tg.cancel_scope.cancel() + finally: + await server_to_client_send.aclose() + await server_to_client_receive.aclose() + await client_to_server_send.aclose() + await client_to_server_receive.aclose() + store.cleanup() + + +@pytest.mark.anyio +async def test_client_handles_cancel_task_request() -> None: + """Test that client can respond to CancelTaskRequest from server.""" + with anyio.fail_after(10): # 10 second timeout + store = InMemoryTaskStore() + + async def cancel_task_handler( + context: RequestContext[ClientSession, None], + params: CancelTaskRequestParams, + ) -> CancelTaskResult | ErrorData: + task = await store.get_task(params.taskId) + if task is None: + return ErrorData(code=types.INVALID_REQUEST, message=f"Task {params.taskId} not found") + await store.update_task(params.taskId, status="cancelled") + updated = await store.get_task(params.taskId) + assert updated is not None + return CancelTaskResult( + taskId=updated.taskId, + status=updated.status, + createdAt=updated.createdAt, + lastUpdatedAt=updated.lastUpdatedAt, + ttl=updated.ttl, + ) + + # Set up streams + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + # Pre-create a task + await store.create_task(TaskMetadata(ttl=60000), task_id="task-to-cancel") + + async def message_handler( + message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, + ) -> None: + if isinstance(message, Exception): + raise message + + task_handlers = ExperimentalTaskHandlers(cancel_task=cancel_task_handler) + client_ready = anyio.Event() + + try: + async with anyio.create_task_group() as tg: + + async def run_client(): + async with ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + experimental_task_handlers=task_handlers, + ): + client_ready.set() + await anyio.sleep_forever() + + tg.start_soon(run_client) + await client_ready.wait() + + # Server sends CancelTaskRequest to client + request_id = "req-4" + request = types.JSONRPCRequest( + jsonrpc="2.0", + id=request_id, + method="tasks/cancel", + params={"taskId": "task-to-cancel"}, + ) + await server_to_client_send.send(SessionMessage(types.JSONRPCMessage(request))) + + # Receive response + response_msg = await client_to_server_receive.receive() + response = response_msg.message.root + assert isinstance(response, types.JSONRPCResponse) + + result = CancelTaskResult.model_validate(response.result) + assert result.taskId == "task-to-cancel" + assert result.status == "cancelled" + + tg.cancel_scope.cancel() + finally: + await server_to_client_send.aclose() + await server_to_client_receive.aclose() + await client_to_server_send.aclose() + await client_to_server_receive.aclose() + store.cleanup() + + +@pytest.mark.anyio +async def test_client_task_augmented_sampling() -> None: + """Test that client can handle task-augmented sampling request from server. + + When server sends CreateMessageRequest with task field: + 1. Client creates a task + 2. Client returns CreateTaskResult immediately + 3. Client processes sampling in background + 4. Server polls via GetTaskRequest + 5. Server gets result via GetTaskPayloadRequest + """ + with anyio.fail_after(10): # 10 second timeout + store = InMemoryTaskStore() + sampling_completed = Event() + created_task_id: list[str | None] = [None] + # Use a mutable container for spawning background tasks + # We must NOT overwrite session._task_group as it breaks the session lifecycle + background_tg: list[TaskGroup | None] = [None] + + async def task_augmented_sampling_callback( + context: RequestContext[ClientSession, None], + params: CreateMessageRequestParams, + task_metadata: TaskMetadata, + ) -> CreateTaskResult: + """Handle task-augmented sampling request.""" + # Create the task + task = await store.create_task(task_metadata) + created_task_id[0] = task.taskId + + # Process in background (simulated) + async def do_sampling(): + result = CreateMessageResult( + role="assistant", + content=TextContent(type="text", text="Sampled response"), + model="test-model", + stopReason="endTurn", + ) + await store.store_result(task.taskId, result) + await store.update_task(task.taskId, status="completed") + sampling_completed.set() + + # Spawn in the outer task group via closure reference + # (not session._task_group which would break session lifecycle) + assert background_tg[0] is not None + background_tg[0].start_soon(do_sampling) + + return CreateTaskResult(task=task) + + async def get_task_handler( + context: RequestContext[ClientSession, None], + params: GetTaskRequestParams, + ) -> GetTaskResult | ErrorData: + task = await store.get_task(params.taskId) + if task is None: + return ErrorData(code=types.INVALID_REQUEST, message="Task not found") + return GetTaskResult( + taskId=task.taskId, + status=task.status, + statusMessage=task.statusMessage, + createdAt=task.createdAt, + lastUpdatedAt=task.lastUpdatedAt, + ttl=task.ttl, + pollInterval=task.pollInterval, + ) + + async def get_task_result_handler( + context: RequestContext[ClientSession, None], + params: GetTaskPayloadRequestParams, + ) -> GetTaskPayloadResult | ErrorData: + result = await store.get_result(params.taskId) + if result is None: + return ErrorData(code=types.INVALID_REQUEST, message="Result not found") + assert isinstance(result, CreateMessageResult) + return GetTaskPayloadResult(**result.model_dump()) + + # Set up streams + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + async def message_handler( + message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, + ) -> None: + if isinstance(message, Exception): + raise message + + task_handlers = ExperimentalTaskHandlers( + augmented_sampling=task_augmented_sampling_callback, + get_task=get_task_handler, + get_task_result=get_task_result_handler, + ) + client_ready = anyio.Event() + + try: + async with anyio.create_task_group() as tg: + # Set the closure reference for background task spawning + background_tg[0] = tg + + async def run_client(): + async with ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + experimental_task_handlers=task_handlers, + ): + client_ready.set() + await anyio.sleep_forever() + + tg.start_soon(run_client) + await client_ready.wait() + + # Step 1: Server sends task-augmented CreateMessageRequest + request_id = "req-sampling" + request = types.JSONRPCRequest( + jsonrpc="2.0", + id=request_id, + method="sampling/createMessage", + params={ + "messages": [{"role": "user", "content": {"type": "text", "text": "Hello"}}], + "maxTokens": 100, + "task": {"ttl": 60000}, + }, + ) + await server_to_client_send.send(SessionMessage(types.JSONRPCMessage(request))) + + # Step 2: Client should respond with CreateTaskResult + response_msg = await client_to_server_receive.receive() + response = response_msg.message.root + assert isinstance(response, types.JSONRPCResponse) + + task_result = CreateTaskResult.model_validate(response.result) + task_id = task_result.task.taskId + assert task_id == created_task_id[0] + + # Step 3: Wait for background sampling to complete + await sampling_completed.wait() + + # Step 4: Server polls task status + poll_request = types.JSONRPCRequest( + jsonrpc="2.0", + id="req-poll", + method="tasks/get", + params={"taskId": task_id}, + ) + await server_to_client_send.send(SessionMessage(types.JSONRPCMessage(poll_request))) + + poll_response_msg = await client_to_server_receive.receive() + poll_response = poll_response_msg.message.root + assert isinstance(poll_response, types.JSONRPCResponse) + + status = GetTaskResult.model_validate(poll_response.result) + assert status.status == "completed" + + # Step 5: Server gets result + result_request = types.JSONRPCRequest( + jsonrpc="2.0", + id="req-result", + method="tasks/result", + params={"taskId": task_id}, + ) + await server_to_client_send.send(SessionMessage(types.JSONRPCMessage(result_request))) + + result_response_msg = await client_to_server_receive.receive() + result_response = result_response_msg.message.root + assert isinstance(result_response, types.JSONRPCResponse) + + # GetTaskPayloadResult is a passthrough - access raw dict + assert isinstance(result_response.result, dict) + final_result = result_response.result + # The result should contain the sampling response + assert final_result["role"] == "assistant" + + tg.cancel_scope.cancel() + finally: + await server_to_client_send.aclose() + await server_to_client_receive.aclose() + await client_to_server_send.aclose() + await client_to_server_receive.aclose() + store.cleanup() + + +@pytest.mark.anyio +async def test_client_returns_error_for_unhandled_task_request() -> None: + """Test that client returns error when no handler is registered for task request.""" + with anyio.fail_after(10): # 10 second timeout + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + async def message_handler( + message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, + ) -> None: + if isinstance(message, Exception): + raise message + + client_ready = anyio.Event() + + try: + # Client with no task handlers (uses defaults which return errors) + async with anyio.create_task_group() as tg: + + async def run_client(): + async with ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ): + client_ready.set() + await anyio.sleep_forever() + + tg.start_soon(run_client) + await client_ready.wait() + + # Server sends GetTaskRequest but client has no handler + request = types.JSONRPCRequest( + jsonrpc="2.0", + id="req-unhandled", + method="tasks/get", + params={"taskId": "nonexistent"}, + ) + await server_to_client_send.send(SessionMessage(types.JSONRPCMessage(request))) + + # Client should respond with error + response_msg = await client_to_server_receive.receive() + response = response_msg.message.root + # Error responses come back as JSONRPCError, not JSONRPCResponse + assert isinstance(response, types.JSONRPCError) + assert ( + "not supported" in response.error.message.lower() + or "method not found" in response.error.message.lower() + ) + + tg.cancel_scope.cancel() + finally: + await server_to_client_send.aclose() + await server_to_client_receive.aclose() + await client_to_server_send.aclose() + await client_to_server_receive.aclose() diff --git a/tests/experimental/tasks/client/test_tasks.py b/tests/experimental/tasks/client/test_tasks.py new file mode 100644 index 0000000000..5807bbe14a --- /dev/null +++ b/tests/experimental/tasks/client/test_tasks.py @@ -0,0 +1,511 @@ +"""Tests for the experimental client task methods (session.experimental).""" + +from dataclasses import dataclass, field +from typing import Any + +import anyio +import pytest +from anyio import Event +from anyio.abc import TaskGroup + +from mcp.client.session import ClientSession +from mcp.server import Server +from mcp.server.lowlevel import NotificationOptions +from mcp.server.models import InitializationOptions +from mcp.server.session import ServerSession +from mcp.shared.experimental.tasks import InMemoryTaskStore, task_execution +from mcp.shared.message import SessionMessage +from mcp.shared.session import RequestResponder +from mcp.types import ( + CallToolRequest, + CallToolRequestParams, + CallToolResult, + CancelTaskRequest, + CancelTaskResult, + ClientRequest, + ClientResult, + CreateTaskResult, + GetTaskPayloadRequest, + GetTaskPayloadResult, + GetTaskRequest, + GetTaskResult, + ListTasksRequest, + ListTasksResult, + ServerNotification, + ServerRequest, + TaskMetadata, + TextContent, + Tool, +) + + +@dataclass +class AppContext: + """Application context passed via lifespan_context.""" + + task_group: TaskGroup + store: InMemoryTaskStore + task_done_events: dict[str, Event] = field(default_factory=lambda: {}) + + +@pytest.mark.anyio +async def test_session_experimental_get_task() -> None: + """Test session.experimental.get_task() method.""" + # Note: We bypass the normal lifespan mechanism + server: Server[AppContext, Any] = Server("test-server") # type: ignore[assignment] + store = InMemoryTaskStore() + + @server.list_tools() + async def list_tools(): + return [Tool(name="test_tool", description="Test", inputSchema={"type": "object"})] + + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent] | CreateTaskResult: + ctx = server.request_context + app = ctx.lifespan_context + + if ctx.experimental.is_task: + task_metadata = ctx.experimental.task_metadata + assert task_metadata is not None + task = await app.store.create_task(task_metadata) + + done_event = Event() + app.task_done_events[task.taskId] = done_event + + async def do_work(): + async with task_execution(task.taskId, app.store) as task_ctx: + await task_ctx.complete( + CallToolResult(content=[TextContent(type="text", text="Done")]), + notify=False, + ) + done_event.set() + + app.task_group.start_soon(do_work) + return CreateTaskResult(task=task) + + return [TextContent(type="text", text="Sync")] + + @server.experimental.get_task() + async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: + app = server.request_context.lifespan_context + task = await app.store.get_task(request.params.taskId) + if task is None: + raise ValueError(f"Task {request.params.taskId} not found") + return GetTaskResult( + taskId=task.taskId, + status=task.status, + statusMessage=task.statusMessage, + createdAt=task.createdAt, + lastUpdatedAt=task.lastUpdatedAt, + ttl=task.ttl, + pollInterval=task.pollInterval, + ) + + # Set up streams + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + async def message_handler( + message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, + ) -> None: + if isinstance(message, Exception): + raise message + + async def run_server(app_context: AppContext): + async with ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="test-server", + server_version="1.0.0", + capabilities=server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ), + ) as server_session: + async for message in server_session.incoming_messages: + await server._handle_message(message, server_session, app_context, raise_exceptions=False) + + async with anyio.create_task_group() as tg: + app_context = AppContext(task_group=tg, store=store) + tg.start_soon(run_server, app_context) + + async with ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as client_session: + await client_session.initialize() + + # Create a task + create_result = await client_session.send_request( + ClientRequest( + CallToolRequest( + params=CallToolRequestParams( + name="test_tool", + arguments={}, + task=TaskMetadata(ttl=60000), + ) + ) + ), + CreateTaskResult, + ) + task_id = create_result.task.taskId + + # Wait for task to complete + await app_context.task_done_events[task_id].wait() + + # Use session.experimental to get task status + task_status = await client_session.experimental.get_task(task_id) + + assert task_status.taskId == task_id + assert task_status.status == "completed" + + tg.cancel_scope.cancel() + + store.cleanup() + + +@pytest.mark.anyio +async def test_session_experimental_get_task_result() -> None: + """Test session.experimental.get_task_result() method.""" + server: Server[AppContext, Any] = Server("test-server") # type: ignore[assignment] + store = InMemoryTaskStore() + + @server.list_tools() + async def list_tools(): + return [Tool(name="test_tool", description="Test", inputSchema={"type": "object"})] + + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent] | CreateTaskResult: + ctx = server.request_context + app = ctx.lifespan_context + + if ctx.experimental.is_task: + task_metadata = ctx.experimental.task_metadata + assert task_metadata is not None + task = await app.store.create_task(task_metadata) + + done_event = Event() + app.task_done_events[task.taskId] = done_event + + async def do_work(): + async with task_execution(task.taskId, app.store) as task_ctx: + await task_ctx.complete( + CallToolResult(content=[TextContent(type="text", text="Task result content")]), + notify=False, + ) + done_event.set() + + app.task_group.start_soon(do_work) + return CreateTaskResult(task=task) + + return [TextContent(type="text", text="Sync")] + + @server.experimental.get_task_result() + async def handle_get_task_result(request: GetTaskPayloadRequest) -> GetTaskPayloadResult: + app = server.request_context.lifespan_context + result = await app.store.get_result(request.params.taskId) + if result is None: + raise ValueError(f"Result for task {request.params.taskId} not found") + assert isinstance(result, CallToolResult) + return GetTaskPayloadResult(**result.model_dump()) + + # Set up streams + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + async def message_handler( + message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, + ) -> None: + if isinstance(message, Exception): + raise message + + async def run_server(app_context: AppContext): + async with ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="test-server", + server_version="1.0.0", + capabilities=server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ), + ) as server_session: + async for message in server_session.incoming_messages: + await server._handle_message(message, server_session, app_context, raise_exceptions=False) + + async with anyio.create_task_group() as tg: + app_context = AppContext(task_group=tg, store=store) + tg.start_soon(run_server, app_context) + + async with ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as client_session: + await client_session.initialize() + + # Create a task + create_result = await client_session.send_request( + ClientRequest( + CallToolRequest( + params=CallToolRequestParams( + name="test_tool", + arguments={}, + task=TaskMetadata(ttl=60000), + ) + ) + ), + CreateTaskResult, + ) + task_id = create_result.task.taskId + + # Wait for task to complete + await app_context.task_done_events[task_id].wait() + + # Use TaskClient to get task result + task_result = await client_session.experimental.get_task_result(task_id, CallToolResult) + + assert len(task_result.content) == 1 + content = task_result.content[0] + assert isinstance(content, TextContent) + assert content.text == "Task result content" + + tg.cancel_scope.cancel() + + store.cleanup() + + +@pytest.mark.anyio +async def test_session_experimental_list_tasks() -> None: + """Test TaskClient.list_tasks() method.""" + server: Server[AppContext, Any] = Server("test-server") # type: ignore[assignment] + store = InMemoryTaskStore() + + @server.list_tools() + async def list_tools(): + return [Tool(name="test_tool", description="Test", inputSchema={"type": "object"})] + + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent] | CreateTaskResult: + ctx = server.request_context + app = ctx.lifespan_context + + if ctx.experimental.is_task: + task_metadata = ctx.experimental.task_metadata + assert task_metadata is not None + task = await app.store.create_task(task_metadata) + + done_event = Event() + app.task_done_events[task.taskId] = done_event + + async def do_work(): + async with task_execution(task.taskId, app.store) as task_ctx: + await task_ctx.complete( + CallToolResult(content=[TextContent(type="text", text="Done")]), + notify=False, + ) + done_event.set() + + app.task_group.start_soon(do_work) + return CreateTaskResult(task=task) + + return [TextContent(type="text", text="Sync")] + + @server.experimental.list_tasks() + async def handle_list_tasks(request: ListTasksRequest) -> ListTasksResult: + app = server.request_context.lifespan_context + tasks_list, next_cursor = await app.store.list_tasks(cursor=request.params.cursor if request.params else None) + return ListTasksResult(tasks=tasks_list, nextCursor=next_cursor) + + # Set up streams + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + async def message_handler( + message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, + ) -> None: + if isinstance(message, Exception): + raise message + + async def run_server(app_context: AppContext): + async with ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="test-server", + server_version="1.0.0", + capabilities=server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ), + ) as server_session: + async for message in server_session.incoming_messages: + await server._handle_message(message, server_session, app_context, raise_exceptions=False) + + async with anyio.create_task_group() as tg: + app_context = AppContext(task_group=tg, store=store) + tg.start_soon(run_server, app_context) + + async with ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as client_session: + await client_session.initialize() + + # Create two tasks + for _ in range(2): + create_result = await client_session.send_request( + ClientRequest( + CallToolRequest( + params=CallToolRequestParams( + name="test_tool", + arguments={}, + task=TaskMetadata(ttl=60000), + ) + ) + ), + CreateTaskResult, + ) + await app_context.task_done_events[create_result.task.taskId].wait() + + # Use TaskClient to list tasks + list_result = await client_session.experimental.list_tasks() + + assert len(list_result.tasks) == 2 + + tg.cancel_scope.cancel() + + store.cleanup() + + +@pytest.mark.anyio +async def test_session_experimental_cancel_task() -> None: + """Test TaskClient.cancel_task() method.""" + server: Server[AppContext, Any] = Server("test-server") # type: ignore[assignment] + store = InMemoryTaskStore() + + @server.list_tools() + async def list_tools(): + return [Tool(name="test_tool", description="Test", inputSchema={"type": "object"})] + + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent] | CreateTaskResult: + ctx = server.request_context + app = ctx.lifespan_context + + if ctx.experimental.is_task: + task_metadata = ctx.experimental.task_metadata + assert task_metadata is not None + task = await app.store.create_task(task_metadata) + # Don't start any work - task stays in "working" status + return CreateTaskResult(task=task) + + return [TextContent(type="text", text="Sync")] + + @server.experimental.get_task() + async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: + app = server.request_context.lifespan_context + task = await app.store.get_task(request.params.taskId) + if task is None: + raise ValueError(f"Task {request.params.taskId} not found") + return GetTaskResult( + taskId=task.taskId, + status=task.status, + statusMessage=task.statusMessage, + createdAt=task.createdAt, + lastUpdatedAt=task.lastUpdatedAt, + ttl=task.ttl, + pollInterval=task.pollInterval, + ) + + @server.experimental.cancel_task() + async def handle_cancel_task(request: CancelTaskRequest) -> CancelTaskResult: + app = server.request_context.lifespan_context + task = await app.store.get_task(request.params.taskId) + if task is None: + raise ValueError(f"Task {request.params.taskId} not found") + await app.store.update_task(request.params.taskId, status="cancelled") + # CancelTaskResult extends Task, so we need to return the updated task info + updated_task = await app.store.get_task(request.params.taskId) + assert updated_task is not None + return CancelTaskResult( + taskId=updated_task.taskId, + status=updated_task.status, + createdAt=updated_task.createdAt, + lastUpdatedAt=updated_task.lastUpdatedAt, + ttl=updated_task.ttl, + ) + + # Set up streams + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + async def message_handler( + message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, + ) -> None: + if isinstance(message, Exception): + raise message + + async def run_server(app_context: AppContext): + async with ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="test-server", + server_version="1.0.0", + capabilities=server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ), + ) as server_session: + async for message in server_session.incoming_messages: + await server._handle_message(message, server_session, app_context, raise_exceptions=False) + + async with anyio.create_task_group() as tg: + app_context = AppContext(task_group=tg, store=store) + tg.start_soon(run_server, app_context) + + async with ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as client_session: + await client_session.initialize() + + # Create a task (but don't complete it) + create_result = await client_session.send_request( + ClientRequest( + CallToolRequest( + params=CallToolRequestParams( + name="test_tool", + arguments={}, + task=TaskMetadata(ttl=60000), + ) + ) + ), + CreateTaskResult, + ) + task_id = create_result.task.taskId + + # Verify task is working + status_before = await client_session.experimental.get_task(task_id) + assert status_before.status == "working" + + # Cancel the task + await client_session.experimental.cancel_task(task_id) + + # Verify task is cancelled + status_after = await client_session.experimental.get_task(task_id) + assert status_after.status == "cancelled" + + tg.cancel_scope.cancel() + + store.cleanup() diff --git a/tests/experimental/tasks/server/__init__.py b/tests/experimental/tasks/server/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/experimental/tasks/server/test_context.py b/tests/experimental/tasks/server/test_context.py new file mode 100644 index 0000000000..778c0a2a93 --- /dev/null +++ b/tests/experimental/tasks/server/test_context.py @@ -0,0 +1,538 @@ +"""Tests for TaskContext and helper functions.""" + +from unittest.mock import AsyncMock + +import anyio +import pytest + +from mcp.shared.experimental.tasks import ( + MODEL_IMMEDIATE_RESPONSE_KEY, + InMemoryTaskStore, + TaskContext, + create_task_state, + run_task, + task_execution, +) +from mcp.types import CallToolResult, TaskMetadata, TextContent + + +async def wait_for_terminal_status(store: InMemoryTaskStore, task_id: str, timeout: float = 5.0) -> None: + """Wait for a task to reach terminal status (completed, failed, cancelled).""" + terminal_statuses = {"completed", "failed", "cancelled"} + with anyio.fail_after(timeout): + while True: + task = await store.get_task(task_id) + if task and task.status in terminal_statuses: + return + await anyio.sleep(0) # Yield to allow other tasks to run + + +# --- TaskContext tests --- + + +@pytest.mark.anyio +async def test_task_context_properties() -> None: + """Test TaskContext basic properties.""" + store = InMemoryTaskStore() + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + ctx = TaskContext(task, store, session=None) + + assert ctx.task_id == task.taskId + assert ctx.task.taskId == task.taskId + assert ctx.task.status == "working" + assert ctx.is_cancelled is False + + store.cleanup() + + +@pytest.mark.anyio +async def test_task_context_update_status() -> None: + """Test TaskContext.update_status.""" + store = InMemoryTaskStore() + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + ctx = TaskContext(task, store, session=None) + + await ctx.update_status("Processing...", notify=False) + + assert ctx.task.statusMessage == "Processing..." + retrieved = await store.get_task(task.taskId) + assert retrieved is not None + assert retrieved.statusMessage == "Processing..." + + store.cleanup() + + +@pytest.mark.anyio +async def test_task_context_update_status_multiple() -> None: + """Test multiple status updates.""" + store = InMemoryTaskStore() + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + ctx = TaskContext(task, store, session=None) + + await ctx.update_status("Step 1...", notify=False) + assert ctx.task.statusMessage == "Step 1..." + + await ctx.update_status("Step 2...", notify=False) + assert ctx.task.statusMessage == "Step 2..." + + await ctx.update_status("Step 3...", notify=False) + assert ctx.task.statusMessage == "Step 3..." + + store.cleanup() + + +@pytest.mark.anyio +async def test_task_context_complete() -> None: + """Test TaskContext.complete.""" + store = InMemoryTaskStore() + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + ctx = TaskContext(task, store, session=None) + + result = CallToolResult(content=[TextContent(type="text", text="Done!")]) + await ctx.complete(result, notify=False) + + assert ctx.task.status == "completed" + + stored_result = await store.get_result(task.taskId) + assert stored_result == result + + store.cleanup() + + +@pytest.mark.anyio +async def test_task_context_fail() -> None: + """Test TaskContext.fail.""" + store = InMemoryTaskStore() + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + ctx = TaskContext(task, store, session=None) + + await ctx.fail("Something went wrong", notify=False) + + assert ctx.task.status == "failed" + assert ctx.task.statusMessage == "Something went wrong" + + store.cleanup() + + +@pytest.mark.anyio +async def test_task_context_cancellation() -> None: + """Test TaskContext cancellation flag.""" + store = InMemoryTaskStore() + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + ctx = TaskContext(task, store, session=None) + + assert ctx.is_cancelled is False + + ctx.request_cancellation() + + assert ctx.is_cancelled is True + + store.cleanup() + + +@pytest.mark.anyio +async def test_task_context_no_notification_without_session() -> None: + """Test that notification doesn't fail when no session is provided.""" + store = InMemoryTaskStore() + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + ctx = TaskContext(task, store, session=None) + + # These should not raise even with notify=True (default) + await ctx.update_status("Status update") + await ctx.complete(CallToolResult(content=[TextContent(type="text", text="Done")])) + + store.cleanup() + + +# --- create_task_state helper tests --- + + +def test_create_task_state_generates_id() -> None: + """Test create_task_state generates a task ID.""" + metadata = TaskMetadata(ttl=60000) + task = create_task_state(metadata) + + assert task.taskId is not None + assert len(task.taskId) > 0 + assert task.status == "working" + assert task.ttl == 60000 + assert task.pollInterval == 500 # Default poll interval + + +def test_create_task_state_uses_provided_id() -> None: + """Test create_task_state uses provided task ID.""" + metadata = TaskMetadata(ttl=60000) + task = create_task_state(metadata, task_id="my-task-id") + + assert task.taskId == "my-task-id" + + +def test_create_task_state_null_ttl() -> None: + """Test create_task_state with null TTL.""" + metadata = TaskMetadata(ttl=None) + task = create_task_state(metadata) + + assert task.ttl is None + assert task.status == "working" + + +def test_create_task_state_has_created_at() -> None: + """Test create_task_state sets createdAt timestamp.""" + metadata = TaskMetadata(ttl=60000) + task = create_task_state(metadata) + + assert task.createdAt is not None + + +# --- TaskContext notification tests (with mock session) --- + + +@pytest.mark.anyio +async def test_task_context_sends_notification_on_fail() -> None: + """Test TaskContext.fail sends notification when session is provided.""" + store = InMemoryTaskStore() + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + + # Create a mock session with send_notification method + mock_session = AsyncMock() + + ctx = TaskContext(task, store, session=mock_session) + + # Fail with notification enabled (default) + await ctx.fail("Test error") + + # Verify notification was sent + assert mock_session.send_notification.called + call_args = mock_session.send_notification.call_args[0][0] + # The notification is wrapped in ServerNotification + assert call_args.root.params.taskId == task.taskId + assert call_args.root.params.status == "failed" + assert call_args.root.params.statusMessage == "Test error" + + store.cleanup() + + +@pytest.mark.anyio +async def test_task_context_sends_notification_on_update_status() -> None: + """Test TaskContext.update_status sends notification when session is provided.""" + store = InMemoryTaskStore() + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + + mock_session = AsyncMock() + ctx = TaskContext(task, store, session=mock_session) + + # Update status with notification enabled (default) + await ctx.update_status("Processing...") + + # Verify notification was sent + assert mock_session.send_notification.called + call_args = mock_session.send_notification.call_args[0][0] + assert call_args.root.params.taskId == task.taskId + assert call_args.root.params.status == "working" + assert call_args.root.params.statusMessage == "Processing..." + + store.cleanup() + + +@pytest.mark.anyio +async def test_task_context_sends_notification_on_complete() -> None: + """Test TaskContext.complete sends notification when session is provided.""" + store = InMemoryTaskStore() + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + + mock_session = AsyncMock() + ctx = TaskContext(task, store, session=mock_session) + + result = CallToolResult(content=[TextContent(type="text", text="Done!")]) + await ctx.complete(result) + + # Verify notification was sent + assert mock_session.send_notification.called + call_args = mock_session.send_notification.call_args[0][0] + assert call_args.root.params.taskId == task.taskId + assert call_args.root.params.status == "completed" + + store.cleanup() + + +# --- task_execution context manager tests --- + + +@pytest.mark.anyio +async def test_task_execution_raises_on_nonexistent_task() -> None: + """Test task_execution raises ValueError when task doesn't exist.""" + store = InMemoryTaskStore() + + with pytest.raises(ValueError, match="Task nonexistent-id not found"): + async with task_execution("nonexistent-id", store): + pass + + store.cleanup() + + +# the context handler swallows the error, therefore the code after is reachable even though IDEs say it's not. +# noinspection PyUnreachableCode +@pytest.mark.anyio +async def test_task_execution_auto_fails_on_exception() -> None: + """Test task_execution automatically fails task on unhandled exception.""" + store = InMemoryTaskStore() + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + + # task_execution suppresses exceptions and auto-fails the task + async with task_execution(task.taskId, store) as ctx: + await ctx.update_status("Starting...", notify=False) + raise RuntimeError("Simulated error") + + # Execution reaches here because exception is suppressed + # Task should be in failed state + failed_task = await store.get_task(task.taskId) + assert failed_task is not None + assert failed_task.status == "failed" + assert failed_task.statusMessage == "Simulated error" + + store.cleanup() + + +# the context handler swallows the error, therefore the code after is reachable even though IDEs say it's not. +# noinspection PyUnreachableCode +@pytest.mark.anyio +async def test_task_execution_doesnt_fail_if_already_terminal() -> None: + """Test task_execution doesn't re-fail if task is already in terminal state.""" + store = InMemoryTaskStore() + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + + # Complete the task first, then raise exception + async with task_execution(task.taskId, store) as ctx: + result = CallToolResult(content=[TextContent(type="text", text="Done")]) + await ctx.complete(result, notify=False) + # Now raise - but task is already completed + raise RuntimeError("Post-completion error") + + # Task should remain completed (not failed) + completed_task = await store.get_task(task.taskId) + assert completed_task is not None + assert completed_task.status == "completed" + + store.cleanup() + + +# --- run_task helper function tests --- + + +@pytest.mark.anyio +async def test_run_task_successful_completion() -> None: + """Test run_task successfully completes work and sets result.""" + store = InMemoryTaskStore() + + async def work(ctx: TaskContext) -> CallToolResult: + await ctx.update_status("Working...", notify=False) + return CallToolResult(content=[TextContent(type="text", text="Success!")]) + + async with anyio.create_task_group() as tg: + result, _ = await run_task( + tg, + store, + TaskMetadata(ttl=60000), + work, + ) + + # Result should be CreateTaskResult with initial working state + assert result.task.status == "working" + task_id = result.task.taskId + + # Wait for work to complete + await wait_for_terminal_status(store, task_id) + + # Check task is completed + task = await store.get_task(task_id) + assert task is not None + assert task.status == "completed" + + # Check result is stored + stored_result = await store.get_result(task_id) + assert stored_result is not None + assert isinstance(stored_result, CallToolResult) + assert stored_result.content[0].text == "Success!" # type: ignore[union-attr] + + store.cleanup() + + +@pytest.mark.anyio +async def test_run_task_auto_fails_on_exception() -> None: + """Test run_task automatically fails task when work raises exception.""" + store = InMemoryTaskStore() + + async def failing_work(ctx: TaskContext) -> CallToolResult: + await ctx.update_status("About to fail...", notify=False) + raise RuntimeError("Work failed!") + + async with anyio.create_task_group() as tg: + result, _ = await run_task( + tg, + store, + TaskMetadata(ttl=60000), + failing_work, + ) + + task_id = result.task.taskId + + # Wait for work to complete (fail) + await wait_for_terminal_status(store, task_id) + + # Check task is failed + task = await store.get_task(task_id) + assert task is not None + assert task.status == "failed" + assert task.statusMessage == "Work failed!" + + store.cleanup() + + +@pytest.mark.anyio +async def test_run_task_with_custom_task_id() -> None: + """Test run_task with custom task_id.""" + store = InMemoryTaskStore() + + async def work(ctx: TaskContext) -> CallToolResult: + return CallToolResult(content=[TextContent(type="text", text="Done")]) + + async with anyio.create_task_group() as tg: + result, _ = await run_task( + tg, + store, + TaskMetadata(ttl=60000), + work, + task_id="my-custom-task-id", + ) + + assert result.task.taskId == "my-custom-task-id" + + # Wait for work to complete + await wait_for_terminal_status(store, "my-custom-task-id") + + task = await store.get_task("my-custom-task-id") + assert task is not None + assert task.status == "completed" + + store.cleanup() + + +@pytest.mark.anyio +async def test_run_task_doesnt_fail_if_already_terminal() -> None: + """Test run_task doesn't re-fail if task already reached terminal state.""" + store = InMemoryTaskStore() + + async def work_that_cancels_then_fails(ctx: TaskContext) -> CallToolResult: + # Manually mark as cancelled, then raise + await store.update_task(ctx.task_id, status="cancelled") + # Refresh ctx's task state + ctx._task = await store.get_task(ctx.task_id) # type: ignore[assignment] + raise RuntimeError("This shouldn't change the status") + + async with anyio.create_task_group() as tg: + result, _ = await run_task( + tg, + store, + TaskMetadata(ttl=60000), + work_that_cancels_then_fails, + ) + + task_id = result.task.taskId + + # Wait for work to complete + await wait_for_terminal_status(store, task_id) + + # Task should remain cancelled (not changed to failed) + task = await store.get_task(task_id) + assert task is not None + assert task.status == "cancelled" + + store.cleanup() + + +@pytest.mark.anyio +async def test_run_task_doesnt_complete_if_already_terminal() -> None: + """Test run_task doesn't complete if task already reached terminal state.""" + store = InMemoryTaskStore() + + async def work_that_completes_after_cancel(ctx: TaskContext) -> CallToolResult: + # Manually mark as cancelled before returning result + await store.update_task(ctx.task_id, status="cancelled") + # Refresh ctx's task state + ctx._task = await store.get_task(ctx.task_id) # type: ignore[assignment] + # Return a result, but task shouldn't be marked completed + return CallToolResult(content=[TextContent(type="text", text="Done")]) + + async with anyio.create_task_group() as tg: + result, _ = await run_task( + tg, + store, + TaskMetadata(ttl=60000), + work_that_completes_after_cancel, + ) + + task_id = result.task.taskId + + # Wait for work to complete + await wait_for_terminal_status(store, task_id) + + # Task should remain cancelled (not changed to completed) + task = await store.get_task(task_id) + assert task is not None + assert task.status == "cancelled" + + store.cleanup() + + +@pytest.mark.anyio +async def test_run_task_with_model_immediate_response() -> None: + """Test run_task includes model_immediate_response in _meta when provided.""" + store = InMemoryTaskStore() + + async def work(ctx: TaskContext) -> CallToolResult: + return CallToolResult(content=[TextContent(type="text", text="Done")]) + + immediate_msg = "Processing your request, please wait..." + + async with anyio.create_task_group() as tg: + result, _ = await run_task( + tg, + store, + TaskMetadata(ttl=60000), + work, + model_immediate_response=immediate_msg, + ) + + # Result should have _meta with model-immediate-response + assert result.meta is not None + assert MODEL_IMMEDIATE_RESPONSE_KEY in result.meta + assert result.meta[MODEL_IMMEDIATE_RESPONSE_KEY] == immediate_msg + + # Verify serialization uses _meta alias + serialized = result.model_dump(by_alias=True) + assert "_meta" in serialized + assert serialized["_meta"][MODEL_IMMEDIATE_RESPONSE_KEY] == immediate_msg + + store.cleanup() + + +@pytest.mark.anyio +async def test_run_task_without_model_immediate_response() -> None: + """Test run_task has no _meta when model_immediate_response is not provided.""" + store = InMemoryTaskStore() + + async def work(ctx: TaskContext) -> CallToolResult: + return CallToolResult(content=[TextContent(type="text", text="Done")]) + + async with anyio.create_task_group() as tg: + result, _ = await run_task( + tg, + store, + TaskMetadata(ttl=60000), + work, + ) + + # Result should not have _meta + assert result.meta is None + + store.cleanup() diff --git a/tests/experimental/tasks/server/test_elicitation_flow.py b/tests/experimental/tasks/server/test_elicitation_flow.py new file mode 100644 index 0000000000..67329292ec --- /dev/null +++ b/tests/experimental/tasks/server/test_elicitation_flow.py @@ -0,0 +1,313 @@ +""" +Integration test for task elicitation flow. + +This tests the complete elicitation flow: +1. Client sends task-augmented tool call +2. Server creates task, returns CreateTaskResult immediately +3. Server handler uses TaskSession.elicit() to request input +4. Client polls, sees input_required status +5. Client calls tasks/result which delivers the elicitation +6. Client responds to elicitation +7. Response is routed back to server handler +8. Handler completes task +9. Client receives final result +""" + +from dataclasses import dataclass, field +from typing import Any + +import anyio +import pytest +from anyio import Event +from anyio.abc import TaskGroup + +from mcp.client.session import ClientSession +from mcp.server import Server +from mcp.server.lowlevel import NotificationOptions +from mcp.server.models import InitializationOptions +from mcp.server.session import ServerSession +from mcp.shared.experimental.tasks import ( + InMemoryTaskMessageQueue, + InMemoryTaskStore, + TaskResultHandler, + TaskSession, + task_execution, +) +from mcp.shared.message import SessionMessage +from mcp.types import ( + TASK_REQUIRED, + CallToolRequest, + CallToolRequestParams, + CallToolResult, + ClientRequest, + CreateTaskResult, + ElicitRequest, + ElicitResult, + GetTaskPayloadRequest, + GetTaskPayloadRequestParams, + GetTaskPayloadResult, + GetTaskRequest, + GetTaskRequestParams, + GetTaskResult, + TaskMetadata, + TextContent, + Tool, + ToolExecution, +) + + +@dataclass +class AppContext: + """Application context with task infrastructure.""" + + task_group: TaskGroup + store: InMemoryTaskStore + queue: InMemoryTaskMessageQueue + task_result_handler: TaskResultHandler + # Events to signal when tasks complete (for testing without sleeps) + task_done_events: dict[str, Event] = field(default_factory=lambda: {}) + + +@pytest.mark.anyio +async def test_elicitation_during_task_with_response_routing() -> None: + """ + Test the complete elicitation flow with response routing. + + This is an end-to-end test that verifies: + - TaskSession.elicit() enqueues the request + - TaskResultHandler delivers it via tasks/result + - Client responds + - Response is routed back to the waiting resolver + - Handler continues and completes + """ + server: Server[AppContext, Any] = Server("test-elicitation") # type: ignore[assignment] + store = InMemoryTaskStore() + queue = InMemoryTaskMessageQueue() + task_result_handler = TaskResultHandler(store, queue) + + @server.list_tools() + async def list_tools(): + return [ + Tool( + name="interactive_tool", + description="A tool that asks for user confirmation", + inputSchema={ + "type": "object", + "properties": {"data": {"type": "string"}}, + }, + execution=ToolExecution(taskSupport=TASK_REQUIRED), + ) + ] + + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent] | CreateTaskResult: + ctx = server.request_context + app = ctx.lifespan_context + + if name == "interactive_tool" and ctx.experimental.is_task: + task_metadata = ctx.experimental.task_metadata + assert task_metadata is not None + task = await app.store.create_task(task_metadata) + + done_event = Event() + app.task_done_events[task.taskId] = done_event + + async def do_interactive_work(): + async with task_execution(task.taskId, app.store) as task_ctx: + await task_ctx.update_status("Requesting confirmation...", notify=True) + + # Create TaskSession for task-aware elicitation + task_session = TaskSession( + session=ctx.session, + task_id=task.taskId, + store=app.store, + queue=app.queue, + ) + + # This enqueues the elicitation request + # It will block until response is routed back + elicit_result = await task_session.elicit( + message=f"Confirm processing of: {arguments.get('data', '')}", + requestedSchema={ + "type": "object", + "properties": { + "confirmed": {"type": "boolean"}, + }, + "required": ["confirmed"], + }, + ) + + # Process based on user response + if elicit_result.action == "accept" and elicit_result.content: + confirmed = elicit_result.content.get("confirmed", False) + if confirmed: + result_text = f"Confirmed and processed: {arguments.get('data', '')}" + else: + result_text = "User declined - not processed" + else: + result_text = "Elicitation cancelled or declined" + + await task_ctx.complete( + CallToolResult(content=[TextContent(type="text", text=result_text)]), + notify=True, # Must notify so TaskResultHandler.handle() wakes up + ) + done_event.set() + + app.task_group.start_soon(do_interactive_work) + return CreateTaskResult(task=task) + + return [TextContent(type="text", text="Non-task result")] + + @server.experimental.get_task() + async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: + app = server.request_context.lifespan_context + task = await app.store.get_task(request.params.taskId) + if task is None: + raise ValueError(f"Task {request.params.taskId} not found") + return GetTaskResult( + taskId=task.taskId, + status=task.status, + statusMessage=task.statusMessage, + createdAt=task.createdAt, + lastUpdatedAt=task.lastUpdatedAt, + ttl=task.ttl, + pollInterval=task.pollInterval, + ) + + @server.experimental.get_task_result() + async def handle_get_task_result( + request: GetTaskPayloadRequest, + ) -> GetTaskPayloadResult: + app = server.request_context.lifespan_context + # Use the TaskResultHandler to handle the dequeue-send-wait pattern + return await app.task_result_handler.handle( + request, + server.request_context.session, + server.request_context.request_id, + ) + + # Set up bidirectional streams + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + # Track elicitation requests received by client + elicitation_received: list[ElicitRequest] = [] + + async def elicitation_callback( + context: Any, + params: Any, + ) -> ElicitResult: + """Client-side elicitation callback that responds to elicitations.""" + elicitation_received.append(ElicitRequest(params=params)) + return ElicitResult( + action="accept", + content={"confirmed": True}, + ) + + async def run_server(app_context: AppContext, server_session: ServerSession): + async for message in server_session.incoming_messages: + await server._handle_message(message, server_session, app_context, raise_exceptions=False) + + async with anyio.create_task_group() as tg: + app_context = AppContext( + task_group=tg, + store=store, + queue=queue, + task_result_handler=task_result_handler, + ) + + # Create server session and wire up task result handler + server_session = ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="test-server", + server_version="1.0.0", + capabilities=server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ), + ) + + # Wire up the task result handler for response routing + server_session.add_response_router(task_result_handler) + + async with server_session: + tg.start_soon(run_server, app_context, server_session) + + async with ClientSession( + server_to_client_receive, + client_to_server_send, + elicitation_callback=elicitation_callback, + ) as client_session: + await client_session.initialize() + + # === Step 1: Send task-augmented tool call === + create_result = await client_session.send_request( + ClientRequest( + CallToolRequest( + params=CallToolRequestParams( + name="interactive_tool", + arguments={"data": "important data"}, + task=TaskMetadata(ttl=60000), + ), + ) + ), + CreateTaskResult, + ) + + assert isinstance(create_result, CreateTaskResult) + task_id = create_result.task.taskId + + # === Step 2: Poll until input_required or completed === + max_polls = 100 + task_status: GetTaskResult | None = None + for _ in range(max_polls): + task_status = await client_session.send_request( + ClientRequest(GetTaskRequest(params=GetTaskRequestParams(taskId=task_id))), + GetTaskResult, + ) + + if task_status.status in ("input_required", "completed", "failed"): + break + await anyio.sleep(0) # Yield to allow server to process + + # Task should be in input_required state (waiting for elicitation response) + assert task_status is not None, "Polling loop did not execute" + assert task_status.status == "input_required", f"Expected input_required, got {task_status.status}" + + # === Step 3: Call tasks/result which will deliver elicitation === + # This should: + # 1. Dequeue the elicitation request + # 2. Send it to us (handled by elicitation_callback above) + # 3. Wait for our response + # 4. Continue until task completes + # 5. Return final result + final_result = await client_session.send_request( + ClientRequest(GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(taskId=task_id))), + CallToolResult, + ) + + # === Verify results === + # We should have received and responded to an elicitation + assert len(elicitation_received) == 1 + assert "Confirm processing of: important data" in elicitation_received[0].params.message + + # Final result should reflect our confirmation + assert len(final_result.content) == 1 + content = final_result.content[0] + assert isinstance(content, TextContent) + assert "Confirmed and processed: important data" in content.text + + # Task should be completed + final_status = await client_session.send_request( + ClientRequest(GetTaskRequest(params=GetTaskRequestParams(taskId=task_id))), + GetTaskResult, + ) + assert final_status.status == "completed" + + tg.cancel_scope.cancel() + + store.cleanup() + queue.cleanup() diff --git a/tests/experimental/tasks/server/test_integration.py b/tests/experimental/tasks/server/test_integration.py new file mode 100644 index 0000000000..8871031b4a --- /dev/null +++ b/tests/experimental/tasks/server/test_integration.py @@ -0,0 +1,375 @@ +"""End-to-end integration tests for tasks functionality. + +These tests demonstrate the full task lifecycle: +1. Client sends task-augmented request (tools/call with task metadata) +2. Server creates task and returns CreateTaskResult immediately +3. Background work executes (using task_execution context manager) +4. Client polls with tasks/get +5. Client retrieves result with tasks/result +""" + +from dataclasses import dataclass, field +from typing import Any + +import anyio +import pytest +from anyio import Event +from anyio.abc import TaskGroup + +from mcp.client.session import ClientSession +from mcp.server import Server +from mcp.server.lowlevel import NotificationOptions +from mcp.server.models import InitializationOptions +from mcp.server.session import ServerSession +from mcp.shared.experimental.tasks import InMemoryTaskStore, task_execution +from mcp.shared.message import SessionMessage +from mcp.shared.session import RequestResponder +from mcp.types import ( + TASK_REQUIRED, + CallToolRequest, + CallToolRequestParams, + CallToolResult, + ClientRequest, + ClientResult, + CreateTaskResult, + GetTaskPayloadRequest, + GetTaskPayloadRequestParams, + GetTaskPayloadResult, + GetTaskRequest, + GetTaskRequestParams, + GetTaskResult, + ListTasksRequest, + ListTasksResult, + ServerNotification, + ServerRequest, + TaskMetadata, + TextContent, + Tool, + ToolExecution, +) + + +@dataclass +class AppContext: + """Application context passed via lifespan_context.""" + + task_group: TaskGroup + store: InMemoryTaskStore + # Events to signal when tasks complete (for testing without sleeps) + task_done_events: dict[str, Event] = field(default_factory=lambda: {}) + + +@pytest.mark.anyio +async def test_task_lifecycle_with_task_execution() -> None: + """ + Test the complete task lifecycle using the task_execution pattern. + + This demonstrates the recommended way to implement task-augmented tools: + 1. Create task in store + 2. Spawn work using task_execution() context manager + 3. Return CreateTaskResult immediately + 4. Work executes in background, auto-fails on exception + """ + # Note: We bypass the normal lifespan mechanism and pass context directly to _handle_message + server: Server[AppContext, Any] = Server("test-tasks") # type: ignore[assignment] + store = InMemoryTaskStore() + + @server.list_tools() + async def list_tools(): + return [ + Tool( + name="process_data", + description="Process data asynchronously", + inputSchema={ + "type": "object", + "properties": {"input": {"type": "string"}}, + }, + execution=ToolExecution(taskSupport=TASK_REQUIRED), + ) + ] + + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent] | CreateTaskResult: + ctx = server.request_context + app = ctx.lifespan_context + + if name == "process_data" and ctx.experimental.is_task: + # 1. Create task in store + task_metadata = ctx.experimental.task_metadata + assert task_metadata is not None + task = await app.store.create_task(task_metadata) + + # 2. Create event to signal completion (for testing) + done_event = Event() + app.task_done_events[task.taskId] = done_event + + # 3. Define work function using task_execution for safety + async def do_work(): + async with task_execution(task.taskId, app.store) as task_ctx: + await task_ctx.update_status("Processing input...", notify=False) + # Simulate work + input_value = arguments.get("input", "") + result_text = f"Processed: {input_value.upper()}" + await task_ctx.complete( + CallToolResult(content=[TextContent(type="text", text=result_text)]), + notify=False, + ) + # Signal completion + done_event.set() + + # 4. Spawn work in task group (from lifespan_context) + app.task_group.start_soon(do_work) + + # 5. Return CreateTaskResult immediately + return CreateTaskResult(task=task) + + # Non-task execution path + return [TextContent(type="text", text="Sync result")] + + # Register task query handlers (delegate to store) + @server.experimental.get_task() + async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: + app = server.request_context.lifespan_context + task = await app.store.get_task(request.params.taskId) + if task is None: + raise ValueError(f"Task {request.params.taskId} not found") + return GetTaskResult( + taskId=task.taskId, + status=task.status, + statusMessage=task.statusMessage, + createdAt=task.createdAt, + lastUpdatedAt=task.lastUpdatedAt, + ttl=task.ttl, + pollInterval=task.pollInterval, + ) + + @server.experimental.get_task_result() + async def handle_get_task_result( + request: GetTaskPayloadRequest, + ) -> GetTaskPayloadResult: + app = server.request_context.lifespan_context + result = await app.store.get_result(request.params.taskId) + if result is None: + raise ValueError(f"Result for task {request.params.taskId} not found") + assert isinstance(result, CallToolResult) + # Return as GetTaskPayloadResult (which accepts extra fields) + return GetTaskPayloadResult(**result.model_dump()) + + @server.experimental.list_tasks() + async def handle_list_tasks(request: ListTasksRequest) -> ListTasksResult: + app = server.request_context.lifespan_context + tasks, next_cursor = await app.store.list_tasks(cursor=request.params.cursor if request.params else None) + return ListTasksResult(tasks=tasks, nextCursor=next_cursor) + + # Set up client-server communication + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + async def message_handler( + message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, + ) -> None: + if isinstance(message, Exception): + raise message + + async def run_server(app_context: AppContext): + async with ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="test-server", + server_version="1.0.0", + capabilities=server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ), + ) as server_session: + async for message in server_session.incoming_messages: + await server._handle_message(message, server_session, app_context, raise_exceptions=False) + + async with anyio.create_task_group() as tg: + # Create app context with task group and store + app_context = AppContext(task_group=tg, store=store) + tg.start_soon(run_server, app_context) + + async with ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as client_session: + await client_session.initialize() + + # === Step 1: Send task-augmented tool call === + create_result = await client_session.send_request( + ClientRequest( + CallToolRequest( + params=CallToolRequestParams( + name="process_data", + arguments={"input": "hello world"}, + task=TaskMetadata(ttl=60000), + ), + ) + ), + CreateTaskResult, + ) + + assert isinstance(create_result, CreateTaskResult) + assert create_result.task.status == "working" + task_id = create_result.task.taskId + + # === Step 2: Wait for task to complete === + await app_context.task_done_events[task_id].wait() + + task_status = await client_session.send_request( + ClientRequest(GetTaskRequest(params=GetTaskRequestParams(taskId=task_id))), + GetTaskResult, + ) + + assert task_status.taskId == task_id + assert task_status.status == "completed" + + # === Step 3: Retrieve the actual result === + task_result = await client_session.send_request( + ClientRequest(GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(taskId=task_id))), + CallToolResult, + ) + + assert len(task_result.content) == 1 + content = task_result.content[0] + assert isinstance(content, TextContent) + assert content.text == "Processed: HELLO WORLD" + + tg.cancel_scope.cancel() + + store.cleanup() + + +@pytest.mark.anyio +async def test_task_auto_fails_on_exception() -> None: + """Test that task_execution automatically fails the task on unhandled exception.""" + # Note: We bypass the normal lifespan mechanism and pass context directly to _handle_message + server: Server[AppContext, Any] = Server("test-tasks-failure") # type: ignore[assignment] + store = InMemoryTaskStore() + + @server.list_tools() + async def list_tools(): + return [ + Tool( + name="failing_task", + description="A task that fails", + inputSchema={"type": "object", "properties": {}}, + ) + ] + + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent] | CreateTaskResult: + ctx = server.request_context + app = ctx.lifespan_context + + if name == "failing_task" and ctx.experimental.is_task: + task_metadata = ctx.experimental.task_metadata + assert task_metadata is not None + task = await app.store.create_task(task_metadata) + + # Create event to signal completion (for testing) + done_event = Event() + app.task_done_events[task.taskId] = done_event + + async def do_failing_work(): + async with task_execution(task.taskId, app.store) as task_ctx: + await task_ctx.update_status("About to fail...", notify=False) + raise RuntimeError("Something went wrong!") + # Note: complete() is never called, but task_execution + # will automatically call fail() due to the exception + # This line is reached because task_execution suppresses the exception + done_event.set() + + app.task_group.start_soon(do_failing_work) + return CreateTaskResult(task=task) + + return [TextContent(type="text", text="Sync")] + + @server.experimental.get_task() + async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: + app = server.request_context.lifespan_context + task = await app.store.get_task(request.params.taskId) + if task is None: + raise ValueError(f"Task {request.params.taskId} not found") + return GetTaskResult( + taskId=task.taskId, + status=task.status, + statusMessage=task.statusMessage, + createdAt=task.createdAt, + lastUpdatedAt=task.lastUpdatedAt, + ttl=task.ttl, + pollInterval=task.pollInterval, + ) + + # Set up streams + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + async def message_handler( + message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, + ) -> None: + if isinstance(message, Exception): + raise message + + async def run_server(app_context: AppContext): + async with ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="test-server", + server_version="1.0.0", + capabilities=server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ), + ) as server_session: + async for message in server_session.incoming_messages: + await server._handle_message(message, server_session, app_context, raise_exceptions=False) + + async with anyio.create_task_group() as tg: + app_context = AppContext(task_group=tg, store=store) + tg.start_soon(run_server, app_context) + + async with ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as client_session: + await client_session.initialize() + + # Send task request + create_result = await client_session.send_request( + ClientRequest( + CallToolRequest( + params=CallToolRequestParams( + name="failing_task", + arguments={}, + task=TaskMetadata(ttl=60000), + ), + ) + ), + CreateTaskResult, + ) + + task_id = create_result.task.taskId + + # Wait for task to complete (even though it fails) + await app_context.task_done_events[task_id].wait() + + # Check that task was auto-failed + task_status = await client_session.send_request( + ClientRequest(GetTaskRequest(params=GetTaskRequestParams(taskId=task_id))), + GetTaskResult, + ) + + assert task_status.status == "failed" + assert task_status.statusMessage == "Something went wrong!" + + tg.cancel_scope.cancel() + + store.cleanup() diff --git a/tests/experimental/tasks/server/test_sampling_flow.py b/tests/experimental/tasks/server/test_sampling_flow.py new file mode 100644 index 0000000000..c3f4894592 --- /dev/null +++ b/tests/experimental/tasks/server/test_sampling_flow.py @@ -0,0 +1,317 @@ +""" +Integration test for task sampling flow. + +This tests the complete sampling flow: +1. Client sends task-augmented tool call +2. Server creates task, returns CreateTaskResult immediately +3. Server handler uses TaskSession.create_message() to request LLM completion +4. Client polls, sees input_required status +5. Client calls tasks/result which delivers the sampling request +6. Client responds with CreateMessageResult +7. Response is routed back to server handler +8. Handler completes task +9. Client receives final result +""" + +from dataclasses import dataclass, field +from typing import Any + +import anyio +import pytest +from anyio import Event +from anyio.abc import TaskGroup + +from mcp.client.session import ClientSession +from mcp.server import Server +from mcp.server.lowlevel import NotificationOptions +from mcp.server.models import InitializationOptions +from mcp.server.session import ServerSession +from mcp.shared.experimental.tasks import ( + InMemoryTaskMessageQueue, + InMemoryTaskStore, + TaskResultHandler, + TaskSession, + task_execution, +) +from mcp.shared.message import SessionMessage +from mcp.types import ( + TASK_REQUIRED, + CallToolRequest, + CallToolRequestParams, + CallToolResult, + ClientRequest, + CreateMessageRequest, + CreateMessageResult, + CreateTaskResult, + GetTaskPayloadRequest, + GetTaskPayloadRequestParams, + GetTaskPayloadResult, + GetTaskRequest, + GetTaskRequestParams, + GetTaskResult, + SamplingMessage, + TaskMetadata, + TextContent, + Tool, + ToolExecution, +) + + +@dataclass +class AppContext: + """Application context with task infrastructure.""" + + task_group: TaskGroup + store: InMemoryTaskStore + queue: InMemoryTaskMessageQueue + task_result_handler: TaskResultHandler + # Events to signal when tasks complete (for testing without sleeps) + task_done_events: dict[str, Event] = field(default_factory=lambda: {}) + + +@pytest.mark.anyio +async def test_sampling_during_task_with_response_routing() -> None: + """ + Test the complete sampling flow with response routing. + + This is an end-to-end test that verifies: + - TaskSession.create_message() enqueues the request + - TaskResultHandler delivers it via tasks/result + - Client responds with CreateMessageResult + - Response is routed back to the waiting resolver + - Handler continues and completes + """ + server: Server[AppContext, Any] = Server("test-sampling") # type: ignore[assignment] + store = InMemoryTaskStore() + queue = InMemoryTaskMessageQueue() + task_result_handler = TaskResultHandler(store, queue) + + @server.list_tools() + async def list_tools(): + return [ + Tool( + name="ai_assistant_tool", + description="A tool that uses AI for processing", + inputSchema={ + "type": "object", + "properties": {"question": {"type": "string"}}, + }, + execution=ToolExecution(taskSupport=TASK_REQUIRED), + ) + ] + + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent] | CreateTaskResult: + ctx = server.request_context + app = ctx.lifespan_context + + if name == "ai_assistant_tool" and ctx.experimental.is_task: + task_metadata = ctx.experimental.task_metadata + assert task_metadata is not None + task = await app.store.create_task(task_metadata) + + done_event = Event() + app.task_done_events[task.taskId] = done_event + + async def do_ai_work(): + async with task_execution(task.taskId, app.store) as task_ctx: + await task_ctx.update_status("Requesting AI assistance...", notify=True) + + # Create TaskSession for task-aware sampling + task_session = TaskSession( + session=ctx.session, + task_id=task.taskId, + store=app.store, + queue=app.queue, + ) + + question = arguments.get("question", "What is 2+2?") + + # This enqueues the sampling request + # It will block until response is routed back + sampling_result = await task_session.create_message( + messages=[ + SamplingMessage( + role="user", + content=TextContent(type="text", text=question), + ) + ], + max_tokens=100, + system_prompt="You are a helpful assistant. Answer concisely.", + ) + + # Process the AI response + ai_response = "Unknown" + if isinstance(sampling_result.content, TextContent): + ai_response = sampling_result.content.text + + result_text = f"AI answered: {ai_response}" + + await task_ctx.complete( + CallToolResult(content=[TextContent(type="text", text=result_text)]), + notify=True, # Must notify so TaskResultHandler.handle() wakes up + ) + done_event.set() + + app.task_group.start_soon(do_ai_work) + return CreateTaskResult(task=task) + + return [TextContent(type="text", text="Non-task result")] + + @server.experimental.get_task() + async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: + app = server.request_context.lifespan_context + task = await app.store.get_task(request.params.taskId) + if task is None: + raise ValueError(f"Task {request.params.taskId} not found") + return GetTaskResult( + taskId=task.taskId, + status=task.status, + statusMessage=task.statusMessage, + createdAt=task.createdAt, + lastUpdatedAt=task.lastUpdatedAt, + ttl=task.ttl, + pollInterval=task.pollInterval, + ) + + @server.experimental.get_task_result() + async def handle_get_task_result( + request: GetTaskPayloadRequest, + ) -> GetTaskPayloadResult: + app = server.request_context.lifespan_context + # Use the TaskResultHandler to handle the dequeue-send-wait pattern + return await app.task_result_handler.handle( + request, + server.request_context.session, + server.request_context.request_id, + ) + + # Set up bidirectional streams + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + # Track sampling requests received by client + sampling_requests_received: list[CreateMessageRequest] = [] + + async def sampling_callback( + context: Any, + params: Any, + ) -> CreateMessageResult: + """Client-side sampling callback that responds to sampling requests.""" + sampling_requests_received.append(CreateMessageRequest(params=params)) + # Return a mock AI response + return CreateMessageResult( + model="test-model", + role="assistant", + content=TextContent(type="text", text="The answer is 4"), + ) + + async def run_server(app_context: AppContext, server_session: ServerSession): + async for message in server_session.incoming_messages: + await server._handle_message(message, server_session, app_context, raise_exceptions=False) + + async with anyio.create_task_group() as tg: + app_context = AppContext( + task_group=tg, + store=store, + queue=queue, + task_result_handler=task_result_handler, + ) + + # Create server session and wire up task result handler + server_session = ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="test-server", + server_version="1.0.0", + capabilities=server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ), + ) + + # Wire up the task result handler for response routing + server_session.add_response_router(task_result_handler) + + async with server_session: + tg.start_soon(run_server, app_context, server_session) + + async with ClientSession( + server_to_client_receive, + client_to_server_send, + sampling_callback=sampling_callback, + ) as client_session: + await client_session.initialize() + + # === Step 1: Send task-augmented tool call === + create_result = await client_session.send_request( + ClientRequest( + CallToolRequest( + params=CallToolRequestParams( + name="ai_assistant_tool", + arguments={"question": "What is 2+2?"}, + task=TaskMetadata(ttl=60000), + ), + ) + ), + CreateTaskResult, + ) + + assert isinstance(create_result, CreateTaskResult) + task_id = create_result.task.taskId + + # === Step 2: Poll until input_required or completed === + max_polls = 100 + task_status: GetTaskResult | None = None + for _ in range(max_polls): + task_status = await client_session.send_request( + ClientRequest(GetTaskRequest(params=GetTaskRequestParams(taskId=task_id))), + GetTaskResult, + ) + + if task_status.status in ("input_required", "completed", "failed"): + break + await anyio.sleep(0) # Yield to allow server to process + + # Task should be in input_required state (waiting for sampling response) + assert task_status is not None, "Polling loop did not execute" + assert task_status.status == "input_required", f"Expected input_required, got {task_status.status}" + + # === Step 3: Call tasks/result which will deliver sampling request === + # This should: + # 1. Dequeue the sampling request + # 2. Send it to us (handled by sampling_callback above) + # 3. Wait for our response + # 4. Continue until task completes + # 5. Return final result + final_result = await client_session.send_request( + ClientRequest(GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(taskId=task_id))), + CallToolResult, + ) + + # === Verify results === + # We should have received and responded to a sampling request + assert len(sampling_requests_received) == 1 + first_message_content = sampling_requests_received[0].params.messages[0].content + assert isinstance(first_message_content, TextContent) + assert first_message_content.text == "What is 2+2?" + + # Final result should reflect the AI response + assert len(final_result.content) == 1 + content = final_result.content[0] + assert isinstance(content, TextContent) + assert "AI answered: The answer is 4" in content.text + + # Task should be completed + final_status = await client_session.send_request( + ClientRequest(GetTaskRequest(params=GetTaskRequestParams(taskId=task_id))), + GetTaskResult, + ) + assert final_status.status == "completed" + + tg.cancel_scope.cancel() + + store.cleanup() + queue.cleanup() diff --git a/tests/experimental/tasks/server/test_server.py b/tests/experimental/tasks/server/test_server.py new file mode 100644 index 0000000000..8c442ef9f4 --- /dev/null +++ b/tests/experimental/tasks/server/test_server.py @@ -0,0 +1,452 @@ +"""Tests for server-side task support (handlers, capabilities, integration).""" + +from datetime import datetime, timezone +from typing import Any + +import anyio +import pytest + +from mcp.client.session import ClientSession +from mcp.server import Server +from mcp.server.lowlevel import NotificationOptions +from mcp.server.models import InitializationOptions +from mcp.server.session import ServerSession +from mcp.shared.message import SessionMessage +from mcp.shared.session import RequestResponder +from mcp.types import ( + TASK_FORBIDDEN, + TASK_OPTIONAL, + TASK_REQUIRED, + CallToolRequest, + CallToolRequestParams, + CallToolResult, + CancelTaskRequest, + CancelTaskRequestParams, + CancelTaskResult, + ClientRequest, + ClientResult, + GetTaskPayloadRequest, + GetTaskPayloadRequestParams, + GetTaskPayloadResult, + GetTaskRequest, + GetTaskRequestParams, + GetTaskResult, + ListTasksRequest, + ListTasksResult, + ListToolsRequest, + ListToolsResult, + ServerNotification, + ServerRequest, + ServerResult, + Task, + TaskMetadata, + TextContent, + Tool, + ToolExecution, +) + +# --- Experimental handler tests --- + + +@pytest.mark.anyio +async def test_list_tasks_handler() -> None: + """Test that experimental list_tasks handler works.""" + server = Server("test") + + now = datetime.now(timezone.utc) + test_tasks = [ + Task( + taskId="task-1", + status="working", + createdAt=now, + lastUpdatedAt=now, + ttl=60000, + pollInterval=1000, + ), + Task( + taskId="task-2", + status="completed", + createdAt=now, + lastUpdatedAt=now, + ttl=60000, + pollInterval=1000, + ), + ] + + @server.experimental.list_tasks() + async def handle_list_tasks(request: ListTasksRequest) -> ListTasksResult: + return ListTasksResult(tasks=test_tasks) + + handler = server.request_handlers[ListTasksRequest] + request = ListTasksRequest(method="tasks/list") + result = await handler(request) + + assert isinstance(result, ServerResult) + assert isinstance(result.root, ListTasksResult) + assert len(result.root.tasks) == 2 + assert result.root.tasks[0].taskId == "task-1" + assert result.root.tasks[1].taskId == "task-2" + + +@pytest.mark.anyio +async def test_get_task_handler() -> None: + """Test that experimental get_task handler works.""" + server = Server("test") + + @server.experimental.get_task() + async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: + now = datetime.now(timezone.utc) + return GetTaskResult( + taskId=request.params.taskId, + status="working", + createdAt=now, + lastUpdatedAt=now, + ttl=60000, + pollInterval=1000, + ) + + handler = server.request_handlers[GetTaskRequest] + request = GetTaskRequest( + method="tasks/get", + params=GetTaskRequestParams(taskId="test-task-123"), + ) + result = await handler(request) + + assert isinstance(result, ServerResult) + assert isinstance(result.root, GetTaskResult) + assert result.root.taskId == "test-task-123" + assert result.root.status == "working" + + +@pytest.mark.anyio +async def test_get_task_result_handler() -> None: + """Test that experimental get_task_result handler works.""" + server = Server("test") + + @server.experimental.get_task_result() + async def handle_get_task_result(request: GetTaskPayloadRequest) -> GetTaskPayloadResult: + return GetTaskPayloadResult() + + handler = server.request_handlers[GetTaskPayloadRequest] + request = GetTaskPayloadRequest( + method="tasks/result", + params=GetTaskPayloadRequestParams(taskId="test-task-123"), + ) + result = await handler(request) + + assert isinstance(result, ServerResult) + assert isinstance(result.root, GetTaskPayloadResult) + + +@pytest.mark.anyio +async def test_cancel_task_handler() -> None: + """Test that experimental cancel_task handler works.""" + server = Server("test") + + @server.experimental.cancel_task() + async def handle_cancel_task(request: CancelTaskRequest) -> CancelTaskResult: + now = datetime.now(timezone.utc) + return CancelTaskResult( + taskId=request.params.taskId, + status="cancelled", + createdAt=now, + lastUpdatedAt=now, + ttl=60000, + ) + + handler = server.request_handlers[CancelTaskRequest] + request = CancelTaskRequest( + method="tasks/cancel", + params=CancelTaskRequestParams(taskId="test-task-123"), + ) + result = await handler(request) + + assert isinstance(result, ServerResult) + assert isinstance(result.root, CancelTaskResult) + assert result.root.taskId == "test-task-123" + assert result.root.status == "cancelled" + + +# --- Server capabilities tests --- + + +@pytest.mark.anyio +async def test_server_capabilities_include_tasks() -> None: + """Test that server capabilities include tasks when handlers are registered.""" + server = Server("test") + + @server.experimental.list_tasks() + async def handle_list_tasks(request: ListTasksRequest) -> ListTasksResult: + return ListTasksResult(tasks=[]) + + @server.experimental.cancel_task() + async def handle_cancel_task(request: CancelTaskRequest) -> CancelTaskResult: + now = datetime.now(timezone.utc) + return CancelTaskResult( + taskId=request.params.taskId, + status="cancelled", + createdAt=now, + lastUpdatedAt=now, + ttl=None, + ) + + capabilities = server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ) + + assert capabilities.tasks is not None + assert capabilities.tasks.list is not None + assert capabilities.tasks.cancel is not None + assert capabilities.tasks.requests is not None + assert capabilities.tasks.requests.tools is not None + + +@pytest.mark.anyio +async def test_server_capabilities_partial_tasks() -> None: + """Test capabilities with only some task handlers registered.""" + server = Server("test") + + @server.experimental.list_tasks() + async def handle_list_tasks(request: ListTasksRequest) -> ListTasksResult: + return ListTasksResult(tasks=[]) + + # Only list_tasks registered, not cancel_task + + capabilities = server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ) + + assert capabilities.tasks is not None + assert capabilities.tasks.list is not None + assert capabilities.tasks.cancel is None # Not registered + + +# --- Tool annotation tests --- + + +@pytest.mark.anyio +async def test_tool_with_task_execution_metadata() -> None: + """Test that tools can declare task execution mode.""" + server = Server("test") + + @server.list_tools() + async def list_tools(): + return [ + Tool( + name="quick_tool", + description="Fast tool", + inputSchema={"type": "object", "properties": {}}, + execution=ToolExecution(taskSupport=TASK_FORBIDDEN), + ), + Tool( + name="long_tool", + description="Long running tool", + inputSchema={"type": "object", "properties": {}}, + execution=ToolExecution(taskSupport=TASK_REQUIRED), + ), + Tool( + name="flexible_tool", + description="Can be either", + inputSchema={"type": "object", "properties": {}}, + execution=ToolExecution(taskSupport=TASK_OPTIONAL), + ), + ] + + tools_handler = server.request_handlers[ListToolsRequest] + request = ListToolsRequest(method="tools/list") + result = await tools_handler(request) + + assert isinstance(result, ServerResult) + assert isinstance(result.root, ListToolsResult) + tools = result.root.tools + + assert tools[0].execution is not None + assert tools[0].execution.taskSupport == TASK_FORBIDDEN + assert tools[1].execution is not None + assert tools[1].execution.taskSupport == TASK_REQUIRED + assert tools[2].execution is not None + assert tools[2].execution.taskSupport == TASK_OPTIONAL + + +# --- Integration tests --- + + +@pytest.mark.anyio +async def test_task_metadata_in_call_tool_request() -> None: + """Test that task metadata is accessible via RequestContext when calling a tool.""" + server = Server("test") + captured_task_metadata: TaskMetadata | None = None + + @server.list_tools() + async def list_tools(): + return [ + Tool( + name="long_task", + description="A long running task", + inputSchema={"type": "object", "properties": {}}, + execution=ToolExecution(taskSupport="optional"), + ) + ] + + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]: + nonlocal captured_task_metadata + ctx = server.request_context + captured_task_metadata = ctx.experimental.task_metadata + return [TextContent(type="text", text="done")] + + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + async def message_handler( + message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, + ) -> None: + if isinstance(message, Exception): + raise message + + async def run_server(): + async with ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="test-server", + server_version="1.0.0", + capabilities=server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ), + ) as server_session: + async with anyio.create_task_group() as tg: + + async def handle_messages(): + async for message in server_session.incoming_messages: + await server._handle_message(message, server_session, {}, False) + + tg.start_soon(handle_messages) + await anyio.sleep_forever() + + async with anyio.create_task_group() as tg: + tg.start_soon(run_server) + + async with ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as client_session: + await client_session.initialize() + + # Call tool with task metadata + await client_session.send_request( + ClientRequest( + CallToolRequest( + params=CallToolRequestParams( + name="long_task", + arguments={}, + task=TaskMetadata(ttl=60000), + ), + ) + ), + CallToolResult, + ) + + tg.cancel_scope.cancel() + + assert captured_task_metadata is not None + assert captured_task_metadata.ttl == 60000 + + +@pytest.mark.anyio +async def test_task_metadata_is_task_property() -> None: + """Test that RequestContext.experimental.is_task works correctly.""" + server = Server("test") + is_task_values: list[bool] = [] + + @server.list_tools() + async def list_tools(): + return [ + Tool( + name="test_tool", + description="Test tool", + inputSchema={"type": "object", "properties": {}}, + ) + ] + + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]: + ctx = server.request_context + is_task_values.append(ctx.experimental.is_task) + return [TextContent(type="text", text="done")] + + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + async def message_handler( + message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, + ) -> None: + if isinstance(message, Exception): + raise message + + async def run_server(): + async with ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="test-server", + server_version="1.0.0", + capabilities=server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ), + ) as server_session: + async with anyio.create_task_group() as tg: + + async def handle_messages(): + async for message in server_session.incoming_messages: + await server._handle_message(message, server_session, {}, False) + + tg.start_soon(handle_messages) + await anyio.sleep_forever() + + async with anyio.create_task_group() as tg: + tg.start_soon(run_server) + + async with ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as client_session: + await client_session.initialize() + + # Call without task metadata + await client_session.send_request( + ClientRequest( + CallToolRequest( + params=CallToolRequestParams(name="test_tool", arguments={}), + ) + ), + CallToolResult, + ) + + # Call with task metadata + await client_session.send_request( + ClientRequest( + CallToolRequest( + params=CallToolRequestParams( + name="test_tool", + arguments={}, + task=TaskMetadata(ttl=60000), + ), + ) + ), + CallToolResult, + ) + + tg.cancel_scope.cancel() + + assert len(is_task_values) == 2 + assert is_task_values[0] is False # First call without task + assert is_task_values[1] is True # Second call with task diff --git a/tests/experimental/tasks/server/test_store.py b/tests/experimental/tasks/server/test_store.py new file mode 100644 index 0000000000..b880253d1a --- /dev/null +++ b/tests/experimental/tasks/server/test_store.py @@ -0,0 +1,499 @@ +"""Tests for InMemoryTaskStore.""" + +from datetime import datetime, timedelta, timezone + +import pytest + +from mcp.shared.exceptions import McpError +from mcp.shared.experimental.tasks import InMemoryTaskStore, cancel_task +from mcp.types import INVALID_PARAMS, CallToolResult, TaskMetadata, TextContent + + +@pytest.mark.anyio +async def test_create_and_get() -> None: + """Test InMemoryTaskStore create and get operations.""" + store = InMemoryTaskStore() + + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + + assert task.taskId is not None + assert task.status == "working" + assert task.ttl == 60000 + + retrieved = await store.get_task(task.taskId) + assert retrieved is not None + assert retrieved.taskId == task.taskId + assert retrieved.status == "working" + + store.cleanup() + + +@pytest.mark.anyio +async def test_create_with_custom_id() -> None: + """Test InMemoryTaskStore create with custom task ID.""" + store = InMemoryTaskStore() + + task = await store.create_task( + metadata=TaskMetadata(ttl=60000), + task_id="my-custom-id", + ) + + assert task.taskId == "my-custom-id" + assert task.status == "working" + + retrieved = await store.get_task("my-custom-id") + assert retrieved is not None + assert retrieved.taskId == "my-custom-id" + + store.cleanup() + + +@pytest.mark.anyio +async def test_create_duplicate_id_raises() -> None: + """Test that creating a task with duplicate ID raises.""" + store = InMemoryTaskStore() + + await store.create_task(metadata=TaskMetadata(ttl=60000), task_id="duplicate") + + with pytest.raises(ValueError, match="already exists"): + await store.create_task(metadata=TaskMetadata(ttl=60000), task_id="duplicate") + + store.cleanup() + + +@pytest.mark.anyio +async def test_get_nonexistent_returns_none() -> None: + """Test that getting a nonexistent task returns None.""" + store = InMemoryTaskStore() + + retrieved = await store.get_task("nonexistent") + assert retrieved is None + + store.cleanup() + + +@pytest.mark.anyio +async def test_update_status() -> None: + """Test InMemoryTaskStore status updates.""" + store = InMemoryTaskStore() + + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + + updated = await store.update_task(task.taskId, status="completed", status_message="All done!") + + assert updated.status == "completed" + assert updated.statusMessage == "All done!" + + retrieved = await store.get_task(task.taskId) + assert retrieved is not None + assert retrieved.status == "completed" + assert retrieved.statusMessage == "All done!" + + store.cleanup() + + +@pytest.mark.anyio +async def test_update_nonexistent_raises() -> None: + """Test that updating a nonexistent task raises.""" + store = InMemoryTaskStore() + + with pytest.raises(ValueError, match="not found"): + await store.update_task("nonexistent", status="completed") + + store.cleanup() + + +@pytest.mark.anyio +async def test_store_and_get_result() -> None: + """Test InMemoryTaskStore result storage and retrieval.""" + store = InMemoryTaskStore() + + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + + # Store result + result = CallToolResult(content=[TextContent(type="text", text="Result data")]) + await store.store_result(task.taskId, result) + + # Retrieve result + retrieved_result = await store.get_result(task.taskId) + assert retrieved_result == result + + store.cleanup() + + +@pytest.mark.anyio +async def test_get_result_nonexistent_returns_none() -> None: + """Test that getting result for nonexistent task returns None.""" + store = InMemoryTaskStore() + + result = await store.get_result("nonexistent") + assert result is None + + store.cleanup() + + +@pytest.mark.anyio +async def test_get_result_no_result_returns_none() -> None: + """Test that getting result when none stored returns None.""" + store = InMemoryTaskStore() + + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + result = await store.get_result(task.taskId) + assert result is None + + store.cleanup() + + +@pytest.mark.anyio +async def test_list_tasks() -> None: + """Test InMemoryTaskStore list operation.""" + store = InMemoryTaskStore() + + # Create multiple tasks + for _ in range(3): + await store.create_task(metadata=TaskMetadata(ttl=60000)) + + tasks, next_cursor = await store.list_tasks() + assert len(tasks) == 3 + assert next_cursor is None # Less than page size + + store.cleanup() + + +@pytest.mark.anyio +async def test_list_tasks_pagination() -> None: + """Test InMemoryTaskStore pagination.""" + store = InMemoryTaskStore(page_size=2) + + # Create 5 tasks + for _ in range(5): + await store.create_task(metadata=TaskMetadata(ttl=60000)) + + # First page + tasks, next_cursor = await store.list_tasks() + assert len(tasks) == 2 + assert next_cursor is not None + + # Second page + tasks, next_cursor = await store.list_tasks(cursor=next_cursor) + assert len(tasks) == 2 + assert next_cursor is not None + + # Third page (last) + tasks, next_cursor = await store.list_tasks(cursor=next_cursor) + assert len(tasks) == 1 + assert next_cursor is None + + store.cleanup() + + +@pytest.mark.anyio +async def test_list_tasks_invalid_cursor() -> None: + """Test that invalid cursor raises.""" + store = InMemoryTaskStore() + + await store.create_task(metadata=TaskMetadata(ttl=60000)) + + with pytest.raises(ValueError, match="Invalid cursor"): + await store.list_tasks(cursor="invalid-cursor") + + store.cleanup() + + +@pytest.mark.anyio +async def test_delete_task() -> None: + """Test InMemoryTaskStore delete operation.""" + store = InMemoryTaskStore() + + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + + deleted = await store.delete_task(task.taskId) + assert deleted is True + + retrieved = await store.get_task(task.taskId) + assert retrieved is None + + # Delete non-existent + deleted = await store.delete_task(task.taskId) + assert deleted is False + + store.cleanup() + + +@pytest.mark.anyio +async def test_get_all_tasks_helper() -> None: + """Test the get_all_tasks debugging helper.""" + store = InMemoryTaskStore() + + await store.create_task(metadata=TaskMetadata(ttl=60000)) + await store.create_task(metadata=TaskMetadata(ttl=60000)) + + all_tasks = store.get_all_tasks() + assert len(all_tasks) == 2 + + store.cleanup() + + +@pytest.mark.anyio +async def test_store_result_nonexistent_raises() -> None: + """Test that storing result for nonexistent task raises ValueError.""" + store = InMemoryTaskStore() + + result = CallToolResult(content=[TextContent(type="text", text="Result")]) + + with pytest.raises(ValueError, match="not found"): + await store.store_result("nonexistent-id", result) + + store.cleanup() + + +@pytest.mark.anyio +async def test_create_task_with_null_ttl() -> None: + """Test creating task with null TTL (never expires).""" + store = InMemoryTaskStore() + + task = await store.create_task(metadata=TaskMetadata(ttl=None)) + + assert task.ttl is None + + # Task should persist (not expire) + retrieved = await store.get_task(task.taskId) + assert retrieved is not None + + store.cleanup() + + +@pytest.mark.anyio +async def test_task_expiration_cleanup() -> None: + """Test that expired tasks are cleaned up lazily.""" + store = InMemoryTaskStore() + + # Create a task with very short TTL + task = await store.create_task(metadata=TaskMetadata(ttl=1)) # 1ms TTL + + # Manually force the expiry to be in the past + stored = store._tasks.get(task.taskId) + assert stored is not None + stored.expires_at = datetime.now(timezone.utc) - timedelta(seconds=10) + + # Task should still exist in internal dict but be expired + assert task.taskId in store._tasks + + # Any access operation should clean up expired tasks + # list_tasks triggers cleanup + tasks, _ = await store.list_tasks() + + # Expired task should be cleaned up + assert task.taskId not in store._tasks + assert len(tasks) == 0 + + store.cleanup() + + +@pytest.mark.anyio +async def test_task_with_null_ttl_never_expires() -> None: + """Test that tasks with null TTL never expire during cleanup.""" + + store = InMemoryTaskStore() + + # Create task with null TTL + task = await store.create_task(metadata=TaskMetadata(ttl=None)) + + # Verify internal storage has no expiry + stored = store._tasks.get(task.taskId) + assert stored is not None + assert stored.expires_at is None + + # Access operations should NOT remove this task + await store.list_tasks() + await store.get_task(task.taskId) + + # Task should still exist + assert task.taskId in store._tasks + retrieved = await store.get_task(task.taskId) + assert retrieved is not None + + store.cleanup() + + +@pytest.mark.anyio +async def test_terminal_task_ttl_reset() -> None: + """Test that TTL is reset when task enters terminal state.""" + + store = InMemoryTaskStore() + + # Create task with short TTL + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) # 60s + + # Get the initial expiry + stored = store._tasks.get(task.taskId) + assert stored is not None + initial_expiry = stored.expires_at + assert initial_expiry is not None + + # Update to terminal state (completed) + await store.update_task(task.taskId, status="completed") + + # Expiry should be reset to a new time (from now + TTL) + new_expiry = stored.expires_at + assert new_expiry is not None + assert new_expiry >= initial_expiry + + store.cleanup() + + +@pytest.mark.anyio +async def test_terminal_status_transition_rejected() -> None: + """Test that transitions from terminal states are rejected. + + Per spec: Terminal states (completed, failed, cancelled) MUST NOT + transition to any other status. + """ + store = InMemoryTaskStore() + + # Test each terminal status + for terminal_status in ("completed", "failed", "cancelled"): + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + + # Move to terminal state + await store.update_task(task.taskId, status=terminal_status) + + # Attempting to transition to any other status should raise + with pytest.raises(ValueError, match="Cannot transition from terminal status"): + await store.update_task(task.taskId, status="working") + + # Also test transitioning to another terminal state + other_terminal = "failed" if terminal_status != "failed" else "completed" + with pytest.raises(ValueError, match="Cannot transition from terminal status"): + await store.update_task(task.taskId, status=other_terminal) + + store.cleanup() + + +@pytest.mark.anyio +async def test_terminal_status_allows_same_status() -> None: + """Test that setting the same terminal status doesn't raise. + + This is not a transition, so it should be allowed (no-op). + """ + store = InMemoryTaskStore() + + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + await store.update_task(task.taskId, status="completed") + + # Setting the same status should not raise + updated = await store.update_task(task.taskId, status="completed") + assert updated.status == "completed" + + # Updating just the message should also work + updated = await store.update_task(task.taskId, status_message="Updated message") + assert updated.statusMessage == "Updated message" + + store.cleanup() + + +# ============================================================================= +# cancel_task helper function tests +# ============================================================================= + + +@pytest.mark.anyio +async def test_cancel_task_succeeds_for_working_task() -> None: + """Test cancel_task helper succeeds for a working task.""" + store = InMemoryTaskStore() + + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + assert task.status == "working" + + result = await cancel_task(store, task.taskId) + + assert result.taskId == task.taskId + assert result.status == "cancelled" + + # Verify store is updated + retrieved = await store.get_task(task.taskId) + assert retrieved is not None + assert retrieved.status == "cancelled" + + store.cleanup() + + +@pytest.mark.anyio +async def test_cancel_task_rejects_nonexistent_task() -> None: + """Test cancel_task raises McpError with INVALID_PARAMS for nonexistent task.""" + store = InMemoryTaskStore() + + with pytest.raises(McpError) as exc_info: + await cancel_task(store, "nonexistent-task-id") + + assert exc_info.value.error.code == INVALID_PARAMS + assert "not found" in exc_info.value.error.message + + store.cleanup() + + +@pytest.mark.anyio +async def test_cancel_task_rejects_completed_task() -> None: + """Test cancel_task raises McpError with INVALID_PARAMS for completed task.""" + store = InMemoryTaskStore() + + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + await store.update_task(task.taskId, status="completed") + + with pytest.raises(McpError) as exc_info: + await cancel_task(store, task.taskId) + + assert exc_info.value.error.code == INVALID_PARAMS + assert "terminal state 'completed'" in exc_info.value.error.message + + store.cleanup() + + +@pytest.mark.anyio +async def test_cancel_task_rejects_failed_task() -> None: + """Test cancel_task raises McpError with INVALID_PARAMS for failed task.""" + store = InMemoryTaskStore() + + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + await store.update_task(task.taskId, status="failed") + + with pytest.raises(McpError) as exc_info: + await cancel_task(store, task.taskId) + + assert exc_info.value.error.code == INVALID_PARAMS + assert "terminal state 'failed'" in exc_info.value.error.message + + store.cleanup() + + +@pytest.mark.anyio +async def test_cancel_task_rejects_already_cancelled_task() -> None: + """Test cancel_task raises McpError with INVALID_PARAMS for already cancelled task.""" + store = InMemoryTaskStore() + + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + await store.update_task(task.taskId, status="cancelled") + + with pytest.raises(McpError) as exc_info: + await cancel_task(store, task.taskId) + + assert exc_info.value.error.code == INVALID_PARAMS + assert "terminal state 'cancelled'" in exc_info.value.error.message + + store.cleanup() + + +@pytest.mark.anyio +async def test_cancel_task_succeeds_for_input_required_task() -> None: + """Test cancel_task helper succeeds for a task in input_required status.""" + store = InMemoryTaskStore() + + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + await store.update_task(task.taskId, status="input_required") + + result = await cancel_task(store, task.taskId) + + assert result.taskId == task.taskId + assert result.status == "cancelled" + + store.cleanup() diff --git a/tests/experimental/tasks/test_interactive_example.py b/tests/experimental/tasks/test_interactive_example.py new file mode 100644 index 0000000000..bfa8df53e7 --- /dev/null +++ b/tests/experimental/tasks/test_interactive_example.py @@ -0,0 +1,610 @@ +""" +Unit test that demonstrates the correct interactive task pattern. + +This test serves as the reference implementation for the simple-task-interactive +examples. It demonstrates: + +1. A server with two tools: + - confirm_delete: Uses elicitation to ask for user confirmation + - write_haiku: Uses sampling to request LLM completion + +2. A client that: + - Calls tools as tasks using session.experimental.call_tool_as_task() + - Handles elicitation via callback + - Handles sampling via callback + - Retrieves results via session.experimental.get_task_result() + +Key insight: The client must call get_task_result() to receive elicitation/sampling +requests. The server delivers these requests via the tasks/result response stream. +Simply polling get_task() will not trigger the callbacks. +""" + +from dataclasses import dataclass, field +from typing import Any + +import anyio +import pytest +from anyio.abc import TaskGroup + +from mcp.client.session import ClientSession +from mcp.server import Server +from mcp.server.lowlevel import NotificationOptions +from mcp.server.models import InitializationOptions +from mcp.server.session import ServerSession +from mcp.shared.context import RequestContext +from mcp.shared.experimental.tasks import ( + InMemoryTaskMessageQueue, + InMemoryTaskStore, + TaskResultHandler, + TaskSession, + task_execution, +) +from mcp.shared.message import SessionMessage +from mcp.types import ( + TASK_REQUIRED, + CallToolResult, + CreateMessageRequestParams, + CreateMessageResult, + ElicitRequestParams, + ElicitResult, + GetTaskPayloadRequest, + GetTaskPayloadResult, + GetTaskRequest, + GetTaskResult, + SamplingMessage, + TextContent, + Tool, + ToolExecution, +) + + +@dataclass +class AppContext: + """Application context with task infrastructure.""" + + task_group: TaskGroup + store: InMemoryTaskStore + queue: InMemoryTaskMessageQueue + handler: TaskResultHandler + configured_sessions: dict[int, bool] = field(default_factory=lambda: {}) + + +def create_server() -> Server[AppContext, Any]: + """Create the server with confirm_delete and write_haiku tools.""" + server: Server[AppContext, Any] = Server("simple-task-interactive") # type: ignore[assignment] + + @server.list_tools() + async def list_tools() -> list[Tool]: + return [ + Tool( + name="confirm_delete", + description="Asks for confirmation before deleting (demonstrates elicitation)", + inputSchema={ + "type": "object", + "properties": {"filename": {"type": "string"}}, + }, + execution=ToolExecution(taskSupport=TASK_REQUIRED), + ), + Tool( + name="write_haiku", + description="Asks LLM to write a haiku (demonstrates sampling)", + inputSchema={ + "type": "object", + "properties": {"topic": {"type": "string"}}, + }, + execution=ToolExecution(taskSupport=TASK_REQUIRED), + ), + ] + + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent] | Any: + ctx = server.request_context + app = ctx.lifespan_context + + # Validate task mode + ctx.experimental.validate_task_mode(TASK_REQUIRED) + + # Ensure handler is configured for response routing + session_id = id(ctx.session) + if session_id not in app.configured_sessions: + ctx.session.add_response_router(app.handler) + app.configured_sessions[session_id] = True + + # Create task + metadata = ctx.experimental.task_metadata + assert metadata is not None + task = await app.store.create_task(metadata) + + if name == "confirm_delete": + filename = arguments.get("filename", "unknown.txt") + + async def do_confirm() -> None: + async with task_execution(task.taskId, app.store) as task_ctx: + task_session = TaskSession( + session=ctx.session, + task_id=task.taskId, + store=app.store, + queue=app.queue, + ) + + result = await task_session.elicit( + message=f"Are you sure you want to delete '{filename}'?", + requestedSchema={ + "type": "object", + "properties": {"confirm": {"type": "boolean"}}, + "required": ["confirm"], + }, + ) + + if result.action == "accept" and result.content: + confirmed = result.content.get("confirm", False) + text = f"Deleted '{filename}'" if confirmed else "Deletion cancelled" + else: + text = "Deletion cancelled" + + await task_ctx.complete( + CallToolResult(content=[TextContent(type="text", text=text)]), + notify=True, + ) + + app.task_group.start_soon(do_confirm) + + elif name == "write_haiku": + topic = arguments.get("topic", "nature") + + async def do_haiku() -> None: + async with task_execution(task.taskId, app.store) as task_ctx: + task_session = TaskSession( + session=ctx.session, + task_id=task.taskId, + store=app.store, + queue=app.queue, + ) + + result = await task_session.create_message( + messages=[ + SamplingMessage( + role="user", + content=TextContent(type="text", text=f"Write a haiku about {topic}"), + ) + ], + max_tokens=50, + ) + + haiku = "No response" + if isinstance(result.content, TextContent): + haiku = result.content.text + + await task_ctx.complete( + CallToolResult(content=[TextContent(type="text", text=f"Haiku:\n{haiku}")]), + notify=True, + ) + + app.task_group.start_soon(do_haiku) + + # Import here to avoid circular imports at module level + from mcp.types import CreateTaskResult + + return CreateTaskResult(task=task) + + @server.experimental.get_task() + async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: + app = server.request_context.lifespan_context + task = await app.store.get_task(request.params.taskId) + if task is None: + raise ValueError(f"Task {request.params.taskId} not found") + return GetTaskResult( + taskId=task.taskId, + status=task.status, + statusMessage=task.statusMessage, + createdAt=task.createdAt, + lastUpdatedAt=task.lastUpdatedAt, + ttl=task.ttl, + pollInterval=task.pollInterval, + ) + + @server.experimental.get_task_result() + async def handle_get_task_result( + request: GetTaskPayloadRequest, + ) -> GetTaskPayloadResult: + ctx = server.request_context + app = ctx.lifespan_context + + # Ensure handler is configured for this session + session_id = id(ctx.session) + if session_id not in app.configured_sessions: + ctx.session.add_response_router(app.handler) + app.configured_sessions[session_id] = True + + return await app.handler.handle(request, ctx.session, ctx.request_id) + + return server + + +@pytest.mark.anyio +async def test_confirm_delete_with_elicitation() -> None: + """ + Test the confirm_delete tool which uses elicitation. + + This demonstrates: + 1. Client calls tool as task + 2. Server asks for confirmation via elicitation + 3. Client receives elicitation via get_task_result() and responds + 4. Server completes task based on response + """ + server = create_server() + store = InMemoryTaskStore() + queue = InMemoryTaskMessageQueue() + handler = TaskResultHandler(store, queue) + + # Track elicitation requests + elicitation_messages: list[str] = [] + + async def elicitation_callback( + context: RequestContext[ClientSession, Any], + params: ElicitRequestParams, + ) -> ElicitResult: + """Handle elicitation - simulates user confirming deletion.""" + elicitation_messages.append(params.message) + # User confirms + return ElicitResult(action="accept", content={"confirm": True}) + + # Set up streams + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + async def run_server(app_context: AppContext, server_session: ServerSession) -> None: + async for message in server_session.incoming_messages: + await server._handle_message(message, server_session, app_context, raise_exceptions=False) + + async with anyio.create_task_group() as tg: + app_context = AppContext( + task_group=tg, + store=store, + queue=queue, + handler=handler, + ) + + server_session = ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="test-server", + server_version="1.0.0", + capabilities=server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ), + ) + server_session.add_response_router(handler) + + async with server_session: + tg.start_soon(run_server, app_context, server_session) + + async with ClientSession( + server_to_client_receive, + client_to_server_send, + elicitation_callback=elicitation_callback, + ) as client: + await client.initialize() + + # List tools + tools = await client.list_tools() + tool_names = [t.name for t in tools.tools] + assert "confirm_delete" in tool_names + assert "write_haiku" in tool_names + + # Call tool as task + result = await client.experimental.call_tool_as_task( + "confirm_delete", + {"filename": "important.txt"}, + ) + task_id = result.task.taskId + + # KEY PATTERN: Call get_task_result() to receive elicitation and get final result + # This is the critical difference from the broken example which only polled get_task() + final = await client.experimental.get_task_result(task_id, CallToolResult) + + # Verify elicitation was received + assert len(elicitation_messages) == 1 + assert "important.txt" in elicitation_messages[0] + + # Verify result + assert len(final.content) == 1 + assert isinstance(final.content[0], TextContent) + assert final.content[0].text == "Deleted 'important.txt'" + + # Verify task is completed + status = await client.experimental.get_task(task_id) + assert status.status == "completed" + + tg.cancel_scope.cancel() + + store.cleanup() + queue.cleanup() + + +@pytest.mark.anyio +async def test_confirm_delete_user_declines() -> None: + """Test confirm_delete when user declines.""" + server = create_server() + store = InMemoryTaskStore() + queue = InMemoryTaskMessageQueue() + handler = TaskResultHandler(store, queue) + + async def elicitation_callback( + context: RequestContext[ClientSession, Any], + params: ElicitRequestParams, + ) -> ElicitResult: + # User declines + return ElicitResult(action="accept", content={"confirm": False}) + + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + async def run_server(app_context: AppContext, server_session: ServerSession) -> None: + async for message in server_session.incoming_messages: + await server._handle_message(message, server_session, app_context, raise_exceptions=False) + + async with anyio.create_task_group() as tg: + app_context = AppContext( + task_group=tg, + store=store, + queue=queue, + handler=handler, + ) + + server_session = ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="test-server", + server_version="1.0.0", + capabilities=server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ), + ) + server_session.add_response_router(handler) + + async with server_session: + tg.start_soon(run_server, app_context, server_session) + + async with ClientSession( + server_to_client_receive, + client_to_server_send, + elicitation_callback=elicitation_callback, + ) as client: + await client.initialize() + + result = await client.experimental.call_tool_as_task( + "confirm_delete", + {"filename": "important.txt"}, + ) + task_id = result.task.taskId + + final = await client.experimental.get_task_result(task_id, CallToolResult) + + assert isinstance(final.content[0], TextContent) + assert final.content[0].text == "Deletion cancelled" + + tg.cancel_scope.cancel() + + store.cleanup() + queue.cleanup() + + +@pytest.mark.anyio +async def test_write_haiku_with_sampling() -> None: + """ + Test the write_haiku tool which uses sampling. + + This demonstrates: + 1. Client calls tool as task + 2. Server requests LLM completion via sampling + 3. Client receives sampling request via get_task_result() and responds + 4. Server completes task with the haiku + """ + server = create_server() + store = InMemoryTaskStore() + queue = InMemoryTaskMessageQueue() + handler = TaskResultHandler(store, queue) + + # Track sampling requests + sampling_prompts: list[str] = [] + test_haiku = """Autumn leaves falling +Softly on the quiet stream +Nature whispers peace""" + + async def sampling_callback( + context: RequestContext[ClientSession, Any], + params: CreateMessageRequestParams, + ) -> CreateMessageResult: + """Handle sampling - returns a test haiku.""" + if params.messages: + content = params.messages[0].content + if isinstance(content, TextContent): + sampling_prompts.append(content.text) + + return CreateMessageResult( + model="test-model", + role="assistant", + content=TextContent(type="text", text=test_haiku), + ) + + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + async def run_server(app_context: AppContext, server_session: ServerSession) -> None: + async for message in server_session.incoming_messages: + await server._handle_message(message, server_session, app_context, raise_exceptions=False) + + async with anyio.create_task_group() as tg: + app_context = AppContext( + task_group=tg, + store=store, + queue=queue, + handler=handler, + ) + + server_session = ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="test-server", + server_version="1.0.0", + capabilities=server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ), + ) + server_session.add_response_router(handler) + + async with server_session: + tg.start_soon(run_server, app_context, server_session) + + async with ClientSession( + server_to_client_receive, + client_to_server_send, + sampling_callback=sampling_callback, + ) as client: + await client.initialize() + + # Call tool as task + result = await client.experimental.call_tool_as_task( + "write_haiku", + {"topic": "autumn leaves"}, + ) + task_id = result.task.taskId + + # Get result (this delivers the sampling request) + final = await client.experimental.get_task_result(task_id, CallToolResult) + + # Verify sampling was requested + assert len(sampling_prompts) == 1 + assert "autumn leaves" in sampling_prompts[0] + + # Verify result contains the haiku + assert len(final.content) == 1 + assert isinstance(final.content[0], TextContent) + assert "Haiku:" in final.content[0].text + assert "Autumn leaves falling" in final.content[0].text + + # Verify task is completed + status = await client.experimental.get_task(task_id) + assert status.status == "completed" + + tg.cancel_scope.cancel() + + store.cleanup() + queue.cleanup() + + +@pytest.mark.anyio +async def test_both_tools_sequentially() -> None: + """ + Test calling both tools sequentially, similar to how the example works. + + This is the closest match to what the example client does. + """ + server = create_server() + store = InMemoryTaskStore() + queue = InMemoryTaskMessageQueue() + handler = TaskResultHandler(store, queue) + + elicitation_count = 0 + sampling_count = 0 + + async def elicitation_callback( + context: RequestContext[ClientSession, Any], + params: ElicitRequestParams, + ) -> ElicitResult: + nonlocal elicitation_count + elicitation_count += 1 + return ElicitResult(action="accept", content={"confirm": True}) + + async def sampling_callback( + context: RequestContext[ClientSession, Any], + params: CreateMessageRequestParams, + ) -> CreateMessageResult: + nonlocal sampling_count + sampling_count += 1 + return CreateMessageResult( + model="test-model", + role="assistant", + content=TextContent(type="text", text="Cherry blossoms fall"), + ) + + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + async def run_server(app_context: AppContext, server_session: ServerSession) -> None: + async for message in server_session.incoming_messages: + await server._handle_message(message, server_session, app_context, raise_exceptions=False) + + async with anyio.create_task_group() as tg: + app_context = AppContext( + task_group=tg, + store=store, + queue=queue, + handler=handler, + ) + + server_session = ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="test-server", + server_version="1.0.0", + capabilities=server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ), + ) + server_session.add_response_router(handler) + + async with server_session: + tg.start_soon(run_server, app_context, server_session) + + async with ClientSession( + server_to_client_receive, + client_to_server_send, + elicitation_callback=elicitation_callback, + sampling_callback=sampling_callback, + ) as client: + await client.initialize() + + # === Demo 1: Elicitation (confirm_delete) === + result1 = await client.experimental.call_tool_as_task( + "confirm_delete", + {"filename": "important.txt"}, + ) + task_id1 = result1.task.taskId + + final1 = await client.experimental.get_task_result(task_id1, CallToolResult) + assert isinstance(final1.content[0], TextContent) + assert "Deleted" in final1.content[0].text + + # === Demo 2: Sampling (write_haiku) === + result2 = await client.experimental.call_tool_as_task( + "write_haiku", + {"topic": "autumn leaves"}, + ) + task_id2 = result2.task.taskId + + final2 = await client.experimental.get_task_result(task_id2, CallToolResult) + assert isinstance(final2.content[0], TextContent) + assert "Haiku:" in final2.content[0].text + + # Verify both callbacks were triggered + assert elicitation_count == 1 + assert sampling_count == 1 + + tg.cancel_scope.cancel() + + store.cleanup() + queue.cleanup() diff --git a/tests/experimental/tasks/test_message_queue.py b/tests/experimental/tasks/test_message_queue.py new file mode 100644 index 0000000000..0406b6ae5d --- /dev/null +++ b/tests/experimental/tasks/test_message_queue.py @@ -0,0 +1,252 @@ +""" +Tests for TaskMessageQueue and InMemoryTaskMessageQueue. +""" + +from datetime import datetime, timezone + +import anyio +import pytest + +from mcp.shared.experimental.tasks import ( + InMemoryTaskMessageQueue, + QueuedMessage, + Resolver, +) +from mcp.types import JSONRPCNotification, JSONRPCRequest + + +@pytest.fixture +def queue() -> InMemoryTaskMessageQueue: + return InMemoryTaskMessageQueue() + + +def make_request(id: int = 1, method: str = "test/method") -> JSONRPCRequest: + return JSONRPCRequest(jsonrpc="2.0", id=id, method=method) + + +def make_notification(method: str = "test/notify") -> JSONRPCNotification: + return JSONRPCNotification(jsonrpc="2.0", method=method) + + +class TestInMemoryTaskMessageQueue: + @pytest.mark.anyio + async def test_enqueue_and_dequeue(self, queue: InMemoryTaskMessageQueue) -> None: + """Test basic enqueue and dequeue operations.""" + task_id = "task-1" + msg = QueuedMessage(type="request", message=make_request()) + + await queue.enqueue(task_id, msg) + result = await queue.dequeue(task_id) + + assert result is not None + assert result.type == "request" + assert result.message.method == "test/method" + + @pytest.mark.anyio + async def test_dequeue_empty_returns_none(self, queue: InMemoryTaskMessageQueue) -> None: + """Dequeue from empty queue returns None.""" + result = await queue.dequeue("nonexistent-task") + assert result is None + + @pytest.mark.anyio + async def test_fifo_ordering(self, queue: InMemoryTaskMessageQueue) -> None: + """Messages are dequeued in FIFO order.""" + task_id = "task-1" + + await queue.enqueue(task_id, QueuedMessage(type="request", message=make_request(1, "first"))) + await queue.enqueue(task_id, QueuedMessage(type="request", message=make_request(2, "second"))) + await queue.enqueue(task_id, QueuedMessage(type="request", message=make_request(3, "third"))) + + msg1 = await queue.dequeue(task_id) + msg2 = await queue.dequeue(task_id) + msg3 = await queue.dequeue(task_id) + + assert msg1 is not None and msg1.message.method == "first" + assert msg2 is not None and msg2.message.method == "second" + assert msg3 is not None and msg3.message.method == "third" + + @pytest.mark.anyio + async def test_separate_queues_per_task(self, queue: InMemoryTaskMessageQueue) -> None: + """Each task has its own queue.""" + await queue.enqueue("task-1", QueuedMessage(type="request", message=make_request(1, "task1-msg"))) + await queue.enqueue("task-2", QueuedMessage(type="request", message=make_request(2, "task2-msg"))) + + msg1 = await queue.dequeue("task-1") + msg2 = await queue.dequeue("task-2") + + assert msg1 is not None and msg1.message.method == "task1-msg" + assert msg2 is not None and msg2.message.method == "task2-msg" + + @pytest.mark.anyio + async def test_peek_does_not_remove(self, queue: InMemoryTaskMessageQueue) -> None: + """Peek returns message without removing it.""" + task_id = "task-1" + await queue.enqueue(task_id, QueuedMessage(type="request", message=make_request())) + + peeked = await queue.peek(task_id) + dequeued = await queue.dequeue(task_id) + + assert peeked is not None + assert dequeued is not None + assert isinstance(peeked.message, JSONRPCRequest) + assert isinstance(dequeued.message, JSONRPCRequest) + assert peeked.message.id == dequeued.message.id + + @pytest.mark.anyio + async def test_is_empty(self, queue: InMemoryTaskMessageQueue) -> None: + """Test is_empty method.""" + task_id = "task-1" + + assert await queue.is_empty(task_id) is True + + await queue.enqueue(task_id, QueuedMessage(type="notification", message=make_notification())) + assert await queue.is_empty(task_id) is False + + await queue.dequeue(task_id) + assert await queue.is_empty(task_id) is True + + @pytest.mark.anyio + async def test_clear_returns_all_messages(self, queue: InMemoryTaskMessageQueue) -> None: + """Clear removes and returns all messages.""" + task_id = "task-1" + + await queue.enqueue(task_id, QueuedMessage(type="request", message=make_request(1))) + await queue.enqueue(task_id, QueuedMessage(type="request", message=make_request(2))) + await queue.enqueue(task_id, QueuedMessage(type="request", message=make_request(3))) + + messages = await queue.clear(task_id) + + assert len(messages) == 3 + assert await queue.is_empty(task_id) is True + + @pytest.mark.anyio + async def test_clear_empty_queue(self, queue: InMemoryTaskMessageQueue) -> None: + """Clear on empty queue returns empty list.""" + messages = await queue.clear("nonexistent") + assert messages == [] + + @pytest.mark.anyio + async def test_notification_messages(self, queue: InMemoryTaskMessageQueue) -> None: + """Test queuing notification messages.""" + task_id = "task-1" + msg = QueuedMessage(type="notification", message=make_notification("log/message")) + + await queue.enqueue(task_id, msg) + result = await queue.dequeue(task_id) + + assert result is not None + assert result.type == "notification" + assert result.message.method == "log/message" + + @pytest.mark.anyio + async def test_message_timestamp(self, queue: InMemoryTaskMessageQueue) -> None: + """Messages have timestamps.""" + before = datetime.now(timezone.utc) + msg = QueuedMessage(type="request", message=make_request()) + after = datetime.now(timezone.utc) + + assert before <= msg.timestamp <= after + + @pytest.mark.anyio + async def test_message_with_resolver(self, queue: InMemoryTaskMessageQueue) -> None: + """Messages can have resolvers.""" + task_id = "task-1" + resolver: Resolver[dict[str, str]] = Resolver() + + msg = QueuedMessage( + type="request", + message=make_request(), + resolver=resolver, + original_request_id=42, + ) + + await queue.enqueue(task_id, msg) + result = await queue.dequeue(task_id) + + assert result is not None + assert result.resolver is resolver + assert result.original_request_id == 42 + + @pytest.mark.anyio + async def test_cleanup_specific_task(self, queue: InMemoryTaskMessageQueue) -> None: + """Cleanup removes specific task's data.""" + await queue.enqueue("task-1", QueuedMessage(type="request", message=make_request(1))) + await queue.enqueue("task-2", QueuedMessage(type="request", message=make_request(2))) + + queue.cleanup("task-1") + + assert await queue.is_empty("task-1") is True + assert await queue.is_empty("task-2") is False + + @pytest.mark.anyio + async def test_cleanup_all(self, queue: InMemoryTaskMessageQueue) -> None: + """Cleanup without task_id removes all data.""" + await queue.enqueue("task-1", QueuedMessage(type="request", message=make_request(1))) + await queue.enqueue("task-2", QueuedMessage(type="request", message=make_request(2))) + + queue.cleanup() + + assert await queue.is_empty("task-1") is True + assert await queue.is_empty("task-2") is True + + @pytest.mark.anyio + async def test_wait_for_message_returns_immediately_if_message_exists( + self, queue: InMemoryTaskMessageQueue + ) -> None: + """wait_for_message returns immediately if queue not empty.""" + task_id = "task-1" + await queue.enqueue(task_id, QueuedMessage(type="request", message=make_request())) + + # Should return immediately, not block + with anyio.fail_after(1): + await queue.wait_for_message(task_id) + + @pytest.mark.anyio + async def test_wait_for_message_blocks_until_message(self, queue: InMemoryTaskMessageQueue) -> None: + """wait_for_message blocks until a message is enqueued.""" + task_id = "task-1" + received = False + waiter_started = anyio.Event() + + async def enqueue_when_ready() -> None: + # Wait until the waiter has started before enqueueing + await waiter_started.wait() + await queue.enqueue(task_id, QueuedMessage(type="request", message=make_request())) + + async def wait_for_msg() -> None: + nonlocal received + # Signal that we're about to start waiting + waiter_started.set() + await queue.wait_for_message(task_id) + received = True + + async with anyio.create_task_group() as tg: + tg.start_soon(wait_for_msg) + tg.start_soon(enqueue_when_ready) + + assert received is True + + @pytest.mark.anyio + async def test_notify_message_available_wakes_waiter(self, queue: InMemoryTaskMessageQueue) -> None: + """notify_message_available wakes up waiting coroutines.""" + task_id = "task-1" + notified = False + waiter_started = anyio.Event() + + async def notify_when_ready() -> None: + # Wait until the waiter has started before notifying + await waiter_started.wait() + await queue.notify_message_available(task_id) + + async def wait_for_notification() -> None: + nonlocal notified + # Signal that we're about to start waiting + waiter_started.set() + await queue.wait_for_message(task_id) + notified = True + + async with anyio.create_task_group() as tg: + tg.start_soon(wait_for_notification) + tg.start_soon(notify_when_ready) + + assert notified is True diff --git a/tests/experimental/tasks/test_request_context.py b/tests/experimental/tasks/test_request_context.py new file mode 100644 index 0000000000..d8ac806d11 --- /dev/null +++ b/tests/experimental/tasks/test_request_context.py @@ -0,0 +1,180 @@ +"""Tests for the RequestContext.experimental (Experimental class) task validation helpers.""" + +import pytest + +from mcp.shared.context import Experimental +from mcp.shared.exceptions import McpError +from mcp.types import ( + METHOD_NOT_FOUND, + TASK_FORBIDDEN, + TASK_OPTIONAL, + TASK_REQUIRED, + ClientCapabilities, + ClientTasksCapability, + TaskMetadata, + Tool, + ToolExecution, +) + +# --- Experimental.is_task --- + + +def test_is_task_true_when_metadata_present() -> None: + exp = Experimental(task_metadata=TaskMetadata(ttl=60000)) + assert exp.is_task is True + + +def test_is_task_false_when_no_metadata() -> None: + exp = Experimental(task_metadata=None) + assert exp.is_task is False + + +# --- Experimental.client_supports_tasks --- + + +def test_client_supports_tasks_true() -> None: + exp = Experimental(_client_capabilities=ClientCapabilities(tasks=ClientTasksCapability())) + assert exp.client_supports_tasks is True + + +def test_client_supports_tasks_false_no_tasks() -> None: + exp = Experimental(_client_capabilities=ClientCapabilities()) + assert exp.client_supports_tasks is False + + +def test_client_supports_tasks_false_no_capabilities() -> None: + exp = Experimental(_client_capabilities=None) + assert exp.client_supports_tasks is False + + +# --- Experimental.validate_task_mode --- + + +def test_validate_task_mode_required_with_task_is_valid() -> None: + exp = Experimental(task_metadata=TaskMetadata(ttl=60000)) + error = exp.validate_task_mode(TASK_REQUIRED, raise_error=False) + assert error is None + + +def test_validate_task_mode_required_without_task_returns_error() -> None: + exp = Experimental(task_metadata=None) + error = exp.validate_task_mode(TASK_REQUIRED, raise_error=False) + assert error is not None + assert error.code == METHOD_NOT_FOUND + assert "requires task-augmented" in error.message + + +def test_validate_task_mode_required_without_task_raises_by_default() -> None: + exp = Experimental(task_metadata=None) + with pytest.raises(McpError) as exc_info: + exp.validate_task_mode(TASK_REQUIRED) + assert exc_info.value.error.code == METHOD_NOT_FOUND + + +def test_validate_task_mode_forbidden_without_task_is_valid() -> None: + exp = Experimental(task_metadata=None) + error = exp.validate_task_mode(TASK_FORBIDDEN, raise_error=False) + assert error is None + + +def test_validate_task_mode_forbidden_with_task_returns_error() -> None: + exp = Experimental(task_metadata=TaskMetadata(ttl=60000)) + error = exp.validate_task_mode(TASK_FORBIDDEN, raise_error=False) + assert error is not None + assert error.code == METHOD_NOT_FOUND + assert "does not support task-augmented" in error.message + + +def test_validate_task_mode_forbidden_with_task_raises_by_default() -> None: + exp = Experimental(task_metadata=TaskMetadata(ttl=60000)) + with pytest.raises(McpError) as exc_info: + exp.validate_task_mode(TASK_FORBIDDEN) + assert exc_info.value.error.code == METHOD_NOT_FOUND + + +def test_validate_task_mode_none_treated_as_forbidden() -> None: + exp = Experimental(task_metadata=TaskMetadata(ttl=60000)) + error = exp.validate_task_mode(None, raise_error=False) + assert error is not None + assert "does not support task-augmented" in error.message + + +def test_validate_task_mode_optional_with_task_is_valid() -> None: + exp = Experimental(task_metadata=TaskMetadata(ttl=60000)) + error = exp.validate_task_mode(TASK_OPTIONAL, raise_error=False) + assert error is None + + +def test_validate_task_mode_optional_without_task_is_valid() -> None: + exp = Experimental(task_metadata=None) + error = exp.validate_task_mode(TASK_OPTIONAL, raise_error=False) + assert error is None + + +# --- Experimental.validate_for_tool --- + + +def test_validate_for_tool_with_execution_required() -> None: + exp = Experimental(task_metadata=None) + tool = Tool( + name="test", + description="test", + inputSchema={"type": "object"}, + execution=ToolExecution(taskSupport=TASK_REQUIRED), + ) + error = exp.validate_for_tool(tool, raise_error=False) + assert error is not None + assert "requires task-augmented" in error.message + + +def test_validate_for_tool_without_execution() -> None: + exp = Experimental(task_metadata=TaskMetadata(ttl=60000)) + tool = Tool( + name="test", + description="test", + inputSchema={"type": "object"}, + execution=None, + ) + error = exp.validate_for_tool(tool, raise_error=False) + assert error is not None + assert "does not support task-augmented" in error.message + + +def test_validate_for_tool_optional_with_task() -> None: + exp = Experimental(task_metadata=TaskMetadata(ttl=60000)) + tool = Tool( + name="test", + description="test", + inputSchema={"type": "object"}, + execution=ToolExecution(taskSupport=TASK_OPTIONAL), + ) + error = exp.validate_for_tool(tool, raise_error=False) + assert error is None + + +# --- Experimental.can_use_tool --- + + +def test_can_use_tool_required_with_task_support() -> None: + exp = Experimental(_client_capabilities=ClientCapabilities(tasks=ClientTasksCapability())) + assert exp.can_use_tool(TASK_REQUIRED) is True + + +def test_can_use_tool_required_without_task_support() -> None: + exp = Experimental(_client_capabilities=ClientCapabilities()) + assert exp.can_use_tool(TASK_REQUIRED) is False + + +def test_can_use_tool_optional_without_task_support() -> None: + exp = Experimental(_client_capabilities=ClientCapabilities()) + assert exp.can_use_tool(TASK_OPTIONAL) is True + + +def test_can_use_tool_forbidden_without_task_support() -> None: + exp = Experimental(_client_capabilities=ClientCapabilities()) + assert exp.can_use_tool(TASK_FORBIDDEN) is True + + +def test_can_use_tool_none_without_task_support() -> None: + exp = Experimental(_client_capabilities=ClientCapabilities()) + assert exp.can_use_tool(None) is True diff --git a/tests/experimental/tasks/test_response_routing.py b/tests/experimental/tasks/test_response_routing.py new file mode 100644 index 0000000000..5e401accd7 --- /dev/null +++ b/tests/experimental/tasks/test_response_routing.py @@ -0,0 +1,652 @@ +""" +Tests for response routing in task-augmented flows. + +This tests the ResponseRouter protocol and its integration with BaseSession +to route responses for queued task requests back to their resolvers. +""" + +from typing import Any +from unittest.mock import AsyncMock, Mock + +import anyio +import pytest + +from mcp.shared.experimental.tasks import ( + InMemoryTaskMessageQueue, + InMemoryTaskStore, + QueuedMessage, + Resolver, + TaskResultHandler, +) +from mcp.shared.response_router import ResponseRouter +from mcp.types import ErrorData, JSONRPCRequest, RequestId, TaskMetadata + + +class TestResponseRouterProtocol: + """Test the ResponseRouter protocol.""" + + def test_task_result_handler_implements_protocol(self) -> None: + """TaskResultHandler implements ResponseRouter protocol.""" + store = InMemoryTaskStore() + queue = InMemoryTaskMessageQueue() + handler = TaskResultHandler(store, queue) + + # Verify it has the required methods + assert hasattr(handler, "route_response") + assert hasattr(handler, "route_error") + assert callable(handler.route_response) + assert callable(handler.route_error) + + def test_protocol_type_checking(self) -> None: + """ResponseRouter can be used as a type hint.""" + + def accepts_router(router: ResponseRouter) -> bool: + return router.route_response(1, {}) + + # This should type-check correctly + store = InMemoryTaskStore() + queue = InMemoryTaskMessageQueue() + handler = TaskResultHandler(store, queue) + + # Should not raise - handler implements the protocol + result = accepts_router(handler) + assert result is False # No pending request + + +class TestTaskResultHandlerRouting: + """Test TaskResultHandler response and error routing.""" + + @pytest.fixture + def handler(self) -> TaskResultHandler: + store = InMemoryTaskStore() + queue = InMemoryTaskMessageQueue() + return TaskResultHandler(store, queue) + + def test_route_response_no_pending_request(self, handler: TaskResultHandler) -> None: + """route_response returns False when no pending request.""" + result = handler.route_response(123, {"status": "ok"}) + assert result is False + + def test_route_error_no_pending_request(self, handler: TaskResultHandler) -> None: + """route_error returns False when no pending request.""" + error = ErrorData(code=-32600, message="Invalid Request") + result = handler.route_error(123, error) + assert result is False + + @pytest.mark.anyio + async def test_route_response_with_pending_request(self, handler: TaskResultHandler) -> None: + """route_response delivers to waiting resolver.""" + resolver: Resolver[dict[str, Any]] = Resolver() + request_id: RequestId = "task-abc-12345678" + + # Simulate what happens during _deliver_queued_messages + handler._pending_requests[request_id] = resolver + + # Route the response + result = handler.route_response(request_id, {"action": "accept", "content": {"name": "test"}}) + + assert result is True + assert resolver.done() + assert await resolver.wait() == {"action": "accept", "content": {"name": "test"}} + + @pytest.mark.anyio + async def test_route_error_with_pending_request(self, handler: TaskResultHandler) -> None: + """route_error delivers exception to waiting resolver.""" + resolver: Resolver[dict[str, Any]] = Resolver() + request_id: RequestId = "task-abc-12345678" + + handler._pending_requests[request_id] = resolver + + error = ErrorData(code=-32600, message="User declined") + result = handler.route_error(request_id, error) + + assert result is True + assert resolver.done() + + # Should raise McpError when awaited + with pytest.raises(Exception) as exc_info: + await resolver.wait() + assert "User declined" in str(exc_info.value) + + def test_route_response_removes_from_pending(self, handler: TaskResultHandler) -> None: + """route_response removes request from pending after routing.""" + resolver: Resolver[dict[str, Any]] = Resolver() + request_id: RequestId = 42 + + handler._pending_requests[request_id] = resolver + handler.route_response(request_id, {}) + + assert request_id not in handler._pending_requests + + def test_route_error_removes_from_pending(self, handler: TaskResultHandler) -> None: + """route_error removes request from pending after routing.""" + resolver: Resolver[dict[str, Any]] = Resolver() + request_id: RequestId = 42 + + handler._pending_requests[request_id] = resolver + handler.route_error(request_id, ErrorData(code=0, message="test")) + + assert request_id not in handler._pending_requests + + def test_route_response_ignores_already_done_resolver(self, handler: TaskResultHandler) -> None: + """route_response returns False for already-resolved resolver.""" + resolver: Resolver[dict[str, Any]] = Resolver() + resolver.set_result({"already": "done"}) + request_id: RequestId = 42 + + handler._pending_requests[request_id] = resolver + result = handler.route_response(request_id, {"new": "data"}) + + # Should return False since resolver was already done + assert result is False + + def test_route_with_string_request_id(self, handler: TaskResultHandler) -> None: + """Response routing works with string request IDs.""" + resolver: Resolver[dict[str, Any]] = Resolver() + request_id = "task-abc-12345678" + + handler._pending_requests[request_id] = resolver + result = handler.route_response(request_id, {"status": "ok"}) + + assert result is True + assert resolver.done() + + def test_route_with_int_request_id(self, handler: TaskResultHandler) -> None: + """Response routing works with integer request IDs.""" + resolver: Resolver[dict[str, Any]] = Resolver() + request_id = 999 + + handler._pending_requests[request_id] = resolver + result = handler.route_response(request_id, {"status": "ok"}) + + assert result is True + assert resolver.done() + + +class TestDeliverQueuedMessages: + """Test that _deliver_queued_messages properly sets up response routing.""" + + @pytest.mark.anyio + async def test_request_resolver_stored_for_routing(self) -> None: + """When delivering a request, its resolver is stored for response routing.""" + store = InMemoryTaskStore() + queue = InMemoryTaskMessageQueue() + handler = TaskResultHandler(store, queue) + + # Create a task + task = await store.create_task(TaskMetadata(ttl=60000), task_id="task-1") + + # Create resolver and queued message + resolver: Resolver[dict[str, Any]] = Resolver() + request_id: RequestId = "task-1-abc12345" + request = JSONRPCRequest(jsonrpc="2.0", id=request_id, method="elicitation/create") + + queued_msg = QueuedMessage( + type="request", + message=request, + resolver=resolver, + original_request_id=request_id, + ) + await queue.enqueue(task.taskId, queued_msg) + + # Create mock session with async send_message + mock_session = Mock() + mock_session.send_message = AsyncMock() + + # Deliver the message + await handler._deliver_queued_messages(task.taskId, mock_session, "outer-request-1") + + # Verify resolver is stored for routing + assert request_id in handler._pending_requests + assert handler._pending_requests[request_id] is resolver + + @pytest.mark.anyio + async def test_notification_not_stored_for_routing(self) -> None: + """Notifications don't create pending request entries.""" + store = InMemoryTaskStore() + queue = InMemoryTaskMessageQueue() + handler = TaskResultHandler(store, queue) + + task = await store.create_task(TaskMetadata(ttl=60000), task_id="task-1") + + from mcp.types import JSONRPCNotification + + notification = JSONRPCNotification(jsonrpc="2.0", method="notifications/log") + queued_msg = QueuedMessage(type="notification", message=notification) + await queue.enqueue(task.taskId, queued_msg) + + mock_session = Mock() + mock_session.send_message = AsyncMock() + + await handler._deliver_queued_messages(task.taskId, mock_session, "outer-request-1") + + # No pending requests for notifications + assert len(handler._pending_requests) == 0 + + +class TestTaskSessionRequestIds: + """Test TaskSession generates unique request IDs.""" + + @pytest.mark.anyio + async def test_request_ids_are_strings(self) -> None: + """TaskSession generates string request IDs to avoid collision with BaseSession.""" + from mcp.shared.experimental.tasks import TaskSession + + store = InMemoryTaskStore() + queue = InMemoryTaskMessageQueue() + mock_session = Mock() + + task_session = TaskSession( + session=mock_session, + task_id="task-abc", + store=store, + queue=queue, + ) + + id1 = task_session._next_request_id() + id2 = task_session._next_request_id() + + # IDs should be strings + assert isinstance(id1, str) + assert isinstance(id2, str) + + # IDs should be unique + assert id1 != id2 + + # IDs should contain task ID for debugging + assert "task-abc" in id1 + assert "task-abc" in id2 + + @pytest.mark.anyio + async def test_request_ids_include_uuid_component(self) -> None: + """Request IDs include a UUID component for uniqueness.""" + from mcp.shared.experimental.tasks import TaskSession + + store = InMemoryTaskStore() + queue = InMemoryTaskMessageQueue() + mock_session = Mock() + + # Create two task sessions with same task_id + task_session1 = TaskSession(session=mock_session, task_id="task-1", store=store, queue=queue) + task_session2 = TaskSession(session=mock_session, task_id="task-1", store=store, queue=queue) + + id1 = task_session1._next_request_id() + id2 = task_session2._next_request_id() + + # Even with same task_id, IDs should be unique due to UUID + assert id1 != id2 + + +class TestRelatedTaskMetadata: + """Test that TaskSession includes related-task metadata in requests.""" + + @pytest.mark.anyio + async def test_elicit_includes_related_task_metadata(self) -> None: + """TaskSession.elicit() includes io.modelcontextprotocol/related-task metadata.""" + from mcp.shared.experimental.tasks import RELATED_TASK_METADATA_KEY, TaskSession + + store = InMemoryTaskStore() + queue = InMemoryTaskMessageQueue() + mock_session = Mock() + + # Create a task first + task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task-123") + + task_session = TaskSession( + session=mock_session, + task_id=task.taskId, + store=store, + queue=queue, + ) + + # Start elicitation (will block waiting for response, so we need to cancel) + async def start_elicit() -> None: + try: + await task_session.elicit( + message="What is your name?", + requestedSchema={"type": "object", "properties": {"name": {"type": "string"}}}, + ) + except anyio.get_cancelled_exc_class(): + pass + + async with anyio.create_task_group() as tg: + tg.start_soon(start_elicit) + await queue.wait_for_message(task.taskId) + + # Check the queued message + msg = await queue.dequeue(task.taskId) + assert msg is not None + assert msg.type == "request" + + # Verify related-task metadata + assert hasattr(msg.message, "params") + params = msg.message.params + assert params is not None + assert "_meta" in params + assert RELATED_TASK_METADATA_KEY in params["_meta"] + assert params["_meta"][RELATED_TASK_METADATA_KEY]["taskId"] == task.taskId + + tg.cancel_scope.cancel() + + def test_related_task_metadata_key_value(self) -> None: + """RELATED_TASK_METADATA_KEY has correct value per spec.""" + from mcp.shared.experimental.tasks import RELATED_TASK_METADATA_KEY + + assert RELATED_TASK_METADATA_KEY == "io.modelcontextprotocol/related-task" + + +class TestEndToEndResponseRouting: + """End-to-end tests for response routing flow.""" + + @pytest.mark.anyio + async def test_full_elicitation_response_flow(self) -> None: + """Test complete flow: enqueue -> deliver -> respond -> receive.""" + from mcp.shared.experimental.tasks import TaskSession + + store = InMemoryTaskStore() + queue = InMemoryTaskMessageQueue() + handler = TaskResultHandler(store, queue) + mock_session = Mock() + + # Create task + task = await store.create_task(TaskMetadata(ttl=60000), task_id="task-flow-test") + + task_session = TaskSession( + session=mock_session, + task_id=task.taskId, + store=store, + queue=queue, + ) + + elicit_result = None + + async def do_elicit() -> None: + nonlocal elicit_result + elicit_result = await task_session.elicit( + message="Enter name", + requestedSchema={"type": "string"}, + ) + + async def simulate_response() -> None: + # Wait for message to be enqueued + await queue.wait_for_message(task.taskId) + + # Simulate TaskResultHandler delivering the message + msg = await queue.dequeue(task.taskId) + assert msg is not None + assert msg.resolver is not None + assert msg.original_request_id is not None + original_id = msg.original_request_id + + # Store resolver (as TaskResultHandler would) + handler._pending_requests[original_id] = msg.resolver + + # Simulate client response arriving + response_data = {"action": "accept", "content": {"name": "Alice"}} + routed = handler.route_response(original_id, response_data) + assert routed is True + + async with anyio.create_task_group() as tg: + tg.start_soon(do_elicit) + tg.start_soon(simulate_response) + + # Verify the elicit() call received the response + assert elicit_result is not None + assert elicit_result.action == "accept" + assert elicit_result.content == {"name": "Alice"} + + @pytest.mark.anyio + async def test_multiple_concurrent_elicitations(self) -> None: + """Multiple elicitations can be routed concurrently.""" + from mcp.shared.experimental.tasks import TaskSession + + store = InMemoryTaskStore() + queue = InMemoryTaskMessageQueue() + handler = TaskResultHandler(store, queue) + mock_session = Mock() + + task = await store.create_task(TaskMetadata(ttl=60000), task_id="task-concurrent") + task_session = TaskSession( + session=mock_session, + task_id=task.taskId, + store=store, + queue=queue, + ) + + results: list[Any] = [] + + async def elicit_and_store(idx: int) -> None: + result = await task_session.elicit( + message=f"Question {idx}", + requestedSchema={"type": "string"}, + ) + results.append((idx, result)) + + async def respond_to_all() -> None: + # Wait for all 3 messages to be enqueued, then respond + for i in range(3): + await queue.wait_for_message(task.taskId) + msg = await queue.dequeue(task.taskId) + if msg and msg.resolver and msg.original_request_id is not None: + request_id = msg.original_request_id + handler._pending_requests[request_id] = msg.resolver + handler.route_response( + request_id, + {"action": "accept", "content": {"answer": f"Response {i}"}}, + ) + + async with anyio.create_task_group() as tg: + tg.start_soon(elicit_and_store, 0) + tg.start_soon(elicit_and_store, 1) + tg.start_soon(elicit_and_store, 2) + tg.start_soon(respond_to_all) + + assert len(results) == 3 + # All should have received responses + for _idx, result in results: + assert result.action == "accept" + + +class TestSamplingResponseRouting: + """Test sampling request/response routing through TaskSession.""" + + @pytest.mark.anyio + async def test_create_message_enqueues_request(self) -> None: + """create_message() enqueues a sampling request.""" + from mcp.shared.experimental.tasks import TaskSession + from mcp.types import SamplingMessage, TextContent + + store = InMemoryTaskStore() + queue = InMemoryTaskMessageQueue() + mock_session = Mock() + + task = await store.create_task(TaskMetadata(ttl=60000), task_id="task-sampling-1") + + task_session = TaskSession( + session=mock_session, + task_id=task.taskId, + store=store, + queue=queue, + ) + + async def start_sampling() -> None: + try: + await task_session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Hello"))], + max_tokens=100, + ) + except anyio.get_cancelled_exc_class(): + pass + + async with anyio.create_task_group() as tg: + tg.start_soon(start_sampling) + await queue.wait_for_message(task.taskId) + + # Verify message was enqueued + msg = await queue.dequeue(task.taskId) + assert msg is not None + assert msg.type == "request" + assert msg.message.method == "sampling/createMessage" + + tg.cancel_scope.cancel() + + @pytest.mark.anyio + async def test_create_message_includes_related_task_metadata(self) -> None: + """Sampling request includes io.modelcontextprotocol/related-task metadata.""" + from mcp.shared.experimental.tasks import RELATED_TASK_METADATA_KEY, TaskSession + from mcp.types import SamplingMessage, TextContent + + store = InMemoryTaskStore() + queue = InMemoryTaskMessageQueue() + mock_session = Mock() + + task = await store.create_task(TaskMetadata(ttl=60000), task_id="task-sampling-meta") + + task_session = TaskSession( + session=mock_session, + task_id=task.taskId, + store=store, + queue=queue, + ) + + async def start_sampling() -> None: + try: + await task_session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Test"))], + max_tokens=50, + ) + except anyio.get_cancelled_exc_class(): + pass + + async with anyio.create_task_group() as tg: + tg.start_soon(start_sampling) + await queue.wait_for_message(task.taskId) + + msg = await queue.dequeue(task.taskId) + assert msg is not None + + # Verify related-task metadata + params = msg.message.params + assert params is not None + assert "_meta" in params + assert RELATED_TASK_METADATA_KEY in params["_meta"] + assert params["_meta"][RELATED_TASK_METADATA_KEY]["taskId"] == task.taskId + + tg.cancel_scope.cancel() + + @pytest.mark.anyio + async def test_create_message_response_routing(self) -> None: + """Response to sampling request is routed back to resolver.""" + from mcp.shared.experimental.tasks import TaskSession + from mcp.types import SamplingMessage, TextContent + + store = InMemoryTaskStore() + queue = InMemoryTaskMessageQueue() + handler = TaskResultHandler(store, queue) + mock_session = Mock() + + task = await store.create_task(TaskMetadata(ttl=60000), task_id="task-sampling-route") + + task_session = TaskSession( + session=mock_session, + task_id=task.taskId, + store=store, + queue=queue, + ) + + sampling_result = None + + async def do_sampling() -> None: + nonlocal sampling_result + sampling_result = await task_session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(type="text", text="What is 2+2?"))], + max_tokens=100, + ) + + async def simulate_response() -> None: + await queue.wait_for_message(task.taskId) + + msg = await queue.dequeue(task.taskId) + assert msg is not None + assert msg.resolver is not None + assert msg.original_request_id is not None + original_id = msg.original_request_id + + handler._pending_requests[original_id] = msg.resolver + + # Simulate sampling response + response_data = { + "model": "test-model", + "role": "assistant", + "content": {"type": "text", "text": "4"}, + } + routed = handler.route_response(original_id, response_data) + assert routed is True + + async with anyio.create_task_group() as tg: + tg.start_soon(do_sampling) + tg.start_soon(simulate_response) + + assert sampling_result is not None + assert sampling_result.model == "test-model" + assert sampling_result.role == "assistant" + + @pytest.mark.anyio + async def test_create_message_updates_task_status(self) -> None: + """create_message() updates task status to input_required then back to working.""" + from mcp.shared.experimental.tasks import TaskSession + from mcp.types import SamplingMessage, TextContent + + store = InMemoryTaskStore() + queue = InMemoryTaskMessageQueue() + handler = TaskResultHandler(store, queue) + mock_session = Mock() + + task = await store.create_task(TaskMetadata(ttl=60000), task_id="task-sampling-status") + + task_session = TaskSession( + session=mock_session, + task_id=task.taskId, + store=store, + queue=queue, + ) + + status_during_wait: str | None = None + + async def do_sampling() -> None: + await task_session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Hi"))], + max_tokens=50, + ) + + async def check_status_and_respond() -> None: + nonlocal status_during_wait + await queue.wait_for_message(task.taskId) + + # Check status while waiting + task_state = await store.get_task(task.taskId) + assert task_state is not None + status_during_wait = task_state.status + + # Respond + msg = await queue.dequeue(task.taskId) + assert msg is not None + assert msg.resolver is not None + assert msg.original_request_id is not None + handler._pending_requests[msg.original_request_id] = msg.resolver + handler.route_response( + msg.original_request_id, + {"model": "m", "role": "assistant", "content": {"type": "text", "text": "Hi"}}, + ) + + async with anyio.create_task_group() as tg: + tg.start_soon(do_sampling) + tg.start_soon(check_status_and_respond) + + # Verify status was input_required during wait + assert status_during_wait == "input_required" + + # Verify status is back to working after + final_task = await store.get_task(task.taskId) + assert final_task is not None + assert final_task.status == "working" diff --git a/tests/experimental/tasks/test_spec_compliance.py b/tests/experimental/tasks/test_spec_compliance.py new file mode 100644 index 0000000000..f6d703c554 --- /dev/null +++ b/tests/experimental/tasks/test_spec_compliance.py @@ -0,0 +1,841 @@ +""" +Tasks Spec Compliance Tests +=========================== + +Test structure mirrors: https://modelcontextprotocol.io/specification/draft/basic/utilities/tasks.md + +Each section contains tests for normative requirements (MUST/SHOULD/MAY). +""" + +from datetime import datetime, timezone + +import pytest + +from mcp.server import Server +from mcp.server.lowlevel import NotificationOptions +from mcp.types import ( + CancelTaskRequest, + CancelTaskResult, + GetTaskRequest, + GetTaskResult, + ListTasksRequest, + ListTasksResult, + ServerCapabilities, +) + +# Shared test datetime +TEST_DATETIME = datetime(2025, 1, 1, tzinfo=timezone.utc) + +# ============================================================================= +# CAPABILITIES DECLARATION +# ============================================================================= + +# --- Server Capabilities --- + + +def _get_capabilities(server: Server) -> ServerCapabilities: + """Helper to get capabilities from a server.""" + return server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ) + + +# -- Capability declaration tests -- + + +def test_server_without_task_handlers_has_no_tasks_capability() -> None: + """Server without any task handlers has no tasks capability.""" + server: Server = Server("test") + caps = _get_capabilities(server) + assert caps.tasks is None + + +def test_server_with_list_tasks_handler_declares_list_capability() -> None: + """Server with list_tasks handler declares tasks.list capability.""" + server: Server = Server("test") + + @server.experimental.list_tasks() + async def handle_list(req: ListTasksRequest) -> ListTasksResult: + return ListTasksResult(tasks=[]) + + caps = _get_capabilities(server) + assert caps.tasks is not None + assert caps.tasks.list is not None + + +def test_server_with_cancel_task_handler_declares_cancel_capability() -> None: + """Server with cancel_task handler declares tasks.cancel capability.""" + server: Server = Server("test") + + @server.experimental.cancel_task() + async def handle_cancel(req: CancelTaskRequest) -> CancelTaskResult: + return CancelTaskResult( + taskId="test", status="cancelled", createdAt=TEST_DATETIME, lastUpdatedAt=TEST_DATETIME, ttl=None + ) + + caps = _get_capabilities(server) + assert caps.tasks is not None + assert caps.tasks.cancel is not None + + +def test_server_with_get_task_handler_declares_requests_tools_call_capability() -> None: + """ + Server with get_task handler declares tasks.requests.tools.call capability. + (get_task is required for task-augmented tools/call support) + """ + server: Server = Server("test") + + @server.experimental.get_task() + async def handle_get(req: GetTaskRequest) -> GetTaskResult: + return GetTaskResult( + taskId="test", status="working", createdAt=TEST_DATETIME, lastUpdatedAt=TEST_DATETIME, ttl=None + ) + + caps = _get_capabilities(server) + assert caps.tasks is not None + assert caps.tasks.requests is not None + assert caps.tasks.requests.tools is not None + + +def test_server_without_list_handler_has_no_list_capability() -> None: + """Server without list_tasks handler has no tasks.list capability.""" + server: Server = Server("test") + + # Register only get_task (not list_tasks) + @server.experimental.get_task() + async def handle_get(req: GetTaskRequest) -> GetTaskResult: + return GetTaskResult( + taskId="test", status="working", createdAt=TEST_DATETIME, lastUpdatedAt=TEST_DATETIME, ttl=None + ) + + caps = _get_capabilities(server) + assert caps.tasks is not None + assert caps.tasks.list is None + + +def test_server_without_cancel_handler_has_no_cancel_capability() -> None: + """Server without cancel_task handler has no tasks.cancel capability.""" + server: Server = Server("test") + + # Register only get_task (not cancel_task) + @server.experimental.get_task() + async def handle_get(req: GetTaskRequest) -> GetTaskResult: + return GetTaskResult( + taskId="test", status="working", createdAt=TEST_DATETIME, lastUpdatedAt=TEST_DATETIME, ttl=None + ) + + caps = _get_capabilities(server) + assert caps.tasks is not None + assert caps.tasks.cancel is None + + +def test_server_with_all_task_handlers_has_full_capability() -> None: + """Server with all task handlers declares complete tasks capability.""" + server: Server = Server("test") + + @server.experimental.list_tasks() + async def handle_list(req: ListTasksRequest) -> ListTasksResult: + return ListTasksResult(tasks=[]) + + @server.experimental.cancel_task() + async def handle_cancel(req: CancelTaskRequest) -> CancelTaskResult: + return CancelTaskResult( + taskId="test", status="cancelled", createdAt=TEST_DATETIME, lastUpdatedAt=TEST_DATETIME, ttl=None + ) + + @server.experimental.get_task() + async def handle_get(req: GetTaskRequest) -> GetTaskResult: + return GetTaskResult( + taskId="test", status="working", createdAt=TEST_DATETIME, lastUpdatedAt=TEST_DATETIME, ttl=None + ) + + caps = _get_capabilities(server) + assert caps.tasks is not None + assert caps.tasks.list is not None + assert caps.tasks.cancel is not None + assert caps.tasks.requests is not None + assert caps.tasks.requests.tools is not None + + +# --- Client Capabilities --- + + +class TestClientCapabilities: + """ + Clients declare: + - tasks.list — supports listing operations + - tasks.cancel — supports cancellation + - tasks.requests.sampling.createMessage — task-augmented sampling + - tasks.requests.elicitation.create — task-augmented elicitation + """ + + def test_client_declares_tasks_capability(self) -> None: + """Client can declare tasks capability.""" + pytest.skip("TODO") + + +# --- Tool-Level Negotiation --- + + +class TestToolLevelNegotiation: + """ + Tools in tools/list responses include execution.taskSupport with values: + - Not present or "forbidden": No task augmentation allowed + - "optional": Task augmentation allowed at requestor discretion + - "required": Task augmentation is mandatory + """ + + def test_tool_execution_task_forbidden_rejects_task_augmented_call(self) -> None: + """Tool with execution.taskSupport="forbidden" MUST reject task-augmented calls (-32601).""" + pytest.skip("TODO") + + def test_tool_execution_task_absent_rejects_task_augmented_call(self) -> None: + """Tool without execution.taskSupport MUST reject task-augmented calls (-32601).""" + pytest.skip("TODO") + + def test_tool_execution_task_optional_accepts_normal_call(self) -> None: + """Tool with execution.taskSupport="optional" accepts normal calls.""" + pytest.skip("TODO") + + def test_tool_execution_task_optional_accepts_task_augmented_call(self) -> None: + """Tool with execution.taskSupport="optional" accepts task-augmented calls.""" + pytest.skip("TODO") + + def test_tool_execution_task_required_rejects_normal_call(self) -> None: + """Tool with execution.taskSupport="required" MUST reject non-task calls (-32601).""" + pytest.skip("TODO") + + def test_tool_execution_task_required_accepts_task_augmented_call(self) -> None: + """Tool with execution.taskSupport="required" accepts task-augmented calls.""" + pytest.skip("TODO") + + +# --- Capability Negotiation --- + + +class TestCapabilityNegotiation: + """ + Requestors SHOULD only augment requests with a task if the corresponding + capability has been declared by the receiver. + + Receivers that do not declare the task capability for a request type + MUST process requests of that type normally, ignoring any task-augmentation + metadata if present. + """ + + def test_receiver_without_capability_ignores_task_metadata(self) -> None: + """ + Receiver without task capability MUST process request normally, + ignoring task-augmentation metadata. + """ + pytest.skip("TODO") + + def test_receiver_with_capability_may_require_task_augmentation(self) -> None: + """ + Receivers that declare task capability MAY return error (-32600) + for non-task-augmented requests, requiring task augmentation. + """ + pytest.skip("TODO") + + +# ============================================================================= +# TASK STATUS LIFECYCLE +# ============================================================================= + + +class TestTaskStatusLifecycle: + """ + Tasks begin in working status and follow valid transitions: + working → input_required → working → terminal + working → terminal (directly) + input_required → terminal (directly) + + Terminal states (no further transitions allowed): + - completed + - failed + - cancelled + """ + + def test_task_begins_in_working_status(self) -> None: + """Tasks MUST begin in working status.""" + pytest.skip("TODO") + + def test_working_to_completed_transition(self) -> None: + """working → completed is valid.""" + pytest.skip("TODO") + + def test_working_to_failed_transition(self) -> None: + """working → failed is valid.""" + pytest.skip("TODO") + + def test_working_to_cancelled_transition(self) -> None: + """working → cancelled is valid.""" + pytest.skip("TODO") + + def test_working_to_input_required_transition(self) -> None: + """working → input_required is valid.""" + pytest.skip("TODO") + + def test_input_required_to_working_transition(self) -> None: + """input_required → working is valid.""" + pytest.skip("TODO") + + def test_input_required_to_terminal_transition(self) -> None: + """input_required → terminal is valid.""" + pytest.skip("TODO") + + def test_terminal_state_no_further_transitions(self) -> None: + """Terminal states allow no further transitions.""" + pytest.skip("TODO") + + def test_completed_is_terminal(self) -> None: + """completed is a terminal state.""" + pytest.skip("TODO") + + def test_failed_is_terminal(self) -> None: + """failed is a terminal state.""" + pytest.skip("TODO") + + def test_cancelled_is_terminal(self) -> None: + """cancelled is a terminal state.""" + pytest.skip("TODO") + + +# --- Input Required Status --- + + +class TestInputRequiredStatus: + """ + When a receiver needs information to proceed, it moves the task to input_required. + The requestor should call tasks/result to retrieve input requests. + The task must include io.modelcontextprotocol/related-task metadata in associated requests. + """ + + def test_input_required_status_retrievable_via_tasks_get(self) -> None: + """Task in input_required status is retrievable via tasks/get.""" + pytest.skip("TODO") + + def test_input_required_related_task_metadata_in_requests(self) -> None: + """ + Task MUST include io.modelcontextprotocol/related-task metadata + in associated requests. + """ + pytest.skip("TODO") + + +# ============================================================================= +# PROTOCOL MESSAGES +# ============================================================================= + +# --- Creating a Task --- + + +class TestCreatingTask: + """ + Request structure: + {"method": "tools/call", "params": {"name": "...", "arguments": {...}, "task": {"ttl": 60000}}} + + Response (CreateTaskResult): + {"result": {"task": {"taskId": "...", "status": "working", ...}}} + + Receivers may include io.modelcontextprotocol/model-immediate-response in _meta. + """ + + def test_task_augmented_request_returns_create_task_result(self) -> None: + """Task-augmented request MUST return CreateTaskResult immediately.""" + pytest.skip("TODO") + + def test_create_task_result_contains_task_id(self) -> None: + """CreateTaskResult MUST contain taskId.""" + pytest.skip("TODO") + + def test_create_task_result_contains_status_working(self) -> None: + """CreateTaskResult MUST have status=working initially.""" + pytest.skip("TODO") + + def test_create_task_result_contains_created_at(self) -> None: + """CreateTaskResult MUST contain createdAt timestamp.""" + pytest.skip("TODO") + + def test_create_task_result_created_at_is_iso8601(self) -> None: + """createdAt MUST be ISO 8601 formatted.""" + pytest.skip("TODO") + + def test_create_task_result_may_contain_ttl(self) -> None: + """CreateTaskResult MAY contain ttl.""" + pytest.skip("TODO") + + def test_create_task_result_may_contain_poll_interval(self) -> None: + """CreateTaskResult MAY contain pollInterval.""" + pytest.skip("TODO") + + def test_create_task_result_may_contain_status_message(self) -> None: + """CreateTaskResult MAY contain statusMessage.""" + pytest.skip("TODO") + + def test_receiver_may_override_requested_ttl(self) -> None: + """Receiver MAY override requested ttl but MUST return actual value.""" + pytest.skip("TODO") + + def test_model_immediate_response_in_meta(self) -> None: + """ + Receiver MAY include io.modelcontextprotocol/model-immediate-response + in _meta to provide immediate response while task executes. + """ + from mcp.shared.experimental.tasks import MODEL_IMMEDIATE_RESPONSE_KEY + from mcp.types import CreateTaskResult, Task + + # Verify the constant has the correct value per spec + assert MODEL_IMMEDIATE_RESPONSE_KEY == "io.modelcontextprotocol/model-immediate-response" + + # CreateTaskResult can include model-immediate-response in _meta + task = Task( + taskId="test-123", + status="working", + createdAt=TEST_DATETIME, + lastUpdatedAt=TEST_DATETIME, + ttl=60000, + ) + immediate_msg = "Task started, processing your request..." + # Note: Must use _meta= (alias) not meta= due to Pydantic alias handling + result = CreateTaskResult( + task=task, + **{"_meta": {MODEL_IMMEDIATE_RESPONSE_KEY: immediate_msg}}, + ) + + # Verify the metadata is present and correct + assert result.meta is not None + assert MODEL_IMMEDIATE_RESPONSE_KEY in result.meta + assert result.meta[MODEL_IMMEDIATE_RESPONSE_KEY] == immediate_msg + + # Verify it serializes correctly with _meta alias + serialized = result.model_dump(by_alias=True) + assert "_meta" in serialized + assert MODEL_IMMEDIATE_RESPONSE_KEY in serialized["_meta"] + assert serialized["_meta"][MODEL_IMMEDIATE_RESPONSE_KEY] == immediate_msg + + +# --- Getting Task Status (tasks/get) --- + + +class TestGettingTaskStatus: + """ + Request: {"method": "tasks/get", "params": {"taskId": "..."}} + Response: Returns full Task object with current status and pollInterval. + """ + + def test_tasks_get_returns_task_object(self) -> None: + """tasks/get MUST return full Task object.""" + pytest.skip("TODO") + + def test_tasks_get_returns_current_status(self) -> None: + """tasks/get MUST return current status.""" + pytest.skip("TODO") + + def test_tasks_get_may_return_poll_interval(self) -> None: + """tasks/get MAY return pollInterval.""" + pytest.skip("TODO") + + def test_tasks_get_invalid_task_id_returns_error(self) -> None: + """tasks/get with invalid taskId MUST return -32602.""" + pytest.skip("TODO") + + def test_tasks_get_nonexistent_task_id_returns_error(self) -> None: + """tasks/get with nonexistent taskId MUST return -32602.""" + pytest.skip("TODO") + + +# --- Retrieving Results (tasks/result) --- + + +class TestRetrievingResults: + """ + Request: {"method": "tasks/result", "params": {"taskId": "..."}} + Response: The actual operation result structure (e.g., CallToolResult). + + This call blocks until terminal status. + """ + + def test_tasks_result_returns_underlying_result(self) -> None: + """tasks/result MUST return exactly what underlying request would return.""" + pytest.skip("TODO") + + def test_tasks_result_blocks_until_terminal(self) -> None: + """tasks/result MUST block for non-terminal tasks.""" + pytest.skip("TODO") + + def test_tasks_result_unblocks_on_terminal(self) -> None: + """tasks/result MUST unblock upon reaching terminal status.""" + pytest.skip("TODO") + + def test_tasks_result_includes_related_task_metadata(self) -> None: + """tasks/result MUST include io.modelcontextprotocol/related-task in _meta.""" + pytest.skip("TODO") + + def test_tasks_result_returns_error_for_failed_task(self) -> None: + """ + tasks/result returns the same error the underlying request + would have produced for failed tasks. + """ + pytest.skip("TODO") + + def test_tasks_result_invalid_task_id_returns_error(self) -> None: + """tasks/result with invalid taskId MUST return -32602.""" + pytest.skip("TODO") + + +# --- Listing Tasks (tasks/list) --- + + +class TestListingTasks: + """ + Request: {"method": "tasks/list", "params": {"cursor": "optional"}} + Response: Array of tasks with pagination support via nextCursor. + """ + + def test_tasks_list_returns_array_of_tasks(self) -> None: + """tasks/list MUST return array of tasks.""" + pytest.skip("TODO") + + def test_tasks_list_pagination_with_cursor(self) -> None: + """tasks/list supports pagination via cursor.""" + pytest.skip("TODO") + + def test_tasks_list_returns_next_cursor_when_more_results(self) -> None: + """tasks/list MUST return nextCursor when more results available.""" + pytest.skip("TODO") + + def test_tasks_list_cursors_are_opaque(self) -> None: + """Implementers MUST treat cursors as opaque tokens.""" + pytest.skip("TODO") + + def test_tasks_list_invalid_cursor_returns_error(self) -> None: + """tasks/list with invalid cursor MUST return -32602.""" + pytest.skip("TODO") + + +# --- Cancelling Tasks (tasks/cancel) --- + + +class TestCancellingTasks: + """ + Request: {"method": "tasks/cancel", "params": {"taskId": "..."}} + Response: Returns the task object with status: "cancelled". + """ + + def test_tasks_cancel_returns_cancelled_task(self) -> None: + """tasks/cancel MUST return task with status=cancelled.""" + pytest.skip("TODO") + + def test_tasks_cancel_terminal_task_returns_error(self) -> None: + """Cancelling already-terminal task MUST return -32602.""" + pytest.skip("TODO") + + def test_tasks_cancel_completed_task_returns_error(self) -> None: + """Cancelling completed task MUST return -32602.""" + pytest.skip("TODO") + + def test_tasks_cancel_failed_task_returns_error(self) -> None: + """Cancelling failed task MUST return -32602.""" + pytest.skip("TODO") + + def test_tasks_cancel_already_cancelled_task_returns_error(self) -> None: + """Cancelling already-cancelled task MUST return -32602.""" + pytest.skip("TODO") + + def test_tasks_cancel_invalid_task_id_returns_error(self) -> None: + """tasks/cancel with invalid taskId MUST return -32602.""" + pytest.skip("TODO") + + +# --- Status Notifications --- + + +class TestStatusNotifications: + """ + Receivers MAY send: {"method": "notifications/tasks/status", "params": {...}} + These are optional; requestors MUST NOT rely on them and SHOULD continue polling. + """ + + def test_receiver_may_send_status_notification(self) -> None: + """Receiver MAY send notifications/tasks/status.""" + pytest.skip("TODO") + + def test_status_notification_contains_task_id(self) -> None: + """Status notification MUST contain taskId.""" + pytest.skip("TODO") + + def test_status_notification_contains_status(self) -> None: + """Status notification MUST contain status.""" + pytest.skip("TODO") + + +# ============================================================================= +# BEHAVIORAL REQUIREMENTS +# ============================================================================= + +# --- Task Management --- + + +class TestTaskManagement: + """ + - Receivers generate unique task IDs as strings + - Tasks must begin in working status + - createdAt timestamps must be ISO 8601 formatted + - Receivers may override requested ttl but must return actual value + - Receivers may delete tasks after TTL expires + - All task-related messages must include io.modelcontextprotocol/related-task + in _meta except for tasks/get, tasks/list, tasks/cancel operations + """ + + def test_task_ids_are_unique_strings(self) -> None: + """Receivers MUST generate unique task IDs as strings.""" + pytest.skip("TODO") + + def test_multiple_tasks_have_unique_ids(self) -> None: + """Multiple tasks MUST have unique IDs.""" + pytest.skip("TODO") + + def test_receiver_may_delete_tasks_after_ttl(self) -> None: + """Receivers MAY delete tasks after TTL expires.""" + pytest.skip("TODO") + + def test_related_task_metadata_in_task_messages(self) -> None: + """ + All task-related messages MUST include io.modelcontextprotocol/related-task + in _meta. + """ + pytest.skip("TODO") + + def test_tasks_get_does_not_require_related_task_metadata(self) -> None: + """tasks/get does not require related-task metadata.""" + pytest.skip("TODO") + + def test_tasks_list_does_not_require_related_task_metadata(self) -> None: + """tasks/list does not require related-task metadata.""" + pytest.skip("TODO") + + def test_tasks_cancel_does_not_require_related_task_metadata(self) -> None: + """tasks/cancel does not require related-task metadata.""" + pytest.skip("TODO") + + +# --- Result Handling --- + + +class TestResultHandling: + """ + - Receivers must return CreateTaskResult immediately upon accepting task-augmented requests + - tasks/result must return exactly what the underlying request would return + - tasks/result blocks for non-terminal tasks; must unblock upon reaching terminal status + """ + + def test_create_task_result_returned_immediately(self) -> None: + """Receiver MUST return CreateTaskResult immediately (not after work completes).""" + pytest.skip("TODO") + + def test_tasks_result_matches_underlying_result_structure(self) -> None: + """tasks/result MUST return same structure as underlying request.""" + pytest.skip("TODO") + + def test_tasks_result_for_tool_call_returns_call_tool_result(self) -> None: + """tasks/result for tools/call returns CallToolResult.""" + pytest.skip("TODO") + + +# --- Progress Tracking --- + + +class TestProgressTracking: + """ + Task-augmented requests support progress notifications using the progressToken + mechanism, which remains valid throughout the task lifetime. + """ + + def test_progress_token_valid_throughout_task_lifetime(self) -> None: + """progressToken remains valid throughout task lifetime.""" + pytest.skip("TODO") + + def test_progress_notifications_sent_during_task_execution(self) -> None: + """Progress notifications can be sent during task execution.""" + pytest.skip("TODO") + + +# ============================================================================= +# ERROR HANDLING +# ============================================================================= + + +class TestProtocolErrors: + """ + Protocol Errors (JSON-RPC standard codes): + - -32600 (Invalid request): Non-task requests to endpoint requiring task augmentation + - -32602 (Invalid params): Invalid/nonexistent taskId, invalid cursor, cancel terminal task + - -32603 (Internal error): Server-side execution failures + """ + + def test_invalid_request_for_required_task_augmentation(self) -> None: + """Non-task request to task-required endpoint returns -32600.""" + pytest.skip("TODO") + + def test_invalid_params_for_invalid_task_id(self) -> None: + """Invalid taskId returns -32602.""" + pytest.skip("TODO") + + def test_invalid_params_for_nonexistent_task_id(self) -> None: + """Nonexistent taskId returns -32602.""" + pytest.skip("TODO") + + def test_invalid_params_for_invalid_cursor(self) -> None: + """Invalid cursor in tasks/list returns -32602.""" + pytest.skip("TODO") + + def test_invalid_params_for_cancel_terminal_task(self) -> None: + """Attempt to cancel terminal task returns -32602.""" + pytest.skip("TODO") + + def test_internal_error_for_server_failure(self) -> None: + """Server-side execution failure returns -32603.""" + pytest.skip("TODO") + + +class TestTaskExecutionErrors: + """ + When underlying requests fail, the task moves to failed status. + - tasks/get response should include statusMessage explaining failure + - tasks/result returns same error the underlying request would have produced + - For tool calls, isError: true moves task to failed status + """ + + def test_underlying_failure_moves_task_to_failed(self) -> None: + """Underlying request failure moves task to failed status.""" + pytest.skip("TODO") + + def test_failed_task_has_status_message(self) -> None: + """Failed task SHOULD include statusMessage explaining failure.""" + pytest.skip("TODO") + + def test_tasks_result_returns_underlying_error(self) -> None: + """tasks/result returns same error underlying request would produce.""" + pytest.skip("TODO") + + def test_tool_call_is_error_true_moves_to_failed(self) -> None: + """Tool call with isError: true moves task to failed status.""" + pytest.skip("TODO") + + +# ============================================================================= +# DATA TYPES +# ============================================================================= + + +class TestTaskObject: + """ + Task Object fields: + - taskId: String identifier + - status: Current execution state + - statusMessage: Optional human-readable description + - createdAt: ISO 8601 timestamp of creation + - ttl: Milliseconds before potential deletion + - pollInterval: Suggested milliseconds between polls + """ + + def test_task_has_task_id_string(self) -> None: + """Task MUST have taskId as string.""" + pytest.skip("TODO") + + def test_task_has_status(self) -> None: + """Task MUST have status.""" + pytest.skip("TODO") + + def test_task_status_message_is_optional(self) -> None: + """Task statusMessage is optional.""" + pytest.skip("TODO") + + def test_task_has_created_at(self) -> None: + """Task MUST have createdAt.""" + pytest.skip("TODO") + + def test_task_ttl_is_optional(self) -> None: + """Task ttl is optional.""" + pytest.skip("TODO") + + def test_task_poll_interval_is_optional(self) -> None: + """Task pollInterval is optional.""" + pytest.skip("TODO") + + +class TestRelatedTaskMetadata: + """ + Related Task Metadata structure: + {"_meta": {"io.modelcontextprotocol/related-task": {"taskId": "..."}}} + """ + + def test_related_task_metadata_structure(self) -> None: + """Related task metadata has correct structure.""" + pytest.skip("TODO") + + def test_related_task_metadata_contains_task_id(self) -> None: + """Related task metadata contains taskId.""" + pytest.skip("TODO") + + +# ============================================================================= +# SECURITY CONSIDERATIONS +# ============================================================================= + + +class TestAccessAndIsolation: + """ + - Task IDs enable access to sensitive results + - Authorization context binding is essential where available + - For non-authorized environments: strong entropy IDs, strict TTL limits + """ + + def test_task_bound_to_authorization_context(self) -> None: + """ + Receivers receiving authorization context MUST bind tasks to that context. + """ + pytest.skip("TODO") + + def test_reject_task_operations_outside_authorization_context(self) -> None: + """ + Receivers MUST reject task operations for tasks outside + requestor's authorization context. + """ + pytest.skip("TODO") + + def test_non_authorized_environments_use_secure_ids(self) -> None: + """ + For non-authorized environments, receivers SHOULD use + cryptographically secure IDs. + """ + pytest.skip("TODO") + + def test_non_authorized_environments_use_shorter_ttls(self) -> None: + """ + For non-authorized environments, receivers SHOULD use shorter TTLs. + """ + pytest.skip("TODO") + + +class TestResourceLimits: + """ + Receivers should: + - Enforce concurrent task limits per requestor + - Implement maximum TTL constraints + - Clean up expired tasks promptly + """ + + def test_concurrent_task_limit_enforced(self) -> None: + """Receiver SHOULD enforce concurrent task limits per requestor.""" + pytest.skip("TODO") + + def test_maximum_ttl_constraint_enforced(self) -> None: + """Receiver SHOULD implement maximum TTL constraints.""" + pytest.skip("TODO") + + def test_expired_tasks_cleaned_up(self) -> None: + """Receiver SHOULD clean up expired tasks promptly.""" + pytest.skip("TODO") diff --git a/uv.lock b/uv.lock index d1363aef41..2aec51e51c 100644 --- a/uv.lock +++ b/uv.lock @@ -15,6 +15,10 @@ members = [ "mcp-simple-resource", "mcp-simple-streamablehttp", "mcp-simple-streamablehttp-stateless", + "mcp-simple-task", + "mcp-simple-task-client", + "mcp-simple-task-interactive", + "mcp-simple-task-interactive-client", "mcp-simple-tool", "mcp-snippets", "mcp-structured-output-lowlevel", @@ -1196,6 +1200,126 @@ dev = [ { name = "ruff", specifier = ">=0.6.9" }, ] +[[package]] +name = "mcp-simple-task" +version = "0.1.0" +source = { editable = "examples/servers/simple-task" } +dependencies = [ + { name = "anyio" }, + { name = "click" }, + { name = "mcp" }, + { name = "starlette" }, + { name = "uvicorn" }, +] + +[package.dev-dependencies] +dev = [ + { name = "pyright" }, + { name = "ruff" }, +] + +[package.metadata] +requires-dist = [ + { name = "anyio", specifier = ">=4.5" }, + { name = "click", specifier = ">=8.0" }, + { name = "mcp", editable = "." }, + { name = "starlette" }, + { name = "uvicorn" }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "pyright", specifier = ">=1.1.378" }, + { name = "ruff", specifier = ">=0.6.9" }, +] + +[[package]] +name = "mcp-simple-task-client" +version = "0.1.0" +source = { editable = "examples/clients/simple-task-client" } +dependencies = [ + { name = "click" }, + { name = "mcp" }, +] + +[package.dev-dependencies] +dev = [ + { name = "pyright" }, + { name = "ruff" }, +] + +[package.metadata] +requires-dist = [ + { name = "click", specifier = ">=8.0" }, + { name = "mcp", editable = "." }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "pyright", specifier = ">=1.1.378" }, + { name = "ruff", specifier = ">=0.6.9" }, +] + +[[package]] +name = "mcp-simple-task-interactive" +version = "0.1.0" +source = { editable = "examples/servers/simple-task-interactive" } +dependencies = [ + { name = "anyio" }, + { name = "click" }, + { name = "mcp" }, + { name = "starlette" }, + { name = "uvicorn" }, +] + +[package.dev-dependencies] +dev = [ + { name = "pyright" }, + { name = "ruff" }, +] + +[package.metadata] +requires-dist = [ + { name = "anyio", specifier = ">=4.5" }, + { name = "click", specifier = ">=8.0" }, + { name = "mcp", editable = "." }, + { name = "starlette" }, + { name = "uvicorn" }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "pyright", specifier = ">=1.1.378" }, + { name = "ruff", specifier = ">=0.6.9" }, +] + +[[package]] +name = "mcp-simple-task-interactive-client" +version = "0.1.0" +source = { editable = "examples/clients/simple-task-interactive-client" } +dependencies = [ + { name = "click" }, + { name = "mcp" }, +] + +[package.dev-dependencies] +dev = [ + { name = "pyright" }, + { name = "ruff" }, +] + +[package.metadata] +requires-dist = [ + { name = "click", specifier = ">=8.0" }, + { name = "mcp", editable = "." }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "pyright", specifier = ">=1.1.378" }, + { name = "ruff", specifier = ">=0.6.9" }, +] + [[package]] name = "mcp-simple-tool" version = "0.1.0"