From e17ddf46a713d50f4248b8cb65a5e89be04af988 Mon Sep 17 00:00:00 2001 From: Konstantin Konstantinov Date: Fri, 19 Sep 2025 15:29:07 +0300 Subject: [PATCH] add prettier for full repository --- .github/workflows/main.yml | 92 +- .prettierignore | 10 + .prettierrc.json | 20 + CODE_OF_CONDUCT.md | 113 +- README.md | 1504 +-- eslint.config.mjs | 16 +- jest.config.js | 18 +- package-lock.json | 34 + package.json | 198 +- src/__mocks__/pkce-challenge.ts | 10 +- src/cli.ts | 284 +- src/client/auth.test.ts | 4862 +++++----- src/client/auth.ts | 1593 ++-- src/client/cross-spawn.test.ts | 276 +- src/client/index.test.ts | 2248 +++-- src/client/index.ts | 799 +- src/client/middleware.test.ts | 2044 ++--- src/client/middleware.ts | 383 +- src/client/sse.test.ts | 2661 +++--- src/client/sse.ts | 469 +- src/client/stdio.test.ts | 112 +- src/client/stdio.ts | 394 +- src/client/streamableHttp.test.ts | 1835 ++-- src/client/streamableHttp.ts | 939 +- src/client/websocket.ts | 113 +- src/examples/README.md | 95 +- .../client/multipleClientsParallel.ts | 248 +- .../client/parallelToolCallsClient.ts | 313 +- src/examples/client/simpleOAuthClient.ts | 654 +- src/examples/client/simpleStreamableHttp.ts | 1415 +-- .../streamableHttpWithSseFallbackClient.ts | 279 +- .../server/demoInMemoryOAuthProvider.ts | 374 +- .../server/jsonResponseStreamableHttp.ts | 287 +- src/examples/server/mcpServerOutputSchema.ts | 124 +- src/examples/server/simpleSseServer.ts | 248 +- .../server/simpleStatelessStreamableHttp.ts | 278 +- src/examples/server/simpleStreamableHttp.ts | 1192 +-- .../sseAndStreamableHttpCompatibleServer.ts | 348 +- .../standaloneSseWithGetStreamableHttp.ts | 174 +- src/examples/server/toolWithSampleServer.ts | 83 +- src/examples/shared/inMemoryEventStore.ts | 115 +- src/inMemory.test.ts | 238 +- src/inMemory.ts | 100 +- src/integration-tests/process-cleanup.test.ts | 46 +- .../stateManagementStreamableHttp.test.ts | 655 +- .../taskResumability.test.ts | 462 +- src/server/auth/clients.ts | 30 +- src/server/auth/errors.ts | 134 +- src/server/auth/handlers/authorize.test.ts | 587 +- src/server/auth/handlers/authorize.ts | 298 +- src/server/auth/handlers/metadata.test.ts | 122 +- src/server/auth/handlers/metadata.ts | 24 +- src/server/auth/handlers/register.test.ts | 452 +- src/server/auth/handlers/register.ts | 219 +- src/server/auth/handlers/revoke.test.ts | 412 +- src/server/auth/handlers/revoke.ts | 138 +- src/server/auth/handlers/token.test.ts | 849 +- src/server/auth/handlers/token.ts | 267 +- .../auth/middleware/allowedMethods.test.ts | 126 +- src/server/auth/middleware/allowedMethods.ts | 26 +- src/server/auth/middleware/bearerAuth.test.ts | 835 +- src/server/auth/middleware/bearerAuth.ts | 146 +- src/server/auth/middleware/clientAuth.test.ts | 260 +- src/server/auth/middleware/clientAuth.ts | 116 +- src/server/auth/provider.ts | 123 +- .../auth/providers/proxyProvider.test.ts | 670 +- src/server/auth/providers/proxyProvider.ts | 411 +- src/server/auth/router.test.ts | 831 +- src/server/auth/router.ts | 320 +- src/server/auth/types.ts | 54 +- src/server/completable.test.ts | 66 +- src/server/completable.ts | 133 +- src/server/index.test.ts | 1708 ++-- src/server/index.ts | 656 +- src/server/mcp.test.ts | 8152 ++++++++--------- src/server/mcp.ts | 2169 ++--- src/server/sse.test.ts | 1248 ++- src/server/sse.ts | 370 +- src/server/stdio.test.ts | 158 +- src/server/stdio.ts | 142 +- src/server/streamableHttp.test.ts | 4034 ++++---- src/server/streamableHttp.ts | 1374 +-- src/server/title.test.ts | 453 +- src/shared/auth-utils.test.ts | 123 +- src/shared/auth-utils.ts | 61 +- src/shared/auth.test.ts | 201 +- src/shared/auth.ts | 312 +- src/shared/metadataUtils.ts | 26 +- .../protocol-transport-handling.test.ts | 354 +- src/shared/protocol.test.ts | 1382 +-- src/shared/protocol.ts | 1220 ++- src/shared/stdio.test.ts | 44 +- src/shared/stdio.ts | 44 +- src/shared/transport.ts | 128 +- src/shared/uriTemplate.test.ts | 560 +- src/shared/uriTemplate.ts | 525 +- src/spec.types.test.ts | 913 +- src/types.test.ts | 229 +- src/types.ts | 1924 ++-- tsconfig.cjs.json | 14 +- tsconfig.json | 42 +- tsconfig.prod.json | 10 +- 102 files changed, 31779 insertions(+), 33201 deletions(-) create mode 100644 .prettierignore create mode 100644 .prettierrc.json diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 04ba17c90..5a7a84b35 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -1,51 +1,51 @@ on: - push: - branches: - - main - pull_request: - release: - types: [published] + push: + branches: + - main + pull_request: + release: + types: [published] concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true jobs: - build: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v4 - - uses: actions/setup-node@v4 - with: - node-version: 18 - cache: npm - - - run: npm ci - - run: npm run build - - run: npm test - - run: npm run lint - - publish: - runs-on: ubuntu-latest - if: github.event_name == 'release' - environment: release - needs: build - - permissions: - contents: read - id-token: write - - steps: - - uses: actions/checkout@v4 - - uses: actions/setup-node@v4 - with: - node-version: 18 - cache: npm - registry-url: 'https://registry.npmjs.org' - - - run: npm ci - - - run: npm publish --provenance --access public - env: - NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }} + build: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-node@v4 + with: + node-version: 18 + cache: npm + + - run: npm ci + - run: npm run build + - run: npm test + - run: npm run lint + + publish: + runs-on: ubuntu-latest + if: github.event_name == 'release' + environment: release + needs: build + + permissions: + contents: read + id-token: write + + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-node@v4 + with: + node-version: 18 + cache: npm + registry-url: 'https://registry.npmjs.org' + + - run: npm ci + + - run: npm publish --provenance --access public + env: + NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }} diff --git a/.prettierignore b/.prettierignore new file mode 100644 index 000000000..bacf96e9e --- /dev/null +++ b/.prettierignore @@ -0,0 +1,10 @@ +# Ignore artifacts: +build +dist +coverage +*-lock.* +node_modules +**/build +**/dist +.github/CODEOWNERS +pnpm-lock.yaml \ No newline at end of file diff --git a/.prettierrc.json b/.prettierrc.json new file mode 100644 index 000000000..840a2c6b0 --- /dev/null +++ b/.prettierrc.json @@ -0,0 +1,20 @@ +{ + "printWidth": 140, + "tabWidth": 4, + "useTabs": false, + "semi": true, + "singleQuote": true, + "trailingComma": "none", + "bracketSpacing": true, + "bracketSameLine": false, + "proseWrap": "always", + "arrowParens": "avoid", + "overrides": [ + { + "files": "**/*.md", + "options": { + "printWidth": 280 + } + } + ] +} diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index 985a28566..62c701add 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -2,127 +2,82 @@ ## Our Pledge -We as members, contributors, and leaders pledge to make participation in our -community a harassment-free experience for everyone, regardless of age, body -size, visible or invisible disability, ethnicity, sex characteristics, gender -identity and expression, level of experience, education, socio-economic status, -nationality, personal appearance, race, religion, or sexual identity -and orientation. +We as members, contributors, and leaders pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, +education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. -We pledge to act and interact in ways that contribute to an open, welcoming, -diverse, inclusive, and healthy community. +We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community. ## Our Standards -Examples of behavior that contributes to a positive environment for our -community include: +Examples of behavior that contributes to a positive environment for our community include: -* Demonstrating empathy and kindness toward other people -* Being respectful of differing opinions, viewpoints, and experiences -* Giving and gracefully accepting constructive feedback -* Accepting responsibility and apologizing to those affected by our mistakes, - and learning from the experience -* Focusing on what is best not just for us as individuals, but for the - overall community +- Demonstrating empathy and kindness toward other people +- Being respectful of differing opinions, viewpoints, and experiences +- Giving and gracefully accepting constructive feedback +- Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience +- Focusing on what is best not just for us as individuals, but for the overall community Examples of unacceptable behavior include: -* The use of sexualized language or imagery, and sexual attention or - advances of any kind -* Trolling, insulting or derogatory comments, and personal or political attacks -* Public or private harassment -* Publishing others' private information, such as a physical or email - address, without their explicit permission -* Other conduct which could reasonably be considered inappropriate in a - professional setting +- The use of sexualized language or imagery, and sexual attention or advances of any kind +- Trolling, insulting or derogatory comments, and personal or political attacks +- Public or private harassment +- Publishing others' private information, such as a physical or email address, without their explicit permission +- Other conduct which could reasonably be considered inappropriate in a professional setting ## Enforcement Responsibilities -Community leaders are responsible for clarifying and enforcing our standards of -acceptable behavior and will take appropriate and fair corrective action in -response to any behavior that they deem inappropriate, threatening, offensive, -or harmful. +Community leaders are responsible for clarifying and enforcing our standards of acceptable behavior and will take appropriate and fair corrective action in response to any behavior that they deem inappropriate, threatening, offensive, or harmful. -Community leaders have the right and responsibility to remove, edit, or reject -comments, commits, code, wiki edits, issues, and other contributions that are -not aligned to this Code of Conduct, and will communicate reasons for moderation -decisions when appropriate. +Community leaders have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, and will communicate reasons for moderation decisions when appropriate. ## Scope -This Code of Conduct applies within all community spaces, and also applies when -an individual is officially representing the community in public spaces. -Examples of representing our community include using an official e-mail address, -posting via an official social media account, or acting as an appointed -representative at an online or offline event. +This Code of Conduct applies within all community spaces, and also applies when an individual is officially representing the community in public spaces. Examples of representing our community include using an official e-mail address, posting via an official social media account, +or acting as an appointed representative at an online or offline event. ## Enforcement -Instances of abusive, harassing, or otherwise unacceptable behavior may be -reported to the community leaders responsible for enforcement at -. -All complaints will be reviewed and investigated promptly and fairly. +Instances of abusive, harassing, or otherwise unacceptable behavior may be reported to the community leaders responsible for enforcement at . All complaints will be reviewed and investigated promptly and fairly. -All community leaders are obligated to respect the privacy and security of the -reporter of any incident. +All community leaders are obligated to respect the privacy and security of the reporter of any incident. ## Enforcement Guidelines -Community leaders will follow these Community Impact Guidelines in determining -the consequences for any action they deem in violation of this Code of Conduct: +Community leaders will follow these Community Impact Guidelines in determining the consequences for any action they deem in violation of this Code of Conduct: ### 1. Correction -**Community Impact**: Use of inappropriate language or other behavior deemed -unprofessional or unwelcome in the community. +**Community Impact**: Use of inappropriate language or other behavior deemed unprofessional or unwelcome in the community. -**Consequence**: A private, written warning from community leaders, providing -clarity around the nature of the violation and an explanation of why the -behavior was inappropriate. A public apology may be requested. +**Consequence**: A private, written warning from community leaders, providing clarity around the nature of the violation and an explanation of why the behavior was inappropriate. A public apology may be requested. ### 2. Warning -**Community Impact**: A violation through a single incident or series -of actions. +**Community Impact**: A violation through a single incident or series of actions. -**Consequence**: A warning with consequences for continued behavior. No -interaction with the people involved, including unsolicited interaction with -those enforcing the Code of Conduct, for a specified period of time. This -includes avoiding interactions in community spaces as well as external channels -like social media. Violating these terms may lead to a temporary or -permanent ban. +**Consequence**: A warning with consequences for continued behavior. No interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, for a specified period of time. This includes avoiding interactions in community spaces as +well as external channels like social media. Violating these terms may lead to a temporary or permanent ban. ### 3. Temporary Ban -**Community Impact**: A serious violation of community standards, including -sustained inappropriate behavior. +**Community Impact**: A serious violation of community standards, including sustained inappropriate behavior. -**Consequence**: A temporary ban from any sort of interaction or public -communication with the community for a specified period of time. No public or -private interaction with the people involved, including unsolicited interaction -with those enforcing the Code of Conduct, is allowed during this period. -Violating these terms may lead to a permanent ban. +**Consequence**: A temporary ban from any sort of interaction or public communication with the community for a specified period of time. No public or private interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, is +allowed during this period. Violating these terms may lead to a permanent ban. ### 4. Permanent Ban -**Community Impact**: Demonstrating a pattern of violation of community -standards, including sustained inappropriate behavior, harassment of an -individual, or aggression toward or disparagement of classes of individuals. +**Community Impact**: Demonstrating a pattern of violation of community standards, including sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals. -**Consequence**: A permanent ban from any sort of public interaction within -the community. +**Consequence**: A permanent ban from any sort of public interaction within the community. ## Attribution -This Code of Conduct is adapted from the [Contributor Covenant][homepage], -version 2.0, available at -. +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.0, available at . -Community Impact Guidelines were inspired by [Mozilla's code of conduct -enforcement ladder](https://github.com/mozilla/diversity). +Community Impact Guidelines were inspired by [Mozilla's code of conduct enforcement ladder](https://github.com/mozilla/diversity). [homepage]: https://www.contributor-covenant.org -For answers to common questions about this code of conduct, see the FAQ at -. Translations are available at -. +For answers to common questions about this code of conduct, see the FAQ at . Translations are available at . diff --git a/README.md b/README.md index cee7eb855..1fcdc1692 100644 --- a/README.md +++ b/README.md @@ -7,25 +7,25 @@ - [Quickstart](#quick-start) - [What is MCP?](#what-is-mcp) - [Core Concepts](#core-concepts) - - [Server](#server) - - [Resources](#resources) - - [Tools](#tools) - - [Prompts](#prompts) - - [Completions](#completions) - - [Sampling](#sampling) + - [Server](#server) + - [Resources](#resources) + - [Tools](#tools) + - [Prompts](#prompts) + - [Completions](#completions) + - [Sampling](#sampling) - [Running Your Server](#running-your-server) - - [stdio](#stdio) - - [Streamable HTTP](#streamable-http) - - [Testing and Debugging](#testing-and-debugging) + - [stdio](#stdio) + - [Streamable HTTP](#streamable-http) + - [Testing and Debugging](#testing-and-debugging) - [Examples](#examples) - - [Echo Server](#echo-server) - - [SQLite Explorer](#sqlite-explorer) + - [Echo Server](#echo-server) + - [SQLite Explorer](#sqlite-explorer) - [Advanced Usage](#advanced-usage) - - [Dynamic Servers](#dynamic-servers) - - [Low-Level Server](#low-level-server) - - [Writing MCP Clients](#writing-mcp-clients) - - [Proxy Authorization Requests Upstream](#proxy-authorization-requests-upstream) - - [Backwards Compatibility](#backwards-compatibility) + - [Dynamic Servers](#dynamic-servers) + - [Low-Level Server](#low-level-server) + - [Writing MCP Clients](#writing-mcp-clients) + - [Proxy Authorization Requests Upstream](#proxy-authorization-requests-upstream) + - [Backwards Compatibility](#backwards-compatibility) - [Documentation](#documentation) - [Contributing](#contributing) - [License](#license) @@ -52,42 +52,45 @@ npm install @modelcontextprotocol/sdk Let's create a simple MCP server that exposes a calculator tool and some data: ```typescript -import { McpServer, ResourceTemplate } from "@modelcontextprotocol/sdk/server/mcp.js"; -import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js"; -import { z } from "zod"; +import { McpServer, ResourceTemplate } from '@modelcontextprotocol/sdk/server/mcp.js'; +import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js'; +import { z } from 'zod'; // Create an MCP server const server = new McpServer({ - name: "demo-server", - version: "1.0.0" + name: 'demo-server', + version: '1.0.0' }); // Add an addition tool -server.registerTool("add", - { - title: "Addition Tool", - description: "Add two numbers", - inputSchema: { a: z.number(), b: z.number() } - }, - async ({ a, b }) => ({ - content: [{ type: "text", text: String(a + b) }] - }) +server.registerTool( + 'add', + { + title: 'Addition Tool', + description: 'Add two numbers', + inputSchema: { a: z.number(), b: z.number() } + }, + async ({ a, b }) => ({ + content: [{ type: 'text', text: String(a + b) }] + }) ); // Add a dynamic greeting resource server.registerResource( - "greeting", - new ResourceTemplate("greeting://{name}", { list: undefined }), - { - title: "Greeting Resource", // Display name for UI - description: "Dynamic greeting generator" - }, - async (uri, { name }) => ({ - contents: [{ - uri: uri.href, - text: `Hello, ${name}!` - }] - }) + 'greeting', + new ResourceTemplate('greeting://{name}', { list: undefined }), + { + title: 'Greeting Resource', // Display name for UI + description: 'Dynamic greeting generator' + }, + async (uri, { name }) => ({ + contents: [ + { + uri: uri.href, + text: `Hello, ${name}!` + } + ] + }) ); // Start receiving messages on stdin and sending messages on stdout @@ -112,8 +115,8 @@ The McpServer is your core interface to the MCP protocol. It handles connection ```typescript const server = new McpServer({ - name: "my-app", - version: "1.0.0" + name: 'my-app', + version: '1.0.0' }); ``` @@ -124,62 +127,68 @@ Resources are how you expose data to LLMs. They're similar to GET endpoints in a ```typescript // Static resource server.registerResource( - "config", - "config://app", - { - title: "Application Config", - description: "Application configuration data", - mimeType: "text/plain" - }, - async (uri) => ({ - contents: [{ - uri: uri.href, - text: "App configuration here" - }] - }) + 'config', + 'config://app', + { + title: 'Application Config', + description: 'Application configuration data', + mimeType: 'text/plain' + }, + async uri => ({ + contents: [ + { + uri: uri.href, + text: 'App configuration here' + } + ] + }) ); // Dynamic resource with parameters server.registerResource( - "user-profile", - new ResourceTemplate("users://{userId}/profile", { list: undefined }), - { - title: "User Profile", - description: "User profile information" - }, - async (uri, { userId }) => ({ - contents: [{ - uri: uri.href, - text: `Profile data for user ${userId}` - }] - }) + 'user-profile', + new ResourceTemplate('users://{userId}/profile', { list: undefined }), + { + title: 'User Profile', + description: 'User profile information' + }, + async (uri, { userId }) => ({ + contents: [ + { + uri: uri.href, + text: `Profile data for user ${userId}` + } + ] + }) ); // Resource with context-aware completion server.registerResource( - "repository", - new ResourceTemplate("github://repos/{owner}/{repo}", { - list: undefined, - complete: { - // Provide intelligent completions based on previously resolved parameters - repo: (value, context) => { - if (context?.arguments?.["owner"] === "org1") { - return ["project1", "project2", "project3"].filter(r => r.startsWith(value)); + 'repository', + new ResourceTemplate('github://repos/{owner}/{repo}', { + list: undefined, + complete: { + // Provide intelligent completions based on previously resolved parameters + repo: (value, context) => { + if (context?.arguments?.['owner'] === 'org1') { + return ['project1', 'project2', 'project3'].filter(r => r.startsWith(value)); + } + return ['default-repo'].filter(r => r.startsWith(value)); + } } - return ["default-repo"].filter(r => r.startsWith(value)); - } - } - }), - { - title: "GitHub Repository", - description: "Repository information" - }, - async (uri, { owner, repo }) => ({ - contents: [{ - uri: uri.href, - text: `Repository: ${owner}/${repo}` - }] - }) + }), + { + title: 'GitHub Repository', + description: 'Repository information' + }, + async (uri, { owner, repo }) => ({ + contents: [ + { + uri: uri.href, + text: `Repository: ${owner}/${repo}` + } + ] + }) ); ``` @@ -190,68 +199,70 @@ Tools let LLMs take actions through your server. Unlike resources, tools are exp ```typescript // Simple tool with parameters server.registerTool( - "calculate-bmi", - { - title: "BMI Calculator", - description: "Calculate Body Mass Index", - inputSchema: { - weightKg: z.number(), - heightM: z.number() - } - }, - async ({ weightKg, heightM }) => ({ - content: [{ - type: "text", - text: String(weightKg / (heightM * heightM)) - }] - }) + 'calculate-bmi', + { + title: 'BMI Calculator', + description: 'Calculate Body Mass Index', + inputSchema: { + weightKg: z.number(), + heightM: z.number() + } + }, + async ({ weightKg, heightM }) => ({ + content: [ + { + type: 'text', + text: String(weightKg / (heightM * heightM)) + } + ] + }) ); // Async tool with external API call server.registerTool( - "fetch-weather", - { - title: "Weather Fetcher", - description: "Get weather data for a city", - inputSchema: { city: z.string() } - }, - async ({ city }) => { - const response = await fetch(`https://api.weather.com/${city}`); - const data = await response.text(); - return { - content: [{ type: "text", text: data }] - }; - } + 'fetch-weather', + { + title: 'Weather Fetcher', + description: 'Get weather data for a city', + inputSchema: { city: z.string() } + }, + async ({ city }) => { + const response = await fetch(`https://api.weather.com/${city}`); + const data = await response.text(); + return { + content: [{ type: 'text', text: data }] + }; + } ); // Tool that returns ResourceLinks server.registerTool( - "list-files", - { - title: "List Files", - description: "List project files", - inputSchema: { pattern: z.string() } - }, - async ({ pattern }) => ({ - content: [ - { type: "text", text: `Found files matching "${pattern}":` }, - // ResourceLinks let tools return references without file content - { - type: "resource_link", - uri: "file:///project/README.md", - name: "README.md", - mimeType: "text/markdown", - description: 'A README file' - }, - { - type: "resource_link", - uri: "file:///project/src/index.ts", - name: "index.ts", - mimeType: "text/typescript", - description: 'An index file' - } - ] - }) + 'list-files', + { + title: 'List Files', + description: 'List project files', + inputSchema: { pattern: z.string() } + }, + async ({ pattern }) => ({ + content: [ + { type: 'text', text: `Found files matching "${pattern}":` }, + // ResourceLinks let tools return references without file content + { + type: 'resource_link', + uri: 'file:///project/README.md', + name: 'README.md', + mimeType: 'text/markdown', + description: 'A README file' + }, + { + type: 'resource_link', + uri: 'file:///project/src/index.ts', + name: 'index.ts', + mimeType: 'text/typescript', + description: 'An index file' + } + ] + }) ); ``` @@ -264,60 +275,64 @@ Tools can return `ResourceLink` objects to reference resources without embedding Prompts are reusable templates that help LLMs interact with your server effectively: ```typescript -import { completable } from "@modelcontextprotocol/sdk/server/completable.js"; +import { completable } from '@modelcontextprotocol/sdk/server/completable.js'; server.registerPrompt( - "review-code", - { - title: "Code Review", - description: "Review code for best practices and potential issues", - argsSchema: { code: z.string() } - }, - ({ code }) => ({ - messages: [{ - role: "user", - content: { - type: "text", - text: `Please review this code:\n\n${code}` - } - }] - }) + 'review-code', + { + title: 'Code Review', + description: 'Review code for best practices and potential issues', + argsSchema: { code: z.string() } + }, + ({ code }) => ({ + messages: [ + { + role: 'user', + content: { + type: 'text', + text: `Please review this code:\n\n${code}` + } + } + ] + }) ); // Prompt with context-aware completion server.registerPrompt( - "team-greeting", - { - title: "Team Greeting", - description: "Generate a greeting for team members", - argsSchema: { - department: completable(z.string(), (value) => { - // Department suggestions - return ["engineering", "sales", "marketing", "support"].filter(d => d.startsWith(value)); - }), - name: completable(z.string(), (value, context) => { - // Name suggestions based on selected department - const department = context?.arguments?.["department"]; - if (department === "engineering") { - return ["Alice", "Bob", "Charlie"].filter(n => n.startsWith(value)); - } else if (department === "sales") { - return ["David", "Eve", "Frank"].filter(n => n.startsWith(value)); - } else if (department === "marketing") { - return ["Grace", "Henry", "Iris"].filter(n => n.startsWith(value)); + 'team-greeting', + { + title: 'Team Greeting', + description: 'Generate a greeting for team members', + argsSchema: { + department: completable(z.string(), value => { + // Department suggestions + return ['engineering', 'sales', 'marketing', 'support'].filter(d => d.startsWith(value)); + }), + name: completable(z.string(), (value, context) => { + // Name suggestions based on selected department + const department = context?.arguments?.['department']; + if (department === 'engineering') { + return ['Alice', 'Bob', 'Charlie'].filter(n => n.startsWith(value)); + } else if (department === 'sales') { + return ['David', 'Eve', 'Frank'].filter(n => n.startsWith(value)); + } else if (department === 'marketing') { + return ['Grace', 'Henry', 'Iris'].filter(n => n.startsWith(value)); + } + return ['Guest'].filter(n => n.startsWith(value)); + }) } - return ["Guest"].filter(n => n.startsWith(value)); - }) - } - }, - ({ department, name }) => ({ - messages: [{ - role: "assistant", - content: { - type: "text", - text: `Hello ${name}, welcome to the ${department} team!` - } - }] - }) + }, + ({ department, name }) => ({ + messages: [ + { + role: 'assistant', + content: { + type: 'text', + text: `Hello ${name}, welcome to the ${department} team!` + } + } + ] + }) ); ``` @@ -330,21 +345,21 @@ MCP supports argument completions to help users fill in prompt arguments and res ```typescript // Request completions for any argument const result = await client.complete({ - ref: { - type: "ref/prompt", // or "ref/resource" - name: "example" // or uri: "template://..." - }, - argument: { - name: "argumentName", - value: "partial" // What the user has typed so far - }, - context: { // Optional: Include previously resolved arguments - arguments: { - previousArg: "value" + ref: { + type: 'ref/prompt', // or "ref/resource" + name: 'example' // or uri: "template://..." + }, + argument: { + name: 'argumentName', + value: 'partial' // What the user has typed so far + }, + context: { + // Optional: Include previously resolved arguments + arguments: { + previousArg: 'value' + } } - } }); - ``` ### Display Names and Metadata @@ -356,6 +371,7 @@ All resources, tools, and prompts support an optional `title` field for better U #### Title Precedence for Tools For tools specifically, there are two ways to specify a title: + - `title` field in the tool configuration - `annotations.title` field (when using the older `tool()` method with annotations) @@ -363,23 +379,32 @@ The precedence order is: `title` → `annotations.title` → `name` ```typescript // Using registerTool (recommended) -server.registerTool("my_tool", { - title: "My Tool", // This title takes precedence - annotations: { - title: "Annotation Title" // This is ignored if title is set - } -}, handler); +server.registerTool( + 'my_tool', + { + title: 'My Tool', // This title takes precedence + annotations: { + title: 'Annotation Title' // This is ignored if title is set + } + }, + handler +); // Using tool with annotations (older API) -server.tool("my_tool", "description", { - title: "Annotation Title" // This is used as title -}, handler); +server.tool( + 'my_tool', + 'description', + { + title: 'Annotation Title' // This is used as title + }, + handler +); ``` When building clients, use the provided utility to get the appropriate display name: ```typescript -import { getDisplayName } from "@modelcontextprotocol/sdk/shared/metadataUtils.js"; +import { getDisplayName } from '@modelcontextprotocol/sdk/shared/metadataUtils.js'; // Automatically handles the precedence: title → annotations.title → name const displayName = getDisplayName(tool); @@ -390,63 +415,62 @@ const displayName = getDisplayName(tool); MCP servers can request LLM completions from connected clients that support sampling. ```typescript -import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; -import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js"; -import { z } from "zod"; +import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js'; +import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js'; +import { z } from 'zod'; const mcpServer = new McpServer({ - name: "tools-with-sample-server", - version: "1.0.0", + name: 'tools-with-sample-server', + version: '1.0.0' }); // Tool that uses LLM sampling to summarize any text mcpServer.registerTool( - "summarize", - { - description: "Summarize any text using an LLM", - inputSchema: { - text: z.string().describe("Text to summarize"), + 'summarize', + { + description: 'Summarize any text using an LLM', + inputSchema: { + text: z.string().describe('Text to summarize') + } }, - }, - async ({ text }) => { - // Call the LLM through MCP sampling - const response = await mcpServer.server.createMessage({ - messages: [ - { - role: "user", - content: { - type: "text", - text: `Please summarize the following text concisely:\n\n${text}`, - }, - }, - ], - maxTokens: 500, - }); + async ({ text }) => { + // Call the LLM through MCP sampling + const response = await mcpServer.server.createMessage({ + messages: [ + { + role: 'user', + content: { + type: 'text', + text: `Please summarize the following text concisely:\n\n${text}` + } + } + ], + maxTokens: 500 + }); - return { - content: [ - { - type: "text", - text: response.content.type === "text" ? response.content.text : "Unable to generate summary", - }, - ], - }; - } + return { + content: [ + { + type: 'text', + text: response.content.type === 'text' ? response.content.text : 'Unable to generate summary' + } + ] + }; + } ); async function main() { - const transport = new StdioServerTransport(); - await mcpServer.connect(transport); - console.log("MCP server is running..."); + const transport = new StdioServerTransport(); + await mcpServer.connect(transport); + console.log('MCP server is running...'); } -main().catch((error) => { - console.error("Server error:", error); - process.exit(1); +main().catch(error => { + console.error('Server error:', error); + process.exit(1); }); ``` - ## Running Your Server MCP servers in TypeScript need to be connected to a transport to communicate with clients. How you start the server depends on the choice of transport: @@ -456,12 +480,12 @@ MCP servers in TypeScript need to be connected to a transport to communicate wit For command-line tools and direct integrations: ```typescript -import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; -import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js"; +import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js'; +import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js'; const server = new McpServer({ - name: "example-server", - version: "1.0.0" + name: 'example-server', + version: '1.0.0' }); // ... set up server resources, tools, and prompts ... @@ -479,13 +503,11 @@ For remote servers, set up a Streamable HTTP transport that handles both client In some cases, servers need to be stateful. This is achieved by [session management](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#session-management). ```typescript -import express from "express"; -import { randomUUID } from "node:crypto"; -import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; -import { StreamableHTTPServerTransport } from "@modelcontextprotocol/sdk/server/streamableHttp.js"; -import { isInitializeRequest } from "@modelcontextprotocol/sdk/types.js" - - +import express from 'express'; +import { randomUUID } from 'node:crypto'; +import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js'; +import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js'; +import { isInitializeRequest } from '@modelcontextprotocol/sdk/types.js'; const app = express(); app.use(express.json()); @@ -495,69 +517,69 @@ const transports: { [sessionId: string]: StreamableHTTPServerTransport } = {}; // Handle POST requests for client-to-server communication app.post('/mcp', async (req, res) => { - // Check for existing session ID - const sessionId = req.headers['mcp-session-id'] as string | undefined; - let transport: StreamableHTTPServerTransport; - - if (sessionId && transports[sessionId]) { - // Reuse existing transport - transport = transports[sessionId]; - } else if (!sessionId && isInitializeRequest(req.body)) { - // New initialization request - transport = new StreamableHTTPServerTransport({ - sessionIdGenerator: () => randomUUID(), - onsessioninitialized: (sessionId) => { - // Store the transport by session ID - transports[sessionId] = transport; - }, - // DNS rebinding protection is disabled by default for backwards compatibility. If you are running this server - // locally, make sure to set: - // enableDnsRebindingProtection: true, - // allowedHosts: ['127.0.0.1'], - }); - - // Clean up transport when closed - transport.onclose = () => { - if (transport.sessionId) { - delete transports[transport.sessionId]; - } - }; - const server = new McpServer({ - name: "example-server", - version: "1.0.0" - }); + // Check for existing session ID + const sessionId = req.headers['mcp-session-id'] as string | undefined; + let transport: StreamableHTTPServerTransport; + + if (sessionId && transports[sessionId]) { + // Reuse existing transport + transport = transports[sessionId]; + } else if (!sessionId && isInitializeRequest(req.body)) { + // New initialization request + transport = new StreamableHTTPServerTransport({ + sessionIdGenerator: () => randomUUID(), + onsessioninitialized: sessionId => { + // Store the transport by session ID + transports[sessionId] = transport; + } + // DNS rebinding protection is disabled by default for backwards compatibility. If you are running this server + // locally, make sure to set: + // enableDnsRebindingProtection: true, + // allowedHosts: ['127.0.0.1'], + }); + + // Clean up transport when closed + transport.onclose = () => { + if (transport.sessionId) { + delete transports[transport.sessionId]; + } + }; + const server = new McpServer({ + name: 'example-server', + version: '1.0.0' + }); - // ... set up server resources, tools, and prompts ... + // ... set up server resources, tools, and prompts ... - // Connect to the MCP server - await server.connect(transport); - } else { - // Invalid request - res.status(400).json({ - jsonrpc: '2.0', - error: { - code: -32000, - message: 'Bad Request: No valid session ID provided', - }, - id: null, - }); - return; - } + // Connect to the MCP server + await server.connect(transport); + } else { + // Invalid request + res.status(400).json({ + jsonrpc: '2.0', + error: { + code: -32000, + message: 'Bad Request: No valid session ID provided' + }, + id: null + }); + return; + } - // Handle the request - await transport.handleRequest(req, res, req.body); + // Handle the request + await transport.handleRequest(req, res, req.body); }); // Reusable handler for GET and DELETE requests const handleSessionRequest = async (req: express.Request, res: express.Response) => { - const sessionId = req.headers['mcp-session-id'] as string | undefined; - if (!sessionId || !transports[sessionId]) { - res.status(400).send('Invalid or missing session ID'); - return; - } - - const transport = transports[sessionId]; - await transport.handleRequest(req, res); + const sessionId = req.headers['mcp-session-id'] as string | undefined; + if (!sessionId || !transports[sessionId]) { + res.status(400).send('Invalid or missing session ID'); + return; + } + + const transport = transports[sessionId]; + await transport.handleRequest(req, res); }; // Handle GET requests for server-to-client notifications via SSE @@ -569,9 +591,7 @@ app.delete('/mcp', handleSessionRequest); app.listen(3000); ``` -> [!TIP] -> When using this in a remote environment, make sure to allow the header parameter `mcp-session-id` in CORS. Otherwise, it may result in a `Bad Request: No valid session ID provided` error. Read the following section for examples. - +> [!TIP] When using this in a remote environment, make sure to allow the header parameter `mcp-session-id` in CORS. Otherwise, it may result in a `Bad Request: No valid session ID provided` error. Read the following section for examples. #### CORS Configuration for Browser-Based Clients @@ -581,15 +601,18 @@ If you'd like your server to be accessible by browser-based MCP clients, you'll import cors from 'cors'; // Add CORS middleware before your MCP routes -app.use(cors({ - origin: '*', // Configure appropriately for production, for example: - // origin: ['https://your-remote-domain.com', 'https://your-other-remote-domain.com'], - exposedHeaders: ['Mcp-Session-Id'], - allowedHeaders: ['Content-Type', 'mcp-session-id'], -})); +app.use( + cors({ + origin: '*', // Configure appropriately for production, for example: + // origin: ['https://your-remote-domain.com', 'https://your-other-remote-domain.com'], + exposedHeaders: ['Mcp-Session-Id'], + allowedHeaders: ['Content-Type', 'mcp-session-id'] + }) +); ``` This configuration is necessary because: + - The MCP streamable HTTP transport uses the `Mcp-Session-Id` header for session management - Browsers restrict access to response headers unless explicitly exposed via CORS - Without this configuration, browser-based clients won't be able to read the session ID from initialization responses @@ -603,79 +626,83 @@ const app = express(); app.use(express.json()); app.post('/mcp', async (req: Request, res: Response) => { - // In stateless mode, create a new instance of transport and server for each request - // to ensure complete isolation. A single instance would cause request ID collisions - // when multiple clients connect concurrently. - - try { - const server = getServer(); - const transport: StreamableHTTPServerTransport = new StreamableHTTPServerTransport({ - sessionIdGenerator: undefined, - }); - res.on('close', () => { - console.log('Request closed'); - transport.close(); - server.close(); - }); - await server.connect(transport); - await transport.handleRequest(req, res, req.body); - } catch (error) { - console.error('Error handling MCP request:', error); - if (!res.headersSent) { - res.status(500).json({ - jsonrpc: '2.0', - error: { - code: -32603, - message: 'Internal server error', - }, - id: null, - }); + // In stateless mode, create a new instance of transport and server for each request + // to ensure complete isolation. A single instance would cause request ID collisions + // when multiple clients connect concurrently. + + try { + const server = getServer(); + const transport: StreamableHTTPServerTransport = new StreamableHTTPServerTransport({ + sessionIdGenerator: undefined + }); + res.on('close', () => { + console.log('Request closed'); + transport.close(); + server.close(); + }); + await server.connect(transport); + await transport.handleRequest(req, res, req.body); + } catch (error) { + console.error('Error handling MCP request:', error); + if (!res.headersSent) { + res.status(500).json({ + jsonrpc: '2.0', + error: { + code: -32603, + message: 'Internal server error' + }, + id: null + }); + } } - } }); // SSE notifications not supported in stateless mode app.get('/mcp', async (req: Request, res: Response) => { - console.log('Received GET MCP request'); - res.writeHead(405).end(JSON.stringify({ - jsonrpc: "2.0", - error: { - code: -32000, - message: "Method not allowed." - }, - id: null - })); + console.log('Received GET MCP request'); + res.writeHead(405).end( + JSON.stringify({ + jsonrpc: '2.0', + error: { + code: -32000, + message: 'Method not allowed.' + }, + id: null + }) + ); }); // Session termination not needed in stateless mode app.delete('/mcp', async (req: Request, res: Response) => { - console.log('Received DELETE MCP request'); - res.writeHead(405).end(JSON.stringify({ - jsonrpc: "2.0", - error: { - code: -32000, - message: "Method not allowed." - }, - id: null - })); + console.log('Received DELETE MCP request'); + res.writeHead(405).end( + JSON.stringify({ + jsonrpc: '2.0', + error: { + code: -32000, + message: 'Method not allowed.' + }, + id: null + }) + ); }); - // Start the server const PORT = 3000; -setupServer().then(() => { - app.listen(PORT, (error) => { - if (error) { - console.error('Failed to start server:', error); - process.exit(1); - } - console.log(`MCP Stateless Streamable HTTP Server listening on port ${PORT}`); - }); -}).catch(error => { - console.error('Failed to set up the server:', error); - process.exit(1); -}); - +setupServer() + .then(() => { + app.listen(PORT, error => { + if (error) { + console.error('Failed to start server:', error); + process.exit(1); + } + console.log(`MCP Stateless Streamable HTTP Server listening on port ${PORT}`); + }); + }) + .catch(error => { + console.error('Failed to set up the server:', error); + process.exit(1); + }); ``` This stateless approach is useful for: @@ -711,57 +738,61 @@ To test your server, you can use the [MCP Inspector](https://github.com/modelcon A simple server demonstrating resources, tools, and prompts: ```typescript -import { McpServer, ResourceTemplate } from "@modelcontextprotocol/sdk/server/mcp.js"; -import { z } from "zod"; +import { McpServer, ResourceTemplate } from '@modelcontextprotocol/sdk/server/mcp.js'; +import { z } from 'zod'; const server = new McpServer({ - name: "echo-server", - version: "1.0.0" + name: 'echo-server', + version: '1.0.0' }); server.registerResource( - "echo", - new ResourceTemplate("echo://{message}", { list: undefined }), - { - title: "Echo Resource", - description: "Echoes back messages as resources" - }, - async (uri, { message }) => ({ - contents: [{ - uri: uri.href, - text: `Resource echo: ${message}` - }] - }) + 'echo', + new ResourceTemplate('echo://{message}', { list: undefined }), + { + title: 'Echo Resource', + description: 'Echoes back messages as resources' + }, + async (uri, { message }) => ({ + contents: [ + { + uri: uri.href, + text: `Resource echo: ${message}` + } + ] + }) ); server.registerTool( - "echo", - { - title: "Echo Tool", - description: "Echoes back the provided message", - inputSchema: { message: z.string() } - }, - async ({ message }) => ({ - content: [{ type: "text", text: `Tool echo: ${message}` }] - }) + 'echo', + { + title: 'Echo Tool', + description: 'Echoes back the provided message', + inputSchema: { message: z.string() } + }, + async ({ message }) => ({ + content: [{ type: 'text', text: `Tool echo: ${message}` }] + }) ); server.registerPrompt( - "echo", - { - title: "Echo Prompt", - description: "Creates a prompt to process a message", - argsSchema: { message: z.string() } - }, - ({ message }) => ({ - messages: [{ - role: "user", - content: { - type: "text", - text: `Please process this message: ${message}` - } - }] - }) + 'echo', + { + title: 'Echo Prompt', + description: 'Creates a prompt to process a message', + argsSchema: { message: z.string() } + }, + ({ message }) => ({ + messages: [ + { + role: 'user', + content: { + type: 'text', + text: `Please process this message: ${message}` + } + } + ] + }) ); ``` @@ -770,81 +801,85 @@ server.registerPrompt( A more complex example showing database integration: ```typescript -import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; -import sqlite3 from "sqlite3"; -import { promisify } from "util"; -import { z } from "zod"; +import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js'; +import sqlite3 from 'sqlite3'; +import { promisify } from 'util'; +import { z } from 'zod'; const server = new McpServer({ - name: "sqlite-explorer", - version: "1.0.0" + name: 'sqlite-explorer', + version: '1.0.0' }); // Helper to create DB connection const getDb = () => { - const db = new sqlite3.Database("database.db"); - return { - all: promisify(db.all.bind(db)), - close: promisify(db.close.bind(db)) - }; + const db = new sqlite3.Database('database.db'); + return { + all: promisify(db.all.bind(db)), + close: promisify(db.close.bind(db)) + }; }; server.registerResource( - "schema", - "schema://main", - { - title: "Database Schema", - description: "SQLite database schema", - mimeType: "text/plain" - }, - async (uri) => { - const db = getDb(); - try { - const tables = await db.all( - "SELECT sql FROM sqlite_master WHERE type='table'" - ); - return { - contents: [{ - uri: uri.href, - text: tables.map((t: {sql: string}) => t.sql).join("\n") - }] - }; - } finally { - await db.close(); + 'schema', + 'schema://main', + { + title: 'Database Schema', + description: 'SQLite database schema', + mimeType: 'text/plain' + }, + async uri => { + const db = getDb(); + try { + const tables = await db.all("SELECT sql FROM sqlite_master WHERE type='table'"); + return { + contents: [ + { + uri: uri.href, + text: tables.map((t: { sql: string }) => t.sql).join('\n') + } + ] + }; + } finally { + await db.close(); + } } - } ); server.registerTool( - "query", - { - title: "SQL Query", - description: "Execute SQL queries on the database", - inputSchema: { sql: z.string() } - }, - async ({ sql }) => { - const db = getDb(); - try { - const results = await db.all(sql); - return { - content: [{ - type: "text", - text: JSON.stringify(results, null, 2) - }] - }; - } catch (err: unknown) { - const error = err as Error; - return { - content: [{ - type: "text", - text: `Error: ${error.message}` - }], - isError: true - }; - } finally { - await db.close(); + 'query', + { + title: 'SQL Query', + description: 'Execute SQL queries on the database', + inputSchema: { sql: z.string() } + }, + async ({ sql }) => { + const db = getDb(); + try { + const results = await db.all(sql); + return { + content: [ + { + type: 'text', + text: JSON.stringify(results, null, 2) + } + ] + }; + } catch (err: unknown) { + const error = err as Error; + return { + content: [ + { + type: 'text', + text: `Error: ${error.message}` + } + ], + isError: true + }; + } finally { + await db.close(); + } } - } ); ``` @@ -855,57 +890,49 @@ server.registerTool( If you want to offer an initial set of tools/prompts/resources, but later add additional ones based on user action or external state change, you can add/update/remove them _after_ the Server is connected. This will automatically emit the corresponding `listChanged` notifications: ```ts -import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; -import { z } from "zod"; +import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js'; +import { z } from 'zod'; const server = new McpServer({ - name: "Dynamic Example", - version: "1.0.0" + name: 'Dynamic Example', + version: '1.0.0' }); -const listMessageTool = server.tool( - "listMessages", - { channel: z.string() }, - async ({ channel }) => ({ - content: [{ type: "text", text: await listMessages(channel) }] - }) -); +const listMessageTool = server.tool('listMessages', { channel: z.string() }, async ({ channel }) => ({ + content: [{ type: 'text', text: await listMessages(channel) }] +})); -const putMessageTool = server.tool( - "putMessage", - { channel: z.string(), message: z.string() }, - async ({ channel, message }) => ({ - content: [{ type: "text", text: await putMessage(channel, message) }] - }) -); +const putMessageTool = server.tool('putMessage', { channel: z.string(), message: z.string() }, async ({ channel, message }) => ({ + content: [{ type: 'text', text: await putMessage(channel, message) }] +})); // Until we upgrade auth, `putMessage` is disabled (won't show up in listTools) -putMessageTool.disable() +putMessageTool.disable(); const upgradeAuthTool = server.tool( - "upgradeAuth", - { permission: z.enum(["write", "admin"])}, - // Any mutations here will automatically emit `listChanged` notifications - async ({ permission }) => { - const { ok, err, previous } = await upgradeAuthAndStoreToken(permission) - if (!ok) return {content: [{ type: "text", text: `Error: ${err}` }]} - - // If we previously had read-only access, 'putMessage' is now available - if (previous === "read") { - putMessageTool.enable() - } + 'upgradeAuth', + { permission: z.enum(['write', 'admin']) }, + // Any mutations here will automatically emit `listChanged` notifications + async ({ permission }) => { + const { ok, err, previous } = await upgradeAuthAndStoreToken(permission); + if (!ok) return { content: [{ type: 'text', text: `Error: ${err}` }] }; + + // If we previously had read-only access, 'putMessage' is now available + if (previous === 'read') { + putMessageTool.enable(); + } - if (permission === 'write') { - // If we've just upgraded to 'write' permissions, we can still call 'upgradeAuth' - // but can only upgrade to 'admin'. - upgradeAuthTool.update({ - paramsSchema: { permission: z.enum(["admin"]) }, // change validation rules - }) - } else { - // If we're now an admin, we no longer have anywhere to upgrade to, so fully remove that tool - upgradeAuthTool.remove() + if (permission === 'write') { + // If we've just upgraded to 'write' permissions, we can still call 'upgradeAuth' + // but can only upgrade to 'admin'. + upgradeAuthTool.update({ + paramsSchema: { permission: z.enum(['admin']) } // change validation rules + }); + } else { + // If we're now an admin, we no longer have anywhere to upgrade to, so fully remove that tool + upgradeAuthTool.remove(); + } } - } -) +); // Connect as normal const transport = new StdioServerTransport(); @@ -918,8 +945,8 @@ When performing bulk updates that trigger notifications (e.g., enabling or disab This feature coalesces multiple, rapid calls for the same notification type into a single message. For example, if you disable five tools in a row, only one `notifications/tools/list_changed` message will be sent instead of five. -> [!IMPORTANT] -> This feature is designed for "simple" notifications that do not carry unique data in their parameters. To prevent silent data loss, debouncing is **automatically bypassed** for any notification that contains a `params` object or a `relatedRequestId`. Such notifications will always be sent immediately. +> [!IMPORTANT] This feature is designed for "simple" notifications that do not carry unique data in their parameters. To prevent silent data loss, debouncing is **automatically bypassed** for any notification that contains a `params` object or a `relatedRequestId`. Such +> notifications will always be sent immediately. This is an opt-in feature configured during server initialization. @@ -954,53 +981,56 @@ server.registerTool("tool3", ...).disable(); For more control, you can use the low-level Server class directly: ```typescript -import { Server } from "@modelcontextprotocol/sdk/server/index.js"; -import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js"; -import { - ListPromptsRequestSchema, - GetPromptRequestSchema -} from "@modelcontextprotocol/sdk/types.js"; +import { Server } from '@modelcontextprotocol/sdk/server/index.js'; +import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js'; +import { ListPromptsRequestSchema, GetPromptRequestSchema } from '@modelcontextprotocol/sdk/types.js'; const server = new Server( - { - name: "example-server", - version: "1.0.0" - }, - { - capabilities: { - prompts: {} + { + name: 'example-server', + version: '1.0.0' + }, + { + capabilities: { + prompts: {} + } } - } ); server.setRequestHandler(ListPromptsRequestSchema, async () => { - return { - prompts: [{ - name: "example-prompt", - description: "An example prompt template", - arguments: [{ - name: "arg1", - description: "Example argument", - required: true - }] - }] - }; + return { + prompts: [ + { + name: 'example-prompt', + description: 'An example prompt template', + arguments: [ + { + name: 'arg1', + description: 'Example argument', + required: true + } + ] + } + ] + }; }); -server.setRequestHandler(GetPromptRequestSchema, async (request) => { - if (request.params.name !== "example-prompt") { - throw new Error("Unknown prompt"); - } - return { - description: "Example prompt", - messages: [{ - role: "user", - content: { - type: "text", - text: "Example prompt text" - } - }] - }; +server.setRequestHandler(GetPromptRequestSchema, async request => { + if (request.params.name !== 'example-prompt') { + throw new Error('Unknown prompt'); + } + return { + description: 'Example prompt', + messages: [ + { + role: 'user', + content: { + type: 'text', + text: 'Example prompt text' + } + } + ] + }; }); const transport = new StdioServerTransport(); @@ -1014,72 +1044,73 @@ MCP servers can request additional information from users through the elicitatio ```typescript // Server-side: Restaurant booking tool that asks for alternatives server.tool( - "book-restaurant", - { - restaurant: z.string(), - date: z.string(), - partySize: z.number() - }, - async ({ restaurant, date, partySize }) => { - // Check availability - const available = await checkAvailability(restaurant, date, partySize); - - if (!available) { - // Ask user if they want to try alternative dates - const result = await server.server.elicitInput({ - message: `No tables available at ${restaurant} on ${date}. Would you like to check alternative dates?`, - requestedSchema: { - type: "object", - properties: { - checkAlternatives: { - type: "boolean", - title: "Check alternative dates", - description: "Would you like me to check other dates?" - }, - flexibleDates: { - type: "string", - title: "Date flexibility", - description: "How flexible are your dates?", - enum: ["next_day", "same_week", "next_week"], - enumNames: ["Next day", "Same week", "Next week"] + 'book-restaurant', + { + restaurant: z.string(), + date: z.string(), + partySize: z.number() + }, + async ({ restaurant, date, partySize }) => { + // Check availability + const available = await checkAvailability(restaurant, date, partySize); + + if (!available) { + // Ask user if they want to try alternative dates + const result = await server.server.elicitInput({ + message: `No tables available at ${restaurant} on ${date}. Would you like to check alternative dates?`, + requestedSchema: { + type: 'object', + properties: { + checkAlternatives: { + type: 'boolean', + title: 'Check alternative dates', + description: 'Would you like me to check other dates?' + }, + flexibleDates: { + type: 'string', + title: 'Date flexibility', + description: 'How flexible are your dates?', + enum: ['next_day', 'same_week', 'next_week'], + enumNames: ['Next day', 'Same week', 'Next week'] + } + }, + required: ['checkAlternatives'] + } + }); + + if (result.action === 'accept' && result.content?.checkAlternatives) { + const alternatives = await findAlternatives(restaurant, date, partySize, result.content.flexibleDates as string); + return { + content: [ + { + type: 'text', + text: `Found these alternatives: ${alternatives.join(', ')}` + } + ] + }; } - }, - required: ["checkAlternatives"] + + return { + content: [ + { + type: 'text', + text: 'No booking made. Original date not available.' + } + ] + }; } - }); - - if (result.action === "accept" && result.content?.checkAlternatives) { - const alternatives = await findAlternatives( - restaurant, - date, - partySize, - result.content.flexibleDates as string - ); + + // Book the table + await makeBooking(restaurant, date, partySize); return { - content: [{ - type: "text", - text: `Found these alternatives: ${alternatives.join(", ")}` - }] + content: [ + { + type: 'text', + text: `Booked table for ${partySize} at ${restaurant} on ${date}` + } + ] }; - } - - return { - content: [{ - type: "text", - text: "No booking made. Original date not available." - }] - }; } - - // Book the table - await makeBooking(restaurant, date, partySize); - return { - content: [{ - type: "text", - text: `Booked table for ${partySize} at ${restaurant} on ${date}` - }] - }; - } ); ``` @@ -1087,24 +1118,24 @@ Client-side: Handle elicitation requests ```typescript // This is a placeholder - implement based on your UI framework -async function getInputFromUser(message: string, schema: any): Promise<{ - action: "accept" | "decline" | "cancel"; - data?: Record; +async function getInputFromUser( + message: string, + schema: any +): Promise<{ + action: 'accept' | 'decline' | 'cancel'; + data?: Record; }> { - // This should be implemented depending on the app - throw new Error("getInputFromUser must be implemented for your platform"); + // This should be implemented depending on the app + throw new Error('getInputFromUser must be implemented for your platform'); } -client.setRequestHandler(ElicitRequestSchema, async (request) => { - const userResponse = await getInputFromUser( - request.params.message, - request.params.requestedSchema - ); - - return { - action: userResponse.action, - content: userResponse.action === "accept" ? userResponse.data : undefined - }; +client.setRequestHandler(ElicitRequestSchema, async request => { + const userResponse = await getInputFromUser(request.params.message, request.params.requestedSchema); + + return { + action: userResponse.action, + content: userResponse.action === 'accept' ? userResponse.data : undefined + }; }); ``` @@ -1115,20 +1146,18 @@ client.setRequestHandler(ElicitRequestSchema, async (request) => { The SDK provides a high-level client interface: ```typescript -import { Client } from "@modelcontextprotocol/sdk/client/index.js"; -import { StdioClientTransport } from "@modelcontextprotocol/sdk/client/stdio.js"; +import { Client } from '@modelcontextprotocol/sdk/client/index.js'; +import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js'; const transport = new StdioClientTransport({ - command: "node", - args: ["server.js"] + command: 'node', + args: ['server.js'] }); -const client = new Client( - { - name: "example-client", - version: "1.0.0" - } -); +const client = new Client({ + name: 'example-client', + version: '1.0.0' +}); await client.connect(transport); @@ -1137,10 +1166,10 @@ const prompts = await client.listPrompts(); // Get a prompt const prompt = await client.getPrompt({ - name: "example-prompt", - arguments: { - arg1: "value" - } + name: 'example-prompt', + arguments: { + arg1: 'value' + } }); // List resources @@ -1148,17 +1177,16 @@ const resources = await client.listResources(); // Read a resource const resource = await client.readResource({ - uri: "file:///example.txt" + uri: 'file:///example.txt' }); // Call a tool const result = await client.callTool({ - name: "example-tool", - arguments: { - arg1: "value" - } + name: 'example-tool', + arguments: { + arg1: 'value' + } }); - ``` ### Proxy Authorization Requests Upstream @@ -1174,31 +1202,33 @@ const app = express(); const proxyProvider = new ProxyOAuthServerProvider({ endpoints: { - authorizationUrl: "https://auth.external.com/oauth2/v1/authorize", - tokenUrl: "https://auth.external.com/oauth2/v1/token", - revocationUrl: "https://auth.external.com/oauth2/v1/revoke", + authorizationUrl: 'https://auth.external.com/oauth2/v1/authorize', + tokenUrl: 'https://auth.external.com/oauth2/v1/token', + revocationUrl: 'https://auth.external.com/oauth2/v1/revoke' }, - verifyAccessToken: async (token) => { + verifyAccessToken: async token => { return { token, - clientId: "123", - scopes: ["openid", "email", "profile"], - } + clientId: '123', + scopes: ['openid', 'email', 'profile'] + }; }, - getClient: async (client_id) => { + getClient: async client_id => { return { client_id, - redirect_uris: ["http://localhost:3000/callback"], - } + redirect_uris: ['http://localhost:3000/callback'] + }; } -}) - -app.use(mcpAuthRouter({ - provider: proxyProvider, - issuerUrl: new URL("http://auth.external.com"), - baseUrl: new URL("http://mcp.example.com"), - serviceDocumentationUrl: new URL("https://docs.example.com/"), -})) +}); + +app.use( + mcpAuthRouter({ + provider: proxyProvider, + issuerUrl: new URL('http://auth.external.com'), + baseUrl: new URL('http://mcp.example.com'), + serviceDocumentationUrl: new URL('https://docs.example.com/') + }) +); ``` This setup allows you to: @@ -1218,31 +1248,29 @@ Clients and servers with StreamableHttp transport can maintain [backwards compat For clients that need to work with both Streamable HTTP and older SSE servers: ```typescript -import { Client } from "@modelcontextprotocol/sdk/client/index.js"; -import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js"; -import { SSEClientTransport } from "@modelcontextprotocol/sdk/client/sse.js"; -let client: Client|undefined = undefined +import { Client } from '@modelcontextprotocol/sdk/client/index.js'; +import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'; +import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js'; +let client: Client | undefined = undefined; const baseUrl = new URL(url); try { - client = new Client({ - name: 'streamable-http-client', - version: '1.0.0' - }); - const transport = new StreamableHTTPClientTransport( - new URL(baseUrl) - ); - await client.connect(transport); - console.log("Connected using Streamable HTTP transport"); + client = new Client({ + name: 'streamable-http-client', + version: '1.0.0' + }); + const transport = new StreamableHTTPClientTransport(new URL(baseUrl)); + await client.connect(transport); + console.log('Connected using Streamable HTTP transport'); } catch (error) { - // If that fails with a 4xx error, try the older SSE transport - console.log("Streamable HTTP connection failed, falling back to SSE transport"); - client = new Client({ - name: 'sse-client', - version: '1.0.0' - }); - const sseTransport = new SSEClientTransport(baseUrl); - await client.connect(sseTransport); - console.log("Connected using SSE transport"); + // If that fails with a 4xx error, try the older SSE transport + console.log('Streamable HTTP connection failed, falling back to SSE transport'); + client = new Client({ + name: 'sse-client', + version: '1.0.0' + }); + const sseTransport = new SSEClientTransport(baseUrl); + await client.connect(sseTransport); + console.log('Connected using SSE transport'); } ``` @@ -1251,14 +1279,14 @@ try { For servers that need to support both Streamable HTTP and older clients: ```typescript -import express from "express"; -import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; -import { StreamableHTTPServerTransport } from "@modelcontextprotocol/sdk/server/streamableHttp.js"; -import { SSEServerTransport } from "@modelcontextprotocol/sdk/server/sse.js"; +import express from 'express'; +import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js'; +import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js'; +import { SSEServerTransport } from '@modelcontextprotocol/sdk/server/sse.js'; const server = new McpServer({ - name: "backwards-compatible-server", - version: "1.0.0" + name: 'backwards-compatible-server', + version: '1.0.0' }); // ... set up server resources, tools, and prompts ... @@ -1268,39 +1296,39 @@ app.use(express.json()); // Store transports for each session type const transports = { - streamable: {} as Record, - sse: {} as Record + streamable: {} as Record, + sse: {} as Record }; // Modern Streamable HTTP endpoint app.all('/mcp', async (req, res) => { - // Handle Streamable HTTP transport for modern clients - // Implementation as shown in the "With Session Management" example - // ... + // Handle Streamable HTTP transport for modern clients + // Implementation as shown in the "With Session Management" example + // ... }); // Legacy SSE endpoint for older clients app.get('/sse', async (req, res) => { - // Create SSE transport for legacy clients - const transport = new SSEServerTransport('/messages', res); - transports.sse[transport.sessionId] = transport; - - res.on("close", () => { - delete transports.sse[transport.sessionId]; - }); - - await server.connect(transport); + // Create SSE transport for legacy clients + const transport = new SSEServerTransport('/messages', res); + transports.sse[transport.sessionId] = transport; + + res.on('close', () => { + delete transports.sse[transport.sessionId]; + }); + + await server.connect(transport); }); // Legacy message endpoint for older clients app.post('/messages', async (req, res) => { - const sessionId = req.query.sessionId as string; - const transport = transports.sse[sessionId]; - if (transport) { - await transport.handlePostMessage(req, res, req.body); - } else { - res.status(400).send('No transport found for sessionId'); - } + const sessionId = req.query.sessionId as string; + const transport = transports.sse[sessionId]; + if (transport) { + await transport.handlePostMessage(req, res, req.body); + } else { + res.status(400).send('No transport found for sessionId'); + } }); app.listen(3000); diff --git a/eslint.config.mjs b/eslint.config.mjs index d792f015f..5849013f3 100644 --- a/eslint.config.mjs +++ b/eslint.config.mjs @@ -2,25 +2,25 @@ import eslint from '@eslint/js'; import tseslint from 'typescript-eslint'; +import eslintConfigPrettier from 'eslint-config-prettier/flat'; export default tseslint.config( eslint.configs.recommended, ...tseslint.configs.recommended, { linterOptions: { - reportUnusedDisableDirectives: false, + reportUnusedDisableDirectives: false }, rules: { - "@typescript-eslint/no-unused-vars": ["error", - { "argsIgnorePattern": "^_" } - ] + '@typescript-eslint/no-unused-vars': ['error', { argsIgnorePattern: '^_' }] } }, { - files: ["src/client/**/*.ts", "src/server/**/*.ts"], - ignores: ["**/*.test.ts"], + files: ['src/client/**/*.ts', 'src/server/**/*.ts'], + ignores: ['**/*.test.ts'], rules: { - "no-console": "error" + 'no-console': 'error' } - } + }, + eslintConfigPrettier ); diff --git a/jest.config.js b/jest.config.js index f8f621c8b..d15de5a17 100644 --- a/jest.config.js +++ b/jest.config.js @@ -1,16 +1,14 @@ -import { createDefaultEsmPreset } from "ts-jest"; +import { createDefaultEsmPreset } from 'ts-jest'; const defaultEsmPreset = createDefaultEsmPreset(); /** @type {import('ts-jest').JestConfigWithTsJest} **/ export default { - ...defaultEsmPreset, - moduleNameMapper: { - "^(\\.{1,2}/.*)\\.js$": "$1", - "^pkce-challenge$": "/src/__mocks__/pkce-challenge.ts" - }, - transformIgnorePatterns: [ - "/node_modules/(?!eventsource)/" - ], - testPathIgnorePatterns: ["/node_modules/", "/dist/"], + ...defaultEsmPreset, + moduleNameMapper: { + '^(\\.{1,2}/.*)\\.js$': '$1', + '^pkce-challenge$': '/src/__mocks__/pkce-challenge.ts' + }, + transformIgnorePatterns: ['/node_modules/(?!eventsource)/'], + testPathIgnorePatterns: ['/node_modules/', '/dist/'] }; diff --git a/package-lock.json b/package-lock.json index 002fe2088..4a7ffc06f 100644 --- a/package-lock.json +++ b/package-lock.json @@ -36,7 +36,9 @@ "@types/supertest": "^6.0.2", "@types/ws": "^8.5.12", "eslint": "^9.8.0", + "eslint-config-prettier": "^10.1.8", "jest": "^29.7.0", + "prettier": "3.6.2", "supertest": "^7.0.0", "ts-jest": "^29.2.4", "tsx": "^4.16.5", @@ -3213,6 +3215,22 @@ } } }, + "node_modules/eslint-config-prettier": { + "version": "10.1.8", + "resolved": "https://registry.npmjs.org/eslint-config-prettier/-/eslint-config-prettier-10.1.8.tgz", + "integrity": "sha512-82GZUjRS0p/jganf6q1rEO25VSoHH0hKPCTrgillPjdI/3bgBhAE1QzHrHTizjpRvy6pGAvKjDJtk2pF9NDq8w==", + "dev": true, + "license": "MIT", + "bin": { + "eslint-config-prettier": "bin/cli.js" + }, + "funding": { + "url": "https://opencollective.com/eslint-config-prettier" + }, + "peerDependencies": { + "eslint": ">=7.0.0" + } + }, "node_modules/eslint-scope": { "version": "8.1.0", "resolved": "https://registry.npmjs.org/eslint-scope/-/eslint-scope-8.1.0.tgz", @@ -5519,6 +5537,22 @@ "node": ">= 0.8.0" } }, + "node_modules/prettier": { + "version": "3.6.2", + "resolved": "https://registry.npmjs.org/prettier/-/prettier-3.6.2.tgz", + "integrity": "sha512-I7AIg5boAr5R0FFtJ6rCfD+LFsWHp81dolrFD8S79U9tb8Az2nGrJncnMSnys+bpQJfRUzqs9hnA81OAA3hCuQ==", + "dev": true, + "license": "MIT", + "bin": { + "prettier": "bin/prettier.cjs" + }, + "engines": { + "node": ">=14" + }, + "funding": { + "url": "https://github.com/prettier/prettier?sponsor=1" + } + }, "node_modules/pretty-format": { "version": "29.7.0", "resolved": "https://registry.npmjs.org/pretty-format/-/pretty-format-29.7.0.tgz", diff --git a/package.json b/package.json index bb36a6d98..2cd28556d 100644 --- a/package.json +++ b/package.json @@ -1,103 +1,105 @@ { - "name": "@modelcontextprotocol/sdk", - "version": "1.18.1", - "description": "Model Context Protocol implementation for TypeScript", - "license": "MIT", - "author": "Anthropic, PBC (https://anthropic.com)", - "homepage": "https://modelcontextprotocol.io", - "bugs": "https://github.com/modelcontextprotocol/typescript-sdk/issues", - "type": "module", - "repository": { - "type": "git", - "url": "git+https://github.com/modelcontextprotocol/typescript-sdk.git" - }, - "engines": { - "node": ">=18" - }, - "keywords": [ - "modelcontextprotocol", - "mcp" - ], - "exports": { - ".": { - "import": "./dist/esm/index.js", - "require": "./dist/cjs/index.js" + "name": "@modelcontextprotocol/sdk", + "version": "1.18.1", + "description": "Model Context Protocol implementation for TypeScript", + "license": "MIT", + "author": "Anthropic, PBC (https://anthropic.com)", + "homepage": "https://modelcontextprotocol.io", + "bugs": "https://github.com/modelcontextprotocol/typescript-sdk/issues", + "type": "module", + "repository": { + "type": "git", + "url": "git+https://github.com/modelcontextprotocol/typescript-sdk.git" }, - "./client": { - "import": "./dist/esm/client/index.js", - "require": "./dist/cjs/client/index.js" + "engines": { + "node": ">=18" }, - "./server": { - "import": "./dist/esm/server/index.js", - "require": "./dist/cjs/server/index.js" + "keywords": [ + "modelcontextprotocol", + "mcp" + ], + "exports": { + ".": { + "import": "./dist/esm/index.js", + "require": "./dist/cjs/index.js" + }, + "./client": { + "import": "./dist/esm/client/index.js", + "require": "./dist/cjs/client/index.js" + }, + "./server": { + "import": "./dist/esm/server/index.js", + "require": "./dist/cjs/server/index.js" + }, + "./*": { + "import": "./dist/esm/*", + "require": "./dist/cjs/*" + } }, - "./*": { - "import": "./dist/esm/*", - "require": "./dist/cjs/*" - } - }, - "typesVersions": { - "*": { - "*": [ - "./dist/esm/*" - ] + "typesVersions": { + "*": { + "*": [ + "./dist/esm/*" + ] + } + }, + "files": [ + "dist" + ], + "scripts": { + "fetch:spec-types": "curl -o spec.types.ts https://raw.githubusercontent.com/modelcontextprotocol/modelcontextprotocol/refs/heads/main/schema/draft/schema.ts", + "build": "npm run build:esm && npm run build:cjs", + "build:esm": "mkdir -p dist/esm && echo '{\"type\": \"module\"}' > dist/esm/package.json && tsc -p tsconfig.prod.json", + "build:esm:w": "npm run build:esm -- -w", + "build:cjs": "mkdir -p dist/cjs && echo '{\"type\": \"commonjs\"}' > dist/cjs/package.json && tsc -p tsconfig.cjs.json", + "build:cjs:w": "npm run build:cjs -- -w", + "examples:simple-server:w": "tsx --watch src/examples/server/simpleStreamableHttp.ts --oauth", + "prepack": "npm run build:esm && npm run build:cjs", + "lint": "eslint src/ && prettier --check .", + "test": "npm run fetch:spec-types && jest", + "start": "npm run server", + "server": "tsx watch --clear-screen=false src/cli.ts server", + "client": "tsx src/cli.ts client" + }, + "dependencies": { + "ajv": "^6.12.6", + "content-type": "^1.0.5", + "cors": "^2.8.5", + "cross-spawn": "^7.0.5", + "eventsource": "^3.0.2", + "eventsource-parser": "^3.0.0", + "express": "^5.0.1", + "express-rate-limit": "^7.5.0", + "pkce-challenge": "^5.0.0", + "raw-body": "^3.0.0", + "zod": "^3.23.8", + "zod-to-json-schema": "^3.24.1" + }, + "devDependencies": { + "@eslint/js": "^9.8.0", + "@jest-mock/express": "^3.0.0", + "@types/content-type": "^1.1.8", + "@types/cors": "^2.8.17", + "@types/cross-spawn": "^6.0.6", + "@types/eslint__js": "^8.42.3", + "@types/eventsource": "^1.1.15", + "@types/express": "^5.0.0", + "@types/jest": "^29.5.12", + "@types/node": "^22.0.2", + "@types/supertest": "^6.0.2", + "@types/ws": "^8.5.12", + "eslint": "^9.8.0", + "jest": "^29.7.0", + "prettier": "3.6.2", + "eslint-config-prettier": "^10.1.8", + "supertest": "^7.0.0", + "ts-jest": "^29.2.4", + "tsx": "^4.16.5", + "typescript": "^5.5.4", + "typescript-eslint": "^8.0.0", + "ws": "^8.18.0" + }, + "resolutions": { + "strip-ansi": "6.0.1" } - }, - "files": [ - "dist" - ], - "scripts": { - "fetch:spec-types": "curl -o spec.types.ts https://raw.githubusercontent.com/modelcontextprotocol/modelcontextprotocol/refs/heads/main/schema/draft/schema.ts", - "build": "npm run build:esm && npm run build:cjs", - "build:esm": "mkdir -p dist/esm && echo '{\"type\": \"module\"}' > dist/esm/package.json && tsc -p tsconfig.prod.json", - "build:esm:w": "npm run build:esm -- -w", - "build:cjs": "mkdir -p dist/cjs && echo '{\"type\": \"commonjs\"}' > dist/cjs/package.json && tsc -p tsconfig.cjs.json", - "build:cjs:w": "npm run build:cjs -- -w", - "examples:simple-server:w": "tsx --watch src/examples/server/simpleStreamableHttp.ts --oauth", - "prepack": "npm run build:esm && npm run build:cjs", - "lint": "eslint src/", - "test": "npm run fetch:spec-types && jest", - "start": "npm run server", - "server": "tsx watch --clear-screen=false src/cli.ts server", - "client": "tsx src/cli.ts client" - }, - "dependencies": { - "ajv": "^6.12.6", - "content-type": "^1.0.5", - "cors": "^2.8.5", - "cross-spawn": "^7.0.5", - "eventsource": "^3.0.2", - "eventsource-parser": "^3.0.0", - "express": "^5.0.1", - "express-rate-limit": "^7.5.0", - "pkce-challenge": "^5.0.0", - "raw-body": "^3.0.0", - "zod": "^3.23.8", - "zod-to-json-schema": "^3.24.1" - }, - "devDependencies": { - "@eslint/js": "^9.8.0", - "@jest-mock/express": "^3.0.0", - "@types/content-type": "^1.1.8", - "@types/cors": "^2.8.17", - "@types/cross-spawn": "^6.0.6", - "@types/eslint__js": "^8.42.3", - "@types/eventsource": "^1.1.15", - "@types/express": "^5.0.0", - "@types/jest": "^29.5.12", - "@types/node": "^22.0.2", - "@types/supertest": "^6.0.2", - "@types/ws": "^8.5.12", - "eslint": "^9.8.0", - "jest": "^29.7.0", - "supertest": "^7.0.0", - "ts-jest": "^29.2.4", - "tsx": "^4.16.5", - "typescript": "^5.5.4", - "typescript-eslint": "^8.0.0", - "ws": "^8.18.0" - }, - "resolutions": { - "strip-ansi": "6.0.1" - } -} \ No newline at end of file +} diff --git a/src/__mocks__/pkce-challenge.ts b/src/__mocks__/pkce-challenge.ts index 10e13054a..3dfec41f9 100644 --- a/src/__mocks__/pkce-challenge.ts +++ b/src/__mocks__/pkce-challenge.ts @@ -1,6 +1,6 @@ export default function pkceChallenge() { - return { - code_verifier: "test_verifier", - code_challenge: "test_challenge", - }; -} \ No newline at end of file + return { + code_verifier: 'test_verifier', + code_challenge: 'test_challenge' + }; +} diff --git a/src/cli.ts b/src/cli.ts index f580a624f..96764803f 100644 --- a/src/cli.ts +++ b/src/cli.ts @@ -1,163 +1,161 @@ -import WebSocket from "ws"; +import WebSocket from 'ws'; // eslint-disable-next-line @typescript-eslint/no-explicit-any (global as any).WebSocket = WebSocket; -import express from "express"; -import { Client } from "./client/index.js"; -import { SSEClientTransport } from "./client/sse.js"; -import { StdioClientTransport } from "./client/stdio.js"; -import { WebSocketClientTransport } from "./client/websocket.js"; -import { Server } from "./server/index.js"; -import { SSEServerTransport } from "./server/sse.js"; -import { StdioServerTransport } from "./server/stdio.js"; -import { ListResourcesResultSchema } from "./types.js"; +import express from 'express'; +import { Client } from './client/index.js'; +import { SSEClientTransport } from './client/sse.js'; +import { StdioClientTransport } from './client/stdio.js'; +import { WebSocketClientTransport } from './client/websocket.js'; +import { Server } from './server/index.js'; +import { SSEServerTransport } from './server/sse.js'; +import { StdioServerTransport } from './server/stdio.js'; +import { ListResourcesResultSchema } from './types.js'; async function runClient(url_or_command: string, args: string[]) { - const client = new Client( - { - name: "mcp-typescript test client", - version: "0.1.0", - }, - { - capabilities: { - sampling: {}, - }, - }, - ); - - let clientTransport; - - let url: URL | undefined = undefined; - try { - url = new URL(url_or_command); - } catch { - // Ignore - } - - if (url?.protocol === "http:" || url?.protocol === "https:") { - clientTransport = new SSEClientTransport(new URL(url_or_command)); - } else if (url?.protocol === "ws:" || url?.protocol === "wss:") { - clientTransport = new WebSocketClientTransport(new URL(url_or_command)); - } else { - clientTransport = new StdioClientTransport({ - command: url_or_command, - args, - }); - } - - console.log("Connected to server."); - - await client.connect(clientTransport); - console.log("Initialized."); - - await client.request({ method: "resources/list" }, ListResourcesResultSchema); - - await client.close(); - console.log("Closed."); -} - -async function runServer(port: number | null) { - if (port !== null) { - const app = express(); - - let servers: Server[] = []; - - app.get("/sse", async (req, res) => { - console.log("Got new SSE connection"); - - const transport = new SSEServerTransport("/message", res); - const server = new Server( + const client = new Client( { - name: "mcp-typescript test server", - version: "0.1.0", + name: 'mcp-typescript test client', + version: '0.1.0' }, { - capabilities: {}, - }, - ); - - servers.push(server); - - server.onclose = () => { - console.log("SSE connection closed"); - servers = servers.filter((s) => s !== server); - }; - - await server.connect(transport); - }); - - app.post("/message", async (req, res) => { - console.log("Received message"); - - const sessionId = req.query.sessionId as string; - const transport = servers - .map((s) => s.transport as SSEServerTransport) - .find((t) => t.sessionId === sessionId); - if (!transport) { - res.status(404).send("Session not found"); - return; - } - - await transport.handlePostMessage(req, res); - }); - - app.listen(port, (error) => { - if (error) { - console.error('Failed to start server:', error); - process.exit(1); - } - console.log(`Server running on http://localhost:${port}/sse`); - }); - } else { - const server = new Server( - { - name: "mcp-typescript test server", - version: "0.1.0", - }, - { - capabilities: { - prompts: {}, - resources: {}, - tools: {}, - logging: {}, - }, - }, + capabilities: { + sampling: {} + } + } ); - const transport = new StdioServerTransport(); - await server.connect(transport); + let clientTransport; - console.log("Server running on stdio"); - } -} + let url: URL | undefined = undefined; + try { + url = new URL(url_or_command); + } catch { + // Ignore + } -const args = process.argv.slice(2); -const command = args[0]; -switch (command) { - case "client": - if (args.length < 2) { - console.error("Usage: client [args...]"); - process.exit(1); + if (url?.protocol === 'http:' || url?.protocol === 'https:') { + clientTransport = new SSEClientTransport(new URL(url_or_command)); + } else if (url?.protocol === 'ws:' || url?.protocol === 'wss:') { + clientTransport = new WebSocketClientTransport(new URL(url_or_command)); + } else { + clientTransport = new StdioClientTransport({ + command: url_or_command, + args + }); } - runClient(args[1], args.slice(2)).catch((error) => { - console.error(error); - process.exit(1); - }); + console.log('Connected to server.'); + + await client.connect(clientTransport); + console.log('Initialized.'); - break; + await client.request({ method: 'resources/list' }, ListResourcesResultSchema); - case "server": { - const port = args[1] ? parseInt(args[1]) : null; - runServer(port).catch((error) => { - console.error(error); - process.exit(1); - }); + await client.close(); + console.log('Closed.'); +} - break; - } +async function runServer(port: number | null) { + if (port !== null) { + const app = express(); + + let servers: Server[] = []; + + app.get('/sse', async (req, res) => { + console.log('Got new SSE connection'); + + const transport = new SSEServerTransport('/message', res); + const server = new Server( + { + name: 'mcp-typescript test server', + version: '0.1.0' + }, + { + capabilities: {} + } + ); + + servers.push(server); + + server.onclose = () => { + console.log('SSE connection closed'); + servers = servers.filter(s => s !== server); + }; + + await server.connect(transport); + }); + + app.post('/message', async (req, res) => { + console.log('Received message'); + + const sessionId = req.query.sessionId as string; + const transport = servers.map(s => s.transport as SSEServerTransport).find(t => t.sessionId === sessionId); + if (!transport) { + res.status(404).send('Session not found'); + return; + } + + await transport.handlePostMessage(req, res); + }); + + app.listen(port, error => { + if (error) { + console.error('Failed to start server:', error); + process.exit(1); + } + console.log(`Server running on http://localhost:${port}/sse`); + }); + } else { + const server = new Server( + { + name: 'mcp-typescript test server', + version: '0.1.0' + }, + { + capabilities: { + prompts: {}, + resources: {}, + tools: {}, + logging: {} + } + } + ); + + const transport = new StdioServerTransport(); + await server.connect(transport); + + console.log('Server running on stdio'); + } +} + +const args = process.argv.slice(2); +const command = args[0]; +switch (command) { + case 'client': + if (args.length < 2) { + console.error('Usage: client [args...]'); + process.exit(1); + } + + runClient(args[1], args.slice(2)).catch(error => { + console.error(error); + process.exit(1); + }); + + break; + + case 'server': { + const port = args[1] ? parseInt(args[1]) : null; + runServer(port).catch(error => { + console.error(error); + process.exit(1); + }); + + break; + } - default: - console.error("Unrecognized command:", command); + default: + console.error('Unrecognized command:', command); } diff --git a/src/client/auth.test.ts b/src/client/auth.test.ts index f28163d14..846ba35c2 100644 --- a/src/client/auth.test.ts +++ b/src/client/auth.test.ts @@ -1,2499 +1,2435 @@ import { LATEST_PROTOCOL_VERSION } from '../types.js'; import { - discoverOAuthMetadata, - discoverAuthorizationServerMetadata, - buildDiscoveryUrls, - startAuthorization, - exchangeAuthorization, - refreshAuthorization, - registerClient, - discoverOAuthProtectedResourceMetadata, - extractResourceMetadataUrl, - auth, - type OAuthClientProvider, -} from "./auth.js"; -import {ServerError} from "../server/auth/errors.js"; + discoverOAuthMetadata, + discoverAuthorizationServerMetadata, + buildDiscoveryUrls, + startAuthorization, + exchangeAuthorization, + refreshAuthorization, + registerClient, + discoverOAuthProtectedResourceMetadata, + extractResourceMetadataUrl, + auth, + type OAuthClientProvider +} from './auth.js'; +import { ServerError } from '../server/auth/errors.js'; import { AuthorizationServerMetadata } from '../shared/auth.js'; // Mock fetch globally const mockFetch = jest.fn(); global.fetch = mockFetch; -describe("OAuth Authorization", () => { - beforeEach(() => { - mockFetch.mockReset(); - }); - - describe("extractResourceMetadataUrl", () => { - it("returns resource metadata url when present", async () => { - const resourceUrl = "https://resource.example.com/.well-known/oauth-protected-resource" - const mockResponse = { - headers: { - get: jest.fn((name) => name === "WWW-Authenticate" ? `Bearer realm="mcp", resource_metadata="${resourceUrl}"` : null), - } - } as unknown as Response - - expect(extractResourceMetadataUrl(mockResponse)).toEqual(new URL(resourceUrl)); - }); - - it("returns undefined if not bearer", async () => { - const resourceUrl = "https://resource.example.com/.well-known/oauth-protected-resource" - const mockResponse = { - headers: { - get: jest.fn((name) => name === "WWW-Authenticate" ? `Basic realm="mcp", resource_metadata="${resourceUrl}"` : null), - } - } as unknown as Response - - expect(extractResourceMetadataUrl(mockResponse)).toBeUndefined(); - }); - - it("returns undefined if resource_metadata not present", async () => { - const mockResponse = { - headers: { - get: jest.fn((name) => name === "WWW-Authenticate" ? `Basic realm="mcp"` : null), - } - } as unknown as Response - - expect(extractResourceMetadataUrl(mockResponse)).toBeUndefined(); - }); - - it("returns undefined on invalid url", async () => { - const resourceUrl = "invalid-url" - const mockResponse = { - headers: { - get: jest.fn((name) => name === "WWW-Authenticate" ? `Basic realm="mcp", resource_metadata="${resourceUrl}"` : null), - } - } as unknown as Response - - expect(extractResourceMetadataUrl(mockResponse)).toBeUndefined(); - }); - }); - - describe("discoverOAuthProtectedResourceMetadata", () => { - const validMetadata = { - resource: "https://resource.example.com", - authorization_servers: ["https://auth.example.com"], - }; - - it("returns metadata when discovery succeeds", async () => { - mockFetch.mockResolvedValueOnce({ - ok: true, - status: 200, - json: async () => validMetadata, - }); - - const metadata = await discoverOAuthProtectedResourceMetadata("https://resource.example.com"); - expect(metadata).toEqual(validMetadata); - const calls = mockFetch.mock.calls; - expect(calls.length).toBe(1); - const [url] = calls[0]; - expect(url.toString()).toBe("https://resource.example.com/.well-known/oauth-protected-resource"); - }); - - it("returns metadata when first fetch fails but second without MCP header succeeds", async () => { - // Set up a counter to control behavior - let callCount = 0; - - // Mock implementation that changes behavior based on call count - mockFetch.mockImplementation((_url, _options) => { - callCount++; - - if (callCount === 1) { - // First call with MCP header - fail with TypeError (simulating CORS error) - // We need to use TypeError specifically because that's what the implementation checks for - return Promise.reject(new TypeError("Network error")); - } else { - // Second call without header - succeed - return Promise.resolve({ - ok: true, - status: 200, - json: async () => validMetadata - }); - } - }); - - // Should succeed with the second call - const metadata = await discoverOAuthProtectedResourceMetadata("https://resource.example.com"); - expect(metadata).toEqual(validMetadata); - - // Verify both calls were made - expect(mockFetch).toHaveBeenCalledTimes(2); - - // Verify first call had MCP header - expect(mockFetch.mock.calls[0][1]?.headers).toHaveProperty("MCP-Protocol-Version"); - }); - - it("throws an error when all fetch attempts fail", async () => { - // Set up a counter to control behavior - let callCount = 0; - - // Mock implementation that changes behavior based on call count - mockFetch.mockImplementation((_url, _options) => { - callCount++; - - if (callCount === 1) { - // First call - fail with TypeError - return Promise.reject(new TypeError("First failure")); - } else { - // Second call - fail with different error - return Promise.reject(new Error("Second failure")); - } - }); - - // Should fail with the second error - await expect(discoverOAuthProtectedResourceMetadata("https://resource.example.com")) - .rejects.toThrow("Second failure"); - - // Verify both calls were made - expect(mockFetch).toHaveBeenCalledTimes(2); - }); - - it("throws on 404 errors", async () => { - mockFetch.mockResolvedValueOnce({ - ok: false, - status: 404, - }); - - await expect(discoverOAuthProtectedResourceMetadata("https://resource.example.com")) - .rejects.toThrow("Resource server does not implement OAuth 2.0 Protected Resource Metadata."); - }); - - it("throws on non-404 errors", async () => { - mockFetch.mockResolvedValueOnce({ - ok: false, - status: 500, - }); - - await expect(discoverOAuthProtectedResourceMetadata("https://resource.example.com")) - .rejects.toThrow("HTTP 500"); - }); - - it("validates metadata schema", async () => { - mockFetch.mockResolvedValueOnce({ - ok: true, - status: 200, - json: async () => ({ - // Missing required fields - scopes_supported: ["email", "mcp"], - }), - }); - - await expect(discoverOAuthProtectedResourceMetadata("https://resource.example.com")) - .rejects.toThrow(); - }); - - it("returns metadata when discovery succeeds with path", async () => { - mockFetch.mockResolvedValueOnce({ - ok: true, - status: 200, - json: async () => validMetadata, - }); - - const metadata = await discoverOAuthProtectedResourceMetadata("https://resource.example.com/path/name"); - expect(metadata).toEqual(validMetadata); - const calls = mockFetch.mock.calls; - expect(calls.length).toBe(1); - const [url] = calls[0]; - expect(url.toString()).toBe("https://resource.example.com/.well-known/oauth-protected-resource/path/name"); - }); - - it("preserves query parameters in path-aware discovery", async () => { - mockFetch.mockResolvedValueOnce({ - ok: true, - status: 200, - json: async () => validMetadata, - }); - - const metadata = await discoverOAuthProtectedResourceMetadata("https://resource.example.com/path?param=value"); - expect(metadata).toEqual(validMetadata); - const calls = mockFetch.mock.calls; - expect(calls.length).toBe(1); - const [url] = calls[0]; - expect(url.toString()).toBe("https://resource.example.com/.well-known/oauth-protected-resource/path?param=value"); - }); - - it.each([400, 401, 403, 404, 410, 422, 429])("falls back to root discovery when path-aware discovery returns %d", async (statusCode) => { - // First call (path-aware) returns 4xx - mockFetch.mockResolvedValueOnce({ - ok: false, - status: statusCode, - }); - - // Second call (root fallback) succeeds - mockFetch.mockResolvedValueOnce({ - ok: true, - status: 200, - json: async () => validMetadata, - }); - - const metadata = await discoverOAuthProtectedResourceMetadata("https://resource.example.com/path/name"); - expect(metadata).toEqual(validMetadata); - - const calls = mockFetch.mock.calls; - expect(calls.length).toBe(2); - - // First call should be path-aware - const [firstUrl, firstOptions] = calls[0]; - expect(firstUrl.toString()).toBe("https://resource.example.com/.well-known/oauth-protected-resource/path/name"); - expect(firstOptions.headers).toEqual({ - "MCP-Protocol-Version": LATEST_PROTOCOL_VERSION - }); - - // Second call should be root fallback - const [secondUrl, secondOptions] = calls[1]; - expect(secondUrl.toString()).toBe("https://resource.example.com/.well-known/oauth-protected-resource"); - expect(secondOptions.headers).toEqual({ - "MCP-Protocol-Version": LATEST_PROTOCOL_VERSION - }); - }); - - it("throws error when both path-aware and root discovery return 404", async () => { - // First call (path-aware) returns 404 - mockFetch.mockResolvedValueOnce({ - ok: false, - status: 404, - }); - - // Second call (root fallback) also returns 404 - mockFetch.mockResolvedValueOnce({ - ok: false, - status: 404, - }); - - await expect(discoverOAuthProtectedResourceMetadata("https://resource.example.com/path/name")) - .rejects.toThrow("Resource server does not implement OAuth 2.0 Protected Resource Metadata."); - - const calls = mockFetch.mock.calls; - expect(calls.length).toBe(2); - }); - - it("throws error on 500 status and does not fallback", async () => { - // First call (path-aware) returns 500 - mockFetch.mockResolvedValueOnce({ - ok: false, - status: 500, - }); - - await expect(discoverOAuthProtectedResourceMetadata("https://resource.example.com/path/name")) - .rejects.toThrow(); - - const calls = mockFetch.mock.calls; - expect(calls.length).toBe(1); // Should not attempt fallback - }); - - it("does not fallback when the original URL is already at root path", async () => { - // First call (path-aware for root) returns 404 - mockFetch.mockResolvedValueOnce({ - ok: false, - status: 404, - }); - - await expect(discoverOAuthProtectedResourceMetadata("https://resource.example.com/")) - .rejects.toThrow("Resource server does not implement OAuth 2.0 Protected Resource Metadata."); - - const calls = mockFetch.mock.calls; - expect(calls.length).toBe(1); // Should not attempt fallback - - const [url] = calls[0]; - expect(url.toString()).toBe("https://resource.example.com/.well-known/oauth-protected-resource"); - }); - - it("does not fallback when the original URL has no path", async () => { - // First call (path-aware for no path) returns 404 - mockFetch.mockResolvedValueOnce({ - ok: false, - status: 404, - }); - - await expect(discoverOAuthProtectedResourceMetadata("https://resource.example.com")) - .rejects.toThrow("Resource server does not implement OAuth 2.0 Protected Resource Metadata."); - - const calls = mockFetch.mock.calls; - expect(calls.length).toBe(1); // Should not attempt fallback - - const [url] = calls[0]; - expect(url.toString()).toBe("https://resource.example.com/.well-known/oauth-protected-resource"); - }); - - it("falls back when path-aware discovery encounters CORS error", async () => { - // First call (path-aware) fails with TypeError (CORS) - mockFetch.mockImplementationOnce(() => Promise.reject(new TypeError("CORS error"))); - - // Retry path-aware without headers (simulating CORS retry) - mockFetch.mockResolvedValueOnce({ - ok: false, - status: 404, - }); - - // Second call (root fallback) succeeds - mockFetch.mockResolvedValueOnce({ - ok: true, - status: 200, - json: async () => validMetadata, - }); - - const metadata = await discoverOAuthProtectedResourceMetadata("https://resource.example.com/deep/path"); - expect(metadata).toEqual(validMetadata); - - const calls = mockFetch.mock.calls; - expect(calls.length).toBe(3); - - // Final call should be root fallback - const [lastUrl, lastOptions] = calls[2]; - expect(lastUrl.toString()).toBe("https://resource.example.com/.well-known/oauth-protected-resource"); - expect(lastOptions.headers).toEqual({ - "MCP-Protocol-Version": LATEST_PROTOCOL_VERSION - }); - }); - - it("does not fallback when resourceMetadataUrl is provided", async () => { - // Call with explicit URL returns 404 - mockFetch.mockResolvedValueOnce({ - ok: false, - status: 404, - }); - - await expect(discoverOAuthProtectedResourceMetadata("https://resource.example.com/path", { - resourceMetadataUrl: "https://custom.example.com/metadata" - })).rejects.toThrow("Resource server does not implement OAuth 2.0 Protected Resource Metadata."); - - const calls = mockFetch.mock.calls; - expect(calls.length).toBe(1); // Should not attempt fallback when explicit URL is provided - - const [url] = calls[0]; - expect(url.toString()).toBe("https://custom.example.com/metadata"); - }); - - it("supports overriding the fetch function used for requests", async () => { - const validMetadata = { - resource: "https://resource.example.com", - authorization_servers: ["https://auth.example.com"], - }; - - const customFetch = jest.fn().mockResolvedValue({ - ok: true, - status: 200, - json: async () => validMetadata, - }); - - const metadata = await discoverOAuthProtectedResourceMetadata( - "https://resource.example.com", - undefined, - customFetch - ); - - expect(metadata).toEqual(validMetadata); - expect(customFetch).toHaveBeenCalledTimes(1); - expect(mockFetch).not.toHaveBeenCalled(); - - const [url, options] = customFetch.mock.calls[0]; - expect(url.toString()).toBe("https://resource.example.com/.well-known/oauth-protected-resource"); - expect(options.headers).toEqual({ - "MCP-Protocol-Version": LATEST_PROTOCOL_VERSION - }); - }); - }); - - describe("discoverOAuthMetadata", () => { - const validMetadata = { - issuer: "https://auth.example.com", - authorization_endpoint: "https://auth.example.com/authorize", - token_endpoint: "https://auth.example.com/token", - registration_endpoint: "https://auth.example.com/register", - response_types_supported: ["code"], - code_challenge_methods_supported: ["S256"], - }; - - it("returns metadata when discovery succeeds", async () => { - mockFetch.mockResolvedValueOnce({ - ok: true, - status: 200, - json: async () => validMetadata, - }); - - const metadata = await discoverOAuthMetadata("https://auth.example.com"); - expect(metadata).toEqual(validMetadata); - const calls = mockFetch.mock.calls; - expect(calls.length).toBe(1); - const [url, options] = calls[0]; - expect(url.toString()).toBe("https://auth.example.com/.well-known/oauth-authorization-server"); - expect(options.headers).toEqual({ - "MCP-Protocol-Version": LATEST_PROTOCOL_VERSION - }); - }); - - it("returns metadata when discovery succeeds with path", async () => { - mockFetch.mockResolvedValueOnce({ - ok: true, - status: 200, - json: async () => validMetadata, - }); - - const metadata = await discoverOAuthMetadata("https://auth.example.com/path/name"); - expect(metadata).toEqual(validMetadata); - const calls = mockFetch.mock.calls; - expect(calls.length).toBe(1); - const [url, options] = calls[0]; - expect(url.toString()).toBe("https://auth.example.com/.well-known/oauth-authorization-server/path/name"); - expect(options.headers).toEqual({ - "MCP-Protocol-Version": LATEST_PROTOCOL_VERSION - }); - }); - - it("falls back to root discovery when path-aware discovery returns 404", async () => { - // First call (path-aware) returns 404 - mockFetch.mockResolvedValueOnce({ - ok: false, - status: 404, - }); - - // Second call (root fallback) succeeds - mockFetch.mockResolvedValueOnce({ - ok: true, - status: 200, - json: async () => validMetadata, - }); - - const metadata = await discoverOAuthMetadata("https://auth.example.com/path/name"); - expect(metadata).toEqual(validMetadata); - - const calls = mockFetch.mock.calls; - expect(calls.length).toBe(2); - - // First call should be path-aware - const [firstUrl, firstOptions] = calls[0]; - expect(firstUrl.toString()).toBe("https://auth.example.com/.well-known/oauth-authorization-server/path/name"); - expect(firstOptions.headers).toEqual({ - "MCP-Protocol-Version": LATEST_PROTOCOL_VERSION - }); - - // Second call should be root fallback - const [secondUrl, secondOptions] = calls[1]; - expect(secondUrl.toString()).toBe("https://auth.example.com/.well-known/oauth-authorization-server"); - expect(secondOptions.headers).toEqual({ - "MCP-Protocol-Version": LATEST_PROTOCOL_VERSION - }); - }); - - it("returns undefined when both path-aware and root discovery return 404", async () => { - // First call (path-aware) returns 404 - mockFetch.mockResolvedValueOnce({ - ok: false, - status: 404, - }); - - // Second call (root fallback) also returns 404 - mockFetch.mockResolvedValueOnce({ - ok: false, - status: 404, - }); - - const metadata = await discoverOAuthMetadata("https://auth.example.com/path/name"); - expect(metadata).toBeUndefined(); - - const calls = mockFetch.mock.calls; - expect(calls.length).toBe(2); - }); - - it("does not fallback when the original URL is already at root path", async () => { - // First call (path-aware for root) returns 404 - mockFetch.mockResolvedValueOnce({ - ok: false, - status: 404, - }); - - const metadata = await discoverOAuthMetadata("https://auth.example.com/"); - expect(metadata).toBeUndefined(); - - const calls = mockFetch.mock.calls; - expect(calls.length).toBe(1); // Should not attempt fallback - - const [url] = calls[0]; - expect(url.toString()).toBe("https://auth.example.com/.well-known/oauth-authorization-server"); - }); - - it("does not fallback when the original URL has no path", async () => { - // First call (path-aware for no path) returns 404 - mockFetch.mockResolvedValueOnce({ - ok: false, - status: 404, - }); - - const metadata = await discoverOAuthMetadata("https://auth.example.com"); - expect(metadata).toBeUndefined(); - - const calls = mockFetch.mock.calls; - expect(calls.length).toBe(1); // Should not attempt fallback - - const [url] = calls[0]; - expect(url.toString()).toBe("https://auth.example.com/.well-known/oauth-authorization-server"); - }); - - it("falls back when path-aware discovery encounters CORS error", async () => { - // First call (path-aware) fails with TypeError (CORS) - mockFetch.mockImplementationOnce(() => Promise.reject(new TypeError("CORS error"))); - - // Retry path-aware without headers (simulating CORS retry) - mockFetch.mockResolvedValueOnce({ - ok: false, - status: 404, - }); - - // Second call (root fallback) succeeds - mockFetch.mockResolvedValueOnce({ - ok: true, - status: 200, - json: async () => validMetadata, - }); - - const metadata = await discoverOAuthMetadata("https://auth.example.com/deep/path"); - expect(metadata).toEqual(validMetadata); - - const calls = mockFetch.mock.calls; - expect(calls.length).toBe(3); - - // Final call should be root fallback - const [lastUrl, lastOptions] = calls[2]; - expect(lastUrl.toString()).toBe("https://auth.example.com/.well-known/oauth-authorization-server"); - expect(lastOptions.headers).toEqual({ - "MCP-Protocol-Version": LATEST_PROTOCOL_VERSION - }); - }); - - it("returns metadata when first fetch fails but second without MCP header succeeds", async () => { - // Set up a counter to control behavior - let callCount = 0; - - // Mock implementation that changes behavior based on call count - mockFetch.mockImplementation((_url, _options) => { - callCount++; - - if (callCount === 1) { - // First call with MCP header - fail with TypeError (simulating CORS error) - // We need to use TypeError specifically because that's what the implementation checks for - return Promise.reject(new TypeError("Network error")); - } else { - // Second call without header - succeed - return Promise.resolve({ - ok: true, - status: 200, - json: async () => validMetadata - }); - } - }); - - // Should succeed with the second call - const metadata = await discoverOAuthMetadata("https://auth.example.com"); - expect(metadata).toEqual(validMetadata); - - // Verify both calls were made - expect(mockFetch).toHaveBeenCalledTimes(2); - - // Verify first call had MCP header - expect(mockFetch.mock.calls[0][1]?.headers).toHaveProperty("MCP-Protocol-Version"); - }); - - it("throws an error when all fetch attempts fail", async () => { - // Set up a counter to control behavior - let callCount = 0; - - // Mock implementation that changes behavior based on call count - mockFetch.mockImplementation((_url, _options) => { - callCount++; - - if (callCount === 1) { - // First call - fail with TypeError - return Promise.reject(new TypeError("First failure")); - } else { - // Second call - fail with different error - return Promise.reject(new Error("Second failure")); - } - }); - - // Should fail with the second error - await expect(discoverOAuthMetadata("https://auth.example.com")) - .rejects.toThrow("Second failure"); - - // Verify both calls were made - expect(mockFetch).toHaveBeenCalledTimes(2); - }); - - it("returns undefined when both CORS requests fail in fetchWithCorsRetry", async () => { - // fetchWithCorsRetry tries with headers (fails with CORS), then retries without headers (also fails with CORS) - // simulating a 404 w/o headers set. We want this to return undefined, not throw TypeError - mockFetch.mockImplementation(() => { - // Both the initial request with headers and retry without headers fail with CORS TypeError - return Promise.reject(new TypeError("Failed to fetch")); - }); - - // This should return undefined (the desired behavior after the fix) - const metadata = await discoverOAuthMetadata("https://auth.example.com/path"); - expect(metadata).toBeUndefined(); - }); - - it("returns undefined when discovery endpoint returns 404", async () => { - mockFetch.mockResolvedValueOnce({ - ok: false, - status: 404, - }); - - const metadata = await discoverOAuthMetadata("https://auth.example.com"); - expect(metadata).toBeUndefined(); - }); - - it("throws on non-404 errors", async () => { - mockFetch.mockResolvedValueOnce(new Response(null, { status: 500 })); - - await expect( - discoverOAuthMetadata("https://auth.example.com") - ).rejects.toThrow("HTTP 500"); - }); - - it("validates metadata schema", async () => { - mockFetch.mockResolvedValueOnce( - Response.json( - { - // Missing required fields - issuer: "https://auth.example.com", - }, - { status: 200 } - ) - ); - - await expect( - discoverOAuthMetadata("https://auth.example.com") - ).rejects.toThrow(); - }); - - it("supports overriding the fetch function used for requests", async () => { - const validMetadata = { - issuer: "https://auth.example.com", - authorization_endpoint: "https://auth.example.com/authorize", - token_endpoint: "https://auth.example.com/token", - registration_endpoint: "https://auth.example.com/register", - response_types_supported: ["code"], - code_challenge_methods_supported: ["S256"], - }; - - const customFetch = jest.fn().mockResolvedValue({ - ok: true, - status: 200, - json: async () => validMetadata, - }); - - const metadata = await discoverOAuthMetadata( - "https://auth.example.com", - {}, - customFetch - ); - - expect(metadata).toEqual(validMetadata); - expect(customFetch).toHaveBeenCalledTimes(1); - expect(mockFetch).not.toHaveBeenCalled(); - - const [url, options] = customFetch.mock.calls[0]; - expect(url.toString()).toBe("https://auth.example.com/.well-known/oauth-authorization-server"); - expect(options.headers).toEqual({ - "MCP-Protocol-Version": LATEST_PROTOCOL_VERSION - }); - }); - }); - - describe("buildDiscoveryUrls", () => { - it("generates correct URLs for server without path", () => { - const urls = buildDiscoveryUrls("https://auth.example.com"); - - expect(urls).toHaveLength(2); - expect(urls.map(u => ({ url: u.url.toString(), type: u.type }))).toEqual([ - { - url: "https://auth.example.com/.well-known/oauth-authorization-server", - type: "oauth" - }, - { - url: "https://auth.example.com/.well-known/openid-configuration", - type: "oidc" - } - ]); - }); - - it("generates correct URLs for server with path", () => { - const urls = buildDiscoveryUrls("https://auth.example.com/tenant1"); - - expect(urls).toHaveLength(4); - expect(urls.map(u => ({ url: u.url.toString(), type: u.type }))).toEqual([ - { - url: "https://auth.example.com/.well-known/oauth-authorization-server/tenant1", - type: "oauth" - }, - { - url: "https://auth.example.com/.well-known/oauth-authorization-server", - type: "oauth" - }, - { - url: "https://auth.example.com/.well-known/openid-configuration/tenant1", - type: "oidc" - }, - { - url: "https://auth.example.com/tenant1/.well-known/openid-configuration", - type: "oidc" - } - ]); - }); - - it("handles URL object input", () => { - const urls = buildDiscoveryUrls(new URL("https://auth.example.com/tenant1")); - - expect(urls).toHaveLength(4); - expect(urls[0].url.toString()).toBe("https://auth.example.com/.well-known/oauth-authorization-server/tenant1"); - }); - }); - - describe("discoverAuthorizationServerMetadata", () => { - const validOAuthMetadata = { - issuer: "https://auth.example.com", - authorization_endpoint: "https://auth.example.com/authorize", - token_endpoint: "https://auth.example.com/token", - registration_endpoint: "https://auth.example.com/register", - response_types_supported: ["code"], - code_challenge_methods_supported: ["S256"], - }; - - const validOpenIdMetadata = { - issuer: "https://auth.example.com", - authorization_endpoint: "https://auth.example.com/authorize", - token_endpoint: "https://auth.example.com/token", - jwks_uri: "https://auth.example.com/jwks", - subject_types_supported: ["public"], - id_token_signing_alg_values_supported: ["RS256"], - response_types_supported: ["code"], - code_challenge_methods_supported: ["S256"], - }; - - it("tries URLs in order and returns first successful metadata", async () => { - // First OAuth URL fails with 404 - mockFetch.mockResolvedValueOnce({ - ok: false, - status: 404, - }); - - // Second OAuth URL (root) succeeds - mockFetch.mockResolvedValueOnce({ - ok: true, - status: 200, - json: async () => validOAuthMetadata, - }); - - const metadata = await discoverAuthorizationServerMetadata( - "https://auth.example.com/tenant1" - ); - - expect(metadata).toEqual(validOAuthMetadata); - - // Verify it tried the URLs in the correct order - const calls = mockFetch.mock.calls; - expect(calls.length).toBe(2); - expect(calls[0][0].toString()).toBe("https://auth.example.com/.well-known/oauth-authorization-server/tenant1"); - expect(calls[1][0].toString()).toBe("https://auth.example.com/.well-known/oauth-authorization-server"); - }); - - it("throws error when OIDC provider does not support S256 PKCE", async () => { - // OAuth discovery fails - mockFetch.mockResolvedValueOnce({ - ok: false, - status: 404, - }); - - // OpenID Connect discovery succeeds but without S256 support - const invalidOpenIdMetadata = { - ...validOpenIdMetadata, - code_challenge_methods_supported: ["plain"], // Missing S256 - }; - - mockFetch.mockResolvedValueOnce({ - ok: true, - status: 200, - json: async () => invalidOpenIdMetadata, - }); - - await expect( - discoverAuthorizationServerMetadata( - "https://auth.example.com" - ) - ).rejects.toThrow("does not support S256 code challenge method required by MCP specification"); - }); - - it("continues on 4xx errors", async () => { - mockFetch.mockResolvedValueOnce({ - ok: false, - status: 400, - }); - - mockFetch.mockResolvedValueOnce({ - ok: true, - status: 200, - json: async () => validOpenIdMetadata, - }); - - const metadata = await discoverAuthorizationServerMetadata("https://mcp.example.com"); - - expect(metadata).toEqual(validOpenIdMetadata); - - }); - - it("throws on non-4xx errors", async () => { - mockFetch.mockResolvedValueOnce({ - ok: false, - status: 500, - }); - - await expect( - discoverAuthorizationServerMetadata("https://mcp.example.com") - ).rejects.toThrow("HTTP 500"); - }); - - it("handles CORS errors with retry", async () => { - // First call fails with CORS - mockFetch.mockImplementationOnce(() => Promise.reject(new TypeError("CORS error"))); - - // Retry without headers succeeds - mockFetch.mockResolvedValueOnce({ - ok: true, - status: 200, - json: async () => validOAuthMetadata, - }); - - const metadata = await discoverAuthorizationServerMetadata( - "https://auth.example.com" - ); - - expect(metadata).toEqual(validOAuthMetadata); - const calls = mockFetch.mock.calls; - expect(calls.length).toBe(2); - - // First call should have headers - expect(calls[0][1]?.headers).toHaveProperty("MCP-Protocol-Version"); - - // Second call should not have headers (CORS retry) - expect(calls[1][1]?.headers).toBeUndefined(); - }); - - it("supports custom fetch function", async () => { - const customFetch = jest.fn().mockResolvedValue({ - ok: true, - status: 200, - json: async () => validOAuthMetadata, - }); - - const metadata = await discoverAuthorizationServerMetadata( - "https://auth.example.com", - { fetchFn: customFetch } - ); - - expect(metadata).toEqual(validOAuthMetadata); - expect(customFetch).toHaveBeenCalledTimes(1); - expect(mockFetch).not.toHaveBeenCalled(); - }); - - it("supports custom protocol version", async () => { - mockFetch.mockResolvedValueOnce({ - ok: true, - status: 200, - json: async () => validOAuthMetadata, - }); - - const metadata = await discoverAuthorizationServerMetadata( - "https://auth.example.com", - { protocolVersion: "2025-01-01" } - ); - - expect(metadata).toEqual(validOAuthMetadata); - const calls = mockFetch.mock.calls; - const [, options] = calls[0]; - expect(options.headers).toEqual({ - "MCP-Protocol-Version": "2025-01-01" - }); - }); - - it("returns undefined when all URLs fail with CORS errors", async () => { - // All fetch attempts fail with CORS errors (TypeError) - mockFetch.mockImplementation(() => Promise.reject(new TypeError("CORS error"))); - - const metadata = await discoverAuthorizationServerMetadata("https://auth.example.com/tenant1"); - - expect(metadata).toBeUndefined(); - - // Verify that all discovery URLs were attempted - expect(mockFetch).toHaveBeenCalledTimes(8); // 4 URLs × 2 attempts each (with and without headers) - }); - }); - - describe("startAuthorization", () => { - const validMetadata = { - issuer: "https://auth.example.com", - authorization_endpoint: "https://auth.example.com/auth", - token_endpoint: "https://auth.example.com/tkn", - response_types_supported: ["code"], - code_challenge_methods_supported: ["S256"], - }; - - const validClientInfo = { - client_id: "client123", - client_secret: "secret123", - redirect_uris: ["http://localhost:3000/callback"], - client_name: "Test Client", - }; - - it("generates authorization URL with PKCE challenge", async () => { - const { authorizationUrl, codeVerifier } = await startAuthorization( - "https://auth.example.com", - { - metadata: undefined, - clientInformation: validClientInfo, - redirectUrl: "http://localhost:3000/callback", - resource: new URL("https://api.example.com/mcp-server"), - } - ); - - expect(authorizationUrl.toString()).toMatch( - /^https:\/\/auth\.example\.com\/authorize\?/ - ); - expect(authorizationUrl.searchParams.get("response_type")).toBe("code"); - expect(authorizationUrl.searchParams.get("code_challenge")).toBe("test_challenge"); - expect(authorizationUrl.searchParams.get("code_challenge_method")).toBe( - "S256" - ); - expect(authorizationUrl.searchParams.get("redirect_uri")).toBe( - "http://localhost:3000/callback" - ); - expect(authorizationUrl.searchParams.get("resource")).toBe("https://api.example.com/mcp-server"); - expect(codeVerifier).toBe("test_verifier"); - }); - - it("includes scope parameter when provided", async () => { - const { authorizationUrl } = await startAuthorization( - "https://auth.example.com", - { - clientInformation: validClientInfo, - redirectUrl: "http://localhost:3000/callback", - scope: "read write profile", - } - ); - - expect(authorizationUrl.searchParams.get("scope")).toBe("read write profile"); - }); - - it("excludes scope parameter when not provided", async () => { - const { authorizationUrl } = await startAuthorization( - "https://auth.example.com", - { - clientInformation: validClientInfo, - redirectUrl: "http://localhost:3000/callback", - } - ); - - expect(authorizationUrl.searchParams.has("scope")).toBe(false); - }); - - it("includes state parameter when provided", async () => { - const { authorizationUrl } = await startAuthorization( - "https://auth.example.com", - { - clientInformation: validClientInfo, - redirectUrl: "http://localhost:3000/callback", - state: "foobar", - } - ); - - expect(authorizationUrl.searchParams.get("state")).toBe("foobar"); - }); - - it("excludes state parameter when not provided", async () => { - const { authorizationUrl } = await startAuthorization( - "https://auth.example.com", - { - clientInformation: validClientInfo, - redirectUrl: "http://localhost:3000/callback", - } - ); - - expect(authorizationUrl.searchParams.has("state")).toBe(false); - }); - - // OpenID Connect requires that the user is prompted for consent if the scope includes 'offline_access' - it("includes consent prompt parameter if scope includes 'offline_access'", async () => { - const { authorizationUrl } = await startAuthorization( - "https://auth.example.com", - { - clientInformation: validClientInfo, - redirectUrl: "http://localhost:3000/callback", - scope: "read write profile offline_access", - } - ); - - expect(authorizationUrl.searchParams.get("prompt")).toBe("consent"); - }); - - it("uses metadata authorization_endpoint when provided", async () => { - const { authorizationUrl } = await startAuthorization( - "https://auth.example.com", - { - metadata: validMetadata, - clientInformation: validClientInfo, - redirectUrl: "http://localhost:3000/callback", - } - ); - - expect(authorizationUrl.toString()).toMatch( - /^https:\/\/auth\.example\.com\/auth\?/ - ); - }); - - it("validates response type support", async () => { - const metadata = { - ...validMetadata, - response_types_supported: ["token"], // Does not support 'code' - }; - - await expect( - startAuthorization("https://auth.example.com", { - metadata, - clientInformation: validClientInfo, - redirectUrl: "http://localhost:3000/callback", - }) - ).rejects.toThrow(/does not support response type/); - }); - - it("validates PKCE support", async () => { - const metadata = { - ...validMetadata, - response_types_supported: ["code"], - code_challenge_methods_supported: ["plain"], // Does not support 'S256' - }; - - await expect( - startAuthorization("https://auth.example.com", { - metadata, - clientInformation: validClientInfo, - redirectUrl: "http://localhost:3000/callback", - }) - ).rejects.toThrow(/does not support code challenge method/); - }); - }); - - describe("exchangeAuthorization", () => { - const validTokens = { - access_token: "access123", - token_type: "Bearer", - expires_in: 3600, - refresh_token: "refresh123", - }; - - const validMetadata = { - issuer: "https://auth.example.com", - authorization_endpoint: "https://auth.example.com/authorize", - token_endpoint: "https://auth.example.com/token", - response_types_supported: ["code"] - }; - - const validClientInfo = { - client_id: "client123", - client_secret: "secret123", - redirect_uris: ["http://localhost:3000/callback"], - client_name: "Test Client", - }; - - it("exchanges code for tokens", async () => { - mockFetch.mockResolvedValueOnce({ - ok: true, - status: 200, - json: async () => validTokens, - }); - - const tokens = await exchangeAuthorization("https://auth.example.com", { - clientInformation: validClientInfo, - authorizationCode: "code123", - codeVerifier: "verifier123", - redirectUri: "http://localhost:3000/callback", - resource: new URL("https://api.example.com/mcp-server"), - }); - - expect(tokens).toEqual(validTokens); - expect(mockFetch).toHaveBeenCalledWith( - expect.objectContaining({ - href: "https://auth.example.com/token", - }), - expect.objectContaining({ - method: "POST", - headers: new Headers({ - "Content-Type": "application/x-www-form-urlencoded", - }), - }) - ); - - const body = mockFetch.mock.calls[0][1].body as URLSearchParams; - expect(body.get("grant_type")).toBe("authorization_code"); - expect(body.get("code")).toBe("code123"); - expect(body.get("code_verifier")).toBe("verifier123"); - expect(body.get("client_id")).toBe("client123"); - expect(body.get("client_secret")).toBe("secret123"); - expect(body.get("redirect_uri")).toBe("http://localhost:3000/callback"); - expect(body.get("resource")).toBe("https://api.example.com/mcp-server"); - }); - - it("exchanges code for tokens with auth", async () => { - mockFetch.mockResolvedValueOnce({ - ok: true, - status: 200, - json: async () => validTokens, - }); - - const tokens = await exchangeAuthorization("https://auth.example.com", { - metadata: validMetadata, - clientInformation: validClientInfo, - authorizationCode: "code123", - codeVerifier: "verifier123", - redirectUri: "http://localhost:3000/callback", - addClientAuthentication: (headers: Headers, params: URLSearchParams, url: string | URL, metadata: AuthorizationServerMetadata) => { - headers.set("Authorization", "Basic " + btoa(validClientInfo.client_id + ":" + validClientInfo.client_secret)); - params.set("example_url", typeof url === 'string' ? url : url.toString()); - params.set("example_metadata", metadata.authorization_endpoint); - params.set("example_param", "example_value"); - }, - }); - - expect(tokens).toEqual(validTokens); - expect(mockFetch).toHaveBeenCalledWith( - expect.objectContaining({ - href: "https://auth.example.com/token", - }), - expect.objectContaining({ - method: "POST", - }) - ); - - const headers = mockFetch.mock.calls[0][1].headers as Headers; - expect(headers.get("Content-Type")).toBe("application/x-www-form-urlencoded"); - expect(headers.get("Authorization")).toBe("Basic Y2xpZW50MTIzOnNlY3JldDEyMw=="); - const body = mockFetch.mock.calls[0][1].body as URLSearchParams; - expect(body.get("grant_type")).toBe("authorization_code"); - expect(body.get("code")).toBe("code123"); - expect(body.get("code_verifier")).toBe("verifier123"); - expect(body.get("client_id")).toBeNull(); - expect(body.get("redirect_uri")).toBe("http://localhost:3000/callback"); - expect(body.get("example_url")).toBe("https://auth.example.com"); - expect(body.get("example_metadata")).toBe("https://auth.example.com/authorize"); - expect(body.get("example_param")).toBe("example_value"); - expect(body.get("client_secret")).toBeNull(); - }); - - it("validates token response schema", async () => { - mockFetch.mockResolvedValueOnce({ - ok: true, - status: 200, - json: async () => ({ - // Missing required fields - access_token: "access123", - }), - }); - - await expect( - exchangeAuthorization("https://auth.example.com", { - clientInformation: validClientInfo, - authorizationCode: "code123", - codeVerifier: "verifier123", - redirectUri: "http://localhost:3000/callback", - }) - ).rejects.toThrow(); - }); - - it("throws on error response", async () => { - mockFetch.mockResolvedValueOnce( - Response.json( - new ServerError("Token exchange failed").toResponseObject(), - { status: 400 } - ) - ); - - await expect( - exchangeAuthorization("https://auth.example.com", { - clientInformation: validClientInfo, - authorizationCode: "code123", - codeVerifier: "verifier123", - redirectUri: "http://localhost:3000/callback", - }) - ).rejects.toThrow("Token exchange failed"); - }); - - it("supports overriding the fetch function used for requests", async () => { - const customFetch = jest.fn().mockResolvedValue({ - ok: true, - status: 200, - json: async () => validTokens, - }); - - const tokens = await exchangeAuthorization("https://auth.example.com", { - clientInformation: validClientInfo, - authorizationCode: "code123", - codeVerifier: "verifier123", - redirectUri: "http://localhost:3000/callback", - resource: new URL("https://api.example.com/mcp-server"), - fetchFn: customFetch, - }); - - expect(tokens).toEqual(validTokens); - expect(customFetch).toHaveBeenCalledTimes(1); - expect(mockFetch).not.toHaveBeenCalled(); - - const [url, options] = customFetch.mock.calls[0]; - expect(url.toString()).toBe("https://auth.example.com/token"); - expect(options).toEqual( - expect.objectContaining({ - method: "POST", - headers: expect.any(Headers), - body: expect.any(URLSearchParams), - }) - ); - - const body = options.body as URLSearchParams; - expect(body.get("grant_type")).toBe("authorization_code"); - expect(body.get("code")).toBe("code123"); - expect(body.get("code_verifier")).toBe("verifier123"); - expect(body.get("client_id")).toBe("client123"); - expect(body.get("client_secret")).toBe("secret123"); - expect(body.get("redirect_uri")).toBe("http://localhost:3000/callback"); - expect(body.get("resource")).toBe("https://api.example.com/mcp-server"); - }); - }); - - describe("refreshAuthorization", () => { - const validTokens = { - access_token: "newaccess123", - token_type: "Bearer", - expires_in: 3600, - } - const validTokensWithNewRefreshToken = { - ...validTokens, - refresh_token: "newrefresh123", - }; - - const validMetadata = { - issuer: "https://auth.example.com", - authorization_endpoint: "https://auth.example.com/authorize", - token_endpoint: "https://auth.example.com/token", - response_types_supported: ["code"] - }; - - const validClientInfo = { - client_id: "client123", - client_secret: "secret123", - redirect_uris: ["http://localhost:3000/callback"], - client_name: "Test Client", - }; - - it("exchanges refresh token for new tokens", async () => { - mockFetch.mockResolvedValueOnce({ - ok: true, - status: 200, - json: async () => validTokensWithNewRefreshToken, - }); - - const tokens = await refreshAuthorization("https://auth.example.com", { - clientInformation: validClientInfo, - refreshToken: "refresh123", - resource: new URL("https://api.example.com/mcp-server"), - }); - - expect(tokens).toEqual(validTokensWithNewRefreshToken); - expect(mockFetch).toHaveBeenCalledWith( - expect.objectContaining({ - href: "https://auth.example.com/token", - }), - expect.objectContaining({ - method: "POST", - headers: new Headers({ - "Content-Type": "application/x-www-form-urlencoded", - }), - }) - ); - - const body = mockFetch.mock.calls[0][1].body as URLSearchParams; - expect(body.get("grant_type")).toBe("refresh_token"); - expect(body.get("refresh_token")).toBe("refresh123"); - expect(body.get("client_id")).toBe("client123"); - expect(body.get("client_secret")).toBe("secret123"); - expect(body.get("resource")).toBe("https://api.example.com/mcp-server"); - }); - - it("exchanges refresh token for new tokens with auth", async () => { - mockFetch.mockResolvedValueOnce({ - ok: true, - status: 200, - json: async () => validTokensWithNewRefreshToken, - }); - - const tokens = await refreshAuthorization("https://auth.example.com", { - metadata: validMetadata, - clientInformation: validClientInfo, - refreshToken: "refresh123", - addClientAuthentication: (headers: Headers, params: URLSearchParams, url: string | URL, metadata?: AuthorizationServerMetadata) => { - headers.set("Authorization", "Basic " + btoa(validClientInfo.client_id + ":" + validClientInfo.client_secret)); - params.set("example_url", typeof url === 'string' ? url : url.toString()); - params.set("example_metadata", metadata?.authorization_endpoint ?? '?'); - params.set("example_param", "example_value"); - }, - }); - - expect(tokens).toEqual(validTokensWithNewRefreshToken); - expect(mockFetch).toHaveBeenCalledWith( - expect.objectContaining({ - href: "https://auth.example.com/token", - }), - expect.objectContaining({ - method: "POST", - }) - ); - - const headers = mockFetch.mock.calls[0][1].headers as Headers; - expect(headers.get("Content-Type")).toBe("application/x-www-form-urlencoded"); - expect(headers.get("Authorization")).toBe("Basic Y2xpZW50MTIzOnNlY3JldDEyMw=="); - const body = mockFetch.mock.calls[0][1].body as URLSearchParams; - expect(body.get("grant_type")).toBe("refresh_token"); - expect(body.get("refresh_token")).toBe("refresh123"); - expect(body.get("client_id")).toBeNull(); - expect(body.get("example_url")).toBe("https://auth.example.com"); - expect(body.get("example_metadata")).toBe("https://auth.example.com/authorize"); - expect(body.get("example_param")).toBe("example_value"); - expect(body.get("client_secret")).toBeNull(); - }); +describe('OAuth Authorization', () => { + beforeEach(() => { + mockFetch.mockReset(); + }); + + describe('extractResourceMetadataUrl', () => { + it('returns resource metadata url when present', async () => { + const resourceUrl = 'https://resource.example.com/.well-known/oauth-protected-resource'; + const mockResponse = { + headers: { + get: jest.fn(name => (name === 'WWW-Authenticate' ? `Bearer realm="mcp", resource_metadata="${resourceUrl}"` : null)) + } + } as unknown as Response; + + expect(extractResourceMetadataUrl(mockResponse)).toEqual(new URL(resourceUrl)); + }); + + it('returns undefined if not bearer', async () => { + const resourceUrl = 'https://resource.example.com/.well-known/oauth-protected-resource'; + const mockResponse = { + headers: { + get: jest.fn(name => (name === 'WWW-Authenticate' ? `Basic realm="mcp", resource_metadata="${resourceUrl}"` : null)) + } + } as unknown as Response; + + expect(extractResourceMetadataUrl(mockResponse)).toBeUndefined(); + }); + + it('returns undefined if resource_metadata not present', async () => { + const mockResponse = { + headers: { + get: jest.fn(name => (name === 'WWW-Authenticate' ? `Basic realm="mcp"` : null)) + } + } as unknown as Response; + + expect(extractResourceMetadataUrl(mockResponse)).toBeUndefined(); + }); + + it('returns undefined on invalid url', async () => { + const resourceUrl = 'invalid-url'; + const mockResponse = { + headers: { + get: jest.fn(name => (name === 'WWW-Authenticate' ? `Basic realm="mcp", resource_metadata="${resourceUrl}"` : null)) + } + } as unknown as Response; + + expect(extractResourceMetadataUrl(mockResponse)).toBeUndefined(); + }); + }); + + describe('discoverOAuthProtectedResourceMetadata', () => { + const validMetadata = { + resource: 'https://resource.example.com', + authorization_servers: ['https://auth.example.com'] + }; - it("exchanges refresh token for new tokens and keep existing refresh token if none is returned", async () => { - mockFetch.mockResolvedValueOnce({ - ok: true, - status: 200, - json: async () => validTokens, - }); + it('returns metadata when discovery succeeds', async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validMetadata + }); + + const metadata = await discoverOAuthProtectedResourceMetadata('https://resource.example.com'); + expect(metadata).toEqual(validMetadata); + const calls = mockFetch.mock.calls; + expect(calls.length).toBe(1); + const [url] = calls[0]; + expect(url.toString()).toBe('https://resource.example.com/.well-known/oauth-protected-resource'); + }); + + it('returns metadata when first fetch fails but second without MCP header succeeds', async () => { + // Set up a counter to control behavior + let callCount = 0; + + // Mock implementation that changes behavior based on call count + mockFetch.mockImplementation((_url, _options) => { + callCount++; + + if (callCount === 1) { + // First call with MCP header - fail with TypeError (simulating CORS error) + // We need to use TypeError specifically because that's what the implementation checks for + return Promise.reject(new TypeError('Network error')); + } else { + // Second call without header - succeed + return Promise.resolve({ + ok: true, + status: 200, + json: async () => validMetadata + }); + } + }); + + // Should succeed with the second call + const metadata = await discoverOAuthProtectedResourceMetadata('https://resource.example.com'); + expect(metadata).toEqual(validMetadata); + + // Verify both calls were made + expect(mockFetch).toHaveBeenCalledTimes(2); + + // Verify first call had MCP header + expect(mockFetch.mock.calls[0][1]?.headers).toHaveProperty('MCP-Protocol-Version'); + }); + + it('throws an error when all fetch attempts fail', async () => { + // Set up a counter to control behavior + let callCount = 0; + + // Mock implementation that changes behavior based on call count + mockFetch.mockImplementation((_url, _options) => { + callCount++; + + if (callCount === 1) { + // First call - fail with TypeError + return Promise.reject(new TypeError('First failure')); + } else { + // Second call - fail with different error + return Promise.reject(new Error('Second failure')); + } + }); + + // Should fail with the second error + await expect(discoverOAuthProtectedResourceMetadata('https://resource.example.com')).rejects.toThrow('Second failure'); + + // Verify both calls were made + expect(mockFetch).toHaveBeenCalledTimes(2); + }); + + it('throws on 404 errors', async () => { + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 404 + }); + + await expect(discoverOAuthProtectedResourceMetadata('https://resource.example.com')).rejects.toThrow( + 'Resource server does not implement OAuth 2.0 Protected Resource Metadata.' + ); + }); + + it('throws on non-404 errors', async () => { + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 500 + }); + + await expect(discoverOAuthProtectedResourceMetadata('https://resource.example.com')).rejects.toThrow('HTTP 500'); + }); + + it('validates metadata schema', async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => ({ + // Missing required fields + scopes_supported: ['email', 'mcp'] + }) + }); + + await expect(discoverOAuthProtectedResourceMetadata('https://resource.example.com')).rejects.toThrow(); + }); + + it('returns metadata when discovery succeeds with path', async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validMetadata + }); + + const metadata = await discoverOAuthProtectedResourceMetadata('https://resource.example.com/path/name'); + expect(metadata).toEqual(validMetadata); + const calls = mockFetch.mock.calls; + expect(calls.length).toBe(1); + const [url] = calls[0]; + expect(url.toString()).toBe('https://resource.example.com/.well-known/oauth-protected-resource/path/name'); + }); + + it('preserves query parameters in path-aware discovery', async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validMetadata + }); + + const metadata = await discoverOAuthProtectedResourceMetadata('https://resource.example.com/path?param=value'); + expect(metadata).toEqual(validMetadata); + const calls = mockFetch.mock.calls; + expect(calls.length).toBe(1); + const [url] = calls[0]; + expect(url.toString()).toBe('https://resource.example.com/.well-known/oauth-protected-resource/path?param=value'); + }); + + it.each([400, 401, 403, 404, 410, 422, 429])( + 'falls back to root discovery when path-aware discovery returns %d', + async statusCode => { + // First call (path-aware) returns 4xx + mockFetch.mockResolvedValueOnce({ + ok: false, + status: statusCode + }); + + // Second call (root fallback) succeeds + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validMetadata + }); + + const metadata = await discoverOAuthProtectedResourceMetadata('https://resource.example.com/path/name'); + expect(metadata).toEqual(validMetadata); + + const calls = mockFetch.mock.calls; + expect(calls.length).toBe(2); + + // First call should be path-aware + const [firstUrl, firstOptions] = calls[0]; + expect(firstUrl.toString()).toBe('https://resource.example.com/.well-known/oauth-protected-resource/path/name'); + expect(firstOptions.headers).toEqual({ + 'MCP-Protocol-Version': LATEST_PROTOCOL_VERSION + }); + + // Second call should be root fallback + const [secondUrl, secondOptions] = calls[1]; + expect(secondUrl.toString()).toBe('https://resource.example.com/.well-known/oauth-protected-resource'); + expect(secondOptions.headers).toEqual({ + 'MCP-Protocol-Version': LATEST_PROTOCOL_VERSION + }); + } + ); + + it('throws error when both path-aware and root discovery return 404', async () => { + // First call (path-aware) returns 404 + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 404 + }); + + // Second call (root fallback) also returns 404 + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 404 + }); + + await expect(discoverOAuthProtectedResourceMetadata('https://resource.example.com/path/name')).rejects.toThrow( + 'Resource server does not implement OAuth 2.0 Protected Resource Metadata.' + ); + + const calls = mockFetch.mock.calls; + expect(calls.length).toBe(2); + }); + + it('throws error on 500 status and does not fallback', async () => { + // First call (path-aware) returns 500 + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 500 + }); + + await expect(discoverOAuthProtectedResourceMetadata('https://resource.example.com/path/name')).rejects.toThrow(); + + const calls = mockFetch.mock.calls; + expect(calls.length).toBe(1); // Should not attempt fallback + }); + + it('does not fallback when the original URL is already at root path', async () => { + // First call (path-aware for root) returns 404 + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 404 + }); + + await expect(discoverOAuthProtectedResourceMetadata('https://resource.example.com/')).rejects.toThrow( + 'Resource server does not implement OAuth 2.0 Protected Resource Metadata.' + ); + + const calls = mockFetch.mock.calls; + expect(calls.length).toBe(1); // Should not attempt fallback + + const [url] = calls[0]; + expect(url.toString()).toBe('https://resource.example.com/.well-known/oauth-protected-resource'); + }); + + it('does not fallback when the original URL has no path', async () => { + // First call (path-aware for no path) returns 404 + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 404 + }); + + await expect(discoverOAuthProtectedResourceMetadata('https://resource.example.com')).rejects.toThrow( + 'Resource server does not implement OAuth 2.0 Protected Resource Metadata.' + ); + + const calls = mockFetch.mock.calls; + expect(calls.length).toBe(1); // Should not attempt fallback + + const [url] = calls[0]; + expect(url.toString()).toBe('https://resource.example.com/.well-known/oauth-protected-resource'); + }); + + it('falls back when path-aware discovery encounters CORS error', async () => { + // First call (path-aware) fails with TypeError (CORS) + mockFetch.mockImplementationOnce(() => Promise.reject(new TypeError('CORS error'))); + + // Retry path-aware without headers (simulating CORS retry) + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 404 + }); + + // Second call (root fallback) succeeds + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validMetadata + }); + + const metadata = await discoverOAuthProtectedResourceMetadata('https://resource.example.com/deep/path'); + expect(metadata).toEqual(validMetadata); + + const calls = mockFetch.mock.calls; + expect(calls.length).toBe(3); + + // Final call should be root fallback + const [lastUrl, lastOptions] = calls[2]; + expect(lastUrl.toString()).toBe('https://resource.example.com/.well-known/oauth-protected-resource'); + expect(lastOptions.headers).toEqual({ + 'MCP-Protocol-Version': LATEST_PROTOCOL_VERSION + }); + }); + + it('does not fallback when resourceMetadataUrl is provided', async () => { + // Call with explicit URL returns 404 + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 404 + }); + + await expect( + discoverOAuthProtectedResourceMetadata('https://resource.example.com/path', { + resourceMetadataUrl: 'https://custom.example.com/metadata' + }) + ).rejects.toThrow('Resource server does not implement OAuth 2.0 Protected Resource Metadata.'); + + const calls = mockFetch.mock.calls; + expect(calls.length).toBe(1); // Should not attempt fallback when explicit URL is provided + + const [url] = calls[0]; + expect(url.toString()).toBe('https://custom.example.com/metadata'); + }); + + it('supports overriding the fetch function used for requests', async () => { + const validMetadata = { + resource: 'https://resource.example.com', + authorization_servers: ['https://auth.example.com'] + }; + + const customFetch = jest.fn().mockResolvedValue({ + ok: true, + status: 200, + json: async () => validMetadata + }); + + const metadata = await discoverOAuthProtectedResourceMetadata('https://resource.example.com', undefined, customFetch); + + expect(metadata).toEqual(validMetadata); + expect(customFetch).toHaveBeenCalledTimes(1); + expect(mockFetch).not.toHaveBeenCalled(); + + const [url, options] = customFetch.mock.calls[0]; + expect(url.toString()).toBe('https://resource.example.com/.well-known/oauth-protected-resource'); + expect(options.headers).toEqual({ + 'MCP-Protocol-Version': LATEST_PROTOCOL_VERSION + }); + }); + }); + + describe('discoverOAuthMetadata', () => { + const validMetadata = { + issuer: 'https://auth.example.com', + authorization_endpoint: 'https://auth.example.com/authorize', + token_endpoint: 'https://auth.example.com/token', + registration_endpoint: 'https://auth.example.com/register', + response_types_supported: ['code'], + code_challenge_methods_supported: ['S256'] + }; - const refreshToken = "refresh123"; - const tokens = await refreshAuthorization("https://auth.example.com", { - clientInformation: validClientInfo, - refreshToken, - }); + it('returns metadata when discovery succeeds', async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validMetadata + }); + + const metadata = await discoverOAuthMetadata('https://auth.example.com'); + expect(metadata).toEqual(validMetadata); + const calls = mockFetch.mock.calls; + expect(calls.length).toBe(1); + const [url, options] = calls[0]; + expect(url.toString()).toBe('https://auth.example.com/.well-known/oauth-authorization-server'); + expect(options.headers).toEqual({ + 'MCP-Protocol-Version': LATEST_PROTOCOL_VERSION + }); + }); + + it('returns metadata when discovery succeeds with path', async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validMetadata + }); + + const metadata = await discoverOAuthMetadata('https://auth.example.com/path/name'); + expect(metadata).toEqual(validMetadata); + const calls = mockFetch.mock.calls; + expect(calls.length).toBe(1); + const [url, options] = calls[0]; + expect(url.toString()).toBe('https://auth.example.com/.well-known/oauth-authorization-server/path/name'); + expect(options.headers).toEqual({ + 'MCP-Protocol-Version': LATEST_PROTOCOL_VERSION + }); + }); + + it('falls back to root discovery when path-aware discovery returns 404', async () => { + // First call (path-aware) returns 404 + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 404 + }); + + // Second call (root fallback) succeeds + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validMetadata + }); + + const metadata = await discoverOAuthMetadata('https://auth.example.com/path/name'); + expect(metadata).toEqual(validMetadata); + + const calls = mockFetch.mock.calls; + expect(calls.length).toBe(2); + + // First call should be path-aware + const [firstUrl, firstOptions] = calls[0]; + expect(firstUrl.toString()).toBe('https://auth.example.com/.well-known/oauth-authorization-server/path/name'); + expect(firstOptions.headers).toEqual({ + 'MCP-Protocol-Version': LATEST_PROTOCOL_VERSION + }); + + // Second call should be root fallback + const [secondUrl, secondOptions] = calls[1]; + expect(secondUrl.toString()).toBe('https://auth.example.com/.well-known/oauth-authorization-server'); + expect(secondOptions.headers).toEqual({ + 'MCP-Protocol-Version': LATEST_PROTOCOL_VERSION + }); + }); + + it('returns undefined when both path-aware and root discovery return 404', async () => { + // First call (path-aware) returns 404 + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 404 + }); + + // Second call (root fallback) also returns 404 + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 404 + }); + + const metadata = await discoverOAuthMetadata('https://auth.example.com/path/name'); + expect(metadata).toBeUndefined(); + + const calls = mockFetch.mock.calls; + expect(calls.length).toBe(2); + }); + + it('does not fallback when the original URL is already at root path', async () => { + // First call (path-aware for root) returns 404 + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 404 + }); + + const metadata = await discoverOAuthMetadata('https://auth.example.com/'); + expect(metadata).toBeUndefined(); + + const calls = mockFetch.mock.calls; + expect(calls.length).toBe(1); // Should not attempt fallback + + const [url] = calls[0]; + expect(url.toString()).toBe('https://auth.example.com/.well-known/oauth-authorization-server'); + }); + + it('does not fallback when the original URL has no path', async () => { + // First call (path-aware for no path) returns 404 + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 404 + }); + + const metadata = await discoverOAuthMetadata('https://auth.example.com'); + expect(metadata).toBeUndefined(); + + const calls = mockFetch.mock.calls; + expect(calls.length).toBe(1); // Should not attempt fallback + + const [url] = calls[0]; + expect(url.toString()).toBe('https://auth.example.com/.well-known/oauth-authorization-server'); + }); + + it('falls back when path-aware discovery encounters CORS error', async () => { + // First call (path-aware) fails with TypeError (CORS) + mockFetch.mockImplementationOnce(() => Promise.reject(new TypeError('CORS error'))); + + // Retry path-aware without headers (simulating CORS retry) + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 404 + }); + + // Second call (root fallback) succeeds + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validMetadata + }); + + const metadata = await discoverOAuthMetadata('https://auth.example.com/deep/path'); + expect(metadata).toEqual(validMetadata); + + const calls = mockFetch.mock.calls; + expect(calls.length).toBe(3); + + // Final call should be root fallback + const [lastUrl, lastOptions] = calls[2]; + expect(lastUrl.toString()).toBe('https://auth.example.com/.well-known/oauth-authorization-server'); + expect(lastOptions.headers).toEqual({ + 'MCP-Protocol-Version': LATEST_PROTOCOL_VERSION + }); + }); + + it('returns metadata when first fetch fails but second without MCP header succeeds', async () => { + // Set up a counter to control behavior + let callCount = 0; + + // Mock implementation that changes behavior based on call count + mockFetch.mockImplementation((_url, _options) => { + callCount++; + + if (callCount === 1) { + // First call with MCP header - fail with TypeError (simulating CORS error) + // We need to use TypeError specifically because that's what the implementation checks for + return Promise.reject(new TypeError('Network error')); + } else { + // Second call without header - succeed + return Promise.resolve({ + ok: true, + status: 200, + json: async () => validMetadata + }); + } + }); + + // Should succeed with the second call + const metadata = await discoverOAuthMetadata('https://auth.example.com'); + expect(metadata).toEqual(validMetadata); + + // Verify both calls were made + expect(mockFetch).toHaveBeenCalledTimes(2); + + // Verify first call had MCP header + expect(mockFetch.mock.calls[0][1]?.headers).toHaveProperty('MCP-Protocol-Version'); + }); + + it('throws an error when all fetch attempts fail', async () => { + // Set up a counter to control behavior + let callCount = 0; + + // Mock implementation that changes behavior based on call count + mockFetch.mockImplementation((_url, _options) => { + callCount++; + + if (callCount === 1) { + // First call - fail with TypeError + return Promise.reject(new TypeError('First failure')); + } else { + // Second call - fail with different error + return Promise.reject(new Error('Second failure')); + } + }); + + // Should fail with the second error + await expect(discoverOAuthMetadata('https://auth.example.com')).rejects.toThrow('Second failure'); + + // Verify both calls were made + expect(mockFetch).toHaveBeenCalledTimes(2); + }); + + it('returns undefined when both CORS requests fail in fetchWithCorsRetry', async () => { + // fetchWithCorsRetry tries with headers (fails with CORS), then retries without headers (also fails with CORS) + // simulating a 404 w/o headers set. We want this to return undefined, not throw TypeError + mockFetch.mockImplementation(() => { + // Both the initial request with headers and retry without headers fail with CORS TypeError + return Promise.reject(new TypeError('Failed to fetch')); + }); + + // This should return undefined (the desired behavior after the fix) + const metadata = await discoverOAuthMetadata('https://auth.example.com/path'); + expect(metadata).toBeUndefined(); + }); + + it('returns undefined when discovery endpoint returns 404', async () => { + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 404 + }); + + const metadata = await discoverOAuthMetadata('https://auth.example.com'); + expect(metadata).toBeUndefined(); + }); + + it('throws on non-404 errors', async () => { + mockFetch.mockResolvedValueOnce(new Response(null, { status: 500 })); + + await expect(discoverOAuthMetadata('https://auth.example.com')).rejects.toThrow('HTTP 500'); + }); + + it('validates metadata schema', async () => { + mockFetch.mockResolvedValueOnce( + Response.json( + { + // Missing required fields + issuer: 'https://auth.example.com' + }, + { status: 200 } + ) + ); + + await expect(discoverOAuthMetadata('https://auth.example.com')).rejects.toThrow(); + }); + + it('supports overriding the fetch function used for requests', async () => { + const validMetadata = { + issuer: 'https://auth.example.com', + authorization_endpoint: 'https://auth.example.com/authorize', + token_endpoint: 'https://auth.example.com/token', + registration_endpoint: 'https://auth.example.com/register', + response_types_supported: ['code'], + code_challenge_methods_supported: ['S256'] + }; + + const customFetch = jest.fn().mockResolvedValue({ + ok: true, + status: 200, + json: async () => validMetadata + }); + + const metadata = await discoverOAuthMetadata('https://auth.example.com', {}, customFetch); + + expect(metadata).toEqual(validMetadata); + expect(customFetch).toHaveBeenCalledTimes(1); + expect(mockFetch).not.toHaveBeenCalled(); + + const [url, options] = customFetch.mock.calls[0]; + expect(url.toString()).toBe('https://auth.example.com/.well-known/oauth-authorization-server'); + expect(options.headers).toEqual({ + 'MCP-Protocol-Version': LATEST_PROTOCOL_VERSION + }); + }); + }); + + describe('buildDiscoveryUrls', () => { + it('generates correct URLs for server without path', () => { + const urls = buildDiscoveryUrls('https://auth.example.com'); + + expect(urls).toHaveLength(2); + expect(urls.map(u => ({ url: u.url.toString(), type: u.type }))).toEqual([ + { + url: 'https://auth.example.com/.well-known/oauth-authorization-server', + type: 'oauth' + }, + { + url: 'https://auth.example.com/.well-known/openid-configuration', + type: 'oidc' + } + ]); + }); + + it('generates correct URLs for server with path', () => { + const urls = buildDiscoveryUrls('https://auth.example.com/tenant1'); + + expect(urls).toHaveLength(4); + expect(urls.map(u => ({ url: u.url.toString(), type: u.type }))).toEqual([ + { + url: 'https://auth.example.com/.well-known/oauth-authorization-server/tenant1', + type: 'oauth' + }, + { + url: 'https://auth.example.com/.well-known/oauth-authorization-server', + type: 'oauth' + }, + { + url: 'https://auth.example.com/.well-known/openid-configuration/tenant1', + type: 'oidc' + }, + { + url: 'https://auth.example.com/tenant1/.well-known/openid-configuration', + type: 'oidc' + } + ]); + }); + + it('handles URL object input', () => { + const urls = buildDiscoveryUrls(new URL('https://auth.example.com/tenant1')); + + expect(urls).toHaveLength(4); + expect(urls[0].url.toString()).toBe('https://auth.example.com/.well-known/oauth-authorization-server/tenant1'); + }); + }); + + describe('discoverAuthorizationServerMetadata', () => { + const validOAuthMetadata = { + issuer: 'https://auth.example.com', + authorization_endpoint: 'https://auth.example.com/authorize', + token_endpoint: 'https://auth.example.com/token', + registration_endpoint: 'https://auth.example.com/register', + response_types_supported: ['code'], + code_challenge_methods_supported: ['S256'] + }; - expect(tokens).toEqual({ refresh_token: refreshToken, ...validTokens }); - }); + const validOpenIdMetadata = { + issuer: 'https://auth.example.com', + authorization_endpoint: 'https://auth.example.com/authorize', + token_endpoint: 'https://auth.example.com/token', + jwks_uri: 'https://auth.example.com/jwks', + subject_types_supported: ['public'], + id_token_signing_alg_values_supported: ['RS256'], + response_types_supported: ['code'], + code_challenge_methods_supported: ['S256'] + }; - it("validates token response schema", async () => { - mockFetch.mockResolvedValueOnce({ - ok: true, - status: 200, - json: async () => ({ - // Missing required fields - access_token: "newaccess123", - }), - }); - - await expect( - refreshAuthorization("https://auth.example.com", { - clientInformation: validClientInfo, - refreshToken: "refresh123", - }) - ).rejects.toThrow(); - }); + it('tries URLs in order and returns first successful metadata', async () => { + // First OAuth URL fails with 404 + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 404 + }); + + // Second OAuth URL (root) succeeds + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validOAuthMetadata + }); + + const metadata = await discoverAuthorizationServerMetadata('https://auth.example.com/tenant1'); + + expect(metadata).toEqual(validOAuthMetadata); + + // Verify it tried the URLs in the correct order + const calls = mockFetch.mock.calls; + expect(calls.length).toBe(2); + expect(calls[0][0].toString()).toBe('https://auth.example.com/.well-known/oauth-authorization-server/tenant1'); + expect(calls[1][0].toString()).toBe('https://auth.example.com/.well-known/oauth-authorization-server'); + }); + + it('throws error when OIDC provider does not support S256 PKCE', async () => { + // OAuth discovery fails + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 404 + }); + + // OpenID Connect discovery succeeds but without S256 support + const invalidOpenIdMetadata = { + ...validOpenIdMetadata, + code_challenge_methods_supported: ['plain'] // Missing S256 + }; + + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => invalidOpenIdMetadata + }); + + await expect(discoverAuthorizationServerMetadata('https://auth.example.com')).rejects.toThrow( + 'does not support S256 code challenge method required by MCP specification' + ); + }); + + it('continues on 4xx errors', async () => { + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 400 + }); + + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validOpenIdMetadata + }); + + const metadata = await discoverAuthorizationServerMetadata('https://mcp.example.com'); + + expect(metadata).toEqual(validOpenIdMetadata); + }); + + it('throws on non-4xx errors', async () => { + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 500 + }); + + await expect(discoverAuthorizationServerMetadata('https://mcp.example.com')).rejects.toThrow('HTTP 500'); + }); + + it('handles CORS errors with retry', async () => { + // First call fails with CORS + mockFetch.mockImplementationOnce(() => Promise.reject(new TypeError('CORS error'))); + + // Retry without headers succeeds + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validOAuthMetadata + }); + + const metadata = await discoverAuthorizationServerMetadata('https://auth.example.com'); + + expect(metadata).toEqual(validOAuthMetadata); + const calls = mockFetch.mock.calls; + expect(calls.length).toBe(2); + + // First call should have headers + expect(calls[0][1]?.headers).toHaveProperty('MCP-Protocol-Version'); + + // Second call should not have headers (CORS retry) + expect(calls[1][1]?.headers).toBeUndefined(); + }); + + it('supports custom fetch function', async () => { + const customFetch = jest.fn().mockResolvedValue({ + ok: true, + status: 200, + json: async () => validOAuthMetadata + }); + + const metadata = await discoverAuthorizationServerMetadata('https://auth.example.com', { fetchFn: customFetch }); + + expect(metadata).toEqual(validOAuthMetadata); + expect(customFetch).toHaveBeenCalledTimes(1); + expect(mockFetch).not.toHaveBeenCalled(); + }); + + it('supports custom protocol version', async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validOAuthMetadata + }); + + const metadata = await discoverAuthorizationServerMetadata('https://auth.example.com', { protocolVersion: '2025-01-01' }); + + expect(metadata).toEqual(validOAuthMetadata); + const calls = mockFetch.mock.calls; + const [, options] = calls[0]; + expect(options.headers).toEqual({ + 'MCP-Protocol-Version': '2025-01-01' + }); + }); + + it('returns undefined when all URLs fail with CORS errors', async () => { + // All fetch attempts fail with CORS errors (TypeError) + mockFetch.mockImplementation(() => Promise.reject(new TypeError('CORS error'))); + + const metadata = await discoverAuthorizationServerMetadata('https://auth.example.com/tenant1'); + + expect(metadata).toBeUndefined(); + + // Verify that all discovery URLs were attempted + expect(mockFetch).toHaveBeenCalledTimes(8); // 4 URLs × 2 attempts each (with and without headers) + }); + }); + + describe('startAuthorization', () => { + const validMetadata = { + issuer: 'https://auth.example.com', + authorization_endpoint: 'https://auth.example.com/auth', + token_endpoint: 'https://auth.example.com/tkn', + response_types_supported: ['code'], + code_challenge_methods_supported: ['S256'] + }; - it("throws on error response", async () => { - mockFetch.mockResolvedValueOnce( - Response.json( - new ServerError("Token refresh failed").toResponseObject(), - { status: 400 } - ) - ); - - await expect( - refreshAuthorization("https://auth.example.com", { - clientInformation: validClientInfo, - refreshToken: "refresh123", - }) - ).rejects.toThrow("Token refresh failed"); - }); - }); - - describe("registerClient", () => { - const validClientMetadata = { - redirect_uris: ["http://localhost:3000/callback"], - client_name: "Test Client", - }; - - const validClientInfo = { - client_id: "client123", - client_secret: "secret123", - client_id_issued_at: 1612137600, - client_secret_expires_at: 1612224000, - ...validClientMetadata, - }; - - it("registers client and returns client information", async () => { - mockFetch.mockResolvedValueOnce({ - ok: true, - status: 200, - json: async () => validClientInfo, - }); - - const clientInfo = await registerClient("https://auth.example.com", { - clientMetadata: validClientMetadata, - }); - - expect(clientInfo).toEqual(validClientInfo); - expect(mockFetch).toHaveBeenCalledWith( - expect.objectContaining({ - href: "https://auth.example.com/register", - }), - expect.objectContaining({ - method: "POST", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify(validClientMetadata), - }) - ); - }); + const validClientInfo = { + client_id: 'client123', + client_secret: 'secret123', + redirect_uris: ['http://localhost:3000/callback'], + client_name: 'Test Client' + }; - it("validates client information response schema", async () => { - mockFetch.mockResolvedValueOnce({ - ok: true, - status: 200, - json: async () => ({ - // Missing required fields - client_secret: "secret123", - }), - }); - - await expect( - registerClient("https://auth.example.com", { - clientMetadata: validClientMetadata, - }) - ).rejects.toThrow(); - }); + it('generates authorization URL with PKCE challenge', async () => { + const { authorizationUrl, codeVerifier } = await startAuthorization('https://auth.example.com', { + metadata: undefined, + clientInformation: validClientInfo, + redirectUrl: 'http://localhost:3000/callback', + resource: new URL('https://api.example.com/mcp-server') + }); + + expect(authorizationUrl.toString()).toMatch(/^https:\/\/auth\.example\.com\/authorize\?/); + expect(authorizationUrl.searchParams.get('response_type')).toBe('code'); + expect(authorizationUrl.searchParams.get('code_challenge')).toBe('test_challenge'); + expect(authorizationUrl.searchParams.get('code_challenge_method')).toBe('S256'); + expect(authorizationUrl.searchParams.get('redirect_uri')).toBe('http://localhost:3000/callback'); + expect(authorizationUrl.searchParams.get('resource')).toBe('https://api.example.com/mcp-server'); + expect(codeVerifier).toBe('test_verifier'); + }); + + it('includes scope parameter when provided', async () => { + const { authorizationUrl } = await startAuthorization('https://auth.example.com', { + clientInformation: validClientInfo, + redirectUrl: 'http://localhost:3000/callback', + scope: 'read write profile' + }); + + expect(authorizationUrl.searchParams.get('scope')).toBe('read write profile'); + }); + + it('excludes scope parameter when not provided', async () => { + const { authorizationUrl } = await startAuthorization('https://auth.example.com', { + clientInformation: validClientInfo, + redirectUrl: 'http://localhost:3000/callback' + }); + + expect(authorizationUrl.searchParams.has('scope')).toBe(false); + }); + + it('includes state parameter when provided', async () => { + const { authorizationUrl } = await startAuthorization('https://auth.example.com', { + clientInformation: validClientInfo, + redirectUrl: 'http://localhost:3000/callback', + state: 'foobar' + }); + + expect(authorizationUrl.searchParams.get('state')).toBe('foobar'); + }); + + it('excludes state parameter when not provided', async () => { + const { authorizationUrl } = await startAuthorization('https://auth.example.com', { + clientInformation: validClientInfo, + redirectUrl: 'http://localhost:3000/callback' + }); + + expect(authorizationUrl.searchParams.has('state')).toBe(false); + }); + + // OpenID Connect requires that the user is prompted for consent if the scope includes 'offline_access' + it("includes consent prompt parameter if scope includes 'offline_access'", async () => { + const { authorizationUrl } = await startAuthorization('https://auth.example.com', { + clientInformation: validClientInfo, + redirectUrl: 'http://localhost:3000/callback', + scope: 'read write profile offline_access' + }); + + expect(authorizationUrl.searchParams.get('prompt')).toBe('consent'); + }); + + it('uses metadata authorization_endpoint when provided', async () => { + const { authorizationUrl } = await startAuthorization('https://auth.example.com', { + metadata: validMetadata, + clientInformation: validClientInfo, + redirectUrl: 'http://localhost:3000/callback' + }); + + expect(authorizationUrl.toString()).toMatch(/^https:\/\/auth\.example\.com\/auth\?/); + }); + + it('validates response type support', async () => { + const metadata = { + ...validMetadata, + response_types_supported: ['token'] // Does not support 'code' + }; + + await expect( + startAuthorization('https://auth.example.com', { + metadata, + clientInformation: validClientInfo, + redirectUrl: 'http://localhost:3000/callback' + }) + ).rejects.toThrow(/does not support response type/); + }); + + it('validates PKCE support', async () => { + const metadata = { + ...validMetadata, + response_types_supported: ['code'], + code_challenge_methods_supported: ['plain'] // Does not support 'S256' + }; + + await expect( + startAuthorization('https://auth.example.com', { + metadata, + clientInformation: validClientInfo, + redirectUrl: 'http://localhost:3000/callback' + }) + ).rejects.toThrow(/does not support code challenge method/); + }); + }); + + describe('exchangeAuthorization', () => { + const validTokens = { + access_token: 'access123', + token_type: 'Bearer', + expires_in: 3600, + refresh_token: 'refresh123' + }; - it("throws when registration endpoint not available in metadata", async () => { - const metadata = { - issuer: "https://auth.example.com", - authorization_endpoint: "https://auth.example.com/authorize", - token_endpoint: "https://auth.example.com/token", - response_types_supported: ["code"], - }; - - await expect( - registerClient("https://auth.example.com", { - metadata, - clientMetadata: validClientMetadata, - }) - ).rejects.toThrow(/does not support dynamic client registration/); - }); + const validMetadata = { + issuer: 'https://auth.example.com', + authorization_endpoint: 'https://auth.example.com/authorize', + token_endpoint: 'https://auth.example.com/token', + response_types_supported: ['code'] + }; - it("throws on error response", async () => { - mockFetch.mockResolvedValueOnce( - Response.json( - new ServerError("Dynamic client registration failed").toResponseObject(), - { status: 400 } - ) - ); - - await expect( - registerClient("https://auth.example.com", { - clientMetadata: validClientMetadata, - }) - ).rejects.toThrow("Dynamic client registration failed"); - }); - }); - - describe("auth function", () => { - const mockProvider: OAuthClientProvider = { - get redirectUrl() { return "http://localhost:3000/callback"; }, - get clientMetadata() { - return { - redirect_uris: ["http://localhost:3000/callback"], - client_name: "Test Client", + const validClientInfo = { + client_id: 'client123', + client_secret: 'secret123', + redirect_uris: ['http://localhost:3000/callback'], + client_name: 'Test Client' }; - }, - clientInformation: jest.fn(), - tokens: jest.fn(), - saveTokens: jest.fn(), - redirectToAuthorization: jest.fn(), - saveCodeVerifier: jest.fn(), - codeVerifier: jest.fn(), - }; - beforeEach(() => { - jest.clearAllMocks(); - }); + it('exchanges code for tokens', async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validTokens + }); + + const tokens = await exchangeAuthorization('https://auth.example.com', { + clientInformation: validClientInfo, + authorizationCode: 'code123', + codeVerifier: 'verifier123', + redirectUri: 'http://localhost:3000/callback', + resource: new URL('https://api.example.com/mcp-server') + }); + + expect(tokens).toEqual(validTokens); + expect(mockFetch).toHaveBeenCalledWith( + expect.objectContaining({ + href: 'https://auth.example.com/token' + }), + expect.objectContaining({ + method: 'POST', + headers: new Headers({ + 'Content-Type': 'application/x-www-form-urlencoded' + }) + }) + ); + + const body = mockFetch.mock.calls[0][1].body as URLSearchParams; + expect(body.get('grant_type')).toBe('authorization_code'); + expect(body.get('code')).toBe('code123'); + expect(body.get('code_verifier')).toBe('verifier123'); + expect(body.get('client_id')).toBe('client123'); + expect(body.get('client_secret')).toBe('secret123'); + expect(body.get('redirect_uri')).toBe('http://localhost:3000/callback'); + expect(body.get('resource')).toBe('https://api.example.com/mcp-server'); + }); + + it('exchanges code for tokens with auth', async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validTokens + }); + + const tokens = await exchangeAuthorization('https://auth.example.com', { + metadata: validMetadata, + clientInformation: validClientInfo, + authorizationCode: 'code123', + codeVerifier: 'verifier123', + redirectUri: 'http://localhost:3000/callback', + addClientAuthentication: ( + headers: Headers, + params: URLSearchParams, + url: string | URL, + metadata: AuthorizationServerMetadata + ) => { + headers.set('Authorization', 'Basic ' + btoa(validClientInfo.client_id + ':' + validClientInfo.client_secret)); + params.set('example_url', typeof url === 'string' ? url : url.toString()); + params.set('example_metadata', metadata.authorization_endpoint); + params.set('example_param', 'example_value'); + } + }); + + expect(tokens).toEqual(validTokens); + expect(mockFetch).toHaveBeenCalledWith( + expect.objectContaining({ + href: 'https://auth.example.com/token' + }), + expect.objectContaining({ + method: 'POST' + }) + ); + + const headers = mockFetch.mock.calls[0][1].headers as Headers; + expect(headers.get('Content-Type')).toBe('application/x-www-form-urlencoded'); + expect(headers.get('Authorization')).toBe('Basic Y2xpZW50MTIzOnNlY3JldDEyMw=='); + const body = mockFetch.mock.calls[0][1].body as URLSearchParams; + expect(body.get('grant_type')).toBe('authorization_code'); + expect(body.get('code')).toBe('code123'); + expect(body.get('code_verifier')).toBe('verifier123'); + expect(body.get('client_id')).toBeNull(); + expect(body.get('redirect_uri')).toBe('http://localhost:3000/callback'); + expect(body.get('example_url')).toBe('https://auth.example.com'); + expect(body.get('example_metadata')).toBe('https://auth.example.com/authorize'); + expect(body.get('example_param')).toBe('example_value'); + expect(body.get('client_secret')).toBeNull(); + }); + + it('validates token response schema', async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => ({ + // Missing required fields + access_token: 'access123' + }) + }); + + await expect( + exchangeAuthorization('https://auth.example.com', { + clientInformation: validClientInfo, + authorizationCode: 'code123', + codeVerifier: 'verifier123', + redirectUri: 'http://localhost:3000/callback' + }) + ).rejects.toThrow(); + }); + + it('throws on error response', async () => { + mockFetch.mockResolvedValueOnce(Response.json(new ServerError('Token exchange failed').toResponseObject(), { status: 400 })); + + await expect( + exchangeAuthorization('https://auth.example.com', { + clientInformation: validClientInfo, + authorizationCode: 'code123', + codeVerifier: 'verifier123', + redirectUri: 'http://localhost:3000/callback' + }) + ).rejects.toThrow('Token exchange failed'); + }); + + it('supports overriding the fetch function used for requests', async () => { + const customFetch = jest.fn().mockResolvedValue({ + ok: true, + status: 200, + json: async () => validTokens + }); + + const tokens = await exchangeAuthorization('https://auth.example.com', { + clientInformation: validClientInfo, + authorizationCode: 'code123', + codeVerifier: 'verifier123', + redirectUri: 'http://localhost:3000/callback', + resource: new URL('https://api.example.com/mcp-server'), + fetchFn: customFetch + }); + + expect(tokens).toEqual(validTokens); + expect(customFetch).toHaveBeenCalledTimes(1); + expect(mockFetch).not.toHaveBeenCalled(); + + const [url, options] = customFetch.mock.calls[0]; + expect(url.toString()).toBe('https://auth.example.com/token'); + expect(options).toEqual( + expect.objectContaining({ + method: 'POST', + headers: expect.any(Headers), + body: expect.any(URLSearchParams) + }) + ); + + const body = options.body as URLSearchParams; + expect(body.get('grant_type')).toBe('authorization_code'); + expect(body.get('code')).toBe('code123'); + expect(body.get('code_verifier')).toBe('verifier123'); + expect(body.get('client_id')).toBe('client123'); + expect(body.get('client_secret')).toBe('secret123'); + expect(body.get('redirect_uri')).toBe('http://localhost:3000/callback'); + expect(body.get('resource')).toBe('https://api.example.com/mcp-server'); + }); + }); + + describe('refreshAuthorization', () => { + const validTokens = { + access_token: 'newaccess123', + token_type: 'Bearer', + expires_in: 3600 + }; + const validTokensWithNewRefreshToken = { + ...validTokens, + refresh_token: 'newrefresh123' + }; - it("falls back to /.well-known/oauth-authorization-server when no protected-resource-metadata", async () => { - // Setup: First call to protected resource metadata fails (404) - // Second call to auth server metadata succeeds - let callCount = 0; - mockFetch.mockImplementation((url) => { - callCount++; - - const urlString = url.toString(); - - if (callCount === 1 && urlString.includes("/.well-known/oauth-protected-resource")) { - // First call - protected resource metadata fails with 404 - return Promise.resolve({ - ok: false, - status: 404, - }); - } else if (callCount === 2 && urlString.includes("/.well-known/oauth-authorization-server")) { - // Second call - auth server metadata succeeds - return Promise.resolve({ - ok: true, - status: 200, - json: async () => ({ - issuer: "https://auth.example.com", - authorization_endpoint: "https://auth.example.com/authorize", - token_endpoint: "https://auth.example.com/token", - registration_endpoint: "https://auth.example.com/register", - response_types_supported: ["code"], - code_challenge_methods_supported: ["S256"], - }), - }); - } else if (callCount === 3 && urlString.includes("/register")) { - // Third call - client registration succeeds - return Promise.resolve({ - ok: true, - status: 200, - json: async () => ({ - client_id: "test-client-id", - client_secret: "test-client-secret", - client_id_issued_at: 1612137600, - client_secret_expires_at: 1612224000, - redirect_uris: ["http://localhost:3000/callback"], - client_name: "Test Client", - }), - }); - } - - return Promise.reject(new Error(`Unexpected fetch call: ${urlString}`)); - }); - - // Mock provider methods - (mockProvider.clientInformation as jest.Mock).mockResolvedValue(undefined); - (mockProvider.tokens as jest.Mock).mockResolvedValue(undefined); - mockProvider.saveClientInformation = jest.fn(); - - // Call the auth function - const result = await auth(mockProvider, { - serverUrl: "https://resource.example.com", - }); - - // Verify the result - expect(result).toBe("REDIRECT"); - - // Verify the sequence of calls - expect(mockFetch).toHaveBeenCalledTimes(3); - - // First call should be to protected resource metadata - expect(mockFetch.mock.calls[0][0].toString()).toBe( - "https://resource.example.com/.well-known/oauth-protected-resource" - ); - - // Second call should be to oauth metadata - expect(mockFetch.mock.calls[1][0].toString()).toBe( - "https://resource.example.com/.well-known/oauth-authorization-server" - ); - }); + const validMetadata = { + issuer: 'https://auth.example.com', + authorization_endpoint: 'https://auth.example.com/authorize', + token_endpoint: 'https://auth.example.com/token', + response_types_supported: ['code'] + }; - it("passes resource parameter through authorization flow", async () => { - // Mock successful metadata discovery - need to include protected resource metadata - mockFetch.mockImplementation((url) => { - const urlString = url.toString(); - if (urlString.includes("/.well-known/oauth-protected-resource")) { - return Promise.resolve({ - ok: true, - status: 200, - json: async () => ({ - resource: "https://api.example.com/mcp-server", - authorization_servers: ["https://auth.example.com"], - }), - }); - } else if (urlString.includes("/.well-known/oauth-authorization-server")) { - return Promise.resolve({ - ok: true, - status: 200, - json: async () => ({ - issuer: "https://auth.example.com", - authorization_endpoint: "https://auth.example.com/authorize", - token_endpoint: "https://auth.example.com/token", - response_types_supported: ["code"], - code_challenge_methods_supported: ["S256"], - }), - }); - } - return Promise.resolve({ ok: false, status: 404 }); - }); - - // Mock provider methods for authorization flow - (mockProvider.clientInformation as jest.Mock).mockResolvedValue({ - client_id: "test-client", - client_secret: "test-secret", - }); - (mockProvider.tokens as jest.Mock).mockResolvedValue(undefined); - (mockProvider.saveCodeVerifier as jest.Mock).mockResolvedValue(undefined); - (mockProvider.redirectToAuthorization as jest.Mock).mockResolvedValue(undefined); - - // Call auth without authorization code (should trigger redirect) - const result = await auth(mockProvider, { - serverUrl: "https://api.example.com/mcp-server", - }); - - expect(result).toBe("REDIRECT"); - - // Verify the authorization URL includes the resource parameter - expect(mockProvider.redirectToAuthorization).toHaveBeenCalledWith( - expect.objectContaining({ - searchParams: expect.any(URLSearchParams), - }) - ); - - const redirectCall = (mockProvider.redirectToAuthorization as jest.Mock).mock.calls[0]; - const authUrl: URL = redirectCall[0]; - expect(authUrl.searchParams.get("resource")).toBe("https://api.example.com/mcp-server"); - }); + const validClientInfo = { + client_id: 'client123', + client_secret: 'secret123', + redirect_uris: ['http://localhost:3000/callback'], + client_name: 'Test Client' + }; - it("includes resource in token exchange when authorization code is provided", async () => { - // Mock successful metadata discovery and token exchange - need protected resource metadata - mockFetch.mockImplementation((url) => { - const urlString = url.toString(); - - if (urlString.includes("/.well-known/oauth-protected-resource")) { - return Promise.resolve({ - ok: true, - status: 200, - json: async () => ({ - resource: "https://api.example.com/mcp-server", - authorization_servers: ["https://auth.example.com"], - }), - }); - } else if (urlString.includes("/.well-known/oauth-authorization-server")) { - return Promise.resolve({ - ok: true, - status: 200, - json: async () => ({ - issuer: "https://auth.example.com", - authorization_endpoint: "https://auth.example.com/authorize", - token_endpoint: "https://auth.example.com/token", - response_types_supported: ["code"], - code_challenge_methods_supported: ["S256"], - }), - }); - } else if (urlString.includes("/token")) { - return Promise.resolve({ - ok: true, - status: 200, - json: async () => ({ - access_token: "access123", - token_type: "Bearer", - expires_in: 3600, - refresh_token: "refresh123", - }), - }); - } - - return Promise.resolve({ ok: false, status: 404 }); - }); - - // Mock provider methods for token exchange - (mockProvider.clientInformation as jest.Mock).mockResolvedValue({ - client_id: "test-client", - client_secret: "test-secret", - }); - (mockProvider.codeVerifier as jest.Mock).mockResolvedValue("test-verifier"); - (mockProvider.saveTokens as jest.Mock).mockResolvedValue(undefined); - - // Call auth with authorization code - const result = await auth(mockProvider, { - serverUrl: "https://api.example.com/mcp-server", - authorizationCode: "auth-code-123", - }); - - expect(result).toBe("AUTHORIZED"); - - // Find the token exchange call - const tokenCall = mockFetch.mock.calls.find(call => - call[0].toString().includes("/token") - ); - expect(tokenCall).toBeDefined(); - - const body = tokenCall![1].body as URLSearchParams; - expect(body.get("resource")).toBe("https://api.example.com/mcp-server"); - expect(body.get("code")).toBe("auth-code-123"); - }); + it('exchanges refresh token for new tokens', async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validTokensWithNewRefreshToken + }); + + const tokens = await refreshAuthorization('https://auth.example.com', { + clientInformation: validClientInfo, + refreshToken: 'refresh123', + resource: new URL('https://api.example.com/mcp-server') + }); + + expect(tokens).toEqual(validTokensWithNewRefreshToken); + expect(mockFetch).toHaveBeenCalledWith( + expect.objectContaining({ + href: 'https://auth.example.com/token' + }), + expect.objectContaining({ + method: 'POST', + headers: new Headers({ + 'Content-Type': 'application/x-www-form-urlencoded' + }) + }) + ); + + const body = mockFetch.mock.calls[0][1].body as URLSearchParams; + expect(body.get('grant_type')).toBe('refresh_token'); + expect(body.get('refresh_token')).toBe('refresh123'); + expect(body.get('client_id')).toBe('client123'); + expect(body.get('client_secret')).toBe('secret123'); + expect(body.get('resource')).toBe('https://api.example.com/mcp-server'); + }); + + it('exchanges refresh token for new tokens with auth', async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validTokensWithNewRefreshToken + }); + + const tokens = await refreshAuthorization('https://auth.example.com', { + metadata: validMetadata, + clientInformation: validClientInfo, + refreshToken: 'refresh123', + addClientAuthentication: ( + headers: Headers, + params: URLSearchParams, + url: string | URL, + metadata?: AuthorizationServerMetadata + ) => { + headers.set('Authorization', 'Basic ' + btoa(validClientInfo.client_id + ':' + validClientInfo.client_secret)); + params.set('example_url', typeof url === 'string' ? url : url.toString()); + params.set('example_metadata', metadata?.authorization_endpoint ?? '?'); + params.set('example_param', 'example_value'); + } + }); + + expect(tokens).toEqual(validTokensWithNewRefreshToken); + expect(mockFetch).toHaveBeenCalledWith( + expect.objectContaining({ + href: 'https://auth.example.com/token' + }), + expect.objectContaining({ + method: 'POST' + }) + ); + + const headers = mockFetch.mock.calls[0][1].headers as Headers; + expect(headers.get('Content-Type')).toBe('application/x-www-form-urlencoded'); + expect(headers.get('Authorization')).toBe('Basic Y2xpZW50MTIzOnNlY3JldDEyMw=='); + const body = mockFetch.mock.calls[0][1].body as URLSearchParams; + expect(body.get('grant_type')).toBe('refresh_token'); + expect(body.get('refresh_token')).toBe('refresh123'); + expect(body.get('client_id')).toBeNull(); + expect(body.get('example_url')).toBe('https://auth.example.com'); + expect(body.get('example_metadata')).toBe('https://auth.example.com/authorize'); + expect(body.get('example_param')).toBe('example_value'); + expect(body.get('client_secret')).toBeNull(); + }); + + it('exchanges refresh token for new tokens and keep existing refresh token if none is returned', async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validTokens + }); + + const refreshToken = 'refresh123'; + const tokens = await refreshAuthorization('https://auth.example.com', { + clientInformation: validClientInfo, + refreshToken + }); + + expect(tokens).toEqual({ refresh_token: refreshToken, ...validTokens }); + }); + + it('validates token response schema', async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => ({ + // Missing required fields + access_token: 'newaccess123' + }) + }); + + await expect( + refreshAuthorization('https://auth.example.com', { + clientInformation: validClientInfo, + refreshToken: 'refresh123' + }) + ).rejects.toThrow(); + }); + + it('throws on error response', async () => { + mockFetch.mockResolvedValueOnce(Response.json(new ServerError('Token refresh failed').toResponseObject(), { status: 400 })); + + await expect( + refreshAuthorization('https://auth.example.com', { + clientInformation: validClientInfo, + refreshToken: 'refresh123' + }) + ).rejects.toThrow('Token refresh failed'); + }); + }); + + describe('registerClient', () => { + const validClientMetadata = { + redirect_uris: ['http://localhost:3000/callback'], + client_name: 'Test Client' + }; - it("includes resource in token refresh", async () => { - // Mock successful metadata discovery and token refresh - need protected resource metadata - mockFetch.mockImplementation((url) => { - const urlString = url.toString(); - - if (urlString.includes("/.well-known/oauth-protected-resource")) { - return Promise.resolve({ - ok: true, - status: 200, - json: async () => ({ - resource: "https://api.example.com/mcp-server", - authorization_servers: ["https://auth.example.com"], - }), - }); - } else if (urlString.includes("/.well-known/oauth-authorization-server")) { - return Promise.resolve({ - ok: true, - status: 200, - json: async () => ({ - issuer: "https://auth.example.com", - authorization_endpoint: "https://auth.example.com/authorize", - token_endpoint: "https://auth.example.com/token", - response_types_supported: ["code"], - code_challenge_methods_supported: ["S256"], - }), - }); - } else if (urlString.includes("/token")) { - return Promise.resolve({ - ok: true, - status: 200, - json: async () => ({ - access_token: "new-access123", - token_type: "Bearer", - expires_in: 3600, - }), - }); - } - - return Promise.resolve({ ok: false, status: 404 }); - }); - - // Mock provider methods for token refresh - (mockProvider.clientInformation as jest.Mock).mockResolvedValue({ - client_id: "test-client", - client_secret: "test-secret", - }); - (mockProvider.tokens as jest.Mock).mockResolvedValue({ - access_token: "old-access", - refresh_token: "refresh123", - }); - (mockProvider.saveTokens as jest.Mock).mockResolvedValue(undefined); - - // Call auth with existing tokens (should trigger refresh) - const result = await auth(mockProvider, { - serverUrl: "https://api.example.com/mcp-server", - }); - - expect(result).toBe("AUTHORIZED"); - - // Find the token refresh call - const tokenCall = mockFetch.mock.calls.find(call => - call[0].toString().includes("/token") - ); - expect(tokenCall).toBeDefined(); - - const body = tokenCall![1].body as URLSearchParams; - expect(body.get("resource")).toBe("https://api.example.com/mcp-server"); - expect(body.get("grant_type")).toBe("refresh_token"); - expect(body.get("refresh_token")).toBe("refresh123"); - }); + const validClientInfo = { + client_id: 'client123', + client_secret: 'secret123', + client_id_issued_at: 1612137600, + client_secret_expires_at: 1612224000, + ...validClientMetadata + }; - it("skips default PRM resource validation when custom validateResourceURL is provided", async () => { - const mockValidateResourceURL = jest.fn().mockResolvedValue(undefined); - const providerWithCustomValidation = { - ...mockProvider, - validateResourceURL: mockValidateResourceURL, - }; - - // Mock protected resource metadata with mismatched resource URL - // This would normally throw an error in default validation, but should be skipped - mockFetch.mockImplementation((url) => { - const urlString = url.toString(); - - if (urlString.includes("/.well-known/oauth-protected-resource")) { - return Promise.resolve({ - ok: true, - status: 200, - json: async () => ({ - resource: "https://different-resource.example.com/mcp-server", // Mismatched resource - authorization_servers: ["https://auth.example.com"], - }), - }); - } else if (urlString.includes("/.well-known/oauth-authorization-server")) { - return Promise.resolve({ - ok: true, - status: 200, - json: async () => ({ - issuer: "https://auth.example.com", - authorization_endpoint: "https://auth.example.com/authorize", - token_endpoint: "https://auth.example.com/token", - response_types_supported: ["code"], - code_challenge_methods_supported: ["S256"], - }), - }); - } - - return Promise.resolve({ ok: false, status: 404 }); - }); - - // Mock provider methods - (providerWithCustomValidation.clientInformation as jest.Mock).mockResolvedValue({ - client_id: "test-client", - client_secret: "test-secret", - }); - (providerWithCustomValidation.tokens as jest.Mock).mockResolvedValue(undefined); - (providerWithCustomValidation.saveCodeVerifier as jest.Mock).mockResolvedValue(undefined); - (providerWithCustomValidation.redirectToAuthorization as jest.Mock).mockResolvedValue(undefined); - - // Call auth - should succeed despite resource mismatch because custom validation overrides default - const result = await auth(providerWithCustomValidation, { - serverUrl: "https://api.example.com/mcp-server", - }); - - expect(result).toBe("REDIRECT"); - - // Verify custom validation method was called - expect(mockValidateResourceURL).toHaveBeenCalledWith( - new URL("https://api.example.com/mcp-server"), - "https://different-resource.example.com/mcp-server" - ); - }); + it('registers client and returns client information', async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validClientInfo + }); + + const clientInfo = await registerClient('https://auth.example.com', { + clientMetadata: validClientMetadata + }); + + expect(clientInfo).toEqual(validClientInfo); + expect(mockFetch).toHaveBeenCalledWith( + expect.objectContaining({ + href: 'https://auth.example.com/register' + }), + expect.objectContaining({ + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify(validClientMetadata) + }) + ); + }); + + it('validates client information response schema', async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => ({ + // Missing required fields + client_secret: 'secret123' + }) + }); + + await expect( + registerClient('https://auth.example.com', { + clientMetadata: validClientMetadata + }) + ).rejects.toThrow(); + }); + + it('throws when registration endpoint not available in metadata', async () => { + const metadata = { + issuer: 'https://auth.example.com', + authorization_endpoint: 'https://auth.example.com/authorize', + token_endpoint: 'https://auth.example.com/token', + response_types_supported: ['code'] + }; + + await expect( + registerClient('https://auth.example.com', { + metadata, + clientMetadata: validClientMetadata + }) + ).rejects.toThrow(/does not support dynamic client registration/); + }); + + it('throws on error response', async () => { + mockFetch.mockResolvedValueOnce( + Response.json(new ServerError('Dynamic client registration failed').toResponseObject(), { status: 400 }) + ); + + await expect( + registerClient('https://auth.example.com', { + clientMetadata: validClientMetadata + }) + ).rejects.toThrow('Dynamic client registration failed'); + }); + }); + + describe('auth function', () => { + const mockProvider: OAuthClientProvider = { + get redirectUrl() { + return 'http://localhost:3000/callback'; + }, + get clientMetadata() { + return { + redirect_uris: ['http://localhost:3000/callback'], + client_name: 'Test Client' + }; + }, + clientInformation: jest.fn(), + tokens: jest.fn(), + saveTokens: jest.fn(), + redirectToAuthorization: jest.fn(), + saveCodeVerifier: jest.fn(), + codeVerifier: jest.fn() + }; - it("uses prefix of server URL from PRM resource as resource parameter", async () => { - // Mock successful metadata discovery with resource URL that is a prefix of requested URL - mockFetch.mockImplementation((url) => { - const urlString = url.toString(); - - if (urlString.includes("/.well-known/oauth-protected-resource")) { - return Promise.resolve({ - ok: true, - status: 200, - json: async () => ({ - // Resource is a prefix of the requested server URL - resource: "https://api.example.com/", - authorization_servers: ["https://auth.example.com"], - }), - }); - } else if (urlString.includes("/.well-known/oauth-authorization-server")) { - return Promise.resolve({ - ok: true, - status: 200, - json: async () => ({ - issuer: "https://auth.example.com", - authorization_endpoint: "https://auth.example.com/authorize", - token_endpoint: "https://auth.example.com/token", - response_types_supported: ["code"], - code_challenge_methods_supported: ["S256"], - }), - }); - } - - return Promise.resolve({ ok: false, status: 404 }); - }); - - // Mock provider methods - (mockProvider.clientInformation as jest.Mock).mockResolvedValue({ - client_id: "test-client", - client_secret: "test-secret", - }); - (mockProvider.tokens as jest.Mock).mockResolvedValue(undefined); - (mockProvider.saveCodeVerifier as jest.Mock).mockResolvedValue(undefined); - (mockProvider.redirectToAuthorization as jest.Mock).mockResolvedValue(undefined); - - // Call auth with a URL that has the resource as prefix - const result = await auth(mockProvider, { - serverUrl: "https://api.example.com/mcp-server/endpoint", - }); - - expect(result).toBe("REDIRECT"); - - // Verify the authorization URL includes the resource parameter from PRM - expect(mockProvider.redirectToAuthorization).toHaveBeenCalledWith( - expect.objectContaining({ - searchParams: expect.any(URLSearchParams), - }) - ); - - const redirectCall = (mockProvider.redirectToAuthorization as jest.Mock).mock.calls[0]; - const authUrl: URL = redirectCall[0]; - // Should use the PRM's resource value, not the full requested URL - expect(authUrl.searchParams.get("resource")).toBe("https://api.example.com/"); - }); + beforeEach(() => { + jest.clearAllMocks(); + }); + + it('falls back to /.well-known/oauth-authorization-server when no protected-resource-metadata', async () => { + // Setup: First call to protected resource metadata fails (404) + // Second call to auth server metadata succeeds + let callCount = 0; + mockFetch.mockImplementation(url => { + callCount++; + + const urlString = url.toString(); + + if (callCount === 1 && urlString.includes('/.well-known/oauth-protected-resource')) { + // First call - protected resource metadata fails with 404 + return Promise.resolve({ + ok: false, + status: 404 + }); + } else if (callCount === 2 && urlString.includes('/.well-known/oauth-authorization-server')) { + // Second call - auth server metadata succeeds + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + issuer: 'https://auth.example.com', + authorization_endpoint: 'https://auth.example.com/authorize', + token_endpoint: 'https://auth.example.com/token', + registration_endpoint: 'https://auth.example.com/register', + response_types_supported: ['code'], + code_challenge_methods_supported: ['S256'] + }) + }); + } else if (callCount === 3 && urlString.includes('/register')) { + // Third call - client registration succeeds + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + client_id: 'test-client-id', + client_secret: 'test-client-secret', + client_id_issued_at: 1612137600, + client_secret_expires_at: 1612224000, + redirect_uris: ['http://localhost:3000/callback'], + client_name: 'Test Client' + }) + }); + } + + return Promise.reject(new Error(`Unexpected fetch call: ${urlString}`)); + }); + + // Mock provider methods + (mockProvider.clientInformation as jest.Mock).mockResolvedValue(undefined); + (mockProvider.tokens as jest.Mock).mockResolvedValue(undefined); + mockProvider.saveClientInformation = jest.fn(); + + // Call the auth function + const result = await auth(mockProvider, { + serverUrl: 'https://resource.example.com' + }); + + // Verify the result + expect(result).toBe('REDIRECT'); + + // Verify the sequence of calls + expect(mockFetch).toHaveBeenCalledTimes(3); + + // First call should be to protected resource metadata + expect(mockFetch.mock.calls[0][0].toString()).toBe('https://resource.example.com/.well-known/oauth-protected-resource'); + + // Second call should be to oauth metadata + expect(mockFetch.mock.calls[1][0].toString()).toBe('https://resource.example.com/.well-known/oauth-authorization-server'); + }); + + it('passes resource parameter through authorization flow', async () => { + // Mock successful metadata discovery - need to include protected resource metadata + mockFetch.mockImplementation(url => { + const urlString = url.toString(); + if (urlString.includes('/.well-known/oauth-protected-resource')) { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + resource: 'https://api.example.com/mcp-server', + authorization_servers: ['https://auth.example.com'] + }) + }); + } else if (urlString.includes('/.well-known/oauth-authorization-server')) { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + issuer: 'https://auth.example.com', + authorization_endpoint: 'https://auth.example.com/authorize', + token_endpoint: 'https://auth.example.com/token', + response_types_supported: ['code'], + code_challenge_methods_supported: ['S256'] + }) + }); + } + return Promise.resolve({ ok: false, status: 404 }); + }); + + // Mock provider methods for authorization flow + (mockProvider.clientInformation as jest.Mock).mockResolvedValue({ + client_id: 'test-client', + client_secret: 'test-secret' + }); + (mockProvider.tokens as jest.Mock).mockResolvedValue(undefined); + (mockProvider.saveCodeVerifier as jest.Mock).mockResolvedValue(undefined); + (mockProvider.redirectToAuthorization as jest.Mock).mockResolvedValue(undefined); + + // Call auth without authorization code (should trigger redirect) + const result = await auth(mockProvider, { + serverUrl: 'https://api.example.com/mcp-server' + }); + + expect(result).toBe('REDIRECT'); + + // Verify the authorization URL includes the resource parameter + expect(mockProvider.redirectToAuthorization).toHaveBeenCalledWith( + expect.objectContaining({ + searchParams: expect.any(URLSearchParams) + }) + ); + + const redirectCall = (mockProvider.redirectToAuthorization as jest.Mock).mock.calls[0]; + const authUrl: URL = redirectCall[0]; + expect(authUrl.searchParams.get('resource')).toBe('https://api.example.com/mcp-server'); + }); + + it('includes resource in token exchange when authorization code is provided', async () => { + // Mock successful metadata discovery and token exchange - need protected resource metadata + mockFetch.mockImplementation(url => { + const urlString = url.toString(); + + if (urlString.includes('/.well-known/oauth-protected-resource')) { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + resource: 'https://api.example.com/mcp-server', + authorization_servers: ['https://auth.example.com'] + }) + }); + } else if (urlString.includes('/.well-known/oauth-authorization-server')) { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + issuer: 'https://auth.example.com', + authorization_endpoint: 'https://auth.example.com/authorize', + token_endpoint: 'https://auth.example.com/token', + response_types_supported: ['code'], + code_challenge_methods_supported: ['S256'] + }) + }); + } else if (urlString.includes('/token')) { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + access_token: 'access123', + token_type: 'Bearer', + expires_in: 3600, + refresh_token: 'refresh123' + }) + }); + } + + return Promise.resolve({ ok: false, status: 404 }); + }); + + // Mock provider methods for token exchange + (mockProvider.clientInformation as jest.Mock).mockResolvedValue({ + client_id: 'test-client', + client_secret: 'test-secret' + }); + (mockProvider.codeVerifier as jest.Mock).mockResolvedValue('test-verifier'); + (mockProvider.saveTokens as jest.Mock).mockResolvedValue(undefined); + + // Call auth with authorization code + const result = await auth(mockProvider, { + serverUrl: 'https://api.example.com/mcp-server', + authorizationCode: 'auth-code-123' + }); + + expect(result).toBe('AUTHORIZED'); + + // Find the token exchange call + const tokenCall = mockFetch.mock.calls.find(call => call[0].toString().includes('/token')); + expect(tokenCall).toBeDefined(); + + const body = tokenCall![1].body as URLSearchParams; + expect(body.get('resource')).toBe('https://api.example.com/mcp-server'); + expect(body.get('code')).toBe('auth-code-123'); + }); + + it('includes resource in token refresh', async () => { + // Mock successful metadata discovery and token refresh - need protected resource metadata + mockFetch.mockImplementation(url => { + const urlString = url.toString(); + + if (urlString.includes('/.well-known/oauth-protected-resource')) { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + resource: 'https://api.example.com/mcp-server', + authorization_servers: ['https://auth.example.com'] + }) + }); + } else if (urlString.includes('/.well-known/oauth-authorization-server')) { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + issuer: 'https://auth.example.com', + authorization_endpoint: 'https://auth.example.com/authorize', + token_endpoint: 'https://auth.example.com/token', + response_types_supported: ['code'], + code_challenge_methods_supported: ['S256'] + }) + }); + } else if (urlString.includes('/token')) { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + access_token: 'new-access123', + token_type: 'Bearer', + expires_in: 3600 + }) + }); + } + + return Promise.resolve({ ok: false, status: 404 }); + }); + + // Mock provider methods for token refresh + (mockProvider.clientInformation as jest.Mock).mockResolvedValue({ + client_id: 'test-client', + client_secret: 'test-secret' + }); + (mockProvider.tokens as jest.Mock).mockResolvedValue({ + access_token: 'old-access', + refresh_token: 'refresh123' + }); + (mockProvider.saveTokens as jest.Mock).mockResolvedValue(undefined); + + // Call auth with existing tokens (should trigger refresh) + const result = await auth(mockProvider, { + serverUrl: 'https://api.example.com/mcp-server' + }); + + expect(result).toBe('AUTHORIZED'); + + // Find the token refresh call + const tokenCall = mockFetch.mock.calls.find(call => call[0].toString().includes('/token')); + expect(tokenCall).toBeDefined(); + + const body = tokenCall![1].body as URLSearchParams; + expect(body.get('resource')).toBe('https://api.example.com/mcp-server'); + expect(body.get('grant_type')).toBe('refresh_token'); + expect(body.get('refresh_token')).toBe('refresh123'); + }); + + it('skips default PRM resource validation when custom validateResourceURL is provided', async () => { + const mockValidateResourceURL = jest.fn().mockResolvedValue(undefined); + const providerWithCustomValidation = { + ...mockProvider, + validateResourceURL: mockValidateResourceURL + }; + + // Mock protected resource metadata with mismatched resource URL + // This would normally throw an error in default validation, but should be skipped + mockFetch.mockImplementation(url => { + const urlString = url.toString(); + + if (urlString.includes('/.well-known/oauth-protected-resource')) { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + resource: 'https://different-resource.example.com/mcp-server', // Mismatched resource + authorization_servers: ['https://auth.example.com'] + }) + }); + } else if (urlString.includes('/.well-known/oauth-authorization-server')) { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + issuer: 'https://auth.example.com', + authorization_endpoint: 'https://auth.example.com/authorize', + token_endpoint: 'https://auth.example.com/token', + response_types_supported: ['code'], + code_challenge_methods_supported: ['S256'] + }) + }); + } + + return Promise.resolve({ ok: false, status: 404 }); + }); + + // Mock provider methods + (providerWithCustomValidation.clientInformation as jest.Mock).mockResolvedValue({ + client_id: 'test-client', + client_secret: 'test-secret' + }); + (providerWithCustomValidation.tokens as jest.Mock).mockResolvedValue(undefined); + (providerWithCustomValidation.saveCodeVerifier as jest.Mock).mockResolvedValue(undefined); + (providerWithCustomValidation.redirectToAuthorization as jest.Mock).mockResolvedValue(undefined); + + // Call auth - should succeed despite resource mismatch because custom validation overrides default + const result = await auth(providerWithCustomValidation, { + serverUrl: 'https://api.example.com/mcp-server' + }); + + expect(result).toBe('REDIRECT'); + + // Verify custom validation method was called + expect(mockValidateResourceURL).toHaveBeenCalledWith( + new URL('https://api.example.com/mcp-server'), + 'https://different-resource.example.com/mcp-server' + ); + }); + + it('uses prefix of server URL from PRM resource as resource parameter', async () => { + // Mock successful metadata discovery with resource URL that is a prefix of requested URL + mockFetch.mockImplementation(url => { + const urlString = url.toString(); + + if (urlString.includes('/.well-known/oauth-protected-resource')) { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + // Resource is a prefix of the requested server URL + resource: 'https://api.example.com/', + authorization_servers: ['https://auth.example.com'] + }) + }); + } else if (urlString.includes('/.well-known/oauth-authorization-server')) { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + issuer: 'https://auth.example.com', + authorization_endpoint: 'https://auth.example.com/authorize', + token_endpoint: 'https://auth.example.com/token', + response_types_supported: ['code'], + code_challenge_methods_supported: ['S256'] + }) + }); + } + + return Promise.resolve({ ok: false, status: 404 }); + }); + + // Mock provider methods + (mockProvider.clientInformation as jest.Mock).mockResolvedValue({ + client_id: 'test-client', + client_secret: 'test-secret' + }); + (mockProvider.tokens as jest.Mock).mockResolvedValue(undefined); + (mockProvider.saveCodeVerifier as jest.Mock).mockResolvedValue(undefined); + (mockProvider.redirectToAuthorization as jest.Mock).mockResolvedValue(undefined); + + // Call auth with a URL that has the resource as prefix + const result = await auth(mockProvider, { + serverUrl: 'https://api.example.com/mcp-server/endpoint' + }); + + expect(result).toBe('REDIRECT'); + + // Verify the authorization URL includes the resource parameter from PRM + expect(mockProvider.redirectToAuthorization).toHaveBeenCalledWith( + expect.objectContaining({ + searchParams: expect.any(URLSearchParams) + }) + ); + + const redirectCall = (mockProvider.redirectToAuthorization as jest.Mock).mock.calls[0]; + const authUrl: URL = redirectCall[0]; + // Should use the PRM's resource value, not the full requested URL + expect(authUrl.searchParams.get('resource')).toBe('https://api.example.com/'); + }); + + it('excludes resource parameter when Protected Resource Metadata is not present', async () => { + // Mock metadata discovery where protected resource metadata is not available (404) + // but authorization server metadata is available + mockFetch.mockImplementation(url => { + const urlString = url.toString(); + + if (urlString.includes('/.well-known/oauth-protected-resource')) { + // Protected resource metadata not available + return Promise.resolve({ + ok: false, + status: 404 + }); + } else if (urlString.includes('/.well-known/oauth-authorization-server')) { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + issuer: 'https://auth.example.com', + authorization_endpoint: 'https://auth.example.com/authorize', + token_endpoint: 'https://auth.example.com/token', + response_types_supported: ['code'], + code_challenge_methods_supported: ['S256'] + }) + }); + } + + return Promise.resolve({ ok: false, status: 404 }); + }); + + // Mock provider methods + (mockProvider.clientInformation as jest.Mock).mockResolvedValue({ + client_id: 'test-client', + client_secret: 'test-secret' + }); + (mockProvider.tokens as jest.Mock).mockResolvedValue(undefined); + (mockProvider.saveCodeVerifier as jest.Mock).mockResolvedValue(undefined); + (mockProvider.redirectToAuthorization as jest.Mock).mockResolvedValue(undefined); + + // Call auth - should not include resource parameter + const result = await auth(mockProvider, { + serverUrl: 'https://api.example.com/mcp-server' + }); + + expect(result).toBe('REDIRECT'); + + // Verify the authorization URL does NOT include the resource parameter + expect(mockProvider.redirectToAuthorization).toHaveBeenCalledWith( + expect.objectContaining({ + searchParams: expect.any(URLSearchParams) + }) + ); + + const redirectCall = (mockProvider.redirectToAuthorization as jest.Mock).mock.calls[0]; + const authUrl: URL = redirectCall[0]; + // Resource parameter should not be present when PRM is not available + expect(authUrl.searchParams.has('resource')).toBe(false); + }); + + it('excludes resource parameter in token exchange when Protected Resource Metadata is not present', async () => { + // Mock metadata discovery - no protected resource metadata, but auth server metadata available + mockFetch.mockImplementation(url => { + const urlString = url.toString(); + + if (urlString.includes('/.well-known/oauth-protected-resource')) { + return Promise.resolve({ + ok: false, + status: 404 + }); + } else if (urlString.includes('/.well-known/oauth-authorization-server')) { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + issuer: 'https://auth.example.com', + authorization_endpoint: 'https://auth.example.com/authorize', + token_endpoint: 'https://auth.example.com/token', + response_types_supported: ['code'], + code_challenge_methods_supported: ['S256'] + }) + }); + } else if (urlString.includes('/token')) { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + access_token: 'access123', + token_type: 'Bearer', + expires_in: 3600, + refresh_token: 'refresh123' + }) + }); + } + + return Promise.resolve({ ok: false, status: 404 }); + }); + + // Mock provider methods for token exchange + (mockProvider.clientInformation as jest.Mock).mockResolvedValue({ + client_id: 'test-client', + client_secret: 'test-secret' + }); + (mockProvider.codeVerifier as jest.Mock).mockResolvedValue('test-verifier'); + (mockProvider.saveTokens as jest.Mock).mockResolvedValue(undefined); + + // Call auth with authorization code + const result = await auth(mockProvider, { + serverUrl: 'https://api.example.com/mcp-server', + authorizationCode: 'auth-code-123' + }); + + expect(result).toBe('AUTHORIZED'); + + // Find the token exchange call + const tokenCall = mockFetch.mock.calls.find(call => call[0].toString().includes('/token')); + expect(tokenCall).toBeDefined(); + + const body = tokenCall![1].body as URLSearchParams; + // Resource parameter should not be present when PRM is not available + expect(body.has('resource')).toBe(false); + expect(body.get('code')).toBe('auth-code-123'); + }); + + it('excludes resource parameter in token refresh when Protected Resource Metadata is not present', async () => { + // Mock metadata discovery - no protected resource metadata, but auth server metadata available + mockFetch.mockImplementation(url => { + const urlString = url.toString(); + + if (urlString.includes('/.well-known/oauth-protected-resource')) { + return Promise.resolve({ + ok: false, + status: 404 + }); + } else if (urlString.includes('/.well-known/oauth-authorization-server')) { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + issuer: 'https://auth.example.com', + authorization_endpoint: 'https://auth.example.com/authorize', + token_endpoint: 'https://auth.example.com/token', + response_types_supported: ['code'], + code_challenge_methods_supported: ['S256'] + }) + }); + } else if (urlString.includes('/token')) { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + access_token: 'new-access123', + token_type: 'Bearer', + expires_in: 3600 + }) + }); + } + + return Promise.resolve({ ok: false, status: 404 }); + }); + + // Mock provider methods for token refresh + (mockProvider.clientInformation as jest.Mock).mockResolvedValue({ + client_id: 'test-client', + client_secret: 'test-secret' + }); + (mockProvider.tokens as jest.Mock).mockResolvedValue({ + access_token: 'old-access', + refresh_token: 'refresh123' + }); + (mockProvider.saveTokens as jest.Mock).mockResolvedValue(undefined); + + // Call auth with existing tokens (should trigger refresh) + const result = await auth(mockProvider, { + serverUrl: 'https://api.example.com/mcp-server' + }); + + expect(result).toBe('AUTHORIZED'); + + // Find the token refresh call + const tokenCall = mockFetch.mock.calls.find(call => call[0].toString().includes('/token')); + expect(tokenCall).toBeDefined(); + + const body = tokenCall![1].body as URLSearchParams; + // Resource parameter should not be present when PRM is not available + expect(body.has('resource')).toBe(false); + expect(body.get('grant_type')).toBe('refresh_token'); + expect(body.get('refresh_token')).toBe('refresh123'); + }); + + it('fetches AS metadata with path from serverUrl when PRM returns external AS', async () => { + // Mock PRM discovery that returns an external AS + mockFetch.mockImplementation(url => { + const urlString = url.toString(); + + if (urlString === 'https://my.resource.com/.well-known/oauth-protected-resource/path/name') { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + resource: 'https://my.resource.com/', + authorization_servers: ['https://auth.example.com/oauth'] + }) + }); + } else if (urlString === 'https://auth.example.com/.well-known/oauth-authorization-server/path/name') { + // Path-aware discovery on AS with path from serverUrl + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + issuer: 'https://auth.example.com', + authorization_endpoint: 'https://auth.example.com/authorize', + token_endpoint: 'https://auth.example.com/token', + response_types_supported: ['code'], + code_challenge_methods_supported: ['S256'] + }) + }); + } + + return Promise.resolve({ ok: false, status: 404 }); + }); + + // Mock provider methods + (mockProvider.clientInformation as jest.Mock).mockResolvedValue({ + client_id: 'test-client', + client_secret: 'test-secret' + }); + (mockProvider.tokens as jest.Mock).mockResolvedValue(undefined); + (mockProvider.saveCodeVerifier as jest.Mock).mockResolvedValue(undefined); + (mockProvider.redirectToAuthorization as jest.Mock).mockResolvedValue(undefined); + + // Call auth with serverUrl that has a path + const result = await auth(mockProvider, { + serverUrl: 'https://my.resource.com/path/name' + }); + + expect(result).toBe('REDIRECT'); + + // Verify the correct URLs were fetched + const calls = mockFetch.mock.calls; + + // First call should be to PRM + expect(calls[0][0].toString()).toBe('https://my.resource.com/.well-known/oauth-protected-resource/path/name'); + + // Second call should be to AS metadata with the path from authorization server + expect(calls[1][0].toString()).toBe('https://auth.example.com/.well-known/oauth-authorization-server/oauth'); + }); + + it('supports overriding the fetch function used for requests', async () => { + const customFetch = jest.fn(); + + // Mock PRM discovery + customFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => ({ + resource: 'https://resource.example.com', + authorization_servers: ['https://auth.example.com'] + }) + }); + + // Mock AS metadata discovery + customFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => ({ + issuer: 'https://auth.example.com', + authorization_endpoint: 'https://auth.example.com/authorize', + token_endpoint: 'https://auth.example.com/token', + registration_endpoint: 'https://auth.example.com/register', + response_types_supported: ['code'], + code_challenge_methods_supported: ['S256'] + }) + }); + + const mockProvider: OAuthClientProvider = { + get redirectUrl() { + return 'http://localhost:3000/callback'; + }, + get clientMetadata() { + return { + client_name: 'Test Client', + redirect_uris: ['http://localhost:3000/callback'] + }; + }, + clientInformation: jest.fn().mockResolvedValue({ + client_id: 'client123', + client_secret: 'secret123' + }), + tokens: jest.fn().mockResolvedValue(undefined), + saveTokens: jest.fn(), + redirectToAuthorization: jest.fn(), + saveCodeVerifier: jest.fn(), + codeVerifier: jest.fn().mockResolvedValue('verifier123') + }; + + const result = await auth(mockProvider, { + serverUrl: 'https://resource.example.com', + fetchFn: customFetch + }); + + expect(result).toBe('REDIRECT'); + expect(customFetch).toHaveBeenCalledTimes(2); + expect(mockFetch).not.toHaveBeenCalled(); + + // Verify custom fetch was called for PRM discovery + expect(customFetch.mock.calls[0][0].toString()).toBe('https://resource.example.com/.well-known/oauth-protected-resource'); + + // Verify custom fetch was called for AS metadata discovery + expect(customFetch.mock.calls[1][0].toString()).toBe('https://auth.example.com/.well-known/oauth-authorization-server'); + }); + }); + + describe('exchangeAuthorization with multiple client authentication methods', () => { + const validTokens = { + access_token: 'access123', + token_type: 'Bearer', + expires_in: 3600, + refresh_token: 'refresh123' + }; - it("excludes resource parameter when Protected Resource Metadata is not present", async () => { - // Mock metadata discovery where protected resource metadata is not available (404) - // but authorization server metadata is available - mockFetch.mockImplementation((url) => { - const urlString = url.toString(); - - if (urlString.includes("/.well-known/oauth-protected-resource")) { - // Protected resource metadata not available - return Promise.resolve({ - ok: false, - status: 404, - }); - } else if (urlString.includes("/.well-known/oauth-authorization-server")) { - return Promise.resolve({ - ok: true, - status: 200, - json: async () => ({ - issuer: "https://auth.example.com", - authorization_endpoint: "https://auth.example.com/authorize", - token_endpoint: "https://auth.example.com/token", - response_types_supported: ["code"], - code_challenge_methods_supported: ["S256"], - }), - }); - } - - return Promise.resolve({ ok: false, status: 404 }); - }); - - // Mock provider methods - (mockProvider.clientInformation as jest.Mock).mockResolvedValue({ - client_id: "test-client", - client_secret: "test-secret", - }); - (mockProvider.tokens as jest.Mock).mockResolvedValue(undefined); - (mockProvider.saveCodeVerifier as jest.Mock).mockResolvedValue(undefined); - (mockProvider.redirectToAuthorization as jest.Mock).mockResolvedValue(undefined); - - // Call auth - should not include resource parameter - const result = await auth(mockProvider, { - serverUrl: "https://api.example.com/mcp-server", - }); - - expect(result).toBe("REDIRECT"); - - // Verify the authorization URL does NOT include the resource parameter - expect(mockProvider.redirectToAuthorization).toHaveBeenCalledWith( - expect.objectContaining({ - searchParams: expect.any(URLSearchParams), - }) - ); - - const redirectCall = (mockProvider.redirectToAuthorization as jest.Mock).mock.calls[0]; - const authUrl: URL = redirectCall[0]; - // Resource parameter should not be present when PRM is not available - expect(authUrl.searchParams.has("resource")).toBe(false); - }); + const validClientInfo = { + client_id: 'client123', + client_secret: 'secret123', + redirect_uris: ['http://localhost:3000/callback'], + client_name: 'Test Client' + }; - it("excludes resource parameter in token exchange when Protected Resource Metadata is not present", async () => { - // Mock metadata discovery - no protected resource metadata, but auth server metadata available - mockFetch.mockImplementation((url) => { - const urlString = url.toString(); - - if (urlString.includes("/.well-known/oauth-protected-resource")) { - return Promise.resolve({ - ok: false, - status: 404, - }); - } else if (urlString.includes("/.well-known/oauth-authorization-server")) { - return Promise.resolve({ - ok: true, - status: 200, - json: async () => ({ - issuer: "https://auth.example.com", - authorization_endpoint: "https://auth.example.com/authorize", - token_endpoint: "https://auth.example.com/token", - response_types_supported: ["code"], - code_challenge_methods_supported: ["S256"], - }), - }); - } else if (urlString.includes("/token")) { - return Promise.resolve({ - ok: true, - status: 200, - json: async () => ({ - access_token: "access123", - token_type: "Bearer", - expires_in: 3600, - refresh_token: "refresh123", - }), - }); - } - - return Promise.resolve({ ok: false, status: 404 }); - }); - - // Mock provider methods for token exchange - (mockProvider.clientInformation as jest.Mock).mockResolvedValue({ - client_id: "test-client", - client_secret: "test-secret", - }); - (mockProvider.codeVerifier as jest.Mock).mockResolvedValue("test-verifier"); - (mockProvider.saveTokens as jest.Mock).mockResolvedValue(undefined); - - // Call auth with authorization code - const result = await auth(mockProvider, { - serverUrl: "https://api.example.com/mcp-server", - authorizationCode: "auth-code-123", - }); - - expect(result).toBe("AUTHORIZED"); - - // Find the token exchange call - const tokenCall = mockFetch.mock.calls.find(call => - call[0].toString().includes("/token") - ); - expect(tokenCall).toBeDefined(); - - const body = tokenCall![1].body as URLSearchParams; - // Resource parameter should not be present when PRM is not available - expect(body.has("resource")).toBe(false); - expect(body.get("code")).toBe("auth-code-123"); - }); + const metadataWithBasicOnly = { + issuer: 'https://auth.example.com', + authorization_endpoint: 'https://auth.example.com/auth', + token_endpoint: 'https://auth.example.com/token', + response_types_supported: ['code'], + code_challenge_methods_supported: ['S256'], + token_endpoint_auth_methods_supported: ['client_secret_basic'] + }; - it("excludes resource parameter in token refresh when Protected Resource Metadata is not present", async () => { - // Mock metadata discovery - no protected resource metadata, but auth server metadata available - mockFetch.mockImplementation((url) => { - const urlString = url.toString(); - - if (urlString.includes("/.well-known/oauth-protected-resource")) { - return Promise.resolve({ - ok: false, - status: 404, - }); - } else if (urlString.includes("/.well-known/oauth-authorization-server")) { - return Promise.resolve({ - ok: true, - status: 200, - json: async () => ({ - issuer: "https://auth.example.com", - authorization_endpoint: "https://auth.example.com/authorize", - token_endpoint: "https://auth.example.com/token", - response_types_supported: ["code"], - code_challenge_methods_supported: ["S256"], - }), - }); - } else if (urlString.includes("/token")) { - return Promise.resolve({ - ok: true, - status: 200, - json: async () => ({ - access_token: "new-access123", - token_type: "Bearer", - expires_in: 3600, - }), - }); - } - - return Promise.resolve({ ok: false, status: 404 }); - }); - - // Mock provider methods for token refresh - (mockProvider.clientInformation as jest.Mock).mockResolvedValue({ - client_id: "test-client", - client_secret: "test-secret", - }); - (mockProvider.tokens as jest.Mock).mockResolvedValue({ - access_token: "old-access", - refresh_token: "refresh123", - }); - (mockProvider.saveTokens as jest.Mock).mockResolvedValue(undefined); - - // Call auth with existing tokens (should trigger refresh) - const result = await auth(mockProvider, { - serverUrl: "https://api.example.com/mcp-server", - }); - - expect(result).toBe("AUTHORIZED"); - - // Find the token refresh call - const tokenCall = mockFetch.mock.calls.find(call => - call[0].toString().includes("/token") - ); - expect(tokenCall).toBeDefined(); - - const body = tokenCall![1].body as URLSearchParams; - // Resource parameter should not be present when PRM is not available - expect(body.has("resource")).toBe(false); - expect(body.get("grant_type")).toBe("refresh_token"); - expect(body.get("refresh_token")).toBe("refresh123"); - }); + const metadataWithPostOnly = { + ...metadataWithBasicOnly, + token_endpoint_auth_methods_supported: ['client_secret_post'] + }; - it("fetches AS metadata with path from serverUrl when PRM returns external AS", async () => { - // Mock PRM discovery that returns an external AS - mockFetch.mockImplementation((url) => { - const urlString = url.toString(); - - if (urlString === "https://my.resource.com/.well-known/oauth-protected-resource/path/name") { - return Promise.resolve({ - ok: true, - status: 200, - json: async () => ({ - resource: "https://my.resource.com/", - authorization_servers: ["https://auth.example.com/oauth"], - }), - }); - } else if (urlString === "https://auth.example.com/.well-known/oauth-authorization-server/path/name") { - // Path-aware discovery on AS with path from serverUrl - return Promise.resolve({ - ok: true, - status: 200, - json: async () => ({ - issuer: "https://auth.example.com", - authorization_endpoint: "https://auth.example.com/authorize", - token_endpoint: "https://auth.example.com/token", - response_types_supported: ["code"], - code_challenge_methods_supported: ["S256"], - }), - }); - } - - return Promise.resolve({ ok: false, status: 404 }); - }); - - // Mock provider methods - (mockProvider.clientInformation as jest.Mock).mockResolvedValue({ - client_id: "test-client", - client_secret: "test-secret", - }); - (mockProvider.tokens as jest.Mock).mockResolvedValue(undefined); - (mockProvider.saveCodeVerifier as jest.Mock).mockResolvedValue(undefined); - (mockProvider.redirectToAuthorization as jest.Mock).mockResolvedValue(undefined); - - // Call auth with serverUrl that has a path - const result = await auth(mockProvider, { - serverUrl: "https://my.resource.com/path/name", - }); - - expect(result).toBe("REDIRECT"); - - // Verify the correct URLs were fetched - const calls = mockFetch.mock.calls; - - // First call should be to PRM - expect(calls[0][0].toString()).toBe("https://my.resource.com/.well-known/oauth-protected-resource/path/name"); - - // Second call should be to AS metadata with the path from authorization server - expect(calls[1][0].toString()).toBe("https://auth.example.com/.well-known/oauth-authorization-server/oauth"); - }); + const metadataWithNoneOnly = { + ...metadataWithBasicOnly, + token_endpoint_auth_methods_supported: ['none'] + }; - it("supports overriding the fetch function used for requests", async () => { - const customFetch = jest.fn(); - - // Mock PRM discovery - customFetch.mockResolvedValueOnce({ - ok: true, - status: 200, - json: async () => ({ - resource: "https://resource.example.com", - authorization_servers: ["https://auth.example.com"], - }), - }); - - // Mock AS metadata discovery - customFetch.mockResolvedValueOnce({ - ok: true, - status: 200, - json: async () => ({ - issuer: "https://auth.example.com", - authorization_endpoint: "https://auth.example.com/authorize", - token_endpoint: "https://auth.example.com/token", - registration_endpoint: "https://auth.example.com/register", - response_types_supported: ["code"], - code_challenge_methods_supported: ["S256"], - }), - }); - - const mockProvider: OAuthClientProvider = { - get redirectUrl() { return "http://localhost:3000/callback"; }, - get clientMetadata() { - return { - client_name: "Test Client", - redirect_uris: ["http://localhost:3000/callback"], - }; - }, - clientInformation: jest.fn().mockResolvedValue({ - client_id: "client123", - client_secret: "secret123", - }), - tokens: jest.fn().mockResolvedValue(undefined), - saveTokens: jest.fn(), - redirectToAuthorization: jest.fn(), - saveCodeVerifier: jest.fn(), - codeVerifier: jest.fn().mockResolvedValue("verifier123"), - }; - - const result = await auth(mockProvider, { - serverUrl: "https://resource.example.com", - fetchFn: customFetch, - }); - - expect(result).toBe("REDIRECT"); - expect(customFetch).toHaveBeenCalledTimes(2); - expect(mockFetch).not.toHaveBeenCalled(); - - // Verify custom fetch was called for PRM discovery - expect(customFetch.mock.calls[0][0].toString()).toBe("https://resource.example.com/.well-known/oauth-protected-resource"); - - // Verify custom fetch was called for AS metadata discovery - expect(customFetch.mock.calls[1][0].toString()).toBe("https://auth.example.com/.well-known/oauth-authorization-server"); - }); - }); - - describe("exchangeAuthorization with multiple client authentication methods", () => { - const validTokens = { - access_token: "access123", - token_type: "Bearer", - expires_in: 3600, - refresh_token: "refresh123", - }; - - const validClientInfo = { - client_id: "client123", - client_secret: "secret123", - redirect_uris: ["http://localhost:3000/callback"], - client_name: "Test Client", - }; - - const metadataWithBasicOnly = { - issuer: "https://auth.example.com", - authorization_endpoint: "https://auth.example.com/auth", - token_endpoint: "https://auth.example.com/token", - response_types_supported: ["code"], - code_challenge_methods_supported: ["S256"], - token_endpoint_auth_methods_supported: ["client_secret_basic"], - }; - - const metadataWithPostOnly = { - ...metadataWithBasicOnly, - token_endpoint_auth_methods_supported: ["client_secret_post"], - }; - - const metadataWithNoneOnly = { - ...metadataWithBasicOnly, - token_endpoint_auth_methods_supported: ["none"], - }; - - const metadataWithAllBuiltinMethods = { - ...metadataWithBasicOnly, - token_endpoint_auth_methods_supported: ["client_secret_basic", "client_secret_post", "none"], - }; - - it("uses HTTP Basic authentication when client_secret_basic is supported", async () => { - mockFetch.mockResolvedValueOnce({ - ok: true, - status: 200, - json: async () => validTokens, - }); - - const tokens = await exchangeAuthorization("https://auth.example.com", { - metadata: metadataWithBasicOnly, - clientInformation: validClientInfo, - authorizationCode: "code123", - redirectUri: "http://localhost:3000/callback", - codeVerifier: "verifier123", - }); - - expect(tokens).toEqual(validTokens); - const request = mockFetch.mock.calls[0][1]; - - // Check Authorization header - const authHeader = request.headers.get("Authorization"); - const expected = "Basic " + btoa("client123:secret123"); - expect(authHeader).toBe(expected); - - const body = request.body as URLSearchParams; - expect(body.get("client_id")).toBeNull(); - expect(body.get("client_secret")).toBeNull(); - }); + const metadataWithAllBuiltinMethods = { + ...metadataWithBasicOnly, + token_endpoint_auth_methods_supported: ['client_secret_basic', 'client_secret_post', 'none'] + }; - it("includes credentials in request body when client_secret_post is supported", async () => { - mockFetch.mockResolvedValueOnce({ - ok: true, - status: 200, - json: async () => validTokens, - }); - - const tokens = await exchangeAuthorization("https://auth.example.com", { - metadata: metadataWithPostOnly, - clientInformation: validClientInfo, - authorizationCode: "code123", - redirectUri: "http://localhost:3000/callback", - codeVerifier: "verifier123", - }); - - expect(tokens).toEqual(validTokens); - const request = mockFetch.mock.calls[0][1]; - - // Check no Authorization header - expect(request.headers.get("Authorization")).toBeNull(); - - const body = request.body as URLSearchParams; - expect(body.get("client_id")).toBe("client123"); - expect(body.get("client_secret")).toBe("secret123"); - }); + it('uses HTTP Basic authentication when client_secret_basic is supported', async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validTokens + }); + + const tokens = await exchangeAuthorization('https://auth.example.com', { + metadata: metadataWithBasicOnly, + clientInformation: validClientInfo, + authorizationCode: 'code123', + redirectUri: 'http://localhost:3000/callback', + codeVerifier: 'verifier123' + }); + + expect(tokens).toEqual(validTokens); + const request = mockFetch.mock.calls[0][1]; + + // Check Authorization header + const authHeader = request.headers.get('Authorization'); + const expected = 'Basic ' + btoa('client123:secret123'); + expect(authHeader).toBe(expected); + + const body = request.body as URLSearchParams; + expect(body.get('client_id')).toBeNull(); + expect(body.get('client_secret')).toBeNull(); + }); + + it('includes credentials in request body when client_secret_post is supported', async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validTokens + }); + + const tokens = await exchangeAuthorization('https://auth.example.com', { + metadata: metadataWithPostOnly, + clientInformation: validClientInfo, + authorizationCode: 'code123', + redirectUri: 'http://localhost:3000/callback', + codeVerifier: 'verifier123' + }); + + expect(tokens).toEqual(validTokens); + const request = mockFetch.mock.calls[0][1]; + + // Check no Authorization header + expect(request.headers.get('Authorization')).toBeNull(); + + const body = request.body as URLSearchParams; + expect(body.get('client_id')).toBe('client123'); + expect(body.get('client_secret')).toBe('secret123'); + }); + + it('it picks client_secret_basic when all builtin methods are supported', async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validTokens + }); + + const tokens = await exchangeAuthorization('https://auth.example.com', { + metadata: metadataWithAllBuiltinMethods, + clientInformation: validClientInfo, + authorizationCode: 'code123', + redirectUri: 'http://localhost:3000/callback', + codeVerifier: 'verifier123' + }); + + expect(tokens).toEqual(validTokens); + const request = mockFetch.mock.calls[0][1]; + + // Check Authorization header - should use Basic auth as it's the most secure + const authHeader = request.headers.get('Authorization'); + const expected = 'Basic ' + btoa('client123:secret123'); + expect(authHeader).toBe(expected); + + // Credentials should not be in body when using Basic auth + const body = request.body as URLSearchParams; + expect(body.get('client_id')).toBeNull(); + expect(body.get('client_secret')).toBeNull(); + }); + + it('uses public client authentication when none method is specified', async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validTokens + }); + + const clientInfoWithoutSecret = { + client_id: 'client123', + redirect_uris: ['http://localhost:3000/callback'], + client_name: 'Test Client' + }; + + const tokens = await exchangeAuthorization('https://auth.example.com', { + metadata: metadataWithNoneOnly, + clientInformation: clientInfoWithoutSecret, + authorizationCode: 'code123', + redirectUri: 'http://localhost:3000/callback', + codeVerifier: 'verifier123' + }); + + expect(tokens).toEqual(validTokens); + const request = mockFetch.mock.calls[0][1]; + + // Check no Authorization header + expect(request.headers.get('Authorization')).toBeNull(); + + const body = request.body as URLSearchParams; + expect(body.get('client_id')).toBe('client123'); + expect(body.get('client_secret')).toBeNull(); + }); + + it('defaults to client_secret_post when no auth methods specified', async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validTokens + }); + + const tokens = await exchangeAuthorization('https://auth.example.com', { + clientInformation: validClientInfo, + authorizationCode: 'code123', + redirectUri: 'http://localhost:3000/callback', + codeVerifier: 'verifier123' + }); + + expect(tokens).toEqual(validTokens); + const request = mockFetch.mock.calls[0][1]; + + // Check headers + expect(request.headers.get('Content-Type')).toBe('application/x-www-form-urlencoded'); + expect(request.headers.get('Authorization')).toBeNull(); + + const body = request.body as URLSearchParams; + expect(body.get('client_id')).toBe('client123'); + expect(body.get('client_secret')).toBe('secret123'); + }); + }); + + describe('refreshAuthorization with multiple client authentication methods', () => { + const validTokens = { + access_token: 'newaccess123', + token_type: 'Bearer', + expires_in: 3600, + refresh_token: 'newrefresh123' + }; - it("it picks client_secret_basic when all builtin methods are supported", async () => { - mockFetch.mockResolvedValueOnce({ - ok: true, - status: 200, - json: async () => validTokens, - }); - - const tokens = await exchangeAuthorization("https://auth.example.com", { - metadata: metadataWithAllBuiltinMethods, - clientInformation: validClientInfo, - authorizationCode: "code123", - redirectUri: "http://localhost:3000/callback", - codeVerifier: "verifier123", - }); - - expect(tokens).toEqual(validTokens); - const request = mockFetch.mock.calls[0][1]; - - // Check Authorization header - should use Basic auth as it's the most secure - const authHeader = request.headers.get("Authorization"); - const expected = "Basic " + btoa("client123:secret123"); - expect(authHeader).toBe(expected); - - // Credentials should not be in body when using Basic auth - const body = request.body as URLSearchParams; - expect(body.get("client_id")).toBeNull(); - expect(body.get("client_secret")).toBeNull(); - }); + const validClientInfo = { + client_id: 'client123', + client_secret: 'secret123', + redirect_uris: ['http://localhost:3000/callback'], + client_name: 'Test Client' + }; - it("uses public client authentication when none method is specified", async () => { - mockFetch.mockResolvedValueOnce({ - ok: true, - status: 200, - json: async () => validTokens, - }); - - const clientInfoWithoutSecret = { - client_id: "client123", - redirect_uris: ["http://localhost:3000/callback"], - client_name: "Test Client", - }; - - const tokens = await exchangeAuthorization("https://auth.example.com", { - metadata: metadataWithNoneOnly, - clientInformation: clientInfoWithoutSecret, - authorizationCode: "code123", - redirectUri: "http://localhost:3000/callback", - codeVerifier: "verifier123", - }); - - expect(tokens).toEqual(validTokens); - const request = mockFetch.mock.calls[0][1]; - - // Check no Authorization header - expect(request.headers.get("Authorization")).toBeNull(); - - const body = request.body as URLSearchParams; - expect(body.get("client_id")).toBe("client123"); - expect(body.get("client_secret")).toBeNull(); - }); + const metadataWithBasicOnly = { + issuer: 'https://auth.example.com', + authorization_endpoint: 'https://auth.example.com/auth', + token_endpoint: 'https://auth.example.com/token', + response_types_supported: ['code'], + token_endpoint_auth_methods_supported: ['client_secret_basic'] + }; - it("defaults to client_secret_post when no auth methods specified", async () => { - mockFetch.mockResolvedValueOnce({ - ok: true, - status: 200, - json: async () => validTokens, - }); - - const tokens = await exchangeAuthorization("https://auth.example.com", { - clientInformation: validClientInfo, - authorizationCode: "code123", - redirectUri: "http://localhost:3000/callback", - codeVerifier: "verifier123", - }); - - expect(tokens).toEqual(validTokens); - const request = mockFetch.mock.calls[0][1]; - - // Check headers - expect(request.headers.get("Content-Type")).toBe("application/x-www-form-urlencoded"); - expect(request.headers.get("Authorization")).toBeNull(); - - const body = request.body as URLSearchParams; - expect(body.get("client_id")).toBe("client123"); - expect(body.get("client_secret")).toBe("secret123"); - }); - }); - - describe("refreshAuthorization with multiple client authentication methods", () => { - const validTokens = { - access_token: "newaccess123", - token_type: "Bearer", - expires_in: 3600, - refresh_token: "newrefresh123", - }; - - const validClientInfo = { - client_id: "client123", - client_secret: "secret123", - redirect_uris: ["http://localhost:3000/callback"], - client_name: "Test Client", - }; - - const metadataWithBasicOnly = { - issuer: "https://auth.example.com", - authorization_endpoint: "https://auth.example.com/auth", - token_endpoint: "https://auth.example.com/token", - response_types_supported: ["code"], - token_endpoint_auth_methods_supported: ["client_secret_basic"], - }; - - const metadataWithPostOnly = { - ...metadataWithBasicOnly, - token_endpoint_auth_methods_supported: ["client_secret_post"], - }; - - it("uses client_secret_basic for refresh token", async () => { - mockFetch.mockResolvedValueOnce({ - ok: true, - status: 200, - json: async () => validTokens, - }); - - const tokens = await refreshAuthorization("https://auth.example.com", { - metadata: metadataWithBasicOnly, - clientInformation: validClientInfo, - refreshToken: "refresh123", - }); - - expect(tokens).toEqual(validTokens); - const request = mockFetch.mock.calls[0][1]; - - // Check Authorization header - const authHeader = request.headers.get("Authorization"); - const expected = "Basic " + btoa("client123:secret123"); - expect(authHeader).toBe(expected); - - const body = request.body as URLSearchParams; - expect(body.get("client_id")).toBeNull(); // should not be in body - expect(body.get("client_secret")).toBeNull(); // should not be in body - expect(body.get("refresh_token")).toBe("refresh123"); - }); + const metadataWithPostOnly = { + ...metadataWithBasicOnly, + token_endpoint_auth_methods_supported: ['client_secret_post'] + }; - it("uses client_secret_post for refresh token", async () => { - mockFetch.mockResolvedValueOnce({ - ok: true, - status: 200, - json: async () => validTokens, - }); - - const tokens = await refreshAuthorization("https://auth.example.com", { - metadata: metadataWithPostOnly, - clientInformation: validClientInfo, - refreshToken: "refresh123", - }); - - expect(tokens).toEqual(validTokens); - const request = mockFetch.mock.calls[0][1]; - - // Check no Authorization header - expect(request.headers.get("Authorization")).toBeNull(); - - const body = request.body as URLSearchParams; - expect(body.get("client_id")).toBe("client123"); - expect(body.get("client_secret")).toBe("secret123"); - expect(body.get("refresh_token")).toBe("refresh123"); + it('uses client_secret_basic for refresh token', async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validTokens + }); + + const tokens = await refreshAuthorization('https://auth.example.com', { + metadata: metadataWithBasicOnly, + clientInformation: validClientInfo, + refreshToken: 'refresh123' + }); + + expect(tokens).toEqual(validTokens); + const request = mockFetch.mock.calls[0][1]; + + // Check Authorization header + const authHeader = request.headers.get('Authorization'); + const expected = 'Basic ' + btoa('client123:secret123'); + expect(authHeader).toBe(expected); + + const body = request.body as URLSearchParams; + expect(body.get('client_id')).toBeNull(); // should not be in body + expect(body.get('client_secret')).toBeNull(); // should not be in body + expect(body.get('refresh_token')).toBe('refresh123'); + }); + + it('uses client_secret_post for refresh token', async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validTokens + }); + + const tokens = await refreshAuthorization('https://auth.example.com', { + metadata: metadataWithPostOnly, + clientInformation: validClientInfo, + refreshToken: 'refresh123' + }); + + expect(tokens).toEqual(validTokens); + const request = mockFetch.mock.calls[0][1]; + + // Check no Authorization header + expect(request.headers.get('Authorization')).toBeNull(); + + const body = request.body as URLSearchParams; + expect(body.get('client_id')).toBe('client123'); + expect(body.get('client_secret')).toBe('secret123'); + expect(body.get('refresh_token')).toBe('refresh123'); + }); }); - }); - }); diff --git a/src/client/auth.ts b/src/client/auth.ts index fcc320f17..d5d39cad4 100644 --- a/src/client/auth.ts +++ b/src/client/auth.ts @@ -1,27 +1,32 @@ -import pkceChallenge from "pkce-challenge"; -import { LATEST_PROTOCOL_VERSION } from "../types.js"; +import pkceChallenge from 'pkce-challenge'; +import { LATEST_PROTOCOL_VERSION } from '../types.js'; import { - OAuthClientMetadata, - OAuthClientInformation, - OAuthTokens, - OAuthMetadata, - OAuthClientInformationFull, - OAuthProtectedResourceMetadata, - OAuthErrorResponseSchema, - AuthorizationServerMetadata, - OpenIdProviderDiscoveryMetadataSchema -} from "../shared/auth.js"; -import { OAuthClientInformationFullSchema, OAuthMetadataSchema, OAuthProtectedResourceMetadataSchema, OAuthTokensSchema } from "../shared/auth.js"; -import { checkResourceAllowed, resourceUrlFromServerUrl } from "../shared/auth-utils.js"; + OAuthClientMetadata, + OAuthClientInformation, + OAuthTokens, + OAuthMetadata, + OAuthClientInformationFull, + OAuthProtectedResourceMetadata, + OAuthErrorResponseSchema, + AuthorizationServerMetadata, + OpenIdProviderDiscoveryMetadataSchema +} from '../shared/auth.js'; import { - InvalidClientError, - InvalidGrantError, - OAUTH_ERRORS, - OAuthError, - ServerError, - UnauthorizedClientError -} from "../server/auth/errors.js"; -import { FetchLike } from "../shared/transport.js"; + OAuthClientInformationFullSchema, + OAuthMetadataSchema, + OAuthProtectedResourceMetadataSchema, + OAuthTokensSchema +} from '../shared/auth.js'; +import { checkResourceAllowed, resourceUrlFromServerUrl } from '../shared/auth-utils.js'; +import { + InvalidClientError, + InvalidGrantError, + OAUTH_ERRORS, + OAuthError, + ServerError, + UnauthorizedClientError +} from '../server/auth/errors.js'; +import { FetchLike } from '../shared/transport.js'; /** * Implements an end-to-end OAuth client to be used with one MCP server. @@ -31,110 +36,115 @@ import { FetchLike } from "../shared/transport.js"; * code verifiers should not cross different sessions. */ export interface OAuthClientProvider { - /** - * The URL to redirect the user agent to after authorization. - */ - get redirectUrl(): string | URL; - - /** - * Metadata about this OAuth client. - */ - get clientMetadata(): OAuthClientMetadata; - - /** - * Returns a OAuth2 state parameter. - */ - state?(): string | Promise; - - /** - * Loads information about this OAuth client, as registered already with the - * server, or returns `undefined` if the client is not registered with the - * server. - */ - clientInformation(): OAuthClientInformation | undefined | Promise; - - /** - * If implemented, this permits the OAuth client to dynamically register with - * the server. Client information saved this way should later be read via - * `clientInformation()`. - * - * This method is not required to be implemented if client information is - * statically known (e.g., pre-registered). - */ - saveClientInformation?(clientInformation: OAuthClientInformationFull): void | Promise; - - /** - * Loads any existing OAuth tokens for the current session, or returns - * `undefined` if there are no saved tokens. - */ - tokens(): OAuthTokens | undefined | Promise; - - /** - * Stores new OAuth tokens for the current session, after a successful - * authorization. - */ - saveTokens(tokens: OAuthTokens): void | Promise; - - /** - * Invoked to redirect the user agent to the given URL to begin the authorization flow. - */ - redirectToAuthorization(authorizationUrl: URL): void | Promise; - - /** - * Saves a PKCE code verifier for the current session, before redirecting to - * the authorization flow. - */ - saveCodeVerifier(codeVerifier: string): void | Promise; - - /** - * Loads the PKCE code verifier for the current session, necessary to validate - * the authorization result. - */ - codeVerifier(): string | Promise; - - /** - * Adds custom client authentication to OAuth token requests. - * - * This optional method allows implementations to customize how client credentials - * are included in token exchange and refresh requests. When provided, this method - * is called instead of the default authentication logic, giving full control over - * the authentication mechanism. - * - * Common use cases include: - * - Supporting authentication methods beyond the standard OAuth 2.0 methods - * - Adding custom headers for proprietary authentication schemes - * - Implementing client assertion-based authentication (e.g., JWT bearer tokens) - * - * @param headers - The request headers (can be modified to add authentication) - * @param params - The request body parameters (can be modified to add credentials) - * @param url - The token endpoint URL being called - * @param metadata - Optional OAuth metadata for the server, which may include supported authentication methods - */ - addClientAuthentication?(headers: Headers, params: URLSearchParams, url: string | URL, metadata?: AuthorizationServerMetadata): void | Promise; - - /** - * If defined, overrides the selection and validation of the - * RFC 8707 Resource Indicator. If left undefined, default - * validation behavior will be used. - * - * Implementations must verify the returned resource matches the MCP server. - */ - validateResourceURL?(serverUrl: string | URL, resource?: string): Promise; - - /** - * If implemented, provides a way for the client to invalidate (e.g. delete) the specified - * credentials, in the case where the server has indicated that they are no longer valid. - * This avoids requiring the user to intervene manually. - */ - invalidateCredentials?(scope: 'all' | 'client' | 'tokens' | 'verifier'): void | Promise; + /** + * The URL to redirect the user agent to after authorization. + */ + get redirectUrl(): string | URL; + + /** + * Metadata about this OAuth client. + */ + get clientMetadata(): OAuthClientMetadata; + + /** + * Returns a OAuth2 state parameter. + */ + state?(): string | Promise; + + /** + * Loads information about this OAuth client, as registered already with the + * server, or returns `undefined` if the client is not registered with the + * server. + */ + clientInformation(): OAuthClientInformation | undefined | Promise; + + /** + * If implemented, this permits the OAuth client to dynamically register with + * the server. Client information saved this way should later be read via + * `clientInformation()`. + * + * This method is not required to be implemented if client information is + * statically known (e.g., pre-registered). + */ + saveClientInformation?(clientInformation: OAuthClientInformationFull): void | Promise; + + /** + * Loads any existing OAuth tokens for the current session, or returns + * `undefined` if there are no saved tokens. + */ + tokens(): OAuthTokens | undefined | Promise; + + /** + * Stores new OAuth tokens for the current session, after a successful + * authorization. + */ + saveTokens(tokens: OAuthTokens): void | Promise; + + /** + * Invoked to redirect the user agent to the given URL to begin the authorization flow. + */ + redirectToAuthorization(authorizationUrl: URL): void | Promise; + + /** + * Saves a PKCE code verifier for the current session, before redirecting to + * the authorization flow. + */ + saveCodeVerifier(codeVerifier: string): void | Promise; + + /** + * Loads the PKCE code verifier for the current session, necessary to validate + * the authorization result. + */ + codeVerifier(): string | Promise; + + /** + * Adds custom client authentication to OAuth token requests. + * + * This optional method allows implementations to customize how client credentials + * are included in token exchange and refresh requests. When provided, this method + * is called instead of the default authentication logic, giving full control over + * the authentication mechanism. + * + * Common use cases include: + * - Supporting authentication methods beyond the standard OAuth 2.0 methods + * - Adding custom headers for proprietary authentication schemes + * - Implementing client assertion-based authentication (e.g., JWT bearer tokens) + * + * @param headers - The request headers (can be modified to add authentication) + * @param params - The request body parameters (can be modified to add credentials) + * @param url - The token endpoint URL being called + * @param metadata - Optional OAuth metadata for the server, which may include supported authentication methods + */ + addClientAuthentication?( + headers: Headers, + params: URLSearchParams, + url: string | URL, + metadata?: AuthorizationServerMetadata + ): void | Promise; + + /** + * If defined, overrides the selection and validation of the + * RFC 8707 Resource Indicator. If left undefined, default + * validation behavior will be used. + * + * Implementations must verify the returned resource matches the MCP server. + */ + validateResourceURL?(serverUrl: string | URL, resource?: string): Promise; + + /** + * If implemented, provides a way for the client to invalidate (e.g. delete) the specified + * credentials, in the case where the server has indicated that they are no longer valid. + * This avoids requiring the user to intervene manually. + */ + invalidateCredentials?(scope: 'all' | 'client' | 'tokens' | 'verifier'): void | Promise; } -export type AuthResult = "AUTHORIZED" | "REDIRECT"; +export type AuthResult = 'AUTHORIZED' | 'REDIRECT'; export class UnauthorizedError extends Error { - constructor(message?: string) { - super(message ?? "Unauthorized"); - } + constructor(message?: string) { + super(message ?? 'Unauthorized'); + } } type ClientAuthMethod = 'client_secret_basic' | 'client_secret_post' | 'none'; @@ -151,32 +161,29 @@ type ClientAuthMethod = 'client_secret_basic' | 'client_secret_post' | 'none'; * @param supportedMethods - Authentication methods supported by the authorization server * @returns The selected authentication method */ -function selectClientAuthMethod( - clientInformation: OAuthClientInformation, - supportedMethods: string[] -): ClientAuthMethod { - const hasClientSecret = clientInformation.client_secret !== undefined; - - // If server doesn't specify supported methods, use RFC 6749 defaults - if (supportedMethods.length === 0) { - return hasClientSecret ? "client_secret_post" : "none"; - } - - // Try methods in priority order (most secure first) - if (hasClientSecret && supportedMethods.includes("client_secret_basic")) { - return "client_secret_basic"; - } - - if (hasClientSecret && supportedMethods.includes("client_secret_post")) { - return "client_secret_post"; - } - - if (supportedMethods.includes("none")) { - return "none"; - } - - // Fallback: use what we have - return hasClientSecret ? "client_secret_post" : "none"; +function selectClientAuthMethod(clientInformation: OAuthClientInformation, supportedMethods: string[]): ClientAuthMethod { + const hasClientSecret = clientInformation.client_secret !== undefined; + + // If server doesn't specify supported methods, use RFC 6749 defaults + if (supportedMethods.length === 0) { + return hasClientSecret ? 'client_secret_post' : 'none'; + } + + // Try methods in priority order (most secure first) + if (hasClientSecret && supportedMethods.includes('client_secret_basic')) { + return 'client_secret_basic'; + } + + if (hasClientSecret && supportedMethods.includes('client_secret_post')) { + return 'client_secret_post'; + } + + if (supportedMethods.includes('none')) { + return 'none'; + } + + // Fallback: use what we have + return hasClientSecret ? 'client_secret_post' : 'none'; } /** @@ -194,55 +201,55 @@ function selectClientAuthMethod( * @throws {Error} When required credentials are missing */ function applyClientAuthentication( - method: ClientAuthMethod, - clientInformation: OAuthClientInformation, - headers: Headers, - params: URLSearchParams + method: ClientAuthMethod, + clientInformation: OAuthClientInformation, + headers: Headers, + params: URLSearchParams ): void { - const { client_id, client_secret } = clientInformation; - - switch (method) { - case "client_secret_basic": - applyBasicAuth(client_id, client_secret, headers); - return; - case "client_secret_post": - applyPostAuth(client_id, client_secret, params); - return; - case "none": - applyPublicAuth(client_id, params); - return; - default: - throw new Error(`Unsupported client authentication method: ${method}`); - } + const { client_id, client_secret } = clientInformation; + + switch (method) { + case 'client_secret_basic': + applyBasicAuth(client_id, client_secret, headers); + return; + case 'client_secret_post': + applyPostAuth(client_id, client_secret, params); + return; + case 'none': + applyPublicAuth(client_id, params); + return; + default: + throw new Error(`Unsupported client authentication method: ${method}`); + } } /** * Applies HTTP Basic authentication (RFC 6749 Section 2.3.1) */ function applyBasicAuth(clientId: string, clientSecret: string | undefined, headers: Headers): void { - if (!clientSecret) { - throw new Error("client_secret_basic authentication requires a client_secret"); - } + if (!clientSecret) { + throw new Error('client_secret_basic authentication requires a client_secret'); + } - const credentials = btoa(`${clientId}:${clientSecret}`); - headers.set("Authorization", `Basic ${credentials}`); + const credentials = btoa(`${clientId}:${clientSecret}`); + headers.set('Authorization', `Basic ${credentials}`); } /** * Applies POST body authentication (RFC 6749 Section 2.3.1) */ function applyPostAuth(clientId: string, clientSecret: string | undefined, params: URLSearchParams): void { - params.set("client_id", clientId); - if (clientSecret) { - params.set("client_secret", clientSecret); - } + params.set('client_id', clientId); + if (clientSecret) { + params.set('client_secret', clientSecret); + } } /** * Applies public client authentication (RFC 6749 Section 2.1) */ function applyPublicAuth(clientId: string, params: URLSearchParams): void { - params.set("client_id", clientId); + params.set('client_id', clientId); } /** @@ -257,19 +264,19 @@ function applyPublicAuth(clientId: string, params: URLSearchParams): void { * @returns A Promise that resolves to an OAuthError instance */ export async function parseErrorResponse(input: Response | string): Promise { - const statusCode = input instanceof Response ? input.status : undefined; - const body = input instanceof Response ? await input.text() : input; - - try { - const result = OAuthErrorResponseSchema.parse(JSON.parse(body)); - const { error, error_description, error_uri } = result; - const errorClass = OAUTH_ERRORS[error] || ServerError; - return new errorClass(error_description || '', error_uri); - } catch (error) { - // Not a valid OAuth error response, but try to inform the user of the raw data anyway - const errorMessage = `${statusCode ? `HTTP ${statusCode}: ` : ''}Invalid OAuth error response: ${error}. Raw body: ${body}`; - return new ServerError(errorMessage); - } + const statusCode = input instanceof Response ? input.status : undefined; + const body = input instanceof Response ? await input.text() : input; + + try { + const result = OAuthErrorResponseSchema.parse(JSON.parse(body)); + const { error, error_description, error_uri } = result; + const errorClass = OAUTH_ERRORS[error] || ServerError; + return new errorClass(error_description || '', error_uri); + } catch (error) { + // Not a valid OAuth error response, but try to inform the user of the raw data anyway + const errorMessage = `${statusCode ? `HTTP ${statusCode}: ` : ''}Invalid OAuth error response: ${error}. Raw body: ${body}`; + return new ServerError(errorMessage); + } } /** @@ -279,203 +286,207 @@ export async function parseErrorResponse(input: Response | string): Promise { - try { - return await authInternal(provider, options); - } catch (error) { - // Handle recoverable error types by invalidating credentials and retrying - if (error instanceof InvalidClientError || error instanceof UnauthorizedClientError) { - await provider.invalidateCredentials?.('all'); - return await authInternal(provider, options); - } else if (error instanceof InvalidGrantError) { - await provider.invalidateCredentials?.('tokens'); - return await authInternal(provider, options); + provider: OAuthClientProvider, + options: { + serverUrl: string | URL; + authorizationCode?: string; + scope?: string; + resourceMetadataUrl?: URL; + fetchFn?: FetchLike; + } +): Promise { + try { + return await authInternal(provider, options); + } catch (error) { + // Handle recoverable error types by invalidating credentials and retrying + if (error instanceof InvalidClientError || error instanceof UnauthorizedClientError) { + await provider.invalidateCredentials?.('all'); + return await authInternal(provider, options); + } else if (error instanceof InvalidGrantError) { + await provider.invalidateCredentials?.('tokens'); + return await authInternal(provider, options); + } + + // Throw otherwise + throw error; } - - // Throw otherwise - throw error - } } async function authInternal( - provider: OAuthClientProvider, - { serverUrl, - authorizationCode, - scope, - resourceMetadataUrl, - fetchFn, - }: { - serverUrl: string | URL; - authorizationCode?: string; - scope?: string; - resourceMetadataUrl?: URL; - fetchFn?: FetchLike; - }, -): Promise { - - let resourceMetadata: OAuthProtectedResourceMetadata | undefined; - let authorizationServerUrl: string | URL | undefined; - try { - resourceMetadata = await discoverOAuthProtectedResourceMetadata(serverUrl, { resourceMetadataUrl }, fetchFn); - if (resourceMetadata.authorization_servers && resourceMetadata.authorization_servers.length > 0) { - authorizationServerUrl = resourceMetadata.authorization_servers[0]; + provider: OAuthClientProvider, + { + serverUrl, + authorizationCode, + scope, + resourceMetadataUrl, + fetchFn + }: { + serverUrl: string | URL; + authorizationCode?: string; + scope?: string; + resourceMetadataUrl?: URL; + fetchFn?: FetchLike; } - } catch { - // Ignore errors and fall back to /.well-known/oauth-authorization-server - } - - /** - * If we don't get a valid authorization server metadata from protected resource metadata, - * fallback to the legacy MCP spec's implementation (version 2025-03-26): MCP server acts as the Authorization server. - */ - if (!authorizationServerUrl) { - authorizationServerUrl = serverUrl; - } - - const resource: URL | undefined = await selectResourceURL(serverUrl, provider, resourceMetadata); - - const metadata = await discoverAuthorizationServerMetadata(authorizationServerUrl, { - fetchFn, - }); - - // Handle client registration if needed - let clientInformation = await Promise.resolve(provider.clientInformation()); - if (!clientInformation) { - if (authorizationCode !== undefined) { - throw new Error("Existing OAuth client information is required when exchanging an authorization code"); +): Promise { + let resourceMetadata: OAuthProtectedResourceMetadata | undefined; + let authorizationServerUrl: string | URL | undefined; + try { + resourceMetadata = await discoverOAuthProtectedResourceMetadata(serverUrl, { resourceMetadataUrl }, fetchFn); + if (resourceMetadata.authorization_servers && resourceMetadata.authorization_servers.length > 0) { + authorizationServerUrl = resourceMetadata.authorization_servers[0]; + } + } catch { + // Ignore errors and fall back to /.well-known/oauth-authorization-server } - if (!provider.saveClientInformation) { - throw new Error("OAuth client information must be saveable for dynamic registration"); + /** + * If we don't get a valid authorization server metadata from protected resource metadata, + * fallback to the legacy MCP spec's implementation (version 2025-03-26): MCP server acts as the Authorization server. + */ + if (!authorizationServerUrl) { + authorizationServerUrl = serverUrl; } - const fullInformation = await registerClient(authorizationServerUrl, { - metadata, - clientMetadata: provider.clientMetadata, - fetchFn, - }); + const resource: URL | undefined = await selectResourceURL(serverUrl, provider, resourceMetadata); - await provider.saveClientInformation(fullInformation); - clientInformation = fullInformation; - } - - // Exchange authorization code for tokens - if (authorizationCode !== undefined) { - const codeVerifier = await provider.codeVerifier(); - const tokens = await exchangeAuthorization(authorizationServerUrl, { - metadata, - clientInformation, - authorizationCode, - codeVerifier, - redirectUri: provider.redirectUrl, - resource, - addClientAuthentication: provider.addClientAuthentication, - fetchFn: fetchFn, + const metadata = await discoverAuthorizationServerMetadata(authorizationServerUrl, { + fetchFn }); - await provider.saveTokens(tokens); - return "AUTHORIZED" - } + // Handle client registration if needed + let clientInformation = await Promise.resolve(provider.clientInformation()); + if (!clientInformation) { + if (authorizationCode !== undefined) { + throw new Error('Existing OAuth client information is required when exchanging an authorization code'); + } + + if (!provider.saveClientInformation) { + throw new Error('OAuth client information must be saveable for dynamic registration'); + } + + const fullInformation = await registerClient(authorizationServerUrl, { + metadata, + clientMetadata: provider.clientMetadata, + fetchFn + }); + + await provider.saveClientInformation(fullInformation); + clientInformation = fullInformation; + } + + // Exchange authorization code for tokens + if (authorizationCode !== undefined) { + const codeVerifier = await provider.codeVerifier(); + const tokens = await exchangeAuthorization(authorizationServerUrl, { + metadata, + clientInformation, + authorizationCode, + codeVerifier, + redirectUri: provider.redirectUrl, + resource, + addClientAuthentication: provider.addClientAuthentication, + fetchFn: fetchFn + }); + + await provider.saveTokens(tokens); + return 'AUTHORIZED'; + } + + const tokens = await provider.tokens(); + + // Handle token refresh or new authorization + if (tokens?.refresh_token) { + try { + // Attempt to refresh the token + const newTokens = await refreshAuthorization(authorizationServerUrl, { + metadata, + clientInformation, + refreshToken: tokens.refresh_token, + resource, + addClientAuthentication: provider.addClientAuthentication, + fetchFn + }); + + await provider.saveTokens(newTokens); + return 'AUTHORIZED'; + } catch (error) { + // If this is a ServerError, or an unknown type, log it out and try to continue. Otherwise, escalate so we can fix things and retry. + if (!(error instanceof OAuthError) || error instanceof ServerError) { + // Could not refresh OAuth tokens + } else { + // Refresh failed for another reason, re-throw + throw error; + } + } + } - const tokens = await provider.tokens(); + const state = provider.state ? await provider.state() : undefined; - // Handle token refresh or new authorization - if (tokens?.refresh_token) { - try { - // Attempt to refresh the token - const newTokens = await refreshAuthorization(authorizationServerUrl, { + // Start new authorization flow + const { authorizationUrl, codeVerifier } = await startAuthorization(authorizationServerUrl, { metadata, clientInformation, - refreshToken: tokens.refresh_token, - resource, - addClientAuthentication: provider.addClientAuthentication, - fetchFn, - }); + state, + redirectUrl: provider.redirectUrl, + scope: scope || provider.clientMetadata.scope, + resource + }); - await provider.saveTokens(newTokens); - return "AUTHORIZED" - } catch (error) { - // If this is a ServerError, or an unknown type, log it out and try to continue. Otherwise, escalate so we can fix things and retry. - if (!(error instanceof OAuthError) || error instanceof ServerError) { - // Could not refresh OAuth tokens - } else { - // Refresh failed for another reason, re-throw - throw error; - } - } - } - - const state = provider.state ? await provider.state() : undefined; - - // Start new authorization flow - const { authorizationUrl, codeVerifier } = await startAuthorization(authorizationServerUrl, { - metadata, - clientInformation, - state, - redirectUrl: provider.redirectUrl, - scope: scope || provider.clientMetadata.scope, - resource, - }); - - await provider.saveCodeVerifier(codeVerifier); - await provider.redirectToAuthorization(authorizationUrl); - return "REDIRECT" + await provider.saveCodeVerifier(codeVerifier); + await provider.redirectToAuthorization(authorizationUrl); + return 'REDIRECT'; } -export async function selectResourceURL(serverUrl: string | URL, provider: OAuthClientProvider, resourceMetadata?: OAuthProtectedResourceMetadata): Promise { - const defaultResource = resourceUrlFromServerUrl(serverUrl); +export async function selectResourceURL( + serverUrl: string | URL, + provider: OAuthClientProvider, + resourceMetadata?: OAuthProtectedResourceMetadata +): Promise { + const defaultResource = resourceUrlFromServerUrl(serverUrl); - // If provider has custom validation, delegate to it - if (provider.validateResourceURL) { - return await provider.validateResourceURL(defaultResource, resourceMetadata?.resource); - } + // If provider has custom validation, delegate to it + if (provider.validateResourceURL) { + return await provider.validateResourceURL(defaultResource, resourceMetadata?.resource); + } - // Only include resource parameter when Protected Resource Metadata is present - if (!resourceMetadata) { - return undefined; - } - - // Validate that the metadata's resource is compatible with our request - if (!checkResourceAllowed({ requestedResource: defaultResource, configuredResource: resourceMetadata.resource })) { - throw new Error(`Protected resource ${resourceMetadata.resource} does not match expected ${defaultResource} (or origin)`); - } - // Prefer the resource from metadata since it's what the server is telling us to request - return new URL(resourceMetadata.resource); + // Only include resource parameter when Protected Resource Metadata is present + if (!resourceMetadata) { + return undefined; + } + + // Validate that the metadata's resource is compatible with our request + if (!checkResourceAllowed({ requestedResource: defaultResource, configuredResource: resourceMetadata.resource })) { + throw new Error(`Protected resource ${resourceMetadata.resource} does not match expected ${defaultResource} (or origin)`); + } + // Prefer the resource from metadata since it's what the server is telling us to request + return new URL(resourceMetadata.resource); } /** * Extract resource_metadata from response header. */ export function extractResourceMetadataUrl(res: Response): URL | undefined { + const authenticateHeader = res.headers.get('WWW-Authenticate'); + if (!authenticateHeader) { + return undefined; + } - const authenticateHeader = res.headers.get("WWW-Authenticate"); - if (!authenticateHeader) { - return undefined; - } - - const [type, scheme] = authenticateHeader.split(' '); - if (type.toLowerCase() !== 'bearer' || !scheme) { - return undefined; - } - const regex = /resource_metadata="([^"]*)"/; - const match = regex.exec(authenticateHeader); + const [type, scheme] = authenticateHeader.split(' '); + if (type.toLowerCase() !== 'bearer' || !scheme) { + return undefined; + } + const regex = /resource_metadata="([^"]*)"/; + const match = regex.exec(authenticateHeader); - if (!match) { - return undefined; - } + if (!match) { + return undefined; + } - try { - return new URL(match[1]); - } catch { - return undefined; - } + try { + return new URL(match[1]); + } catch { + return undefined; + } } /** @@ -485,126 +496,109 @@ export function extractResourceMetadataUrl(res: Response): URL | undefined { * return `undefined`. Any other errors will be thrown as exceptions. */ export async function discoverOAuthProtectedResourceMetadata( - serverUrl: string | URL, - opts?: { protocolVersion?: string, resourceMetadataUrl?: string | URL }, - fetchFn: FetchLike = fetch, + serverUrl: string | URL, + opts?: { protocolVersion?: string; resourceMetadataUrl?: string | URL }, + fetchFn: FetchLike = fetch ): Promise { - const response = await discoverMetadataWithFallback( - serverUrl, - 'oauth-protected-resource', - fetchFn, - { - protocolVersion: opts?.protocolVersion, - metadataUrl: opts?.resourceMetadataUrl, - }, - ); - - if (!response || response.status === 404) { - throw new Error(`Resource server does not implement OAuth 2.0 Protected Resource Metadata.`); - } - - if (!response.ok) { - throw new Error( - `HTTP ${response.status} trying to load well-known OAuth protected resource metadata.`, - ); - } - return OAuthProtectedResourceMetadataSchema.parse(await response.json()); + const response = await discoverMetadataWithFallback(serverUrl, 'oauth-protected-resource', fetchFn, { + protocolVersion: opts?.protocolVersion, + metadataUrl: opts?.resourceMetadataUrl + }); + + if (!response || response.status === 404) { + throw new Error(`Resource server does not implement OAuth 2.0 Protected Resource Metadata.`); + } + + if (!response.ok) { + throw new Error(`HTTP ${response.status} trying to load well-known OAuth protected resource metadata.`); + } + return OAuthProtectedResourceMetadataSchema.parse(await response.json()); } /** * Helper function to handle fetch with CORS retry logic */ -async function fetchWithCorsRetry( - url: URL, - headers?: Record, - fetchFn: FetchLike = fetch, -): Promise { - try { - return await fetchFn(url, { headers }); - } catch (error) { - if (error instanceof TypeError) { - if (headers) { - // CORS errors come back as TypeError, retry without headers - return fetchWithCorsRetry(url, undefined, fetchFn) - } else { - // We're getting CORS errors on retry too, return undefined - return undefined - } +async function fetchWithCorsRetry(url: URL, headers?: Record, fetchFn: FetchLike = fetch): Promise { + try { + return await fetchFn(url, { headers }); + } catch (error) { + if (error instanceof TypeError) { + if (headers) { + // CORS errors come back as TypeError, retry without headers + return fetchWithCorsRetry(url, undefined, fetchFn); + } else { + // We're getting CORS errors on retry too, return undefined + return undefined; + } + } + throw error; } - throw error; - } } /** * Constructs the well-known path for auth-related metadata discovery */ function buildWellKnownPath( - wellKnownPrefix: 'oauth-authorization-server' | 'oauth-protected-resource' | 'openid-configuration', - pathname: string = '', - options: { prependPathname?: boolean } = {} + wellKnownPrefix: 'oauth-authorization-server' | 'oauth-protected-resource' | 'openid-configuration', + pathname: string = '', + options: { prependPathname?: boolean } = {} ): string { - // Strip trailing slash from pathname to avoid double slashes - if (pathname.endsWith('/')) { - pathname = pathname.slice(0, -1); - } - - return options.prependPathname - ? `${pathname}/.well-known/${wellKnownPrefix}` - : `/.well-known/${wellKnownPrefix}${pathname}`; + // Strip trailing slash from pathname to avoid double slashes + if (pathname.endsWith('/')) { + pathname = pathname.slice(0, -1); + } + + return options.prependPathname ? `${pathname}/.well-known/${wellKnownPrefix}` : `/.well-known/${wellKnownPrefix}${pathname}`; } /** * Tries to discover OAuth metadata at a specific URL */ -async function tryMetadataDiscovery( - url: URL, - protocolVersion: string, - fetchFn: FetchLike = fetch, -): Promise { - const headers = { - "MCP-Protocol-Version": protocolVersion - }; - return await fetchWithCorsRetry(url, headers, fetchFn); +async function tryMetadataDiscovery(url: URL, protocolVersion: string, fetchFn: FetchLike = fetch): Promise { + const headers = { + 'MCP-Protocol-Version': protocolVersion + }; + return await fetchWithCorsRetry(url, headers, fetchFn); } /** * Determines if fallback to root discovery should be attempted */ function shouldAttemptFallback(response: Response | undefined, pathname: string): boolean { - return !response || (response.status >= 400 && response.status < 500) && pathname !== '/'; + return !response || (response.status >= 400 && response.status < 500 && pathname !== '/'); } /** * Generic function for discovering OAuth metadata with fallback support */ async function discoverMetadataWithFallback( - serverUrl: string | URL, - wellKnownType: 'oauth-authorization-server' | 'oauth-protected-resource', - fetchFn: FetchLike, - opts?: { protocolVersion?: string; metadataUrl?: string | URL, metadataServerUrl?: string | URL }, + serverUrl: string | URL, + wellKnownType: 'oauth-authorization-server' | 'oauth-protected-resource', + fetchFn: FetchLike, + opts?: { protocolVersion?: string; metadataUrl?: string | URL; metadataServerUrl?: string | URL } ): Promise { - const issuer = new URL(serverUrl); - const protocolVersion = opts?.protocolVersion ?? LATEST_PROTOCOL_VERSION; - - let url: URL; - if (opts?.metadataUrl) { - url = new URL(opts.metadataUrl); - } else { - // Try path-aware discovery first - const wellKnownPath = buildWellKnownPath(wellKnownType, issuer.pathname); - url = new URL(wellKnownPath, opts?.metadataServerUrl ?? issuer); - url.search = issuer.search; - } - - let response = await tryMetadataDiscovery(url, protocolVersion, fetchFn); - - // If path-aware discovery fails with 404 and we're not already at root, try fallback to root discovery - if (!opts?.metadataUrl && shouldAttemptFallback(response, issuer.pathname)) { - const rootUrl = new URL(`/.well-known/${wellKnownType}`, issuer); - response = await tryMetadataDiscovery(rootUrl, protocolVersion, fetchFn); - } - - return response; + const issuer = new URL(serverUrl); + const protocolVersion = opts?.protocolVersion ?? LATEST_PROTOCOL_VERSION; + + let url: URL; + if (opts?.metadataUrl) { + url = new URL(opts.metadataUrl); + } else { + // Try path-aware discovery first + const wellKnownPath = buildWellKnownPath(wellKnownType, issuer.pathname); + url = new URL(wellKnownPath, opts?.metadataServerUrl ?? issuer); + url.search = issuer.search; + } + + let response = await tryMetadataDiscovery(url, protocolVersion, fetchFn); + + // If path-aware discovery fails with 404 and we're not already at root, try fallback to root discovery + if (!opts?.metadataUrl && shouldAttemptFallback(response, issuer.pathname)) { + const rootUrl = new URL(`/.well-known/${wellKnownType}`, issuer); + response = await tryMetadataDiscovery(rootUrl, protocolVersion, fetchFn); + } + + return response; } /** @@ -616,50 +610,42 @@ async function discoverMetadataWithFallback( * @deprecated This function is deprecated in favor of `discoverAuthorizationServerMetadata`. */ export async function discoverOAuthMetadata( - issuer: string | URL, - { - authorizationServerUrl, - protocolVersion, - }: { - authorizationServerUrl?: string | URL, - protocolVersion?: string, - } = {}, - fetchFn: FetchLike = fetch, -): Promise { - if (typeof issuer === 'string') { - issuer = new URL(issuer); - } - if (!authorizationServerUrl) { - authorizationServerUrl = issuer; - } - if (typeof authorizationServerUrl === 'string') { - authorizationServerUrl = new URL(authorizationServerUrl); - } - protocolVersion ??= LATEST_PROTOCOL_VERSION ; - - const response = await discoverMetadataWithFallback( - authorizationServerUrl, - 'oauth-authorization-server', - fetchFn, + issuer: string | URL, { - protocolVersion, - metadataServerUrl: authorizationServerUrl, - }, - ); + authorizationServerUrl, + protocolVersion + }: { + authorizationServerUrl?: string | URL; + protocolVersion?: string; + } = {}, + fetchFn: FetchLike = fetch +): Promise { + if (typeof issuer === 'string') { + issuer = new URL(issuer); + } + if (!authorizationServerUrl) { + authorizationServerUrl = issuer; + } + if (typeof authorizationServerUrl === 'string') { + authorizationServerUrl = new URL(authorizationServerUrl); + } + protocolVersion ??= LATEST_PROTOCOL_VERSION; - if (!response || response.status === 404) { - return undefined; - } + const response = await discoverMetadataWithFallback(authorizationServerUrl, 'oauth-authorization-server', fetchFn, { + protocolVersion, + metadataServerUrl: authorizationServerUrl + }); - if (!response.ok) { - throw new Error( - `HTTP ${response.status} trying to load well-known OAuth metadata`, - ); - } + if (!response || response.status === 404) { + return undefined; + } - return OAuthMetadataSchema.parse(await response.json()); -} + if (!response.ok) { + throw new Error(`HTTP ${response.status} trying to load well-known OAuth metadata`); + } + return OAuthMetadataSchema.parse(await response.json()); +} /** * Builds a list of discovery URLs to try for authorization server metadata. @@ -669,59 +655,58 @@ export async function discoverOAuthMetadata( * 3. OIDC metadata endpoints */ export function buildDiscoveryUrls(authorizationServerUrl: string | URL): { url: URL; type: 'oauth' | 'oidc' }[] { - const url = typeof authorizationServerUrl === 'string' ? new URL(authorizationServerUrl) : authorizationServerUrl; - const hasPath = url.pathname !== '/'; - const urlsToTry: { url: URL; type: 'oauth' | 'oidc' }[] = []; + const url = typeof authorizationServerUrl === 'string' ? new URL(authorizationServerUrl) : authorizationServerUrl; + const hasPath = url.pathname !== '/'; + const urlsToTry: { url: URL; type: 'oauth' | 'oidc' }[] = []; + + if (!hasPath) { + // Root path: https://example.com/.well-known/oauth-authorization-server + urlsToTry.push({ + url: new URL('/.well-known/oauth-authorization-server', url.origin), + type: 'oauth' + }); + + // OIDC: https://example.com/.well-known/openid-configuration + urlsToTry.push({ + url: new URL(`/.well-known/openid-configuration`, url.origin), + type: 'oidc' + }); + + return urlsToTry; + } + + // Strip trailing slash from pathname to avoid double slashes + let pathname = url.pathname; + if (pathname.endsWith('/')) { + pathname = pathname.slice(0, -1); + } + // 1. OAuth metadata at the given URL + // Insert well-known before the path: https://example.com/.well-known/oauth-authorization-server/tenant1 + urlsToTry.push({ + url: new URL(`/.well-known/oauth-authorization-server${pathname}`, url.origin), + type: 'oauth' + }); - if (!hasPath) { // Root path: https://example.com/.well-known/oauth-authorization-server urlsToTry.push({ - url: new URL('/.well-known/oauth-authorization-server', url.origin), - type: 'oauth' + url: new URL('/.well-known/oauth-authorization-server', url.origin), + type: 'oauth' }); - // OIDC: https://example.com/.well-known/openid-configuration + // 3. OIDC metadata endpoints + // RFC 8414 style: Insert /.well-known/openid-configuration before the path urlsToTry.push({ - url: new URL(`/.well-known/openid-configuration`, url.origin), - type: 'oidc' + url: new URL(`/.well-known/openid-configuration${pathname}`, url.origin), + type: 'oidc' + }); + // OIDC Discovery 1.0 style: Append /.well-known/openid-configuration after the path + urlsToTry.push({ + url: new URL(`${pathname}/.well-known/openid-configuration`, url.origin), + type: 'oidc' }); return urlsToTry; - } - - // Strip trailing slash from pathname to avoid double slashes - let pathname = url.pathname; - if (pathname.endsWith('/')) { - pathname = pathname.slice(0, -1); - } - - // 1. OAuth metadata at the given URL - // Insert well-known before the path: https://example.com/.well-known/oauth-authorization-server/tenant1 - urlsToTry.push({ - url: new URL(`/.well-known/oauth-authorization-server${pathname}`, url.origin), - type: 'oauth' - }); - - // Root path: https://example.com/.well-known/oauth-authorization-server - urlsToTry.push({ - url: new URL('/.well-known/oauth-authorization-server', url.origin), - type: 'oauth' - }); - - // 3. OIDC metadata endpoints - // RFC 8414 style: Insert /.well-known/openid-configuration before the path - urlsToTry.push({ - url: new URL(`/.well-known/openid-configuration${pathname}`, url.origin), - type: 'oidc' - }); - // OIDC Discovery 1.0 style: Append /.well-known/openid-configuration after the path - urlsToTry.push({ - url: new URL(`${pathname}/.well-known/openid-configuration`, url.origin), - type: 'oidc' - }); - - return urlsToTry; } /** @@ -741,140 +726,132 @@ export function buildDiscoveryUrls(authorizationServerUrl: string | URL): { url: * @returns Promise resolving to authorization server metadata, or undefined if discovery fails */ export async function discoverAuthorizationServerMetadata( - authorizationServerUrl: string | URL, - { - fetchFn = fetch, - protocolVersion = LATEST_PROTOCOL_VERSION, - }: { - fetchFn?: FetchLike; - protocolVersion?: string; - } = {} + authorizationServerUrl: string | URL, + { + fetchFn = fetch, + protocolVersion = LATEST_PROTOCOL_VERSION + }: { + fetchFn?: FetchLike; + protocolVersion?: string; + } = {} ): Promise { - const headers = { 'MCP-Protocol-Version': protocolVersion }; - - // Get the list of URLs to try - const urlsToTry = buildDiscoveryUrls(authorizationServerUrl); - - // Try each URL in order - for (const { url: endpointUrl, type } of urlsToTry) { - const response = await fetchWithCorsRetry(endpointUrl, headers, fetchFn); - - if (!response) { - /** - * CORS error occurred - don't throw as the endpoint may not allow CORS, - * continue trying other possible endpoints - */ - continue; + const headers = { 'MCP-Protocol-Version': protocolVersion }; + + // Get the list of URLs to try + const urlsToTry = buildDiscoveryUrls(authorizationServerUrl); + + // Try each URL in order + for (const { url: endpointUrl, type } of urlsToTry) { + const response = await fetchWithCorsRetry(endpointUrl, headers, fetchFn); + + if (!response) { + /** + * CORS error occurred - don't throw as the endpoint may not allow CORS, + * continue trying other possible endpoints + */ + continue; + } + + if (!response.ok) { + // Continue looking for any 4xx response code. + if (response.status >= 400 && response.status < 500) { + continue; // Try next URL + } + throw new Error( + `HTTP ${response.status} trying to load ${type === 'oauth' ? 'OAuth' : 'OpenID provider'} metadata from ${endpointUrl}` + ); + } + + // Parse and validate based on type + if (type === 'oauth') { + return OAuthMetadataSchema.parse(await response.json()); + } else { + const metadata = OpenIdProviderDiscoveryMetadataSchema.parse(await response.json()); + + // MCP spec requires OIDC providers to support S256 PKCE + if (!metadata.code_challenge_methods_supported?.includes('S256')) { + throw new Error( + `Incompatible OIDC provider at ${endpointUrl}: does not support S256 code challenge method required by MCP specification` + ); + } + + return metadata; + } } - if (!response.ok) { - // Continue looking for any 4xx response code. - if (response.status >= 400 && response.status < 500) { - continue; // Try next URL - } - throw new Error(`HTTP ${response.status} trying to load ${type === 'oauth' ? 'OAuth' : 'OpenID provider'} metadata from ${endpointUrl}`); - } - - // Parse and validate based on type - if (type === 'oauth') { - return OAuthMetadataSchema.parse(await response.json()); - } else { - const metadata = OpenIdProviderDiscoveryMetadataSchema.parse(await response.json()); - - // MCP spec requires OIDC providers to support S256 PKCE - if (!metadata.code_challenge_methods_supported?.includes('S256')) { - throw new Error( - `Incompatible OIDC provider at ${endpointUrl}: does not support S256 code challenge method required by MCP specification` - ); - } - - return metadata; - } - } - - return undefined; + return undefined; } /** * Begins the authorization flow with the given server, by generating a PKCE challenge and constructing the authorization URL. */ export async function startAuthorization( - authorizationServerUrl: string | URL, - { - metadata, - clientInformation, - redirectUrl, - scope, - state, - resource, - }: { - metadata?: AuthorizationServerMetadata; - clientInformation: OAuthClientInformation; - redirectUrl: string | URL; - scope?: string; - state?: string; - resource?: URL; - }, + authorizationServerUrl: string | URL, + { + metadata, + clientInformation, + redirectUrl, + scope, + state, + resource + }: { + metadata?: AuthorizationServerMetadata; + clientInformation: OAuthClientInformation; + redirectUrl: string | URL; + scope?: string; + state?: string; + resource?: URL; + } ): Promise<{ authorizationUrl: URL; codeVerifier: string }> { - const responseType = "code"; - const codeChallengeMethod = "S256"; + const responseType = 'code'; + const codeChallengeMethod = 'S256'; + + let authorizationUrl: URL; + if (metadata) { + authorizationUrl = new URL(metadata.authorization_endpoint); + + if (!metadata.response_types_supported.includes(responseType)) { + throw new Error(`Incompatible auth server: does not support response type ${responseType}`); + } + + if (!metadata.code_challenge_methods_supported || !metadata.code_challenge_methods_supported.includes(codeChallengeMethod)) { + throw new Error(`Incompatible auth server: does not support code challenge method ${codeChallengeMethod}`); + } + } else { + authorizationUrl = new URL('/authorize', authorizationServerUrl); + } + + // Generate PKCE challenge + const challenge = await pkceChallenge(); + const codeVerifier = challenge.code_verifier; + const codeChallenge = challenge.code_challenge; + + authorizationUrl.searchParams.set('response_type', responseType); + authorizationUrl.searchParams.set('client_id', clientInformation.client_id); + authorizationUrl.searchParams.set('code_challenge', codeChallenge); + authorizationUrl.searchParams.set('code_challenge_method', codeChallengeMethod); + authorizationUrl.searchParams.set('redirect_uri', String(redirectUrl)); + + if (state) { + authorizationUrl.searchParams.set('state', state); + } - let authorizationUrl: URL; - if (metadata) { - authorizationUrl = new URL(metadata.authorization_endpoint); + if (scope) { + authorizationUrl.searchParams.set('scope', scope); + } - if (!metadata.response_types_supported.includes(responseType)) { - throw new Error( - `Incompatible auth server: does not support response type ${responseType}`, - ); + if (scope?.includes('offline_access')) { + // if the request includes the OIDC-only "offline_access" scope, + // we need to set the prompt to "consent" to ensure the user is prompted to grant offline access + // https://openid.net/specs/openid-connect-core-1_0.html#OfflineAccess + authorizationUrl.searchParams.append('prompt', 'consent'); } - if ( - !metadata.code_challenge_methods_supported || - !metadata.code_challenge_methods_supported.includes(codeChallengeMethod) - ) { - throw new Error( - `Incompatible auth server: does not support code challenge method ${codeChallengeMethod}`, - ); + if (resource) { + authorizationUrl.searchParams.set('resource', resource.href); } - } else { - authorizationUrl = new URL("/authorize", authorizationServerUrl); - } - - // Generate PKCE challenge - const challenge = await pkceChallenge(); - const codeVerifier = challenge.code_verifier; - const codeChallenge = challenge.code_challenge; - - authorizationUrl.searchParams.set("response_type", responseType); - authorizationUrl.searchParams.set("client_id", clientInformation.client_id); - authorizationUrl.searchParams.set("code_challenge", codeChallenge); - authorizationUrl.searchParams.set( - "code_challenge_method", - codeChallengeMethod, - ); - authorizationUrl.searchParams.set("redirect_uri", String(redirectUrl)); - - if (state) { - authorizationUrl.searchParams.set("state", state); - } - - if (scope) { - authorizationUrl.searchParams.set("scope", scope); - } - - if (scope?.includes("offline_access")) { - // if the request includes the OIDC-only "offline_access" scope, - // we need to set the prompt to "consent" to ensure the user is prompted to grant offline access - // https://openid.net/specs/openid-connect-core-1_0.html#OfflineAccess - authorizationUrl.searchParams.append("prompt", "consent"); - } - - if (resource) { - authorizationUrl.searchParams.set("resource", resource.href); - } - - return { authorizationUrl, codeVerifier }; + + return { authorizationUrl, codeVerifier }; } /** @@ -890,79 +867,72 @@ export async function startAuthorization( * @throws {Error} When token exchange fails or authentication is invalid */ export async function exchangeAuthorization( - authorizationServerUrl: string | URL, - { - metadata, - clientInformation, - authorizationCode, - codeVerifier, - redirectUri, - resource, - addClientAuthentication, - fetchFn, - }: { - metadata?: AuthorizationServerMetadata; - clientInformation: OAuthClientInformation; - authorizationCode: string; - codeVerifier: string; - redirectUri: string | URL; - resource?: URL; - addClientAuthentication?: OAuthClientProvider["addClientAuthentication"]; - fetchFn?: FetchLike; - }, + authorizationServerUrl: string | URL, + { + metadata, + clientInformation, + authorizationCode, + codeVerifier, + redirectUri, + resource, + addClientAuthentication, + fetchFn + }: { + metadata?: AuthorizationServerMetadata; + clientInformation: OAuthClientInformation; + authorizationCode: string; + codeVerifier: string; + redirectUri: string | URL; + resource?: URL; + addClientAuthentication?: OAuthClientProvider['addClientAuthentication']; + fetchFn?: FetchLike; + } ): Promise { - const grantType = "authorization_code"; - - const tokenUrl = metadata?.token_endpoint - ? new URL(metadata.token_endpoint) - : new URL("/token", authorizationServerUrl); - - if ( - metadata?.grant_types_supported && - !metadata.grant_types_supported.includes(grantType) - ) { - throw new Error( - `Incompatible auth server: does not support grant type ${grantType}`, - ); - } - - // Exchange code for tokens - const headers = new Headers({ - "Content-Type": "application/x-www-form-urlencoded", - "Accept": "application/json", - }); - const params = new URLSearchParams({ - grant_type: grantType, - code: authorizationCode, - code_verifier: codeVerifier, - redirect_uri: String(redirectUri), - }); - - if (addClientAuthentication) { - addClientAuthentication(headers, params, authorizationServerUrl, metadata); - } else { - // Determine and apply client authentication method - const supportedMethods = metadata?.token_endpoint_auth_methods_supported ?? []; - const authMethod = selectClientAuthMethod(clientInformation, supportedMethods); - - applyClientAuthentication(authMethod, clientInformation, headers, params); - } - - if (resource) { - params.set("resource", resource.href); - } - - const response = await (fetchFn ?? fetch)(tokenUrl, { - method: "POST", - headers, - body: params, - }); - - if (!response.ok) { - throw await parseErrorResponse(response); - } - - return OAuthTokensSchema.parse(await response.json()); + const grantType = 'authorization_code'; + + const tokenUrl = metadata?.token_endpoint ? new URL(metadata.token_endpoint) : new URL('/token', authorizationServerUrl); + + if (metadata?.grant_types_supported && !metadata.grant_types_supported.includes(grantType)) { + throw new Error(`Incompatible auth server: does not support grant type ${grantType}`); + } + + // Exchange code for tokens + const headers = new Headers({ + 'Content-Type': 'application/x-www-form-urlencoded', + Accept: 'application/json' + }); + const params = new URLSearchParams({ + grant_type: grantType, + code: authorizationCode, + code_verifier: codeVerifier, + redirect_uri: String(redirectUri) + }); + + if (addClientAuthentication) { + addClientAuthentication(headers, params, authorizationServerUrl, metadata); + } else { + // Determine and apply client authentication method + const supportedMethods = metadata?.token_endpoint_auth_methods_supported ?? []; + const authMethod = selectClientAuthMethod(clientInformation, supportedMethods); + + applyClientAuthentication(authMethod, clientInformation, headers, params); + } + + if (resource) { + params.set('resource', resource.href); + } + + const response = await (fetchFn ?? fetch)(tokenUrl, { + method: 'POST', + headers, + body: params + }); + + if (!response.ok) { + throw await parseErrorResponse(response); + } + + return OAuthTokensSchema.parse(await response.json()); } /** @@ -978,114 +948,109 @@ export async function exchangeAuthorization( * @throws {Error} When token refresh fails or authentication is invalid */ export async function refreshAuthorization( - authorizationServerUrl: string | URL, - { - metadata, - clientInformation, - refreshToken, - resource, - addClientAuthentication, - fetchFn, - }: { - metadata?: AuthorizationServerMetadata; - clientInformation: OAuthClientInformation; - refreshToken: string; - resource?: URL; - addClientAuthentication?: OAuthClientProvider["addClientAuthentication"]; - fetchFn?: FetchLike; - } + authorizationServerUrl: string | URL, + { + metadata, + clientInformation, + refreshToken, + resource, + addClientAuthentication, + fetchFn + }: { + metadata?: AuthorizationServerMetadata; + clientInformation: OAuthClientInformation; + refreshToken: string; + resource?: URL; + addClientAuthentication?: OAuthClientProvider['addClientAuthentication']; + fetchFn?: FetchLike; + } ): Promise { - const grantType = "refresh_token"; - - let tokenUrl: URL; - if (metadata) { - tokenUrl = new URL(metadata.token_endpoint); - - if ( - metadata.grant_types_supported && - !metadata.grant_types_supported.includes(grantType) - ) { - throw new Error( - `Incompatible auth server: does not support grant type ${grantType}`, - ); + const grantType = 'refresh_token'; + + let tokenUrl: URL; + if (metadata) { + tokenUrl = new URL(metadata.token_endpoint); + + if (metadata.grant_types_supported && !metadata.grant_types_supported.includes(grantType)) { + throw new Error(`Incompatible auth server: does not support grant type ${grantType}`); + } + } else { + tokenUrl = new URL('/token', authorizationServerUrl); } - } else { - tokenUrl = new URL("/token", authorizationServerUrl); - } - - // Exchange refresh token - const headers = new Headers({ - "Content-Type": "application/x-www-form-urlencoded", - }); - const params = new URLSearchParams({ - grant_type: grantType, - refresh_token: refreshToken, - }); - - if (addClientAuthentication) { - addClientAuthentication(headers, params, authorizationServerUrl, metadata); - } else { - // Determine and apply client authentication method - const supportedMethods = metadata?.token_endpoint_auth_methods_supported ?? []; - const authMethod = selectClientAuthMethod(clientInformation, supportedMethods); - - applyClientAuthentication(authMethod, clientInformation, headers, params); - } - - if (resource) { - params.set("resource", resource.href); - } - - const response = await (fetchFn ?? fetch)(tokenUrl, { - method: "POST", - headers, - body: params, - }); - if (!response.ok) { - throw await parseErrorResponse(response); - } - - return OAuthTokensSchema.parse({ refresh_token: refreshToken, ...(await response.json()) }); + + // Exchange refresh token + const headers = new Headers({ + 'Content-Type': 'application/x-www-form-urlencoded' + }); + const params = new URLSearchParams({ + grant_type: grantType, + refresh_token: refreshToken + }); + + if (addClientAuthentication) { + addClientAuthentication(headers, params, authorizationServerUrl, metadata); + } else { + // Determine and apply client authentication method + const supportedMethods = metadata?.token_endpoint_auth_methods_supported ?? []; + const authMethod = selectClientAuthMethod(clientInformation, supportedMethods); + + applyClientAuthentication(authMethod, clientInformation, headers, params); + } + + if (resource) { + params.set('resource', resource.href); + } + + const response = await (fetchFn ?? fetch)(tokenUrl, { + method: 'POST', + headers, + body: params + }); + if (!response.ok) { + throw await parseErrorResponse(response); + } + + return OAuthTokensSchema.parse({ refresh_token: refreshToken, ...(await response.json()) }); } /** * Performs OAuth 2.0 Dynamic Client Registration according to RFC 7591. */ export async function registerClient( - authorizationServerUrl: string | URL, - { - metadata, - clientMetadata, - fetchFn, - }: { - metadata?: AuthorizationServerMetadata; - clientMetadata: OAuthClientMetadata; - fetchFn?: FetchLike; - }, + authorizationServerUrl: string | URL, + { + metadata, + clientMetadata, + fetchFn + }: { + metadata?: AuthorizationServerMetadata; + clientMetadata: OAuthClientMetadata; + fetchFn?: FetchLike; + } ): Promise { - let registrationUrl: URL; + let registrationUrl: URL; - if (metadata) { - if (!metadata.registration_endpoint) { - throw new Error("Incompatible auth server: does not support dynamic client registration"); - } + if (metadata) { + if (!metadata.registration_endpoint) { + throw new Error('Incompatible auth server: does not support dynamic client registration'); + } - registrationUrl = new URL(metadata.registration_endpoint); - } else { - registrationUrl = new URL("/register", authorizationServerUrl); - } + registrationUrl = new URL(metadata.registration_endpoint); + } else { + registrationUrl = new URL('/register', authorizationServerUrl); + } - const response = await (fetchFn ?? fetch)(registrationUrl, { - method: "POST", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify(clientMetadata), - }); + const response = await (fetchFn ?? fetch)(registrationUrl, { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify(clientMetadata) + }); - if (!response.ok) { - throw await parseErrorResponse(response); - } + if (!response.ok) { + throw await parseErrorResponse(response); + } - return OAuthClientInformationFullSchema.parse(await response.json()); + return OAuthClientInformationFullSchema.parse(await response.json()); } diff --git a/src/client/cross-spawn.test.ts b/src/client/cross-spawn.test.ts index 8480d94f7..ca2a5005c 100644 --- a/src/client/cross-spawn.test.ts +++ b/src/client/cross-spawn.test.ts @@ -1,152 +1,152 @@ -import { StdioClientTransport, getDefaultEnvironment } from "./stdio.js"; -import spawn from "cross-spawn"; -import { JSONRPCMessage } from "../types.js"; -import { ChildProcess } from "node:child_process"; +import { StdioClientTransport, getDefaultEnvironment } from './stdio.js'; +import spawn from 'cross-spawn'; +import { JSONRPCMessage } from '../types.js'; +import { ChildProcess } from 'node:child_process'; // mock cross-spawn -jest.mock("cross-spawn"); +jest.mock('cross-spawn'); const mockSpawn = spawn as jest.MockedFunction; -describe("StdioClientTransport using cross-spawn", () => { - beforeEach(() => { - // mock cross-spawn's return value - mockSpawn.mockImplementation(() => { - const mockProcess: { - on: jest.Mock; - stdin?: { on: jest.Mock; write: jest.Mock }; - stdout?: { on: jest.Mock }; - stderr?: null; - } = { - on: jest.fn((event: string, callback: () => void) => { - if (event === "spawn") { - callback(); - } - return mockProcess; - }), - stdin: { - on: jest.fn(), - write: jest.fn().mockReturnValue(true) - }, - stdout: { - on: jest.fn() - }, - stderr: null - }; - return mockProcess as unknown as ChildProcess; +describe('StdioClientTransport using cross-spawn', () => { + beforeEach(() => { + // mock cross-spawn's return value + mockSpawn.mockImplementation(() => { + const mockProcess: { + on: jest.Mock; + stdin?: { on: jest.Mock; write: jest.Mock }; + stdout?: { on: jest.Mock }; + stderr?: null; + } = { + on: jest.fn((event: string, callback: () => void) => { + if (event === 'spawn') { + callback(); + } + return mockProcess; + }), + stdin: { + on: jest.fn(), + write: jest.fn().mockReturnValue(true) + }, + stdout: { + on: jest.fn() + }, + stderr: null + }; + return mockProcess as unknown as ChildProcess; + }); }); - }); - afterEach(() => { - jest.clearAllMocks(); - }); - - test("should call cross-spawn correctly", async () => { - const transport = new StdioClientTransport({ - command: "test-command", - args: ["arg1", "arg2"] + afterEach(() => { + jest.clearAllMocks(); }); - await transport.start(); - - // verify spawn is called correctly - expect(mockSpawn).toHaveBeenCalledWith( - "test-command", - ["arg1", "arg2"], - expect.objectContaining({ - shell: false - }) - ); - }); - - test("should pass environment variables correctly", async () => { - const customEnv = { TEST_VAR: "test-value" }; - const transport = new StdioClientTransport({ - command: "test-command", - env: customEnv + test('should call cross-spawn correctly', async () => { + const transport = new StdioClientTransport({ + command: 'test-command', + args: ['arg1', 'arg2'] + }); + + await transport.start(); + + // verify spawn is called correctly + expect(mockSpawn).toHaveBeenCalledWith( + 'test-command', + ['arg1', 'arg2'], + expect.objectContaining({ + shell: false + }) + ); }); - await transport.start(); - - // verify environment variables are merged correctly - expect(mockSpawn).toHaveBeenCalledWith( - "test-command", - [], - expect.objectContaining({ - env: { - ...getDefaultEnvironment(), - ...customEnv - } - }) - ); - }); - - test("should use default environment when env is undefined", async () => { - const transport = new StdioClientTransport({ - command: "test-command", - env: undefined + test('should pass environment variables correctly', async () => { + const customEnv = { TEST_VAR: 'test-value' }; + const transport = new StdioClientTransport({ + command: 'test-command', + env: customEnv + }); + + await transport.start(); + + // verify environment variables are merged correctly + expect(mockSpawn).toHaveBeenCalledWith( + 'test-command', + [], + expect.objectContaining({ + env: { + ...getDefaultEnvironment(), + ...customEnv + } + }) + ); }); - await transport.start(); - - // verify default environment is used - expect(mockSpawn).toHaveBeenCalledWith( - "test-command", - [], - expect.objectContaining({ - env: getDefaultEnvironment() - }) - ); - }); - - test("should send messages correctly", async () => { - const transport = new StdioClientTransport({ - command: "test-command" + test('should use default environment when env is undefined', async () => { + const transport = new StdioClientTransport({ + command: 'test-command', + env: undefined + }); + + await transport.start(); + + // verify default environment is used + expect(mockSpawn).toHaveBeenCalledWith( + 'test-command', + [], + expect.objectContaining({ + env: getDefaultEnvironment() + }) + ); }); - // get the mock process object - const mockProcess: { - on: jest.Mock; - stdin: { - on: jest.Mock; - write: jest.Mock; - once: jest.Mock; - }; - stdout: { - on: jest.Mock; - }; - stderr: null; - } = { - on: jest.fn((event: string, callback: () => void) => { - if (event === "spawn") { - callback(); - } - return mockProcess; - }), - stdin: { - on: jest.fn(), - write: jest.fn().mockReturnValue(true), - once: jest.fn() - }, - stdout: { - on: jest.fn() - }, - stderr: null - }; - - mockSpawn.mockReturnValue(mockProcess as unknown as ChildProcess); - - await transport.start(); - - // 关键修复:确保 jsonrpc 是字面量 "2.0" - const message: JSONRPCMessage = { - jsonrpc: "2.0", - id: "test-id", - method: "test-method" - }; - - await transport.send(message); - - // verify message is sent correctly - expect(mockProcess.stdin.write).toHaveBeenCalled(); - }); -}); \ No newline at end of file + test('should send messages correctly', async () => { + const transport = new StdioClientTransport({ + command: 'test-command' + }); + + // get the mock process object + const mockProcess: { + on: jest.Mock; + stdin: { + on: jest.Mock; + write: jest.Mock; + once: jest.Mock; + }; + stdout: { + on: jest.Mock; + }; + stderr: null; + } = { + on: jest.fn((event: string, callback: () => void) => { + if (event === 'spawn') { + callback(); + } + return mockProcess; + }), + stdin: { + on: jest.fn(), + write: jest.fn().mockReturnValue(true), + once: jest.fn() + }, + stdout: { + on: jest.fn() + }, + stderr: null + }; + + mockSpawn.mockReturnValue(mockProcess as unknown as ChildProcess); + + await transport.start(); + + // 关键修复:确保 jsonrpc 是字面量 "2.0" + const message: JSONRPCMessage = { + jsonrpc: '2.0', + id: 'test-id', + method: 'test-method' + }; + + await transport.send(message); + + // verify message is sent correctly + expect(mockProcess.stdin.write).toHaveBeenCalled(); + }); +}); diff --git a/src/client/index.test.ts b/src/client/index.test.ts index abd0c34e4..de37b2d90 100644 --- a/src/client/index.test.ts +++ b/src/client/index.test.ts @@ -1,1303 +1,1241 @@ /* eslint-disable @typescript-eslint/no-unused-vars */ /* eslint-disable no-constant-binary-expression */ /* eslint-disable @typescript-eslint/no-unused-expressions */ -import { Client } from "./index.js"; -import { z } from "zod"; +import { Client } from './index.js'; +import { z } from 'zod'; import { - RequestSchema, - NotificationSchema, - ResultSchema, - LATEST_PROTOCOL_VERSION, - SUPPORTED_PROTOCOL_VERSIONS, - InitializeRequestSchema, - ListResourcesRequestSchema, - ListToolsRequestSchema, - CallToolRequestSchema, - CreateMessageRequestSchema, - ElicitRequestSchema, - ListRootsRequestSchema, - ErrorCode, -} from "../types.js"; -import { Transport } from "../shared/transport.js"; -import { Server } from "../server/index.js"; -import { InMemoryTransport } from "../inMemory.js"; + RequestSchema, + NotificationSchema, + ResultSchema, + LATEST_PROTOCOL_VERSION, + SUPPORTED_PROTOCOL_VERSIONS, + InitializeRequestSchema, + ListResourcesRequestSchema, + ListToolsRequestSchema, + CallToolRequestSchema, + CreateMessageRequestSchema, + ElicitRequestSchema, + ListRootsRequestSchema, + ErrorCode +} from '../types.js'; +import { Transport } from '../shared/transport.js'; +import { Server } from '../server/index.js'; +import { InMemoryTransport } from '../inMemory.js'; /*** * Test: Initialize with Matching Protocol Version */ -test("should initialize with matching protocol version", async () => { - const clientTransport: Transport = { - start: jest.fn().mockResolvedValue(undefined), - close: jest.fn().mockResolvedValue(undefined), - send: jest.fn().mockImplementation((message) => { - if (message.method === "initialize") { - clientTransport.onmessage?.({ - jsonrpc: "2.0", - id: message.id, - result: { - protocolVersion: LATEST_PROTOCOL_VERSION, - capabilities: {}, - serverInfo: { - name: "test", - version: "1.0", - }, - instructions: "test instructions", - }, - }); - } - return Promise.resolve(); - }), - }; - - const client = new Client( - { - name: "test client", - version: "1.0", - }, - { - capabilities: { - sampling: {}, - }, - }, - ); - - await client.connect(clientTransport); - - // Should have sent initialize with latest version - expect(clientTransport.send).toHaveBeenCalledWith( - expect.objectContaining({ - method: "initialize", - params: expect.objectContaining({ - protocolVersion: LATEST_PROTOCOL_VERSION, - }), - }), - expect.objectContaining({ - relatedRequestId: undefined, - }), - ); - - // Should have the instructions returned - expect(client.getInstructions()).toEqual("test instructions"); +test('should initialize with matching protocol version', async () => { + const clientTransport: Transport = { + start: jest.fn().mockResolvedValue(undefined), + close: jest.fn().mockResolvedValue(undefined), + send: jest.fn().mockImplementation(message => { + if (message.method === 'initialize') { + clientTransport.onmessage?.({ + jsonrpc: '2.0', + id: message.id, + result: { + protocolVersion: LATEST_PROTOCOL_VERSION, + capabilities: {}, + serverInfo: { + name: 'test', + version: '1.0' + }, + instructions: 'test instructions' + } + }); + } + return Promise.resolve(); + }) + }; + + const client = new Client( + { + name: 'test client', + version: '1.0' + }, + { + capabilities: { + sampling: {} + } + } + ); + + await client.connect(clientTransport); + + // Should have sent initialize with latest version + expect(clientTransport.send).toHaveBeenCalledWith( + expect.objectContaining({ + method: 'initialize', + params: expect.objectContaining({ + protocolVersion: LATEST_PROTOCOL_VERSION + }) + }), + expect.objectContaining({ + relatedRequestId: undefined + }) + ); + + // Should have the instructions returned + expect(client.getInstructions()).toEqual('test instructions'); }); /*** * Test: Initialize with Supported Older Protocol Version */ -test("should initialize with supported older protocol version", async () => { - const OLD_VERSION = SUPPORTED_PROTOCOL_VERSIONS[1]; - const clientTransport: Transport = { - start: jest.fn().mockResolvedValue(undefined), - close: jest.fn().mockResolvedValue(undefined), - send: jest.fn().mockImplementation((message) => { - if (message.method === "initialize") { - clientTransport.onmessage?.({ - jsonrpc: "2.0", - id: message.id, - result: { - protocolVersion: OLD_VERSION, - capabilities: {}, - serverInfo: { - name: "test", - version: "1.0", - }, - }, - }); - } - return Promise.resolve(); - }), - }; - - const client = new Client( - { - name: "test client", - version: "1.0", - }, - { - capabilities: { - sampling: {}, - }, - }, - ); - - await client.connect(clientTransport); - - // Connection should succeed with the older version - expect(client.getServerVersion()).toEqual({ - name: "test", - version: "1.0", - }); - - // Expect no instructions - expect(client.getInstructions()).toBeUndefined(); +test('should initialize with supported older protocol version', async () => { + const OLD_VERSION = SUPPORTED_PROTOCOL_VERSIONS[1]; + const clientTransport: Transport = { + start: jest.fn().mockResolvedValue(undefined), + close: jest.fn().mockResolvedValue(undefined), + send: jest.fn().mockImplementation(message => { + if (message.method === 'initialize') { + clientTransport.onmessage?.({ + jsonrpc: '2.0', + id: message.id, + result: { + protocolVersion: OLD_VERSION, + capabilities: {}, + serverInfo: { + name: 'test', + version: '1.0' + } + } + }); + } + return Promise.resolve(); + }) + }; + + const client = new Client( + { + name: 'test client', + version: '1.0' + }, + { + capabilities: { + sampling: {} + } + } + ); + + await client.connect(clientTransport); + + // Connection should succeed with the older version + expect(client.getServerVersion()).toEqual({ + name: 'test', + version: '1.0' + }); + + // Expect no instructions + expect(client.getInstructions()).toBeUndefined(); }); /*** * Test: Reject Unsupported Protocol Version */ -test("should reject unsupported protocol version", async () => { - const clientTransport: Transport = { - start: jest.fn().mockResolvedValue(undefined), - close: jest.fn().mockResolvedValue(undefined), - send: jest.fn().mockImplementation((message) => { - if (message.method === "initialize") { - clientTransport.onmessage?.({ - jsonrpc: "2.0", - id: message.id, - result: { - protocolVersion: "invalid-version", - capabilities: {}, - serverInfo: { - name: "test", - version: "1.0", - }, - }, - }); - } - return Promise.resolve(); - }), - }; - - const client = new Client( - { - name: "test client", - version: "1.0", - }, - { - capabilities: { - sampling: {}, - }, - }, - ); - - await expect(client.connect(clientTransport)).rejects.toThrow( - "Server's protocol version is not supported: invalid-version", - ); - - expect(clientTransport.close).toHaveBeenCalled(); +test('should reject unsupported protocol version', async () => { + const clientTransport: Transport = { + start: jest.fn().mockResolvedValue(undefined), + close: jest.fn().mockResolvedValue(undefined), + send: jest.fn().mockImplementation(message => { + if (message.method === 'initialize') { + clientTransport.onmessage?.({ + jsonrpc: '2.0', + id: message.id, + result: { + protocolVersion: 'invalid-version', + capabilities: {}, + serverInfo: { + name: 'test', + version: '1.0' + } + } + }); + } + return Promise.resolve(); + }) + }; + + const client = new Client( + { + name: 'test client', + version: '1.0' + }, + { + capabilities: { + sampling: {} + } + } + ); + + await expect(client.connect(clientTransport)).rejects.toThrow("Server's protocol version is not supported: invalid-version"); + + expect(clientTransport.close).toHaveBeenCalled(); }); /*** * Test: Connect New Client to Old Supported Server Version */ -test("should connect new client to old, supported server version", async () => { - const OLD_VERSION = SUPPORTED_PROTOCOL_VERSIONS[1]; - const server = new Server( - { - name: "test server", - version: "1.0", - }, - { - capabilities: { - resources: {}, - tools: {}, +test('should connect new client to old, supported server version', async () => { + const OLD_VERSION = SUPPORTED_PROTOCOL_VERSIONS[1]; + const server = new Server( + { + name: 'test server', + version: '1.0' }, - }, - ); - - server.setRequestHandler(InitializeRequestSchema, (_request) => ({ - protocolVersion: OLD_VERSION, - capabilities: { - resources: {}, - tools: {}, - }, - serverInfo: { - name: "old server", - version: "1.0", - }, - })); - - server.setRequestHandler(ListResourcesRequestSchema, () => ({ - resources: [], - })); - - server.setRequestHandler(ListToolsRequestSchema, () => ({ - tools: [], - })); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - const client = new Client( - { - name: "new client", - version: "1.0", - protocolVersion: LATEST_PROTOCOL_VERSION, - }, - { + { + capabilities: { + resources: {}, + tools: {} + } + } + ); + + server.setRequestHandler(InitializeRequestSchema, _request => ({ + protocolVersion: OLD_VERSION, capabilities: { - sampling: {}, + resources: {}, + tools: {} + }, + serverInfo: { + name: 'old server', + version: '1.0' + } + })); + + server.setRequestHandler(ListResourcesRequestSchema, () => ({ + resources: [] + })); + + server.setRequestHandler(ListToolsRequestSchema, () => ({ + tools: [] + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + const client = new Client( + { + name: 'new client', + version: '1.0', + protocolVersion: LATEST_PROTOCOL_VERSION }, - enforceStrictCapabilities: true, - }, - ); - - await Promise.all([ - client.connect(clientTransport), - server.connect(serverTransport), - ]); - - expect(client.getServerVersion()).toEqual({ - name: "old server", - version: "1.0", - }); + { + capabilities: { + sampling: {} + }, + enforceStrictCapabilities: true + } + ); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + expect(client.getServerVersion()).toEqual({ + name: 'old server', + version: '1.0' + }); }); /*** * Test: Version Negotiation with Old Client and Newer Server */ -test("should negotiate version when client is old, and newer server supports its version", async () => { - const OLD_VERSION = SUPPORTED_PROTOCOL_VERSIONS[1]; - const server = new Server( - { - name: "new server", - version: "1.0", - }, - { - capabilities: { - resources: {}, - tools: {}, +test('should negotiate version when client is old, and newer server supports its version', async () => { + const OLD_VERSION = SUPPORTED_PROTOCOL_VERSIONS[1]; + const server = new Server( + { + name: 'new server', + version: '1.0' }, - }, - ); - - server.setRequestHandler(InitializeRequestSchema, (_request) => ({ - protocolVersion: LATEST_PROTOCOL_VERSION, - capabilities: { - resources: {}, - tools: {}, - }, - serverInfo: { - name: "new server", - version: "1.0", - }, - })); - - server.setRequestHandler(ListResourcesRequestSchema, () => ({ - resources: [], - })); - - server.setRequestHandler(ListToolsRequestSchema, () => ({ - tools: [], - })); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - const client = new Client( - { - name: "old client", - version: "1.0", - protocolVersion: OLD_VERSION, - }, - { + { + capabilities: { + resources: {}, + tools: {} + } + } + ); + + server.setRequestHandler(InitializeRequestSchema, _request => ({ + protocolVersion: LATEST_PROTOCOL_VERSION, capabilities: { - sampling: {}, + resources: {}, + tools: {} + }, + serverInfo: { + name: 'new server', + version: '1.0' + } + })); + + server.setRequestHandler(ListResourcesRequestSchema, () => ({ + resources: [] + })); + + server.setRequestHandler(ListToolsRequestSchema, () => ({ + tools: [] + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + const client = new Client( + { + name: 'old client', + version: '1.0', + protocolVersion: OLD_VERSION }, - enforceStrictCapabilities: true, - }, - ); - - await Promise.all([ - client.connect(clientTransport), - server.connect(serverTransport), - ]); - - expect(client.getServerVersion()).toEqual({ - name: "new server", - version: "1.0", - }); + { + capabilities: { + sampling: {} + }, + enforceStrictCapabilities: true + } + ); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + expect(client.getServerVersion()).toEqual({ + name: 'new server', + version: '1.0' + }); }); /*** * Test: Throw when Old Client and Server Version Mismatch */ test("should throw when client is old, and server doesn't support its version", async () => { - const OLD_VERSION = SUPPORTED_PROTOCOL_VERSIONS[1]; - const FUTURE_VERSION = "FUTURE_VERSION"; - const server = new Server( - { - name: "new server", - version: "1.0", - }, - { - capabilities: { - resources: {}, - tools: {}, + const OLD_VERSION = SUPPORTED_PROTOCOL_VERSIONS[1]; + const FUTURE_VERSION = 'FUTURE_VERSION'; + const server = new Server( + { + name: 'new server', + version: '1.0' }, - }, - ); - - server.setRequestHandler(InitializeRequestSchema, (_request) => ({ - protocolVersion: FUTURE_VERSION, - capabilities: { - resources: {}, - tools: {}, - }, - serverInfo: { - name: "new server", - version: "1.0", - }, - })); - - server.setRequestHandler(ListResourcesRequestSchema, () => ({ - resources: [], - })); - - server.setRequestHandler(ListToolsRequestSchema, () => ({ - tools: [], - })); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - const client = new Client( - { - name: "old client", - version: "1.0", - protocolVersion: OLD_VERSION, - }, - { + { + capabilities: { + resources: {}, + tools: {} + } + } + ); + + server.setRequestHandler(InitializeRequestSchema, _request => ({ + protocolVersion: FUTURE_VERSION, capabilities: { - sampling: {}, + resources: {}, + tools: {} }, - enforceStrictCapabilities: true, - }, - ); + serverInfo: { + name: 'new server', + version: '1.0' + } + })); + + server.setRequestHandler(ListResourcesRequestSchema, () => ({ + resources: [] + })); + + server.setRequestHandler(ListToolsRequestSchema, () => ({ + tools: [] + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - await Promise.all([ - expect(client.connect(clientTransport)).rejects.toThrow( - "Server's protocol version is not supported: FUTURE_VERSION" - ), - server.connect(serverTransport), - ]); + const client = new Client( + { + name: 'old client', + version: '1.0', + protocolVersion: OLD_VERSION + }, + { + capabilities: { + sampling: {} + }, + enforceStrictCapabilities: true + } + ); + await Promise.all([ + expect(client.connect(clientTransport)).rejects.toThrow("Server's protocol version is not supported: FUTURE_VERSION"), + server.connect(serverTransport) + ]); }); /*** * Test: Respect Server Capabilities */ -test("should respect server capabilities", async () => { - const server = new Server( - { - name: "test server", - version: "1.0", - }, - { - capabilities: { +test('should respect server capabilities', async () => { + const server = new Server( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: { + resources: {}, + tools: {} + } + } + ); + + server.setRequestHandler(InitializeRequestSchema, _request => ({ + protocolVersion: LATEST_PROTOCOL_VERSION, + capabilities: { + resources: {}, + tools: {} + }, + serverInfo: { + name: 'test', + version: '1.0' + } + })); + + server.setRequestHandler(ListResourcesRequestSchema, () => ({ + resources: [] + })); + + server.setRequestHandler(ListToolsRequestSchema, () => ({ + tools: [] + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + const client = new Client( + { + name: 'test client', + version: '1.0' + }, + { + capabilities: { + sampling: {} + }, + enforceStrictCapabilities: true + } + ); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Server supports resources and tools, but not prompts + expect(client.getServerCapabilities()).toEqual({ resources: {}, - tools: {}, - }, - }, - ); - - server.setRequestHandler(InitializeRequestSchema, (_request) => ({ - protocolVersion: LATEST_PROTOCOL_VERSION, - capabilities: { - resources: {}, - tools: {}, - }, - serverInfo: { - name: "test", - version: "1.0", - }, - })); - - server.setRequestHandler(ListResourcesRequestSchema, () => ({ - resources: [], - })); - - server.setRequestHandler(ListToolsRequestSchema, () => ({ - tools: [], - })); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - const client = new Client( - { - name: "test client", - version: "1.0", - }, - { - capabilities: { - sampling: {}, - }, - enforceStrictCapabilities: true, - }, - ); - - await Promise.all([ - client.connect(clientTransport), - server.connect(serverTransport), - ]); - - // Server supports resources and tools, but not prompts - expect(client.getServerCapabilities()).toEqual({ - resources: {}, - tools: {}, - }); - - // These should work - await expect(client.listResources()).resolves.not.toThrow(); - await expect(client.listTools()).resolves.not.toThrow(); - - // These should throw because prompts, logging, and completions are not supported - await expect(client.listPrompts()).rejects.toThrow( - "Server does not support prompts", - ); - await expect(client.setLoggingLevel("error")).rejects.toThrow( - "Server does not support logging", - ); - await expect( - client.complete({ - ref: { type: "ref/prompt", name: "test" }, - argument: { name: "test", value: "test" }, - }), - ).rejects.toThrow("Server does not support completions"); + tools: {} + }); + + // These should work + await expect(client.listResources()).resolves.not.toThrow(); + await expect(client.listTools()).resolves.not.toThrow(); + + // These should throw because prompts, logging, and completions are not supported + await expect(client.listPrompts()).rejects.toThrow('Server does not support prompts'); + await expect(client.setLoggingLevel('error')).rejects.toThrow('Server does not support logging'); + await expect( + client.complete({ + ref: { type: 'ref/prompt', name: 'test' }, + argument: { name: 'test', value: 'test' } + }) + ).rejects.toThrow('Server does not support completions'); }); /*** * Test: Respect Client Notification Capabilities */ -test("should respect client notification capabilities", async () => { - const server = new Server( - { - name: "test server", - version: "1.0", - }, - { - capabilities: {}, - }, - ); - - const client = new Client( - { - name: "test client", - version: "1.0", - }, - { - capabilities: { - roots: { - listChanged: true, +test('should respect client notification capabilities', async () => { + const server = new Server( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: {} + } + ); + + const client = new Client( + { + name: 'test client', + version: '1.0' + }, + { + capabilities: { + roots: { + listChanged: true + } + } + } + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // This should work because the client has the roots.listChanged capability + await expect(client.sendRootsListChanged()).resolves.not.toThrow(); + + // Create a new client without the roots.listChanged capability + const clientWithoutCapability = new Client( + { + name: 'test client without capability', + version: '1.0' }, - }, - }, - ); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - server.connect(serverTransport), - ]); - - // This should work because the client has the roots.listChanged capability - await expect(client.sendRootsListChanged()).resolves.not.toThrow(); - - // Create a new client without the roots.listChanged capability - const clientWithoutCapability = new Client( - { - name: "test client without capability", - version: "1.0", - }, - { - capabilities: {}, - enforceStrictCapabilities: true, - }, - ); - - await clientWithoutCapability.connect(clientTransport); - - // This should throw because the client doesn't have the roots.listChanged capability - await expect(clientWithoutCapability.sendRootsListChanged()).rejects.toThrow( - /^Client does not support/, - ); + { + capabilities: {}, + enforceStrictCapabilities: true + } + ); + + await clientWithoutCapability.connect(clientTransport); + + // This should throw because the client doesn't have the roots.listChanged capability + await expect(clientWithoutCapability.sendRootsListChanged()).rejects.toThrow(/^Client does not support/); }); /*** * Test: Respect Server Notification Capabilities */ -test("should respect server notification capabilities", async () => { - const server = new Server( - { - name: "test server", - version: "1.0", - }, - { - capabilities: { - logging: {}, - resources: { - listChanged: true, +test('should respect server notification capabilities', async () => { + const server = new Server( + { + name: 'test server', + version: '1.0' }, - }, - }, - ); - - const client = new Client( - { - name: "test client", - version: "1.0", - }, - { - capabilities: {}, - }, - ); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - server.connect(serverTransport), - ]); - - // These should work because the server has the corresponding capabilities - await expect( - server.sendLoggingMessage({ level: "info", data: "Test" }), - ).resolves.not.toThrow(); - await expect(server.sendResourceListChanged()).resolves.not.toThrow(); - - // This should throw because the server doesn't have the tools capability - await expect(server.sendToolListChanged()).rejects.toThrow( - "Server does not support notifying of tool list changes", - ); -}); + { + capabilities: { + logging: {}, + resources: { + listChanged: true + } + } + } + ); -/*** - * Test: Only Allow setRequestHandler for Declared Capabilities - */ -test("should only allow setRequestHandler for declared capabilities", () => { - const client = new Client( - { - name: "test client", - version: "1.0", - }, - { - capabilities: { - sampling: {}, - }, - }, - ); - - // This should work because sampling is a declared capability - expect(() => { - client.setRequestHandler(CreateMessageRequestSchema, () => ({ - model: "test-model", - role: "assistant", - content: { - type: "text", - text: "Test response", - }, - })); - }).not.toThrow(); + const client = new Client( + { + name: 'test client', + version: '1.0' + }, + { + capabilities: {} + } + ); - // This should throw because roots listing is not a declared capability - expect(() => { - client.setRequestHandler(ListRootsRequestSchema, () => ({})); - }).toThrow("Client does not support roots capability"); -}); + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); -test("should allow setRequestHandler for declared elicitation capability", () => { - const client = new Client( - { - name: "test-client", - version: "1.0.0", - }, - { - capabilities: { - elicitation: {}, - }, - }, - ); - - // This should work because elicitation is a declared capability - expect(() => { - client.setRequestHandler(ElicitRequestSchema, () => ({ - action: "accept", - content: { - username: "test-user", - confirmed: true, - }, - })); - }).not.toThrow(); - - // This should throw because sampling is not a declared capability - expect(() => { - client.setRequestHandler(CreateMessageRequestSchema, () => ({ - model: "test-model", - role: "assistant", - content: { - type: "text", - text: "Test response", - }, - })); - }).toThrow("Client does not support sampling capability"); + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // These should work because the server has the corresponding capabilities + await expect(server.sendLoggingMessage({ level: 'info', data: 'Test' })).resolves.not.toThrow(); + await expect(server.sendResourceListChanged()).resolves.not.toThrow(); + + // This should throw because the server doesn't have the tools capability + await expect(server.sendToolListChanged()).rejects.toThrow('Server does not support notifying of tool list changes'); }); /*** - * Test: Type Checking - * Test that custom request/notification/result schemas can be used with the Client class. + * Test: Only Allow setRequestHandler for Declared Capabilities */ -test("should typecheck", () => { - const GetWeatherRequestSchema = RequestSchema.extend({ - method: z.literal("weather/get"), - params: z.object({ - city: z.string(), - }), - }); - - const GetForecastRequestSchema = RequestSchema.extend({ - method: z.literal("weather/forecast"), - params: z.object({ - city: z.string(), - days: z.number(), - }), - }); - - const WeatherForecastNotificationSchema = NotificationSchema.extend({ - method: z.literal("weather/alert"), - params: z.object({ - severity: z.enum(["warning", "watch"]), - message: z.string(), - }), - }); - - const WeatherRequestSchema = GetWeatherRequestSchema.or( - GetForecastRequestSchema, - ); - const WeatherNotificationSchema = WeatherForecastNotificationSchema; - const WeatherResultSchema = ResultSchema.extend({ - temperature: z.number(), - conditions: z.string(), - }); - - type WeatherRequest = z.infer; - type WeatherNotification = z.infer; - type WeatherResult = z.infer; - - // Create a typed Client for weather data - const weatherClient = new Client< - WeatherRequest, - WeatherNotification, - WeatherResult - >( - { - name: "WeatherClient", - version: "1.0.0", - }, - { - capabilities: { - sampling: {}, - }, - }, - ); - - // Typecheck that only valid weather requests/notifications/results are allowed - false && - weatherClient.request( - { - method: "weather/get", - params: { - city: "Seattle", +test('should only allow setRequestHandler for declared capabilities', () => { + const client = new Client( + { + name: 'test client', + version: '1.0' }, - }, - WeatherResultSchema, + { + capabilities: { + sampling: {} + } + } ); - false && - weatherClient.notification({ - method: "weather/alert", - params: { - severity: "warning", - message: "Storm approaching", - }, - }); + // This should work because sampling is a declared capability + expect(() => { + client.setRequestHandler(CreateMessageRequestSchema, () => ({ + model: 'test-model', + role: 'assistant', + content: { + type: 'text', + text: 'Test response' + } + })); + }).not.toThrow(); + + // This should throw because roots listing is not a declared capability + expect(() => { + client.setRequestHandler(ListRootsRequestSchema, () => ({})); + }).toThrow('Client does not support roots capability'); }); -/*** - * Test: Handle Client Cancelling a Request - */ -test("should handle client cancelling a request", async () => { - const server = new Server( - { - name: "test server", - version: "1.0", - }, - { - capabilities: { - resources: {}, - }, - }, - ); +test('should allow setRequestHandler for declared elicitation capability', () => { + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + elicitation: {} + } + } + ); - // Set up server to delay responding to listResources - server.setRequestHandler( - ListResourcesRequestSchema, - async (request, extra) => { - await new Promise((resolve) => setTimeout(resolve, 1000)); - return { - resources: [], - }; - }, - ); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - const client = new Client( - { - name: "test client", - version: "1.0", - }, - { - capabilities: {}, - }, - ); - - await Promise.all([ - client.connect(clientTransport), - server.connect(serverTransport), - ]); - - // Set up abort controller - const controller = new AbortController(); - - // Issue request but cancel it immediately - const listResourcesPromise = client.listResources(undefined, { - signal: controller.signal, - }); - controller.abort("Cancelled by test"); - - // Request should be rejected - await expect(listResourcesPromise).rejects.toBe("Cancelled by test"); + // This should work because elicitation is a declared capability + expect(() => { + client.setRequestHandler(ElicitRequestSchema, () => ({ + action: 'accept', + content: { + username: 'test-user', + confirmed: true + } + })); + }).not.toThrow(); + + // This should throw because sampling is not a declared capability + expect(() => { + client.setRequestHandler(CreateMessageRequestSchema, () => ({ + model: 'test-model', + role: 'assistant', + content: { + type: 'text', + text: 'Test response' + } + })); + }).toThrow('Client does not support sampling capability'); }); /*** - * Test: Handle Request Timeout + * Test: Type Checking + * Test that custom request/notification/result schemas can be used with the Client class. */ -test("should handle request timeout", async () => { - const server = new Server( - { - name: "test server", - version: "1.0", - }, - { - capabilities: { - resources: {}, - }, - }, - ); +test('should typecheck', () => { + const GetWeatherRequestSchema = RequestSchema.extend({ + method: z.literal('weather/get'), + params: z.object({ + city: z.string() + }) + }); - // Set up server with a delayed response - server.setRequestHandler( - ListResourcesRequestSchema, - async (_request, extra) => { - const timer = new Promise((resolve) => { - const timeout = setTimeout(resolve, 100); - extra.signal.addEventListener("abort", () => clearTimeout(timeout)); - }); - - await timer; - return { - resources: [], - }; - }, - ); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - const client = new Client( - { - name: "test client", - version: "1.0", - }, - { - capabilities: {}, - }, - ); - - await Promise.all([ - client.connect(clientTransport), - server.connect(serverTransport), - ]); - - // Request with 0 msec timeout should fail immediately - await expect( - client.listResources(undefined, { timeout: 0 }), - ).rejects.toMatchObject({ - code: ErrorCode.RequestTimeout, - }); -}); + const GetForecastRequestSchema = RequestSchema.extend({ + method: z.literal('weather/forecast'), + params: z.object({ + city: z.string(), + days: z.number() + }) + }); -describe('outputSchema validation', () => { - /*** - * Test: Validate structuredContent Against outputSchema - */ - test('should validate structuredContent against outputSchema', async () => { - const server = new Server({ - name: 'test-server', - version: '1.0.0', - }, { - capabilities: { - tools: {}, - }, + const WeatherForecastNotificationSchema = NotificationSchema.extend({ + method: z.literal('weather/alert'), + params: z.object({ + severity: z.enum(['warning', 'watch']), + message: z.string() + }) }); - // Set up server handlers - server.setRequestHandler(InitializeRequestSchema, async (request) => ({ - protocolVersion: request.params.protocolVersion, - capabilities: {}, - serverInfo: { - name: 'test-server', - version: '1.0.0', - } - })); + const WeatherRequestSchema = GetWeatherRequestSchema.or(GetForecastRequestSchema); + const WeatherNotificationSchema = WeatherForecastNotificationSchema; + const WeatherResultSchema = ResultSchema.extend({ + temperature: z.number(), + conditions: z.string() + }); - server.setRequestHandler(ListToolsRequestSchema, async () => ({ - tools: [ + type WeatherRequest = z.infer; + type WeatherNotification = z.infer; + type WeatherResult = z.infer; + + // Create a typed Client for weather data + const weatherClient = new Client( + { + name: 'WeatherClient', + version: '1.0.0' + }, { - name: 'test-tool', - description: 'A test tool', - inputSchema: { - type: 'object', - properties: {}, - }, - outputSchema: { - type: 'object', - properties: { - result: { type: 'string' }, - count: { type: 'number' }, + capabilities: { + sampling: {} + } + } + ); + + // Typecheck that only valid weather requests/notifications/results are allowed + false && + weatherClient.request( + { + method: 'weather/get', + params: { + city: 'Seattle' + } }, - required: ['result', 'count'], - additionalProperties: false, - }, + WeatherResultSchema + ); + + false && + weatherClient.notification({ + method: 'weather/alert', + params: { + severity: 'warning', + message: 'Storm approaching' + } + }); +}); + +/*** + * Test: Handle Client Cancelling a Request + */ +test('should handle client cancelling a request', async () => { + const server = new Server( + { + name: 'test server', + version: '1.0' }, - ], - })); + { + capabilities: { + resources: {} + } + } + ); - server.setRequestHandler(CallToolRequestSchema, async (request) => { - if (request.params.name === 'test-tool') { + // Set up server to delay responding to listResources + server.setRequestHandler(ListResourcesRequestSchema, async (request, extra) => { + await new Promise(resolve => setTimeout(resolve, 1000)); return { - structuredContent: { result: 'success', count: 42 }, + resources: [] }; - } - throw new Error('Unknown tool'); - }); - - const client = new Client({ - name: 'test-client', - version: '1.0.0', }); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - await Promise.all([ - client.connect(clientTransport), - server.connect(serverTransport), - ]); + const client = new Client( + { + name: 'test client', + version: '1.0' + }, + { + capabilities: {} + } + ); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Set up abort controller + const controller = new AbortController(); - // List tools to cache the schemas - await client.listTools(); - - // Call the tool - should validate successfully - const result = await client.callTool({ name: 'test-tool' }); - expect(result.structuredContent).toEqual({ result: 'success', count: 42 }); - }); - - /*** - * Test: Throw Error when structuredContent Does Not Match Schema - */ - test('should throw error when structuredContent does not match schema', async () => { - const server = new Server({ - name: 'test-server', - version: '1.0.0', - }, { - capabilities: { - tools: {}, - }, + // Issue request but cancel it immediately + const listResourcesPromise = client.listResources(undefined, { + signal: controller.signal }); + controller.abort('Cancelled by test'); - // Set up server handlers - server.setRequestHandler(InitializeRequestSchema, async (request) => ({ - protocolVersion: request.params.protocolVersion, - capabilities: {}, - serverInfo: { - name: 'test-server', - version: '1.0.0', - } - })); + // Request should be rejected + await expect(listResourcesPromise).rejects.toBe('Cancelled by test'); +}); - server.setRequestHandler(ListToolsRequestSchema, async () => ({ - tools: [ +/*** + * Test: Handle Request Timeout + */ +test('should handle request timeout', async () => { + const server = new Server( { - name: 'test-tool', - description: 'A test tool', - inputSchema: { - type: 'object', - properties: {}, - }, - outputSchema: { - type: 'object', - properties: { - result: { type: 'string' }, - count: { type: 'number' }, - }, - required: ['result', 'count'], - additionalProperties: false, - }, + name: 'test server', + version: '1.0' }, - ], - })); + { + capabilities: { + resources: {} + } + } + ); - server.setRequestHandler(CallToolRequestSchema, async (request) => { - if (request.params.name === 'test-tool') { - // Return invalid structured content (count is string instead of number) + // Set up server with a delayed response + server.setRequestHandler(ListResourcesRequestSchema, async (_request, extra) => { + const timer = new Promise(resolve => { + const timeout = setTimeout(resolve, 100); + extra.signal.addEventListener('abort', () => clearTimeout(timeout)); + }); + + await timer; return { - structuredContent: { result: 'success', count: 'not a number' }, + resources: [] }; - } - throw new Error('Unknown tool'); - }); - - const client = new Client({ - name: 'test-client', - version: '1.0.0', }); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - await Promise.all([ - client.connect(clientTransport), - server.connect(serverTransport), - ]); + const client = new Client( + { + name: 'test client', + version: '1.0' + }, + { + capabilities: {} + } + ); - // List tools to cache the schemas - await client.listTools(); + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - // Call the tool - should throw validation error - await expect(client.callTool({ name: 'test-tool' })).rejects.toThrow( - /Structured content does not match the tool's output schema/ - ); - }); - - /*** - * Test: Throw Error when Tool with outputSchema Returns No structuredContent - */ - test('should throw error when tool with outputSchema returns no structuredContent', async () => { - const server = new Server({ - name: 'test-server', - version: '1.0.0', - }, { - capabilities: { - tools: {}, - }, + // Request with 0 msec timeout should fail immediately + await expect(client.listResources(undefined, { timeout: 0 })).rejects.toMatchObject({ + code: ErrorCode.RequestTimeout }); +}); - // Set up server handlers - server.setRequestHandler(InitializeRequestSchema, async (request) => ({ - protocolVersion: request.params.protocolVersion, - capabilities: {}, - serverInfo: { - name: 'test-server', - version: '1.0.0', - } - })); - - server.setRequestHandler(ListToolsRequestSchema, async () => ({ - tools: [ - { - name: 'test-tool', - description: 'A test tool', - inputSchema: { - type: 'object', - properties: {}, - }, - outputSchema: { - type: 'object', - properties: { - result: { type: 'string' }, +describe('outputSchema validation', () => { + /*** + * Test: Validate structuredContent Against outputSchema + */ + test('should validate structuredContent against outputSchema', async () => { + const server = new Server( + { + name: 'test-server', + version: '1.0.0' }, - required: ['result'], - }, - }, - ], - })); - - server.setRequestHandler(CallToolRequestSchema, async (request) => { - if (request.params.name === 'test-tool') { - // Return content instead of structuredContent - return { - content: [{ type: 'text', text: 'This should be structured content' }], - }; - } - throw new Error('Unknown tool'); - }); + { + capabilities: { + tools: {} + } + } + ); + + // Set up server handlers + server.setRequestHandler(InitializeRequestSchema, async request => ({ + protocolVersion: request.params.protocolVersion, + capabilities: {}, + serverInfo: { + name: 'test-server', + version: '1.0.0' + } + })); + + server.setRequestHandler(ListToolsRequestSchema, async () => ({ + tools: [ + { + name: 'test-tool', + description: 'A test tool', + inputSchema: { + type: 'object', + properties: {} + }, + outputSchema: { + type: 'object', + properties: { + result: { type: 'string' }, + count: { type: 'number' } + }, + required: ['result', 'count'], + additionalProperties: false + } + } + ] + })); + + server.setRequestHandler(CallToolRequestSchema, async request => { + if (request.params.name === 'test-tool') { + return { + structuredContent: { result: 'success', count: 42 } + }; + } + throw new Error('Unknown tool'); + }); - const client = new Client({ - name: 'test-client', - version: '1.0.0', - }); + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - await Promise.all([ - client.connect(clientTransport), - server.connect(serverTransport), - ]); + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - // List tools to cache the schemas - await client.listTools(); + // List tools to cache the schemas + await client.listTools(); - // Call the tool - should throw error - await expect(client.callTool({ name: 'test-tool' })).rejects.toThrow( - /Tool test-tool has an output schema but did not return structured content/ - ); - }); - - /*** - * Test: Handle Tools Without outputSchema Normally - */ - test('should handle tools without outputSchema normally', async () => { - const server = new Server({ - name: 'test-server', - version: '1.0.0', - }, { - capabilities: { - tools: {}, - }, + // Call the tool - should validate successfully + const result = await client.callTool({ name: 'test-tool' }); + expect(result.structuredContent).toEqual({ result: 'success', count: 42 }); }); - // Set up server handlers - server.setRequestHandler(InitializeRequestSchema, async (request) => ({ - protocolVersion: request.params.protocolVersion, - capabilities: {}, - serverInfo: { - name: 'test-server', - version: '1.0.0', - } - })); + /*** + * Test: Throw Error when structuredContent Does Not Match Schema + */ + test('should throw error when structuredContent does not match schema', async () => { + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {} + } + } + ); + + // Set up server handlers + server.setRequestHandler(InitializeRequestSchema, async request => ({ + protocolVersion: request.params.protocolVersion, + capabilities: {}, + serverInfo: { + name: 'test-server', + version: '1.0.0' + } + })); + + server.setRequestHandler(ListToolsRequestSchema, async () => ({ + tools: [ + { + name: 'test-tool', + description: 'A test tool', + inputSchema: { + type: 'object', + properties: {} + }, + outputSchema: { + type: 'object', + properties: { + result: { type: 'string' }, + count: { type: 'number' } + }, + required: ['result', 'count'], + additionalProperties: false + } + } + ] + })); + + server.setRequestHandler(CallToolRequestSchema, async request => { + if (request.params.name === 'test-tool') { + // Return invalid structured content (count is string instead of number) + return { + structuredContent: { result: 'success', count: 'not a number' } + }; + } + throw new Error('Unknown tool'); + }); - server.setRequestHandler(ListToolsRequestSchema, async () => ({ - tools: [ - { - name: 'test-tool', - description: 'A test tool', - inputSchema: { - type: 'object', - properties: {}, - }, - // No outputSchema - }, - ], - })); + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); - server.setRequestHandler(CallToolRequestSchema, async (request) => { - if (request.params.name === 'test-tool') { - // Return regular content - return { - content: [{ type: 'text', text: 'Normal response' }], - }; - } - throw new Error('Unknown tool'); - }); + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - const client = new Client({ - name: 'test-client', - version: '1.0.0', + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // List tools to cache the schemas + await client.listTools(); + + // Call the tool - should throw validation error + await expect(client.callTool({ name: 'test-tool' })).rejects.toThrow(/Structured content does not match the tool's output schema/); }); - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + /*** + * Test: Throw Error when Tool with outputSchema Returns No structuredContent + */ + test('should throw error when tool with outputSchema returns no structuredContent', async () => { + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {} + } + } + ); + + // Set up server handlers + server.setRequestHandler(InitializeRequestSchema, async request => ({ + protocolVersion: request.params.protocolVersion, + capabilities: {}, + serverInfo: { + name: 'test-server', + version: '1.0.0' + } + })); + + server.setRequestHandler(ListToolsRequestSchema, async () => ({ + tools: [ + { + name: 'test-tool', + description: 'A test tool', + inputSchema: { + type: 'object', + properties: {} + }, + outputSchema: { + type: 'object', + properties: { + result: { type: 'string' } + }, + required: ['result'] + } + } + ] + })); + + server.setRequestHandler(CallToolRequestSchema, async request => { + if (request.params.name === 'test-tool') { + // Return content instead of structuredContent + return { + content: [{ type: 'text', text: 'This should be structured content' }] + }; + } + throw new Error('Unknown tool'); + }); - await Promise.all([ - client.connect(clientTransport), - server.connect(serverTransport), - ]); + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); - // List tools to cache the schemas - await client.listTools(); - - // Call the tool - should work normally without validation - const result = await client.callTool({ name: 'test-tool' }); - expect(result.content).toEqual([{ type: 'text', text: 'Normal response' }]); - }); - - /*** - * Test: Handle Complex JSON Schema Validation - */ - test('should handle complex JSON schema validation', async () => { - const server = new Server({ - name: 'test-server', - version: '1.0.0', - }, { - capabilities: { - tools: {}, - }, - }); + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - // Set up server handlers - server.setRequestHandler(InitializeRequestSchema, async (request) => ({ - protocolVersion: request.params.protocolVersion, - capabilities: {}, - serverInfo: { - name: 'test-server', - version: '1.0.0', - } - })); + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - server.setRequestHandler(ListToolsRequestSchema, async () => ({ - tools: [ - { - name: 'complex-tool', - description: 'A tool with complex schema', - inputSchema: { - type: 'object', - properties: {}, - }, - outputSchema: { - type: 'object', - properties: { - name: { type: 'string', minLength: 3 }, - age: { type: 'integer', minimum: 0, maximum: 120 }, - active: { type: 'boolean' }, - tags: { - type: 'array', - items: { type: 'string' }, - minItems: 1, - }, - metadata: { - type: 'object', - properties: { - created: { type: 'string' }, - }, - required: ['created'], - }, - }, - required: ['name', 'age', 'active', 'tags', 'metadata'], - additionalProperties: false, - }, - }, - ], - })); + // List tools to cache the schemas + await client.listTools(); - server.setRequestHandler(CallToolRequestSchema, async (request) => { - if (request.params.name === 'complex-tool') { - return { - structuredContent: { - name: 'John Doe', - age: 30, - active: true, - tags: ['user', 'admin'], - metadata: { - created: '2023-01-01T00:00:00Z', - }, - }, - }; - } - throw new Error('Unknown tool'); + // Call the tool - should throw error + await expect(client.callTool({ name: 'test-tool' })).rejects.toThrow( + /Tool test-tool has an output schema but did not return structured content/ + ); }); - const client = new Client({ - name: 'test-client', - version: '1.0.0', - }); + /*** + * Test: Handle Tools Without outputSchema Normally + */ + test('should handle tools without outputSchema normally', async () => { + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {} + } + } + ); + + // Set up server handlers + server.setRequestHandler(InitializeRequestSchema, async request => ({ + protocolVersion: request.params.protocolVersion, + capabilities: {}, + serverInfo: { + name: 'test-server', + version: '1.0.0' + } + })); + + server.setRequestHandler(ListToolsRequestSchema, async () => ({ + tools: [ + { + name: 'test-tool', + description: 'A test tool', + inputSchema: { + type: 'object', + properties: {} + } + // No outputSchema + } + ] + })); + + server.setRequestHandler(CallToolRequestSchema, async request => { + if (request.params.name === 'test-tool') { + // Return regular content + return { + content: [{ type: 'text', text: 'Normal response' }] + }; + } + throw new Error('Unknown tool'); + }); - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); - await Promise.all([ - client.connect(clientTransport), - server.connect(serverTransport), - ]); + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - // List tools to cache the schemas - await client.listTools(); - - // Call the tool - should validate successfully - const result = await client.callTool({ name: 'complex-tool' }); - expect(result.structuredContent).toBeDefined(); - const structuredContent = result.structuredContent as { name: string; age: number }; - expect(structuredContent.name).toBe('John Doe'); - expect(structuredContent.age).toBe(30); - }); - - /*** - * Test: Fail Validation with Additional Properties When Not Allowed - */ - test('should fail validation with additional properties when not allowed', async () => { - const server = new Server({ - name: 'test-server', - version: '1.0.0', - }, { - capabilities: { - tools: {}, - }, - }); + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - // Set up server handlers - server.setRequestHandler(InitializeRequestSchema, async (request) => ({ - protocolVersion: request.params.protocolVersion, - capabilities: {}, - serverInfo: { - name: 'test-server', - version: '1.0.0', - } - })); + // List tools to cache the schemas + await client.listTools(); - server.setRequestHandler(ListToolsRequestSchema, async () => ({ - tools: [ - { - name: 'strict-tool', - description: 'A tool with strict schema', - inputSchema: { - type: 'object', - properties: {}, - }, - outputSchema: { - type: 'object', - properties: { - name: { type: 'string' }, + // Call the tool - should work normally without validation + const result = await client.callTool({ name: 'test-tool' }); + expect(result.content).toEqual([{ type: 'text', text: 'Normal response' }]); + }); + + /*** + * Test: Handle Complex JSON Schema Validation + */ + test('should handle complex JSON schema validation', async () => { + const server = new Server( + { + name: 'test-server', + version: '1.0.0' }, - required: ['name'], - additionalProperties: false, - }, - }, - ], - })); + { + capabilities: { + tools: {} + } + } + ); + + // Set up server handlers + server.setRequestHandler(InitializeRequestSchema, async request => ({ + protocolVersion: request.params.protocolVersion, + capabilities: {}, + serverInfo: { + name: 'test-server', + version: '1.0.0' + } + })); + + server.setRequestHandler(ListToolsRequestSchema, async () => ({ + tools: [ + { + name: 'complex-tool', + description: 'A tool with complex schema', + inputSchema: { + type: 'object', + properties: {} + }, + outputSchema: { + type: 'object', + properties: { + name: { type: 'string', minLength: 3 }, + age: { type: 'integer', minimum: 0, maximum: 120 }, + active: { type: 'boolean' }, + tags: { + type: 'array', + items: { type: 'string' }, + minItems: 1 + }, + metadata: { + type: 'object', + properties: { + created: { type: 'string' } + }, + required: ['created'] + } + }, + required: ['name', 'age', 'active', 'tags', 'metadata'], + additionalProperties: false + } + } + ] + })); + + server.setRequestHandler(CallToolRequestSchema, async request => { + if (request.params.name === 'complex-tool') { + return { + structuredContent: { + name: 'John Doe', + age: 30, + active: true, + tags: ['user', 'admin'], + metadata: { + created: '2023-01-01T00:00:00Z' + } + } + }; + } + throw new Error('Unknown tool'); + }); - server.setRequestHandler(CallToolRequestSchema, async (request) => { - if (request.params.name === 'strict-tool') { - // Return structured content with extra property - return { - structuredContent: { - name: 'John', - extraField: 'not allowed', - }, - }; - } - throw new Error('Unknown tool'); - }); + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // List tools to cache the schemas + await client.listTools(); - const client = new Client({ - name: 'test-client', - version: '1.0.0', + // Call the tool - should validate successfully + const result = await client.callTool({ name: 'complex-tool' }); + expect(result.structuredContent).toBeDefined(); + const structuredContent = result.structuredContent as { name: string; age: number }; + expect(structuredContent.name).toBe('John Doe'); + expect(structuredContent.age).toBe(30); }); - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + /*** + * Test: Fail Validation with Additional Properties When Not Allowed + */ + test('should fail validation with additional properties when not allowed', async () => { + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {} + } + } + ); + + // Set up server handlers + server.setRequestHandler(InitializeRequestSchema, async request => ({ + protocolVersion: request.params.protocolVersion, + capabilities: {}, + serverInfo: { + name: 'test-server', + version: '1.0.0' + } + })); + + server.setRequestHandler(ListToolsRequestSchema, async () => ({ + tools: [ + { + name: 'strict-tool', + description: 'A tool with strict schema', + inputSchema: { + type: 'object', + properties: {} + }, + outputSchema: { + type: 'object', + properties: { + name: { type: 'string' } + }, + required: ['name'], + additionalProperties: false + } + } + ] + })); + + server.setRequestHandler(CallToolRequestSchema, async request => { + if (request.params.name === 'strict-tool') { + // Return structured content with extra property + return { + structuredContent: { + name: 'John', + extraField: 'not allowed' + } + }; + } + throw new Error('Unknown tool'); + }); - await Promise.all([ - client.connect(clientTransport), - server.connect(serverTransport), - ]); + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); - // List tools to cache the schemas - await client.listTools(); + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - // Call the tool - should throw validation error due to additional property - await expect(client.callTool({ name: 'strict-tool' })).rejects.toThrow( - /Structured content does not match the tool's output schema/ - ); - }); + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + // List tools to cache the schemas + await client.listTools(); + // Call the tool - should throw validation error due to additional property + await expect(client.callTool({ name: 'strict-tool' })).rejects.toThrow( + /Structured content does not match the tool's output schema/ + ); + }); }); diff --git a/src/client/index.ts b/src/client/index.ts index 3e8d8ec80..856eb18e5 100644 --- a/src/client/index.ts +++ b/src/client/index.ts @@ -1,56 +1,51 @@ +import { mergeCapabilities, Protocol, ProtocolOptions, RequestOptions } from '../shared/protocol.js'; +import { Transport } from '../shared/transport.js'; import { - mergeCapabilities, - Protocol, - ProtocolOptions, - RequestOptions, -} from "../shared/protocol.js"; -import { Transport } from "../shared/transport.js"; -import { - CallToolRequest, - CallToolResultSchema, - ClientCapabilities, - ClientNotification, - ClientRequest, - ClientResult, - CompatibilityCallToolResultSchema, - CompleteRequest, - CompleteResultSchema, - EmptyResultSchema, - GetPromptRequest, - GetPromptResultSchema, - Implementation, - InitializeResultSchema, - LATEST_PROTOCOL_VERSION, - ListPromptsRequest, - ListPromptsResultSchema, - ListResourcesRequest, - ListResourcesResultSchema, - ListResourceTemplatesRequest, - ListResourceTemplatesResultSchema, - ListToolsRequest, - ListToolsResultSchema, - LoggingLevel, - Notification, - ReadResourceRequest, - ReadResourceResultSchema, - Request, - Result, - ServerCapabilities, - SubscribeRequest, - SUPPORTED_PROTOCOL_VERSIONS, - UnsubscribeRequest, - Tool, - ErrorCode, - McpError, -} from "../types.js"; -import Ajv from "ajv"; -import type { ValidateFunction } from "ajv"; + CallToolRequest, + CallToolResultSchema, + ClientCapabilities, + ClientNotification, + ClientRequest, + ClientResult, + CompatibilityCallToolResultSchema, + CompleteRequest, + CompleteResultSchema, + EmptyResultSchema, + GetPromptRequest, + GetPromptResultSchema, + Implementation, + InitializeResultSchema, + LATEST_PROTOCOL_VERSION, + ListPromptsRequest, + ListPromptsResultSchema, + ListResourcesRequest, + ListResourcesResultSchema, + ListResourceTemplatesRequest, + ListResourceTemplatesResultSchema, + ListToolsRequest, + ListToolsResultSchema, + LoggingLevel, + Notification, + ReadResourceRequest, + ReadResourceResultSchema, + Request, + Result, + ServerCapabilities, + SubscribeRequest, + SUPPORTED_PROTOCOL_VERSIONS, + UnsubscribeRequest, + Tool, + ErrorCode, + McpError +} from '../types.js'; +import Ajv from 'ajv'; +import type { ValidateFunction } from 'ajv'; export type ClientOptions = ProtocolOptions & { - /** - * Capabilities to advertise as being supported by this client. - */ - capabilities?: ClientCapabilities; + /** + * Capabilities to advertise as being supported by this client. + */ + capabilities?: ClientCapabilities; }; /** @@ -79,441 +74,333 @@ export type ClientOptions = ProtocolOptions & { * ``` */ export class Client< - RequestT extends Request = Request, - NotificationT extends Notification = Notification, - ResultT extends Result = Result, -> extends Protocol< - ClientRequest | RequestT, - ClientNotification | NotificationT, - ClientResult | ResultT -> { - private _serverCapabilities?: ServerCapabilities; - private _serverVersion?: Implementation; - private _capabilities: ClientCapabilities; - private _instructions?: string; - private _cachedToolOutputValidators: Map = new Map(); - private _ajv: InstanceType; - - /** - * Initializes this client with the given name and version information. - */ - constructor( - private _clientInfo: Implementation, - options?: ClientOptions, - ) { - super(options); - this._capabilities = options?.capabilities ?? {}; - this._ajv = new Ajv(); - } - - /** - * Registers new capabilities. This can only be called before connecting to a transport. - * - * The new capabilities will be merged with any existing capabilities previously given (e.g., at initialization). - */ - public registerCapabilities(capabilities: ClientCapabilities): void { - if (this.transport) { - throw new Error( - "Cannot register capabilities after connecting to transport", - ); + RequestT extends Request = Request, + NotificationT extends Notification = Notification, + ResultT extends Result = Result +> extends Protocol { + private _serverCapabilities?: ServerCapabilities; + private _serverVersion?: Implementation; + private _capabilities: ClientCapabilities; + private _instructions?: string; + private _cachedToolOutputValidators: Map = new Map(); + private _ajv: InstanceType; + + /** + * Initializes this client with the given name and version information. + */ + constructor( + private _clientInfo: Implementation, + options?: ClientOptions + ) { + super(options); + this._capabilities = options?.capabilities ?? {}; + this._ajv = new Ajv(); } - this._capabilities = mergeCapabilities(this._capabilities, capabilities); - } - - protected assertCapability( - capability: keyof ServerCapabilities, - method: string, - ): void { - if (!this._serverCapabilities?.[capability]) { - throw new Error( - `Server does not support ${capability} (required for ${method})`, - ); - } - } - - override async connect(transport: Transport, options?: RequestOptions): Promise { - await super.connect(transport); - // When transport sessionId is already set this means we are trying to reconnect. - // In this case we don't need to initialize again. - if (transport.sessionId !== undefined) { - return; - } - try { - const result = await this.request( - { - method: "initialize", - params: { - protocolVersion: LATEST_PROTOCOL_VERSION, - capabilities: this._capabilities, - clientInfo: this._clientInfo, - }, - }, - InitializeResultSchema, - options - ); - - if (result === undefined) { - throw new Error(`Server sent invalid initialize result: ${result}`); - } - - if (!SUPPORTED_PROTOCOL_VERSIONS.includes(result.protocolVersion)) { - throw new Error( - `Server's protocol version is not supported: ${result.protocolVersion}`, - ); - } - - this._serverCapabilities = result.capabilities; - this._serverVersion = result.serverInfo; - // HTTP transports must set the protocol version in each header after initialization. - if (transport.setProtocolVersion) { - transport.setProtocolVersion(result.protocolVersion); - } - - this._instructions = result.instructions; - - await this.notification({ - method: "notifications/initialized", - }); - } catch (error) { - // Disconnect if initialization fails. - void this.close(); - throw error; - } - } - - /** - * After initialization has completed, this will be populated with the server's reported capabilities. - */ - getServerCapabilities(): ServerCapabilities | undefined { - return this._serverCapabilities; - } - - /** - * After initialization has completed, this will be populated with information about the server's name and version. - */ - getServerVersion(): Implementation | undefined { - return this._serverVersion; - } - - /** - * After initialization has completed, this may be populated with information about the server's instructions. - */ - getInstructions(): string | undefined { - return this._instructions; - } - - protected assertCapabilityForMethod(method: RequestT["method"]): void { - switch (method as ClientRequest["method"]) { - case "logging/setLevel": - if (!this._serverCapabilities?.logging) { - throw new Error( - `Server does not support logging (required for ${method})`, - ); + /** + * Registers new capabilities. This can only be called before connecting to a transport. + * + * The new capabilities will be merged with any existing capabilities previously given (e.g., at initialization). + */ + public registerCapabilities(capabilities: ClientCapabilities): void { + if (this.transport) { + throw new Error('Cannot register capabilities after connecting to transport'); } - break; - - case "prompts/get": - case "prompts/list": - if (!this._serverCapabilities?.prompts) { - throw new Error( - `Server does not support prompts (required for ${method})`, - ); + + this._capabilities = mergeCapabilities(this._capabilities, capabilities); + } + + protected assertCapability(capability: keyof ServerCapabilities, method: string): void { + if (!this._serverCapabilities?.[capability]) { + throw new Error(`Server does not support ${capability} (required for ${method})`); } - break; - - case "resources/list": - case "resources/templates/list": - case "resources/read": - case "resources/subscribe": - case "resources/unsubscribe": - if (!this._serverCapabilities?.resources) { - throw new Error( - `Server does not support resources (required for ${method})`, - ); + } + + override async connect(transport: Transport, options?: RequestOptions): Promise { + await super.connect(transport); + // When transport sessionId is already set this means we are trying to reconnect. + // In this case we don't need to initialize again. + if (transport.sessionId !== undefined) { + return; } + try { + const result = await this.request( + { + method: 'initialize', + params: { + protocolVersion: LATEST_PROTOCOL_VERSION, + capabilities: this._capabilities, + clientInfo: this._clientInfo + } + }, + InitializeResultSchema, + options + ); + + if (result === undefined) { + throw new Error(`Server sent invalid initialize result: ${result}`); + } + + if (!SUPPORTED_PROTOCOL_VERSIONS.includes(result.protocolVersion)) { + throw new Error(`Server's protocol version is not supported: ${result.protocolVersion}`); + } + + this._serverCapabilities = result.capabilities; + this._serverVersion = result.serverInfo; + // HTTP transports must set the protocol version in each header after initialization. + if (transport.setProtocolVersion) { + transport.setProtocolVersion(result.protocolVersion); + } - if ( - method === "resources/subscribe" && - !this._serverCapabilities.resources.subscribe - ) { - throw new Error( - `Server does not support resource subscriptions (required for ${method})`, - ); + this._instructions = result.instructions; + + await this.notification({ + method: 'notifications/initialized' + }); + } catch (error) { + // Disconnect if initialization fails. + void this.close(); + throw error; } + } - break; + /** + * After initialization has completed, this will be populated with the server's reported capabilities. + */ + getServerCapabilities(): ServerCapabilities | undefined { + return this._serverCapabilities; + } + + /** + * After initialization has completed, this will be populated with information about the server's name and version. + */ + getServerVersion(): Implementation | undefined { + return this._serverVersion; + } - case "tools/call": - case "tools/list": - if (!this._serverCapabilities?.tools) { - throw new Error( - `Server does not support tools (required for ${method})`, - ); + /** + * After initialization has completed, this may be populated with information about the server's instructions. + */ + getInstructions(): string | undefined { + return this._instructions; + } + + protected assertCapabilityForMethod(method: RequestT['method']): void { + switch (method as ClientRequest['method']) { + case 'logging/setLevel': + if (!this._serverCapabilities?.logging) { + throw new Error(`Server does not support logging (required for ${method})`); + } + break; + + case 'prompts/get': + case 'prompts/list': + if (!this._serverCapabilities?.prompts) { + throw new Error(`Server does not support prompts (required for ${method})`); + } + break; + + case 'resources/list': + case 'resources/templates/list': + case 'resources/read': + case 'resources/subscribe': + case 'resources/unsubscribe': + if (!this._serverCapabilities?.resources) { + throw new Error(`Server does not support resources (required for ${method})`); + } + + if (method === 'resources/subscribe' && !this._serverCapabilities.resources.subscribe) { + throw new Error(`Server does not support resource subscriptions (required for ${method})`); + } + + break; + + case 'tools/call': + case 'tools/list': + if (!this._serverCapabilities?.tools) { + throw new Error(`Server does not support tools (required for ${method})`); + } + break; + + case 'completion/complete': + if (!this._serverCapabilities?.completions) { + throw new Error(`Server does not support completions (required for ${method})`); + } + break; + + case 'initialize': + // No specific capability required for initialize + break; + + case 'ping': + // No specific capability required for ping + break; } - break; + } - case "completion/complete": - if (!this._serverCapabilities?.completions) { - throw new Error( - `Server does not support completions (required for ${method})`, - ); + protected assertNotificationCapability(method: NotificationT['method']): void { + switch (method as ClientNotification['method']) { + case 'notifications/roots/list_changed': + if (!this._capabilities.roots?.listChanged) { + throw new Error(`Client does not support roots list changed notifications (required for ${method})`); + } + break; + + case 'notifications/initialized': + // No specific capability required for initialized + break; + + case 'notifications/cancelled': + // Cancellation notifications are always allowed + break; + + case 'notifications/progress': + // Progress notifications are always allowed + break; } - break; + } - case "initialize": - // No specific capability required for initialize - break; + protected assertRequestHandlerCapability(method: string): void { + switch (method) { + case 'sampling/createMessage': + if (!this._capabilities.sampling) { + throw new Error(`Client does not support sampling capability (required for ${method})`); + } + break; + + case 'elicitation/create': + if (!this._capabilities.elicitation) { + throw new Error(`Client does not support elicitation capability (required for ${method})`); + } + break; + + case 'roots/list': + if (!this._capabilities.roots) { + throw new Error(`Client does not support roots capability (required for ${method})`); + } + break; + + case 'ping': + // No specific capability required for ping + break; + } + } - case "ping": - // No specific capability required for ping - break; + async ping(options?: RequestOptions) { + return this.request({ method: 'ping' }, EmptyResultSchema, options); } - } - - protected assertNotificationCapability( - method: NotificationT["method"], - ): void { - switch (method as ClientNotification["method"]) { - case "notifications/roots/list_changed": - if (!this._capabilities.roots?.listChanged) { - throw new Error( - `Client does not support roots list changed notifications (required for ${method})`, - ); - } - break; - case "notifications/initialized": - // No specific capability required for initialized - break; + async complete(params: CompleteRequest['params'], options?: RequestOptions) { + return this.request({ method: 'completion/complete', params }, CompleteResultSchema, options); + } - case "notifications/cancelled": - // Cancellation notifications are always allowed - break; + async setLoggingLevel(level: LoggingLevel, options?: RequestOptions) { + return this.request({ method: 'logging/setLevel', params: { level } }, EmptyResultSchema, options); + } - case "notifications/progress": - // Progress notifications are always allowed - break; + async getPrompt(params: GetPromptRequest['params'], options?: RequestOptions) { + return this.request({ method: 'prompts/get', params }, GetPromptResultSchema, options); } - } - - protected assertRequestHandlerCapability(method: string): void { - switch (method) { - case "sampling/createMessage": - if (!this._capabilities.sampling) { - throw new Error( - `Client does not support sampling capability (required for ${method})`, - ); - } - break; - case "elicitation/create": - if (!this._capabilities.elicitation) { - throw new Error( - `Client does not support elicitation capability (required for ${method})`, - ); - } - break; + async listPrompts(params?: ListPromptsRequest['params'], options?: RequestOptions) { + return this.request({ method: 'prompts/list', params }, ListPromptsResultSchema, options); + } + + async listResources(params?: ListResourcesRequest['params'], options?: RequestOptions) { + return this.request({ method: 'resources/list', params }, ListResourcesResultSchema, options); + } + + async listResourceTemplates(params?: ListResourceTemplatesRequest['params'], options?: RequestOptions) { + return this.request({ method: 'resources/templates/list', params }, ListResourceTemplatesResultSchema, options); + } + + async readResource(params: ReadResourceRequest['params'], options?: RequestOptions) { + return this.request({ method: 'resources/read', params }, ReadResourceResultSchema, options); + } + + async subscribeResource(params: SubscribeRequest['params'], options?: RequestOptions) { + return this.request({ method: 'resources/subscribe', params }, EmptyResultSchema, options); + } + + async unsubscribeResource(params: UnsubscribeRequest['params'], options?: RequestOptions) { + return this.request({ method: 'resources/unsubscribe', params }, EmptyResultSchema, options); + } - case "roots/list": - if (!this._capabilities.roots) { - throw new Error( - `Client does not support roots capability (required for ${method})`, - ); + async callTool( + params: CallToolRequest['params'], + resultSchema: typeof CallToolResultSchema | typeof CompatibilityCallToolResultSchema = CallToolResultSchema, + options?: RequestOptions + ) { + const result = await this.request({ method: 'tools/call', params }, resultSchema, options); + + // Check if the tool has an outputSchema + const validator = this.getToolOutputValidator(params.name); + if (validator) { + // If tool has outputSchema, it MUST return structuredContent (unless it's an error) + if (!result.structuredContent && !result.isError) { + throw new McpError( + ErrorCode.InvalidRequest, + `Tool ${params.name} has an output schema but did not return structured content` + ); + } + + // Only validate structured content if present (not when there's an error) + if (result.structuredContent) { + try { + // Validate the structured content (which is already an object) against the schema + const isValid = validator(result.structuredContent); + + if (!isValid) { + throw new McpError( + ErrorCode.InvalidParams, + `Structured content does not match the tool's output schema: ${this._ajv.errorsText(validator.errors)}` + ); + } + } catch (error) { + if (error instanceof McpError) { + throw error; + } + throw new McpError( + ErrorCode.InvalidParams, + `Failed to validate structured content: ${error instanceof Error ? error.message : String(error)}` + ); + } + } } - break; - case "ping": - // No specific capability required for ping - break; + return result; } - } - - async ping(options?: RequestOptions) { - return this.request({ method: "ping" }, EmptyResultSchema, options); - } - - async complete(params: CompleteRequest["params"], options?: RequestOptions) { - return this.request( - { method: "completion/complete", params }, - CompleteResultSchema, - options, - ); - } - - async setLoggingLevel(level: LoggingLevel, options?: RequestOptions) { - return this.request( - { method: "logging/setLevel", params: { level } }, - EmptyResultSchema, - options, - ); - } - - async getPrompt( - params: GetPromptRequest["params"], - options?: RequestOptions, - ) { - return this.request( - { method: "prompts/get", params }, - GetPromptResultSchema, - options, - ); - } - - async listPrompts( - params?: ListPromptsRequest["params"], - options?: RequestOptions, - ) { - return this.request( - { method: "prompts/list", params }, - ListPromptsResultSchema, - options, - ); - } - - async listResources( - params?: ListResourcesRequest["params"], - options?: RequestOptions, - ) { - return this.request( - { method: "resources/list", params }, - ListResourcesResultSchema, - options, - ); - } - - async listResourceTemplates( - params?: ListResourceTemplatesRequest["params"], - options?: RequestOptions, - ) { - return this.request( - { method: "resources/templates/list", params }, - ListResourceTemplatesResultSchema, - options, - ); - } - - async readResource( - params: ReadResourceRequest["params"], - options?: RequestOptions, - ) { - return this.request( - { method: "resources/read", params }, - ReadResourceResultSchema, - options, - ); - } - - async subscribeResource( - params: SubscribeRequest["params"], - options?: RequestOptions, - ) { - return this.request( - { method: "resources/subscribe", params }, - EmptyResultSchema, - options, - ); - } - - async unsubscribeResource( - params: UnsubscribeRequest["params"], - options?: RequestOptions, - ) { - return this.request( - { method: "resources/unsubscribe", params }, - EmptyResultSchema, - options, - ); - } - - async callTool( - params: CallToolRequest["params"], - resultSchema: - | typeof CallToolResultSchema - | typeof CompatibilityCallToolResultSchema = CallToolResultSchema, - options?: RequestOptions, - ) { - const result = await this.request( - { method: "tools/call", params }, - resultSchema, - options, - ); - - // Check if the tool has an outputSchema - const validator = this.getToolOutputValidator(params.name); - if (validator) { - // If tool has outputSchema, it MUST return structuredContent (unless it's an error) - if (!result.structuredContent && !result.isError) { - throw new McpError( - ErrorCode.InvalidRequest, - `Tool ${params.name} has an output schema but did not return structured content` - ); - } - - // Only validate structured content if present (not when there's an error) - if (result.structuredContent) { - try { - // Validate the structured content (which is already an object) against the schema - const isValid = validator(result.structuredContent); - if (!isValid) { - throw new McpError( - ErrorCode.InvalidParams, - `Structured content does not match the tool's output schema: ${this._ajv.errorsText(validator.errors)}` - ); - } - } catch (error) { - if (error instanceof McpError) { - throw error; - } - throw new McpError( - ErrorCode.InvalidParams, - `Failed to validate structured content: ${error instanceof Error ? error.message : String(error)}` - ); + private cacheToolOutputSchemas(tools: Tool[]) { + this._cachedToolOutputValidators.clear(); + + for (const tool of tools) { + // If the tool has an outputSchema, create and cache the Ajv validator + if (tool.outputSchema) { + try { + const validator = this._ajv.compile(tool.outputSchema); + this._cachedToolOutputValidators.set(tool.name, validator); + } catch { + // Ignore schema compilation errors + } + } } - } } - return result; - } + private getToolOutputValidator(toolName: string): ValidateFunction | undefined { + return this._cachedToolOutputValidators.get(toolName); + } - private cacheToolOutputSchemas(tools: Tool[]) { - this._cachedToolOutputValidators.clear(); + async listTools(params?: ListToolsRequest['params'], options?: RequestOptions) { + const result = await this.request({ method: 'tools/list', params }, ListToolsResultSchema, options); - for (const tool of tools) { - // If the tool has an outputSchema, create and cache the Ajv validator - if (tool.outputSchema) { - try { - const validator = this._ajv.compile(tool.outputSchema); - this._cachedToolOutputValidators.set(tool.name, validator); - } catch { - // Ignore schema compilation errors - } - } + // Cache the tools and their output schemas for future validation + this.cacheToolOutputSchemas(result.tools); + + return result; + } + + async sendRootsListChanged() { + return this.notification({ method: 'notifications/roots/list_changed' }); } - } - - private getToolOutputValidator(toolName: string): ValidateFunction | undefined { - return this._cachedToolOutputValidators.get(toolName); - } - - async listTools( - params?: ListToolsRequest["params"], - options?: RequestOptions, - ) { - const result = await this.request( - { method: "tools/list", params }, - ListToolsResultSchema, - options, - ); - - // Cache the tools and their output schemas for future validation - this.cacheToolOutputSchemas(result.tools); - - return result; - } - - async sendRootsListChanged() { - return this.notification({ method: "notifications/roots/list_changed" }); - } } diff --git a/src/client/middleware.test.ts b/src/client/middleware.test.ts index 265aa70d6..c0420514b 100644 --- a/src/client/middleware.test.ts +++ b/src/client/middleware.test.ts @@ -1,1213 +1,1105 @@ -import { - withOAuth, - withLogging, - applyMiddlewares, - createMiddleware, -} from "./middleware.js"; -import { OAuthClientProvider } from "./auth.js"; -import { FetchLike } from "../shared/transport.js"; - -jest.mock("../client/auth.js", () => { - const actual = jest.requireActual("../client/auth.js"); - return { - ...actual, - auth: jest.fn(), - extractResourceMetadataUrl: jest.fn(), - }; +import { withOAuth, withLogging, applyMiddlewares, createMiddleware } from './middleware.js'; +import { OAuthClientProvider } from './auth.js'; +import { FetchLike } from '../shared/transport.js'; + +jest.mock('../client/auth.js', () => { + const actual = jest.requireActual('../client/auth.js'); + return { + ...actual, + auth: jest.fn(), + extractResourceMetadataUrl: jest.fn() + }; }); -import { auth, extractResourceMetadataUrl } from "./auth.js"; +import { auth, extractResourceMetadataUrl } from './auth.js'; const mockAuth = auth as jest.MockedFunction; -const mockExtractResourceMetadataUrl = - extractResourceMetadataUrl as jest.MockedFunction< - typeof extractResourceMetadataUrl - >; - -describe("withOAuth", () => { - let mockProvider: jest.Mocked; - let mockFetch: jest.MockedFunction; - - beforeEach(() => { - jest.clearAllMocks(); - - mockProvider = { - get redirectUrl() { - return "http://localhost/callback"; - }, - get clientMetadata() { - return { redirect_uris: ["http://localhost/callback"] }; - }, - tokens: jest.fn(), - saveTokens: jest.fn(), - clientInformation: jest.fn(), - redirectToAuthorization: jest.fn(), - saveCodeVerifier: jest.fn(), - codeVerifier: jest.fn(), - invalidateCredentials: jest.fn(), - }; +const mockExtractResourceMetadataUrl = extractResourceMetadataUrl as jest.MockedFunction; + +describe('withOAuth', () => { + let mockProvider: jest.Mocked; + let mockFetch: jest.MockedFunction; + + beforeEach(() => { + jest.clearAllMocks(); + + mockProvider = { + get redirectUrl() { + return 'http://localhost/callback'; + }, + get clientMetadata() { + return { redirect_uris: ['http://localhost/callback'] }; + }, + tokens: jest.fn(), + saveTokens: jest.fn(), + clientInformation: jest.fn(), + redirectToAuthorization: jest.fn(), + saveCodeVerifier: jest.fn(), + codeVerifier: jest.fn(), + invalidateCredentials: jest.fn() + }; + + mockFetch = jest.fn(); + }); + + it('should add Authorization header when tokens are available (with explicit baseUrl)', async () => { + mockProvider.tokens.mockResolvedValue({ + access_token: 'test-token', + token_type: 'Bearer', + expires_in: 3600 + }); - mockFetch = jest.fn(); - }); + mockFetch.mockResolvedValue(new Response('success', { status: 200 })); - it("should add Authorization header when tokens are available (with explicit baseUrl)", async () => { - mockProvider.tokens.mockResolvedValue({ - access_token: "test-token", - token_type: "Bearer", - expires_in: 3600, + const enhancedFetch = withOAuth(mockProvider, 'https://api.example.com')(mockFetch); + + await enhancedFetch('https://api.example.com/data'); + + expect(mockFetch).toHaveBeenCalledWith( + 'https://api.example.com/data', + expect.objectContaining({ + headers: expect.any(Headers) + }) + ); + + const callArgs = mockFetch.mock.calls[0]; + const headers = callArgs[1]?.headers as Headers; + expect(headers.get('Authorization')).toBe('Bearer test-token'); }); - mockFetch.mockResolvedValue(new Response("success", { status: 200 })); + it('should add Authorization header when tokens are available (without baseUrl)', async () => { + mockProvider.tokens.mockResolvedValue({ + access_token: 'test-token', + token_type: 'Bearer', + expires_in: 3600 + }); - const enhancedFetch = withOAuth( - mockProvider, - "https://api.example.com", - )(mockFetch); + mockFetch.mockResolvedValue(new Response('success', { status: 200 })); - await enhancedFetch("https://api.example.com/data"); + // Test without baseUrl - should extract from request URL + const enhancedFetch = withOAuth(mockProvider)(mockFetch); - expect(mockFetch).toHaveBeenCalledWith( - "https://api.example.com/data", - expect.objectContaining({ - headers: expect.any(Headers), - }), - ); + await enhancedFetch('https://api.example.com/data'); - const callArgs = mockFetch.mock.calls[0]; - const headers = callArgs[1]?.headers as Headers; - expect(headers.get("Authorization")).toBe("Bearer test-token"); - }); + expect(mockFetch).toHaveBeenCalledWith( + 'https://api.example.com/data', + expect.objectContaining({ + headers: expect.any(Headers) + }) + ); - it("should add Authorization header when tokens are available (without baseUrl)", async () => { - mockProvider.tokens.mockResolvedValue({ - access_token: "test-token", - token_type: "Bearer", - expires_in: 3600, + const callArgs = mockFetch.mock.calls[0]; + const headers = callArgs[1]?.headers as Headers; + expect(headers.get('Authorization')).toBe('Bearer test-token'); }); - mockFetch.mockResolvedValue(new Response("success", { status: 200 })); - - // Test without baseUrl - should extract from request URL - const enhancedFetch = withOAuth(mockProvider)(mockFetch); - - await enhancedFetch("https://api.example.com/data"); - - expect(mockFetch).toHaveBeenCalledWith( - "https://api.example.com/data", - expect.objectContaining({ - headers: expect.any(Headers), - }), - ); - - const callArgs = mockFetch.mock.calls[0]; - const headers = callArgs[1]?.headers as Headers; - expect(headers.get("Authorization")).toBe("Bearer test-token"); - }); - - it("should handle requests without tokens (without baseUrl)", async () => { - mockProvider.tokens.mockResolvedValue(undefined); - mockFetch.mockResolvedValue(new Response("success", { status: 200 })); - - // Test without baseUrl - const enhancedFetch = withOAuth(mockProvider)(mockFetch); - - await enhancedFetch("https://api.example.com/data"); - - expect(mockFetch).toHaveBeenCalledTimes(1); - const callArgs = mockFetch.mock.calls[0]; - const headers = callArgs[1]?.headers as Headers; - expect(headers.get("Authorization")).toBeNull(); - }); - - it("should retry request after successful auth on 401 response (with explicit baseUrl)", async () => { - mockProvider.tokens - .mockResolvedValueOnce({ - access_token: "old-token", - token_type: "Bearer", - expires_in: 3600, - }) - .mockResolvedValueOnce({ - access_token: "new-token", - token_type: "Bearer", - expires_in: 3600, - }); - - const unauthorizedResponse = new Response("Unauthorized", { - status: 401, - headers: { "www-authenticate": 'Bearer realm="oauth"' }, - }); - const successResponse = new Response("success", { status: 200 }); - - mockFetch - .mockResolvedValueOnce(unauthorizedResponse) - .mockResolvedValueOnce(successResponse); - - const mockResourceUrl = new URL( - "https://oauth.example.com/.well-known/oauth-protected-resource", - ); - mockExtractResourceMetadataUrl.mockReturnValue(mockResourceUrl); - mockAuth.mockResolvedValue("AUTHORIZED"); - - const enhancedFetch = withOAuth( - mockProvider, - "https://api.example.com", - )(mockFetch); - - const result = await enhancedFetch("https://api.example.com/data"); - - expect(result).toBe(successResponse); - expect(mockFetch).toHaveBeenCalledTimes(2); - expect(mockAuth).toHaveBeenCalledWith(mockProvider, { - serverUrl: "https://api.example.com", - resourceMetadataUrl: mockResourceUrl, - fetchFn: mockFetch, - }); + it('should handle requests without tokens (without baseUrl)', async () => { + mockProvider.tokens.mockResolvedValue(undefined); + mockFetch.mockResolvedValue(new Response('success', { status: 200 })); + + // Test without baseUrl + const enhancedFetch = withOAuth(mockProvider)(mockFetch); + + await enhancedFetch('https://api.example.com/data'); - // Verify the retry used the new token - const retryCallArgs = mockFetch.mock.calls[1]; - const retryHeaders = retryCallArgs[1]?.headers as Headers; - expect(retryHeaders.get("Authorization")).toBe("Bearer new-token"); - }); - - it("should retry request after successful auth on 401 response (without baseUrl)", async () => { - mockProvider.tokens - .mockResolvedValueOnce({ - access_token: "old-token", - token_type: "Bearer", - expires_in: 3600, - }) - .mockResolvedValueOnce({ - access_token: "new-token", - token_type: "Bearer", - expires_in: 3600, - }); - - const unauthorizedResponse = new Response("Unauthorized", { - status: 401, - headers: { "www-authenticate": 'Bearer realm="oauth"' }, + expect(mockFetch).toHaveBeenCalledTimes(1); + const callArgs = mockFetch.mock.calls[0]; + const headers = callArgs[1]?.headers as Headers; + expect(headers.get('Authorization')).toBeNull(); }); - const successResponse = new Response("success", { status: 200 }); - mockFetch - .mockResolvedValueOnce(unauthorizedResponse) - .mockResolvedValueOnce(successResponse); + it('should retry request after successful auth on 401 response (with explicit baseUrl)', async () => { + mockProvider.tokens + .mockResolvedValueOnce({ + access_token: 'old-token', + token_type: 'Bearer', + expires_in: 3600 + }) + .mockResolvedValueOnce({ + access_token: 'new-token', + token_type: 'Bearer', + expires_in: 3600 + }); + + const unauthorizedResponse = new Response('Unauthorized', { + status: 401, + headers: { 'www-authenticate': 'Bearer realm="oauth"' } + }); + const successResponse = new Response('success', { status: 200 }); + + mockFetch.mockResolvedValueOnce(unauthorizedResponse).mockResolvedValueOnce(successResponse); - const mockResourceUrl = new URL( - "https://oauth.example.com/.well-known/oauth-protected-resource", - ); - mockExtractResourceMetadataUrl.mockReturnValue(mockResourceUrl); - mockAuth.mockResolvedValue("AUTHORIZED"); + const mockResourceUrl = new URL('https://oauth.example.com/.well-known/oauth-protected-resource'); + mockExtractResourceMetadataUrl.mockReturnValue(mockResourceUrl); + mockAuth.mockResolvedValue('AUTHORIZED'); - // Test without baseUrl - should extract from request URL - const enhancedFetch = withOAuth(mockProvider)(mockFetch); + const enhancedFetch = withOAuth(mockProvider, 'https://api.example.com')(mockFetch); - const result = await enhancedFetch("https://api.example.com/data"); + const result = await enhancedFetch('https://api.example.com/data'); - expect(result).toBe(successResponse); - expect(mockFetch).toHaveBeenCalledTimes(2); - expect(mockAuth).toHaveBeenCalledWith(mockProvider, { - serverUrl: "https://api.example.com", // Should be extracted from request URL - resourceMetadataUrl: mockResourceUrl, - fetchFn: mockFetch, + expect(result).toBe(successResponse); + expect(mockFetch).toHaveBeenCalledTimes(2); + expect(mockAuth).toHaveBeenCalledWith(mockProvider, { + serverUrl: 'https://api.example.com', + resourceMetadataUrl: mockResourceUrl, + fetchFn: mockFetch + }); + + // Verify the retry used the new token + const retryCallArgs = mockFetch.mock.calls[1]; + const retryHeaders = retryCallArgs[1]?.headers as Headers; + expect(retryHeaders.get('Authorization')).toBe('Bearer new-token'); }); - // Verify the retry used the new token - const retryCallArgs = mockFetch.mock.calls[1]; - const retryHeaders = retryCallArgs[1]?.headers as Headers; - expect(retryHeaders.get("Authorization")).toBe("Bearer new-token"); - }); - - it("should throw UnauthorizedError when auth returns REDIRECT (without baseUrl)", async () => { - mockProvider.tokens.mockResolvedValue({ - access_token: "test-token", - token_type: "Bearer", - expires_in: 3600, + it('should retry request after successful auth on 401 response (without baseUrl)', async () => { + mockProvider.tokens + .mockResolvedValueOnce({ + access_token: 'old-token', + token_type: 'Bearer', + expires_in: 3600 + }) + .mockResolvedValueOnce({ + access_token: 'new-token', + token_type: 'Bearer', + expires_in: 3600 + }); + + const unauthorizedResponse = new Response('Unauthorized', { + status: 401, + headers: { 'www-authenticate': 'Bearer realm="oauth"' } + }); + const successResponse = new Response('success', { status: 200 }); + + mockFetch.mockResolvedValueOnce(unauthorizedResponse).mockResolvedValueOnce(successResponse); + + const mockResourceUrl = new URL('https://oauth.example.com/.well-known/oauth-protected-resource'); + mockExtractResourceMetadataUrl.mockReturnValue(mockResourceUrl); + mockAuth.mockResolvedValue('AUTHORIZED'); + + // Test without baseUrl - should extract from request URL + const enhancedFetch = withOAuth(mockProvider)(mockFetch); + + const result = await enhancedFetch('https://api.example.com/data'); + + expect(result).toBe(successResponse); + expect(mockFetch).toHaveBeenCalledTimes(2); + expect(mockAuth).toHaveBeenCalledWith(mockProvider, { + serverUrl: 'https://api.example.com', // Should be extracted from request URL + resourceMetadataUrl: mockResourceUrl, + fetchFn: mockFetch + }); + + // Verify the retry used the new token + const retryCallArgs = mockFetch.mock.calls[1]; + const retryHeaders = retryCallArgs[1]?.headers as Headers; + expect(retryHeaders.get('Authorization')).toBe('Bearer new-token'); }); - mockFetch.mockResolvedValue(new Response("Unauthorized", { status: 401 })); - mockExtractResourceMetadataUrl.mockReturnValue(undefined); - mockAuth.mockResolvedValue("REDIRECT"); + it('should throw UnauthorizedError when auth returns REDIRECT (without baseUrl)', async () => { + mockProvider.tokens.mockResolvedValue({ + access_token: 'test-token', + token_type: 'Bearer', + expires_in: 3600 + }); - // Test without baseUrl - const enhancedFetch = withOAuth(mockProvider)(mockFetch); + mockFetch.mockResolvedValue(new Response('Unauthorized', { status: 401 })); + mockExtractResourceMetadataUrl.mockReturnValue(undefined); + mockAuth.mockResolvedValue('REDIRECT'); - await expect(enhancedFetch("https://api.example.com/data")).rejects.toThrow( - "Authentication requires user authorization - redirect initiated", - ); - }); + // Test without baseUrl + const enhancedFetch = withOAuth(mockProvider)(mockFetch); - it("should throw UnauthorizedError when auth fails", async () => { - mockProvider.tokens.mockResolvedValue({ - access_token: "test-token", - token_type: "Bearer", - expires_in: 3600, + await expect(enhancedFetch('https://api.example.com/data')).rejects.toThrow( + 'Authentication requires user authorization - redirect initiated' + ); }); - mockFetch.mockResolvedValue(new Response("Unauthorized", { status: 401 })); - mockExtractResourceMetadataUrl.mockReturnValue(undefined); - mockAuth.mockRejectedValue(new Error("Network error")); - - const enhancedFetch = withOAuth( - mockProvider, - "https://api.example.com", - )(mockFetch); - - await expect(enhancedFetch("https://api.example.com/data")).rejects.toThrow( - "Failed to re-authenticate: Network error", - ); - }); - - it("should handle persistent 401 responses after auth", async () => { - mockProvider.tokens.mockResolvedValue({ - access_token: "test-token", - token_type: "Bearer", - expires_in: 3600, + it('should throw UnauthorizedError when auth fails', async () => { + mockProvider.tokens.mockResolvedValue({ + access_token: 'test-token', + token_type: 'Bearer', + expires_in: 3600 + }); + + mockFetch.mockResolvedValue(new Response('Unauthorized', { status: 401 })); + mockExtractResourceMetadataUrl.mockReturnValue(undefined); + mockAuth.mockRejectedValue(new Error('Network error')); + + const enhancedFetch = withOAuth(mockProvider, 'https://api.example.com')(mockFetch); + + await expect(enhancedFetch('https://api.example.com/data')).rejects.toThrow('Failed to re-authenticate: Network error'); }); - // Always return 401 - mockFetch.mockResolvedValue(new Response("Unauthorized", { status: 401 })); - mockExtractResourceMetadataUrl.mockReturnValue(undefined); - mockAuth.mockResolvedValue("AUTHORIZED"); - - const enhancedFetch = withOAuth( - mockProvider, - "https://api.example.com", - )(mockFetch); - - await expect(enhancedFetch("https://api.example.com/data")).rejects.toThrow( - "Authentication failed for https://api.example.com/data", - ); - - // Should have made initial request + 1 retry after auth = 2 total - expect(mockFetch).toHaveBeenCalledTimes(2); - expect(mockAuth).toHaveBeenCalledTimes(1); - }); - - it("should preserve original request method and body", async () => { - mockProvider.tokens.mockResolvedValue({ - access_token: "test-token", - token_type: "Bearer", - expires_in: 3600, + it('should handle persistent 401 responses after auth', async () => { + mockProvider.tokens.mockResolvedValue({ + access_token: 'test-token', + token_type: 'Bearer', + expires_in: 3600 + }); + + // Always return 401 + mockFetch.mockResolvedValue(new Response('Unauthorized', { status: 401 })); + mockExtractResourceMetadataUrl.mockReturnValue(undefined); + mockAuth.mockResolvedValue('AUTHORIZED'); + + const enhancedFetch = withOAuth(mockProvider, 'https://api.example.com')(mockFetch); + + await expect(enhancedFetch('https://api.example.com/data')).rejects.toThrow( + 'Authentication failed for https://api.example.com/data' + ); + + // Should have made initial request + 1 retry after auth = 2 total + expect(mockFetch).toHaveBeenCalledTimes(2); + expect(mockAuth).toHaveBeenCalledTimes(1); }); - mockFetch.mockResolvedValue(new Response("success", { status: 200 })); + it('should preserve original request method and body', async () => { + mockProvider.tokens.mockResolvedValue({ + access_token: 'test-token', + token_type: 'Bearer', + expires_in: 3600 + }); + + mockFetch.mockResolvedValue(new Response('success', { status: 200 })); - const enhancedFetch = withOAuth( - mockProvider, - "https://api.example.com", - )(mockFetch); + const enhancedFetch = withOAuth(mockProvider, 'https://api.example.com')(mockFetch); - const requestBody = JSON.stringify({ data: "test" }); - await enhancedFetch("https://api.example.com/data", { - method: "POST", - body: requestBody, - headers: { "Content-Type": "application/json" }, + const requestBody = JSON.stringify({ data: 'test' }); + await enhancedFetch('https://api.example.com/data', { + method: 'POST', + body: requestBody, + headers: { 'Content-Type': 'application/json' } + }); + + expect(mockFetch).toHaveBeenCalledWith( + 'https://api.example.com/data', + expect.objectContaining({ + method: 'POST', + body: requestBody, + headers: expect.any(Headers) + }) + ); + + const callArgs = mockFetch.mock.calls[0]; + const headers = callArgs[1]?.headers as Headers; + expect(headers.get('Content-Type')).toBe('application/json'); + expect(headers.get('Authorization')).toBe('Bearer test-token'); }); - expect(mockFetch).toHaveBeenCalledWith( - "https://api.example.com/data", - expect.objectContaining({ - method: "POST", - body: requestBody, - headers: expect.any(Headers), - }), - ); - - const callArgs = mockFetch.mock.calls[0]; - const headers = callArgs[1]?.headers as Headers; - expect(headers.get("Content-Type")).toBe("application/json"); - expect(headers.get("Authorization")).toBe("Bearer test-token"); - }); - - it("should handle non-401 errors normally", async () => { - mockProvider.tokens.mockResolvedValue({ - access_token: "test-token", - token_type: "Bearer", - expires_in: 3600, + it('should handle non-401 errors normally', async () => { + mockProvider.tokens.mockResolvedValue({ + access_token: 'test-token', + token_type: 'Bearer', + expires_in: 3600 + }); + + const serverErrorResponse = new Response('Server Error', { status: 500 }); + mockFetch.mockResolvedValue(serverErrorResponse); + + const enhancedFetch = withOAuth(mockProvider, 'https://api.example.com')(mockFetch); + + const result = await enhancedFetch('https://api.example.com/data'); + + expect(result).toBe(serverErrorResponse); + expect(mockFetch).toHaveBeenCalledTimes(1); + expect(mockAuth).not.toHaveBeenCalled(); }); - const serverErrorResponse = new Response("Server Error", { status: 500 }); - mockFetch.mockResolvedValue(serverErrorResponse); + it('should handle URL object as input (without baseUrl)', async () => { + mockProvider.tokens.mockResolvedValue({ + access_token: 'test-token', + token_type: 'Bearer', + expires_in: 3600 + }); - const enhancedFetch = withOAuth( - mockProvider, - "https://api.example.com", - )(mockFetch); + mockFetch.mockResolvedValue(new Response('success', { status: 200 })); - const result = await enhancedFetch("https://api.example.com/data"); + // Test URL object without baseUrl - should extract origin from URL object + const enhancedFetch = withOAuth(mockProvider)(mockFetch); - expect(result).toBe(serverErrorResponse); - expect(mockFetch).toHaveBeenCalledTimes(1); - expect(mockAuth).not.toHaveBeenCalled(); - }); + await enhancedFetch(new URL('https://api.example.com/data')); - it("should handle URL object as input (without baseUrl)", async () => { - mockProvider.tokens.mockResolvedValue({ - access_token: "test-token", - token_type: "Bearer", - expires_in: 3600, + expect(mockFetch).toHaveBeenCalledWith( + expect.any(URL), + expect.objectContaining({ + headers: expect.any(Headers) + }) + ); }); - mockFetch.mockResolvedValue(new Response("success", { status: 200 })); - - // Test URL object without baseUrl - should extract origin from URL object - const enhancedFetch = withOAuth(mockProvider)(mockFetch); - - await enhancedFetch(new URL("https://api.example.com/data")); - - expect(mockFetch).toHaveBeenCalledWith( - expect.any(URL), - expect.objectContaining({ - headers: expect.any(Headers), - }), - ); - }); - - it("should handle URL object in auth retry (without baseUrl)", async () => { - mockProvider.tokens - .mockResolvedValueOnce({ - access_token: "old-token", - token_type: "Bearer", - expires_in: 3600, - }) - .mockResolvedValueOnce({ - access_token: "new-token", - token_type: "Bearer", - expires_in: 3600, - }); - - const unauthorizedResponse = new Response("Unauthorized", { status: 401 }); - const successResponse = new Response("success", { status: 200 }); - - mockFetch - .mockResolvedValueOnce(unauthorizedResponse) - .mockResolvedValueOnce(successResponse); - - mockExtractResourceMetadataUrl.mockReturnValue(undefined); - mockAuth.mockResolvedValue("AUTHORIZED"); - - const enhancedFetch = withOAuth(mockProvider)(mockFetch); - - const result = await enhancedFetch(new URL("https://api.example.com/data")); - - expect(result).toBe(successResponse); - expect(mockFetch).toHaveBeenCalledTimes(2); - expect(mockAuth).toHaveBeenCalledWith(mockProvider, { - serverUrl: "https://api.example.com", // Should extract origin from URL object - resourceMetadataUrl: undefined, - fetchFn: mockFetch, + it('should handle URL object in auth retry (without baseUrl)', async () => { + mockProvider.tokens + .mockResolvedValueOnce({ + access_token: 'old-token', + token_type: 'Bearer', + expires_in: 3600 + }) + .mockResolvedValueOnce({ + access_token: 'new-token', + token_type: 'Bearer', + expires_in: 3600 + }); + + const unauthorizedResponse = new Response('Unauthorized', { status: 401 }); + const successResponse = new Response('success', { status: 200 }); + + mockFetch.mockResolvedValueOnce(unauthorizedResponse).mockResolvedValueOnce(successResponse); + + mockExtractResourceMetadataUrl.mockReturnValue(undefined); + mockAuth.mockResolvedValue('AUTHORIZED'); + + const enhancedFetch = withOAuth(mockProvider)(mockFetch); + + const result = await enhancedFetch(new URL('https://api.example.com/data')); + + expect(result).toBe(successResponse); + expect(mockFetch).toHaveBeenCalledTimes(2); + expect(mockAuth).toHaveBeenCalledWith(mockProvider, { + serverUrl: 'https://api.example.com', // Should extract origin from URL object + resourceMetadataUrl: undefined, + fetchFn: mockFetch + }); }); - }); }); -describe("withLogging", () => { - let mockFetch: jest.MockedFunction; - let mockLogger: jest.MockedFunction< - (input: { - method: string; - url: string | URL; - status: number; - statusText: string; - duration: number; - requestHeaders?: Headers; - responseHeaders?: Headers; - error?: Error; - }) => void - >; - let consoleErrorSpy: jest.SpyInstance; - let consoleLogSpy: jest.SpyInstance; - - beforeEach(() => { - jest.clearAllMocks(); - - consoleErrorSpy = jest.spyOn(console, "error").mockImplementation(() => {}); - consoleLogSpy = jest.spyOn(console, "log").mockImplementation(() => {}); - - mockFetch = jest.fn(); - mockLogger = jest.fn(); - }); - - afterEach(() => { - consoleErrorSpy.mockRestore(); - consoleLogSpy.mockRestore(); - }); - - it("should log successful requests with default logger", async () => { - const response = new Response("success", { status: 200, statusText: "OK" }); - mockFetch.mockResolvedValue(response); - - const enhancedFetch = withLogging()(mockFetch); - - await enhancedFetch("https://api.example.com/data"); - - expect(consoleLogSpy).toHaveBeenCalledWith( - expect.stringMatching( - /HTTP GET https:\/\/api\.example\.com\/data 200 OK \(\d+\.\d+ms\)/, - ), - ); - }); - - it("should log error responses with default logger", async () => { - const response = new Response("Not Found", { - status: 404, - statusText: "Not Found", +describe('withLogging', () => { + let mockFetch: jest.MockedFunction; + let mockLogger: jest.MockedFunction< + (input: { + method: string; + url: string | URL; + status: number; + statusText: string; + duration: number; + requestHeaders?: Headers; + responseHeaders?: Headers; + error?: Error; + }) => void + >; + let consoleErrorSpy: jest.SpyInstance; + let consoleLogSpy: jest.SpyInstance; + + beforeEach(() => { + jest.clearAllMocks(); + + consoleErrorSpy = jest.spyOn(console, 'error').mockImplementation(() => {}); + consoleLogSpy = jest.spyOn(console, 'log').mockImplementation(() => {}); + + mockFetch = jest.fn(); + mockLogger = jest.fn(); }); - mockFetch.mockResolvedValue(response); - const enhancedFetch = withLogging()(mockFetch); + afterEach(() => { + consoleErrorSpy.mockRestore(); + consoleLogSpy.mockRestore(); + }); - await enhancedFetch("https://api.example.com/data"); + it('should log successful requests with default logger', async () => { + const response = new Response('success', { status: 200, statusText: 'OK' }); + mockFetch.mockResolvedValue(response); - expect(consoleErrorSpy).toHaveBeenCalledWith( - expect.stringMatching( - /HTTP GET https:\/\/api\.example\.com\/data 404 Not Found \(\d+\.\d+ms\)/, - ), - ); - }); + const enhancedFetch = withLogging()(mockFetch); - it("should log network errors with default logger", async () => { - const networkError = new Error("Network connection failed"); - mockFetch.mockRejectedValue(networkError); + await enhancedFetch('https://api.example.com/data'); - const enhancedFetch = withLogging()(mockFetch); + expect(consoleLogSpy).toHaveBeenCalledWith( + expect.stringMatching(/HTTP GET https:\/\/api\.example\.com\/data 200 OK \(\d+\.\d+ms\)/) + ); + }); - await expect(enhancedFetch("https://api.example.com/data")).rejects.toThrow( - "Network connection failed", - ); + it('should log error responses with default logger', async () => { + const response = new Response('Not Found', { + status: 404, + statusText: 'Not Found' + }); + mockFetch.mockResolvedValue(response); - expect(consoleErrorSpy).toHaveBeenCalledWith( - expect.stringMatching( - /HTTP GET https:\/\/api\.example\.com\/data failed: Network connection failed \(\d+\.\d+ms\)/, - ), - ); - }); + const enhancedFetch = withLogging()(mockFetch); - it("should use custom logger when provided", async () => { - const response = new Response("success", { status: 200, statusText: "OK" }); - mockFetch.mockResolvedValue(response); + await enhancedFetch('https://api.example.com/data'); + + expect(consoleErrorSpy).toHaveBeenCalledWith( + expect.stringMatching(/HTTP GET https:\/\/api\.example\.com\/data 404 Not Found \(\d+\.\d+ms\)/) + ); + }); - const enhancedFetch = withLogging({ logger: mockLogger })(mockFetch); + it('should log network errors with default logger', async () => { + const networkError = new Error('Network connection failed'); + mockFetch.mockRejectedValue(networkError); - await enhancedFetch("https://api.example.com/data", { method: "POST" }); + const enhancedFetch = withLogging()(mockFetch); - expect(mockLogger).toHaveBeenCalledWith({ - method: "POST", - url: "https://api.example.com/data", - status: 200, - statusText: "OK", - duration: expect.any(Number), - requestHeaders: undefined, - responseHeaders: undefined, + await expect(enhancedFetch('https://api.example.com/data')).rejects.toThrow('Network connection failed'); + + expect(consoleErrorSpy).toHaveBeenCalledWith( + expect.stringMatching(/HTTP GET https:\/\/api\.example\.com\/data failed: Network connection failed \(\d+\.\d+ms\)/) + ); }); - expect(consoleLogSpy).not.toHaveBeenCalled(); - }); + it('should use custom logger when provided', async () => { + const response = new Response('success', { status: 200, statusText: 'OK' }); + mockFetch.mockResolvedValue(response); - it("should include request headers when configured", async () => { - const response = new Response("success", { status: 200, statusText: "OK" }); - mockFetch.mockResolvedValue(response); + const enhancedFetch = withLogging({ logger: mockLogger })(mockFetch); - const enhancedFetch = withLogging({ - logger: mockLogger, - includeRequestHeaders: true, - })(mockFetch); + await enhancedFetch('https://api.example.com/data', { method: 'POST' }); - await enhancedFetch("https://api.example.com/data", { - headers: { - Authorization: "Bearer token", - "Content-Type": "application/json", - }, - }); + expect(mockLogger).toHaveBeenCalledWith({ + method: 'POST', + url: 'https://api.example.com/data', + status: 200, + statusText: 'OK', + duration: expect.any(Number), + requestHeaders: undefined, + responseHeaders: undefined + }); - expect(mockLogger).toHaveBeenCalledWith({ - method: "GET", - url: "https://api.example.com/data", - status: 200, - statusText: "OK", - duration: expect.any(Number), - requestHeaders: expect.any(Headers), - responseHeaders: undefined, + expect(consoleLogSpy).not.toHaveBeenCalled(); }); - const logCall = mockLogger.mock.calls[0][0]; - expect(logCall.requestHeaders?.get("Authorization")).toBe("Bearer token"); - expect(logCall.requestHeaders?.get("Content-Type")).toBe( - "application/json", - ); - }); - - it("should include response headers when configured", async () => { - const response = new Response("success", { - status: 200, - statusText: "OK", - headers: { - "Content-Type": "application/json", - "Cache-Control": "no-cache", - }, - }); - mockFetch.mockResolvedValue(response); - - const enhancedFetch = withLogging({ - logger: mockLogger, - includeResponseHeaders: true, - })(mockFetch); - - await enhancedFetch("https://api.example.com/data"); - - const logCall = mockLogger.mock.calls[0][0]; - expect(logCall.responseHeaders?.get("Content-Type")).toBe( - "application/json", - ); - expect(logCall.responseHeaders?.get("Cache-Control")).toBe("no-cache"); - }); - - it("should respect statusLevel option", async () => { - const successResponse = new Response("success", { - status: 200, - statusText: "OK", - }); - const errorResponse = new Response("Server Error", { - status: 500, - statusText: "Internal Server Error", + it('should include request headers when configured', async () => { + const response = new Response('success', { status: 200, statusText: 'OK' }); + mockFetch.mockResolvedValue(response); + + const enhancedFetch = withLogging({ + logger: mockLogger, + includeRequestHeaders: true + })(mockFetch); + + await enhancedFetch('https://api.example.com/data', { + headers: { + Authorization: 'Bearer token', + 'Content-Type': 'application/json' + } + }); + + expect(mockLogger).toHaveBeenCalledWith({ + method: 'GET', + url: 'https://api.example.com/data', + status: 200, + statusText: 'OK', + duration: expect.any(Number), + requestHeaders: expect.any(Headers), + responseHeaders: undefined + }); + + const logCall = mockLogger.mock.calls[0][0]; + expect(logCall.requestHeaders?.get('Authorization')).toBe('Bearer token'); + expect(logCall.requestHeaders?.get('Content-Type')).toBe('application/json'); }); - mockFetch - .mockResolvedValueOnce(successResponse) - .mockResolvedValueOnce(errorResponse); - - const enhancedFetch = withLogging({ - logger: mockLogger, - statusLevel: 400, - })(mockFetch); - - // 200 response should not be logged (below statusLevel 400) - await enhancedFetch("https://api.example.com/success"); - expect(mockLogger).not.toHaveBeenCalled(); - - // 500 response should be logged (above statusLevel 400) - await enhancedFetch("https://api.example.com/error"); - expect(mockLogger).toHaveBeenCalledWith({ - method: "GET", - url: "https://api.example.com/error", - status: 500, - statusText: "Internal Server Error", - duration: expect.any(Number), - requestHeaders: undefined, - responseHeaders: undefined, + it('should include response headers when configured', async () => { + const response = new Response('success', { + status: 200, + statusText: 'OK', + headers: { + 'Content-Type': 'application/json', + 'Cache-Control': 'no-cache' + } + }); + mockFetch.mockResolvedValue(response); + + const enhancedFetch = withLogging({ + logger: mockLogger, + includeResponseHeaders: true + })(mockFetch); + + await enhancedFetch('https://api.example.com/data'); + + const logCall = mockLogger.mock.calls[0][0]; + expect(logCall.responseHeaders?.get('Content-Type')).toBe('application/json'); + expect(logCall.responseHeaders?.get('Cache-Control')).toBe('no-cache'); }); - }); - - it("should always log network errors regardless of statusLevel", async () => { - const networkError = new Error("Connection timeout"); - mockFetch.mockRejectedValue(networkError); - - const enhancedFetch = withLogging({ - logger: mockLogger, - statusLevel: 500, // Very high log level - })(mockFetch); - - await expect(enhancedFetch("https://api.example.com/data")).rejects.toThrow( - "Connection timeout", - ); - - expect(mockLogger).toHaveBeenCalledWith({ - method: "GET", - url: "https://api.example.com/data", - status: 0, - statusText: "Network Error", - duration: expect.any(Number), - requestHeaders: undefined, - error: networkError, + + it('should respect statusLevel option', async () => { + const successResponse = new Response('success', { + status: 200, + statusText: 'OK' + }); + const errorResponse = new Response('Server Error', { + status: 500, + statusText: 'Internal Server Error' + }); + + mockFetch.mockResolvedValueOnce(successResponse).mockResolvedValueOnce(errorResponse); + + const enhancedFetch = withLogging({ + logger: mockLogger, + statusLevel: 400 + })(mockFetch); + + // 200 response should not be logged (below statusLevel 400) + await enhancedFetch('https://api.example.com/success'); + expect(mockLogger).not.toHaveBeenCalled(); + + // 500 response should be logged (above statusLevel 400) + await enhancedFetch('https://api.example.com/error'); + expect(mockLogger).toHaveBeenCalledWith({ + method: 'GET', + url: 'https://api.example.com/error', + status: 500, + statusText: 'Internal Server Error', + duration: expect.any(Number), + requestHeaders: undefined, + responseHeaders: undefined + }); }); - }); - it("should include headers in default logger message when configured", async () => { - const response = new Response("success", { - status: 200, - statusText: "OK", - headers: { "Content-Type": "application/json" }, + it('should always log network errors regardless of statusLevel', async () => { + const networkError = new Error('Connection timeout'); + mockFetch.mockRejectedValue(networkError); + + const enhancedFetch = withLogging({ + logger: mockLogger, + statusLevel: 500 // Very high log level + })(mockFetch); + + await expect(enhancedFetch('https://api.example.com/data')).rejects.toThrow('Connection timeout'); + + expect(mockLogger).toHaveBeenCalledWith({ + method: 'GET', + url: 'https://api.example.com/data', + status: 0, + statusText: 'Network Error', + duration: expect.any(Number), + requestHeaders: undefined, + error: networkError + }); }); - mockFetch.mockResolvedValue(response); - const enhancedFetch = withLogging({ - includeRequestHeaders: true, - includeResponseHeaders: true, - })(mockFetch); + it('should include headers in default logger message when configured', async () => { + const response = new Response('success', { + status: 200, + statusText: 'OK', + headers: { 'Content-Type': 'application/json' } + }); + mockFetch.mockResolvedValue(response); + + const enhancedFetch = withLogging({ + includeRequestHeaders: true, + includeResponseHeaders: true + })(mockFetch); - await enhancedFetch("https://api.example.com/data", { - headers: { Authorization: "Bearer token" }, - }); + await enhancedFetch('https://api.example.com/data', { + headers: { Authorization: 'Bearer token' } + }); - expect(consoleLogSpy).toHaveBeenCalledWith( - expect.stringContaining("Request Headers: {authorization: Bearer token}"), - ); - expect(consoleLogSpy).toHaveBeenCalledWith( - expect.stringContaining( - "Response Headers: {content-type: application/json}", - ), - ); - }); - - it("should measure request duration accurately", async () => { - // Mock a slow response - const response = new Response("success", { status: 200 }); - mockFetch.mockImplementation(async () => { - await new Promise((resolve) => setTimeout(resolve, 100)); - return response; + expect(consoleLogSpy).toHaveBeenCalledWith(expect.stringContaining('Request Headers: {authorization: Bearer token}')); + expect(consoleLogSpy).toHaveBeenCalledWith(expect.stringContaining('Response Headers: {content-type: application/json}')); }); - const enhancedFetch = withLogging({ logger: mockLogger })(mockFetch); + it('should measure request duration accurately', async () => { + // Mock a slow response + const response = new Response('success', { status: 200 }); + mockFetch.mockImplementation(async () => { + await new Promise(resolve => setTimeout(resolve, 100)); + return response; + }); - await enhancedFetch("https://api.example.com/data"); + const enhancedFetch = withLogging({ logger: mockLogger })(mockFetch); - const logCall = mockLogger.mock.calls[0][0]; - expect(logCall.duration).toBeGreaterThanOrEqual(90); // Allow some margin for timing - }); -}); + await enhancedFetch('https://api.example.com/data'); -describe("applyMiddleware", () => { - let mockFetch: jest.MockedFunction; - - beforeEach(() => { - jest.clearAllMocks(); - mockFetch = jest.fn(); - }); - - it("should compose no middleware correctly", () => { - const response = new Response("success", { status: 200 }); - mockFetch.mockResolvedValue(response); - - const composedFetch = applyMiddlewares()(mockFetch); - - expect(composedFetch).toBe(mockFetch); - }); - - it("should compose single middleware correctly", async () => { - const response = new Response("success", { status: 200 }); - mockFetch.mockResolvedValue(response); - - // Create a middleware that adds a header - const middleware1 = - (next: FetchLike) => async (input: string | URL, init?: RequestInit) => { - const headers = new Headers(init?.headers); - headers.set("X-Middleware-1", "applied"); - return next(input, { ...init, headers }); - }; - - const composedFetch = applyMiddlewares(middleware1)(mockFetch); - - await composedFetch("https://api.example.com/data"); - - expect(mockFetch).toHaveBeenCalledWith( - "https://api.example.com/data", - expect.objectContaining({ - headers: expect.any(Headers), - }), - ); - - const callArgs = mockFetch.mock.calls[0]; - const headers = callArgs[1]?.headers as Headers; - expect(headers.get("X-Middleware-1")).toBe("applied"); - }); - - it("should compose multiple middleware in order", async () => { - const response = new Response("success", { status: 200 }); - mockFetch.mockResolvedValue(response); - - // Create middleware that add identifying headers - const middleware1 = - (next: FetchLike) => async (input: string | URL, init?: RequestInit) => { - const headers = new Headers(init?.headers); - headers.set("X-Middleware-1", "applied"); - return next(input, { ...init, headers }); - }; - - const middleware2 = - (next: FetchLike) => async (input: string | URL, init?: RequestInit) => { - const headers = new Headers(init?.headers); - headers.set("X-Middleware-2", "applied"); - return next(input, { ...init, headers }); - }; - - const middleware3 = - (next: FetchLike) => async (input: string | URL, init?: RequestInit) => { - const headers = new Headers(init?.headers); - headers.set("X-Middleware-3", "applied"); - return next(input, { ...init, headers }); - }; - - const composedFetch = applyMiddlewares( - middleware1, - middleware2, - middleware3, - )(mockFetch); - - await composedFetch("https://api.example.com/data"); - - const callArgs = mockFetch.mock.calls[0]; - const headers = callArgs[1]?.headers as Headers; - expect(headers.get("X-Middleware-1")).toBe("applied"); - expect(headers.get("X-Middleware-2")).toBe("applied"); - expect(headers.get("X-Middleware-3")).toBe("applied"); - }); - - it("should work with real fetch middleware functions", async () => { - const response = new Response("success", { status: 200, statusText: "OK" }); - mockFetch.mockResolvedValue(response); - - // Create middleware that add identifying headers - const oauthMiddleware = - (next: FetchLike) => async (input: string | URL, init?: RequestInit) => { - const headers = new Headers(init?.headers); - headers.set("Authorization", "Bearer test-token"); - return next(input, { ...init, headers }); - }; - - // Use custom logger to avoid console output - const mockLogger = jest.fn(); - const composedFetch = applyMiddlewares( - oauthMiddleware, - withLogging({ logger: mockLogger, statusLevel: 0 }), - )(mockFetch); - - await composedFetch("https://api.example.com/data"); - - // Should have both Authorization header and logging - const callArgs = mockFetch.mock.calls[0]; - const headers = callArgs[1]?.headers as Headers; - expect(headers.get("Authorization")).toBe("Bearer test-token"); - expect(mockLogger).toHaveBeenCalledWith({ - method: "GET", - url: "https://api.example.com/data", - status: 200, - statusText: "OK", - duration: expect.any(Number), - requestHeaders: undefined, - responseHeaders: undefined, + const logCall = mockLogger.mock.calls[0][0]; + expect(logCall.duration).toBeGreaterThanOrEqual(90); // Allow some margin for timing }); - }); - - it("should preserve error propagation through middleware", async () => { - const errorMiddleware = - (next: FetchLike) => async (input: string | URL, init?: RequestInit) => { - try { - return await next(input, init); - } catch (error) { - // Add context to the error - throw new Error( - `Middleware error: ${error instanceof Error ? error.message : String(error)}`, - ); - } - }; - - const originalError = new Error("Network failure"); - mockFetch.mockRejectedValue(originalError); - - const composedFetch = applyMiddlewares(errorMiddleware)(mockFetch); - - await expect(composedFetch("https://api.example.com/data")).rejects.toThrow( - "Middleware error: Network failure", - ); - }); }); -describe("Integration Tests", () => { - let mockProvider: jest.Mocked; - let mockFetch: jest.MockedFunction; - - beforeEach(() => { - jest.clearAllMocks(); - - mockProvider = { - get redirectUrl() { - return "http://localhost/callback"; - }, - get clientMetadata() { - return { redirect_uris: ["http://localhost/callback"] }; - }, - tokens: jest.fn(), - saveTokens: jest.fn(), - clientInformation: jest.fn(), - redirectToAuthorization: jest.fn(), - saveCodeVerifier: jest.fn(), - codeVerifier: jest.fn(), - invalidateCredentials: jest.fn(), - }; +describe('applyMiddleware', () => { + let mockFetch: jest.MockedFunction; + + beforeEach(() => { + jest.clearAllMocks(); + mockFetch = jest.fn(); + }); - mockFetch = jest.fn(); - }); + it('should compose no middleware correctly', () => { + const response = new Response('success', { status: 200 }); + mockFetch.mockResolvedValue(response); - it("should work with SSE transport pattern", async () => { - // Simulate how SSE transport might use the middleware - mockProvider.tokens.mockResolvedValue({ - access_token: "sse-token", - token_type: "Bearer", - expires_in: 3600, + const composedFetch = applyMiddlewares()(mockFetch); + + expect(composedFetch).toBe(mockFetch); }); - const response = new Response('{"jsonrpc":"2.0","id":1,"result":{}}', { - status: 200, - headers: { "Content-Type": "application/json" }, + it('should compose single middleware correctly', async () => { + const response = new Response('success', { status: 200 }); + mockFetch.mockResolvedValue(response); + + // Create a middleware that adds a header + const middleware1 = (next: FetchLike) => async (input: string | URL, init?: RequestInit) => { + const headers = new Headers(init?.headers); + headers.set('X-Middleware-1', 'applied'); + return next(input, { ...init, headers }); + }; + + const composedFetch = applyMiddlewares(middleware1)(mockFetch); + + await composedFetch('https://api.example.com/data'); + + expect(mockFetch).toHaveBeenCalledWith( + 'https://api.example.com/data', + expect.objectContaining({ + headers: expect.any(Headers) + }) + ); + + const callArgs = mockFetch.mock.calls[0]; + const headers = callArgs[1]?.headers as Headers; + expect(headers.get('X-Middleware-1')).toBe('applied'); }); - mockFetch.mockResolvedValue(response); - - // Use custom logger to avoid console output - const mockLogger = jest.fn(); - const enhancedFetch = applyMiddlewares( - withOAuth( - mockProvider as OAuthClientProvider, - "https://mcp-server.example.com", - ), - withLogging({ logger: mockLogger, statusLevel: 400 }), // Only log errors - )(mockFetch); - - // Simulate SSE POST request - await enhancedFetch("https://mcp-server.example.com/endpoint", { - method: "POST", - headers: { "Content-Type": "application/json" }, - body: JSON.stringify({ - jsonrpc: "2.0", - method: "tools/list", - id: 1, - }), + + it('should compose multiple middleware in order', async () => { + const response = new Response('success', { status: 200 }); + mockFetch.mockResolvedValue(response); + + // Create middleware that add identifying headers + const middleware1 = (next: FetchLike) => async (input: string | URL, init?: RequestInit) => { + const headers = new Headers(init?.headers); + headers.set('X-Middleware-1', 'applied'); + return next(input, { ...init, headers }); + }; + + const middleware2 = (next: FetchLike) => async (input: string | URL, init?: RequestInit) => { + const headers = new Headers(init?.headers); + headers.set('X-Middleware-2', 'applied'); + return next(input, { ...init, headers }); + }; + + const middleware3 = (next: FetchLike) => async (input: string | URL, init?: RequestInit) => { + const headers = new Headers(init?.headers); + headers.set('X-Middleware-3', 'applied'); + return next(input, { ...init, headers }); + }; + + const composedFetch = applyMiddlewares(middleware1, middleware2, middleware3)(mockFetch); + + await composedFetch('https://api.example.com/data'); + + const callArgs = mockFetch.mock.calls[0]; + const headers = callArgs[1]?.headers as Headers; + expect(headers.get('X-Middleware-1')).toBe('applied'); + expect(headers.get('X-Middleware-2')).toBe('applied'); + expect(headers.get('X-Middleware-3')).toBe('applied'); }); - expect(mockFetch).toHaveBeenCalledWith( - "https://mcp-server.example.com/endpoint", - expect.objectContaining({ - method: "POST", - headers: expect.any(Headers), - body: expect.any(String), - }), - ); - - const callArgs = mockFetch.mock.calls[0]; - const headers = callArgs[1]?.headers as Headers; - expect(headers.get("Authorization")).toBe("Bearer sse-token"); - expect(headers.get("Content-Type")).toBe("application/json"); - }); - - it("should work with StreamableHTTP transport pattern", async () => { - // Simulate how StreamableHTTP transport might use the middleware - mockProvider.tokens.mockResolvedValue({ - access_token: "streamable-token", - token_type: "Bearer", - expires_in: 3600, + it('should work with real fetch middleware functions', async () => { + const response = new Response('success', { status: 200, statusText: 'OK' }); + mockFetch.mockResolvedValue(response); + + // Create middleware that add identifying headers + const oauthMiddleware = (next: FetchLike) => async (input: string | URL, init?: RequestInit) => { + const headers = new Headers(init?.headers); + headers.set('Authorization', 'Bearer test-token'); + return next(input, { ...init, headers }); + }; + + // Use custom logger to avoid console output + const mockLogger = jest.fn(); + const composedFetch = applyMiddlewares(oauthMiddleware, withLogging({ logger: mockLogger, statusLevel: 0 }))(mockFetch); + + await composedFetch('https://api.example.com/data'); + + // Should have both Authorization header and logging + const callArgs = mockFetch.mock.calls[0]; + const headers = callArgs[1]?.headers as Headers; + expect(headers.get('Authorization')).toBe('Bearer test-token'); + expect(mockLogger).toHaveBeenCalledWith({ + method: 'GET', + url: 'https://api.example.com/data', + status: 200, + statusText: 'OK', + duration: expect.any(Number), + requestHeaders: undefined, + responseHeaders: undefined + }); }); - const response = new Response(null, { - status: 202, - headers: { "mcp-session-id": "session-123" }, + it('should preserve error propagation through middleware', async () => { + const errorMiddleware = (next: FetchLike) => async (input: string | URL, init?: RequestInit) => { + try { + return await next(input, init); + } catch (error) { + // Add context to the error + throw new Error(`Middleware error: ${error instanceof Error ? error.message : String(error)}`); + } + }; + + const originalError = new Error('Network failure'); + mockFetch.mockRejectedValue(originalError); + + const composedFetch = applyMiddlewares(errorMiddleware)(mockFetch); + + await expect(composedFetch('https://api.example.com/data')).rejects.toThrow('Middleware error: Network failure'); }); - mockFetch.mockResolvedValue(response); - - // Use custom logger to avoid console output - const mockLogger = jest.fn(); - const enhancedFetch = applyMiddlewares( - withOAuth( - mockProvider as OAuthClientProvider, - "https://streamable-server.example.com", - ), - withLogging({ - logger: mockLogger, - includeResponseHeaders: true, - statusLevel: 0, - }), - )(mockFetch); - - // Simulate StreamableHTTP initialization request - await enhancedFetch("https://streamable-server.example.com/mcp", { - method: "POST", - headers: { - "Content-Type": "application/json", - Accept: "application/json, text/event-stream", - }, - body: JSON.stringify({ - jsonrpc: "2.0", - method: "initialize", - params: { protocolVersion: "2025-03-26", clientInfo: { name: "test" } }, - id: 1, - }), +}); + +describe('Integration Tests', () => { + let mockProvider: jest.Mocked; + let mockFetch: jest.MockedFunction; + + beforeEach(() => { + jest.clearAllMocks(); + + mockProvider = { + get redirectUrl() { + return 'http://localhost/callback'; + }, + get clientMetadata() { + return { redirect_uris: ['http://localhost/callback'] }; + }, + tokens: jest.fn(), + saveTokens: jest.fn(), + clientInformation: jest.fn(), + redirectToAuthorization: jest.fn(), + saveCodeVerifier: jest.fn(), + codeVerifier: jest.fn(), + invalidateCredentials: jest.fn() + }; + + mockFetch = jest.fn(); }); - const callArgs = mockFetch.mock.calls[0]; - const headers = callArgs[1]?.headers as Headers; - expect(headers.get("Authorization")).toBe("Bearer streamable-token"); - expect(headers.get("Accept")).toBe("application/json, text/event-stream"); - }); - - it("should handle auth retry in transport-like scenario", async () => { - mockProvider.tokens - .mockResolvedValueOnce({ - access_token: "expired-token", - token_type: "Bearer", - expires_in: 3600, - }) - .mockResolvedValueOnce({ - access_token: "fresh-token", - token_type: "Bearer", - expires_in: 3600, - }); - - const unauthorizedResponse = new Response('{"error":"invalid_token"}', { - status: 401, - headers: { "www-authenticate": 'Bearer realm="mcp"' }, + it('should work with SSE transport pattern', async () => { + // Simulate how SSE transport might use the middleware + mockProvider.tokens.mockResolvedValue({ + access_token: 'sse-token', + token_type: 'Bearer', + expires_in: 3600 + }); + + const response = new Response('{"jsonrpc":"2.0","id":1,"result":{}}', { + status: 200, + headers: { 'Content-Type': 'application/json' } + }); + mockFetch.mockResolvedValue(response); + + // Use custom logger to avoid console output + const mockLogger = jest.fn(); + const enhancedFetch = applyMiddlewares( + withOAuth(mockProvider as OAuthClientProvider, 'https://mcp-server.example.com'), + withLogging({ logger: mockLogger, statusLevel: 400 }) // Only log errors + )(mockFetch); + + // Simulate SSE POST request + await enhancedFetch('https://mcp-server.example.com/endpoint', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + jsonrpc: '2.0', + method: 'tools/list', + id: 1 + }) + }); + + expect(mockFetch).toHaveBeenCalledWith( + 'https://mcp-server.example.com/endpoint', + expect.objectContaining({ + method: 'POST', + headers: expect.any(Headers), + body: expect.any(String) + }) + ); + + const callArgs = mockFetch.mock.calls[0]; + const headers = callArgs[1]?.headers as Headers; + expect(headers.get('Authorization')).toBe('Bearer sse-token'); + expect(headers.get('Content-Type')).toBe('application/json'); }); - const successResponse = new Response( - '{"jsonrpc":"2.0","id":1,"result":{}}', - { - status: 200, - }, - ); - - mockFetch - .mockResolvedValueOnce(unauthorizedResponse) - .mockResolvedValueOnce(successResponse); - - mockExtractResourceMetadataUrl.mockReturnValue( - new URL("https://auth.example.com/.well-known/oauth-protected-resource"), - ); - mockAuth.mockResolvedValue("AUTHORIZED"); - - // Use custom logger to avoid console output - const mockLogger = jest.fn(); - const enhancedFetch = applyMiddlewares( - withOAuth( - mockProvider as OAuthClientProvider, - "https://mcp-server.example.com", - ), - withLogging({ logger: mockLogger, statusLevel: 0 }), - )(mockFetch); - - const result = await enhancedFetch( - "https://mcp-server.example.com/endpoint", - { - method: "POST", - body: JSON.stringify({ jsonrpc: "2.0", method: "test", id: 1 }), - }, - ); - - expect(result).toBe(successResponse); - expect(mockFetch).toHaveBeenCalledTimes(2); - expect(mockAuth).toHaveBeenCalledWith(mockProvider, { - serverUrl: "https://mcp-server.example.com", - resourceMetadataUrl: new URL( - "https://auth.example.com/.well-known/oauth-protected-resource", - ), - fetchFn: mockFetch, + + it('should work with StreamableHTTP transport pattern', async () => { + // Simulate how StreamableHTTP transport might use the middleware + mockProvider.tokens.mockResolvedValue({ + access_token: 'streamable-token', + token_type: 'Bearer', + expires_in: 3600 + }); + + const response = new Response(null, { + status: 202, + headers: { 'mcp-session-id': 'session-123' } + }); + mockFetch.mockResolvedValue(response); + + // Use custom logger to avoid console output + const mockLogger = jest.fn(); + const enhancedFetch = applyMiddlewares( + withOAuth(mockProvider as OAuthClientProvider, 'https://streamable-server.example.com'), + withLogging({ + logger: mockLogger, + includeResponseHeaders: true, + statusLevel: 0 + }) + )(mockFetch); + + // Simulate StreamableHTTP initialization request + await enhancedFetch('https://streamable-server.example.com/mcp', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream' + }, + body: JSON.stringify({ + jsonrpc: '2.0', + method: 'initialize', + params: { protocolVersion: '2025-03-26', clientInfo: { name: 'test' } }, + id: 1 + }) + }); + + const callArgs = mockFetch.mock.calls[0]; + const headers = callArgs[1]?.headers as Headers; + expect(headers.get('Authorization')).toBe('Bearer streamable-token'); + expect(headers.get('Accept')).toBe('application/json, text/event-stream'); }); - }); -}); -describe("createMiddleware", () => { - let mockFetch: jest.MockedFunction; + it('should handle auth retry in transport-like scenario', async () => { + mockProvider.tokens + .mockResolvedValueOnce({ + access_token: 'expired-token', + token_type: 'Bearer', + expires_in: 3600 + }) + .mockResolvedValueOnce({ + access_token: 'fresh-token', + token_type: 'Bearer', + expires_in: 3600 + }); + + const unauthorizedResponse = new Response('{"error":"invalid_token"}', { + status: 401, + headers: { 'www-authenticate': 'Bearer realm="mcp"' } + }); + const successResponse = new Response('{"jsonrpc":"2.0","id":1,"result":{}}', { + status: 200 + }); - beforeEach(() => { - jest.clearAllMocks(); - mockFetch = jest.fn(); - }); + mockFetch.mockResolvedValueOnce(unauthorizedResponse).mockResolvedValueOnce(successResponse); - it("should create middleware with cleaner syntax", async () => { - const response = new Response("success", { status: 200 }); - mockFetch.mockResolvedValue(response); + mockExtractResourceMetadataUrl.mockReturnValue(new URL('https://auth.example.com/.well-known/oauth-protected-resource')); + mockAuth.mockResolvedValue('AUTHORIZED'); - const customMiddleware = createMiddleware(async (next, input, init) => { - const headers = new Headers(init?.headers); - headers.set("X-Custom-Header", "custom-value"); - return next(input, { ...init, headers }); + // Use custom logger to avoid console output + const mockLogger = jest.fn(); + const enhancedFetch = applyMiddlewares( + withOAuth(mockProvider as OAuthClientProvider, 'https://mcp-server.example.com'), + withLogging({ logger: mockLogger, statusLevel: 0 }) + )(mockFetch); + + const result = await enhancedFetch('https://mcp-server.example.com/endpoint', { + method: 'POST', + body: JSON.stringify({ jsonrpc: '2.0', method: 'test', id: 1 }) + }); + + expect(result).toBe(successResponse); + expect(mockFetch).toHaveBeenCalledTimes(2); + expect(mockAuth).toHaveBeenCalledWith(mockProvider, { + serverUrl: 'https://mcp-server.example.com', + resourceMetadataUrl: new URL('https://auth.example.com/.well-known/oauth-protected-resource'), + fetchFn: mockFetch + }); }); +}); - const enhancedFetch = customMiddleware(mockFetch); - await enhancedFetch("https://api.example.com/data"); - - expect(mockFetch).toHaveBeenCalledWith( - "https://api.example.com/data", - expect.objectContaining({ - headers: expect.any(Headers), - }), - ); - - const callArgs = mockFetch.mock.calls[0]; - const headers = callArgs[1]?.headers as Headers; - expect(headers.get("X-Custom-Header")).toBe("custom-value"); - }); - - it("should support conditional middleware logic", async () => { - const apiResponse = new Response("api response", { status: 200 }); - const publicResponse = new Response("public response", { status: 200 }); - mockFetch - .mockResolvedValueOnce(apiResponse) - .mockResolvedValueOnce(publicResponse); - - const conditionalMiddleware = createMiddleware( - async (next, input, init) => { - const url = typeof input === "string" ? input : input.toString(); - - if (url.includes("/api/")) { - const headers = new Headers(init?.headers); - headers.set("X-API-Version", "v2"); - return next(input, { ...init, headers }); - } - - return next(input, init); - }, - ); - - const enhancedFetch = conditionalMiddleware(mockFetch); - - // Test API route - await enhancedFetch("https://example.com/api/users"); - let callArgs = mockFetch.mock.calls[0]; - const headers = callArgs[1]?.headers as Headers; - expect(headers.get("X-API-Version")).toBe("v2"); - - // Test non-API route - await enhancedFetch("https://example.com/public/page"); - callArgs = mockFetch.mock.calls[1]; - const maybeHeaders = callArgs[1]?.headers as Headers | undefined; - expect(maybeHeaders?.get("X-API-Version")).toBeUndefined(); - }); - - it("should support short-circuit responses", async () => { - const customMiddleware = createMiddleware(async (next, input, init) => { - const url = typeof input === "string" ? input : input.toString(); - - // Short-circuit for specific URL - if (url.includes("/cached")) { - return new Response("cached data", { status: 200 }); - } - - return next(input, init); +describe('createMiddleware', () => { + let mockFetch: jest.MockedFunction; + + beforeEach(() => { + jest.clearAllMocks(); + mockFetch = jest.fn(); }); - const enhancedFetch = customMiddleware(mockFetch); - - // Test cached route (should not call mockFetch) - const cachedResponse = await enhancedFetch( - "https://example.com/cached/data", - ); - expect(await cachedResponse.text()).toBe("cached data"); - expect(mockFetch).not.toHaveBeenCalled(); - - // Test normal route - mockFetch.mockResolvedValue(new Response("fresh data", { status: 200 })); - const normalResponse = await enhancedFetch("https://example.com/normal/data"); - expect(await normalResponse.text()).toBe("fresh data"); - expect(mockFetch).toHaveBeenCalledTimes(1); - }); - - it("should handle response transformation", async () => { - const originalResponse = new Response('{"data": "original"}', { - status: 200, - headers: { "Content-Type": "application/json" }, + it('should create middleware with cleaner syntax', async () => { + const response = new Response('success', { status: 200 }); + mockFetch.mockResolvedValue(response); + + const customMiddleware = createMiddleware(async (next, input, init) => { + const headers = new Headers(init?.headers); + headers.set('X-Custom-Header', 'custom-value'); + return next(input, { ...init, headers }); + }); + + const enhancedFetch = customMiddleware(mockFetch); + await enhancedFetch('https://api.example.com/data'); + + expect(mockFetch).toHaveBeenCalledWith( + 'https://api.example.com/data', + expect.objectContaining({ + headers: expect.any(Headers) + }) + ); + + const callArgs = mockFetch.mock.calls[0]; + const headers = callArgs[1]?.headers as Headers; + expect(headers.get('X-Custom-Header')).toBe('custom-value'); }); - mockFetch.mockResolvedValue(originalResponse); - const transformMiddleware = createMiddleware(async (next, input, init) => { - const response = await next(input, init); + it('should support conditional middleware logic', async () => { + const apiResponse = new Response('api response', { status: 200 }); + const publicResponse = new Response('public response', { status: 200 }); + mockFetch.mockResolvedValueOnce(apiResponse).mockResolvedValueOnce(publicResponse); - if (response.headers.get("content-type")?.includes("application/json")) { - const data = await response.json(); - const transformed = { ...data, timestamp: 123456789 }; + const conditionalMiddleware = createMiddleware(async (next, input, init) => { + const url = typeof input === 'string' ? input : input.toString(); - return new Response(JSON.stringify(transformed), { - status: response.status, - statusText: response.statusText, - headers: response.headers, + if (url.includes('/api/')) { + const headers = new Headers(init?.headers); + headers.set('X-API-Version', 'v2'); + return next(input, { ...init, headers }); + } + + return next(input, init); }); - } - return response; - }); + const enhancedFetch = conditionalMiddleware(mockFetch); - const enhancedFetch = transformMiddleware(mockFetch); - const response = await enhancedFetch("https://api.example.com/data"); - const result = await response.json(); + // Test API route + await enhancedFetch('https://example.com/api/users'); + let callArgs = mockFetch.mock.calls[0]; + const headers = callArgs[1]?.headers as Headers; + expect(headers.get('X-API-Version')).toBe('v2'); - expect(result).toEqual({ - data: "original", - timestamp: 123456789, - }); - }); - - it("should support error handling and recovery", async () => { - let attemptCount = 0; - mockFetch.mockImplementation(async () => { - attemptCount++; - if (attemptCount === 1) { - throw new Error("Network error"); - } - return new Response("success", { status: 200 }); + // Test non-API route + await enhancedFetch('https://example.com/public/page'); + callArgs = mockFetch.mock.calls[1]; + const maybeHeaders = callArgs[1]?.headers as Headers | undefined; + expect(maybeHeaders?.get('X-API-Version')).toBeUndefined(); }); - const retryMiddleware = createMiddleware(async (next, input, init) => { - try { - return await next(input, init); - } catch (error) { - // Retry once on network error - console.log("Retrying request after error:", error); - return await next(input, init); - } + it('should support short-circuit responses', async () => { + const customMiddleware = createMiddleware(async (next, input, init) => { + const url = typeof input === 'string' ? input : input.toString(); + + // Short-circuit for specific URL + if (url.includes('/cached')) { + return new Response('cached data', { status: 200 }); + } + + return next(input, init); + }); + + const enhancedFetch = customMiddleware(mockFetch); + + // Test cached route (should not call mockFetch) + const cachedResponse = await enhancedFetch('https://example.com/cached/data'); + expect(await cachedResponse.text()).toBe('cached data'); + expect(mockFetch).not.toHaveBeenCalled(); + + // Test normal route + mockFetch.mockResolvedValue(new Response('fresh data', { status: 200 })); + const normalResponse = await enhancedFetch('https://example.com/normal/data'); + expect(await normalResponse.text()).toBe('fresh data'); + expect(mockFetch).toHaveBeenCalledTimes(1); }); - const enhancedFetch = retryMiddleware(mockFetch); - const response = await enhancedFetch("https://api.example.com/data"); + it('should handle response transformation', async () => { + const originalResponse = new Response('{"data": "original"}', { + status: 200, + headers: { 'Content-Type': 'application/json' } + }); + mockFetch.mockResolvedValue(originalResponse); + + const transformMiddleware = createMiddleware(async (next, input, init) => { + const response = await next(input, init); + + if (response.headers.get('content-type')?.includes('application/json')) { + const data = await response.json(); + const transformed = { ...data, timestamp: 123456789 }; - expect(await response.text()).toBe("success"); - expect(mockFetch).toHaveBeenCalledTimes(2); - }); + return new Response(JSON.stringify(transformed), { + status: response.status, + statusText: response.statusText, + headers: response.headers + }); + } + + return response; + }); - it("should compose well with other middleware", async () => { - const response = new Response("success", { status: 200 }); - mockFetch.mockResolvedValue(response); + const enhancedFetch = transformMiddleware(mockFetch); + const response = await enhancedFetch('https://api.example.com/data'); + const result = await response.json(); - // Create custom middleware using createMiddleware - const customAuth = createMiddleware(async (next, input, init) => { - const headers = new Headers(init?.headers); - headers.set("Authorization", "Custom token"); - return next(input, { ...init, headers }); + expect(result).toEqual({ + data: 'original', + timestamp: 123456789 + }); }); - const customLogging = createMiddleware(async (next, input, init) => { - const url = typeof input === "string" ? input : input.toString(); - console.log(`Request to: ${url}`); - const response = await next(input, init); - console.log(`Response status: ${response.status}`); - return response; + it('should support error handling and recovery', async () => { + let attemptCount = 0; + mockFetch.mockImplementation(async () => { + attemptCount++; + if (attemptCount === 1) { + throw new Error('Network error'); + } + return new Response('success', { status: 200 }); + }); + + const retryMiddleware = createMiddleware(async (next, input, init) => { + try { + return await next(input, init); + } catch (error) { + // Retry once on network error + console.log('Retrying request after error:', error); + return await next(input, init); + } + }); + + const enhancedFetch = retryMiddleware(mockFetch); + const response = await enhancedFetch('https://api.example.com/data'); + + expect(await response.text()).toBe('success'); + expect(mockFetch).toHaveBeenCalledTimes(2); }); - // Compose with existing middleware - const enhancedFetch = applyMiddlewares( - customAuth, - customLogging, - withLogging({ statusLevel: 400 }), - )(mockFetch); + it('should compose well with other middleware', async () => { + const response = new Response('success', { status: 200 }); + mockFetch.mockResolvedValue(response); - await enhancedFetch("https://api.example.com/data"); + // Create custom middleware using createMiddleware + const customAuth = createMiddleware(async (next, input, init) => { + const headers = new Headers(init?.headers); + headers.set('Authorization', 'Custom token'); + return next(input, { ...init, headers }); + }); + + const customLogging = createMiddleware(async (next, input, init) => { + const url = typeof input === 'string' ? input : input.toString(); + console.log(`Request to: ${url}`); + const response = await next(input, init); + console.log(`Response status: ${response.status}`); + return response; + }); - const callArgs = mockFetch.mock.calls[0]; - const headers = callArgs[1]?.headers as Headers; - expect(headers.get("Authorization")).toBe("Custom token"); - }); + // Compose with existing middleware + const enhancedFetch = applyMiddlewares(customAuth, customLogging, withLogging({ statusLevel: 400 }))(mockFetch); - it("should have access to both input types (string and URL)", async () => { - const response = new Response("success", { status: 200 }); - mockFetch.mockResolvedValue(response); + await enhancedFetch('https://api.example.com/data'); - let capturedInputType: string | undefined; - const inspectMiddleware = createMiddleware(async (next, input, init) => { - capturedInputType = typeof input === "string" ? "string" : "URL"; - return next(input, init); + const callArgs = mockFetch.mock.calls[0]; + const headers = callArgs[1]?.headers as Headers; + expect(headers.get('Authorization')).toBe('Custom token'); }); - const enhancedFetch = inspectMiddleware(mockFetch); + it('should have access to both input types (string and URL)', async () => { + const response = new Response('success', { status: 200 }); + mockFetch.mockResolvedValue(response); - // Test with string input - await enhancedFetch("https://api.example.com/data"); - expect(capturedInputType).toBe("string"); + let capturedInputType: string | undefined; + const inspectMiddleware = createMiddleware(async (next, input, init) => { + capturedInputType = typeof input === 'string' ? 'string' : 'URL'; + return next(input, init); + }); + + const enhancedFetch = inspectMiddleware(mockFetch); - // Test with URL input - await enhancedFetch(new URL("https://api.example.com/data")); - expect(capturedInputType).toBe("URL"); - }); + // Test with string input + await enhancedFetch('https://api.example.com/data'); + expect(capturedInputType).toBe('string'); + + // Test with URL input + await enhancedFetch(new URL('https://api.example.com/data')); + expect(capturedInputType).toBe('URL'); + }); }); diff --git a/src/client/middleware.ts b/src/client/middleware.ts index 3d0661584..a7cbc6c69 100644 --- a/src/client/middleware.ts +++ b/src/client/middleware.ts @@ -1,10 +1,5 @@ -import { - auth, - extractResourceMetadataUrl, - OAuthClientProvider, - UnauthorizedError, -} from "./auth.js"; -import { FetchLike } from "../shared/transport.js"; +import { auth, extractResourceMetadataUrl, OAuthClientProvider, UnauthorizedError } from './auth.js'; +import { FetchLike } from '../shared/transport.js'; /** * Middleware function that wraps and enhances fetch functionality. @@ -39,114 +34,106 @@ export type Middleware = (next: FetchLike) => FetchLike; * @returns A fetch middleware function */ export const withOAuth = - (provider: OAuthClientProvider, baseUrl?: string | URL): Middleware => - (next) => { - return async (input, init) => { - const makeRequest = async (): Promise => { - const headers = new Headers(init?.headers); - - // Add authorization header if tokens are available - const tokens = await provider.tokens(); - if (tokens) { - headers.set("Authorization", `Bearer ${tokens.access_token}`); - } - - return await next(input, { ...init, headers }); - }; - - let response = await makeRequest(); - - // Handle 401 responses by attempting re-authentication - if (response.status === 401) { - try { - const resourceMetadataUrl = extractResourceMetadataUrl(response); - - // Use provided baseUrl or extract from request URL - const serverUrl = - baseUrl || - (typeof input === "string" ? new URL(input).origin : input.origin); - - const result = await auth(provider, { - serverUrl, - resourceMetadataUrl, - fetchFn: next, - }); - - if (result === "REDIRECT") { - throw new UnauthorizedError( - "Authentication requires user authorization - redirect initiated", - ); - } - - if (result !== "AUTHORIZED") { - throw new UnauthorizedError( - `Authentication failed with result: ${result}`, - ); - } - - // Retry the request with fresh tokens - response = await makeRequest(); - } catch (error) { - if (error instanceof UnauthorizedError) { - throw error; - } - throw new UnauthorizedError( - `Failed to re-authenticate: ${error instanceof Error ? error.message : String(error)}`, - ); - } - } - - // If we still have a 401 after re-auth attempt, throw an error - if (response.status === 401) { - const url = typeof input === "string" ? input : input.toString(); - throw new UnauthorizedError(`Authentication failed for ${url}`); - } - - return response; + (provider: OAuthClientProvider, baseUrl?: string | URL): Middleware => + next => { + return async (input, init) => { + const makeRequest = async (): Promise => { + const headers = new Headers(init?.headers); + + // Add authorization header if tokens are available + const tokens = await provider.tokens(); + if (tokens) { + headers.set('Authorization', `Bearer ${tokens.access_token}`); + } + + return await next(input, { ...init, headers }); + }; + + let response = await makeRequest(); + + // Handle 401 responses by attempting re-authentication + if (response.status === 401) { + try { + const resourceMetadataUrl = extractResourceMetadataUrl(response); + + // Use provided baseUrl or extract from request URL + const serverUrl = baseUrl || (typeof input === 'string' ? new URL(input).origin : input.origin); + + const result = await auth(provider, { + serverUrl, + resourceMetadataUrl, + fetchFn: next + }); + + if (result === 'REDIRECT') { + throw new UnauthorizedError('Authentication requires user authorization - redirect initiated'); + } + + if (result !== 'AUTHORIZED') { + throw new UnauthorizedError(`Authentication failed with result: ${result}`); + } + + // Retry the request with fresh tokens + response = await makeRequest(); + } catch (error) { + if (error instanceof UnauthorizedError) { + throw error; + } + throw new UnauthorizedError(`Failed to re-authenticate: ${error instanceof Error ? error.message : String(error)}`); + } + } + + // If we still have a 401 after re-auth attempt, throw an error + if (response.status === 401) { + const url = typeof input === 'string' ? input : input.toString(); + throw new UnauthorizedError(`Authentication failed for ${url}`); + } + + return response; + }; }; - }; /** * Logger function type for HTTP requests */ export type RequestLogger = (input: { - method: string; - url: string | URL; - status: number; - statusText: string; - duration: number; - requestHeaders?: Headers; - responseHeaders?: Headers; - error?: Error; + method: string; + url: string | URL; + status: number; + statusText: string; + duration: number; + requestHeaders?: Headers; + responseHeaders?: Headers; + error?: Error; }) => void; /** * Configuration options for the logging middleware */ export type LoggingOptions = { - /** - * Custom logger function, defaults to console logging - */ - logger?: RequestLogger; - - /** - * Whether to include request headers in logs - * @default false - */ - includeRequestHeaders?: boolean; - - /** - * Whether to include response headers in logs - * @default false - */ - includeResponseHeaders?: boolean; - - /** - * Status level filter - only log requests with status >= this value - * Set to 0 to log all requests, 400 to log only errors - * @default 0 - */ - statusLevel?: number; + /** + * Custom logger function, defaults to console logging + */ + logger?: RequestLogger; + + /** + * Whether to include request headers in logs + * @default false + */ + includeRequestHeaders?: boolean; + + /** + * Whether to include response headers in logs + * @default false + */ + includeResponseHeaders?: boolean; + + /** + * Status level filter - only log requests with status >= this value + * Set to 0 to log all requests, 400 to log only errors + * @default 0 + */ + statusLevel?: number; }; /** @@ -166,100 +153,82 @@ export type LoggingOptions = { * @returns A fetch middleware function */ export const withLogging = (options: LoggingOptions = {}): Middleware => { - const { - logger, - includeRequestHeaders = false, - includeResponseHeaders = false, - statusLevel = 0, - } = options; - - const defaultLogger: RequestLogger = (input) => { - const { - method, - url, - status, - statusText, - duration, - requestHeaders, - responseHeaders, - error, - } = input; - - let message = error - ? `HTTP ${method} ${url} failed: ${error.message} (${duration}ms)` - : `HTTP ${method} ${url} ${status} ${statusText} (${duration}ms)`; - - // Add headers to message if requested - if (includeRequestHeaders && requestHeaders) { - const reqHeaders = Array.from(requestHeaders.entries()) - .map(([key, value]) => `${key}: ${value}`) - .join(", "); - message += `\n Request Headers: {${reqHeaders}}`; - } - - if (includeResponseHeaders && responseHeaders) { - const resHeaders = Array.from(responseHeaders.entries()) - .map(([key, value]) => `${key}: ${value}`) - .join(", "); - message += `\n Response Headers: {${resHeaders}}`; - } - - if (error || status >= 400) { - // eslint-disable-next-line no-console - console.error(message); - } else { - // eslint-disable-next-line no-console - console.log(message); - } - }; - - const logFn = logger || defaultLogger; - - return (next) => async (input, init) => { - const startTime = performance.now(); - const method = init?.method || "GET"; - const url = typeof input === "string" ? input : input.toString(); - const requestHeaders = includeRequestHeaders - ? new Headers(init?.headers) - : undefined; - - try { - const response = await next(input, init); - const duration = performance.now() - startTime; - - // Only log if status meets the log level threshold - if (response.status >= statusLevel) { - logFn({ - method, - url, - status: response.status, - statusText: response.statusText, - duration, - requestHeaders, - responseHeaders: includeResponseHeaders - ? response.headers - : undefined, - }); - } - - return response; - } catch (error) { - const duration = performance.now() - startTime; - - // Always log errors regardless of log level - logFn({ - method, - url, - status: 0, - statusText: "Network Error", - duration, - requestHeaders, - error: error as Error, - }); - - throw error; - } - }; + const { logger, includeRequestHeaders = false, includeResponseHeaders = false, statusLevel = 0 } = options; + + const defaultLogger: RequestLogger = input => { + const { method, url, status, statusText, duration, requestHeaders, responseHeaders, error } = input; + + let message = error + ? `HTTP ${method} ${url} failed: ${error.message} (${duration}ms)` + : `HTTP ${method} ${url} ${status} ${statusText} (${duration}ms)`; + + // Add headers to message if requested + if (includeRequestHeaders && requestHeaders) { + const reqHeaders = Array.from(requestHeaders.entries()) + .map(([key, value]) => `${key}: ${value}`) + .join(', '); + message += `\n Request Headers: {${reqHeaders}}`; + } + + if (includeResponseHeaders && responseHeaders) { + const resHeaders = Array.from(responseHeaders.entries()) + .map(([key, value]) => `${key}: ${value}`) + .join(', '); + message += `\n Response Headers: {${resHeaders}}`; + } + + if (error || status >= 400) { + // eslint-disable-next-line no-console + console.error(message); + } else { + // eslint-disable-next-line no-console + console.log(message); + } + }; + + const logFn = logger || defaultLogger; + + return next => async (input, init) => { + const startTime = performance.now(); + const method = init?.method || 'GET'; + const url = typeof input === 'string' ? input : input.toString(); + const requestHeaders = includeRequestHeaders ? new Headers(init?.headers) : undefined; + + try { + const response = await next(input, init); + const duration = performance.now() - startTime; + + // Only log if status meets the log level threshold + if (response.status >= statusLevel) { + logFn({ + method, + url, + status: response.status, + statusText: response.statusText, + duration, + requestHeaders, + responseHeaders: includeResponseHeaders ? response.headers : undefined + }); + } + + return response; + } catch (error) { + const duration = performance.now() - startTime; + + // Always log errors regardless of log level + logFn({ + method, + url, + status: 0, + statusText: 'Network Error', + duration, + requestHeaders, + error: error as Error + }); + + throw error; + } + }; }; /** @@ -281,12 +250,10 @@ export const withLogging = (options: LoggingOptions = {}): Middleware => { * @param middleware - Array of fetch middleware to compose into a pipeline * @returns A single composed middleware function */ -export const applyMiddlewares = ( - ...middleware: Middleware[] -): Middleware => { - return (next) => { - return middleware.reduce((handler, mw) => mw(handler), next); - }; +export const applyMiddlewares = (...middleware: Middleware[]): Middleware => { + return next => { + return middleware.reduce((handler, mw) => mw(handler), next); + }; }; /** @@ -347,12 +314,6 @@ export const applyMiddlewares = ( * @param handler - Function that receives the next handler and request parameters * @returns A fetch middleware function */ -export const createMiddleware = ( - handler: ( - next: FetchLike, - input: string | URL, - init?: RequestInit, - ) => Promise, -): Middleware => { - return (next) => (input, init) => handler(next, input as string | URL, init); +export const createMiddleware = (handler: (next: FetchLike, input: string | URL, init?: RequestInit) => Promise): Middleware => { + return next => (input, init) => handler(next, input as string | URL, init); }; diff --git a/src/client/sse.test.ts b/src/client/sse.test.ts index 4fce9976f..9e4b73e92 100644 --- a/src/client/sse.test.ts +++ b/src/client/sse.test.ts @@ -1,1449 +1,1484 @@ -import { createServer, ServerResponse, type IncomingMessage, type Server } from "http"; -import { AddressInfo } from "net"; -import { JSONRPCMessage } from "../types.js"; -import { SSEClientTransport } from "./sse.js"; -import { OAuthClientProvider, UnauthorizedError } from "./auth.js"; -import { OAuthTokens } from "../shared/auth.js"; -import { InvalidClientError, InvalidGrantError, UnauthorizedClientError } from "../server/auth/errors.js"; - -describe("SSEClientTransport", () => { - let resourceServer: Server; - let authServer: Server; - let transport: SSEClientTransport; - let resourceBaseUrl: URL; - let authBaseUrl: URL; - let lastServerRequest: IncomingMessage; - let sendServerMessage: ((message: string) => void) | null = null; - - beforeEach((done) => { - // Reset state - lastServerRequest = null as unknown as IncomingMessage; - sendServerMessage = null; - - authServer = createServer((req, res) => { - if (req.url === "/.well-known/oauth-authorization-server") { - res.writeHead(200, { - "Content-Type": "application/json" +import { createServer, ServerResponse, type IncomingMessage, type Server } from 'http'; +import { AddressInfo } from 'net'; +import { JSONRPCMessage } from '../types.js'; +import { SSEClientTransport } from './sse.js'; +import { OAuthClientProvider, UnauthorizedError } from './auth.js'; +import { OAuthTokens } from '../shared/auth.js'; +import { InvalidClientError, InvalidGrantError, UnauthorizedClientError } from '../server/auth/errors.js'; + +describe('SSEClientTransport', () => { + let resourceServer: Server; + let authServer: Server; + let transport: SSEClientTransport; + let resourceBaseUrl: URL; + let authBaseUrl: URL; + let lastServerRequest: IncomingMessage; + let sendServerMessage: ((message: string) => void) | null = null; + + beforeEach(done => { + // Reset state + lastServerRequest = null as unknown as IncomingMessage; + sendServerMessage = null; + + authServer = createServer((req, res) => { + if (req.url === '/.well-known/oauth-authorization-server') { + res.writeHead(200, { + 'Content-Type': 'application/json' + }); + res.end( + JSON.stringify({ + issuer: 'https://auth.example.com', + authorization_endpoint: 'https://auth.example.com/authorize', + token_endpoint: 'https://auth.example.com/token', + registration_endpoint: 'https://auth.example.com/register', + response_types_supported: ['code'], + code_challenge_methods_supported: ['S256'] + }) + ); + return; + } + res.writeHead(401).end(); }); - res.end(JSON.stringify({ - issuer: "https://auth.example.com", - authorization_endpoint: "https://auth.example.com/authorize", - token_endpoint: "https://auth.example.com/token", - registration_endpoint: "https://auth.example.com/register", - response_types_supported: ["code"], - code_challenge_methods_supported: ["S256"], - })); - return; - } - res.writeHead(401).end(); - }); - // Create a test server that will receive the EventSource connection - resourceServer = createServer((req, res) => { - lastServerRequest = req; - - // Send SSE headers - res.writeHead(200, { - "Content-Type": "text/event-stream", - "Cache-Control": "no-cache, no-transform", - Connection: "keep-alive", - }); - - // Send the endpoint event - res.write("event: endpoint\n"); - res.write(`data: ${resourceBaseUrl.href}\n\n`); - - // Store reference to send function for tests - sendServerMessage = (message: string) => { - res.write(`data: ${message}\n\n`); - }; - - // Handle request body for POST endpoints - if (req.method === "POST") { - let body = ""; - req.on("data", (chunk) => { - body += chunk; + // Create a test server that will receive the EventSource connection + resourceServer = createServer((req, res) => { + lastServerRequest = req; + + // Send SSE headers + res.writeHead(200, { + 'Content-Type': 'text/event-stream', + 'Cache-Control': 'no-cache, no-transform', + Connection: 'keep-alive' + }); + + // Send the endpoint event + res.write('event: endpoint\n'); + res.write(`data: ${resourceBaseUrl.href}\n\n`); + + // Store reference to send function for tests + sendServerMessage = (message: string) => { + res.write(`data: ${message}\n\n`); + }; + + // Handle request body for POST endpoints + if (req.method === 'POST') { + let body = ''; + req.on('data', chunk => { + body += chunk; + }); + req.on('end', () => { + (req as IncomingMessage & { body: string }).body = body; + res.end(); + }); + } }); - req.on("end", () => { - (req as IncomingMessage & { body: string }).body = body; - res.end(); + + // Start server on random port + resourceServer.listen(0, '127.0.0.1', () => { + const addr = resourceServer.address() as AddressInfo; + resourceBaseUrl = new URL(`http://127.0.0.1:${addr.port}`); + done(); }); - } - }); - // Start server on random port - resourceServer.listen(0, "127.0.0.1", () => { - const addr = resourceServer.address() as AddressInfo; - resourceBaseUrl = new URL(`http://127.0.0.1:${addr.port}`); - done(); + jest.spyOn(console, 'error').mockImplementation(() => {}); }); - jest.spyOn(console, 'error').mockImplementation(() => {}); - }); + afterEach(async () => { + await transport.close(); + await resourceServer.close(); + await authServer.close(); - afterEach(async () => { - await transport.close(); - await resourceServer.close(); - await authServer.close(); + jest.clearAllMocks(); + }); - jest.clearAllMocks(); - }); + describe('connection handling', () => { + it('establishes SSE connection and receives endpoint', async () => { + transport = new SSEClientTransport(resourceBaseUrl); + await transport.start(); - describe("connection handling", () => { - it("establishes SSE connection and receives endpoint", async () => { - transport = new SSEClientTransport(resourceBaseUrl); - await transport.start(); + expect(lastServerRequest.headers.accept).toBe('text/event-stream'); + expect(lastServerRequest.method).toBe('GET'); + }); - expect(lastServerRequest.headers.accept).toBe("text/event-stream"); - expect(lastServerRequest.method).toBe("GET"); - }); + it('rejects if server returns non-200 status', async () => { + // Create a server that returns 403 + await resourceServer.close(); - it("rejects if server returns non-200 status", async () => { - // Create a server that returns 403 - await resourceServer.close(); + resourceServer = createServer((req, res) => { + res.writeHead(403); + res.end(); + }); - resourceServer = createServer((req, res) => { - res.writeHead(403); - res.end(); - }); + await new Promise(resolve => { + resourceServer.listen(0, '127.0.0.1', () => { + const addr = resourceServer.address() as AddressInfo; + resourceBaseUrl = new URL(`http://127.0.0.1:${addr.port}`); + resolve(); + }); + }); - await new Promise((resolve) => { - resourceServer.listen(0, "127.0.0.1", () => { - const addr = resourceServer.address() as AddressInfo; - resourceBaseUrl = new URL(`http://127.0.0.1:${addr.port}`); - resolve(); + transport = new SSEClientTransport(resourceBaseUrl); + await expect(transport.start()).rejects.toThrow(); }); - }); - transport = new SSEClientTransport(resourceBaseUrl); - await expect(transport.start()).rejects.toThrow(); + it('closes EventSource connection on close()', async () => { + transport = new SSEClientTransport(resourceBaseUrl); + await transport.start(); + + const closePromise = new Promise(resolve => { + lastServerRequest.on('close', resolve); + }); + + await transport.close(); + await closePromise; + }); }); - it("closes EventSource connection on close()", async () => { - transport = new SSEClientTransport(resourceBaseUrl); - await transport.start(); + describe('message handling', () => { + it('receives and parses JSON-RPC messages', async () => { + const receivedMessages: JSONRPCMessage[] = []; + transport = new SSEClientTransport(resourceBaseUrl); + transport.onmessage = msg => receivedMessages.push(msg); - const closePromise = new Promise((resolve) => { - lastServerRequest.on("close", resolve); - }); + await transport.start(); - await transport.close(); - await closePromise; - }); - }); + const testMessage: JSONRPCMessage = { + jsonrpc: '2.0', + id: 'test-1', + method: 'test', + params: { foo: 'bar' } + }; - describe("message handling", () => { - it("receives and parses JSON-RPC messages", async () => { - const receivedMessages: JSONRPCMessage[] = []; - transport = new SSEClientTransport(resourceBaseUrl); - transport.onmessage = (msg) => receivedMessages.push(msg); + sendServerMessage!(JSON.stringify(testMessage)); - await transport.start(); + // Wait for message processing + await new Promise(resolve => setTimeout(resolve, 50)); - const testMessage: JSONRPCMessage = { - jsonrpc: "2.0", - id: "test-1", - method: "test", - params: { foo: "bar" }, - }; + expect(receivedMessages).toHaveLength(1); + expect(receivedMessages[0]).toEqual(testMessage); + }); - sendServerMessage!(JSON.stringify(testMessage)); + it('handles malformed JSON messages', async () => { + const errors: Error[] = []; + transport = new SSEClientTransport(resourceBaseUrl); + transport.onerror = err => errors.push(err); - // Wait for message processing - await new Promise((resolve) => setTimeout(resolve, 50)); + await transport.start(); - expect(receivedMessages).toHaveLength(1); - expect(receivedMessages[0]).toEqual(testMessage); - }); + sendServerMessage!('invalid json'); - it("handles malformed JSON messages", async () => { - const errors: Error[] = []; - transport = new SSEClientTransport(resourceBaseUrl); - transport.onerror = (err) => errors.push(err); + // Wait for message processing + await new Promise(resolve => setTimeout(resolve, 50)); - await transport.start(); + expect(errors).toHaveLength(1); + expect(errors[0].message).toMatch(/JSON/); + }); - sendServerMessage!("invalid json"); + it('handles messages via POST requests', async () => { + transport = new SSEClientTransport(resourceBaseUrl); + await transport.start(); - // Wait for message processing - await new Promise((resolve) => setTimeout(resolve, 50)); + const testMessage: JSONRPCMessage = { + jsonrpc: '2.0', + id: 'test-1', + method: 'test', + params: { foo: 'bar' } + }; - expect(errors).toHaveLength(1); - expect(errors[0].message).toMatch(/JSON/); - }); + await transport.send(testMessage); - it("handles messages via POST requests", async () => { - transport = new SSEClientTransport(resourceBaseUrl); - await transport.start(); - - const testMessage: JSONRPCMessage = { - jsonrpc: "2.0", - id: "test-1", - method: "test", - params: { foo: "bar" }, - }; - - await transport.send(testMessage); - - // Wait for request processing - await new Promise((resolve) => setTimeout(resolve, 50)); - - expect(lastServerRequest.method).toBe("POST"); - expect(lastServerRequest.headers["content-type"]).toBe( - "application/json", - ); - expect( - JSON.parse( - (lastServerRequest as IncomingMessage & { body: string }).body, - ), - ).toEqual(testMessage); - }); + // Wait for request processing + await new Promise(resolve => setTimeout(resolve, 50)); - it("handles POST request failures", async () => { - // Create a server that returns 500 for POST - await resourceServer.close(); - - resourceServer = createServer((req, res) => { - if (req.method === "GET") { - res.writeHead(200, { - "Content-Type": "text/event-stream", - "Cache-Control": "no-cache, no-transform", - Connection: "keep-alive", - }); - res.write("event: endpoint\n"); - res.write(`data: ${resourceBaseUrl.href}\n\n`); - } else { - res.writeHead(500); - res.end("Internal error"); - } - }); - - await new Promise((resolve) => { - resourceServer.listen(0, "127.0.0.1", () => { - const addr = resourceServer.address() as AddressInfo; - resourceBaseUrl = new URL(`http://127.0.0.1:${addr.port}`); - resolve(); + expect(lastServerRequest.method).toBe('POST'); + expect(lastServerRequest.headers['content-type']).toBe('application/json'); + expect(JSON.parse((lastServerRequest as IncomingMessage & { body: string }).body)).toEqual(testMessage); }); - }); - transport = new SSEClientTransport(resourceBaseUrl); - await transport.start(); + it('handles POST request failures', async () => { + // Create a server that returns 500 for POST + await resourceServer.close(); + + resourceServer = createServer((req, res) => { + if (req.method === 'GET') { + res.writeHead(200, { + 'Content-Type': 'text/event-stream', + 'Cache-Control': 'no-cache, no-transform', + Connection: 'keep-alive' + }); + res.write('event: endpoint\n'); + res.write(`data: ${resourceBaseUrl.href}\n\n`); + } else { + res.writeHead(500); + res.end('Internal error'); + } + }); + + await new Promise(resolve => { + resourceServer.listen(0, '127.0.0.1', () => { + const addr = resourceServer.address() as AddressInfo; + resourceBaseUrl = new URL(`http://127.0.0.1:${addr.port}`); + resolve(); + }); + }); - const testMessage: JSONRPCMessage = { - jsonrpc: "2.0", - id: "test-1", - method: "test", - params: {}, - }; + transport = new SSEClientTransport(resourceBaseUrl); + await transport.start(); - await expect(transport.send(testMessage)).rejects.toThrow(/500/); + const testMessage: JSONRPCMessage = { + jsonrpc: '2.0', + id: 'test-1', + method: 'test', + params: {} + }; + + await expect(transport.send(testMessage)).rejects.toThrow(/500/); + }); }); - }); - describe("header handling", () => { - it("uses custom fetch implementation from EventSourceInit to add auth headers", async () => { - const authToken = "Bearer test-token"; + describe('header handling', () => { + it('uses custom fetch implementation from EventSourceInit to add auth headers', async () => { + const authToken = 'Bearer test-token'; + + // Create a fetch wrapper that adds auth header + const fetchWithAuth = (url: string | URL, init?: RequestInit) => { + const headers = new Headers(init?.headers); + headers.set('Authorization', authToken); + return fetch(url.toString(), { ...init, headers }); + }; + + transport = new SSEClientTransport(resourceBaseUrl, { + eventSourceInit: { + fetch: fetchWithAuth + } + }); - // Create a fetch wrapper that adds auth header - const fetchWithAuth = (url: string | URL, init?: RequestInit) => { - const headers = new Headers(init?.headers); - headers.set("Authorization", authToken); - return fetch(url.toString(), { ...init, headers }); - }; + await transport.start(); - transport = new SSEClientTransport(resourceBaseUrl, { - eventSourceInit: { - fetch: fetchWithAuth, - }, - }); + // Verify the auth header was received by the server + expect(lastServerRequest.headers.authorization).toBe(authToken); + }); - await transport.start(); + it('uses custom fetch implementation from options', async () => { + const authToken = 'Bearer custom-token'; - // Verify the auth header was received by the server - expect(lastServerRequest.headers.authorization).toBe(authToken); - }); + const fetchWithAuth = jest.fn((url: string | URL, init?: RequestInit) => { + const headers = new Headers(init?.headers); + headers.set('Authorization', authToken); + return fetch(url.toString(), { ...init, headers }); + }); - it("uses custom fetch implementation from options", async () => { - const authToken = "Bearer custom-token"; + transport = new SSEClientTransport(resourceBaseUrl, { + fetch: fetchWithAuth + }); - const fetchWithAuth = jest.fn((url: string | URL, init?: RequestInit) => { - const headers = new Headers(init?.headers); - headers.set("Authorization", authToken); - return fetch(url.toString(), { ...init, headers }); - }); + await transport.start(); - transport = new SSEClientTransport(resourceBaseUrl, { - fetch: fetchWithAuth, - }); + expect(lastServerRequest.headers.authorization).toBe(authToken); - await transport.start(); + // Send a message to verify fetchWithAuth used for POST as well + const message: JSONRPCMessage = { + jsonrpc: '2.0', + id: '1', + method: 'test', + params: {} + }; - expect(lastServerRequest.headers.authorization).toBe(authToken); + await transport.send(message); - // Send a message to verify fetchWithAuth used for POST as well - const message: JSONRPCMessage = { - jsonrpc: "2.0", - id: "1", - method: "test", - params: {}, - }; + expect(fetchWithAuth).toHaveBeenCalledTimes(2); + expect(lastServerRequest.method).toBe('POST'); + expect(lastServerRequest.headers.authorization).toBe(authToken); + }); - await transport.send(message); + it('passes custom headers to fetch requests', async () => { + const customHeaders = { + Authorization: 'Bearer test-token', + 'X-Custom-Header': 'custom-value' + }; - expect(fetchWithAuth).toHaveBeenCalledTimes(2); - expect(lastServerRequest.method).toBe("POST"); - expect(lastServerRequest.headers.authorization).toBe(authToken); + transport = new SSEClientTransport(resourceBaseUrl, { + requestInit: { + headers: customHeaders + } + }); + + await transport.start(); + + // Store original fetch + const originalFetch = global.fetch; + + try { + // Mock fetch for the message sending test + global.fetch = jest.fn().mockResolvedValue({ + ok: true + }); + + const message: JSONRPCMessage = { + jsonrpc: '2.0', + id: '1', + method: 'test', + params: {} + }; + + await transport.send(message); + + // Verify fetch was called with correct headers + expect(global.fetch).toHaveBeenCalledWith( + expect.any(URL), + expect.objectContaining({ + headers: expect.any(Headers) + }) + ); + + const calledHeaders = (global.fetch as jest.Mock).mock.calls[0][1].headers; + expect(calledHeaders.get('Authorization')).toBe(customHeaders.Authorization); + expect(calledHeaders.get('X-Custom-Header')).toBe(customHeaders['X-Custom-Header']); + expect(calledHeaders.get('content-type')).toBe('application/json'); + } finally { + // Restore original fetch + global.fetch = originalFetch; + } + }); }); - it("passes custom headers to fetch requests", async () => { - const customHeaders = { - Authorization: "Bearer test-token", - "X-Custom-Header": "custom-value", - }; + describe('auth handling', () => { + const authServerMetadataUrls = ['/.well-known/oauth-authorization-server', '/.well-known/openid-configuration']; + + let mockAuthProvider: jest.Mocked; + + beforeEach(() => { + mockAuthProvider = { + get redirectUrl() { + return 'http://localhost/callback'; + }, + get clientMetadata() { + return { redirect_uris: ['http://localhost/callback'] }; + }, + clientInformation: jest.fn(() => ({ client_id: 'test-client-id', client_secret: 'test-client-secret' })), + tokens: jest.fn(), + saveTokens: jest.fn(), + redirectToAuthorization: jest.fn(), + saveCodeVerifier: jest.fn(), + codeVerifier: jest.fn(), + invalidateCredentials: jest.fn() + }; + }); - transport = new SSEClientTransport(resourceBaseUrl, { - requestInit: { - headers: customHeaders, - }, - }); + it('attaches auth header from provider on SSE connection', async () => { + mockAuthProvider.tokens.mockResolvedValue({ + access_token: 'test-token', + token_type: 'Bearer' + }); - await transport.start(); + transport = new SSEClientTransport(resourceBaseUrl, { + authProvider: mockAuthProvider + }); - // Store original fetch - const originalFetch = global.fetch; + await transport.start(); - try { - // Mock fetch for the message sending test - global.fetch = jest.fn().mockResolvedValue({ - ok: true, + expect(lastServerRequest.headers.authorization).toBe('Bearer test-token'); + expect(mockAuthProvider.tokens).toHaveBeenCalled(); }); - const message: JSONRPCMessage = { - jsonrpc: "2.0", - id: "1", - method: "test", - params: {}, - }; + it('attaches custom header from provider on initial SSE connection', async () => { + mockAuthProvider.tokens.mockResolvedValue({ + access_token: 'test-token', + token_type: 'Bearer' + }); + const customHeaders = { + 'X-Custom-Header': 'custom-value' + }; + + transport = new SSEClientTransport(resourceBaseUrl, { + authProvider: mockAuthProvider, + requestInit: { + headers: customHeaders + } + }); - await transport.send(message); - - // Verify fetch was called with correct headers - expect(global.fetch).toHaveBeenCalledWith( - expect.any(URL), - expect.objectContaining({ - headers: expect.any(Headers), - }), - ); - - const calledHeaders = (global.fetch as jest.Mock).mock.calls[0][1] - .headers; - expect(calledHeaders.get("Authorization")).toBe( - customHeaders.Authorization, - ); - expect(calledHeaders.get("X-Custom-Header")).toBe( - customHeaders["X-Custom-Header"], - ); - expect(calledHeaders.get("content-type")).toBe("application/json"); - } finally { - // Restore original fetch - global.fetch = originalFetch; - } - }); - }); - - describe("auth handling", () => { - const authServerMetadataUrls = [ - "/.well-known/oauth-authorization-server", - "/.well-known/openid-configuration", - ]; - - let mockAuthProvider: jest.Mocked; - - beforeEach(() => { - mockAuthProvider = { - get redirectUrl() { return "http://localhost/callback"; }, - get clientMetadata() { return { redirect_uris: ["http://localhost/callback"] }; }, - clientInformation: jest.fn(() => ({ client_id: "test-client-id", client_secret: "test-client-secret" })), - tokens: jest.fn(), - saveTokens: jest.fn(), - redirectToAuthorization: jest.fn(), - saveCodeVerifier: jest.fn(), - codeVerifier: jest.fn(), - invalidateCredentials: jest.fn(), - }; - }); + await transport.start(); - it("attaches auth header from provider on SSE connection", async () => { - mockAuthProvider.tokens.mockResolvedValue({ - access_token: "test-token", - token_type: "Bearer" - }); + expect(lastServerRequest.headers.authorization).toBe('Bearer test-token'); + expect(lastServerRequest.headers['x-custom-header']).toBe('custom-value'); + expect(mockAuthProvider.tokens).toHaveBeenCalled(); + }); - transport = new SSEClientTransport(resourceBaseUrl, { - authProvider: mockAuthProvider, - }); + it('attaches auth header from provider on POST requests', async () => { + mockAuthProvider.tokens.mockResolvedValue({ + access_token: 'test-token', + token_type: 'Bearer' + }); - await transport.start(); + transport = new SSEClientTransport(resourceBaseUrl, { + authProvider: mockAuthProvider + }); - expect(lastServerRequest.headers.authorization).toBe("Bearer test-token"); - expect(mockAuthProvider.tokens).toHaveBeenCalled(); - }); + await transport.start(); - it("attaches custom header from provider on initial SSE connection", async () => { - mockAuthProvider.tokens.mockResolvedValue({ - access_token: "test-token", - token_type: "Bearer" - }); - const customHeaders = { - "X-Custom-Header": "custom-value", - }; - - transport = new SSEClientTransport(resourceBaseUrl, { - authProvider: mockAuthProvider, - requestInit: { - headers: customHeaders, - }, - }); - - await transport.start(); - - expect(lastServerRequest.headers.authorization).toBe("Bearer test-token"); - expect(lastServerRequest.headers["x-custom-header"]).toBe("custom-value"); - expect(mockAuthProvider.tokens).toHaveBeenCalled(); - }); + const message: JSONRPCMessage = { + jsonrpc: '2.0', + id: '1', + method: 'test', + params: {} + }; - it("attaches auth header from provider on POST requests", async () => { - mockAuthProvider.tokens.mockResolvedValue({ - access_token: "test-token", - token_type: "Bearer" - }); + await transport.send(message); - transport = new SSEClientTransport(resourceBaseUrl, { - authProvider: mockAuthProvider, - }); + expect(lastServerRequest.headers.authorization).toBe('Bearer test-token'); + expect(mockAuthProvider.tokens).toHaveBeenCalled(); + }); - await transport.start(); + it('attempts auth flow on 401 during SSE connection', async () => { + // Create server that returns 401s + resourceServer.close(); + authServer.close(); + + // Start auth server on random port + await new Promise(resolve => { + authServer.listen(0, '127.0.0.1', () => { + const addr = authServer.address() as AddressInfo; + authBaseUrl = new URL(`http://127.0.0.1:${addr.port}`); + resolve(); + }); + }); - const message: JSONRPCMessage = { - jsonrpc: "2.0", - id: "1", - method: "test", - params: {}, - }; + resourceServer = createServer((req, res) => { + lastServerRequest = req; + + if (req.url === '/.well-known/oauth-protected-resource') { + res.writeHead(200, { + 'Content-Type': 'application/json' + }).end( + JSON.stringify({ + resource: resourceBaseUrl.href, + authorization_servers: [`${authBaseUrl}`] + }) + ); + return; + } + + if (req.url !== '/') { + res.writeHead(404).end(); + } else { + res.writeHead(401).end(); + } + }); - await transport.send(message); + await new Promise(resolve => { + resourceServer.listen(0, '127.0.0.1', () => { + const addr = resourceServer.address() as AddressInfo; + resourceBaseUrl = new URL(`http://127.0.0.1:${addr.port}`); + resolve(); + }); + }); - expect(lastServerRequest.headers.authorization).toBe("Bearer test-token"); - expect(mockAuthProvider.tokens).toHaveBeenCalled(); - }); + transport = new SSEClientTransport(resourceBaseUrl, { + authProvider: mockAuthProvider + }); - it("attempts auth flow on 401 during SSE connection", async () => { + await expect(() => transport.start()).rejects.toThrow(UnauthorizedError); + expect(mockAuthProvider.redirectToAuthorization.mock.calls).toHaveLength(1); + }); - // Create server that returns 401s - resourceServer.close(); - authServer.close(); + it('attempts auth flow on 401 during POST request', async () => { + // Create server that accepts SSE but returns 401 on POST + resourceServer.close(); + authServer.close(); + + await new Promise(resolve => { + authServer.listen(0, '127.0.0.1', () => { + const addr = authServer.address() as AddressInfo; + authBaseUrl = new URL(`http://127.0.0.1:${addr.port}`); + resolve(); + }); + }); - // Start auth server on random port - await new Promise(resolve => { - authServer.listen(0, "127.0.0.1", () => { - const addr = authServer.address() as AddressInfo; - authBaseUrl = new URL(`http://127.0.0.1:${addr.port}`); - resolve(); - }); - }); - - resourceServer = createServer((req, res) => { - lastServerRequest = req; - - if (req.url === "/.well-known/oauth-protected-resource") { - res.writeHead(200, { - 'Content-Type': 'application/json', - }) - .end(JSON.stringify({ - resource: resourceBaseUrl.href, - authorization_servers: [`${authBaseUrl}`], - })); - return; - } - - if (req.url !== "/") { - res.writeHead(404).end(); - } else { - res.writeHead(401).end(); - } - }); - - await new Promise(resolve => { - resourceServer.listen(0, "127.0.0.1", () => { - const addr = resourceServer.address() as AddressInfo; - resourceBaseUrl = new URL(`http://127.0.0.1:${addr.port}`); - resolve(); - }); - }); + resourceServer = createServer((req, res) => { + lastServerRequest = req; + + switch (req.method) { + case 'GET': + if (req.url === '/.well-known/oauth-protected-resource') { + res.writeHead(200, { + 'Content-Type': 'application/json' + }).end( + JSON.stringify({ + resource: resourceBaseUrl.href, + authorization_servers: [`${authBaseUrl}`] + }) + ); + return; + } + + if (req.url !== '/') { + res.writeHead(404).end(); + return; + } + + res.writeHead(200, { + 'Content-Type': 'text/event-stream', + 'Cache-Control': 'no-cache, no-transform', + Connection: 'keep-alive' + }); + res.write('event: endpoint\n'); + res.write(`data: ${resourceBaseUrl.href}\n\n`); + break; + + case 'POST': + res.writeHead(401); + res.end(); + break; + } + }); + + await new Promise(resolve => { + resourceServer.listen(0, '127.0.0.1', () => { + const addr = resourceServer.address() as AddressInfo; + resourceBaseUrl = new URL(`http://127.0.0.1:${addr.port}`); + resolve(); + }); + }); - transport = new SSEClientTransport(resourceBaseUrl, { - authProvider: mockAuthProvider, - }); + transport = new SSEClientTransport(resourceBaseUrl, { + authProvider: mockAuthProvider + }); - await expect(() => transport.start()).rejects.toThrow(UnauthorizedError); - expect(mockAuthProvider.redirectToAuthorization.mock.calls).toHaveLength(1); - }); + await transport.start(); - it("attempts auth flow on 401 during POST request", async () => { - // Create server that accepts SSE but returns 401 on POST - resourceServer.close(); - authServer.close(); + const message: JSONRPCMessage = { + jsonrpc: '2.0', + id: '1', + method: 'test', + params: {} + }; - await new Promise(resolve => { - authServer.listen(0, "127.0.0.1", () => { - const addr = authServer.address() as AddressInfo; - authBaseUrl = new URL(`http://127.0.0.1:${addr.port}`); - resolve(); + await expect(() => transport.send(message)).rejects.toThrow(UnauthorizedError); + expect(mockAuthProvider.redirectToAuthorization.mock.calls).toHaveLength(1); }); - }); - - resourceServer = createServer((req, res) => { - lastServerRequest = req; - - switch (req.method) { - case "GET": - if (req.url === "/.well-known/oauth-protected-resource") { - res.writeHead(200, { - 'Content-Type': 'application/json', - }) - .end(JSON.stringify({ - resource: resourceBaseUrl.href, - authorization_servers: [`${authBaseUrl}`], - })); - return; - } - if (req.url !== "/") { - res.writeHead(404).end(); - return; - } + it('respects custom headers when using auth provider', async () => { + mockAuthProvider.tokens.mockResolvedValue({ + access_token: 'test-token', + token_type: 'Bearer' + }); - res.writeHead(200, { - "Content-Type": "text/event-stream", - "Cache-Control": "no-cache, no-transform", - Connection: "keep-alive", + const customHeaders = { + 'X-Custom-Header': 'custom-value' + }; + + transport = new SSEClientTransport(resourceBaseUrl, { + authProvider: mockAuthProvider, + requestInit: { + headers: customHeaders + } }); - res.write("event: endpoint\n"); - res.write(`data: ${resourceBaseUrl.href}\n\n`); - break; - - case "POST": - res.writeHead(401); - res.end(); - break; - } - }); - - await new Promise(resolve => { - resourceServer.listen(0, "127.0.0.1", () => { - const addr = resourceServer.address() as AddressInfo; - resourceBaseUrl = new URL(`http://127.0.0.1:${addr.port}`); - resolve(); + + await transport.start(); + + const message: JSONRPCMessage = { + jsonrpc: '2.0', + id: '1', + method: 'test', + params: {} + }; + + await transport.send(message); + + expect(lastServerRequest.headers.authorization).toBe('Bearer test-token'); + expect(lastServerRequest.headers['x-custom-header']).toBe('custom-value'); }); - }); - transport = new SSEClientTransport(resourceBaseUrl, { - authProvider: mockAuthProvider, - }); + it('refreshes expired token during SSE connection', async () => { + // Mock tokens() to return expired token until saveTokens is called + let currentTokens: OAuthTokens = { + access_token: 'expired-token', + token_type: 'Bearer', + refresh_token: 'refresh-token' + }; + mockAuthProvider.tokens.mockImplementation(() => currentTokens); + mockAuthProvider.saveTokens.mockImplementation(tokens => { + currentTokens = tokens; + }); - await transport.start(); + // Create server that returns 401 for expired token, then accepts new token + resourceServer.close(); + authServer.close(); + + authServer = createServer((req, res) => { + if (req.url && authServerMetadataUrls.includes(req.url)) { + res.writeHead(404).end(); + return; + } + + if (req.url === '/token' && req.method === 'POST') { + // Handle token refresh request + let body = ''; + req.on('data', chunk => { + body += chunk; + }); + req.on('end', () => { + const params = new URLSearchParams(body); + if ( + params.get('grant_type') === 'refresh_token' && + params.get('refresh_token') === 'refresh-token' && + params.get('client_id') === 'test-client-id' && + params.get('client_secret') === 'test-client-secret' + ) { + res.writeHead(200, { 'Content-Type': 'application/json' }); + res.end( + JSON.stringify({ + access_token: 'new-token', + token_type: 'Bearer', + refresh_token: 'new-refresh-token' + }) + ); + } else { + res.writeHead(400).end(); + } + }); + return; + } + + res.writeHead(401).end(); + }); - const message: JSONRPCMessage = { - jsonrpc: "2.0", - id: "1", - method: "test", - params: {}, - }; + // Start auth server on random port + await new Promise(resolve => { + authServer.listen(0, '127.0.0.1', () => { + const addr = authServer.address() as AddressInfo; + authBaseUrl = new URL(`http://127.0.0.1:${addr.port}`); + resolve(); + }); + }); - await expect(() => transport.send(message)).rejects.toThrow(UnauthorizedError); - expect(mockAuthProvider.redirectToAuthorization.mock.calls).toHaveLength(1); - }); + let connectionAttempts = 0; + resourceServer = createServer((req, res) => { + lastServerRequest = req; + + if (req.url === '/.well-known/oauth-protected-resource') { + res.writeHead(200, { + 'Content-Type': 'application/json' + }).end( + JSON.stringify({ + resource: resourceBaseUrl.href, + authorization_servers: [`${authBaseUrl}`] + }) + ); + return; + } + + if (req.url !== '/') { + res.writeHead(404).end(); + return; + } + + const auth = req.headers.authorization; + if (auth === 'Bearer expired-token') { + res.writeHead(401).end(); + return; + } + + if (auth === 'Bearer new-token') { + res.writeHead(200, { + 'Content-Type': 'text/event-stream', + 'Cache-Control': 'no-cache, no-transform', + Connection: 'keep-alive' + }); + res.write('event: endpoint\n'); + res.write(`data: ${resourceBaseUrl.href}\n\n`); + connectionAttempts++; + return; + } + + res.writeHead(401).end(); + }); - it("respects custom headers when using auth provider", async () => { - mockAuthProvider.tokens.mockResolvedValue({ - access_token: "test-token", - token_type: "Bearer" - }); + await new Promise(resolve => { + resourceServer.listen(0, '127.0.0.1', () => { + const addr = resourceServer.address() as AddressInfo; + resourceBaseUrl = new URL(`http://127.0.0.1:${addr.port}`); + resolve(); + }); + }); - const customHeaders = { - "X-Custom-Header": "custom-value", - }; + transport = new SSEClientTransport(resourceBaseUrl, { + authProvider: mockAuthProvider + }); - transport = new SSEClientTransport(resourceBaseUrl, { - authProvider: mockAuthProvider, - requestInit: { - headers: customHeaders, - }, - }); + await transport.start(); - await transport.start(); + expect(mockAuthProvider.saveTokens).toHaveBeenCalledWith({ + access_token: 'new-token', + token_type: 'Bearer', + refresh_token: 'new-refresh-token' + }); + expect(connectionAttempts).toBe(1); + expect(lastServerRequest.headers.authorization).toBe('Bearer new-token'); + }); - const message: JSONRPCMessage = { - jsonrpc: "2.0", - id: "1", - method: "test", - params: {}, - }; + it('refreshes expired token during POST request', async () => { + // Mock tokens() to return expired token until saveTokens is called + let currentTokens: OAuthTokens = { + access_token: 'expired-token', + token_type: 'Bearer', + refresh_token: 'refresh-token' + }; + mockAuthProvider.tokens.mockImplementation(() => currentTokens); + mockAuthProvider.saveTokens.mockImplementation(tokens => { + currentTokens = tokens; + }); - await transport.send(message); + // Create server that returns 401 for expired token, then accepts new token + resourceServer.close(); + authServer.close(); + + authServer = createServer((req, res) => { + if (req.url && authServerMetadataUrls.includes(req.url)) { + res.writeHead(404).end(); + return; + } + + if (req.url === '/token' && req.method === 'POST') { + // Handle token refresh request + let body = ''; + req.on('data', chunk => { + body += chunk; + }); + req.on('end', () => { + const params = new URLSearchParams(body); + if ( + params.get('grant_type') === 'refresh_token' && + params.get('refresh_token') === 'refresh-token' && + params.get('client_id') === 'test-client-id' && + params.get('client_secret') === 'test-client-secret' + ) { + res.writeHead(200, { 'Content-Type': 'application/json' }); + res.end( + JSON.stringify({ + access_token: 'new-token', + token_type: 'Bearer', + refresh_token: 'new-refresh-token' + }) + ); + } else { + res.writeHead(400).end(); + } + }); + return; + } + + res.writeHead(401).end(); + }); - expect(lastServerRequest.headers.authorization).toBe("Bearer test-token"); - expect(lastServerRequest.headers["x-custom-header"]).toBe("custom-value"); - }); + // Start auth server on random port + await new Promise(resolve => { + authServer.listen(0, '127.0.0.1', () => { + const addr = authServer.address() as AddressInfo; + authBaseUrl = new URL(`http://127.0.0.1:${addr.port}`); + resolve(); + }); + }); - it("refreshes expired token during SSE connection", async () => { - // Mock tokens() to return expired token until saveTokens is called - let currentTokens: OAuthTokens = { - access_token: "expired-token", - token_type: "Bearer", - refresh_token: "refresh-token" - }; - mockAuthProvider.tokens.mockImplementation(() => currentTokens); - mockAuthProvider.saveTokens.mockImplementation((tokens) => { - currentTokens = tokens; - }); - - // Create server that returns 401 for expired token, then accepts new token - resourceServer.close(); - authServer.close(); - - authServer = createServer((req, res) => { - if (req.url && authServerMetadataUrls.includes(req.url)) { - res.writeHead(404).end(); - return; - } - - if (req.url === "/token" && req.method === "POST") { - // Handle token refresh request - let body = ""; - req.on("data", chunk => { body += chunk; }); - req.on("end", () => { - const params = new URLSearchParams(body); - if (params.get("grant_type") === "refresh_token" && - params.get("refresh_token") === "refresh-token" && - params.get("client_id") === "test-client-id" && - params.get("client_secret") === "test-client-secret") { - res.writeHead(200, { "Content-Type": "application/json" }); - res.end(JSON.stringify({ - access_token: "new-token", - token_type: "Bearer", - refresh_token: "new-refresh-token" - })); - } else { - res.writeHead(400).end(); - } - }); - return; - } + let postAttempts = 0; + resourceServer = createServer((req, res) => { + lastServerRequest = req; + + if (req.url === '/.well-known/oauth-protected-resource') { + res.writeHead(200, { + 'Content-Type': 'application/json' + }).end( + JSON.stringify({ + resource: resourceBaseUrl.href, + authorization_servers: [`${authBaseUrl}`] + }) + ); + return; + } + + switch (req.method) { + case 'GET': + if (req.url !== '/') { + res.writeHead(404).end(); + return; + } + + res.writeHead(200, { + 'Content-Type': 'text/event-stream', + 'Cache-Control': 'no-cache, no-transform', + Connection: 'keep-alive' + }); + res.write('event: endpoint\n'); + res.write(`data: ${resourceBaseUrl.href}\n\n`); + break; + + case 'POST': { + if (req.url !== '/') { + res.writeHead(404).end(); + return; + } + + const auth = req.headers.authorization; + if (auth === 'Bearer expired-token') { + res.writeHead(401).end(); + return; + } + + if (auth === 'Bearer new-token') { + res.writeHead(200).end(); + postAttempts++; + return; + } + + res.writeHead(401).end(); + break; + } + } + }); - res.writeHead(401).end(); + await new Promise(resolve => { + resourceServer.listen(0, '127.0.0.1', () => { + const addr = resourceServer.address() as AddressInfo; + resourceBaseUrl = new URL(`http://127.0.0.1:${addr.port}`); + resolve(); + }); + }); + + transport = new SSEClientTransport(resourceBaseUrl, { + authProvider: mockAuthProvider + }); - }); + await transport.start(); - // Start auth server on random port - await new Promise(resolve => { - authServer.listen(0, "127.0.0.1", () => { - const addr = authServer.address() as AddressInfo; - authBaseUrl = new URL(`http://127.0.0.1:${addr.port}`); - resolve(); + const message: JSONRPCMessage = { + jsonrpc: '2.0', + id: '1', + method: 'test', + params: {} + }; + + await transport.send(message); + + expect(mockAuthProvider.saveTokens).toHaveBeenCalledWith({ + access_token: 'new-token', + token_type: 'Bearer', + refresh_token: 'new-refresh-token' + }); + expect(postAttempts).toBe(1); + expect(lastServerRequest.headers.authorization).toBe('Bearer new-token'); }); - }); - - let connectionAttempts = 0; - resourceServer = createServer((req, res) => { - lastServerRequest = req; - - if (req.url === "/.well-known/oauth-protected-resource") { - res.writeHead(200, { - 'Content-Type': 'application/json', - }) - .end(JSON.stringify({ - resource: resourceBaseUrl.href, - authorization_servers: [`${authBaseUrl}`], - })); - return; - } - - if (req.url !== "/") { - res.writeHead(404).end(); - return; - } - - const auth = req.headers.authorization; - if (auth === "Bearer expired-token") { - res.writeHead(401).end(); - return; - } - - if (auth === "Bearer new-token") { - res.writeHead(200, { - "Content-Type": "text/event-stream", - "Cache-Control": "no-cache, no-transform", - Connection: "keep-alive", - }); - res.write("event: endpoint\n"); - res.write(`data: ${resourceBaseUrl.href}\n\n`); - connectionAttempts++; - return; - } - - res.writeHead(401).end(); - }); - - await new Promise(resolve => { - resourceServer.listen(0, "127.0.0.1", () => { - const addr = resourceServer.address() as AddressInfo; - resourceBaseUrl = new URL(`http://127.0.0.1:${addr.port}`); - resolve(); + + it('redirects to authorization if refresh token flow fails', async () => { + // Mock tokens() to return expired token until saveTokens is called + let currentTokens: OAuthTokens = { + access_token: 'expired-token', + token_type: 'Bearer', + refresh_token: 'refresh-token' + }; + mockAuthProvider.tokens.mockImplementation(() => currentTokens); + mockAuthProvider.saveTokens.mockImplementation(tokens => { + currentTokens = tokens; + }); + + // Create server that returns 401 for all tokens + resourceServer.close(); + authServer.close(); + + authServer = createServer((req, res) => { + if (req.url && authServerMetadataUrls.includes(req.url)) { + res.writeHead(404).end(); + return; + } + + if (req.url === '/token' && req.method === 'POST') { + // Handle token refresh request - always fail + res.writeHead(400).end(); + return; + } + + res.writeHead(401).end(); + }); + + // Start auth server on random port + await new Promise(resolve => { + authServer.listen(0, '127.0.0.1', () => { + const addr = authServer.address() as AddressInfo; + authBaseUrl = new URL(`http://127.0.0.1:${addr.port}`); + resolve(); + }); + }); + + resourceServer = createServer((req, res) => { + lastServerRequest = req; + + if (req.url === '/.well-known/oauth-protected-resource') { + res.writeHead(200, { + 'Content-Type': 'application/json' + }).end( + JSON.stringify({ + resource: resourceBaseUrl.href, + authorization_servers: [`${authBaseUrl}`] + }) + ); + return; + } + + if (req.url !== '/') { + res.writeHead(404).end(); + return; + } + res.writeHead(401).end(); + }); + + await new Promise(resolve => { + resourceServer.listen(0, '127.0.0.1', () => { + const addr = resourceServer.address() as AddressInfo; + resourceBaseUrl = new URL(`http://127.0.0.1:${addr.port}`); + resolve(); + }); + }); + + transport = new SSEClientTransport(resourceBaseUrl, { + authProvider: mockAuthProvider + }); + + await expect(() => transport.start()).rejects.toThrow(UnauthorizedError); + expect(mockAuthProvider.redirectToAuthorization).toHaveBeenCalled(); }); - }); - transport = new SSEClientTransport(resourceBaseUrl, { - authProvider: mockAuthProvider, - }); + it('invalidates all credentials on InvalidClientError during token refresh', async () => { + // Mock tokens() to return token with refresh token + mockAuthProvider.tokens.mockResolvedValue({ + access_token: 'expired-token', + token_type: 'Bearer', + refresh_token: 'refresh-token' + }); + + let baseUrl = resourceBaseUrl; + + // Create server that returns InvalidClientError on token refresh + const server = createServer((req, res) => { + lastServerRequest = req; + + // Handle OAuth metadata discovery + if (req.url === '/.well-known/oauth-authorization-server' && req.method === 'GET') { + res.writeHead(200, { 'Content-Type': 'application/json' }); + res.end( + JSON.stringify({ + issuer: baseUrl.href, + authorization_endpoint: `${baseUrl.href}authorize`, + token_endpoint: `${baseUrl.href}token`, + response_types_supported: ['code'], + code_challenge_methods_supported: ['S256'] + }) + ); + return; + } + + if (req.url === '/token' && req.method === 'POST') { + // Handle token refresh request - return InvalidClientError + const error = new InvalidClientError('Client authentication failed'); + res.writeHead(400, { 'Content-Type': 'application/json' }).end(JSON.stringify(error.toResponseObject())); + return; + } + + if (req.url !== '/') { + res.writeHead(404).end(); + return; + } + res.writeHead(401).end(); + }); - await transport.start(); + await new Promise(resolve => { + server.listen(0, '127.0.0.1', () => { + const addr = server.address() as AddressInfo; + baseUrl = new URL(`http://127.0.0.1:${addr.port}`); + resolve(); + }); + }); - expect(mockAuthProvider.saveTokens).toHaveBeenCalledWith({ - access_token: "new-token", - token_type: "Bearer", - refresh_token: "new-refresh-token" - }); - expect(connectionAttempts).toBe(1); - expect(lastServerRequest.headers.authorization).toBe("Bearer new-token"); - }); + transport = new SSEClientTransport(baseUrl, { + authProvider: mockAuthProvider + }); - it("refreshes expired token during POST request", async () => { - // Mock tokens() to return expired token until saveTokens is called - let currentTokens: OAuthTokens = { - access_token: "expired-token", - token_type: "Bearer", - refresh_token: "refresh-token" - }; - mockAuthProvider.tokens.mockImplementation(() => currentTokens); - mockAuthProvider.saveTokens.mockImplementation((tokens) => { - currentTokens = tokens; - }); - - // Create server that returns 401 for expired token, then accepts new token - resourceServer.close(); - authServer.close(); - - authServer = createServer((req, res) => { - if (req.url && authServerMetadataUrls.includes(req.url)) { - res.writeHead(404).end(); - return; - } - - if (req.url === "/token" && req.method === "POST") { - // Handle token refresh request - let body = ""; - req.on("data", chunk => { body += chunk; }); - req.on("end", () => { - const params = new URLSearchParams(body); - if (params.get("grant_type") === "refresh_token" && - params.get("refresh_token") === "refresh-token" && - params.get("client_id") === "test-client-id" && - params.get("client_secret") === "test-client-secret") { - res.writeHead(200, { "Content-Type": "application/json" }); - res.end(JSON.stringify({ - access_token: "new-token", - token_type: "Bearer", - refresh_token: "new-refresh-token" - })); - } else { - res.writeHead(400).end(); - } - }); - return; - } + await expect(() => transport.start()).rejects.toThrow(InvalidClientError); + expect(mockAuthProvider.invalidateCredentials).toHaveBeenCalledWith('all'); + }); - res.writeHead(401).end(); + it('invalidates all credentials on UnauthorizedClientError during token refresh', async () => { + // Mock tokens() to return token with refresh token + mockAuthProvider.tokens.mockResolvedValue({ + access_token: 'expired-token', + token_type: 'Bearer', + refresh_token: 'refresh-token' + }); + + let baseUrl = resourceBaseUrl; + + const server = createServer((req, res) => { + lastServerRequest = req; + + // Handle OAuth metadata discovery + if (req.url === '/.well-known/oauth-authorization-server' && req.method === 'GET') { + res.writeHead(200, { 'Content-Type': 'application/json' }); + res.end( + JSON.stringify({ + issuer: baseUrl.href, + authorization_endpoint: `${baseUrl.href}authorize`, + token_endpoint: `${baseUrl.href}token`, + response_types_supported: ['code'], + code_challenge_methods_supported: ['S256'] + }) + ); + return; + } + + if (req.url === '/token' && req.method === 'POST') { + // Handle token refresh request - return UnauthorizedClientError + const error = new UnauthorizedClientError('Client not authorized'); + res.writeHead(400, { 'Content-Type': 'application/json' }).end(JSON.stringify(error.toResponseObject())); + return; + } + + if (req.url !== '/') { + res.writeHead(404).end(); + return; + } + res.writeHead(401).end(); + }); - }); + await new Promise(resolve => { + server.listen(0, '127.0.0.1', () => { + const addr = server.address() as AddressInfo; + baseUrl = new URL(`http://127.0.0.1:${addr.port}`); + resolve(); + }); + }); - // Start auth server on random port - await new Promise(resolve => { - authServer.listen(0, "127.0.0.1", () => { - const addr = authServer.address() as AddressInfo; - authBaseUrl = new URL(`http://127.0.0.1:${addr.port}`); - resolve(); + transport = new SSEClientTransport(baseUrl, { + authProvider: mockAuthProvider + }); + + await expect(() => transport.start()).rejects.toThrow(UnauthorizedClientError); + expect(mockAuthProvider.invalidateCredentials).toHaveBeenCalledWith('all'); }); - }); - - let postAttempts = 0; - resourceServer = createServer((req, res) => { - lastServerRequest = req; - - if (req.url === "/.well-known/oauth-protected-resource") { - res.writeHead(200, { - 'Content-Type': 'application/json', - }) - .end(JSON.stringify({ - resource: resourceBaseUrl.href, - authorization_servers: [`${authBaseUrl}`], - })); - return; - } - - switch (req.method) { - case "GET": - if (req.url !== "/") { - res.writeHead(404).end(); - return; - } - res.writeHead(200, { - "Content-Type": "text/event-stream", - "Cache-Control": "no-cache, no-transform", - Connection: "keep-alive", + it('invalidates tokens on InvalidGrantError during token refresh', async () => { + // Mock tokens() to return token with refresh token + mockAuthProvider.tokens.mockResolvedValue({ + access_token: 'expired-token', + token_type: 'Bearer', + refresh_token: 'refresh-token' + }); + let baseUrl = resourceBaseUrl; + + const server = createServer((req, res) => { + lastServerRequest = req; + + // Handle OAuth metadata discovery + if (req.url === '/.well-known/oauth-authorization-server' && req.method === 'GET') { + res.writeHead(200, { 'Content-Type': 'application/json' }); + res.end( + JSON.stringify({ + issuer: baseUrl.href, + authorization_endpoint: `${baseUrl.href}authorize`, + token_endpoint: `${baseUrl.href}token`, + response_types_supported: ['code'], + code_challenge_methods_supported: ['S256'] + }) + ); + return; + } + + if (req.url === '/token' && req.method === 'POST') { + // Handle token refresh request - return InvalidGrantError + const error = new InvalidGrantError('Invalid refresh token'); + res.writeHead(400, { 'Content-Type': 'application/json' }).end(JSON.stringify(error.toResponseObject())); + return; + } + + if (req.url !== '/') { + res.writeHead(404).end(); + return; + } + res.writeHead(401).end(); }); - res.write("event: endpoint\n"); - res.write(`data: ${resourceBaseUrl.href}\n\n`); - break; - case "POST": { - if (req.url !== "/") { - res.writeHead(404).end(); - return; - } + await new Promise(resolve => { + server.listen(0, '127.0.0.1', () => { + const addr = server.address() as AddressInfo; + baseUrl = new URL(`http://127.0.0.1:${addr.port}`); + resolve(); + }); + }); - const auth = req.headers.authorization; - if (auth === "Bearer expired-token") { - res.writeHead(401).end(); - return; - } - - if (auth === "Bearer new-token") { - res.writeHead(200).end(); - postAttempts++; - return; - } - - res.writeHead(401).end(); - break; - } - } - }); - - await new Promise(resolve => { - resourceServer.listen(0, "127.0.0.1", () => { - const addr = resourceServer.address() as AddressInfo; - resourceBaseUrl = new URL(`http://127.0.0.1:${addr.port}`); - resolve(); + transport = new SSEClientTransport(baseUrl, { + authProvider: mockAuthProvider + }); + + await expect(() => transport.start()).rejects.toThrow(InvalidGrantError); + expect(mockAuthProvider.invalidateCredentials).toHaveBeenCalledWith('tokens'); }); - }); - - transport = new SSEClientTransport(resourceBaseUrl, { - authProvider: mockAuthProvider, - }); - - await transport.start(); - - const message: JSONRPCMessage = { - jsonrpc: "2.0", - id: "1", - method: "test", - params: {}, - }; - - await transport.send(message); - - expect(mockAuthProvider.saveTokens).toHaveBeenCalledWith({ - access_token: "new-token", - token_type: "Bearer", - refresh_token: "new-refresh-token" - }); - expect(postAttempts).toBe(1); - expect(lastServerRequest.headers.authorization).toBe("Bearer new-token"); }); - it("redirects to authorization if refresh token flow fails", async () => { - // Mock tokens() to return expired token until saveTokens is called - let currentTokens: OAuthTokens = { - access_token: "expired-token", - token_type: "Bearer", - refresh_token: "refresh-token" - }; - mockAuthProvider.tokens.mockImplementation(() => currentTokens); - mockAuthProvider.saveTokens.mockImplementation((tokens) => { - currentTokens = tokens; - }); - - // Create server that returns 401 for all tokens - resourceServer.close(); - authServer.close(); - - authServer = createServer((req, res) => { - if (req.url && authServerMetadataUrls.includes(req.url)) { - res.writeHead(404).end(); - return; - } - - if (req.url === "/token" && req.method === "POST") { - // Handle token refresh request - always fail - res.writeHead(400).end(); - return; - } - - res.writeHead(401).end(); - - }); - - - // Start auth server on random port - await new Promise(resolve => { - authServer.listen(0, "127.0.0.1", () => { - const addr = authServer.address() as AddressInfo; - authBaseUrl = new URL(`http://127.0.0.1:${addr.port}`); - resolve(); + describe('custom fetch in auth code paths', () => { + let customFetch: jest.MockedFunction; + let globalFetchSpy: jest.SpyInstance; + let mockAuthProvider: jest.Mocked; + let resourceServerHandler: jest.Mock< + void, + [ + IncomingMessage, + ServerResponse & { + req: IncomingMessage; + } + ], + void + >; + + /** + * Helper function to create a mock auth provider with configurable behavior + */ + const createMockAuthProvider = ( + config: { + hasTokens?: boolean; + tokensExpired?: boolean; + hasRefreshToken?: boolean; + clientRegistered?: boolean; + authorizationCode?: string; + } = {} + ): jest.Mocked => { + const tokens = config.hasTokens + ? { + access_token: config.tokensExpired ? 'expired-token' : 'valid-token', + token_type: 'Bearer' as const, + ...(config.hasRefreshToken && { refresh_token: 'refresh-token' }) + } + : undefined; + + const clientInfo = config.clientRegistered + ? { + client_id: 'test-client-id', + client_secret: 'test-client-secret' + } + : undefined; + + return { + get redirectUrl() { + return 'http://localhost/callback'; + }, + get clientMetadata() { + return { + redirect_uris: ['http://localhost/callback'], + client_name: 'Test Client' + }; + }, + clientInformation: jest.fn().mockResolvedValue(clientInfo), + tokens: jest.fn().mockResolvedValue(tokens), + saveTokens: jest.fn(), + redirectToAuthorization: jest.fn(), + saveCodeVerifier: jest.fn(), + codeVerifier: jest.fn().mockResolvedValue('test-verifier'), + invalidateCredentials: jest.fn() + }; + }; + + const createCustomFetchMockAuthServer = async () => { + authServer = createServer((req, res) => { + if (req.url === '/.well-known/oauth-authorization-server') { + res.writeHead(200, { 'Content-Type': 'application/json' }); + res.end( + JSON.stringify({ + issuer: `http://127.0.0.1:${(authServer.address() as AddressInfo).port}`, + authorization_endpoint: `http://127.0.0.1:${(authServer.address() as AddressInfo).port}/authorize`, + token_endpoint: `http://127.0.0.1:${(authServer.address() as AddressInfo).port}/token`, + registration_endpoint: `http://127.0.0.1:${(authServer.address() as AddressInfo).port}/register`, + response_types_supported: ['code'], + code_challenge_methods_supported: ['S256'] + }) + ); + return; + } + + if (req.url === '/token' && req.method === 'POST') { + // Handle token exchange request + let body = ''; + req.on('data', chunk => { + body += chunk; + }); + req.on('end', () => { + const params = new URLSearchParams(body); + if ( + params.get('grant_type') === 'authorization_code' && + params.get('code') === 'test-auth-code' && + params.get('client_id') === 'test-client-id' + ) { + res.writeHead(200, { 'Content-Type': 'application/json' }); + res.end( + JSON.stringify({ + access_token: 'new-access-token', + token_type: 'Bearer', + expires_in: 3600, + refresh_token: 'new-refresh-token' + }) + ); + } else { + res.writeHead(400).end(); + } + }); + return; + } + + res.writeHead(404).end(); + }); + + // Start auth server on random port + await new Promise(resolve => { + authServer.listen(0, '127.0.0.1', () => { + const addr = authServer.address() as AddressInfo; + authBaseUrl = new URL(`http://127.0.0.1:${addr.port}`); + resolve(); + }); + }); + }; + + const createCustomFetchMockResourceServer = async () => { + // Set up resource server that provides OAuth metadata + resourceServer = createServer((req, res) => { + lastServerRequest = req; + + if (req.url === '/.well-known/oauth-protected-resource') { + res.writeHead(200, { 'Content-Type': 'application/json' }); + res.end( + JSON.stringify({ + resource: resourceBaseUrl.href, + authorization_servers: [authBaseUrl.href] + }) + ); + return; + } + + resourceServerHandler(req, res); + }); + + // Start resource server on random port + await new Promise(resolve => { + resourceServer.listen(0, '127.0.0.1', () => { + const addr = resourceServer.address() as AddressInfo; + resourceBaseUrl = new URL(`http://127.0.0.1:${addr.port}`); + resolve(); + }); + }); + }; + + beforeEach(async () => { + // Close existing servers to set up custom auth flow servers + resourceServer.close(); + authServer.close(); + + const originalFetch = fetch; + + // Create custom fetch spy that delegates to real fetch + customFetch = jest.fn((url, init) => { + return originalFetch(url.toString(), init); + }); + + // Spy on global fetch to detect unauthorized usage + globalFetchSpy = jest.spyOn(global, 'fetch'); + + // Create mock auth provider with default configuration + mockAuthProvider = createMockAuthProvider({ + hasTokens: false, + clientRegistered: true + }); + + // Set up auth server that handles OAuth discovery and token requests + await createCustomFetchMockAuthServer(); + + // Set up resource server + resourceServerHandler = jest.fn( + ( + _req: IncomingMessage, + res: ServerResponse & { + req: IncomingMessage; + } + ) => { + res.writeHead(404).end(); + } + ); + await createCustomFetchMockResourceServer(); }); - }); - - resourceServer = createServer((req, res) => { - lastServerRequest = req; - - if (req.url === "/.well-known/oauth-protected-resource") { - res.writeHead(200, { - 'Content-Type': 'application/json', - }) - .end(JSON.stringify({ - resource: resourceBaseUrl.href, - authorization_servers: [`${authBaseUrl}`], - })); - return; - } - - if (req.url !== "/") { - res.writeHead(404).end(); - return; - } - res.writeHead(401).end(); - }); - - await new Promise(resolve => { - resourceServer.listen(0, "127.0.0.1", () => { - const addr = resourceServer.address() as AddressInfo; - resourceBaseUrl = new URL(`http://127.0.0.1:${addr.port}`); - resolve(); + + afterEach(() => { + globalFetchSpy.mockRestore(); }); - }); - transport = new SSEClientTransport(resourceBaseUrl, { - authProvider: mockAuthProvider, - }); + it('uses custom fetch during auth flow on SSE connection 401 - no global fetch fallback', async () => { + // Set up resource server that returns 401 on SSE connection and provides OAuth metadata + resourceServerHandler.mockImplementation((req, res) => { + if (req.url === '/') { + // Return 401 to trigger auth flow + res.writeHead(401, { + 'WWW-Authenticate': `Bearer realm="mcp", resource_metadata="${resourceBaseUrl.href}.well-known/oauth-protected-resource"` + }); + res.end(); + return; + } + + res.writeHead(404).end(); + }); - await expect(() => transport.start()).rejects.toThrow(UnauthorizedError); - expect(mockAuthProvider.redirectToAuthorization).toHaveBeenCalled(); - }); + // Create transport with custom fetch and auth provider + transport = new SSEClientTransport(resourceBaseUrl, { + authProvider: mockAuthProvider, + fetch: customFetch + }); - it("invalidates all credentials on InvalidClientError during token refresh", async () => { - // Mock tokens() to return token with refresh token - mockAuthProvider.tokens.mockResolvedValue({ - access_token: "expired-token", - token_type: "Bearer", - refresh_token: "refresh-token" - }); - - let baseUrl = resourceBaseUrl; - - // Create server that returns InvalidClientError on token refresh - const server = createServer((req, res) => { - lastServerRequest = req; - - // Handle OAuth metadata discovery - if (req.url === "/.well-known/oauth-authorization-server" && req.method === "GET") { - res.writeHead(200, { 'Content-Type': 'application/json' }); - res.end(JSON.stringify({ - issuer: baseUrl.href, - authorization_endpoint: `${baseUrl.href}authorize`, - token_endpoint: `${baseUrl.href}token`, - response_types_supported: ["code"], - code_challenge_methods_supported: ["S256"], - })); - return; - } - - if (req.url === "/token" && req.method === "POST") { - // Handle token refresh request - return InvalidClientError - const error = new InvalidClientError("Client authentication failed"); - res.writeHead(400, { 'Content-Type': 'application/json' }) - .end(JSON.stringify(error.toResponseObject())); - return; - } - - if (req.url !== "/") { - res.writeHead(404).end(); - return; - } - res.writeHead(401).end(); - }); - - await new Promise(resolve => { - server.listen(0, "127.0.0.1", () => { - const addr = server.address() as AddressInfo; - baseUrl = new URL(`http://127.0.0.1:${addr.port}`); - resolve(); - }); - }); + // Attempt to start - should trigger auth flow and eventually fail with UnauthorizedError + await expect(transport.start()).rejects.toThrow(UnauthorizedError); - transport = new SSEClientTransport(baseUrl, { - authProvider: mockAuthProvider, - }); + // Verify custom fetch was used + expect(customFetch).toHaveBeenCalled(); - await expect(() => transport.start()).rejects.toThrow(InvalidClientError); - expect(mockAuthProvider.invalidateCredentials).toHaveBeenCalledWith('all'); - }); + // Verify specific OAuth endpoints were called with custom fetch + const customFetchCalls = customFetch.mock.calls; + const callUrls = customFetchCalls.map(([url]) => url.toString()); - it("invalidates all credentials on UnauthorizedClientError during token refresh", async () => { - // Mock tokens() to return token with refresh token - mockAuthProvider.tokens.mockResolvedValue({ - access_token: "expired-token", - token_type: "Bearer", - refresh_token: "refresh-token" - }); - - let baseUrl = resourceBaseUrl; - - const server = createServer((req, res) => { - lastServerRequest = req; - - // Handle OAuth metadata discovery - if (req.url === "/.well-known/oauth-authorization-server" && req.method === "GET") { - res.writeHead(200, { 'Content-Type': 'application/json' }); - res.end(JSON.stringify({ - issuer: baseUrl.href, - authorization_endpoint: `${baseUrl.href}authorize`, - token_endpoint: `${baseUrl.href}token`, - response_types_supported: ["code"], - code_challenge_methods_supported: ["S256"], - })); - return; - } - - if (req.url === "/token" && req.method === "POST") { - // Handle token refresh request - return UnauthorizedClientError - const error = new UnauthorizedClientError("Client not authorized"); - res.writeHead(400, { 'Content-Type': 'application/json' }) - .end(JSON.stringify(error.toResponseObject())); - return; - } - - if (req.url !== "/") { - res.writeHead(404).end(); - return; - } - res.writeHead(401).end(); - }); - - await new Promise(resolve => { - server.listen(0, "127.0.0.1", () => { - const addr = server.address() as AddressInfo; - baseUrl = new URL(`http://127.0.0.1:${addr.port}`); - resolve(); - }); - }); + // Should have called resource metadata discovery + expect(callUrls.some(url => url.includes('/.well-known/oauth-protected-resource'))).toBe(true); - transport = new SSEClientTransport(baseUrl, { - authProvider: mockAuthProvider, - }); + // Should have called OAuth authorization server metadata discovery + expect(callUrls.some(url => url.includes('/.well-known/oauth-authorization-server'))).toBe(true); - await expect(() => transport.start()).rejects.toThrow(UnauthorizedClientError); - expect(mockAuthProvider.invalidateCredentials).toHaveBeenCalledWith('all'); - }); + // Verify auth provider was called to redirect to authorization + expect(mockAuthProvider.redirectToAuthorization).toHaveBeenCalled(); - it("invalidates tokens on InvalidGrantError during token refresh", async () => { - // Mock tokens() to return token with refresh token - mockAuthProvider.tokens.mockResolvedValue({ - access_token: "expired-token", - token_type: "Bearer", - refresh_token: "refresh-token" - }); - let baseUrl = resourceBaseUrl; - - const server = createServer((req, res) => { - lastServerRequest = req; - - // Handle OAuth metadata discovery - if (req.url === "/.well-known/oauth-authorization-server" && req.method === "GET") { - res.writeHead(200, { 'Content-Type': 'application/json' }); - res.end(JSON.stringify({ - issuer: baseUrl.href, - authorization_endpoint: `${baseUrl.href}authorize`, - token_endpoint: `${baseUrl.href}token`, - response_types_supported: ["code"], - code_challenge_methods_supported: ["S256"], - })); - return; - } - - if (req.url === "/token" && req.method === "POST") { - // Handle token refresh request - return InvalidGrantError - const error = new InvalidGrantError("Invalid refresh token"); - res.writeHead(400, { 'Content-Type': 'application/json' }) - .end(JSON.stringify(error.toResponseObject())); - return; - } - - if (req.url !== "/") { - res.writeHead(404).end(); - return; - } - res.writeHead(401).end(); - }); - - await new Promise(resolve => { - server.listen(0, "127.0.0.1", () => { - const addr = server.address() as AddressInfo; - baseUrl = new URL(`http://127.0.0.1:${addr.port}`); - resolve(); + // Global fetch should never have been called + expect(globalFetchSpy).not.toHaveBeenCalled(); }); - }); - transport = new SSEClientTransport(baseUrl, { - authProvider: mockAuthProvider, - }); + it('uses custom fetch during auth flow on POST request 401 - no global fetch fallback', async () => { + // Set up resource server that accepts SSE connection but returns 401 on POST + resourceServerHandler.mockImplementation((req, res) => { + switch (req.method) { + case 'GET': + if (req.url === '/') { + // Accept SSE connection + res.writeHead(200, { + 'Content-Type': 'text/event-stream', + 'Cache-Control': 'no-cache, no-transform', + Connection: 'keep-alive' + }); + res.write('event: endpoint\n'); + res.write(`data: ${resourceBaseUrl.href}\n\n`); + return; + } + break; + + case 'POST': + if (req.url === '/') { + // Return 401 to trigger auth retry + res.writeHead(401, { + 'WWW-Authenticate': `Bearer realm="mcp", resource_metadata="${resourceBaseUrl.href}.well-known/oauth-protected-resource"` + }); + res.end(); + return; + } + break; + } + + res.writeHead(404).end(); + }); - await expect(() => transport.start()).rejects.toThrow(InvalidGrantError); - expect(mockAuthProvider.invalidateCredentials).toHaveBeenCalledWith('tokens'); - }); - }); - - describe("custom fetch in auth code paths", () => { - let customFetch: jest.MockedFunction; - let globalFetchSpy: jest.SpyInstance; - let mockAuthProvider: jest.Mocked; - let resourceServerHandler: jest.Mock & { - req: IncomingMessage; - }], void>; - - /** - * Helper function to create a mock auth provider with configurable behavior - */ - const createMockAuthProvider = (config: { - hasTokens?: boolean; - tokensExpired?: boolean; - hasRefreshToken?: boolean; - clientRegistered?: boolean; - authorizationCode?: string; - } = {}): jest.Mocked => { - const tokens = config.hasTokens ? { - access_token: config.tokensExpired ? "expired-token" : "valid-token", - token_type: "Bearer" as const, - ...(config.hasRefreshToken && { refresh_token: "refresh-token" }) - } : undefined; - - const clientInfo = config.clientRegistered ? { - client_id: "test-client-id", - client_secret: "test-client-secret" - } : undefined; - - return { - get redirectUrl() { return "http://localhost/callback"; }, - get clientMetadata() { - return { - redirect_uris: ["http://localhost/callback"], - client_name: "Test Client" - }; - }, - clientInformation: jest.fn().mockResolvedValue(clientInfo), - tokens: jest.fn().mockResolvedValue(tokens), - saveTokens: jest.fn(), - redirectToAuthorization: jest.fn(), - saveCodeVerifier: jest.fn(), - codeVerifier: jest.fn().mockResolvedValue("test-verifier"), - invalidateCredentials: jest.fn(), - }; - }; - - const createCustomFetchMockAuthServer = async () => { - authServer = createServer((req, res) => { - if (req.url === "/.well-known/oauth-authorization-server") { - res.writeHead(200, { "Content-Type": "application/json" }); - res.end(JSON.stringify({ - issuer: `http://127.0.0.1:${(authServer.address() as AddressInfo).port}`, - authorization_endpoint: `http://127.0.0.1:${(authServer.address() as AddressInfo).port}/authorize`, - token_endpoint: `http://127.0.0.1:${(authServer.address() as AddressInfo).port}/token`, - registration_endpoint: `http://127.0.0.1:${(authServer.address() as AddressInfo).port}/register`, - response_types_supported: ["code"], - code_challenge_methods_supported: ["S256"], - })); - return; - } - - if (req.url === "/token" && req.method === "POST") { - // Handle token exchange request - let body = ""; - req.on("data", chunk => { body += chunk; }); - req.on("end", () => { - const params = new URLSearchParams(body); - if (params.get("grant_type") === "authorization_code" && - params.get("code") === "test-auth-code" && - params.get("client_id") === "test-client-id") { - res.writeHead(200, { "Content-Type": "application/json" }); - res.end(JSON.stringify({ - access_token: "new-access-token", - token_type: "Bearer", - expires_in: 3600, - refresh_token: "new-refresh-token" - })); - } else { - res.writeHead(400).end(); - } - }); - return; - } - - res.writeHead(404).end(); - }); - - // Start auth server on random port - await new Promise(resolve => { - authServer.listen(0, "127.0.0.1", () => { - const addr = authServer.address() as AddressInfo; - authBaseUrl = new URL(`http://127.0.0.1:${addr.port}`); - resolve(); - }); - }); - }; - - const createCustomFetchMockResourceServer = async () => { - // Set up resource server that provides OAuth metadata - resourceServer = createServer((req, res) => { - lastServerRequest = req; - - if (req.url === "/.well-known/oauth-protected-resource") { - res.writeHead(200, { "Content-Type": "application/json" }); - res.end(JSON.stringify({ - resource: resourceBaseUrl.href, - authorization_servers: [authBaseUrl.href], - })); - return; - } - - resourceServerHandler(req, res); - }); - - // Start resource server on random port - await new Promise(resolve => { - resourceServer.listen(0, "127.0.0.1", () => { - const addr = resourceServer.address() as AddressInfo; - resourceBaseUrl = new URL(`http://127.0.0.1:${addr.port}`); - resolve(); + // Create transport with custom fetch and auth provider + transport = new SSEClientTransport(resourceBaseUrl, { + authProvider: mockAuthProvider, + fetch: customFetch + }); + + // Start the transport (should succeed) + await transport.start(); + + // Send a message that should trigger 401 and auth retry + const message: JSONRPCMessage = { + jsonrpc: '2.0', + id: '1', + method: 'test', + params: {} + }; + + // Attempt to send message - should trigger auth flow and eventually fail + await expect(transport.send(message)).rejects.toThrow(UnauthorizedError); + + // Verify custom fetch was used + expect(customFetch).toHaveBeenCalled(); + + // Verify specific OAuth endpoints were called with custom fetch + const customFetchCalls = customFetch.mock.calls; + const callUrls = customFetchCalls.map(([url]) => url.toString()); + + // Should have called resource metadata discovery + expect(callUrls.some(url => url.includes('/.well-known/oauth-protected-resource'))).toBe(true); + + // Should have called OAuth authorization server metadata discovery + expect(callUrls.some(url => url.includes('/.well-known/oauth-authorization-server'))).toBe(true); + + // Should have attempted the POST request that triggered the 401 + const postCalls = customFetchCalls.filter( + ([url, options]) => url.toString() === resourceBaseUrl.href && options?.method === 'POST' + ); + expect(postCalls.length).toBeGreaterThan(0); + + // Verify auth provider was called to redirect to authorization + expect(mockAuthProvider.redirectToAuthorization).toHaveBeenCalled(); + + // Global fetch should never have been called + expect(globalFetchSpy).not.toHaveBeenCalled(); }); - }); - }; - - beforeEach(async () => { - // Close existing servers to set up custom auth flow servers - resourceServer.close(); - authServer.close(); - - const originalFetch = fetch; - - // Create custom fetch spy that delegates to real fetch - customFetch = jest.fn((url, init) => { - return originalFetch(url.toString(), init); - }); - - // Spy on global fetch to detect unauthorized usage - globalFetchSpy = jest.spyOn(global, 'fetch'); - - // Create mock auth provider with default configuration - mockAuthProvider = createMockAuthProvider({ - hasTokens: false, - clientRegistered: true - }); - - // Set up auth server that handles OAuth discovery and token requests - await createCustomFetchMockAuthServer(); - - // Set up resource server - resourceServerHandler = jest.fn((_req: IncomingMessage, res: ServerResponse & { - req: IncomingMessage; - }) => { - res.writeHead(404).end(); - }); - await createCustomFetchMockResourceServer(); - }); - afterEach(() => { - globalFetchSpy.mockRestore(); - }); + it('uses custom fetch in finishAuth method - no global fetch fallback', async () => { + // Create mock auth provider that expects to save tokens + const authProviderWithCode = createMockAuthProvider({ + clientRegistered: true, + authorizationCode: 'test-auth-code' + }); - it("uses custom fetch during auth flow on SSE connection 401 - no global fetch fallback", async () => { - // Set up resource server that returns 401 on SSE connection and provides OAuth metadata - resourceServerHandler.mockImplementation((req, res) => { - if (req.url === "/") { - // Return 401 to trigger auth flow - res.writeHead(401, { - "WWW-Authenticate": `Bearer realm="mcp", resource_metadata="${resourceBaseUrl.href}.well-known/oauth-protected-resource"` - }); - res.end(); - return; - } - - res.writeHead(404).end(); - }); - - // Create transport with custom fetch and auth provider - transport = new SSEClientTransport(resourceBaseUrl, { - authProvider: mockAuthProvider, - fetch: customFetch, - }); - - // Attempt to start - should trigger auth flow and eventually fail with UnauthorizedError - await expect(transport.start()).rejects.toThrow(UnauthorizedError); - - // Verify custom fetch was used - expect(customFetch).toHaveBeenCalled(); - - // Verify specific OAuth endpoints were called with custom fetch - const customFetchCalls = customFetch.mock.calls; - const callUrls = customFetchCalls.map(([url]) => url.toString()); - - // Should have called resource metadata discovery - expect(callUrls.some(url => url.includes('/.well-known/oauth-protected-resource'))).toBe(true); - - // Should have called OAuth authorization server metadata discovery - expect(callUrls.some(url => url.includes('/.well-known/oauth-authorization-server'))).toBe(true); - - // Verify auth provider was called to redirect to authorization - expect(mockAuthProvider.redirectToAuthorization).toHaveBeenCalled(); - - // Global fetch should never have been called - expect(globalFetchSpy).not.toHaveBeenCalled(); - }); + // Create transport with custom fetch and auth provider + transport = new SSEClientTransport(resourceBaseUrl, { + authProvider: authProviderWithCode, + fetch: customFetch + }); - it("uses custom fetch during auth flow on POST request 401 - no global fetch fallback", async () => { - // Set up resource server that accepts SSE connection but returns 401 on POST - resourceServerHandler.mockImplementation((req, res) => { - switch (req.method) { - case "GET": - if (req.url === "/") { - // Accept SSE connection - res.writeHead(200, { - "Content-Type": "text/event-stream", - "Cache-Control": "no-cache, no-transform", - Connection: "keep-alive", - }); - res.write("event: endpoint\n"); - res.write(`data: ${resourceBaseUrl.href}\n\n`); - return; - } - break; - - case "POST": - if (req.url === "/") { - // Return 401 to trigger auth retry - res.writeHead(401, { - "WWW-Authenticate": `Bearer realm="mcp", resource_metadata="${resourceBaseUrl.href}.well-known/oauth-protected-resource"` - }); - res.end(); - return; - } - break; - } - - res.writeHead(404).end(); - }); - - // Create transport with custom fetch and auth provider - transport = new SSEClientTransport(resourceBaseUrl, { - authProvider: mockAuthProvider, - fetch: customFetch, - }); - - // Start the transport (should succeed) - await transport.start(); - - // Send a message that should trigger 401 and auth retry - const message: JSONRPCMessage = { - jsonrpc: "2.0", - id: "1", - method: "test", - params: {}, - }; - - // Attempt to send message - should trigger auth flow and eventually fail - await expect(transport.send(message)).rejects.toThrow(UnauthorizedError); - - // Verify custom fetch was used - expect(customFetch).toHaveBeenCalled(); - - // Verify specific OAuth endpoints were called with custom fetch - const customFetchCalls = customFetch.mock.calls; - const callUrls = customFetchCalls.map(([url]) => url.toString()); - - // Should have called resource metadata discovery - expect(callUrls.some(url => url.includes('/.well-known/oauth-protected-resource'))).toBe(true); - - // Should have called OAuth authorization server metadata discovery - expect(callUrls.some(url => url.includes('/.well-known/oauth-authorization-server'))).toBe(true); - - // Should have attempted the POST request that triggered the 401 - const postCalls = customFetchCalls.filter(([url, options]) => - url.toString() === resourceBaseUrl.href && options?.method === "POST" - ); - expect(postCalls.length).toBeGreaterThan(0); - - // Verify auth provider was called to redirect to authorization - expect(mockAuthProvider.redirectToAuthorization).toHaveBeenCalled(); - - // Global fetch should never have been called - expect(globalFetchSpy).not.toHaveBeenCalled(); - }); + // Call finishAuth with authorization code + await transport.finishAuth('test-auth-code'); + + // Verify custom fetch was used + expect(customFetch).toHaveBeenCalled(); + + // Verify specific OAuth endpoints were called with custom fetch + const customFetchCalls = customFetch.mock.calls; + const callUrls = customFetchCalls.map(([url]) => url.toString()); + + // Should have called resource metadata discovery + expect(callUrls.some(url => url.includes('/.well-known/oauth-protected-resource'))).toBe(true); - it("uses custom fetch in finishAuth method - no global fetch fallback", async () => { - // Create mock auth provider that expects to save tokens - const authProviderWithCode = createMockAuthProvider({ - clientRegistered: true, - authorizationCode: "test-auth-code" - }); - - // Create transport with custom fetch and auth provider - transport = new SSEClientTransport(resourceBaseUrl, { - authProvider: authProviderWithCode, - fetch: customFetch, - }); - - // Call finishAuth with authorization code - await transport.finishAuth("test-auth-code"); - - // Verify custom fetch was used - expect(customFetch).toHaveBeenCalled(); - - // Verify specific OAuth endpoints were called with custom fetch - const customFetchCalls = customFetch.mock.calls; - const callUrls = customFetchCalls.map(([url]) => url.toString()); - - // Should have called resource metadata discovery - expect(callUrls.some(url => url.includes('/.well-known/oauth-protected-resource'))).toBe(true); - - // Should have called OAuth authorization server metadata discovery - expect(callUrls.some(url => url.includes('/.well-known/oauth-authorization-server'))).toBe(true); - - // Should have called token endpoint for authorization code exchange - const tokenCalls = customFetchCalls.filter(([url, options]) => - url.toString().includes('/token') && options?.method === "POST" - ); - expect(tokenCalls.length).toBeGreaterThan(0); - - // Verify tokens were saved - expect(authProviderWithCode.saveTokens).toHaveBeenCalledWith({ - access_token: "new-access-token", - token_type: "Bearer", - expires_in: 3600, - refresh_token: "new-refresh-token" - }); - - // Global fetch should never have been called - expect(globalFetchSpy).not.toHaveBeenCalled(); + // Should have called OAuth authorization server metadata discovery + expect(callUrls.some(url => url.includes('/.well-known/oauth-authorization-server'))).toBe(true); + + // Should have called token endpoint for authorization code exchange + const tokenCalls = customFetchCalls.filter(([url, options]) => url.toString().includes('/token') && options?.method === 'POST'); + expect(tokenCalls.length).toBeGreaterThan(0); + + // Verify tokens were saved + expect(authProviderWithCode.saveTokens).toHaveBeenCalledWith({ + access_token: 'new-access-token', + token_type: 'Bearer', + expires_in: 3600, + refresh_token: 'new-refresh-token' + }); + + // Global fetch should never have been called + expect(globalFetchSpy).not.toHaveBeenCalled(); + }); }); - }); }); diff --git a/src/client/sse.ts b/src/client/sse.ts index e1c86ccdb..aa4942444 100644 --- a/src/client/sse.ts +++ b/src/client/sse.ts @@ -1,57 +1,57 @@ -import { EventSource, type ErrorEvent, type EventSourceInit } from "eventsource"; -import { Transport, FetchLike } from "../shared/transport.js"; -import { JSONRPCMessage, JSONRPCMessageSchema } from "../types.js"; -import { auth, AuthResult, extractResourceMetadataUrl, OAuthClientProvider, UnauthorizedError } from "./auth.js"; +import { EventSource, type ErrorEvent, type EventSourceInit } from 'eventsource'; +import { Transport, FetchLike } from '../shared/transport.js'; +import { JSONRPCMessage, JSONRPCMessageSchema } from '../types.js'; +import { auth, AuthResult, extractResourceMetadataUrl, OAuthClientProvider, UnauthorizedError } from './auth.js'; export class SseError extends Error { - constructor( - public readonly code: number | undefined, - message: string | undefined, - public readonly event: ErrorEvent, - ) { - super(`SSE error: ${message}`); - } + constructor( + public readonly code: number | undefined, + message: string | undefined, + public readonly event: ErrorEvent + ) { + super(`SSE error: ${message}`); + } } /** * Configuration options for the `SSEClientTransport`. */ export type SSEClientTransportOptions = { - /** - * An OAuth client provider to use for authentication. - * - * When an `authProvider` is specified and the SSE connection is started: - * 1. The connection is attempted with any existing access token from the `authProvider`. - * 2. If the access token has expired, the `authProvider` is used to refresh the token. - * 3. If token refresh fails or no access token exists, and auth is required, `OAuthClientProvider.redirectToAuthorization` is called, and an `UnauthorizedError` will be thrown from `connect`/`start`. - * - * After the user has finished authorizing via their user agent, and is redirected back to the MCP client application, call `SSEClientTransport.finishAuth` with the authorization code before retrying the connection. - * - * If an `authProvider` is not provided, and auth is required, an `UnauthorizedError` will be thrown. - * - * `UnauthorizedError` might also be thrown when sending any message over the SSE transport, indicating that the session has expired, and needs to be re-authed and reconnected. - */ - authProvider?: OAuthClientProvider; - - /** - * Customizes the initial SSE request to the server (the request that begins the stream). - * - * NOTE: Setting this property will prevent an `Authorization` header from - * being automatically attached to the SSE request, if an `authProvider` is - * also given. This can be worked around by setting the `Authorization` header - * manually. - */ - eventSourceInit?: EventSourceInit; - - /** - * Customizes recurring POST requests to the server. - */ - requestInit?: RequestInit; - - /** - * Custom fetch implementation used for all network requests. - */ - fetch?: FetchLike; + /** + * An OAuth client provider to use for authentication. + * + * When an `authProvider` is specified and the SSE connection is started: + * 1. The connection is attempted with any existing access token from the `authProvider`. + * 2. If the access token has expired, the `authProvider` is used to refresh the token. + * 3. If token refresh fails or no access token exists, and auth is required, `OAuthClientProvider.redirectToAuthorization` is called, and an `UnauthorizedError` will be thrown from `connect`/`start`. + * + * After the user has finished authorizing via their user agent, and is redirected back to the MCP client application, call `SSEClientTransport.finishAuth` with the authorization code before retrying the connection. + * + * If an `authProvider` is not provided, and auth is required, an `UnauthorizedError` will be thrown. + * + * `UnauthorizedError` might also be thrown when sending any message over the SSE transport, indicating that the session has expired, and needs to be re-authed and reconnected. + */ + authProvider?: OAuthClientProvider; + + /** + * Customizes the initial SSE request to the server (the request that begins the stream). + * + * NOTE: Setting this property will prevent an `Authorization` header from + * being automatically attached to the SSE request, if an `authProvider` is + * also given. This can be worked around by setting the `Authorization` header + * manually. + */ + eventSourceInit?: EventSourceInit; + + /** + * Customizes recurring POST requests to the server. + */ + requestInit?: RequestInit; + + /** + * Custom fetch implementation used for all network requests. + */ + fetch?: FetchLike; }; /** @@ -59,220 +59,217 @@ export type SSEClientTransportOptions = { * messages and make separate POST requests for sending messages. */ export class SSEClientTransport implements Transport { - private _eventSource?: EventSource; - private _endpoint?: URL; - private _abortController?: AbortController; - private _url: URL; - private _resourceMetadataUrl?: URL; - private _eventSourceInit?: EventSourceInit; - private _requestInit?: RequestInit; - private _authProvider?: OAuthClientProvider; - private _fetch?: FetchLike; - private _protocolVersion?: string; - - onclose?: () => void; - onerror?: (error: Error) => void; - onmessage?: (message: JSONRPCMessage) => void; - - constructor( - url: URL, - opts?: SSEClientTransportOptions, - ) { - this._url = url; - this._resourceMetadataUrl = undefined; - this._eventSourceInit = opts?.eventSourceInit; - this._requestInit = opts?.requestInit; - this._authProvider = opts?.authProvider; - this._fetch = opts?.fetch; - } - - private async _authThenStart(): Promise { - if (!this._authProvider) { - throw new UnauthorizedError("No auth provider"); - } - - let result: AuthResult; - try { - result = await auth(this._authProvider, { serverUrl: this._url, resourceMetadataUrl: this._resourceMetadataUrl, fetchFn: this._fetch }); - } catch (error) { - this.onerror?.(error as Error); - throw error; - } - - if (result !== "AUTHORIZED") { - throw new UnauthorizedError(); + private _eventSource?: EventSource; + private _endpoint?: URL; + private _abortController?: AbortController; + private _url: URL; + private _resourceMetadataUrl?: URL; + private _eventSourceInit?: EventSourceInit; + private _requestInit?: RequestInit; + private _authProvider?: OAuthClientProvider; + private _fetch?: FetchLike; + private _protocolVersion?: string; + + onclose?: () => void; + onerror?: (error: Error) => void; + onmessage?: (message: JSONRPCMessage) => void; + + constructor(url: URL, opts?: SSEClientTransportOptions) { + this._url = url; + this._resourceMetadataUrl = undefined; + this._eventSourceInit = opts?.eventSourceInit; + this._requestInit = opts?.requestInit; + this._authProvider = opts?.authProvider; + this._fetch = opts?.fetch; } - return await this._startOrAuth(); - } - - private async _commonHeaders(): Promise { - const headers: HeadersInit = {}; - if (this._authProvider) { - const tokens = await this._authProvider.tokens(); - if (tokens) { - headers["Authorization"] = `Bearer ${tokens.access_token}`; - } - } - if (this._protocolVersion) { - headers["mcp-protocol-version"] = this._protocolVersion; - } - - return new Headers( - { ...headers, ...this._requestInit?.headers } - ); - } - - private _startOrAuth(): Promise { - const fetchImpl = (this?._eventSourceInit?.fetch ?? this._fetch ?? fetch) as typeof fetch - return new Promise((resolve, reject) => { - this._eventSource = new EventSource( - this._url.href, - { - ...this._eventSourceInit, - fetch: async (url, init) => { - const headers = await this._commonHeaders(); - headers.set("Accept", "text/event-stream"); - const response = await fetchImpl(url, { - ...init, - headers, - }) - - if (response.status === 401 && response.headers.has('www-authenticate')) { - this._resourceMetadataUrl = extractResourceMetadataUrl(response); - } - - return response - }, - }, - ); - this._abortController = new AbortController(); - - this._eventSource.onerror = (event) => { - if (event.code === 401 && this._authProvider) { - - this._authThenStart().then(resolve, reject); - return; + private async _authThenStart(): Promise { + if (!this._authProvider) { + throw new UnauthorizedError('No auth provider'); } - const error = new SseError(event.code, event.message, event); - reject(error); - this.onerror?.(error); - }; - - this._eventSource.onopen = () => { - // The connection is open, but we need to wait for the endpoint to be received. - }; - - this._eventSource.addEventListener("endpoint", (event: Event) => { - const messageEvent = event as MessageEvent; - + let result: AuthResult; try { - this._endpoint = new URL(messageEvent.data, this._url); - if (this._endpoint.origin !== this._url.origin) { - throw new Error( - `Endpoint origin does not match connection origin: ${this._endpoint.origin}`, - ); - } + result = await auth(this._authProvider, { + serverUrl: this._url, + resourceMetadataUrl: this._resourceMetadataUrl, + fetchFn: this._fetch + }); } catch (error) { - reject(error); - this.onerror?.(error as Error); + this.onerror?.(error as Error); + throw error; + } - void this.close(); - return; + if (result !== 'AUTHORIZED') { + throw new UnauthorizedError(); } - resolve(); - }); + return await this._startOrAuth(); + } - this._eventSource.onmessage = (event: Event) => { - const messageEvent = event as MessageEvent; - let message: JSONRPCMessage; - try { - message = JSONRPCMessageSchema.parse(JSON.parse(messageEvent.data)); - } catch (error) { - this.onerror?.(error as Error); - return; + private async _commonHeaders(): Promise { + const headers: HeadersInit = {}; + if (this._authProvider) { + const tokens = await this._authProvider.tokens(); + if (tokens) { + headers['Authorization'] = `Bearer ${tokens.access_token}`; + } + } + if (this._protocolVersion) { + headers['mcp-protocol-version'] = this._protocolVersion; } - this.onmessage?.(message); - }; - }); - } - - async start() { - if (this._eventSource) { - throw new Error( - "SSEClientTransport already started! If using Client class, note that connect() calls start() automatically.", - ); + return new Headers({ ...headers, ...this._requestInit?.headers }); } - return await this._startOrAuth(); - } - - /** - * Call this method after the user has finished authorizing via their user agent and is redirected back to the MCP client application. This will exchange the authorization code for an access token, enabling the next connection attempt to successfully auth. - */ - async finishAuth(authorizationCode: string): Promise { - if (!this._authProvider) { - throw new UnauthorizedError("No auth provider"); + private _startOrAuth(): Promise { + const fetchImpl = (this?._eventSourceInit?.fetch ?? this._fetch ?? fetch) as typeof fetch; + return new Promise((resolve, reject) => { + this._eventSource = new EventSource(this._url.href, { + ...this._eventSourceInit, + fetch: async (url, init) => { + const headers = await this._commonHeaders(); + headers.set('Accept', 'text/event-stream'); + const response = await fetchImpl(url, { + ...init, + headers + }); + + if (response.status === 401 && response.headers.has('www-authenticate')) { + this._resourceMetadataUrl = extractResourceMetadataUrl(response); + } + + return response; + } + }); + this._abortController = new AbortController(); + + this._eventSource.onerror = event => { + if (event.code === 401 && this._authProvider) { + this._authThenStart().then(resolve, reject); + return; + } + + const error = new SseError(event.code, event.message, event); + reject(error); + this.onerror?.(error); + }; + + this._eventSource.onopen = () => { + // The connection is open, but we need to wait for the endpoint to be received. + }; + + this._eventSource.addEventListener('endpoint', (event: Event) => { + const messageEvent = event as MessageEvent; + + try { + this._endpoint = new URL(messageEvent.data, this._url); + if (this._endpoint.origin !== this._url.origin) { + throw new Error(`Endpoint origin does not match connection origin: ${this._endpoint.origin}`); + } + } catch (error) { + reject(error); + this.onerror?.(error as Error); + + void this.close(); + return; + } + + resolve(); + }); + + this._eventSource.onmessage = (event: Event) => { + const messageEvent = event as MessageEvent; + let message: JSONRPCMessage; + try { + message = JSONRPCMessageSchema.parse(JSON.parse(messageEvent.data)); + } catch (error) { + this.onerror?.(error as Error); + return; + } + + this.onmessage?.(message); + }; + }); } - const result = await auth(this._authProvider, { serverUrl: this._url, authorizationCode, resourceMetadataUrl: this._resourceMetadataUrl, fetchFn: this._fetch }); - if (result !== "AUTHORIZED") { - throw new UnauthorizedError("Failed to authorize"); + async start() { + if (this._eventSource) { + throw new Error('SSEClientTransport already started! If using Client class, note that connect() calls start() automatically.'); + } + + return await this._startOrAuth(); } - } - async close(): Promise { - this._abortController?.abort(); - this._eventSource?.close(); - this.onclose?.(); - } + /** + * Call this method after the user has finished authorizing via their user agent and is redirected back to the MCP client application. This will exchange the authorization code for an access token, enabling the next connection attempt to successfully auth. + */ + async finishAuth(authorizationCode: string): Promise { + if (!this._authProvider) { + throw new UnauthorizedError('No auth provider'); + } - async send(message: JSONRPCMessage): Promise { - if (!this._endpoint) { - throw new Error("Not connected"); + const result = await auth(this._authProvider, { + serverUrl: this._url, + authorizationCode, + resourceMetadataUrl: this._resourceMetadataUrl, + fetchFn: this._fetch + }); + if (result !== 'AUTHORIZED') { + throw new UnauthorizedError('Failed to authorize'); + } } - try { - const headers = await this._commonHeaders(); - headers.set("content-type", "application/json"); - const init = { - ...this._requestInit, - method: "POST", - headers, - body: JSON.stringify(message), - signal: this._abortController?.signal, - }; - - const response = await (this._fetch ?? fetch)(this._endpoint, init); - if (!response.ok) { - if (response.status === 401 && this._authProvider) { - - this._resourceMetadataUrl = extractResourceMetadataUrl(response); - - const result = await auth(this._authProvider, { serverUrl: this._url, resourceMetadataUrl: this._resourceMetadataUrl, fetchFn: this._fetch }); - if (result !== "AUTHORIZED") { - throw new UnauthorizedError(); - } + async close(): Promise { + this._abortController?.abort(); + this._eventSource?.close(); + this.onclose?.(); + } - // Purposely _not_ awaited, so we don't call onerror twice - return this.send(message); + async send(message: JSONRPCMessage): Promise { + if (!this._endpoint) { + throw new Error('Not connected'); } - const text = await response.text().catch(() => null); - throw new Error( - `Error POSTing to endpoint (HTTP ${response.status}): ${text}`, - ); - } - } catch (error) { - this.onerror?.(error as Error); - throw error; + try { + const headers = await this._commonHeaders(); + headers.set('content-type', 'application/json'); + const init = { + ...this._requestInit, + method: 'POST', + headers, + body: JSON.stringify(message), + signal: this._abortController?.signal + }; + + const response = await (this._fetch ?? fetch)(this._endpoint, init); + if (!response.ok) { + if (response.status === 401 && this._authProvider) { + this._resourceMetadataUrl = extractResourceMetadataUrl(response); + + const result = await auth(this._authProvider, { + serverUrl: this._url, + resourceMetadataUrl: this._resourceMetadataUrl, + fetchFn: this._fetch + }); + if (result !== 'AUTHORIZED') { + throw new UnauthorizedError(); + } + + // Purposely _not_ awaited, so we don't call onerror twice + return this.send(message); + } + + const text = await response.text().catch(() => null); + throw new Error(`Error POSTing to endpoint (HTTP ${response.status}): ${text}`); + } + } catch (error) { + this.onerror?.(error as Error); + throw error; + } } - } - setProtocolVersion(version: string): void { - this._protocolVersion = version; - } + setProtocolVersion(version: string): void { + this._protocolVersion = version; + } } diff --git a/src/client/stdio.test.ts b/src/client/stdio.test.ts index 2e4d92c25..d2f5b5c41 100644 --- a/src/client/stdio.test.ts +++ b/src/client/stdio.test.ts @@ -1,77 +1,77 @@ -import { JSONRPCMessage } from "../types.js"; -import { StdioClientTransport, StdioServerParameters } from "./stdio.js"; +import { JSONRPCMessage } from '../types.js'; +import { StdioClientTransport, StdioServerParameters } from './stdio.js'; // Configure default server parameters based on OS // Uses 'more' command for Windows and 'tee' command for Unix/Linux const getDefaultServerParameters = (): StdioServerParameters => { - if (process.platform === "win32") { - return { command: "more" }; - } - return { command: "/usr/bin/tee" }; + if (process.platform === 'win32') { + return { command: 'more' }; + } + return { command: '/usr/bin/tee' }; }; const serverParameters = getDefaultServerParameters(); -test("should start then close cleanly", async () => { - const client = new StdioClientTransport(serverParameters); - client.onerror = (error) => { - throw error; - }; +test('should start then close cleanly', async () => { + const client = new StdioClientTransport(serverParameters); + client.onerror = error => { + throw error; + }; - let didClose = false; - client.onclose = () => { - didClose = true; - }; + let didClose = false; + client.onclose = () => { + didClose = true; + }; - await client.start(); - expect(didClose).toBeFalsy(); - await client.close(); - expect(didClose).toBeTruthy(); + await client.start(); + expect(didClose).toBeFalsy(); + await client.close(); + expect(didClose).toBeTruthy(); }); -test("should read messages", async () => { - const client = new StdioClientTransport(serverParameters); - client.onerror = (error) => { - throw error; - }; +test('should read messages', async () => { + const client = new StdioClientTransport(serverParameters); + client.onerror = error => { + throw error; + }; - const messages: JSONRPCMessage[] = [ - { - jsonrpc: "2.0", - id: 1, - method: "ping", - }, - { - jsonrpc: "2.0", - method: "notifications/initialized", - }, - ]; + const messages: JSONRPCMessage[] = [ + { + jsonrpc: '2.0', + id: 1, + method: 'ping' + }, + { + jsonrpc: '2.0', + method: 'notifications/initialized' + } + ]; - const readMessages: JSONRPCMessage[] = []; - const finished = new Promise((resolve) => { - client.onmessage = (message) => { - readMessages.push(message); + const readMessages: JSONRPCMessage[] = []; + const finished = new Promise(resolve => { + client.onmessage = message => { + readMessages.push(message); - if (JSON.stringify(message) === JSON.stringify(messages[1])) { - resolve(); - } - }; - }); + if (JSON.stringify(message) === JSON.stringify(messages[1])) { + resolve(); + } + }; + }); - await client.start(); - await client.send(messages[0]); - await client.send(messages[1]); - await finished; - expect(readMessages).toEqual(messages); + await client.start(); + await client.send(messages[0]); + await client.send(messages[1]); + await finished; + expect(readMessages).toEqual(messages); - await client.close(); + await client.close(); }); -test("should return child process pid", async () => { - const client = new StdioClientTransport(serverParameters); +test('should return child process pid', async () => { + const client = new StdioClientTransport(serverParameters); - await client.start(); - expect(client.pid).not.toBeNull(); - await client.close(); - expect(client.pid).toBeNull(); + await client.start(); + expect(client.pid).not.toBeNull(); + await client.close(); + expect(client.pid).toBeNull(); }); diff --git a/src/client/stdio.ts b/src/client/stdio.ts index 62292ce10..d62a3aeb6 100644 --- a/src/client/stdio.ts +++ b/src/client/stdio.ts @@ -1,87 +1,87 @@ -import { ChildProcess, IOType } from "node:child_process"; -import spawn from "cross-spawn"; -import process from "node:process"; -import { Stream, PassThrough } from "node:stream"; -import { ReadBuffer, serializeMessage } from "../shared/stdio.js"; -import { Transport } from "../shared/transport.js"; -import { JSONRPCMessage } from "../types.js"; +import { ChildProcess, IOType } from 'node:child_process'; +import spawn from 'cross-spawn'; +import process from 'node:process'; +import { Stream, PassThrough } from 'node:stream'; +import { ReadBuffer, serializeMessage } from '../shared/stdio.js'; +import { Transport } from '../shared/transport.js'; +import { JSONRPCMessage } from '../types.js'; export type StdioServerParameters = { - /** - * The executable to run to start the server. - */ - command: string; - - /** - * Command line arguments to pass to the executable. - */ - args?: string[]; - - /** - * The environment to use when spawning the process. - * - * If not specified, the result of getDefaultEnvironment() will be used. - */ - env?: Record; - - /** - * How to handle stderr of the child process. This matches the semantics of Node's `child_process.spawn`. - * - * The default is "inherit", meaning messages to stderr will be printed to the parent process's stderr. - */ - stderr?: IOType | Stream | number; - - /** - * The working directory to use when spawning the process. - * - * If not specified, the current working directory will be inherited. - */ - cwd?: string; + /** + * The executable to run to start the server. + */ + command: string; + + /** + * Command line arguments to pass to the executable. + */ + args?: string[]; + + /** + * The environment to use when spawning the process. + * + * If not specified, the result of getDefaultEnvironment() will be used. + */ + env?: Record; + + /** + * How to handle stderr of the child process. This matches the semantics of Node's `child_process.spawn`. + * + * The default is "inherit", meaning messages to stderr will be printed to the parent process's stderr. + */ + stderr?: IOType | Stream | number; + + /** + * The working directory to use when spawning the process. + * + * If not specified, the current working directory will be inherited. + */ + cwd?: string; }; /** * Environment variables to inherit by default, if an environment is not explicitly given. */ export const DEFAULT_INHERITED_ENV_VARS = - process.platform === "win32" - ? [ - "APPDATA", - "HOMEDRIVE", - "HOMEPATH", - "LOCALAPPDATA", - "PATH", - "PROCESSOR_ARCHITECTURE", - "SYSTEMDRIVE", - "SYSTEMROOT", - "TEMP", - "USERNAME", - "USERPROFILE", - "PROGRAMFILES", - ] - : /* list inspired by the default env inheritance of sudo */ - ["HOME", "LOGNAME", "PATH", "SHELL", "TERM", "USER"]; + process.platform === 'win32' + ? [ + 'APPDATA', + 'HOMEDRIVE', + 'HOMEPATH', + 'LOCALAPPDATA', + 'PATH', + 'PROCESSOR_ARCHITECTURE', + 'SYSTEMDRIVE', + 'SYSTEMROOT', + 'TEMP', + 'USERNAME', + 'USERPROFILE', + 'PROGRAMFILES' + ] + : /* list inspired by the default env inheritance of sudo */ + ['HOME', 'LOGNAME', 'PATH', 'SHELL', 'TERM', 'USER']; /** * Returns a default environment object including only environment variables deemed safe to inherit. */ export function getDefaultEnvironment(): Record { - const env: Record = {}; + const env: Record = {}; - for (const key of DEFAULT_INHERITED_ENV_VARS) { - const value = process.env[key]; - if (value === undefined) { - continue; - } + for (const key of DEFAULT_INHERITED_ENV_VARS) { + const value = process.env[key]; + if (value === undefined) { + continue; + } - if (value.startsWith("()")) { - // Skip functions, which are a security risk. - continue; - } + if (value.startsWith('()')) { + // Skip functions, which are a security risk. + continue; + } - env[key] = value; - } + env[key] = value; + } - return env; + return env; } /** @@ -90,151 +90,147 @@ export function getDefaultEnvironment(): Record { * This transport is only available in Node.js environments. */ export class StdioClientTransport implements Transport { - private _process?: ChildProcess; - private _abortController: AbortController = new AbortController(); - private _readBuffer: ReadBuffer = new ReadBuffer(); - private _serverParams: StdioServerParameters; - private _stderrStream: PassThrough | null = null; - - onclose?: () => void; - onerror?: (error: Error) => void; - onmessage?: (message: JSONRPCMessage) => void; - - constructor(server: StdioServerParameters) { - this._serverParams = server; - if (server.stderr === "pipe" || server.stderr === "overlapped") { - this._stderrStream = new PassThrough(); - } - } - - /** - * Starts the server process and prepares to communicate with it. - */ - async start(): Promise { - if (this._process) { - throw new Error( - "StdioClientTransport already started! If using Client class, note that connect() calls start() automatically." - ); + private _process?: ChildProcess; + private _abortController: AbortController = new AbortController(); + private _readBuffer: ReadBuffer = new ReadBuffer(); + private _serverParams: StdioServerParameters; + private _stderrStream: PassThrough | null = null; + + onclose?: () => void; + onerror?: (error: Error) => void; + onmessage?: (message: JSONRPCMessage) => void; + + constructor(server: StdioServerParameters) { + this._serverParams = server; + if (server.stderr === 'pipe' || server.stderr === 'overlapped') { + this._stderrStream = new PassThrough(); + } } - return new Promise((resolve, reject) => { - this._process = spawn( - this._serverParams.command, - this._serverParams.args ?? [], - { - // merge default env with server env because mcp server needs some env vars - env: { - ...getDefaultEnvironment(), - ...this._serverParams.env, - }, - stdio: ["pipe", "pipe", this._serverParams.stderr ?? "inherit"], - shell: false, - signal: this._abortController.signal, - windowsHide: process.platform === "win32" && isElectron(), - cwd: this._serverParams.cwd, + /** + * Starts the server process and prepares to communicate with it. + */ + async start(): Promise { + if (this._process) { + throw new Error( + 'StdioClientTransport already started! If using Client class, note that connect() calls start() automatically.' + ); } - ); - this._process.on("error", (error) => { - if (error.name === "AbortError") { - // Expected when close() is called. - this.onclose?.(); - return; - } + return new Promise((resolve, reject) => { + this._process = spawn(this._serverParams.command, this._serverParams.args ?? [], { + // merge default env with server env because mcp server needs some env vars + env: { + ...getDefaultEnvironment(), + ...this._serverParams.env + }, + stdio: ['pipe', 'pipe', this._serverParams.stderr ?? 'inherit'], + shell: false, + signal: this._abortController.signal, + windowsHide: process.platform === 'win32' && isElectron(), + cwd: this._serverParams.cwd + }); + + this._process.on('error', error => { + if (error.name === 'AbortError') { + // Expected when close() is called. + this.onclose?.(); + return; + } + + reject(error); + this.onerror?.(error); + }); + + this._process.on('spawn', () => { + resolve(); + }); + + this._process.on('close', _code => { + this._process = undefined; + this.onclose?.(); + }); + + this._process.stdin?.on('error', error => { + this.onerror?.(error); + }); + + this._process.stdout?.on('data', chunk => { + this._readBuffer.append(chunk); + this.processReadBuffer(); + }); + + this._process.stdout?.on('error', error => { + this.onerror?.(error); + }); + + if (this._stderrStream && this._process.stderr) { + this._process.stderr.pipe(this._stderrStream); + } + }); + } - reject(error); - this.onerror?.(error); - }); + /** + * The stderr stream of the child process, if `StdioServerParameters.stderr` was set to "pipe" or "overlapped". + * + * If stderr piping was requested, a PassThrough stream is returned _immediately_, allowing callers to + * attach listeners before the start method is invoked. This prevents loss of any early + * error output emitted by the child process. + */ + get stderr(): Stream | null { + if (this._stderrStream) { + return this._stderrStream; + } - this._process.on("spawn", () => { - resolve(); - }); + return this._process?.stderr ?? null; + } - this._process.on("close", (_code) => { - this._process = undefined; - this.onclose?.(); - }); - - this._process.stdin?.on("error", (error) => { - this.onerror?.(error); - }); - - this._process.stdout?.on("data", (chunk) => { - this._readBuffer.append(chunk); - this.processReadBuffer(); - }); - - this._process.stdout?.on("error", (error) => { - this.onerror?.(error); - }); - - if (this._stderrStream && this._process.stderr) { - this._process.stderr.pipe(this._stderrStream); - } - }); - } - - /** - * The stderr stream of the child process, if `StdioServerParameters.stderr` was set to "pipe" or "overlapped". - * - * If stderr piping was requested, a PassThrough stream is returned _immediately_, allowing callers to - * attach listeners before the start method is invoked. This prevents loss of any early - * error output emitted by the child process. - */ - get stderr(): Stream | null { - if (this._stderrStream) { - return this._stderrStream; + /** + * The child process pid spawned by this transport. + * + * This is only available after the transport has been started. + */ + get pid(): number | null { + return this._process?.pid ?? null; } - return this._process?.stderr ?? null; - } - - /** - * The child process pid spawned by this transport. - * - * This is only available after the transport has been started. - */ - get pid(): number | null { - return this._process?.pid ?? null; - } - - private processReadBuffer() { - while (true) { - try { - const message = this._readBuffer.readMessage(); - if (message === null) { - break; + private processReadBuffer() { + while (true) { + try { + const message = this._readBuffer.readMessage(); + if (message === null) { + break; + } + + this.onmessage?.(message); + } catch (error) { + this.onerror?.(error as Error); + } } + } + + async close(): Promise { + this._abortController.abort(); + this._process = undefined; + this._readBuffer.clear(); + } - this.onmessage?.(message); - } catch (error) { - this.onerror?.(error as Error); - } + send(message: JSONRPCMessage): Promise { + return new Promise(resolve => { + if (!this._process?.stdin) { + throw new Error('Not connected'); + } + + const json = serializeMessage(message); + if (this._process.stdin.write(json)) { + resolve(); + } else { + this._process.stdin.once('drain', resolve); + } + }); } - } - - async close(): Promise { - this._abortController.abort(); - this._process = undefined; - this._readBuffer.clear(); - } - - send(message: JSONRPCMessage): Promise { - return new Promise((resolve) => { - if (!this._process?.stdin) { - throw new Error("Not connected"); - } - - const json = serializeMessage(message); - if (this._process.stdin.write(json)) { - resolve(); - } else { - this._process.stdin.once("drain", resolve); - } - }); - } } function isElectron() { - return "type" in process; + return 'type' in process; } diff --git a/src/client/streamableHttp.test.ts b/src/client/streamableHttp.test.ts index fdd35ed3f..52b6e70fc 100644 --- a/src/client/streamableHttp.test.ts +++ b/src/client/streamableHttp.test.ts @@ -1,1004 +1,1011 @@ -import { StartSSEOptions, StreamableHTTPClientTransport, StreamableHTTPReconnectionOptions } from "./streamableHttp.js"; -import { OAuthClientProvider, UnauthorizedError } from "./auth.js"; -import { JSONRPCMessage, JSONRPCRequest } from "../types.js"; -import { InvalidClientError, InvalidGrantError, UnauthorizedClientError } from "../server/auth/errors.js"; - - -describe("StreamableHTTPClientTransport", () => { - let transport: StreamableHTTPClientTransport; - let mockAuthProvider: jest.Mocked; - - beforeEach(() => { - mockAuthProvider = { - get redirectUrl() { return "http://localhost/callback"; }, - get clientMetadata() { return { redirect_uris: ["http://localhost/callback"] }; }, - clientInformation: jest.fn(() => ({ client_id: "test-client-id", client_secret: "test-client-secret" })), - tokens: jest.fn(), - saveTokens: jest.fn(), - redirectToAuthorization: jest.fn(), - saveCodeVerifier: jest.fn(), - codeVerifier: jest.fn(), - invalidateCredentials: jest.fn(), - }; - transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), { authProvider: mockAuthProvider }); - jest.spyOn(global, "fetch"); - }); - - afterEach(async () => { - await transport.close().catch(() => { }); - jest.clearAllMocks(); - }); - - it("should send JSON-RPC messages via POST", async () => { - const message: JSONRPCMessage = { - jsonrpc: "2.0", - method: "test", - params: {}, - id: "test-id" - }; - - (global.fetch as jest.Mock).mockResolvedValueOnce({ - ok: true, - status: 202, - headers: new Headers(), - }); +import { StartSSEOptions, StreamableHTTPClientTransport, StreamableHTTPReconnectionOptions } from './streamableHttp.js'; +import { OAuthClientProvider, UnauthorizedError } from './auth.js'; +import { JSONRPCMessage, JSONRPCRequest } from '../types.js'; +import { InvalidClientError, InvalidGrantError, UnauthorizedClientError } from '../server/auth/errors.js'; - await transport.send(message); - - expect(global.fetch).toHaveBeenCalledWith( - expect.anything(), - expect.objectContaining({ - method: "POST", - headers: expect.any(Headers), - body: JSON.stringify(message) - }) - ); - }); - - it("should send batch messages", async () => { - const messages: JSONRPCMessage[] = [ - { jsonrpc: "2.0", method: "test1", params: {}, id: "id1" }, - { jsonrpc: "2.0", method: "test2", params: {}, id: "id2" } - ]; - - (global.fetch as jest.Mock).mockResolvedValueOnce({ - ok: true, - status: 200, - headers: new Headers({ "content-type": "text/event-stream" }), - body: null +describe('StreamableHTTPClientTransport', () => { + let transport: StreamableHTTPClientTransport; + let mockAuthProvider: jest.Mocked; + + beforeEach(() => { + mockAuthProvider = { + get redirectUrl() { + return 'http://localhost/callback'; + }, + get clientMetadata() { + return { redirect_uris: ['http://localhost/callback'] }; + }, + clientInformation: jest.fn(() => ({ client_id: 'test-client-id', client_secret: 'test-client-secret' })), + tokens: jest.fn(), + saveTokens: jest.fn(), + redirectToAuthorization: jest.fn(), + saveCodeVerifier: jest.fn(), + codeVerifier: jest.fn(), + invalidateCredentials: jest.fn() + }; + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider: mockAuthProvider }); + jest.spyOn(global, 'fetch'); }); - await transport.send(messages); - - expect(global.fetch).toHaveBeenCalledWith( - expect.anything(), - expect.objectContaining({ - method: "POST", - headers: expect.any(Headers), - body: JSON.stringify(messages) - }) - ); - }); - - it("should store session ID received during initialization", async () => { - const message: JSONRPCMessage = { - jsonrpc: "2.0", - method: "initialize", - params: { - clientInfo: { name: "test-client", version: "1.0" }, - protocolVersion: "2025-03-26" - }, - id: "init-id" - }; - - (global.fetch as jest.Mock).mockResolvedValueOnce({ - ok: true, - status: 200, - headers: new Headers({ "content-type": "text/event-stream", "mcp-session-id": "test-session-id" }), + afterEach(async () => { + await transport.close().catch(() => {}); + jest.clearAllMocks(); }); - await transport.send(message); + it('should send JSON-RPC messages via POST', async () => { + const message: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'test', + params: {}, + id: 'test-id' + }; + + (global.fetch as jest.Mock).mockResolvedValueOnce({ + ok: true, + status: 202, + headers: new Headers() + }); - // Send a second message that should include the session ID - (global.fetch as jest.Mock).mockResolvedValueOnce({ - ok: true, - status: 202, - headers: new Headers() - }); + await transport.send(message); - await transport.send({ jsonrpc: "2.0", method: "test", params: {} } as JSONRPCMessage); - - // Check that second request included session ID header - const calls = (global.fetch as jest.Mock).mock.calls; - const lastCall = calls[calls.length - 1]; - expect(lastCall[1].headers).toBeDefined(); - expect(lastCall[1].headers.get("mcp-session-id")).toBe("test-session-id"); - }); - - it("should terminate session with DELETE request", async () => { - // First, simulate getting a session ID - const message: JSONRPCMessage = { - jsonrpc: "2.0", - method: "initialize", - params: { - clientInfo: { name: "test-client", version: "1.0" }, - protocolVersion: "2025-03-26" - }, - id: "init-id" - }; - - (global.fetch as jest.Mock).mockResolvedValueOnce({ - ok: true, - status: 200, - headers: new Headers({ "content-type": "text/event-stream", "mcp-session-id": "test-session-id" }), + expect(global.fetch).toHaveBeenCalledWith( + expect.anything(), + expect.objectContaining({ + method: 'POST', + headers: expect.any(Headers), + body: JSON.stringify(message) + }) + ); }); - await transport.send(message); - expect(transport.sessionId).toBe("test-session-id"); + it('should send batch messages', async () => { + const messages: JSONRPCMessage[] = [ + { jsonrpc: '2.0', method: 'test1', params: {}, id: 'id1' }, + { jsonrpc: '2.0', method: 'test2', params: {}, id: 'id2' } + ]; + + (global.fetch as jest.Mock).mockResolvedValueOnce({ + ok: true, + status: 200, + headers: new Headers({ 'content-type': 'text/event-stream' }), + body: null + }); - // Now terminate the session - (global.fetch as jest.Mock).mockResolvedValueOnce({ - ok: true, - status: 200, - headers: new Headers() - }); + await transport.send(messages); - await transport.terminateSession(); - - // Verify the DELETE request was sent with the session ID - const calls = (global.fetch as jest.Mock).mock.calls; - const lastCall = calls[calls.length - 1]; - expect(lastCall[1].method).toBe("DELETE"); - expect(lastCall[1].headers.get("mcp-session-id")).toBe("test-session-id"); - - // The session ID should be cleared after successful termination - expect(transport.sessionId).toBeUndefined(); - }); - - it("should handle 405 response when server doesn't support session termination", async () => { - // First, simulate getting a session ID - const message: JSONRPCMessage = { - jsonrpc: "2.0", - method: "initialize", - params: { - clientInfo: { name: "test-client", version: "1.0" }, - protocolVersion: "2025-03-26" - }, - id: "init-id" - }; - - (global.fetch as jest.Mock).mockResolvedValueOnce({ - ok: true, - status: 200, - headers: new Headers({ "content-type": "text/event-stream", "mcp-session-id": "test-session-id" }), + expect(global.fetch).toHaveBeenCalledWith( + expect.anything(), + expect.objectContaining({ + method: 'POST', + headers: expect.any(Headers), + body: JSON.stringify(messages) + }) + ); }); - await transport.send(message); + it('should store session ID received during initialization', async () => { + const message: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'initialize', + params: { + clientInfo: { name: 'test-client', version: '1.0' }, + protocolVersion: '2025-03-26' + }, + id: 'init-id' + }; + + (global.fetch as jest.Mock).mockResolvedValueOnce({ + ok: true, + status: 200, + headers: new Headers({ 'content-type': 'text/event-stream', 'mcp-session-id': 'test-session-id' }) + }); - // Now terminate the session, but server responds with 405 - (global.fetch as jest.Mock).mockResolvedValueOnce({ - ok: false, - status: 405, - statusText: "Method Not Allowed", - headers: new Headers() - }); + await transport.send(message); - await expect(transport.terminateSession()).resolves.not.toThrow(); - }); - - it("should handle 404 response when session expires", async () => { - const message: JSONRPCMessage = { - jsonrpc: "2.0", - method: "test", - params: {}, - id: "test-id" - }; - - (global.fetch as jest.Mock).mockResolvedValueOnce({ - ok: false, - status: 404, - statusText: "Not Found", - text: () => Promise.resolve("Session not found"), - headers: new Headers() - }); + // Send a second message that should include the session ID + (global.fetch as jest.Mock).mockResolvedValueOnce({ + ok: true, + status: 202, + headers: new Headers() + }); + + await transport.send({ jsonrpc: '2.0', method: 'test', params: {} } as JSONRPCMessage); - const errorSpy = jest.fn(); - transport.onerror = errorSpy; - - await expect(transport.send(message)).rejects.toThrow("Error POSTing to endpoint (HTTP 404)"); - expect(errorSpy).toHaveBeenCalled(); - }); - - it("should handle non-streaming JSON response", async () => { - const message: JSONRPCMessage = { - jsonrpc: "2.0", - method: "test", - params: {}, - id: "test-id" - }; - - const responseMessage: JSONRPCMessage = { - jsonrpc: "2.0", - result: { success: true }, - id: "test-id" - }; - - (global.fetch as jest.Mock).mockResolvedValueOnce({ - ok: true, - status: 200, - headers: new Headers({ "content-type": "application/json" }), - json: () => Promise.resolve(responseMessage) + // Check that second request included session ID header + const calls = (global.fetch as jest.Mock).mock.calls; + const lastCall = calls[calls.length - 1]; + expect(lastCall[1].headers).toBeDefined(); + expect(lastCall[1].headers.get('mcp-session-id')).toBe('test-session-id'); }); - const messageSpy = jest.fn(); - transport.onmessage = messageSpy; + it('should terminate session with DELETE request', async () => { + // First, simulate getting a session ID + const message: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'initialize', + params: { + clientInfo: { name: 'test-client', version: '1.0' }, + protocolVersion: '2025-03-26' + }, + id: 'init-id' + }; + + (global.fetch as jest.Mock).mockResolvedValueOnce({ + ok: true, + status: 200, + headers: new Headers({ 'content-type': 'text/event-stream', 'mcp-session-id': 'test-session-id' }) + }); - await transport.send(message); + await transport.send(message); + expect(transport.sessionId).toBe('test-session-id'); + + // Now terminate the session + (global.fetch as jest.Mock).mockResolvedValueOnce({ + ok: true, + status: 200, + headers: new Headers() + }); - expect(messageSpy).toHaveBeenCalledWith(responseMessage); - }); + await transport.terminateSession(); - it("should attempt initial GET connection and handle 405 gracefully", async () => { - // Mock the server not supporting GET for SSE (returning 405) - (global.fetch as jest.Mock).mockResolvedValueOnce({ - ok: false, - status: 405, - statusText: "Method Not Allowed" + // Verify the DELETE request was sent with the session ID + const calls = (global.fetch as jest.Mock).mock.calls; + const lastCall = calls[calls.length - 1]; + expect(lastCall[1].method).toBe('DELETE'); + expect(lastCall[1].headers.get('mcp-session-id')).toBe('test-session-id'); + + // The session ID should be cleared after successful termination + expect(transport.sessionId).toBeUndefined(); }); - // We expect the 405 error to be caught and handled gracefully - // This should not throw an error that breaks the transport - await transport.start(); - await expect(transport["_startOrAuthSse"]({})).resolves.not.toThrow("Failed to open SSE stream: Method Not Allowed"); - // Check that GET was attempted - expect(global.fetch).toHaveBeenCalledWith( - expect.anything(), - expect.objectContaining({ - method: "GET", - headers: expect.any(Headers) - }) - ); - - // Verify transport still works after 405 - (global.fetch as jest.Mock).mockResolvedValueOnce({ - ok: true, - status: 202, - headers: new Headers() + it("should handle 405 response when server doesn't support session termination", async () => { + // First, simulate getting a session ID + const message: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'initialize', + params: { + clientInfo: { name: 'test-client', version: '1.0' }, + protocolVersion: '2025-03-26' + }, + id: 'init-id' + }; + + (global.fetch as jest.Mock).mockResolvedValueOnce({ + ok: true, + status: 200, + headers: new Headers({ 'content-type': 'text/event-stream', 'mcp-session-id': 'test-session-id' }) + }); + + await transport.send(message); + + // Now terminate the session, but server responds with 405 + (global.fetch as jest.Mock).mockResolvedValueOnce({ + ok: false, + status: 405, + statusText: 'Method Not Allowed', + headers: new Headers() + }); + + await expect(transport.terminateSession()).resolves.not.toThrow(); }); - await transport.send({ jsonrpc: "2.0", method: "test", params: {} } as JSONRPCMessage); - expect(global.fetch).toHaveBeenCalledTimes(2); - }); - - it("should handle successful initial GET connection for SSE", async () => { - // Set up readable stream for SSE events - const encoder = new TextEncoder(); - const stream = new ReadableStream({ - start(controller) { - // Send a server notification via SSE - const event = "event: message\ndata: {\"jsonrpc\": \"2.0\", \"method\": \"serverNotification\", \"params\": {}}\n\n"; - controller.enqueue(encoder.encode(event)); - } + it('should handle 404 response when session expires', async () => { + const message: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'test', + params: {}, + id: 'test-id' + }; + + (global.fetch as jest.Mock).mockResolvedValueOnce({ + ok: false, + status: 404, + statusText: 'Not Found', + text: () => Promise.resolve('Session not found'), + headers: new Headers() + }); + + const errorSpy = jest.fn(); + transport.onerror = errorSpy; + + await expect(transport.send(message)).rejects.toThrow('Error POSTing to endpoint (HTTP 404)'); + expect(errorSpy).toHaveBeenCalled(); }); - // Mock successful GET connection - (global.fetch as jest.Mock).mockResolvedValueOnce({ - ok: true, - status: 200, - headers: new Headers({ "content-type": "text/event-stream" }), - body: stream + it('should handle non-streaming JSON response', async () => { + const message: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'test', + params: {}, + id: 'test-id' + }; + + const responseMessage: JSONRPCMessage = { + jsonrpc: '2.0', + result: { success: true }, + id: 'test-id' + }; + + (global.fetch as jest.Mock).mockResolvedValueOnce({ + ok: true, + status: 200, + headers: new Headers({ 'content-type': 'application/json' }), + json: () => Promise.resolve(responseMessage) + }); + + const messageSpy = jest.fn(); + transport.onmessage = messageSpy; + + await transport.send(message); + + expect(messageSpy).toHaveBeenCalledWith(responseMessage); }); - const messageSpy = jest.fn(); - transport.onmessage = messageSpy; - - await transport.start(); - await transport["_startOrAuthSse"]({}); - - // Give time for the SSE event to be processed - await new Promise(resolve => setTimeout(resolve, 50)); - - expect(messageSpy).toHaveBeenCalledWith( - expect.objectContaining({ - jsonrpc: "2.0", - method: "serverNotification", - params: {} - }) - ); - }); - - it("should handle multiple concurrent SSE streams", async () => { - // Mock two POST requests that return SSE streams - const makeStream = (id: string) => { - const encoder = new TextEncoder(); - return new ReadableStream({ - start(controller) { - const event = `event: message\ndata: {"jsonrpc": "2.0", "result": {"id": "${id}"}, "id": "${id}"}\n\n`; - controller.enqueue(encoder.encode(event)); - } - }); - }; - - (global.fetch as jest.Mock) - .mockResolvedValueOnce({ - ok: true, - status: 200, - headers: new Headers({ "content-type": "text/event-stream" }), - body: makeStream("request1") - }) - .mockResolvedValueOnce({ - ok: true, - status: 200, - headers: new Headers({ "content-type": "text/event-stream" }), - body: makeStream("request2") - }); - - const messageSpy = jest.fn(); - transport.onmessage = messageSpy; - - // Send two concurrent requests - await Promise.all([ - transport.send({ jsonrpc: "2.0", method: "test1", params: {}, id: "request1" }), - transport.send({ jsonrpc: "2.0", method: "test2", params: {}, id: "request2" }) - ]); - - // Give time for SSE processing - await new Promise(resolve => setTimeout(resolve, 100)); - - // Both streams should have delivered their messages - expect(messageSpy).toHaveBeenCalledTimes(2); - - // Verify received messages without assuming specific order - expect(messageSpy.mock.calls.some(call => { - const msg = call[0]; - return msg.id === "request1" && msg.result?.id === "request1"; - })).toBe(true); - - expect(messageSpy.mock.calls.some(call => { - const msg = call[0]; - return msg.id === "request2" && msg.result?.id === "request2"; - })).toBe(true); - }); - - it("should support custom reconnection options", () => { - // Create a transport with custom reconnection options - transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), { - reconnectionOptions: { - initialReconnectionDelay: 500, - maxReconnectionDelay: 10000, - reconnectionDelayGrowFactor: 2, - maxRetries: 5, - } + it('should attempt initial GET connection and handle 405 gracefully', async () => { + // Mock the server not supporting GET for SSE (returning 405) + (global.fetch as jest.Mock).mockResolvedValueOnce({ + ok: false, + status: 405, + statusText: 'Method Not Allowed' + }); + + // We expect the 405 error to be caught and handled gracefully + // This should not throw an error that breaks the transport + await transport.start(); + await expect(transport['_startOrAuthSse']({})).resolves.not.toThrow('Failed to open SSE stream: Method Not Allowed'); + // Check that GET was attempted + expect(global.fetch).toHaveBeenCalledWith( + expect.anything(), + expect.objectContaining({ + method: 'GET', + headers: expect.any(Headers) + }) + ); + + // Verify transport still works after 405 + (global.fetch as jest.Mock).mockResolvedValueOnce({ + ok: true, + status: 202, + headers: new Headers() + }); + + await transport.send({ jsonrpc: '2.0', method: 'test', params: {} } as JSONRPCMessage); + expect(global.fetch).toHaveBeenCalledTimes(2); }); - // Verify options were set correctly (checking implementation details) - // Access private properties for testing - const transportInstance = transport as unknown as { - _reconnectionOptions: StreamableHTTPReconnectionOptions; - }; - expect(transportInstance._reconnectionOptions.initialReconnectionDelay).toBe(500); - expect(transportInstance._reconnectionOptions.maxRetries).toBe(5); - }); - - it("should pass lastEventId when reconnecting", async () => { - // Create a fresh transport - transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp")); - - // Mock fetch to verify headers sent - const fetchSpy = global.fetch as jest.Mock; - fetchSpy.mockReset(); - fetchSpy.mockResolvedValue({ - ok: true, - status: 200, - headers: new Headers({ "content-type": "text/event-stream" }), - body: new ReadableStream() + it('should handle successful initial GET connection for SSE', async () => { + // Set up readable stream for SSE events + const encoder = new TextEncoder(); + const stream = new ReadableStream({ + start(controller) { + // Send a server notification via SSE + const event = 'event: message\ndata: {"jsonrpc": "2.0", "method": "serverNotification", "params": {}}\n\n'; + controller.enqueue(encoder.encode(event)); + } + }); + + // Mock successful GET connection + (global.fetch as jest.Mock).mockResolvedValueOnce({ + ok: true, + status: 200, + headers: new Headers({ 'content-type': 'text/event-stream' }), + body: stream + }); + + const messageSpy = jest.fn(); + transport.onmessage = messageSpy; + + await transport.start(); + await transport['_startOrAuthSse']({}); + + // Give time for the SSE event to be processed + await new Promise(resolve => setTimeout(resolve, 50)); + + expect(messageSpy).toHaveBeenCalledWith( + expect.objectContaining({ + jsonrpc: '2.0', + method: 'serverNotification', + params: {} + }) + ); }); - // Call the reconnect method directly with a lastEventId - await transport.start(); - // Type assertion to access private method - const transportWithPrivateMethods = transport as unknown as { - _startOrAuthSse: (options: { resumptionToken?: string }) => Promise - }; - await transportWithPrivateMethods._startOrAuthSse({ resumptionToken: "test-event-id" }); - - // Verify fetch was called with the lastEventId header - expect(fetchSpy).toHaveBeenCalled(); - const fetchCall = fetchSpy.mock.calls[0]; - const headers = fetchCall[1].headers; - expect(headers.get("last-event-id")).toBe("test-event-id"); - }); - - it("should throw error when invalid content-type is received", async () => { - // Clear any previous state from other tests - jest.clearAllMocks(); - - // Create a fresh transport instance - transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp")); - - const message: JSONRPCMessage = { - jsonrpc: "2.0", - method: "test", - params: {}, - id: "test-id" - }; - - const stream = new ReadableStream({ - start(controller) { - controller.enqueue(new TextEncoder().encode("invalid text response")); - controller.close(); - } + it('should handle multiple concurrent SSE streams', async () => { + // Mock two POST requests that return SSE streams + const makeStream = (id: string) => { + const encoder = new TextEncoder(); + return new ReadableStream({ + start(controller) { + const event = `event: message\ndata: {"jsonrpc": "2.0", "result": {"id": "${id}"}, "id": "${id}"}\n\n`; + controller.enqueue(encoder.encode(event)); + } + }); + }; + + (global.fetch as jest.Mock) + .mockResolvedValueOnce({ + ok: true, + status: 200, + headers: new Headers({ 'content-type': 'text/event-stream' }), + body: makeStream('request1') + }) + .mockResolvedValueOnce({ + ok: true, + status: 200, + headers: new Headers({ 'content-type': 'text/event-stream' }), + body: makeStream('request2') + }); + + const messageSpy = jest.fn(); + transport.onmessage = messageSpy; + + // Send two concurrent requests + await Promise.all([ + transport.send({ jsonrpc: '2.0', method: 'test1', params: {}, id: 'request1' }), + transport.send({ jsonrpc: '2.0', method: 'test2', params: {}, id: 'request2' }) + ]); + + // Give time for SSE processing + await new Promise(resolve => setTimeout(resolve, 100)); + + // Both streams should have delivered their messages + expect(messageSpy).toHaveBeenCalledTimes(2); + + // Verify received messages without assuming specific order + expect( + messageSpy.mock.calls.some(call => { + const msg = call[0]; + return msg.id === 'request1' && msg.result?.id === 'request1'; + }) + ).toBe(true); + + expect( + messageSpy.mock.calls.some(call => { + const msg = call[0]; + return msg.id === 'request2' && msg.result?.id === 'request2'; + }) + ).toBe(true); }); - const errorSpy = jest.fn(); - transport.onerror = errorSpy; + it('should support custom reconnection options', () => { + // Create a transport with custom reconnection options + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + reconnectionOptions: { + initialReconnectionDelay: 500, + maxReconnectionDelay: 10000, + reconnectionDelayGrowFactor: 2, + maxRetries: 5 + } + }); - (global.fetch as jest.Mock).mockResolvedValueOnce({ - ok: true, - status: 200, - headers: new Headers({ "content-type": "text/plain" }), - body: stream + // Verify options were set correctly (checking implementation details) + // Access private properties for testing + const transportInstance = transport as unknown as { + _reconnectionOptions: StreamableHTTPReconnectionOptions; + }; + expect(transportInstance._reconnectionOptions.initialReconnectionDelay).toBe(500); + expect(transportInstance._reconnectionOptions.maxRetries).toBe(5); }); - await transport.start(); - await expect(transport.send(message)).rejects.toThrow("Unexpected content type: text/plain"); - expect(errorSpy).toHaveBeenCalled(); - }); - - it("uses custom fetch implementation if provided", async () => { - // Create custom fetch - const customFetch = jest.fn() - .mockResolvedValueOnce( - new Response(null, { status: 200, headers: { "content-type": "text/event-stream" } }) - ) - .mockResolvedValueOnce(new Response(null, { status: 202 })); - - // Create transport instance - transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), { - fetch: customFetch - }); + it('should pass lastEventId when reconnecting', async () => { + // Create a fresh transport + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp')); + + // Mock fetch to verify headers sent + const fetchSpy = global.fetch as jest.Mock; + fetchSpy.mockReset(); + fetchSpy.mockResolvedValue({ + ok: true, + status: 200, + headers: new Headers({ 'content-type': 'text/event-stream' }), + body: new ReadableStream() + }); - await transport.start(); - await (transport as unknown as { _startOrAuthSse: (opts: StartSSEOptions) => Promise })._startOrAuthSse({}); + // Call the reconnect method directly with a lastEventId + await transport.start(); + // Type assertion to access private method + const transportWithPrivateMethods = transport as unknown as { + _startOrAuthSse: (options: { resumptionToken?: string }) => Promise; + }; + await transportWithPrivateMethods._startOrAuthSse({ resumptionToken: 'test-event-id' }); + + // Verify fetch was called with the lastEventId header + expect(fetchSpy).toHaveBeenCalled(); + const fetchCall = fetchSpy.mock.calls[0]; + const headers = fetchCall[1].headers; + expect(headers.get('last-event-id')).toBe('test-event-id'); + }); - await transport.send({ jsonrpc: "2.0", method: "test", params: {}, id: "1" } as JSONRPCMessage); + it('should throw error when invalid content-type is received', async () => { + // Clear any previous state from other tests + jest.clearAllMocks(); + + // Create a fresh transport instance + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp')); + + const message: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'test', + params: {}, + id: 'test-id' + }; + + const stream = new ReadableStream({ + start(controller) { + controller.enqueue(new TextEncoder().encode('invalid text response')); + controller.close(); + } + }); - // Verify custom fetch was used - expect(customFetch).toHaveBeenCalled(); + const errorSpy = jest.fn(); + transport.onerror = errorSpy; - // Global fetch should never have been called - expect(global.fetch).not.toHaveBeenCalled(); - }); + (global.fetch as jest.Mock).mockResolvedValueOnce({ + ok: true, + status: 200, + headers: new Headers({ 'content-type': 'text/plain' }), + body: stream + }); - it("should always send specified custom headers", async () => { - const requestInit = { - headers: { - "X-Custom-Header": "CustomValue" - } - }; - transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), { - requestInit: requestInit + await transport.start(); + await expect(transport.send(message)).rejects.toThrow('Unexpected content type: text/plain'); + expect(errorSpy).toHaveBeenCalled(); }); - let actualReqInit: RequestInit = {}; + it('uses custom fetch implementation if provided', async () => { + // Create custom fetch + const customFetch = jest + .fn() + .mockResolvedValueOnce(new Response(null, { status: 200, headers: { 'content-type': 'text/event-stream' } })) + .mockResolvedValueOnce(new Response(null, { status: 202 })); - ((global.fetch as jest.Mock)).mockImplementation( - async (_url, reqInit) => { - actualReqInit = reqInit; - return new Response(null, { status: 200, headers: { "content-type": "text/event-stream" } }); - } - ); + // Create transport instance + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + fetch: customFetch + }); - await transport.start(); + await transport.start(); + await (transport as unknown as { _startOrAuthSse: (opts: StartSSEOptions) => Promise })._startOrAuthSse({}); - await transport["_startOrAuthSse"]({}); - expect((actualReqInit.headers as Headers).get("x-custom-header")).toBe("CustomValue"); + await transport.send({ jsonrpc: '2.0', method: 'test', params: {}, id: '1' } as JSONRPCMessage); - requestInit.headers["X-Custom-Header"] = "SecondCustomValue"; + // Verify custom fetch was used + expect(customFetch).toHaveBeenCalled(); - await transport.send({ jsonrpc: "2.0", method: "test", params: {} } as JSONRPCMessage); - expect((actualReqInit.headers as Headers).get("x-custom-header")).toBe("SecondCustomValue"); + // Global fetch should never have been called + expect(global.fetch).not.toHaveBeenCalled(); + }); - expect(global.fetch).toHaveBeenCalledTimes(2); - }); + it('should always send specified custom headers', async () => { + const requestInit = { + headers: { + 'X-Custom-Header': 'CustomValue' + } + }; + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + requestInit: requestInit + }); - it("should always send specified custom headers (Headers class)", async () => { - const requestInit = { - headers: new Headers({ - "X-Custom-Header": "CustomValue" - }) - }; - transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), { - requestInit: requestInit - }); + let actualReqInit: RequestInit = {}; - let actualReqInit: RequestInit = {}; + (global.fetch as jest.Mock).mockImplementation(async (_url, reqInit) => { + actualReqInit = reqInit; + return new Response(null, { status: 200, headers: { 'content-type': 'text/event-stream' } }); + }); - ((global.fetch as jest.Mock)).mockImplementation( - async (_url, reqInit) => { - actualReqInit = reqInit; - return new Response(null, { status: 200, headers: { "content-type": "text/event-stream" } }); - } - ); + await transport.start(); - await transport.start(); + await transport['_startOrAuthSse']({}); + expect((actualReqInit.headers as Headers).get('x-custom-header')).toBe('CustomValue'); - await transport["_startOrAuthSse"]({}); - expect((actualReqInit.headers as Headers).get("x-custom-header")).toBe("CustomValue"); + requestInit.headers['X-Custom-Header'] = 'SecondCustomValue'; - (requestInit.headers as Headers).set("X-Custom-Header","SecondCustomValue"); + await transport.send({ jsonrpc: '2.0', method: 'test', params: {} } as JSONRPCMessage); + expect((actualReqInit.headers as Headers).get('x-custom-header')).toBe('SecondCustomValue'); - await transport.send({ jsonrpc: "2.0", method: "test", params: {} } as JSONRPCMessage); - expect((actualReqInit.headers as Headers).get("x-custom-header")).toBe("SecondCustomValue"); + expect(global.fetch).toHaveBeenCalledTimes(2); + }); - expect(global.fetch).toHaveBeenCalledTimes(2); - }); + it('should always send specified custom headers (Headers class)', async () => { + const requestInit = { + headers: new Headers({ + 'X-Custom-Header': 'CustomValue' + }) + }; + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + requestInit: requestInit + }); - it("should have exponential backoff with configurable maxRetries", () => { - // This test verifies the maxRetries and backoff calculation directly + let actualReqInit: RequestInit = {}; - // Create transport with specific options for testing - transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), { - reconnectionOptions: { - initialReconnectionDelay: 100, - maxReconnectionDelay: 5000, - reconnectionDelayGrowFactor: 2, - maxRetries: 3, - } - }); + (global.fetch as jest.Mock).mockImplementation(async (_url, reqInit) => { + actualReqInit = reqInit; + return new Response(null, { status: 200, headers: { 'content-type': 'text/event-stream' } }); + }); - // Get access to the internal implementation - const getDelay = transport["_getNextReconnectionDelay"].bind(transport); - - // First retry - should use initial delay - expect(getDelay(0)).toBe(100); - - // Second retry - should double (2^1 * 100 = 200) - expect(getDelay(1)).toBe(200); - - // Third retry - should double again (2^2 * 100 = 400) - expect(getDelay(2)).toBe(400); - - // Fourth retry - should double again (2^3 * 100 = 800) - expect(getDelay(3)).toBe(800); - - // Tenth retry - should be capped at maxReconnectionDelay - expect(getDelay(10)).toBe(5000); - }); - - it("attempts auth flow on 401 during POST request", async () => { - const message: JSONRPCMessage = { - jsonrpc: "2.0", - method: "test", - params: {}, - id: "test-id" - }; - - (global.fetch as jest.Mock) - .mockResolvedValueOnce({ - ok: false, - status: 401, - statusText: "Unauthorized", - headers: new Headers() - }) - .mockResolvedValue({ - ok: false, - status: 404 - }); - - await expect(transport.send(message)).rejects.toThrow(UnauthorizedError); - expect(mockAuthProvider.redirectToAuthorization.mock.calls).toHaveLength(1); - }); - - describe('Reconnection Logic', () => { - let transport: StreamableHTTPClientTransport; + await transport.start(); - // Use fake timers to control setTimeout and make the test instant. - beforeEach(() => jest.useFakeTimers()); - afterEach(() => jest.useRealTimers()); - - it('should reconnect a GET-initiated notification stream that fails', async () => { - // ARRANGE - transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), { - reconnectionOptions: { - initialReconnectionDelay: 10, - maxRetries: 1, - maxReconnectionDelay: 1000, // Ensure it doesn't retry indefinitely - reconnectionDelayGrowFactor: 1 // No exponential backoff for simplicity - } - }); - - const errorSpy = jest.fn(); - transport.onerror = errorSpy; - - const failingStream = new ReadableStream({ - start(controller) { controller.error(new Error("Network failure")); } - }); - - const fetchMock = global.fetch as jest.Mock; - // Mock the initial GET request, which will fail. - fetchMock.mockResolvedValueOnce({ - ok: true, status: 200, - headers: new Headers({ "content-type": "text/event-stream" }), - body: failingStream, - }); - // Mock the reconnection GET request, which will succeed. - fetchMock.mockResolvedValueOnce({ - ok: true, status: 200, - headers: new Headers({ "content-type": "text/event-stream" }), - body: new ReadableStream(), - }); - - // ACT - await transport.start(); - // Trigger the GET stream directly using the internal method for a clean test. - await transport["_startOrAuthSse"]({}); - await jest.advanceTimersByTimeAsync(20); // Trigger reconnection timeout - - // ASSERT - expect(errorSpy).toHaveBeenCalledWith(expect.objectContaining({ - message: expect.stringContaining('SSE stream disconnected: Error: Network failure'), - })); - // THE KEY ASSERTION: A second fetch call proves reconnection was attempted. - expect(fetchMock).toHaveBeenCalledTimes(2); - expect(fetchMock.mock.calls[0][1]?.method).toBe('GET'); - expect(fetchMock.mock.calls[1][1]?.method).toBe('GET'); - }); + await transport['_startOrAuthSse']({}); + expect((actualReqInit.headers as Headers).get('x-custom-header')).toBe('CustomValue'); - it('should NOT reconnect a POST-initiated stream that fails', async () => { - // ARRANGE - transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), { - reconnectionOptions: { - initialReconnectionDelay: 10, - maxRetries: 1, - maxReconnectionDelay: 1000, // Ensure it doesn't retry indefinitely - reconnectionDelayGrowFactor: 1 // No exponential backoff for simplicity - } - }); - - const errorSpy = jest.fn(); - transport.onerror = errorSpy; - - const failingStream = new ReadableStream({ - start(controller) { controller.error(new Error("Network failure")); } - }); - - const fetchMock = global.fetch as jest.Mock; - // Mock the POST request. It returns a streaming content-type but a failing body. - fetchMock.mockResolvedValueOnce({ - ok: true, status: 200, - headers: new Headers({ "content-type": "text/event-stream" }), - body: failingStream, - }); - - // A dummy request message to trigger the `send` logic. - const requestMessage: JSONRPCRequest = { - jsonrpc: '2.0', - method: 'long_running_tool', - id: 'request-1', - params: {}, - }; - - // ACT - await transport.start(); - // Use the public `send` method to initiate a POST that gets a stream response. - await transport.send(requestMessage); - await jest.advanceTimersByTimeAsync(20); // Advance time to check for reconnections - - // ASSERT - expect(errorSpy).toHaveBeenCalledWith(expect.objectContaining({ - message: expect.stringContaining('SSE stream disconnected: Error: Network failure'), - })); - // THE KEY ASSERTION: Fetch was only called ONCE. No reconnection was attempted. - expect(fetchMock).toHaveBeenCalledTimes(1); - expect(fetchMock.mock.calls[0][1]?.method).toBe('POST'); - }); - }); - - it("invalidates all credentials on InvalidClientError during auth", async () => { - const message: JSONRPCMessage = { - jsonrpc: "2.0", - method: "test", - params: {}, - id: "test-id" - }; - - mockAuthProvider.tokens.mockResolvedValue({ - access_token: "test-token", - token_type: "Bearer", - refresh_token: "test-refresh" - }); + (requestInit.headers as Headers).set('X-Custom-Header', 'SecondCustomValue'); - const unauthedResponse = { - ok: false, - status: 401, - statusText: "Unauthorized", - headers: new Headers() - }; - (global.fetch as jest.Mock) - // Initial connection - .mockResolvedValueOnce(unauthedResponse) - // Resource discovery, path aware - .mockResolvedValueOnce(unauthedResponse) - // Resource discovery, root - .mockResolvedValueOnce(unauthedResponse) - // OAuth metadata discovery - .mockResolvedValueOnce({ - ok: true, - status: 200, - json: async () => ({ - issuer: "http://localhost:1234", - authorization_endpoint: "http://localhost:1234/authorize", - token_endpoint: "http://localhost:1234/token", - response_types_supported: ["code"], - code_challenge_methods_supported: ["S256"], - }), - }) - // Token refresh fails with InvalidClientError - .mockResolvedValueOnce(Response.json( - new InvalidClientError("Client authentication failed").toResponseObject(), - { status: 400 } - )) - // Fallback should fail to complete the flow - .mockResolvedValue({ - ok: false, - status: 404 - }); - - await expect(transport.send(message)).rejects.toThrow(UnauthorizedError); - expect(mockAuthProvider.invalidateCredentials).toHaveBeenCalledWith('all'); - }); - - it("invalidates all credentials on UnauthorizedClientError during auth", async () => { - const message: JSONRPCMessage = { - jsonrpc: "2.0", - method: "test", - params: {}, - id: "test-id" - }; - - mockAuthProvider.tokens.mockResolvedValue({ - access_token: "test-token", - token_type: "Bearer", - refresh_token: "test-refresh" - }); + await transport.send({ jsonrpc: '2.0', method: 'test', params: {} } as JSONRPCMessage); + expect((actualReqInit.headers as Headers).get('x-custom-header')).toBe('SecondCustomValue'); - const unauthedResponse = { - ok: false, - status: 401, - statusText: "Unauthorized", - headers: new Headers() - }; - (global.fetch as jest.Mock) - // Initial connection - .mockResolvedValueOnce(unauthedResponse) - // Resource discovery, path aware - .mockResolvedValueOnce(unauthedResponse) - // Resource discovery, root - .mockResolvedValueOnce(unauthedResponse) - // OAuth metadata discovery - .mockResolvedValueOnce({ - ok: true, - status: 200, - json: async () => ({ - issuer: "http://localhost:1234", - authorization_endpoint: "http://localhost:1234/authorize", - token_endpoint: "http://localhost:1234/token", - response_types_supported: ["code"], - code_challenge_methods_supported: ["S256"], - }), - }) - // Token refresh fails with UnauthorizedClientError - .mockResolvedValueOnce(Response.json( - new UnauthorizedClientError("Client not authorized").toResponseObject(), - { status: 400 } - )) - // Fallback should fail to complete the flow - .mockResolvedValue({ - ok: false, - status: 404 - }); - - await expect(transport.send(message)).rejects.toThrow(UnauthorizedError); - expect(mockAuthProvider.invalidateCredentials).toHaveBeenCalledWith('all'); - }); - - it("invalidates tokens on InvalidGrantError during auth", async () => { - const message: JSONRPCMessage = { - jsonrpc: "2.0", - method: "test", - params: {}, - id: "test-id" - }; - - mockAuthProvider.tokens.mockResolvedValue({ - access_token: "test-token", - token_type: "Bearer", - refresh_token: "test-refresh" + expect(global.fetch).toHaveBeenCalledTimes(2); }); - const unauthedResponse = { - ok: false, - status: 401, - statusText: "Unauthorized", - headers: new Headers() - }; - (global.fetch as jest.Mock) - // Initial connection - .mockResolvedValueOnce(unauthedResponse) - // Resource discovery, path aware - .mockResolvedValueOnce(unauthedResponse) - // Resource discovery, root - .mockResolvedValueOnce(unauthedResponse) - // OAuth metadata discovery - .mockResolvedValueOnce({ - ok: true, - status: 200, - json: async () => ({ - issuer: "http://localhost:1234", - authorization_endpoint: "http://localhost:1234/authorize", - token_endpoint: "http://localhost:1234/token", - response_types_supported: ["code"], - code_challenge_methods_supported: ["S256"], - }), - }) - // Token refresh fails with InvalidGrantError - .mockResolvedValueOnce(Response.json( - new InvalidGrantError("Invalid refresh token").toResponseObject(), - { status: 400 } - )) - // Fallback should fail to complete the flow - .mockResolvedValue({ - ok: false, - status: 404 - }); - - await expect(transport.send(message)).rejects.toThrow(UnauthorizedError); - expect(mockAuthProvider.invalidateCredentials).toHaveBeenCalledWith('tokens'); - }); - - describe("custom fetch in auth code paths", () => { - it("uses custom fetch during auth flow on 401 - no global fetch fallback", async () => { - const unauthedResponse = { - ok: false, - status: 401, - statusText: "Unauthorized", - headers: new Headers() - }; - - // Create custom fetch - const customFetch = jest.fn() - // Initial connection - .mockResolvedValueOnce(unauthedResponse) - // Resource discovery - .mockResolvedValueOnce(unauthedResponse) - // OAuth metadata discovery - .mockResolvedValueOnce({ - ok: true, - status: 200, - json: async () => ({ - issuer: "http://localhost:1234", - authorization_endpoint: "http://localhost:1234/authorize", - token_endpoint: "http://localhost:1234/token", - response_types_supported: ["code"], - code_challenge_methods_supported: ["S256"], - }), - }) - // Token refresh fails with InvalidClientError - .mockResolvedValueOnce(Response.json( - new InvalidClientError("Client authentication failed").toResponseObject(), - { status: 400 } - )) - // Fallback should fail to complete the flow - .mockResolvedValue({ - ok: false, - status: 404 + it('should have exponential backoff with configurable maxRetries', () => { + // This test verifies the maxRetries and backoff calculation directly + + // Create transport with specific options for testing + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + reconnectionOptions: { + initialReconnectionDelay: 100, + maxReconnectionDelay: 5000, + reconnectionDelayGrowFactor: 2, + maxRetries: 3 + } }); - // Create transport instance - transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), { - authProvider: mockAuthProvider, - fetch: customFetch - }); + // Get access to the internal implementation + const getDelay = transport['_getNextReconnectionDelay'].bind(transport); - // Attempt to start - should trigger auth flow and eventually fail with UnauthorizedError - await transport.start(); - await expect((transport as unknown as { _startOrAuthSse: (opts: StartSSEOptions) => Promise })._startOrAuthSse({})).rejects.toThrow(UnauthorizedError); + // First retry - should use initial delay + expect(getDelay(0)).toBe(100); - // Verify custom fetch was used - expect(customFetch).toHaveBeenCalled(); + // Second retry - should double (2^1 * 100 = 200) + expect(getDelay(1)).toBe(200); - // Verify specific OAuth endpoints were called with custom fetch - const customFetchCalls = customFetch.mock.calls; - const callUrls = customFetchCalls.map(([url]) => url.toString()); + // Third retry - should double again (2^2 * 100 = 400) + expect(getDelay(2)).toBe(400); - // Should have called resource metadata discovery - expect(callUrls.some(url => url.includes('/.well-known/oauth-protected-resource'))).toBe(true); + // Fourth retry - should double again (2^3 * 100 = 800) + expect(getDelay(3)).toBe(800); - // Should have called OAuth authorization server metadata discovery - expect(callUrls.some(url => url.includes('/.well-known/oauth-authorization-server'))).toBe(true); - - // Verify auth provider was called to redirect to authorization - expect(mockAuthProvider.redirectToAuthorization).toHaveBeenCalled(); + // Tenth retry - should be capped at maxReconnectionDelay + expect(getDelay(10)).toBe(5000); + }); - // Global fetch should never have been called - expect(global.fetch).not.toHaveBeenCalled(); + it('attempts auth flow on 401 during POST request', async () => { + const message: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'test', + params: {}, + id: 'test-id' + }; + + (global.fetch as jest.Mock) + .mockResolvedValueOnce({ + ok: false, + status: 401, + statusText: 'Unauthorized', + headers: new Headers() + }) + .mockResolvedValue({ + ok: false, + status: 404 + }); + + await expect(transport.send(message)).rejects.toThrow(UnauthorizedError); + expect(mockAuthProvider.redirectToAuthorization.mock.calls).toHaveLength(1); }); - it("uses custom fetch in finishAuth method - no global fetch fallback", async () => { - // Create custom fetch - const customFetch = jest.fn() - // Protected resource metadata discovery - .mockResolvedValueOnce({ - ok: true, - status: 200, - json: async () => ({ - authorization_servers: ["http://localhost:1234"], - resource: "http://localhost:1234/mcp" - }), - }) - // OAuth metadata discovery - .mockResolvedValueOnce({ - ok: true, - status: 200, - json: async () => ({ - issuer: "http://localhost:1234", - authorization_endpoint: "http://localhost:1234/authorize", - token_endpoint: "http://localhost:1234/token", - response_types_supported: ["code"], - code_challenge_methods_supported: ["S256"], - }), - }) - // Code exchange - .mockResolvedValueOnce({ - ok: true, - status: 200, - json: async () => ({ - access_token: "new-access-token", - refresh_token: "new-refresh-token", - token_type: "Bearer", - expires_in: 3600, - }), + describe('Reconnection Logic', () => { + let transport: StreamableHTTPClientTransport; + + // Use fake timers to control setTimeout and make the test instant. + beforeEach(() => jest.useFakeTimers()); + afterEach(() => jest.useRealTimers()); + + it('should reconnect a GET-initiated notification stream that fails', async () => { + // ARRANGE + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + reconnectionOptions: { + initialReconnectionDelay: 10, + maxRetries: 1, + maxReconnectionDelay: 1000, // Ensure it doesn't retry indefinitely + reconnectionDelayGrowFactor: 1 // No exponential backoff for simplicity + } + }); + + const errorSpy = jest.fn(); + transport.onerror = errorSpy; + + const failingStream = new ReadableStream({ + start(controller) { + controller.error(new Error('Network failure')); + } + }); + + const fetchMock = global.fetch as jest.Mock; + // Mock the initial GET request, which will fail. + fetchMock.mockResolvedValueOnce({ + ok: true, + status: 200, + headers: new Headers({ 'content-type': 'text/event-stream' }), + body: failingStream + }); + // Mock the reconnection GET request, which will succeed. + fetchMock.mockResolvedValueOnce({ + ok: true, + status: 200, + headers: new Headers({ 'content-type': 'text/event-stream' }), + body: new ReadableStream() + }); + + // ACT + await transport.start(); + // Trigger the GET stream directly using the internal method for a clean test. + await transport['_startOrAuthSse']({}); + await jest.advanceTimersByTimeAsync(20); // Trigger reconnection timeout + + // ASSERT + expect(errorSpy).toHaveBeenCalledWith( + expect.objectContaining({ + message: expect.stringContaining('SSE stream disconnected: Error: Network failure') + }) + ); + // THE KEY ASSERTION: A second fetch call proves reconnection was attempted. + expect(fetchMock).toHaveBeenCalledTimes(2); + expect(fetchMock.mock.calls[0][1]?.method).toBe('GET'); + expect(fetchMock.mock.calls[1][1]?.method).toBe('GET'); }); - // Create transport instance - transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), { - authProvider: mockAuthProvider, - fetch: customFetch - }); + it('should NOT reconnect a POST-initiated stream that fails', async () => { + // ARRANGE + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + reconnectionOptions: { + initialReconnectionDelay: 10, + maxRetries: 1, + maxReconnectionDelay: 1000, // Ensure it doesn't retry indefinitely + reconnectionDelayGrowFactor: 1 // No exponential backoff for simplicity + } + }); + + const errorSpy = jest.fn(); + transport.onerror = errorSpy; + + const failingStream = new ReadableStream({ + start(controller) { + controller.error(new Error('Network failure')); + } + }); + + const fetchMock = global.fetch as jest.Mock; + // Mock the POST request. It returns a streaming content-type but a failing body. + fetchMock.mockResolvedValueOnce({ + ok: true, + status: 200, + headers: new Headers({ 'content-type': 'text/event-stream' }), + body: failingStream + }); + + // A dummy request message to trigger the `send` logic. + const requestMessage: JSONRPCRequest = { + jsonrpc: '2.0', + method: 'long_running_tool', + id: 'request-1', + params: {} + }; + + // ACT + await transport.start(); + // Use the public `send` method to initiate a POST that gets a stream response. + await transport.send(requestMessage); + await jest.advanceTimersByTimeAsync(20); // Advance time to check for reconnections + + // ASSERT + expect(errorSpy).toHaveBeenCalledWith( + expect.objectContaining({ + message: expect.stringContaining('SSE stream disconnected: Error: Network failure') + }) + ); + // THE KEY ASSERTION: Fetch was only called ONCE. No reconnection was attempted. + expect(fetchMock).toHaveBeenCalledTimes(1); + expect(fetchMock.mock.calls[0][1]?.method).toBe('POST'); + }); + }); - // Call finishAuth with authorization code - await transport.finishAuth("test-auth-code"); + it('invalidates all credentials on InvalidClientError during auth', async () => { + const message: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'test', + params: {}, + id: 'test-id' + }; + + mockAuthProvider.tokens.mockResolvedValue({ + access_token: 'test-token', + token_type: 'Bearer', + refresh_token: 'test-refresh' + }); - // Verify custom fetch was used - expect(customFetch).toHaveBeenCalled(); + const unauthedResponse = { + ok: false, + status: 401, + statusText: 'Unauthorized', + headers: new Headers() + }; + (global.fetch as jest.Mock) + // Initial connection + .mockResolvedValueOnce(unauthedResponse) + // Resource discovery, path aware + .mockResolvedValueOnce(unauthedResponse) + // Resource discovery, root + .mockResolvedValueOnce(unauthedResponse) + // OAuth metadata discovery + .mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => ({ + issuer: 'http://localhost:1234', + authorization_endpoint: 'http://localhost:1234/authorize', + token_endpoint: 'http://localhost:1234/token', + response_types_supported: ['code'], + code_challenge_methods_supported: ['S256'] + }) + }) + // Token refresh fails with InvalidClientError + .mockResolvedValueOnce( + Response.json(new InvalidClientError('Client authentication failed').toResponseObject(), { status: 400 }) + ) + // Fallback should fail to complete the flow + .mockResolvedValue({ + ok: false, + status: 404 + }); + + await expect(transport.send(message)).rejects.toThrow(UnauthorizedError); + expect(mockAuthProvider.invalidateCredentials).toHaveBeenCalledWith('all'); + }); - // Verify specific OAuth endpoints were called with custom fetch - const customFetchCalls = customFetch.mock.calls; - const callUrls = customFetchCalls.map(([url]) => url.toString()); + it('invalidates all credentials on UnauthorizedClientError during auth', async () => { + const message: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'test', + params: {}, + id: 'test-id' + }; + + mockAuthProvider.tokens.mockResolvedValue({ + access_token: 'test-token', + token_type: 'Bearer', + refresh_token: 'test-refresh' + }); - // Should have called resource metadata discovery - expect(callUrls.some(url => url.includes('/.well-known/oauth-protected-resource'))).toBe(true); + const unauthedResponse = { + ok: false, + status: 401, + statusText: 'Unauthorized', + headers: new Headers() + }; + (global.fetch as jest.Mock) + // Initial connection + .mockResolvedValueOnce(unauthedResponse) + // Resource discovery, path aware + .mockResolvedValueOnce(unauthedResponse) + // Resource discovery, root + .mockResolvedValueOnce(unauthedResponse) + // OAuth metadata discovery + .mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => ({ + issuer: 'http://localhost:1234', + authorization_endpoint: 'http://localhost:1234/authorize', + token_endpoint: 'http://localhost:1234/token', + response_types_supported: ['code'], + code_challenge_methods_supported: ['S256'] + }) + }) + // Token refresh fails with UnauthorizedClientError + .mockResolvedValueOnce(Response.json(new UnauthorizedClientError('Client not authorized').toResponseObject(), { status: 400 })) + // Fallback should fail to complete the flow + .mockResolvedValue({ + ok: false, + status: 404 + }); + + await expect(transport.send(message)).rejects.toThrow(UnauthorizedError); + expect(mockAuthProvider.invalidateCredentials).toHaveBeenCalledWith('all'); + }); - // Should have called OAuth authorization server metadata discovery - expect(callUrls.some(url => url.includes('/.well-known/oauth-authorization-server'))).toBe(true); + it('invalidates tokens on InvalidGrantError during auth', async () => { + const message: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'test', + params: {}, + id: 'test-id' + }; + + mockAuthProvider.tokens.mockResolvedValue({ + access_token: 'test-token', + token_type: 'Bearer', + refresh_token: 'test-refresh' + }); - // Should have called token endpoint for authorization code exchange - const tokenCalls = customFetchCalls.filter(([url, options]) => - url.toString().includes('/token') && options?.method === "POST" - ); - expect(tokenCalls.length).toBeGreaterThan(0); + const unauthedResponse = { + ok: false, + status: 401, + statusText: 'Unauthorized', + headers: new Headers() + }; + (global.fetch as jest.Mock) + // Initial connection + .mockResolvedValueOnce(unauthedResponse) + // Resource discovery, path aware + .mockResolvedValueOnce(unauthedResponse) + // Resource discovery, root + .mockResolvedValueOnce(unauthedResponse) + // OAuth metadata discovery + .mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => ({ + issuer: 'http://localhost:1234', + authorization_endpoint: 'http://localhost:1234/authorize', + token_endpoint: 'http://localhost:1234/token', + response_types_supported: ['code'], + code_challenge_methods_supported: ['S256'] + }) + }) + // Token refresh fails with InvalidGrantError + .mockResolvedValueOnce(Response.json(new InvalidGrantError('Invalid refresh token').toResponseObject(), { status: 400 })) + // Fallback should fail to complete the flow + .mockResolvedValue({ + ok: false, + status: 404 + }); + + await expect(transport.send(message)).rejects.toThrow(UnauthorizedError); + expect(mockAuthProvider.invalidateCredentials).toHaveBeenCalledWith('tokens'); + }); - // Verify tokens were saved - expect(mockAuthProvider.saveTokens).toHaveBeenCalledWith({ - access_token: "new-access-token", - token_type: "Bearer", - expires_in: 3600, - refresh_token: "new-refresh-token" - }); + describe('custom fetch in auth code paths', () => { + it('uses custom fetch during auth flow on 401 - no global fetch fallback', async () => { + const unauthedResponse = { + ok: false, + status: 401, + statusText: 'Unauthorized', + headers: new Headers() + }; + + // Create custom fetch + const customFetch = jest + .fn() + // Initial connection + .mockResolvedValueOnce(unauthedResponse) + // Resource discovery + .mockResolvedValueOnce(unauthedResponse) + // OAuth metadata discovery + .mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => ({ + issuer: 'http://localhost:1234', + authorization_endpoint: 'http://localhost:1234/authorize', + token_endpoint: 'http://localhost:1234/token', + response_types_supported: ['code'], + code_challenge_methods_supported: ['S256'] + }) + }) + // Token refresh fails with InvalidClientError + .mockResolvedValueOnce( + Response.json(new InvalidClientError('Client authentication failed').toResponseObject(), { status: 400 }) + ) + // Fallback should fail to complete the flow + .mockResolvedValue({ + ok: false, + status: 404 + }); + + // Create transport instance + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + authProvider: mockAuthProvider, + fetch: customFetch + }); + + // Attempt to start - should trigger auth flow and eventually fail with UnauthorizedError + await transport.start(); + await expect( + (transport as unknown as { _startOrAuthSse: (opts: StartSSEOptions) => Promise })._startOrAuthSse({}) + ).rejects.toThrow(UnauthorizedError); + + // Verify custom fetch was used + expect(customFetch).toHaveBeenCalled(); + + // Verify specific OAuth endpoints were called with custom fetch + const customFetchCalls = customFetch.mock.calls; + const callUrls = customFetchCalls.map(([url]) => url.toString()); + + // Should have called resource metadata discovery + expect(callUrls.some(url => url.includes('/.well-known/oauth-protected-resource'))).toBe(true); + + // Should have called OAuth authorization server metadata discovery + expect(callUrls.some(url => url.includes('/.well-known/oauth-authorization-server'))).toBe(true); + + // Verify auth provider was called to redirect to authorization + expect(mockAuthProvider.redirectToAuthorization).toHaveBeenCalled(); + + // Global fetch should never have been called + expect(global.fetch).not.toHaveBeenCalled(); + }); - // Global fetch should never have been called - expect(global.fetch).not.toHaveBeenCalled(); + it('uses custom fetch in finishAuth method - no global fetch fallback', async () => { + // Create custom fetch + const customFetch = jest + .fn() + // Protected resource metadata discovery + .mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => ({ + authorization_servers: ['http://localhost:1234'], + resource: 'http://localhost:1234/mcp' + }) + }) + // OAuth metadata discovery + .mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => ({ + issuer: 'http://localhost:1234', + authorization_endpoint: 'http://localhost:1234/authorize', + token_endpoint: 'http://localhost:1234/token', + response_types_supported: ['code'], + code_challenge_methods_supported: ['S256'] + }) + }) + // Code exchange + .mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => ({ + access_token: 'new-access-token', + refresh_token: 'new-refresh-token', + token_type: 'Bearer', + expires_in: 3600 + }) + }); + + // Create transport instance + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + authProvider: mockAuthProvider, + fetch: customFetch + }); + + // Call finishAuth with authorization code + await transport.finishAuth('test-auth-code'); + + // Verify custom fetch was used + expect(customFetch).toHaveBeenCalled(); + + // Verify specific OAuth endpoints were called with custom fetch + const customFetchCalls = customFetch.mock.calls; + const callUrls = customFetchCalls.map(([url]) => url.toString()); + + // Should have called resource metadata discovery + expect(callUrls.some(url => url.includes('/.well-known/oauth-protected-resource'))).toBe(true); + + // Should have called OAuth authorization server metadata discovery + expect(callUrls.some(url => url.includes('/.well-known/oauth-authorization-server'))).toBe(true); + + // Should have called token endpoint for authorization code exchange + const tokenCalls = customFetchCalls.filter(([url, options]) => url.toString().includes('/token') && options?.method === 'POST'); + expect(tokenCalls.length).toBeGreaterThan(0); + + // Verify tokens were saved + expect(mockAuthProvider.saveTokens).toHaveBeenCalledWith({ + access_token: 'new-access-token', + token_type: 'Bearer', + expires_in: 3600, + refresh_token: 'new-refresh-token' + }); + + // Global fetch should never have been called + expect(global.fetch).not.toHaveBeenCalled(); + }); }); - }); }); diff --git a/src/client/streamableHttp.ts b/src/client/streamableHttp.ts index 12714ea44..fc35590bc 100644 --- a/src/client/streamableHttp.ts +++ b/src/client/streamableHttp.ts @@ -1,119 +1,119 @@ -import { Transport, FetchLike } from "../shared/transport.js"; -import { isInitializedNotification, isJSONRPCRequest, isJSONRPCResponse, JSONRPCMessage, JSONRPCMessageSchema } from "../types.js"; -import { auth, AuthResult, extractResourceMetadataUrl, OAuthClientProvider, UnauthorizedError } from "./auth.js"; -import { EventSourceParserStream } from "eventsource-parser/stream"; +import { Transport, FetchLike } from '../shared/transport.js'; +import { isInitializedNotification, isJSONRPCRequest, isJSONRPCResponse, JSONRPCMessage, JSONRPCMessageSchema } from '../types.js'; +import { auth, AuthResult, extractResourceMetadataUrl, OAuthClientProvider, UnauthorizedError } from './auth.js'; +import { EventSourceParserStream } from 'eventsource-parser/stream'; // Default reconnection options for StreamableHTTP connections const DEFAULT_STREAMABLE_HTTP_RECONNECTION_OPTIONS: StreamableHTTPReconnectionOptions = { - initialReconnectionDelay: 1000, - maxReconnectionDelay: 30000, - reconnectionDelayGrowFactor: 1.5, - maxRetries: 2, + initialReconnectionDelay: 1000, + maxReconnectionDelay: 30000, + reconnectionDelayGrowFactor: 1.5, + maxRetries: 2 }; export class StreamableHTTPError extends Error { - constructor( - public readonly code: number | undefined, - message: string | undefined, - ) { - super(`Streamable HTTP error: ${message}`); - } + constructor( + public readonly code: number | undefined, + message: string | undefined + ) { + super(`Streamable HTTP error: ${message}`); + } } /** * Options for starting or authenticating an SSE connection */ export interface StartSSEOptions { - /** - * The resumption token used to continue long-running requests that were interrupted. - * - * This allows clients to reconnect and continue from where they left off. - */ - resumptionToken?: string; - - /** - * A callback that is invoked when the resumption token changes. - * - * This allows clients to persist the latest token for potential reconnection. - */ - onresumptiontoken?: (token: string) => void; - - /** - * Override Message ID to associate with the replay message - * so that response can be associate with the new resumed request. - */ - replayMessageId?: string | number; + /** + * The resumption token used to continue long-running requests that were interrupted. + * + * This allows clients to reconnect and continue from where they left off. + */ + resumptionToken?: string; + + /** + * A callback that is invoked when the resumption token changes. + * + * This allows clients to persist the latest token for potential reconnection. + */ + onresumptiontoken?: (token: string) => void; + + /** + * Override Message ID to associate with the replay message + * so that response can be associate with the new resumed request. + */ + replayMessageId?: string | number; } /** * Configuration options for reconnection behavior of the StreamableHTTPClientTransport. */ export interface StreamableHTTPReconnectionOptions { - /** - * Maximum backoff time between reconnection attempts in milliseconds. - * Default is 30000 (30 seconds). - */ - maxReconnectionDelay: number; - - /** - * Initial backoff time between reconnection attempts in milliseconds. - * Default is 1000 (1 second). - */ - initialReconnectionDelay: number; - - /** - * The factor by which the reconnection delay increases after each attempt. - * Default is 1.5. - */ - reconnectionDelayGrowFactor: number; - - /** - * Maximum number of reconnection attempts before giving up. - * Default is 2. - */ - maxRetries: number; + /** + * Maximum backoff time between reconnection attempts in milliseconds. + * Default is 30000 (30 seconds). + */ + maxReconnectionDelay: number; + + /** + * Initial backoff time between reconnection attempts in milliseconds. + * Default is 1000 (1 second). + */ + initialReconnectionDelay: number; + + /** + * The factor by which the reconnection delay increases after each attempt. + * Default is 1.5. + */ + reconnectionDelayGrowFactor: number; + + /** + * Maximum number of reconnection attempts before giving up. + * Default is 2. + */ + maxRetries: number; } /** * Configuration options for the `StreamableHTTPClientTransport`. */ export type StreamableHTTPClientTransportOptions = { - /** - * An OAuth client provider to use for authentication. - * - * When an `authProvider` is specified and the connection is started: - * 1. The connection is attempted with any existing access token from the `authProvider`. - * 2. If the access token has expired, the `authProvider` is used to refresh the token. - * 3. If token refresh fails or no access token exists, and auth is required, `OAuthClientProvider.redirectToAuthorization` is called, and an `UnauthorizedError` will be thrown from `connect`/`start`. - * - * After the user has finished authorizing via their user agent, and is redirected back to the MCP client application, call `StreamableHTTPClientTransport.finishAuth` with the authorization code before retrying the connection. - * - * If an `authProvider` is not provided, and auth is required, an `UnauthorizedError` will be thrown. - * - * `UnauthorizedError` might also be thrown when sending any message over the transport, indicating that the session has expired, and needs to be re-authed and reconnected. - */ - authProvider?: OAuthClientProvider; - - /** - * Customizes HTTP requests to the server. - */ - requestInit?: RequestInit; - - /** - * Custom fetch implementation used for all network requests. - */ - fetch?: FetchLike; - - /** - * Options to configure the reconnection behavior. - */ - reconnectionOptions?: StreamableHTTPReconnectionOptions; - - /** - * Session ID for the connection. This is used to identify the session on the server. - * When not provided and connecting to a server that supports session IDs, the server will generate a new session ID. - */ - sessionId?: string; + /** + * An OAuth client provider to use for authentication. + * + * When an `authProvider` is specified and the connection is started: + * 1. The connection is attempted with any existing access token from the `authProvider`. + * 2. If the access token has expired, the `authProvider` is used to refresh the token. + * 3. If token refresh fails or no access token exists, and auth is required, `OAuthClientProvider.redirectToAuthorization` is called, and an `UnauthorizedError` will be thrown from `connect`/`start`. + * + * After the user has finished authorizing via their user agent, and is redirected back to the MCP client application, call `StreamableHTTPClientTransport.finishAuth` with the authorization code before retrying the connection. + * + * If an `authProvider` is not provided, and auth is required, an `UnauthorizedError` will be thrown. + * + * `UnauthorizedError` might also be thrown when sending any message over the transport, indicating that the session has expired, and needs to be re-authed and reconnected. + */ + authProvider?: OAuthClientProvider; + + /** + * Customizes HTTP requests to the server. + */ + requestInit?: RequestInit; + + /** + * Custom fetch implementation used for all network requests. + */ + fetch?: FetchLike; + + /** + * Options to configure the reconnection behavior. + */ + reconnectionOptions?: StreamableHTTPReconnectionOptions; + + /** + * Session ID for the connection. This is used to identify the session on the server. + * When not provided and connecting to a server that supports session IDs, the server will generate a new session ID. + */ + sessionId?: string; }; /** @@ -122,439 +122,428 @@ export type StreamableHTTPClientTransportOptions = { * for receiving messages. */ export class StreamableHTTPClientTransport implements Transport { - private _abortController?: AbortController; - private _url: URL; - private _resourceMetadataUrl?: URL; - private _requestInit?: RequestInit; - private _authProvider?: OAuthClientProvider; - private _fetch?: FetchLike; - private _sessionId?: string; - private _reconnectionOptions: StreamableHTTPReconnectionOptions; - private _protocolVersion?: string; - - onclose?: () => void; - onerror?: (error: Error) => void; - onmessage?: (message: JSONRPCMessage) => void; - - constructor( - url: URL, - opts?: StreamableHTTPClientTransportOptions, - ) { - this._url = url; - this._resourceMetadataUrl = undefined; - this._requestInit = opts?.requestInit; - this._authProvider = opts?.authProvider; - this._fetch = opts?.fetch; - this._sessionId = opts?.sessionId; - this._reconnectionOptions = opts?.reconnectionOptions ?? DEFAULT_STREAMABLE_HTTP_RECONNECTION_OPTIONS; - } - - private async _authThenStart(): Promise { - if (!this._authProvider) { - throw new UnauthorizedError("No auth provider"); - } - - let result: AuthResult; - try { - result = await auth(this._authProvider, { serverUrl: this._url, resourceMetadataUrl: this._resourceMetadataUrl, fetchFn: this._fetch }); - } catch (error) { - this.onerror?.(error as Error); - throw error; + private _abortController?: AbortController; + private _url: URL; + private _resourceMetadataUrl?: URL; + private _requestInit?: RequestInit; + private _authProvider?: OAuthClientProvider; + private _fetch?: FetchLike; + private _sessionId?: string; + private _reconnectionOptions: StreamableHTTPReconnectionOptions; + private _protocolVersion?: string; + + onclose?: () => void; + onerror?: (error: Error) => void; + onmessage?: (message: JSONRPCMessage) => void; + + constructor(url: URL, opts?: StreamableHTTPClientTransportOptions) { + this._url = url; + this._resourceMetadataUrl = undefined; + this._requestInit = opts?.requestInit; + this._authProvider = opts?.authProvider; + this._fetch = opts?.fetch; + this._sessionId = opts?.sessionId; + this._reconnectionOptions = opts?.reconnectionOptions ?? DEFAULT_STREAMABLE_HTTP_RECONNECTION_OPTIONS; } - if (result !== "AUTHORIZED") { - throw new UnauthorizedError(); - } + private async _authThenStart(): Promise { + if (!this._authProvider) { + throw new UnauthorizedError('No auth provider'); + } - return await this._startOrAuthSse({ resumptionToken: undefined }); - } + let result: AuthResult; + try { + result = await auth(this._authProvider, { + serverUrl: this._url, + resourceMetadataUrl: this._resourceMetadataUrl, + fetchFn: this._fetch + }); + } catch (error) { + this.onerror?.(error as Error); + throw error; + } - private async _commonHeaders(): Promise { - const headers: HeadersInit & Record = {}; - if (this._authProvider) { - const tokens = await this._authProvider.tokens(); - if (tokens) { - headers["Authorization"] = `Bearer ${tokens.access_token}`; - } - } + if (result !== 'AUTHORIZED') { + throw new UnauthorizedError(); + } - if (this._sessionId) { - headers["mcp-session-id"] = this._sessionId; - } - if (this._protocolVersion) { - headers["mcp-protocol-version"] = this._protocolVersion; + return await this._startOrAuthSse({ resumptionToken: undefined }); } - const extraHeaders = this._normalizeHeaders(this._requestInit?.headers); - - return new Headers({ - ...headers, - ...extraHeaders, - }); - } - - - private async _startOrAuthSse(options: StartSSEOptions): Promise { - const { resumptionToken } = options; - try { - // Try to open an initial SSE stream with GET to listen for server messages - // This is optional according to the spec - server may not support it - const headers = await this._commonHeaders(); - headers.set("Accept", "text/event-stream"); - - // Include Last-Event-ID header for resumable streams if provided - if (resumptionToken) { - headers.set("last-event-id", resumptionToken); - } - - const response = await (this._fetch ?? fetch)(this._url, { - method: "GET", - headers, - signal: this._abortController?.signal, - }); - - if (!response.ok) { - if (response.status === 401 && this._authProvider) { - // Need to authenticate - return await this._authThenStart(); + private async _commonHeaders(): Promise { + const headers: HeadersInit & Record = {}; + if (this._authProvider) { + const tokens = await this._authProvider.tokens(); + if (tokens) { + headers['Authorization'] = `Bearer ${tokens.access_token}`; + } } - // 405 indicates that the server does not offer an SSE stream at GET endpoint - // This is an expected case that should not trigger an error - if (response.status === 405) { - return; + if (this._sessionId) { + headers['mcp-session-id'] = this._sessionId; + } + if (this._protocolVersion) { + headers['mcp-protocol-version'] = this._protocolVersion; } - throw new StreamableHTTPError( - response.status, - `Failed to open SSE stream: ${response.statusText}`, - ); - } + const extraHeaders = this._normalizeHeaders(this._requestInit?.headers); - this._handleSseStream(response.body, options, true); - } catch (error) { - this.onerror?.(error as Error); - throw error; + return new Headers({ + ...headers, + ...extraHeaders + }); } - } + private async _startOrAuthSse(options: StartSSEOptions): Promise { + const { resumptionToken } = options; + try { + // Try to open an initial SSE stream with GET to listen for server messages + // This is optional according to the spec - server may not support it + const headers = await this._commonHeaders(); + headers.set('Accept', 'text/event-stream'); + + // Include Last-Event-ID header for resumable streams if provided + if (resumptionToken) { + headers.set('last-event-id', resumptionToken); + } - /** - * Calculates the next reconnection delay using backoff algorithm - * - * @param attempt Current reconnection attempt count for the specific stream - * @returns Time to wait in milliseconds before next reconnection attempt - */ - private _getNextReconnectionDelay(attempt: number): number { - // Access default values directly, ensuring they're never undefined - const initialDelay = this._reconnectionOptions.initialReconnectionDelay; - const growFactor = this._reconnectionOptions.reconnectionDelayGrowFactor; - const maxDelay = this._reconnectionOptions.maxReconnectionDelay; + const response = await (this._fetch ?? fetch)(this._url, { + method: 'GET', + headers, + signal: this._abortController?.signal + }); + + if (!response.ok) { + if (response.status === 401 && this._authProvider) { + // Need to authenticate + return await this._authThenStart(); + } + + // 405 indicates that the server does not offer an SSE stream at GET endpoint + // This is an expected case that should not trigger an error + if (response.status === 405) { + return; + } + + throw new StreamableHTTPError(response.status, `Failed to open SSE stream: ${response.statusText}`); + } - // Cap at maximum delay - return Math.min(initialDelay * Math.pow(growFactor, attempt), maxDelay); + this._handleSseStream(response.body, options, true); + } catch (error) { + this.onerror?.(error as Error); + throw error; + } + } - } + /** + * Calculates the next reconnection delay using backoff algorithm + * + * @param attempt Current reconnection attempt count for the specific stream + * @returns Time to wait in milliseconds before next reconnection attempt + */ + private _getNextReconnectionDelay(attempt: number): number { + // Access default values directly, ensuring they're never undefined + const initialDelay = this._reconnectionOptions.initialReconnectionDelay; + const growFactor = this._reconnectionOptions.reconnectionDelayGrowFactor; + const maxDelay = this._reconnectionOptions.maxReconnectionDelay; + + // Cap at maximum delay + return Math.min(initialDelay * Math.pow(growFactor, attempt), maxDelay); + } private _normalizeHeaders(headers: HeadersInit | undefined): Record { - if (!headers) return {}; + if (!headers) return {}; - if (headers instanceof Headers) { - return Object.fromEntries(headers.entries()); - } + if (headers instanceof Headers) { + return Object.fromEntries(headers.entries()); + } - if (Array.isArray(headers)) { - return Object.fromEntries(headers); - } + if (Array.isArray(headers)) { + return Object.fromEntries(headers); + } - return { ...headers as Record }; - } - - /** - * Schedule a reconnection attempt with exponential backoff - * - * @param lastEventId The ID of the last received event for resumability - * @param attemptCount Current reconnection attempt count for this specific stream - */ - private _scheduleReconnection(options: StartSSEOptions, attemptCount = 0): void { - // Use provided options or default options - const maxRetries = this._reconnectionOptions.maxRetries; - - // Check if we've exceeded maximum retry attempts - if (maxRetries > 0 && attemptCount >= maxRetries) { - this.onerror?.(new Error(`Maximum reconnection attempts (${maxRetries}) exceeded.`)); - return; + return { ...(headers as Record) }; } - // Calculate next delay based on current attempt count - const delay = this._getNextReconnectionDelay(attemptCount); - - // Schedule the reconnection - setTimeout(() => { - // Use the last event ID to resume where we left off - this._startOrAuthSse(options).catch(error => { - this.onerror?.(new Error(`Failed to reconnect SSE stream: ${error instanceof Error ? error.message : String(error)}`)); - // Schedule another attempt if this one failed, incrementing the attempt counter - this._scheduleReconnection(options, attemptCount + 1); - }); - }, delay); - } - - private _handleSseStream( - stream: ReadableStream | null, - options: StartSSEOptions, - isReconnectable: boolean, - ): void { - if (!stream) { - return; + /** + * Schedule a reconnection attempt with exponential backoff + * + * @param lastEventId The ID of the last received event for resumability + * @param attemptCount Current reconnection attempt count for this specific stream + */ + private _scheduleReconnection(options: StartSSEOptions, attemptCount = 0): void { + // Use provided options or default options + const maxRetries = this._reconnectionOptions.maxRetries; + + // Check if we've exceeded maximum retry attempts + if (maxRetries > 0 && attemptCount >= maxRetries) { + this.onerror?.(new Error(`Maximum reconnection attempts (${maxRetries}) exceeded.`)); + return; + } + + // Calculate next delay based on current attempt count + const delay = this._getNextReconnectionDelay(attemptCount); + + // Schedule the reconnection + setTimeout(() => { + // Use the last event ID to resume where we left off + this._startOrAuthSse(options).catch(error => { + this.onerror?.(new Error(`Failed to reconnect SSE stream: ${error instanceof Error ? error.message : String(error)}`)); + // Schedule another attempt if this one failed, incrementing the attempt counter + this._scheduleReconnection(options, attemptCount + 1); + }); + }, delay); } - const { onresumptiontoken, replayMessageId } = options; - - let lastEventId: string | undefined; - const processStream = async () => { - // this is the closest we can get to trying to catch network errors - // if something happens reader will throw - try { - // Create a pipeline: binary stream -> text decoder -> SSE parser - const reader = stream - .pipeThrough(new TextDecoderStream()) - .pipeThrough(new EventSourceParserStream()) - .getReader(); - - - while (true) { - const { value: event, done } = await reader.read(); - if (done) { - break; - } - - // Update last event ID if provided - if (event.id) { - lastEventId = event.id; - onresumptiontoken?.(event.id); - } - - if (!event.event || event.event === "message") { + + private _handleSseStream(stream: ReadableStream | null, options: StartSSEOptions, isReconnectable: boolean): void { + if (!stream) { + return; + } + const { onresumptiontoken, replayMessageId } = options; + + let lastEventId: string | undefined; + const processStream = async () => { + // this is the closest we can get to trying to catch network errors + // if something happens reader will throw try { - const message = JSONRPCMessageSchema.parse(JSON.parse(event.data)); - if (replayMessageId !== undefined && isJSONRPCResponse(message)) { - message.id = replayMessageId; - } - this.onmessage?.(message); + // Create a pipeline: binary stream -> text decoder -> SSE parser + const reader = stream.pipeThrough(new TextDecoderStream()).pipeThrough(new EventSourceParserStream()).getReader(); + + while (true) { + const { value: event, done } = await reader.read(); + if (done) { + break; + } + + // Update last event ID if provided + if (event.id) { + lastEventId = event.id; + onresumptiontoken?.(event.id); + } + + if (!event.event || event.event === 'message') { + try { + const message = JSONRPCMessageSchema.parse(JSON.parse(event.data)); + if (replayMessageId !== undefined && isJSONRPCResponse(message)) { + message.id = replayMessageId; + } + this.onmessage?.(message); + } catch (error) { + this.onerror?.(error as Error); + } + } + } } catch (error) { - this.onerror?.(error as Error); + // Handle stream errors - likely a network disconnect + this.onerror?.(new Error(`SSE stream disconnected: ${error}`)); + + // Attempt to reconnect if the stream disconnects unexpectedly and we aren't closing + if (isReconnectable && this._abortController && !this._abortController.signal.aborted) { + // Use the exponential backoff reconnection strategy + try { + this._scheduleReconnection( + { + resumptionToken: lastEventId, + onresumptiontoken, + replayMessageId + }, + 0 + ); + } catch (error) { + this.onerror?.(new Error(`Failed to reconnect: ${error instanceof Error ? error.message : String(error)}`)); + } + } } - } - } - } catch (error) { - // Handle stream errors - likely a network disconnect - this.onerror?.(new Error(`SSE stream disconnected: ${error}`)); - - // Attempt to reconnect if the stream disconnects unexpectedly and we aren't closing - if ( - isReconnectable && - this._abortController && - !this._abortController.signal.aborted - ) { - // Use the exponential backoff reconnection strategy - try { - this._scheduleReconnection({ - resumptionToken: lastEventId, - onresumptiontoken, - replayMessageId - }, 0); - } - catch (error) { - this.onerror?.(new Error(`Failed to reconnect: ${error instanceof Error ? error.message : String(error)}`)); - - } - } - } - }; - processStream(); - } - - async start() { - if (this._abortController) { - throw new Error( - "StreamableHTTPClientTransport already started! If using Client class, note that connect() calls start() automatically.", - ); + }; + processStream(); } - this._abortController = new AbortController(); - } + async start() { + if (this._abortController) { + throw new Error( + 'StreamableHTTPClientTransport already started! If using Client class, note that connect() calls start() automatically.' + ); + } - /** - * Call this method after the user has finished authorizing via their user agent and is redirected back to the MCP client application. This will exchange the authorization code for an access token, enabling the next connection attempt to successfully auth. - */ - async finishAuth(authorizationCode: string): Promise { - if (!this._authProvider) { - throw new UnauthorizedError("No auth provider"); + this._abortController = new AbortController(); } - const result = await auth(this._authProvider, { serverUrl: this._url, authorizationCode, resourceMetadataUrl: this._resourceMetadataUrl, fetchFn: this._fetch }); - if (result !== "AUTHORIZED") { - throw new UnauthorizedError("Failed to authorize"); + /** + * Call this method after the user has finished authorizing via their user agent and is redirected back to the MCP client application. This will exchange the authorization code for an access token, enabling the next connection attempt to successfully auth. + */ + async finishAuth(authorizationCode: string): Promise { + if (!this._authProvider) { + throw new UnauthorizedError('No auth provider'); + } + + const result = await auth(this._authProvider, { + serverUrl: this._url, + authorizationCode, + resourceMetadataUrl: this._resourceMetadataUrl, + fetchFn: this._fetch + }); + if (result !== 'AUTHORIZED') { + throw new UnauthorizedError('Failed to authorize'); + } } - } - async close(): Promise { - // Abort any pending requests - this._abortController?.abort(); + async close(): Promise { + // Abort any pending requests + this._abortController?.abort(); - this.onclose?.(); - } + this.onclose?.(); + } - async send(message: JSONRPCMessage | JSONRPCMessage[], options?: { resumptionToken?: string, onresumptiontoken?: (token: string) => void }): Promise { - try { - const { resumptionToken, onresumptiontoken } = options || {}; + async send( + message: JSONRPCMessage | JSONRPCMessage[], + options?: { resumptionToken?: string; onresumptiontoken?: (token: string) => void } + ): Promise { + try { + const { resumptionToken, onresumptiontoken } = options || {}; + + if (resumptionToken) { + // If we have at last event ID, we need to reconnect the SSE stream + this._startOrAuthSse({ resumptionToken, replayMessageId: isJSONRPCRequest(message) ? message.id : undefined }).catch(err => + this.onerror?.(err) + ); + return; + } - if (resumptionToken) { - // If we have at last event ID, we need to reconnect the SSE stream - this._startOrAuthSse({ resumptionToken, replayMessageId: isJSONRPCRequest(message) ? message.id : undefined }).catch(err => this.onerror?.(err)); - return; - } + const headers = await this._commonHeaders(); + headers.set('content-type', 'application/json'); + headers.set('accept', 'application/json, text/event-stream'); - const headers = await this._commonHeaders(); - headers.set("content-type", "application/json"); - headers.set("accept", "application/json, text/event-stream"); + const init = { + ...this._requestInit, + method: 'POST', + headers, + body: JSON.stringify(message), + signal: this._abortController?.signal + }; - const init = { - ...this._requestInit, - method: "POST", - headers, - body: JSON.stringify(message), - signal: this._abortController?.signal, - }; + const response = await (this._fetch ?? fetch)(this._url, init); - const response = await (this._fetch ?? fetch)(this._url, init); + // Handle session ID received during initialization + const sessionId = response.headers.get('mcp-session-id'); + if (sessionId) { + this._sessionId = sessionId; + } - // Handle session ID received during initialization - const sessionId = response.headers.get("mcp-session-id"); - if (sessionId) { - this._sessionId = sessionId; - } + if (!response.ok) { + if (response.status === 401 && this._authProvider) { + this._resourceMetadataUrl = extractResourceMetadataUrl(response); + + const result = await auth(this._authProvider, { + serverUrl: this._url, + resourceMetadataUrl: this._resourceMetadataUrl, + fetchFn: this._fetch + }); + if (result !== 'AUTHORIZED') { + throw new UnauthorizedError(); + } + + // Purposely _not_ awaited, so we don't call onerror twice + return this.send(message); + } + + const text = await response.text().catch(() => null); + throw new Error(`Error POSTing to endpoint (HTTP ${response.status}): ${text}`); + } - if (!response.ok) { - if (response.status === 401 && this._authProvider) { + // If the response is 202 Accepted, there's no body to process + if (response.status === 202) { + // if the accepted notification is initialized, we start the SSE stream + // if it's supported by the server + if (isInitializedNotification(message)) { + // Start without a lastEventId since this is a fresh connection + this._startOrAuthSse({ resumptionToken: undefined }).catch(err => this.onerror?.(err)); + } + return; + } - this._resourceMetadataUrl = extractResourceMetadataUrl(response); + // Get original message(s) for detecting request IDs + const messages = Array.isArray(message) ? message : [message]; + + const hasRequests = messages.filter(msg => 'method' in msg && 'id' in msg && msg.id !== undefined).length > 0; + + // Check the response type + const contentType = response.headers.get('content-type'); + + if (hasRequests) { + if (contentType?.includes('text/event-stream')) { + // Handle SSE stream responses for requests + // We use the same handler as standalone streams, which now supports + // reconnection with the last event ID + this._handleSseStream(response.body, { onresumptiontoken }, false); + } else if (contentType?.includes('application/json')) { + // For non-streaming servers, we might get direct JSON responses + const data = await response.json(); + const responseMessages = Array.isArray(data) + ? data.map(msg => JSONRPCMessageSchema.parse(msg)) + : [JSONRPCMessageSchema.parse(data)]; + + for (const msg of responseMessages) { + this.onmessage?.(msg); + } + } else { + throw new StreamableHTTPError(-1, `Unexpected content type: ${contentType}`); + } + } + } catch (error) { + this.onerror?.(error as Error); + throw error; + } + } - const result = await auth(this._authProvider, { serverUrl: this._url, resourceMetadataUrl: this._resourceMetadataUrl, fetchFn: this._fetch }); - if (result !== "AUTHORIZED") { - throw new UnauthorizedError(); - } + get sessionId(): string | undefined { + return this._sessionId; + } - // Purposely _not_ awaited, so we don't call onerror twice - return this.send(message); + /** + * Terminates the current session by sending a DELETE request to the server. + * + * Clients that no longer need a particular session + * (e.g., because the user is leaving the client application) SHOULD send an + * HTTP DELETE to the MCP endpoint with the Mcp-Session-Id header to explicitly + * terminate the session. + * + * The server MAY respond with HTTP 405 Method Not Allowed, indicating that + * the server does not allow clients to terminate sessions. + */ + async terminateSession(): Promise { + if (!this._sessionId) { + return; // No session to terminate } - const text = await response.text().catch(() => null); - throw new Error( - `Error POSTing to endpoint (HTTP ${response.status}): ${text}`, - ); - } - - // If the response is 202 Accepted, there's no body to process - if (response.status === 202) { - // if the accepted notification is initialized, we start the SSE stream - // if it's supported by the server - if (isInitializedNotification(message)) { - // Start without a lastEventId since this is a fresh connection - this._startOrAuthSse({ resumptionToken: undefined }).catch(err => this.onerror?.(err)); - } - return; - } - - // Get original message(s) for detecting request IDs - const messages = Array.isArray(message) ? message : [message]; - - const hasRequests = messages.filter(msg => "method" in msg && "id" in msg && msg.id !== undefined).length > 0; - - // Check the response type - const contentType = response.headers.get("content-type"); - - if (hasRequests) { - if (contentType?.includes("text/event-stream")) { - // Handle SSE stream responses for requests - // We use the same handler as standalone streams, which now supports - // reconnection with the last event ID - this._handleSseStream(response.body, { onresumptiontoken }, false); - } else if (contentType?.includes("application/json")) { - // For non-streaming servers, we might get direct JSON responses - const data = await response.json(); - const responseMessages = Array.isArray(data) - ? data.map(msg => JSONRPCMessageSchema.parse(msg)) - : [JSONRPCMessageSchema.parse(data)]; - - for (const msg of responseMessages) { - this.onmessage?.(msg); - } - } else { - throw new StreamableHTTPError( - -1, - `Unexpected content type: ${contentType}`, - ); + try { + const headers = await this._commonHeaders(); + + const init = { + ...this._requestInit, + method: 'DELETE', + headers, + signal: this._abortController?.signal + }; + + const response = await (this._fetch ?? fetch)(this._url, init); + + // We specifically handle 405 as a valid response according to the spec, + // meaning the server does not support explicit session termination + if (!response.ok && response.status !== 405) { + throw new StreamableHTTPError(response.status, `Failed to terminate session: ${response.statusText}`); + } + + this._sessionId = undefined; + } catch (error) { + this.onerror?.(error as Error); + throw error; } - } - } catch (error) { - this.onerror?.(error as Error); - throw error; - } - } - - get sessionId(): string | undefined { - return this._sessionId; - } - - /** - * Terminates the current session by sending a DELETE request to the server. - * - * Clients that no longer need a particular session - * (e.g., because the user is leaving the client application) SHOULD send an - * HTTP DELETE to the MCP endpoint with the Mcp-Session-Id header to explicitly - * terminate the session. - * - * The server MAY respond with HTTP 405 Method Not Allowed, indicating that - * the server does not allow clients to terminate sessions. - */ - async terminateSession(): Promise { - if (!this._sessionId) { - return; // No session to terminate } - try { - const headers = await this._commonHeaders(); - - const init = { - ...this._requestInit, - method: "DELETE", - headers, - signal: this._abortController?.signal, - }; - - const response = await (this._fetch ?? fetch)(this._url, init); - - // We specifically handle 405 as a valid response according to the spec, - // meaning the server does not support explicit session termination - if (!response.ok && response.status !== 405) { - throw new StreamableHTTPError( - response.status, - `Failed to terminate session: ${response.statusText}` - ); - } - - this._sessionId = undefined; - } catch (error) { - this.onerror?.(error as Error); - throw error; + setProtocolVersion(version: string): void { + this._protocolVersion = version; + } + get protocolVersion(): string | undefined { + return this._protocolVersion; } - } - - setProtocolVersion(version: string): void { - this._protocolVersion = version; - } - get protocolVersion(): string | undefined { - return this._protocolVersion; - } } diff --git a/src/client/websocket.ts b/src/client/websocket.ts index 3ca760820..aed766caf 100644 --- a/src/client/websocket.ts +++ b/src/client/websocket.ts @@ -1,77 +1,74 @@ -import { Transport } from "../shared/transport.js"; -import { JSONRPCMessage, JSONRPCMessageSchema } from "../types.js"; +import { Transport } from '../shared/transport.js'; +import { JSONRPCMessage, JSONRPCMessageSchema } from '../types.js'; -const SUBPROTOCOL = "mcp"; +const SUBPROTOCOL = 'mcp'; /** * Client transport for WebSocket: this will connect to a server over the WebSocket protocol. */ export class WebSocketClientTransport implements Transport { - private _socket?: WebSocket; - private _url: URL; + private _socket?: WebSocket; + private _url: URL; - onclose?: () => void; - onerror?: (error: Error) => void; - onmessage?: (message: JSONRPCMessage) => void; + onclose?: () => void; + onerror?: (error: Error) => void; + onmessage?: (message: JSONRPCMessage) => void; - constructor(url: URL) { - this._url = url; - } - - start(): Promise { - if (this._socket) { - throw new Error( - "WebSocketClientTransport already started! If using Client class, note that connect() calls start() automatically.", - ); + constructor(url: URL) { + this._url = url; } - return new Promise((resolve, reject) => { - this._socket = new WebSocket(this._url, SUBPROTOCOL); + start(): Promise { + if (this._socket) { + throw new Error( + 'WebSocketClientTransport already started! If using Client class, note that connect() calls start() automatically.' + ); + } - this._socket.onerror = (event) => { - const error = - "error" in event - ? (event.error as Error) - : new Error(`WebSocket error: ${JSON.stringify(event)}`); - reject(error); - this.onerror?.(error); - }; + return new Promise((resolve, reject) => { + this._socket = new WebSocket(this._url, SUBPROTOCOL); - this._socket.onopen = () => { - resolve(); - }; + this._socket.onerror = event => { + const error = 'error' in event ? (event.error as Error) : new Error(`WebSocket error: ${JSON.stringify(event)}`); + reject(error); + this.onerror?.(error); + }; - this._socket.onclose = () => { - this.onclose?.(); - }; + this._socket.onopen = () => { + resolve(); + }; - this._socket.onmessage = (event: MessageEvent) => { - let message: JSONRPCMessage; - try { - message = JSONRPCMessageSchema.parse(JSON.parse(event.data)); - } catch (error) { - this.onerror?.(error as Error); - return; - } + this._socket.onclose = () => { + this.onclose?.(); + }; - this.onmessage?.(message); - }; - }); - } + this._socket.onmessage = (event: MessageEvent) => { + let message: JSONRPCMessage; + try { + message = JSONRPCMessageSchema.parse(JSON.parse(event.data)); + } catch (error) { + this.onerror?.(error as Error); + return; + } - async close(): Promise { - this._socket?.close(); - } + this.onmessage?.(message); + }; + }); + } - send(message: JSONRPCMessage): Promise { - return new Promise((resolve, reject) => { - if (!this._socket) { - reject(new Error("Not connected")); - return; - } + async close(): Promise { + this._socket?.close(); + } + + send(message: JSONRPCMessage): Promise { + return new Promise((resolve, reject) => { + if (!this._socket) { + reject(new Error('Not connected')); + return; + } - this._socket?.send(JSON.stringify(message)); - resolve(); - }); - } + this._socket?.send(JSON.stringify(message)); + resolve(); + }); + } } diff --git a/src/examples/README.md b/src/examples/README.md index ac92e8ded..1c30b8dde 100644 --- a/src/examples/README.md +++ b/src/examples/README.md @@ -5,14 +5,14 @@ This directory contains example implementations of MCP clients and servers using ## Table of Contents - [Client Implementations](#client-implementations) - - [Streamable HTTP Client](#streamable-http-client) - - [Backwards Compatible Client](#backwards-compatible-client) + - [Streamable HTTP Client](#streamable-http-client) + - [Backwards Compatible Client](#backwards-compatible-client) - [Server Implementations](#server-implementations) - - [Single Node Deployment](#single-node-deployment) - - [Streamable HTTP Transport](#streamable-http-transport) - - [Deprecated SSE Transport](#deprecated-sse-transport) - - [Backwards Compatible Server](#streamable-http-backwards-compatible-server-with-sse) - - [Multi-Node Deployment](#multi-node-deployment) + - [Single Node Deployment](#single-node-deployment) + - [Streamable HTTP Transport](#streamable-http-transport) + - [Deprecated SSE Transport](#deprecated-sse-transport) + - [Backwards Compatible Server](#streamable-http-backwards-compatible-server-with-sse) + - [Multi-Node Deployment](#multi-node-deployment) - [Backwards Compatibility](#testing-streamable-http-backwards-compatibility-with-sse) ## Client Implementations @@ -44,8 +44,8 @@ npx tsx src/examples/client/simpleOAuthClient.js A client that implements backwards compatibility according to the [MCP specification](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#backwards-compatibility), allowing it to work with both new and legacy servers. This client demonstrates: - The client first POSTs an initialize request to the server URL: - - If successful, it uses the Streamable HTTP transport - - If it fails with a 4xx status, it attempts a GET request to establish an SSE stream + - If successful, it uses the Streamable HTTP transport + - If it fails with a 4xx status, it attempts a GET request to establish an SSE stream ```bash npx tsx src/examples/client/streamableHttpWithSseFallbackClient.ts @@ -61,7 +61,7 @@ These examples demonstrate how to set up an MCP server on a single node with dif ##### Simple Streamable HTTP Server -A server that implements the Streamable HTTP transport (protocol version 2025-03-26). +A server that implements the Streamable HTTP transport (protocol version 2025-03-26). - Basic server setup with Express and the Streamable HTTP transport - Session management with an in-memory event store for resumability @@ -83,7 +83,7 @@ npx tsx src/examples/server/simpleStreamableHttp.ts --oauth --oauth-strict ##### JSON Response Mode Server -A server that uses Streamable HTTP transport with JSON response mode enabled (no SSE). +A server that uses Streamable HTTP transport with JSON response mode enabled (no SSE). - Streamable HTTP with JSON response mode, which returns responses directly in the response body - Limited support for notifications (since SSE is disabled) @@ -96,12 +96,11 @@ npx tsx src/examples/server/jsonResponseStreamableHttp.ts ##### Streamable HTTP with server notifications -A server that demonstrates server notifications using Streamable HTTP. +A server that demonstrates server notifications using Streamable HTTP. - Resource list change notifications with dynamically added resources - Automatic resource creation on a timed interval - ```bash npx tsx src/examples/server/standaloneSseWithGetStreamableHttp.ts ``` @@ -117,9 +116,9 @@ A server that implements the deprecated HTTP+SSE transport (protocol version 202 npx tsx src/examples/server/simpleSseServer.ts ``` -#### Streamable Http Backwards Compatible Server with SSE +#### Streamable Http Backwards Compatible Server with SSE -A server that supports both Streamable HTTP and SSE transports, adhering to the [MCP specification for backwards compatibility](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#backwards-compatibility). +A server that supports both Streamable HTTP and SSE transports, adhering to the [MCP specification for backwards compatibility](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#backwards-compatibility). - Single MCP server instance with multiple transport options - Support for Streamable HTTP requests at `/mcp` endpoint (GET/POST/DELETE) @@ -134,6 +133,7 @@ npx tsx src/examples/server/sseAndStreamableHttpCompatibleServer.ts ### Multi-Node Deployment When deploying MCP servers in a horizontally scaled environment (multiple server instances), there are a few different options that can be useful for different use cases: + - **Stateless mode** - No need to maintain state between calls to MCP servers. Useful for simple API wrapper servers. - **Persistent storage mode** - No local state needed, but session data is stored in a database. Example: an MCP server for online ordering where the shopping cart is stored in a database. - **Local state with message routing** - Local state is needed, and all requests for a session must be routed to the correct node. This can be done with a message queue and pub/sub system. @@ -145,8 +145,9 @@ The Streamable HTTP transport can be configured to operate without tracking sess ##### Implementation To enable stateless mode, configure the `StreamableHTTPServerTransport` with: + ```typescript -sessionIdGenerator: undefined +sessionIdGenerator: undefined; ``` This disables session management entirely, and the server won't generate or expect session IDs. @@ -174,8 +175,6 @@ This disables session management entirely, and the server won't generate or expe └─────────────────┘ └─────────────────────┘ ``` - - #### Persistent Storage Mode For cases where you need session continuity but don't need to maintain in-memory state on specific nodes, you can use a database to persist session data while still allowing any node to handle requests. @@ -199,7 +198,6 @@ All session state is stored in the database, and any node can serve any client b - Good for applications where state can be fully externalized - Somewhat higher latency due to database access for each request - ``` ┌─────────────────────────────────────────────┐ │ Client │ @@ -226,30 +224,28 @@ All session state is stored in the database, and any node can serve any client b └─────────────────────────────────────────────┘ ``` - - #### Streamable HTTP with Distributed Message Routing For scenarios where local in-memory state must be maintained on specific nodes (such as Computer Use or complex session state), the Streamable HTTP transport can be combined with a pub/sub system to route messages to the correct node handling each session. 1. **Bidirectional Message Queue Integration**: - - All nodes both publish to and subscribe from the message queue - - Each node registers the sessions it's actively handling - - Messages are routed based on session ownership + - All nodes both publish to and subscribe from the message queue + - Each node registers the sessions it's actively handling + - Messages are routed based on session ownership 2. **Request Handling Flow**: - - When a client connects to Node A with an existing `mcp-session-id` - - If Node A doesn't own this session, it: - - Establishes and maintains the SSE connection with the client - - Publishes the request to the message queue with the session ID - - Node B (which owns the session) receives the request from the queue - - Node B processes the request with its local session state - - Node B publishes responses/notifications back to the queue - - Node A subscribes to the response channel and forwards to the client + - When a client connects to Node A with an existing `mcp-session-id` + - If Node A doesn't own this session, it: + - Establishes and maintains the SSE connection with the client + - Publishes the request to the message queue with the session ID + - Node B (which owns the session) receives the request from the queue + - Node B processes the request with its local session state + - Node B publishes responses/notifications back to the queue + - Node A subscribes to the response channel and forwards to the client 3. **Channel Identification**: - - Each message channel combines both `mcp-session-id` and `stream-id` - - This ensures responses are correctly routed back to the originating connection + - Each message channel combines both `mcp-session-id` and `stream-id` + - This ensures responses are correctly routed back to the originating connection ``` ┌─────────────────────────────────────────────┐ @@ -277,12 +273,10 @@ For scenarios where local in-memory state must be maintained on specific nodes ( └─────────────────────────────────────────────┘ ``` - - Maintains session affinity for stateful operations without client redirection - Enables horizontal scaling while preserving complex in-memory state - Provides fault tolerance through the message queue as intermediary - ## Backwards Compatibility ### Testing Streamable HTTP Backwards Compatibility with SSE @@ -290,20 +284,21 @@ For scenarios where local in-memory state must be maintained on specific nodes ( To test the backwards compatibility features: 1. Start one of the server implementations: - ```bash - # Legacy SSE server (protocol version 2024-11-05) - npx tsx src/examples/server/simpleSseServer.ts - - # Streamable HTTP server (protocol version 2025-03-26) - npx tsx src/examples/server/simpleStreamableHttp.ts - - # Backwards compatible server (supports both protocols) - npx tsx src/examples/server/sseAndStreamableHttpCompatibleServer.ts - ``` + + ```bash + # Legacy SSE server (protocol version 2024-11-05) + npx tsx src/examples/server/simpleSseServer.ts + + # Streamable HTTP server (protocol version 2025-03-26) + npx tsx src/examples/server/simpleStreamableHttp.ts + + # Backwards compatible server (supports both protocols) + npx tsx src/examples/server/sseAndStreamableHttpCompatibleServer.ts + ``` 2. Then run the backwards compatible client: - ```bash - npx tsx src/examples/client/streamableHttpWithSseFallbackClient.ts - ``` + ```bash + npx tsx src/examples/client/streamableHttpWithSseFallbackClient.ts + ``` -This demonstrates how the MCP ecosystem ensures interoperability between clients and servers regardless of which protocol version they were built for. \ No newline at end of file +This demonstrates how the MCP ecosystem ensures interoperability between clients and servers regardless of which protocol version they were built for. diff --git a/src/examples/client/multipleClientsParallel.ts b/src/examples/client/multipleClientsParallel.ts index cc01fc06e..492235cdd 100644 --- a/src/examples/client/multipleClientsParallel.ts +++ b/src/examples/client/multipleClientsParallel.ts @@ -1,18 +1,13 @@ import { Client } from '../../client/index.js'; import { StreamableHTTPClientTransport } from '../../client/streamableHttp.js'; -import { - CallToolRequest, - CallToolResultSchema, - LoggingMessageNotificationSchema, - CallToolResult, -} from '../../types.js'; +import { CallToolRequest, CallToolResultSchema, LoggingMessageNotificationSchema, CallToolResult } from '../../types.js'; /** * Multiple Clients MCP Example - * + * * This client demonstrates how to: * 1. Create multiple MCP clients in parallel - * 2. Each client calls a single tool + * 2. Each client calls a single tool * 3. Track notifications from each client independently */ @@ -21,140 +16,139 @@ const args = process.argv.slice(2); const serverUrl = args[0] || 'http://localhost:3000/mcp'; interface ClientConfig { - id: string; - name: string; - toolName: string; - toolArguments: Record; + id: string; + name: string; + toolName: string; + toolArguments: Record; } async function createAndRunClient(config: ClientConfig): Promise<{ id: string; result: CallToolResult }> { - console.log(`[${config.id}] Creating client: ${config.name}`); - - const client = new Client({ - name: config.name, - version: '1.0.0' - }); - - const transport = new StreamableHTTPClientTransport(new URL(serverUrl)); - - // Set up client-specific error handler - client.onerror = (error) => { - console.error(`[${config.id}] Client error:`, error); - }; - - // Set up client-specific notification handler - client.setNotificationHandler(LoggingMessageNotificationSchema, (notification) => { - console.log(`[${config.id}] Notification: ${notification.params.data}`); - }); - - try { - // Connect to the server - await client.connect(transport); - console.log(`[${config.id}] Connected to MCP server`); - - // Call the specified tool - console.log(`[${config.id}] Calling tool: ${config.toolName}`); - const toolRequest: CallToolRequest = { - method: 'tools/call', - params: { - name: config.toolName, - arguments: { - ...config.toolArguments, - // Add client ID to arguments for identification in notifications - caller: config.id - } - } - }; + console.log(`[${config.id}] Creating client: ${config.name}`); - const result = await client.request(toolRequest, CallToolResultSchema); - console.log(`[${config.id}] Tool call completed`); + const client = new Client({ + name: config.name, + version: '1.0.0' + }); - // Keep the connection open for a bit to receive notifications - await new Promise(resolve => setTimeout(resolve, 5000)); + const transport = new StreamableHTTPClientTransport(new URL(serverUrl)); - // Disconnect - await transport.close(); - console.log(`[${config.id}] Disconnected from MCP server`); + // Set up client-specific error handler + client.onerror = error => { + console.error(`[${config.id}] Client error:`, error); + }; + + // Set up client-specific notification handler + client.setNotificationHandler(LoggingMessageNotificationSchema, notification => { + console.log(`[${config.id}] Notification: ${notification.params.data}`); + }); - return { id: config.id, result }; - } catch (error) { - console.error(`[${config.id}] Error:`, error); - throw error; - } + try { + // Connect to the server + await client.connect(transport); + console.log(`[${config.id}] Connected to MCP server`); + + // Call the specified tool + console.log(`[${config.id}] Calling tool: ${config.toolName}`); + const toolRequest: CallToolRequest = { + method: 'tools/call', + params: { + name: config.toolName, + arguments: { + ...config.toolArguments, + // Add client ID to arguments for identification in notifications + caller: config.id + } + } + }; + + const result = await client.request(toolRequest, CallToolResultSchema); + console.log(`[${config.id}] Tool call completed`); + + // Keep the connection open for a bit to receive notifications + await new Promise(resolve => setTimeout(resolve, 5000)); + + // Disconnect + await transport.close(); + console.log(`[${config.id}] Disconnected from MCP server`); + + return { id: config.id, result }; + } catch (error) { + console.error(`[${config.id}] Error:`, error); + throw error; + } } async function main(): Promise { - console.log('MCP Multiple Clients Example'); - console.log('============================'); - console.log(`Server URL: ${serverUrl}`); - console.log(''); - - try { - // Define client configurations - const clientConfigs: ClientConfig[] = [ - { - id: 'client1', - name: 'basic-client-1', - toolName: 'start-notification-stream', - toolArguments: { - interval: 3, // 1 second between notifications - count: 5 // Send 5 notifications - } - }, - { - id: 'client2', - name: 'basic-client-2', - toolName: 'start-notification-stream', - toolArguments: { - interval: 2, // 2 seconds between notifications - count: 3 // Send 3 notifications - } - }, - { - id: 'client3', - name: 'basic-client-3', - toolName: 'start-notification-stream', - toolArguments: { - interval: 1, // 0.5 second between notifications - count: 8 // Send 8 notifications - } - } - ]; - - // Start all clients in parallel - console.log(`Starting ${clientConfigs.length} clients in parallel...`); + console.log('MCP Multiple Clients Example'); + console.log('============================'); + console.log(`Server URL: ${serverUrl}`); console.log(''); - const clientPromises = clientConfigs.map(config => createAndRunClient(config)); - const results = await Promise.all(clientPromises); - - // Display results from all clients - console.log('\n=== Final Results ==='); - results.forEach(({ id, result }) => { - console.log(`\n[${id}] Tool result:`); - if (Array.isArray(result.content)) { - result.content.forEach((item: { type: string; text?: string }) => { - if (item.type === 'text' && item.text) { - console.log(` ${item.text}`); - } else { - console.log(` ${item.type} content:`, item); - } + try { + // Define client configurations + const clientConfigs: ClientConfig[] = [ + { + id: 'client1', + name: 'basic-client-1', + toolName: 'start-notification-stream', + toolArguments: { + interval: 3, // 1 second between notifications + count: 5 // Send 5 notifications + } + }, + { + id: 'client2', + name: 'basic-client-2', + toolName: 'start-notification-stream', + toolArguments: { + interval: 2, // 2 seconds between notifications + count: 3 // Send 3 notifications + } + }, + { + id: 'client3', + name: 'basic-client-3', + toolName: 'start-notification-stream', + toolArguments: { + interval: 1, // 0.5 second between notifications + count: 8 // Send 8 notifications + } + } + ]; + + // Start all clients in parallel + console.log(`Starting ${clientConfigs.length} clients in parallel...`); + console.log(''); + + const clientPromises = clientConfigs.map(config => createAndRunClient(config)); + const results = await Promise.all(clientPromises); + + // Display results from all clients + console.log('\n=== Final Results ==='); + results.forEach(({ id, result }) => { + console.log(`\n[${id}] Tool result:`); + if (Array.isArray(result.content)) { + result.content.forEach((item: { type: string; text?: string }) => { + if (item.type === 'text' && item.text) { + console.log(` ${item.text}`); + } else { + console.log(` ${item.type} content:`, item); + } + }); + } else { + console.log(` Unexpected result format:`, result); + } }); - } else { - console.log(` Unexpected result format:`, result); - } - }); - - console.log('\n=== All clients completed successfully ==='); - } catch (error) { - console.error('Error running multiple clients:', error); - process.exit(1); - } + console.log('\n=== All clients completed successfully ==='); + } catch (error) { + console.error('Error running multiple clients:', error); + process.exit(1); + } } // Start the example main().catch((error: unknown) => { - console.error('Error running MCP multiple clients example:', error); - process.exit(1); -}); \ No newline at end of file + console.error('Error running MCP multiple clients example:', error); + process.exit(1); +}); diff --git a/src/examples/client/parallelToolCallsClient.ts b/src/examples/client/parallelToolCallsClient.ts index 3783992d6..2ad249de7 100644 --- a/src/examples/client/parallelToolCallsClient.ts +++ b/src/examples/client/parallelToolCallsClient.ts @@ -1,16 +1,16 @@ import { Client } from '../../client/index.js'; import { StreamableHTTPClientTransport } from '../../client/streamableHttp.js'; import { - ListToolsRequest, - ListToolsResultSchema, - CallToolResultSchema, - LoggingMessageNotificationSchema, - CallToolResult, + ListToolsRequest, + ListToolsResultSchema, + CallToolResultSchema, + LoggingMessageNotificationSchema, + CallToolResult } from '../../types.js'; /** * Parallel Tool Calls MCP Client - * + * * This client demonstrates how to: * 1. Start multiple tool calls in parallel * 2. Track notifications from each tool call using a caller parameter @@ -21,92 +21,90 @@ const args = process.argv.slice(2); const serverUrl = args[0] || 'http://localhost:3000/mcp'; async function main(): Promise { - console.log('MCP Parallel Tool Calls Client'); - console.log('=============================='); - console.log(`Connecting to server at: ${serverUrl}`); - - let client: Client; - let transport: StreamableHTTPClientTransport; - - try { - // Create client with streamable HTTP transport - client = new Client({ - name: 'parallel-tool-calls-client', - version: '1.0.0' - }); - - client.onerror = (error) => { - console.error('Client error:', error); - }; - - // Connect to the server - transport = new StreamableHTTPClientTransport(new URL(serverUrl)); - await client.connect(transport); - console.log('Successfully connected to MCP server'); - - // Set up notification handler with caller identification - client.setNotificationHandler(LoggingMessageNotificationSchema, (notification) => { - console.log(`Notification: ${notification.params.data}`); - }); - - console.log("List tools") - const toolsRequest = await listTools(client); - console.log("Tools: ", toolsRequest) - - - // 2. Start multiple notification tools in parallel - console.log('\n=== Starting Multiple Notification Streams in Parallel ==='); - const toolResults = await startParallelNotificationTools(client); - - // Log the results from each tool call - for (const [caller, result] of Object.entries(toolResults)) { - console.log(`\n=== Tool result for ${caller} ===`); - result.content.forEach((item: { type: string; text?: string; }) => { - if (item.type === 'text') { - console.log(` ${item.text}`); - } else { - console.log(` ${item.type} content:`, item); - } - }); - } + console.log('MCP Parallel Tool Calls Client'); + console.log('=============================='); + console.log(`Connecting to server at: ${serverUrl}`); + + let client: Client; + let transport: StreamableHTTPClientTransport; + + try { + // Create client with streamable HTTP transport + client = new Client({ + name: 'parallel-tool-calls-client', + version: '1.0.0' + }); - // 3. Wait for all notifications (10 seconds) - console.log('\n=== Waiting for all notifications ==='); - await new Promise(resolve => setTimeout(resolve, 10000)); + client.onerror = error => { + console.error('Client error:', error); + }; - // 4. Disconnect - console.log('\n=== Disconnecting ==='); - await transport.close(); - console.log('Disconnected from MCP server'); + // Connect to the server + transport = new StreamableHTTPClientTransport(new URL(serverUrl)); + await client.connect(transport); + console.log('Successfully connected to MCP server'); - } catch (error) { - console.error('Error running client:', error); - process.exit(1); - } + // Set up notification handler with caller identification + client.setNotificationHandler(LoggingMessageNotificationSchema, notification => { + console.log(`Notification: ${notification.params.data}`); + }); + + console.log('List tools'); + const toolsRequest = await listTools(client); + console.log('Tools: ', toolsRequest); + + // 2. Start multiple notification tools in parallel + console.log('\n=== Starting Multiple Notification Streams in Parallel ==='); + const toolResults = await startParallelNotificationTools(client); + + // Log the results from each tool call + for (const [caller, result] of Object.entries(toolResults)) { + console.log(`\n=== Tool result for ${caller} ===`); + result.content.forEach((item: { type: string; text?: string }) => { + if (item.type === 'text') { + console.log(` ${item.text}`); + } else { + console.log(` ${item.type} content:`, item); + } + }); + } + + // 3. Wait for all notifications (10 seconds) + console.log('\n=== Waiting for all notifications ==='); + await new Promise(resolve => setTimeout(resolve, 10000)); + + // 4. Disconnect + console.log('\n=== Disconnecting ==='); + await transport.close(); + console.log('Disconnected from MCP server'); + } catch (error) { + console.error('Error running client:', error); + process.exit(1); + } } /** * List available tools on the server */ async function listTools(client: Client): Promise { - try { - const toolsRequest: ListToolsRequest = { - method: 'tools/list', - params: {} - }; - const toolsResult = await client.request(toolsRequest, ListToolsResultSchema); - - console.log('Available tools:'); - if (toolsResult.tools.length === 0) { - console.log(' No tools available'); - } else { - for (const tool of toolsResult.tools) { - console.log(` - ${tool.name}: ${tool.description}`); - } + try { + const toolsRequest: ListToolsRequest = { + method: 'tools/list', + params: {} + }; + const toolsResult = await client.request(toolsRequest, ListToolsResultSchema); + + console.log('Available tools:'); + if (toolsResult.tools.length === 0) { + console.log(' No tools available'); + } else { + for (const tool of toolsResult.tools) { + console.log(` - ${tool.name}: ${tool.description}`); + } + } + } catch (error) { + console.log(`Tools not supported by this server: ${error}`); } - } catch (error) { - console.log(`Tools not supported by this server: ${error}`); - } } /** @@ -114,84 +112,85 @@ async function listTools(client: Client): Promise { * Each tool call includes a caller parameter to identify its notifications */ async function startParallelNotificationTools(client: Client): Promise> { - try { - // Define multiple tool calls with different configurations - const toolCalls = [ - { - caller: 'fast-notifier', - request: { - method: 'tools/call', - params: { - name: 'start-notification-stream', - arguments: { - interval: 2, // 0.5 second between notifications - count: 10, // Send 10 notifications - caller: 'fast-notifier' // Identify this tool call - } - } - } - }, - { - caller: 'slow-notifier', - request: { - method: 'tools/call', - params: { - name: 'start-notification-stream', - arguments: { - interval: 5, // 2 seconds between notifications - count: 5, // Send 5 notifications - caller: 'slow-notifier' // Identify this tool call + try { + // Define multiple tool calls with different configurations + const toolCalls = [ + { + caller: 'fast-notifier', + request: { + method: 'tools/call', + params: { + name: 'start-notification-stream', + arguments: { + interval: 2, // 0.5 second between notifications + count: 10, // Send 10 notifications + caller: 'fast-notifier' // Identify this tool call + } + } + } + }, + { + caller: 'slow-notifier', + request: { + method: 'tools/call', + params: { + name: 'start-notification-stream', + arguments: { + interval: 5, // 2 seconds between notifications + count: 5, // Send 5 notifications + caller: 'slow-notifier' // Identify this tool call + } + } + } + }, + { + caller: 'burst-notifier', + request: { + method: 'tools/call', + params: { + name: 'start-notification-stream', + arguments: { + interval: 1, // 0.1 second between notifications + count: 3, // Send just 3 notifications + caller: 'burst-notifier' // Identify this tool call + } + } + } } - } - } - }, - { - caller: 'burst-notifier', - request: { - method: 'tools/call', - params: { - name: 'start-notification-stream', - arguments: { - interval: 1, // 0.1 second between notifications - count: 3, // Send just 3 notifications - caller: 'burst-notifier' // Identify this tool call - } - } - } - } - ]; - - console.log(`Starting ${toolCalls.length} notification tools in parallel...`); - - // Start all tool calls in parallel - const toolPromises = toolCalls.map(({ caller, request }) => { - console.log(`Starting tool call for ${caller}...`); - return client.request(request, CallToolResultSchema) - .then(result => ({ caller, result })) - .catch(error => { - console.error(`Error in tool call for ${caller}:`, error); - throw error; + ]; + + console.log(`Starting ${toolCalls.length} notification tools in parallel...`); + + // Start all tool calls in parallel + const toolPromises = toolCalls.map(({ caller, request }) => { + console.log(`Starting tool call for ${caller}...`); + return client + .request(request, CallToolResultSchema) + .then(result => ({ caller, result })) + .catch(error => { + console.error(`Error in tool call for ${caller}:`, error); + throw error; + }); }); - }); - - // Wait for all tool calls to complete - const results = await Promise.all(toolPromises); - - // Organize results by caller - const resultsByTool: Record = {}; - results.forEach(({ caller, result }) => { - resultsByTool[caller] = result; - }); - - return resultsByTool; - } catch (error) { - console.error(`Error starting parallel notification tools:`, error); - throw error; - } + + // Wait for all tool calls to complete + const results = await Promise.all(toolPromises); + + // Organize results by caller + const resultsByTool: Record = {}; + results.forEach(({ caller, result }) => { + resultsByTool[caller] = result; + }); + + return resultsByTool; + } catch (error) { + console.error(`Error starting parallel notification tools:`, error); + throw error; + } } // Start the client main().catch((error: unknown) => { - console.error('Error running MCP client:', error); - process.exit(1); -}); \ No newline at end of file + console.error('Error running MCP client:', error); + process.exit(1); +}); diff --git a/src/examples/client/simpleOAuthClient.ts b/src/examples/client/simpleOAuthClient.ts index b7388384a..354886050 100644 --- a/src/examples/client/simpleOAuthClient.ts +++ b/src/examples/client/simpleOAuthClient.ts @@ -7,12 +7,7 @@ import { exec } from 'node:child_process'; import { Client } from '../../client/index.js'; import { StreamableHTTPClientTransport } from '../../client/streamableHttp.js'; import { OAuthClientInformation, OAuthClientInformationFull, OAuthClientMetadata, OAuthTokens } from '../../shared/auth.js'; -import { - CallToolRequest, - ListToolsRequest, - CallToolResultSchema, - ListToolsResultSchema -} from '../../types.js'; +import { CallToolRequest, ListToolsRequest, CallToolResultSchema, ListToolsResultSchema } from '../../types.js'; import { OAuthClientProvider, UnauthorizedError } from '../../client/auth.js'; // Configuration @@ -25,124 +20,126 @@ const CALLBACK_URL = `http://localhost:${CALLBACK_PORT}/callback`; * In production, you should persist tokens securely */ class InMemoryOAuthClientProvider implements OAuthClientProvider { - private _clientInformation?: OAuthClientInformationFull; - private _tokens?: OAuthTokens; - private _codeVerifier?: string; - - constructor( - private readonly _redirectUrl: string | URL, - private readonly _clientMetadata: OAuthClientMetadata, - onRedirect?: (url: URL) => void - ) { - this._onRedirect = onRedirect || ((url) => { - console.log(`Redirect to: ${url.toString()}`); - }); - } + private _clientInformation?: OAuthClientInformationFull; + private _tokens?: OAuthTokens; + private _codeVerifier?: string; + + constructor( + private readonly _redirectUrl: string | URL, + private readonly _clientMetadata: OAuthClientMetadata, + onRedirect?: (url: URL) => void + ) { + this._onRedirect = + onRedirect || + (url => { + console.log(`Redirect to: ${url.toString()}`); + }); + } - private _onRedirect: (url: URL) => void; + private _onRedirect: (url: URL) => void; - get redirectUrl(): string | URL { - return this._redirectUrl; - } + get redirectUrl(): string | URL { + return this._redirectUrl; + } - get clientMetadata(): OAuthClientMetadata { - return this._clientMetadata; - } + get clientMetadata(): OAuthClientMetadata { + return this._clientMetadata; + } - clientInformation(): OAuthClientInformation | undefined { - return this._clientInformation; - } + clientInformation(): OAuthClientInformation | undefined { + return this._clientInformation; + } - saveClientInformation(clientInformation: OAuthClientInformationFull): void { - this._clientInformation = clientInformation; - } + saveClientInformation(clientInformation: OAuthClientInformationFull): void { + this._clientInformation = clientInformation; + } - tokens(): OAuthTokens | undefined { - return this._tokens; - } + tokens(): OAuthTokens | undefined { + return this._tokens; + } - saveTokens(tokens: OAuthTokens): void { - this._tokens = tokens; - } + saveTokens(tokens: OAuthTokens): void { + this._tokens = tokens; + } - redirectToAuthorization(authorizationUrl: URL): void { - this._onRedirect(authorizationUrl); - } + redirectToAuthorization(authorizationUrl: URL): void { + this._onRedirect(authorizationUrl); + } - saveCodeVerifier(codeVerifier: string): void { - this._codeVerifier = codeVerifier; - } + saveCodeVerifier(codeVerifier: string): void { + this._codeVerifier = codeVerifier; + } - codeVerifier(): string { - if (!this._codeVerifier) { - throw new Error('No code verifier saved'); + codeVerifier(): string { + if (!this._codeVerifier) { + throw new Error('No code verifier saved'); + } + return this._codeVerifier; } - return this._codeVerifier; - } } /** * Interactive MCP client with OAuth authentication * Demonstrates the complete OAuth flow with browser-based authorization */ class InteractiveOAuthClient { - private client: Client | null = null; - private readonly rl = createInterface({ - input: process.stdin, - output: process.stdout, - }); - - constructor(private serverUrl: string) { } - - /** - * Prompts user for input via readline - */ - private async question(query: string): Promise { - return new Promise((resolve) => { - this.rl.question(query, resolve); + private client: Client | null = null; + private readonly rl = createInterface({ + input: process.stdin, + output: process.stdout }); - } - /** - * Opens the authorization URL in the user's default browser - */ - private async openBrowser(url: string): Promise { - console.log(`🌐 Opening browser for authorization: ${url}`); + constructor(private serverUrl: string) {} - const command = `open "${url}"`; + /** + * Prompts user for input via readline + */ + private async question(query: string): Promise { + return new Promise(resolve => { + this.rl.question(query, resolve); + }); + } - exec(command, (error) => { - if (error) { - console.error(`Failed to open browser: ${error.message}`); - console.log(`Please manually open: ${url}`); - } - }); - } - /** - * Example OAuth callback handler - in production, use a more robust approach - * for handling callbacks and storing tokens - */ - /** - * Starts a temporary HTTP server to receive the OAuth callback - */ - private async waitForOAuthCallback(): Promise { - return new Promise((resolve, reject) => { - const server = createServer((req, res) => { - // Ignore favicon requests - if (req.url === '/favicon.ico') { - res.writeHead(404); - res.end(); - return; - } + /** + * Opens the authorization URL in the user's default browser + */ + private async openBrowser(url: string): Promise { + console.log(`🌐 Opening browser for authorization: ${url}`); - console.log(`📥 Received callback: ${req.url}`); - const parsedUrl = new URL(req.url || '', 'http://localhost'); - const code = parsedUrl.searchParams.get('code'); - const error = parsedUrl.searchParams.get('error'); + const command = `open "${url}"`; - if (code) { - console.log(`✅ Authorization code received: ${code?.substring(0, 10)}...`); - res.writeHead(200, { 'Content-Type': 'text/html' }); - res.end(` + exec(command, error => { + if (error) { + console.error(`Failed to open browser: ${error.message}`); + console.log(`Please manually open: ${url}`); + } + }); + } + /** + * Example OAuth callback handler - in production, use a more robust approach + * for handling callbacks and storing tokens + */ + /** + * Starts a temporary HTTP server to receive the OAuth callback + */ + private async waitForOAuthCallback(): Promise { + return new Promise((resolve, reject) => { + const server = createServer((req, res) => { + // Ignore favicon requests + if (req.url === '/favicon.ico') { + res.writeHead(404); + res.end(); + return; + } + + console.log(`📥 Received callback: ${req.url}`); + const parsedUrl = new URL(req.url || '', 'http://localhost'); + const code = parsedUrl.searchParams.get('code'); + const error = parsedUrl.searchParams.get('error'); + + if (code) { + console.log(`✅ Authorization code received: ${code?.substring(0, 10)}...`); + res.writeHead(200, { 'Content-Type': 'text/html' }); + res.end(`

Authorization Successful!

@@ -152,12 +149,12 @@ class InteractiveOAuthClient { `); - resolve(code); - setTimeout(() => server.close(), 3000); - } else if (error) { - console.log(`❌ Authorization error: ${error}`); - res.writeHead(400, { 'Content-Type': 'text/html' }); - res.end(` + resolve(code); + setTimeout(() => server.close(), 3000); + } else if (error) { + console.log(`❌ Authorization error: ${error}`); + res.writeHead(400, { 'Content-Type': 'text/html' }); + res.end(`

Authorization Failed

@@ -165,260 +162,259 @@ class InteractiveOAuthClient { `); - reject(new Error(`OAuth authorization failed: ${error}`)); - } else { - console.log(`❌ No authorization code or error in callback`); - res.writeHead(400); - res.end('Bad request'); - reject(new Error('No authorization code provided')); + reject(new Error(`OAuth authorization failed: ${error}`)); + } else { + console.log(`❌ No authorization code or error in callback`); + res.writeHead(400); + res.end('Bad request'); + reject(new Error('No authorization code provided')); + } + }); + + server.listen(CALLBACK_PORT, () => { + console.log(`OAuth callback server started on http://localhost:${CALLBACK_PORT}`); + }); + }); + } + + private async attemptConnection(oauthProvider: InMemoryOAuthClientProvider): Promise { + console.log('🚢 Creating transport with OAuth provider...'); + const baseUrl = new URL(this.serverUrl); + const transport = new StreamableHTTPClientTransport(baseUrl, { + authProvider: oauthProvider + }); + console.log('🚢 Transport created'); + + try { + console.log('🔌 Attempting connection (this will trigger OAuth redirect)...'); + await this.client!.connect(transport); + console.log('✅ Connected successfully'); + } catch (error) { + if (error instanceof UnauthorizedError) { + console.log('🔐 OAuth required - waiting for authorization...'); + const callbackPromise = this.waitForOAuthCallback(); + const authCode = await callbackPromise; + await transport.finishAuth(authCode); + console.log('🔐 Authorization code received:', authCode); + console.log('🔌 Reconnecting with authenticated transport...'); + await this.attemptConnection(oauthProvider); + } else { + console.error('❌ Connection failed with non-auth error:', error); + throw error; + } } - }); + } - server.listen(CALLBACK_PORT, () => { - console.log(`OAuth callback server started on http://localhost:${CALLBACK_PORT}`); - }); - }); - } + /** + * Establishes connection to the MCP server with OAuth authentication + */ + async connect(): Promise { + console.log(`🔗 Attempting to connect to ${this.serverUrl}...`); + + const clientMetadata: OAuthClientMetadata = { + client_name: 'Simple OAuth MCP Client', + redirect_uris: [CALLBACK_URL], + grant_types: ['authorization_code', 'refresh_token'], + response_types: ['code'], + token_endpoint_auth_method: 'client_secret_post', + scope: 'mcp:tools' + }; + + console.log('🔐 Creating OAuth provider...'); + const oauthProvider = new InMemoryOAuthClientProvider(CALLBACK_URL, clientMetadata, (redirectUrl: URL) => { + console.log(`📌 OAuth redirect handler called - opening browser`); + console.log(`Opening browser to: ${redirectUrl.toString()}`); + this.openBrowser(redirectUrl.toString()); + }); + console.log('🔐 OAuth provider created'); - private async attemptConnection(oauthProvider: InMemoryOAuthClientProvider): Promise { - console.log('🚢 Creating transport with OAuth provider...'); - const baseUrl = new URL(this.serverUrl); - const transport = new StreamableHTTPClientTransport(baseUrl, { - authProvider: oauthProvider - }); - console.log('🚢 Transport created'); + console.log('👤 Creating MCP client...'); + this.client = new Client( + { + name: 'simple-oauth-client', + version: '1.0.0' + }, + { capabilities: {} } + ); + console.log('👤 Client created'); + + console.log('🔐 Starting OAuth flow...'); - try { - console.log('🔌 Attempting connection (this will trigger OAuth redirect)...'); - await this.client!.connect(transport); - console.log('✅ Connected successfully'); - } catch (error) { - if (error instanceof UnauthorizedError) { - console.log('🔐 OAuth required - waiting for authorization...'); - const callbackPromise = this.waitForOAuthCallback(); - const authCode = await callbackPromise; - await transport.finishAuth(authCode); - console.log('🔐 Authorization code received:', authCode); - console.log('🔌 Reconnecting with authenticated transport...'); await this.attemptConnection(oauthProvider); - } else { - console.error('❌ Connection failed with non-auth error:', error); - throw error; - } - } - } - - /** - * Establishes connection to the MCP server with OAuth authentication - */ - async connect(): Promise { - console.log(`🔗 Attempting to connect to ${this.serverUrl}...`); - - const clientMetadata: OAuthClientMetadata = { - client_name: 'Simple OAuth MCP Client', - redirect_uris: [CALLBACK_URL], - grant_types: ['authorization_code', 'refresh_token'], - response_types: ['code'], - token_endpoint_auth_method: 'client_secret_post', - scope: 'mcp:tools' - }; - - console.log('🔐 Creating OAuth provider...'); - const oauthProvider = new InMemoryOAuthClientProvider( - CALLBACK_URL, - clientMetadata, - (redirectUrl: URL) => { - console.log(`📌 OAuth redirect handler called - opening browser`); - console.log(`Opening browser to: ${redirectUrl.toString()}`); - this.openBrowser(redirectUrl.toString()); - } - ); - console.log('🔐 OAuth provider created'); - - console.log('👤 Creating MCP client...'); - this.client = new Client({ - name: 'simple-oauth-client', - version: '1.0.0', - }, { capabilities: {} }); - console.log('👤 Client created'); - - console.log('🔐 Starting OAuth flow...'); - - await this.attemptConnection(oauthProvider); - - // Start interactive loop - await this.interactiveLoop(); - } - - /** - * Main interactive loop for user commands - */ - async interactiveLoop(): Promise { - console.log('\n🎯 Interactive MCP Client with OAuth'); - console.log('Commands:'); - console.log(' list - List available tools'); - console.log(' call [args] - Call a tool'); - console.log(' quit - Exit the client'); - console.log(); - while (true) { - try { - const command = await this.question('mcp> '); + // Start interactive loop + await this.interactiveLoop(); + } - if (!command.trim()) { - continue; + /** + * Main interactive loop for user commands + */ + async interactiveLoop(): Promise { + console.log('\n🎯 Interactive MCP Client with OAuth'); + console.log('Commands:'); + console.log(' list - List available tools'); + console.log(' call [args] - Call a tool'); + console.log(' quit - Exit the client'); + console.log(); + + while (true) { + try { + const command = await this.question('mcp> '); + + if (!command.trim()) { + continue; + } + + if (command === 'quit') { + console.log('\n👋 Goodbye!'); + this.close(); + process.exit(0); + } else if (command === 'list') { + await this.listTools(); + } else if (command.startsWith('call ')) { + await this.handleCallTool(command); + } else { + console.log("❌ Unknown command. Try 'list', 'call ', or 'quit'"); + } + } catch (error) { + if (error instanceof Error && error.message === 'SIGINT') { + console.log('\n\n👋 Goodbye!'); + break; + } + console.error('❌ Error:', error); + } } + } - if (command === 'quit') { - console.log('\n👋 Goodbye!'); - this.close(); - process.exit(0); - } else if (command === 'list') { - await this.listTools(); - } else if (command.startsWith('call ')) { - await this.handleCallTool(command); - } else { - console.log('❌ Unknown command. Try \'list\', \'call \', or \'quit\''); - } - } catch (error) { - if (error instanceof Error && error.message === 'SIGINT') { - console.log('\n\n👋 Goodbye!'); - break; + private async listTools(): Promise { + if (!this.client) { + console.log('❌ Not connected to server'); + return; } - console.error('❌ Error:', error); - } - } - } - private async listTools(): Promise { - if (!this.client) { - console.log('❌ Not connected to server'); - return; + try { + const request: ListToolsRequest = { + method: 'tools/list', + params: {} + }; + + const result = await this.client.request(request, ListToolsResultSchema); + + if (result.tools && result.tools.length > 0) { + console.log('\n📋 Available tools:'); + result.tools.forEach((tool, index) => { + console.log(`${index + 1}. ${tool.name}`); + if (tool.description) { + console.log(` Description: ${tool.description}`); + } + console.log(); + }); + } else { + console.log('No tools available'); + } + } catch (error) { + console.error('❌ Failed to list tools:', error); + } } - try { - const request: ListToolsRequest = { - method: 'tools/list', - params: {}, - }; - - const result = await this.client.request(request, ListToolsResultSchema); - - if (result.tools && result.tools.length > 0) { - console.log('\n📋 Available tools:'); - result.tools.forEach((tool, index) => { - console.log(`${index + 1}. ${tool.name}`); - if (tool.description) { - console.log(` Description: ${tool.description}`); - } - console.log(); - }); - } else { - console.log('No tools available'); - } - } catch (error) { - console.error('❌ Failed to list tools:', error); - } - } + private async handleCallTool(command: string): Promise { + const parts = command.split(/\s+/); + const toolName = parts[1]; - private async handleCallTool(command: string): Promise { - const parts = command.split(/\s+/); - const toolName = parts[1]; + if (!toolName) { + console.log('❌ Please specify a tool name'); + return; + } - if (!toolName) { - console.log('❌ Please specify a tool name'); - return; - } + // Parse arguments (simple JSON-like format) + let toolArgs: Record = {}; + if (parts.length > 2) { + const argsString = parts.slice(2).join(' '); + try { + toolArgs = JSON.parse(argsString); + } catch { + console.log('❌ Invalid arguments format (expected JSON)'); + return; + } + } - // Parse arguments (simple JSON-like format) - let toolArgs: Record = {}; - if (parts.length > 2) { - const argsString = parts.slice(2).join(' '); - try { - toolArgs = JSON.parse(argsString); - } catch { - console.log('❌ Invalid arguments format (expected JSON)'); - return; - } + await this.callTool(toolName, toolArgs); } - await this.callTool(toolName, toolArgs); - } - - private async callTool(toolName: string, toolArgs: Record): Promise { - if (!this.client) { - console.log('❌ Not connected to server'); - return; - } + private async callTool(toolName: string, toolArgs: Record): Promise { + if (!this.client) { + console.log('❌ Not connected to server'); + return; + } - try { - const request: CallToolRequest = { - method: 'tools/call', - params: { - name: toolName, - arguments: toolArgs, - }, - }; - - const result = await this.client.request(request, CallToolResultSchema); - - console.log(`\n🔧 Tool '${toolName}' result:`); - if (result.content) { - result.content.forEach((content) => { - if (content.type === 'text') { - console.log(content.text); - } else { - console.log(content); - } - }); - } else { - console.log(result); - } - } catch (error) { - console.error(`❌ Failed to call tool '${toolName}':`, error); + try { + const request: CallToolRequest = { + method: 'tools/call', + params: { + name: toolName, + arguments: toolArgs + } + }; + + const result = await this.client.request(request, CallToolResultSchema); + + console.log(`\n🔧 Tool '${toolName}' result:`); + if (result.content) { + result.content.forEach(content => { + if (content.type === 'text') { + console.log(content.text); + } else { + console.log(content); + } + }); + } else { + console.log(result); + } + } catch (error) { + console.error(`❌ Failed to call tool '${toolName}':`, error); + } } - } - close(): void { - this.rl.close(); - if (this.client) { - // Note: Client doesn't have a close method in the current implementation - // This would typically close the transport connection + close(): void { + this.rl.close(); + if (this.client) { + // Note: Client doesn't have a close method in the current implementation + // This would typically close the transport connection + } } - } } /** * Main entry point */ async function main(): Promise { - const serverUrl = process.env.MCP_SERVER_URL || DEFAULT_SERVER_URL; + const serverUrl = process.env.MCP_SERVER_URL || DEFAULT_SERVER_URL; - console.log('🚀 Simple MCP OAuth Client'); - console.log(`Connecting to: ${serverUrl}`); - console.log(); + console.log('🚀 Simple MCP OAuth Client'); + console.log(`Connecting to: ${serverUrl}`); + console.log(); - const client = new InteractiveOAuthClient(serverUrl); + const client = new InteractiveOAuthClient(serverUrl); - // Handle graceful shutdown - process.on('SIGINT', () => { - console.log('\n\n👋 Goodbye!'); - client.close(); - process.exit(0); - }); + // Handle graceful shutdown + process.on('SIGINT', () => { + console.log('\n\n👋 Goodbye!'); + client.close(); + process.exit(0); + }); - try { - await client.connect(); - } catch (error) { - console.error('Failed to start client:', error); - process.exit(1); - } finally { - client.close(); - } + try { + await client.connect(); + } catch (error) { + console.error('Failed to start client:', error); + process.exit(1); + } finally { + client.close(); + } } // Run if this file is executed directly -main().catch((error) => { - console.error('Unhandled error:', error); - process.exit(1); -}); \ No newline at end of file +main().catch(error => { + console.error('Unhandled error:', error); + process.exit(1); +}); diff --git a/src/examples/client/simpleStreamableHttp.ts b/src/examples/client/simpleStreamableHttp.ts index ddb274196..10f6afcbe 100644 --- a/src/examples/client/simpleStreamableHttp.ts +++ b/src/examples/client/simpleStreamableHttp.ts @@ -2,30 +2,30 @@ import { Client } from '../../client/index.js'; import { StreamableHTTPClientTransport } from '../../client/streamableHttp.js'; import { createInterface } from 'node:readline'; import { - ListToolsRequest, - ListToolsResultSchema, - CallToolRequest, - CallToolResultSchema, - ListPromptsRequest, - ListPromptsResultSchema, - GetPromptRequest, - GetPromptResultSchema, - ListResourcesRequest, - ListResourcesResultSchema, - LoggingMessageNotificationSchema, - ResourceListChangedNotificationSchema, - ElicitRequestSchema, - ResourceLink, - ReadResourceRequest, - ReadResourceResultSchema, + ListToolsRequest, + ListToolsResultSchema, + CallToolRequest, + CallToolResultSchema, + ListPromptsRequest, + ListPromptsResultSchema, + GetPromptRequest, + GetPromptResultSchema, + ListResourcesRequest, + ListResourcesResultSchema, + LoggingMessageNotificationSchema, + ResourceListChangedNotificationSchema, + ElicitRequestSchema, + ResourceLink, + ReadResourceRequest, + ReadResourceResultSchema } from '../../types.js'; import { getDisplayName } from '../../shared/metadataUtils.js'; -import Ajv from "ajv"; +import Ajv from 'ajv'; // Create readline interface for user input const readline = createInterface({ - input: process.stdin, - output: process.stdout + input: process.stdin, + output: process.stdout }); // Track received notifications for debugging resumability @@ -39,792 +39,799 @@ let notificationsToolLastEventId: string | undefined = undefined; let sessionId: string | undefined = undefined; async function main(): Promise { - console.log('MCP Interactive Client'); - console.log('====================='); + console.log('MCP Interactive Client'); + console.log('====================='); - // Connect to server immediately with default settings - await connect(); + // Connect to server immediately with default settings + await connect(); - // Print help and start the command loop - printHelp(); - commandLoop(); + // Print help and start the command loop + printHelp(); + commandLoop(); } function printHelp(): void { - console.log('\nAvailable commands:'); - console.log(' connect [url] - Connect to MCP server (default: http://localhost:3000/mcp)'); - console.log(' disconnect - Disconnect from server'); - console.log(' terminate-session - Terminate the current session'); - console.log(' reconnect - Reconnect to the server'); - console.log(' list-tools - List available tools'); - console.log(' call-tool [args] - Call a tool with optional JSON arguments'); - console.log(' greet [name] - Call the greet tool'); - console.log(' multi-greet [name] - Call the multi-greet tool with notifications'); - console.log(' collect-info [type] - Test elicitation with collect-user-info tool (contact/preferences/feedback)'); - console.log(' start-notifications [interval] [count] - Start periodic notifications'); - console.log(' run-notifications-tool-with-resumability [interval] [count] - Run notification tool with resumability'); - console.log(' list-prompts - List available prompts'); - console.log(' get-prompt [name] [args] - Get a prompt with optional JSON arguments'); - console.log(' list-resources - List available resources'); - console.log(' read-resource - Read a specific resource by URI'); - console.log(' help - Show this help'); - console.log(' quit - Exit the program'); + console.log('\nAvailable commands:'); + console.log(' connect [url] - Connect to MCP server (default: http://localhost:3000/mcp)'); + console.log(' disconnect - Disconnect from server'); + console.log(' terminate-session - Terminate the current session'); + console.log(' reconnect - Reconnect to the server'); + console.log(' list-tools - List available tools'); + console.log(' call-tool [args] - Call a tool with optional JSON arguments'); + console.log(' greet [name] - Call the greet tool'); + console.log(' multi-greet [name] - Call the multi-greet tool with notifications'); + console.log(' collect-info [type] - Test elicitation with collect-user-info tool (contact/preferences/feedback)'); + console.log(' start-notifications [interval] [count] - Start periodic notifications'); + console.log(' run-notifications-tool-with-resumability [interval] [count] - Run notification tool with resumability'); + console.log(' list-prompts - List available prompts'); + console.log(' get-prompt [name] [args] - Get a prompt with optional JSON arguments'); + console.log(' list-resources - List available resources'); + console.log(' read-resource - Read a specific resource by URI'); + console.log(' help - Show this help'); + console.log(' quit - Exit the program'); } function commandLoop(): void { - readline.question('\n> ', async (input) => { - const args = input.trim().split(/\s+/); - const command = args[0]?.toLowerCase(); + readline.question('\n> ', async input => { + const args = input.trim().split(/\s+/); + const command = args[0]?.toLowerCase(); - try { - switch (command) { - case 'connect': - await connect(args[1]); - break; - - case 'disconnect': - await disconnect(); - break; - - case 'terminate-session': - await terminateSession(); - break; - - case 'reconnect': - await reconnect(); - break; - - case 'list-tools': - await listTools(); - break; - - case 'call-tool': - if (args.length < 2) { - console.log('Usage: call-tool [args]'); - } else { - const toolName = args[1]; - let toolArgs = {}; - if (args.length > 2) { - try { - toolArgs = JSON.parse(args.slice(2).join(' ')); - } catch { - console.log('Invalid JSON arguments. Using empty args.'); - } - } - await callTool(toolName, toolArgs); - } - break; - - case 'greet': - await callGreetTool(args[1] || 'MCP User'); - break; - - case 'multi-greet': - await callMultiGreetTool(args[1] || 'MCP User'); - break; - - case 'collect-info': - await callCollectInfoTool(args[1] || 'contact'); - break; - - case 'start-notifications': { - const interval = args[1] ? parseInt(args[1], 10) : 2000; - const count = args[2] ? parseInt(args[2], 10) : 10; - await startNotifications(interval, count); - break; - } + try { + switch (command) { + case 'connect': + await connect(args[1]); + break; + + case 'disconnect': + await disconnect(); + break; + + case 'terminate-session': + await terminateSession(); + break; + + case 'reconnect': + await reconnect(); + break; + + case 'list-tools': + await listTools(); + break; + + case 'call-tool': + if (args.length < 2) { + console.log('Usage: call-tool [args]'); + } else { + const toolName = args[1]; + let toolArgs = {}; + if (args.length > 2) { + try { + toolArgs = JSON.parse(args.slice(2).join(' ')); + } catch { + console.log('Invalid JSON arguments. Using empty args.'); + } + } + await callTool(toolName, toolArgs); + } + break; + + case 'greet': + await callGreetTool(args[1] || 'MCP User'); + break; + + case 'multi-greet': + await callMultiGreetTool(args[1] || 'MCP User'); + break; + + case 'collect-info': + await callCollectInfoTool(args[1] || 'contact'); + break; + + case 'start-notifications': { + const interval = args[1] ? parseInt(args[1], 10) : 2000; + const count = args[2] ? parseInt(args[2], 10) : 10; + await startNotifications(interval, count); + break; + } - case 'run-notifications-tool-with-resumability': { - const interval = args[1] ? parseInt(args[1], 10) : 2000; - const count = args[2] ? parseInt(args[2], 10) : 10; - await runNotificationsToolWithResumability(interval, count); - break; - } + case 'run-notifications-tool-with-resumability': { + const interval = args[1] ? parseInt(args[1], 10) : 2000; + const count = args[2] ? parseInt(args[2], 10) : 10; + await runNotificationsToolWithResumability(interval, count); + break; + } - case 'list-prompts': - await listPrompts(); - break; - - case 'get-prompt': - if (args.length < 2) { - console.log('Usage: get-prompt [args]'); - } else { - const promptName = args[1]; - let promptArgs = {}; - if (args.length > 2) { - try { - promptArgs = JSON.parse(args.slice(2).join(' ')); - } catch { - console.log('Invalid JSON arguments. Using empty args.'); - } + case 'list-prompts': + await listPrompts(); + break; + + case 'get-prompt': + if (args.length < 2) { + console.log('Usage: get-prompt [args]'); + } else { + const promptName = args[1]; + let promptArgs = {}; + if (args.length > 2) { + try { + promptArgs = JSON.parse(args.slice(2).join(' ')); + } catch { + console.log('Invalid JSON arguments. Using empty args.'); + } + } + await getPrompt(promptName, promptArgs); + } + break; + + case 'list-resources': + await listResources(); + break; + + case 'read-resource': + if (args.length < 2) { + console.log('Usage: read-resource '); + } else { + await readResource(args[1]); + } + break; + + case 'help': + printHelp(); + break; + + case 'quit': + case 'exit': + await cleanup(); + return; + + default: + if (command) { + console.log(`Unknown command: ${command}`); + } + break; } - await getPrompt(promptName, promptArgs); - } - break; - - case 'list-resources': - await listResources(); - break; - - case 'read-resource': - if (args.length < 2) { - console.log('Usage: read-resource '); - } else { - await readResource(args[1]); - } - break; - - case 'help': - printHelp(); - break; - - case 'quit': - case 'exit': - await cleanup(); - return; - - default: - if (command) { - console.log(`Unknown command: ${command}`); - } - break; - } - } catch (error) { - console.error(`Error executing command: ${error}`); - } + } catch (error) { + console.error(`Error executing command: ${error}`); + } - // Continue the command loop - commandLoop(); - }); + // Continue the command loop + commandLoop(); + }); } async function connect(url?: string): Promise { - if (client) { - console.log('Already connected. Disconnect first.'); - return; - } - - if (url) { - serverUrl = url; - } - - console.log(`Connecting to ${serverUrl}...`); - - try { - // Create a new client with elicitation capability - client = new Client({ - name: 'example-client', - version: '1.0.0' - }, { - capabilities: { - elicitation: {}, - }, - }); - client.onerror = (error) => { - console.error('\x1b[31mClient error:', error, '\x1b[0m'); + if (client) { + console.log('Already connected. Disconnect first.'); + return; } - // Set up elicitation request handler with proper validation - client.setRequestHandler(ElicitRequestSchema, async (request) => { - console.log('\n🔔 Elicitation Request Received:'); - console.log(`Message: ${request.params.message}`); - console.log('Requested Schema:'); - console.log(JSON.stringify(request.params.requestedSchema, null, 2)); - - const schema = request.params.requestedSchema; - const properties = schema.properties; - const required = schema.required || []; - - // Set up AJV validator for the requested schema - const ajv = new Ajv(); - const validate = ajv.compile(schema); - - let attempts = 0; - const maxAttempts = 3; - - while (attempts < maxAttempts) { - attempts++; - console.log(`\nPlease provide the following information (attempt ${attempts}/${maxAttempts}):`); - - const content: Record = {}; - let inputCancelled = false; - - // Collect input for each field - for (const [fieldName, fieldSchema] of Object.entries(properties)) { - const field = fieldSchema as { - type?: string; - title?: string; - description?: string; - default?: unknown; - enum?: string[]; - minimum?: number; - maximum?: number; - minLength?: number; - maxLength?: number; - format?: string; - }; - - const isRequired = required.includes(fieldName); - let prompt = `${field.title || fieldName}`; - - // Add helpful information to the prompt - if (field.description) { - prompt += ` (${field.description})`; - } - if (field.enum) { - prompt += ` [options: ${field.enum.join(', ')}]`; - } - if (field.type === 'number' || field.type === 'integer') { - if (field.minimum !== undefined && field.maximum !== undefined) { - prompt += ` [${field.minimum}-${field.maximum}]`; - } else if (field.minimum !== undefined) { - prompt += ` [min: ${field.minimum}]`; - } else if (field.maximum !== undefined) { - prompt += ` [max: ${field.maximum}]`; - } - } - if (field.type === 'string' && field.format) { - prompt += ` [format: ${field.format}]`; - } - if (isRequired) { - prompt += ' *required*'; - } - if (field.default !== undefined) { - prompt += ` [default: ${field.default}]`; - } - - prompt += ': '; - - const answer = await new Promise((resolve) => { - readline.question(prompt, (input) => { - resolve(input.trim()); - }); - }); - - // Check for cancellation - if (answer.toLowerCase() === 'cancel' || answer.toLowerCase() === 'c') { - inputCancelled = true; - break; - } - - // Parse and validate the input - try { - if (answer === '' && field.default !== undefined) { - content[fieldName] = field.default; - } else if (answer === '' && !isRequired) { - // Skip optional empty fields - continue; - } else if (answer === '') { - throw new Error(`${fieldName} is required`); - } else { - // Parse the value based on type - let parsedValue: unknown; - - if (field.type === 'boolean') { - parsedValue = answer.toLowerCase() === 'true' || answer.toLowerCase() === 'yes' || answer === '1'; - } else if (field.type === 'number') { - parsedValue = parseFloat(answer); - if (isNaN(parsedValue as number)) { - throw new Error(`${fieldName} must be a valid number`); - } - } else if (field.type === 'integer') { - parsedValue = parseInt(answer, 10); - if (isNaN(parsedValue as number)) { - throw new Error(`${fieldName} must be a valid integer`); - } - } else if (field.enum) { - if (!field.enum.includes(answer)) { - throw new Error(`${fieldName} must be one of: ${field.enum.join(', ')}`); - } - parsedValue = answer; - } else { - parsedValue = answer; - } + if (url) { + serverUrl = url; + } - content[fieldName] = parsedValue; + console.log(`Connecting to ${serverUrl}...`); + + try { + // Create a new client with elicitation capability + client = new Client( + { + name: 'example-client', + version: '1.0.0' + }, + { + capabilities: { + elicitation: {} + } } - } catch (error) { - console.log(`❌ Error: ${error}`); - // Continue to next attempt - break; - } - } + ); + client.onerror = error => { + console.error('\x1b[31mClient error:', error, '\x1b[0m'); + }; + + // Set up elicitation request handler with proper validation + client.setRequestHandler(ElicitRequestSchema, async request => { + console.log('\n🔔 Elicitation Request Received:'); + console.log(`Message: ${request.params.message}`); + console.log('Requested Schema:'); + console.log(JSON.stringify(request.params.requestedSchema, null, 2)); + + const schema = request.params.requestedSchema; + const properties = schema.properties; + const required = schema.required || []; + + // Set up AJV validator for the requested schema + const ajv = new Ajv(); + const validate = ajv.compile(schema); + + let attempts = 0; + const maxAttempts = 3; + + while (attempts < maxAttempts) { + attempts++; + console.log(`\nPlease provide the following information (attempt ${attempts}/${maxAttempts}):`); + + const content: Record = {}; + let inputCancelled = false; + + // Collect input for each field + for (const [fieldName, fieldSchema] of Object.entries(properties)) { + const field = fieldSchema as { + type?: string; + title?: string; + description?: string; + default?: unknown; + enum?: string[]; + minimum?: number; + maximum?: number; + minLength?: number; + maxLength?: number; + format?: string; + }; + + const isRequired = required.includes(fieldName); + let prompt = `${field.title || fieldName}`; + + // Add helpful information to the prompt + if (field.description) { + prompt += ` (${field.description})`; + } + if (field.enum) { + prompt += ` [options: ${field.enum.join(', ')}]`; + } + if (field.type === 'number' || field.type === 'integer') { + if (field.minimum !== undefined && field.maximum !== undefined) { + prompt += ` [${field.minimum}-${field.maximum}]`; + } else if (field.minimum !== undefined) { + prompt += ` [min: ${field.minimum}]`; + } else if (field.maximum !== undefined) { + prompt += ` [max: ${field.maximum}]`; + } + } + if (field.type === 'string' && field.format) { + prompt += ` [format: ${field.format}]`; + } + if (isRequired) { + prompt += ' *required*'; + } + if (field.default !== undefined) { + prompt += ` [default: ${field.default}]`; + } + + prompt += ': '; + + const answer = await new Promise(resolve => { + readline.question(prompt, input => { + resolve(input.trim()); + }); + }); + + // Check for cancellation + if (answer.toLowerCase() === 'cancel' || answer.toLowerCase() === 'c') { + inputCancelled = true; + break; + } + + // Parse and validate the input + try { + if (answer === '' && field.default !== undefined) { + content[fieldName] = field.default; + } else if (answer === '' && !isRequired) { + // Skip optional empty fields + continue; + } else if (answer === '') { + throw new Error(`${fieldName} is required`); + } else { + // Parse the value based on type + let parsedValue: unknown; + + if (field.type === 'boolean') { + parsedValue = answer.toLowerCase() === 'true' || answer.toLowerCase() === 'yes' || answer === '1'; + } else if (field.type === 'number') { + parsedValue = parseFloat(answer); + if (isNaN(parsedValue as number)) { + throw new Error(`${fieldName} must be a valid number`); + } + } else if (field.type === 'integer') { + parsedValue = parseInt(answer, 10); + if (isNaN(parsedValue as number)) { + throw new Error(`${fieldName} must be a valid integer`); + } + } else if (field.enum) { + if (!field.enum.includes(answer)) { + throw new Error(`${fieldName} must be one of: ${field.enum.join(', ')}`); + } + parsedValue = answer; + } else { + parsedValue = answer; + } + + content[fieldName] = parsedValue; + } + } catch (error) { + console.log(`❌ Error: ${error}`); + // Continue to next attempt + break; + } + } - if (inputCancelled) { - return { action: 'cancel' }; - } + if (inputCancelled) { + return { action: 'cancel' }; + } - // If we didn't complete all fields due to an error, try again - if (Object.keys(content).length !== Object.keys(properties).filter(name => - required.includes(name) || content[name] !== undefined - ).length) { - if (attempts < maxAttempts) { - console.log('Please try again...'); - continue; - } else { - console.log('Maximum attempts reached. Declining request.'); - return { action: 'decline' }; - } - } + // If we didn't complete all fields due to an error, try again + if ( + Object.keys(content).length !== + Object.keys(properties).filter(name => required.includes(name) || content[name] !== undefined).length + ) { + if (attempts < maxAttempts) { + console.log('Please try again...'); + continue; + } else { + console.log('Maximum attempts reached. Declining request.'); + return { action: 'decline' }; + } + } - // Validate the complete object against the schema - const isValid = validate(content); + // Validate the complete object against the schema + const isValid = validate(content); + + if (!isValid) { + console.log('❌ Validation errors:'); + validate.errors?.forEach(error => { + console.log(` - ${error.dataPath || 'root'}: ${error.message}`); + }); + + if (attempts < maxAttempts) { + console.log('Please correct the errors and try again...'); + continue; + } else { + console.log('Maximum attempts reached. Declining request.'); + return { action: 'decline' }; + } + } - if (!isValid) { - console.log('❌ Validation errors:'); - validate.errors?.forEach(error => { - console.log(` - ${error.dataPath || 'root'}: ${error.message}`); - }); + // Show the collected data and ask for confirmation + console.log('\n✅ Collected data:'); + console.log(JSON.stringify(content, null, 2)); + + const confirmAnswer = await new Promise(resolve => { + readline.question('\nSubmit this information? (yes/no/cancel): ', input => { + resolve(input.trim().toLowerCase()); + }); + }); + + if (confirmAnswer === 'yes' || confirmAnswer === 'y') { + return { + action: 'accept', + content + }; + } else if (confirmAnswer === 'cancel' || confirmAnswer === 'c') { + return { action: 'cancel' }; + } else if (confirmAnswer === 'no' || confirmAnswer === 'n') { + if (attempts < maxAttempts) { + console.log('Please re-enter the information...'); + continue; + } else { + return { action: 'decline' }; + } + } + } - if (attempts < maxAttempts) { - console.log('Please correct the errors and try again...'); - continue; - } else { console.log('Maximum attempts reached. Declining request.'); return { action: 'decline' }; - } - } - - // Show the collected data and ask for confirmation - console.log('\n✅ Collected data:'); - console.log(JSON.stringify(content, null, 2)); - - const confirmAnswer = await new Promise((resolve) => { - readline.question('\nSubmit this information? (yes/no/cancel): ', (input) => { - resolve(input.trim().toLowerCase()); - }); }); + transport = new StreamableHTTPClientTransport(new URL(serverUrl), { + sessionId: sessionId + }); - if (confirmAnswer === 'yes' || confirmAnswer === 'y') { - return { - action: 'accept', - content, - }; - } else if (confirmAnswer === 'cancel' || confirmAnswer === 'c') { - return { action: 'cancel' }; - } else if (confirmAnswer === 'no' || confirmAnswer === 'n') { - if (attempts < maxAttempts) { - console.log('Please re-enter the information...'); - continue; - } else { - return { action: 'decline' }; - } - } - } - - console.log('Maximum attempts reached. Declining request.'); - return { action: 'decline' }; - }); - - transport = new StreamableHTTPClientTransport( - new URL(serverUrl), - { - sessionId: sessionId - } - ); - - // Set up notification handlers - client.setNotificationHandler(LoggingMessageNotificationSchema, (notification) => { - notificationCount++; - console.log(`\nNotification #${notificationCount}: ${notification.params.level} - ${notification.params.data}`); - // Re-display the prompt - process.stdout.write('> '); - }); + // Set up notification handlers + client.setNotificationHandler(LoggingMessageNotificationSchema, notification => { + notificationCount++; + console.log(`\nNotification #${notificationCount}: ${notification.params.level} - ${notification.params.data}`); + // Re-display the prompt + process.stdout.write('> '); + }); - client.setNotificationHandler(ResourceListChangedNotificationSchema, async (_) => { - console.log(`\nResource list changed notification received!`); - try { - if (!client) { - console.log('Client disconnected, cannot fetch resources'); - return; - } - const resourcesResult = await client.request({ - method: 'resources/list', - params: {} - }, ListResourcesResultSchema); - console.log('Available resources count:', resourcesResult.resources.length); - } catch { - console.log('Failed to list resources after change notification'); - } - // Re-display the prompt - process.stdout.write('> '); - }); + client.setNotificationHandler(ResourceListChangedNotificationSchema, async _ => { + console.log(`\nResource list changed notification received!`); + try { + if (!client) { + console.log('Client disconnected, cannot fetch resources'); + return; + } + const resourcesResult = await client.request( + { + method: 'resources/list', + params: {} + }, + ListResourcesResultSchema + ); + console.log('Available resources count:', resourcesResult.resources.length); + } catch { + console.log('Failed to list resources after change notification'); + } + // Re-display the prompt + process.stdout.write('> '); + }); - // Connect the client - await client.connect(transport); - sessionId = transport.sessionId - console.log('Transport created with session ID:', sessionId); - console.log('Connected to MCP server'); - } catch (error) { - console.error('Failed to connect:', error); - client = null; - transport = null; - } + // Connect the client + await client.connect(transport); + sessionId = transport.sessionId; + console.log('Transport created with session ID:', sessionId); + console.log('Connected to MCP server'); + } catch (error) { + console.error('Failed to connect:', error); + client = null; + transport = null; + } } async function disconnect(): Promise { - if (!client || !transport) { - console.log('Not connected.'); - return; - } - - try { - await transport.close(); - console.log('Disconnected from MCP server'); - client = null; - transport = null; - } catch (error) { - console.error('Error disconnecting:', error); - } + if (!client || !transport) { + console.log('Not connected.'); + return; + } + + try { + await transport.close(); + console.log('Disconnected from MCP server'); + client = null; + transport = null; + } catch (error) { + console.error('Error disconnecting:', error); + } } async function terminateSession(): Promise { - if (!client || !transport) { - console.log('Not connected.'); - return; - } - - try { - console.log('Terminating session with ID:', transport.sessionId); - await transport.terminateSession(); - console.log('Session terminated successfully'); - - // Check if sessionId was cleared after termination - if (!transport.sessionId) { - console.log('Session ID has been cleared'); - sessionId = undefined; - - // Also close the transport and clear client objects - await transport.close(); - console.log('Transport closed after session termination'); - client = null; - transport = null; - } else { - console.log('Server responded with 405 Method Not Allowed (session termination not supported)'); - console.log('Session ID is still active:', transport.sessionId); + if (!client || !transport) { + console.log('Not connected.'); + return; + } + + try { + console.log('Terminating session with ID:', transport.sessionId); + await transport.terminateSession(); + console.log('Session terminated successfully'); + + // Check if sessionId was cleared after termination + if (!transport.sessionId) { + console.log('Session ID has been cleared'); + sessionId = undefined; + + // Also close the transport and clear client objects + await transport.close(); + console.log('Transport closed after session termination'); + client = null; + transport = null; + } else { + console.log('Server responded with 405 Method Not Allowed (session termination not supported)'); + console.log('Session ID is still active:', transport.sessionId); + } + } catch (error) { + console.error('Error terminating session:', error); } - } catch (error) { - console.error('Error terminating session:', error); - } } async function reconnect(): Promise { - if (client) { - await disconnect(); - } - await connect(); + if (client) { + await disconnect(); + } + await connect(); } async function listTools(): Promise { - if (!client) { - console.log('Not connected to server.'); - return; - } - - try { - const toolsRequest: ListToolsRequest = { - method: 'tools/list', - params: {} - }; - const toolsResult = await client.request(toolsRequest, ListToolsResultSchema); - - console.log('Available tools:'); - if (toolsResult.tools.length === 0) { - console.log(' No tools available'); - } else { - for (const tool of toolsResult.tools) { - console.log(` - id: ${tool.name}, name: ${getDisplayName(tool)}, description: ${tool.description}`); - } + if (!client) { + console.log('Not connected to server.'); + return; + } + + try { + const toolsRequest: ListToolsRequest = { + method: 'tools/list', + params: {} + }; + const toolsResult = await client.request(toolsRequest, ListToolsResultSchema); + + console.log('Available tools:'); + if (toolsResult.tools.length === 0) { + console.log(' No tools available'); + } else { + for (const tool of toolsResult.tools) { + console.log(` - id: ${tool.name}, name: ${getDisplayName(tool)}, description: ${tool.description}`); + } + } + } catch (error) { + console.log(`Tools not supported by this server (${error})`); } - } catch (error) { - console.log(`Tools not supported by this server (${error})`); - } } async function callTool(name: string, args: Record): Promise { - if (!client) { - console.log('Not connected to server.'); - return; - } - - try { - const request: CallToolRequest = { - method: 'tools/call', - params: { - name, - arguments: args - } - }; - - console.log(`Calling tool '${name}' with args:`, args); - const result = await client.request(request, CallToolResultSchema); - - console.log('Tool result:'); - const resourceLinks: ResourceLink[] = []; - - result.content.forEach(item => { - if (item.type === 'text') { - console.log(` ${item.text}`); - } else if (item.type === 'resource_link') { - const resourceLink = item as ResourceLink; - resourceLinks.push(resourceLink); - console.log(` 📁 Resource Link: ${resourceLink.name}`); - console.log(` URI: ${resourceLink.uri}`); - if (resourceLink.mimeType) { - console.log(` Type: ${resourceLink.mimeType}`); - } - if (resourceLink.description) { - console.log(` Description: ${resourceLink.description}`); - } - } else if (item.type === 'resource') { - console.log(` [Embedded Resource: ${item.resource.uri}]`); - } else if (item.type === 'image') { - console.log(` [Image: ${item.mimeType}]`); - } else if (item.type === 'audio') { - console.log(` [Audio: ${item.mimeType}]`); - } else { - console.log(` [Unknown content type]:`, item); - } - }); + if (!client) { + console.log('Not connected to server.'); + return; + } + + try { + const request: CallToolRequest = { + method: 'tools/call', + params: { + name, + arguments: args + } + }; + + console.log(`Calling tool '${name}' with args:`, args); + const result = await client.request(request, CallToolResultSchema); + + console.log('Tool result:'); + const resourceLinks: ResourceLink[] = []; + + result.content.forEach(item => { + if (item.type === 'text') { + console.log(` ${item.text}`); + } else if (item.type === 'resource_link') { + const resourceLink = item as ResourceLink; + resourceLinks.push(resourceLink); + console.log(` 📁 Resource Link: ${resourceLink.name}`); + console.log(` URI: ${resourceLink.uri}`); + if (resourceLink.mimeType) { + console.log(` Type: ${resourceLink.mimeType}`); + } + if (resourceLink.description) { + console.log(` Description: ${resourceLink.description}`); + } + } else if (item.type === 'resource') { + console.log(` [Embedded Resource: ${item.resource.uri}]`); + } else if (item.type === 'image') { + console.log(` [Image: ${item.mimeType}]`); + } else if (item.type === 'audio') { + console.log(` [Audio: ${item.mimeType}]`); + } else { + console.log(` [Unknown content type]:`, item); + } + }); - // Offer to read resource links - if (resourceLinks.length > 0) { - console.log(`\nFound ${resourceLinks.length} resource link(s). Use 'read-resource ' to read their content.`); + // Offer to read resource links + if (resourceLinks.length > 0) { + console.log(`\nFound ${resourceLinks.length} resource link(s). Use 'read-resource ' to read their content.`); + } + } catch (error) { + console.log(`Error calling tool ${name}: ${error}`); } - } catch (error) { - console.log(`Error calling tool ${name}: ${error}`); - } } - async function callGreetTool(name: string): Promise { - await callTool('greet', { name }); + await callTool('greet', { name }); } async function callMultiGreetTool(name: string): Promise { - console.log('Calling multi-greet tool with notifications...'); - await callTool('multi-greet', { name }); + console.log('Calling multi-greet tool with notifications...'); + await callTool('multi-greet', { name }); } async function callCollectInfoTool(infoType: string): Promise { - console.log(`Testing elicitation with collect-user-info tool (${infoType})...`); - await callTool('collect-user-info', { infoType }); + console.log(`Testing elicitation with collect-user-info tool (${infoType})...`); + await callTool('collect-user-info', { infoType }); } async function startNotifications(interval: number, count: number): Promise { - console.log(`Starting notification stream: interval=${interval}ms, count=${count || 'unlimited'}`); - await callTool('start-notification-stream', { interval, count }); + console.log(`Starting notification stream: interval=${interval}ms, count=${count || 'unlimited'}`); + await callTool('start-notification-stream', { interval, count }); } async function runNotificationsToolWithResumability(interval: number, count: number): Promise { - if (!client) { - console.log('Not connected to server.'); - return; - } - - try { - console.log(`Starting notification stream with resumability: interval=${interval}ms, count=${count || 'unlimited'}`); - console.log(`Using resumption token: ${notificationsToolLastEventId || 'none'}`); - - const request: CallToolRequest = { - method: 'tools/call', - params: { - name: 'start-notification-stream', - arguments: { interval, count } - } - }; - - const onLastEventIdUpdate = (event: string) => { - notificationsToolLastEventId = event; - console.log(`Updated resumption token: ${event}`); - }; - - const result = await client.request(request, CallToolResultSchema, { - resumptionToken: notificationsToolLastEventId, - onresumptiontoken: onLastEventIdUpdate - }); + if (!client) { + console.log('Not connected to server.'); + return; + } - console.log('Tool result:'); - result.content.forEach(item => { - if (item.type === 'text') { - console.log(` ${item.text}`); - } else { - console.log(` ${item.type} content:`, item); - } - }); - } catch (error) { - console.log(`Error starting notification stream: ${error}`); - } + try { + console.log(`Starting notification stream with resumability: interval=${interval}ms, count=${count || 'unlimited'}`); + console.log(`Using resumption token: ${notificationsToolLastEventId || 'none'}`); + + const request: CallToolRequest = { + method: 'tools/call', + params: { + name: 'start-notification-stream', + arguments: { interval, count } + } + }; + + const onLastEventIdUpdate = (event: string) => { + notificationsToolLastEventId = event; + console.log(`Updated resumption token: ${event}`); + }; + + const result = await client.request(request, CallToolResultSchema, { + resumptionToken: notificationsToolLastEventId, + onresumptiontoken: onLastEventIdUpdate + }); + + console.log('Tool result:'); + result.content.forEach(item => { + if (item.type === 'text') { + console.log(` ${item.text}`); + } else { + console.log(` ${item.type} content:`, item); + } + }); + } catch (error) { + console.log(`Error starting notification stream: ${error}`); + } } async function listPrompts(): Promise { - if (!client) { - console.log('Not connected to server.'); - return; - } - - try { - const promptsRequest: ListPromptsRequest = { - method: 'prompts/list', - params: {} - }; - const promptsResult = await client.request(promptsRequest, ListPromptsResultSchema); - console.log('Available prompts:'); - if (promptsResult.prompts.length === 0) { - console.log(' No prompts available'); - } else { - for (const prompt of promptsResult.prompts) { - console.log(` - id: ${prompt.name}, name: ${getDisplayName(prompt)}, description: ${prompt.description}`); - } + if (!client) { + console.log('Not connected to server.'); + return; + } + + try { + const promptsRequest: ListPromptsRequest = { + method: 'prompts/list', + params: {} + }; + const promptsResult = await client.request(promptsRequest, ListPromptsResultSchema); + console.log('Available prompts:'); + if (promptsResult.prompts.length === 0) { + console.log(' No prompts available'); + } else { + for (const prompt of promptsResult.prompts) { + console.log(` - id: ${prompt.name}, name: ${getDisplayName(prompt)}, description: ${prompt.description}`); + } + } + } catch (error) { + console.log(`Prompts not supported by this server (${error})`); } - } catch (error) { - console.log(`Prompts not supported by this server (${error})`); - } } async function getPrompt(name: string, args: Record): Promise { - if (!client) { - console.log('Not connected to server.'); - return; - } - - try { - const promptRequest: GetPromptRequest = { - method: 'prompts/get', - params: { - name, - arguments: args as Record - } - }; - - const promptResult = await client.request(promptRequest, GetPromptResultSchema); - console.log('Prompt template:'); - promptResult.messages.forEach((msg, index) => { - console.log(` [${index + 1}] ${msg.role}: ${msg.content.text}`); - }); - } catch (error) { - console.log(`Error getting prompt ${name}: ${error}`); - } + if (!client) { + console.log('Not connected to server.'); + return; + } + + try { + const promptRequest: GetPromptRequest = { + method: 'prompts/get', + params: { + name, + arguments: args as Record + } + }; + + const promptResult = await client.request(promptRequest, GetPromptResultSchema); + console.log('Prompt template:'); + promptResult.messages.forEach((msg, index) => { + console.log(` [${index + 1}] ${msg.role}: ${msg.content.text}`); + }); + } catch (error) { + console.log(`Error getting prompt ${name}: ${error}`); + } } async function listResources(): Promise { - if (!client) { - console.log('Not connected to server.'); - return; - } - - try { - const resourcesRequest: ListResourcesRequest = { - method: 'resources/list', - params: {} - }; - const resourcesResult = await client.request(resourcesRequest, ListResourcesResultSchema); - - console.log('Available resources:'); - if (resourcesResult.resources.length === 0) { - console.log(' No resources available'); - } else { - for (const resource of resourcesResult.resources) { - console.log(` - id: ${resource.name}, name: ${getDisplayName(resource)}, description: ${resource.uri}`); - } + if (!client) { + console.log('Not connected to server.'); + return; + } + + try { + const resourcesRequest: ListResourcesRequest = { + method: 'resources/list', + params: {} + }; + const resourcesResult = await client.request(resourcesRequest, ListResourcesResultSchema); + + console.log('Available resources:'); + if (resourcesResult.resources.length === 0) { + console.log(' No resources available'); + } else { + for (const resource of resourcesResult.resources) { + console.log(` - id: ${resource.name}, name: ${getDisplayName(resource)}, description: ${resource.uri}`); + } + } + } catch (error) { + console.log(`Resources not supported by this server (${error})`); } - } catch (error) { - console.log(`Resources not supported by this server (${error})`); - } } async function readResource(uri: string): Promise { - if (!client) { - console.log('Not connected to server.'); - return; - } - - try { - const request: ReadResourceRequest = { - method: 'resources/read', - params: { uri } - }; - - console.log(`Reading resource: ${uri}`); - const result = await client.request(request, ReadResourceResultSchema); - - console.log('Resource contents:'); - for (const content of result.contents) { - console.log(` URI: ${content.uri}`); - if (content.mimeType) { - console.log(` Type: ${content.mimeType}`); - } - - if ('text' in content && typeof content.text === 'string') { - console.log(' Content:'); - console.log(' ---'); - console.log(content.text.split('\n').map((line: string) => ' ' + line).join('\n')); - console.log(' ---'); - } else if ('blob' in content && typeof content.blob === 'string') { - console.log(` [Binary data: ${content.blob.length} bytes]`); - } + if (!client) { + console.log('Not connected to server.'); + return; + } + + try { + const request: ReadResourceRequest = { + method: 'resources/read', + params: { uri } + }; + + console.log(`Reading resource: ${uri}`); + const result = await client.request(request, ReadResourceResultSchema); + + console.log('Resource contents:'); + for (const content of result.contents) { + console.log(` URI: ${content.uri}`); + if (content.mimeType) { + console.log(` Type: ${content.mimeType}`); + } + + if ('text' in content && typeof content.text === 'string') { + console.log(' Content:'); + console.log(' ---'); + console.log( + content.text + .split('\n') + .map((line: string) => ' ' + line) + .join('\n') + ); + console.log(' ---'); + } else if ('blob' in content && typeof content.blob === 'string') { + console.log(` [Binary data: ${content.blob.length} bytes]`); + } + } + } catch (error) { + console.log(`Error reading resource ${uri}: ${error}`); } - } catch (error) { - console.log(`Error reading resource ${uri}: ${error}`); - } } async function cleanup(): Promise { - if (client && transport) { - try { - // First try to terminate the session gracefully - if (transport.sessionId) { + if (client && transport) { try { - console.log('Terminating session before exit...'); - await transport.terminateSession(); - console.log('Session terminated successfully'); + // First try to terminate the session gracefully + if (transport.sessionId) { + try { + console.log('Terminating session before exit...'); + await transport.terminateSession(); + console.log('Session terminated successfully'); + } catch (error) { + console.error('Error terminating session:', error); + } + } + + // Then close the transport + await transport.close(); } catch (error) { - console.error('Error terminating session:', error); + console.error('Error closing transport:', error); } - } - - // Then close the transport - await transport.close(); - } catch (error) { - console.error('Error closing transport:', error); } - } - process.stdin.setRawMode(false); - readline.close(); - console.log('\nGoodbye!'); - process.exit(0); + process.stdin.setRawMode(false); + readline.close(); + console.log('\nGoodbye!'); + process.exit(0); } // Set up raw mode for keyboard input to capture Escape key process.stdin.setRawMode(true); -process.stdin.on('data', async (data) => { - // Check for Escape key (27) - if (data.length === 1 && data[0] === 27) { - console.log('\nESC key pressed. Disconnecting from server...'); +process.stdin.on('data', async data => { + // Check for Escape key (27) + if (data.length === 1 && data[0] === 27) { + console.log('\nESC key pressed. Disconnecting from server...'); + + // Abort current operation and disconnect from server + if (client && transport) { + await disconnect(); + console.log('Disconnected. Press Enter to continue.'); + } else { + console.log('Not connected to server.'); + } - // Abort current operation and disconnect from server - if (client && transport) { - await disconnect(); - console.log('Disconnected. Press Enter to continue.'); - } else { - console.log('Not connected to server.'); + // Re-display the prompt + process.stdout.write('> '); } - - // Re-display the prompt - process.stdout.write('> '); - } }); // Handle Ctrl+C process.on('SIGINT', async () => { - console.log('\nReceived SIGINT. Cleaning up...'); - await cleanup(); + console.log('\nReceived SIGINT. Cleaning up...'); + await cleanup(); }); // Start the interactive client main().catch((error: unknown) => { - console.error('Error running MCP client:', error); - process.exit(1); -}); \ No newline at end of file + console.error('Error running MCP client:', error); + process.exit(1); +}); diff --git a/src/examples/client/streamableHttpWithSseFallbackClient.ts b/src/examples/client/streamableHttpWithSseFallbackClient.ts index 7646f0f78..657f48953 100644 --- a/src/examples/client/streamableHttpWithSseFallbackClient.ts +++ b/src/examples/client/streamableHttpWithSseFallbackClient.ts @@ -2,20 +2,20 @@ import { Client } from '../../client/index.js'; import { StreamableHTTPClientTransport } from '../../client/streamableHttp.js'; import { SSEClientTransport } from '../../client/sse.js'; import { - ListToolsRequest, - ListToolsResultSchema, - CallToolRequest, - CallToolResultSchema, - LoggingMessageNotificationSchema, + ListToolsRequest, + ListToolsResultSchema, + CallToolRequest, + CallToolResultSchema, + LoggingMessageNotificationSchema } from '../../types.js'; /** * Simplified Backwards Compatible MCP Client - * + * * This client demonstrates backward compatibility with both: * 1. Modern servers using Streamable HTTP transport (protocol version 2025-03-26) * 2. Older servers using HTTP+SSE transport (protocol version 2024-11-05) - * + * * Following the MCP specification for backwards compatibility: * - Attempts to POST an initialize request to the server URL first (modern transport) * - If that fails with 4xx status, falls back to GET request for SSE stream (older transport) @@ -26,46 +26,45 @@ const args = process.argv.slice(2); const serverUrl = args[0] || 'http://localhost:3000/mcp'; async function main(): Promise { - console.log('MCP Backwards Compatible Client'); - console.log('==============================='); - console.log(`Connecting to server at: ${serverUrl}`); - - let client: Client; - let transport: StreamableHTTPClientTransport | SSEClientTransport; - - try { - // Try connecting with automatic transport detection - const connection = await connectWithBackwardsCompatibility(serverUrl); - client = connection.client; - transport = connection.transport; - - // Set up notification handler - client.setNotificationHandler(LoggingMessageNotificationSchema, (notification) => { - console.log(`Notification: ${notification.params.level} - ${notification.params.data}`); - }); - - // DEMO WORKFLOW: - // 1. List available tools - console.log('\n=== Listing Available Tools ==='); - await listTools(client); - - // 2. Call the notification tool - console.log('\n=== Starting Notification Stream ==='); - await startNotificationTool(client); - - // 3. Wait for all notifications (5 seconds) - console.log('\n=== Waiting for all notifications ==='); - await new Promise(resolve => setTimeout(resolve, 5000)); + console.log('MCP Backwards Compatible Client'); + console.log('==============================='); + console.log(`Connecting to server at: ${serverUrl}`); - // 4. Disconnect - console.log('\n=== Disconnecting ==='); - await transport.close(); - console.log('Disconnected from MCP server'); + let client: Client; + let transport: StreamableHTTPClientTransport | SSEClientTransport; - } catch (error) { - console.error('Error running client:', error); - process.exit(1); - } + try { + // Try connecting with automatic transport detection + const connection = await connectWithBackwardsCompatibility(serverUrl); + client = connection.client; + transport = connection.transport; + + // Set up notification handler + client.setNotificationHandler(LoggingMessageNotificationSchema, notification => { + console.log(`Notification: ${notification.params.level} - ${notification.params.data}`); + }); + + // DEMO WORKFLOW: + // 1. List available tools + console.log('\n=== Listing Available Tools ==='); + await listTools(client); + + // 2. Call the notification tool + console.log('\n=== Starting Notification Stream ==='); + await startNotificationTool(client); + + // 3. Wait for all notifications (5 seconds) + console.log('\n=== Waiting for all notifications ==='); + await new Promise(resolve => setTimeout(resolve, 5000)); + + // 4. Disconnect + console.log('\n=== Disconnecting ==='); + await transport.close(); + console.log('Disconnected from MCP server'); + } catch (error) { + console.error('Error running client:', error); + process.exit(1); + } } /** @@ -73,120 +72,120 @@ async function main(): Promise { * Following the spec for client backward compatibility */ async function connectWithBackwardsCompatibility(url: string): Promise<{ - client: Client, - transport: StreamableHTTPClientTransport | SSEClientTransport, - transportType: 'streamable-http' | 'sse' + client: Client; + transport: StreamableHTTPClientTransport | SSEClientTransport; + transportType: 'streamable-http' | 'sse'; }> { - console.log('1. Trying Streamable HTTP transport first...'); - - // Step 1: Try Streamable HTTP transport first - const client = new Client({ - name: 'backwards-compatible-client', - version: '1.0.0' - }); - - client.onerror = (error) => { - console.error('Client error:', error); - }; - const baseUrl = new URL(url); - - try { - // Create modern transport - const streamableTransport = new StreamableHTTPClientTransport(baseUrl); - await client.connect(streamableTransport); - - console.log('Successfully connected using modern Streamable HTTP transport.'); - return { - client, - transport: streamableTransport, - transportType: 'streamable-http' - }; - } catch (error) { - // Step 2: If transport fails, try the older SSE transport - console.log(`StreamableHttp transport connection failed: ${error}`); - console.log('2. Falling back to deprecated HTTP+SSE transport...'); + console.log('1. Trying Streamable HTTP transport first...'); - try { - // Create SSE transport pointing to /sse endpoint - const sseTransport = new SSEClientTransport(baseUrl); - const sseClient = new Client({ + // Step 1: Try Streamable HTTP transport first + const client = new Client({ name: 'backwards-compatible-client', version: '1.0.0' - }); - await sseClient.connect(sseTransport); - - console.log('Successfully connected using deprecated HTTP+SSE transport.'); - return { - client: sseClient, - transport: sseTransport, - transportType: 'sse' - }; - } catch (sseError) { - console.error(`Failed to connect with either transport method:\n1. Streamable HTTP error: ${error}\n2. SSE error: ${sseError}`); - throw new Error('Could not connect to server with any available transport'); + }); + + client.onerror = error => { + console.error('Client error:', error); + }; + const baseUrl = new URL(url); + + try { + // Create modern transport + const streamableTransport = new StreamableHTTPClientTransport(baseUrl); + await client.connect(streamableTransport); + + console.log('Successfully connected using modern Streamable HTTP transport.'); + return { + client, + transport: streamableTransport, + transportType: 'streamable-http' + }; + } catch (error) { + // Step 2: If transport fails, try the older SSE transport + console.log(`StreamableHttp transport connection failed: ${error}`); + console.log('2. Falling back to deprecated HTTP+SSE transport...'); + + try { + // Create SSE transport pointing to /sse endpoint + const sseTransport = new SSEClientTransport(baseUrl); + const sseClient = new Client({ + name: 'backwards-compatible-client', + version: '1.0.0' + }); + await sseClient.connect(sseTransport); + + console.log('Successfully connected using deprecated HTTP+SSE transport.'); + return { + client: sseClient, + transport: sseTransport, + transportType: 'sse' + }; + } catch (sseError) { + console.error(`Failed to connect with either transport method:\n1. Streamable HTTP error: ${error}\n2. SSE error: ${sseError}`); + throw new Error('Could not connect to server with any available transport'); + } } - } } /** * List available tools on the server */ async function listTools(client: Client): Promise { - try { - const toolsRequest: ListToolsRequest = { - method: 'tools/list', - params: {} - }; - const toolsResult = await client.request(toolsRequest, ListToolsResultSchema); - - console.log('Available tools:'); - if (toolsResult.tools.length === 0) { - console.log(' No tools available'); - } else { - for (const tool of toolsResult.tools) { - console.log(` - ${tool.name}: ${tool.description}`); - } + try { + const toolsRequest: ListToolsRequest = { + method: 'tools/list', + params: {} + }; + const toolsResult = await client.request(toolsRequest, ListToolsResultSchema); + + console.log('Available tools:'); + if (toolsResult.tools.length === 0) { + console.log(' No tools available'); + } else { + for (const tool of toolsResult.tools) { + console.log(` - ${tool.name}: ${tool.description}`); + } + } + } catch (error) { + console.log(`Tools not supported by this server: ${error}`); } - } catch (error) { - console.log(`Tools not supported by this server: ${error}`); - } } /** * Start a notification stream by calling the notification tool */ async function startNotificationTool(client: Client): Promise { - try { - // Call the notification tool using reasonable defaults - const request: CallToolRequest = { - method: 'tools/call', - params: { - name: 'start-notification-stream', - arguments: { - interval: 1000, // 1 second between notifications - count: 5 // Send 5 notifications - } - } - }; - - console.log('Calling notification tool...'); - const result = await client.request(request, CallToolResultSchema); - - console.log('Tool result:'); - result.content.forEach(item => { - if (item.type === 'text') { - console.log(` ${item.text}`); - } else { - console.log(` ${item.type} content:`, item); - } - }); - } catch (error) { - console.log(`Error calling notification tool: ${error}`); - } + try { + // Call the notification tool using reasonable defaults + const request: CallToolRequest = { + method: 'tools/call', + params: { + name: 'start-notification-stream', + arguments: { + interval: 1000, // 1 second between notifications + count: 5 // Send 5 notifications + } + } + }; + + console.log('Calling notification tool...'); + const result = await client.request(request, CallToolResultSchema); + + console.log('Tool result:'); + result.content.forEach(item => { + if (item.type === 'text') { + console.log(` ${item.text}`); + } else { + console.log(` ${item.type} content:`, item); + } + }); + } catch (error) { + console.log(`Error calling notification tool: ${error}`); + } } // Start the client main().catch((error: unknown) => { - console.error('Error running MCP client:', error); - process.exit(1); -}); \ No newline at end of file + console.error('Error running MCP client:', error); + process.exit(1); +}); diff --git a/src/examples/server/demoInMemoryOAuthProvider.ts b/src/examples/server/demoInMemoryOAuthProvider.ts index c83748d35..780770bad 100644 --- a/src/examples/server/demoInMemoryOAuthProvider.ts +++ b/src/examples/server/demoInMemoryOAuthProvider.ts @@ -2,23 +2,22 @@ import { randomUUID } from 'node:crypto'; import { AuthorizationParams, OAuthServerProvider } from '../../server/auth/provider.js'; import { OAuthRegisteredClientsStore } from '../../server/auth/clients.js'; import { OAuthClientInformationFull, OAuthMetadata, OAuthTokens } from '../../shared/auth.js'; -import express, { Request, Response } from "express"; +import express, { Request, Response } from 'express'; import { AuthInfo } from '../../server/auth/types.js'; import { createOAuthMetadata, mcpAuthRouter } from '../../server/auth/router.js'; import { resourceUrlFromServerUrl } from '../../shared/auth-utils.js'; - export class DemoInMemoryClientsStore implements OAuthRegisteredClientsStore { - private clients = new Map(); + private clients = new Map(); - async getClient(clientId: string) { - return this.clients.get(clientId); - } + async getClient(clientId: string) { + return this.clients.get(clientId); + } - async registerClient(clientMetadata: OAuthClientInformationFull) { - this.clients.set(clientMetadata.client_id, clientMetadata); - return clientMetadata; - } + async registerClient(clientMetadata: OAuthClientInformationFull) { + this.clients.set(clientMetadata.client_id, clientMetadata); + return clientMetadata; + } } /** @@ -30,193 +29,200 @@ export class DemoInMemoryClientsStore implements OAuthRegisteredClientsStore { * - Rate limiting */ export class DemoInMemoryAuthProvider implements OAuthServerProvider { - clientsStore = new DemoInMemoryClientsStore(); - private codes = new Map(); - private tokens = new Map(); - - constructor(private validateResource?: (resource?: URL) => boolean) {} - - async authorize( - client: OAuthClientInformationFull, - params: AuthorizationParams, - res: Response - ): Promise { - const code = randomUUID(); - - const searchParams = new URLSearchParams({ - code, - }); - if (params.state !== undefined) { - searchParams.set('state', params.state); + clientsStore = new DemoInMemoryClientsStore(); + private codes = new Map< + string, + { + params: AuthorizationParams; + client: OAuthClientInformationFull; + } + >(); + private tokens = new Map(); + + constructor(private validateResource?: (resource?: URL) => boolean) {} + + async authorize(client: OAuthClientInformationFull, params: AuthorizationParams, res: Response): Promise { + const code = randomUUID(); + + const searchParams = new URLSearchParams({ + code + }); + if (params.state !== undefined) { + searchParams.set('state', params.state); + } + + this.codes.set(code, { + client, + params + }); + + const targetUrl = new URL(client.redirect_uris[0]); + targetUrl.search = searchParams.toString(); + res.redirect(targetUrl.toString()); } - this.codes.set(code, { - client, - params - }); - - const targetUrl = new URL(client.redirect_uris[0]); - targetUrl.search = searchParams.toString(); - res.redirect(targetUrl.toString()); - } - - async challengeForAuthorizationCode( - client: OAuthClientInformationFull, - authorizationCode: string - ): Promise { - - // Store the challenge with the code data - const codeData = this.codes.get(authorizationCode); - if (!codeData) { - throw new Error('Invalid authorization code'); - } + async challengeForAuthorizationCode(client: OAuthClientInformationFull, authorizationCode: string): Promise { + // Store the challenge with the code data + const codeData = this.codes.get(authorizationCode); + if (!codeData) { + throw new Error('Invalid authorization code'); + } - return codeData.params.codeChallenge; - } - - async exchangeAuthorizationCode( - client: OAuthClientInformationFull, - authorizationCode: string, - // Note: code verifier is checked in token.ts by default - // it's unused here for that reason. - _codeVerifier?: string - ): Promise { - const codeData = this.codes.get(authorizationCode); - if (!codeData) { - throw new Error('Invalid authorization code'); + return codeData.params.codeChallenge; } - if (codeData.client.client_id !== client.client_id) { - throw new Error(`Authorization code was not issued to this client, ${codeData.client.client_id} != ${client.client_id}`); + async exchangeAuthorizationCode( + client: OAuthClientInformationFull, + authorizationCode: string, + // Note: code verifier is checked in token.ts by default + // it's unused here for that reason. + _codeVerifier?: string + ): Promise { + const codeData = this.codes.get(authorizationCode); + if (!codeData) { + throw new Error('Invalid authorization code'); + } + + if (codeData.client.client_id !== client.client_id) { + throw new Error(`Authorization code was not issued to this client, ${codeData.client.client_id} != ${client.client_id}`); + } + + if (this.validateResource && !this.validateResource(codeData.params.resource)) { + throw new Error(`Invalid resource: ${codeData.params.resource}`); + } + + this.codes.delete(authorizationCode); + const token = randomUUID(); + + const tokenData = { + token, + clientId: client.client_id, + scopes: codeData.params.scopes || [], + expiresAt: Date.now() + 3600000, // 1 hour + resource: codeData.params.resource, + type: 'access' + }; + + this.tokens.set(token, tokenData); + + return { + access_token: token, + token_type: 'bearer', + expires_in: 3600, + scope: (codeData.params.scopes || []).join(' ') + }; } - if (this.validateResource && !this.validateResource(codeData.params.resource)) { - throw new Error(`Invalid resource: ${codeData.params.resource}`); + async exchangeRefreshToken( + _client: OAuthClientInformationFull, + _refreshToken: string, + _scopes?: string[], + _resource?: URL + ): Promise { + throw new Error('Not implemented for example demo'); } - this.codes.delete(authorizationCode); - const token = randomUUID(); - - const tokenData = { - token, - clientId: client.client_id, - scopes: codeData.params.scopes || [], - expiresAt: Date.now() + 3600000, // 1 hour - resource: codeData.params.resource, - type: 'access', - }; - - this.tokens.set(token, tokenData); - - return { - access_token: token, - token_type: 'bearer', - expires_in: 3600, - scope: (codeData.params.scopes || []).join(' '), - }; - } - - async exchangeRefreshToken( - _client: OAuthClientInformationFull, - _refreshToken: string, - _scopes?: string[], - _resource?: URL - ): Promise { - throw new Error('Not implemented for example demo'); - } - - async verifyAccessToken(token: string): Promise { - const tokenData = this.tokens.get(token); - if (!tokenData || !tokenData.expiresAt || tokenData.expiresAt < Date.now()) { - throw new Error('Invalid or expired token'); + async verifyAccessToken(token: string): Promise { + const tokenData = this.tokens.get(token); + if (!tokenData || !tokenData.expiresAt || tokenData.expiresAt < Date.now()) { + throw new Error('Invalid or expired token'); + } + + return { + token, + clientId: tokenData.clientId, + scopes: tokenData.scopes, + expiresAt: Math.floor(tokenData.expiresAt / 1000), + resource: tokenData.resource + }; } - - return { - token, - clientId: tokenData.clientId, - scopes: tokenData.scopes, - expiresAt: Math.floor(tokenData.expiresAt / 1000), - resource: tokenData.resource, - }; - } } +export const setupAuthServer = ({ + authServerUrl, + mcpServerUrl, + strictResource +}: { + authServerUrl: URL; + mcpServerUrl: URL; + strictResource: boolean; +}): OAuthMetadata => { + // Create separate auth server app + // NOTE: This is a separate app on a separate port to illustrate + // how to separate an OAuth Authorization Server from a Resource + // server in the SDK. The SDK is not intended to be provide a standalone + // authorization server. + + const validateResource = strictResource + ? (resource?: URL) => { + if (!resource) return false; + const expectedResource = resourceUrlFromServerUrl(mcpServerUrl); + return resource.toString() === expectedResource.toString(); + } + : undefined; + + const provider = new DemoInMemoryAuthProvider(validateResource); + const authApp = express(); + authApp.use(express.json()); + // For introspection requests + authApp.use(express.urlencoded()); + + // Add OAuth routes to the auth server + // NOTE: this will also add a protected resource metadata route, + // but it won't be used, so leave it. + authApp.use( + mcpAuthRouter({ + provider, + issuerUrl: authServerUrl, + scopesSupported: ['mcp:tools'] + }) + ); + + authApp.post('/introspect', async (req: Request, res: Response) => { + try { + const { token } = req.body; + if (!token) { + res.status(400).json({ error: 'Token is required' }); + return; + } + + const tokenInfo = await provider.verifyAccessToken(token); + res.json({ + active: true, + client_id: tokenInfo.clientId, + scope: tokenInfo.scopes.join(' '), + exp: tokenInfo.expiresAt, + aud: tokenInfo.resource + }); + return; + } catch (error) { + res.status(401).json({ + active: false, + error: 'Unauthorized', + error_description: `Invalid token: ${error}` + }); + } + }); -export const setupAuthServer = ({authServerUrl, mcpServerUrl, strictResource}: {authServerUrl: URL, mcpServerUrl: URL, strictResource: boolean}): OAuthMetadata => { - // Create separate auth server app - // NOTE: This is a separate app on a separate port to illustrate - // how to separate an OAuth Authorization Server from a Resource - // server in the SDK. The SDK is not intended to be provide a standalone - // authorization server. - - const validateResource = strictResource ? (resource?: URL) => { - if (!resource) return false; - const expectedResource = resourceUrlFromServerUrl(mcpServerUrl); - return resource.toString() === expectedResource.toString(); - } : undefined; - - const provider = new DemoInMemoryAuthProvider(validateResource); - const authApp = express(); - authApp.use(express.json()); - // For introspection requests - authApp.use(express.urlencoded()); - - // Add OAuth routes to the auth server - // NOTE: this will also add a protected resource metadata route, - // but it won't be used, so leave it. - authApp.use(mcpAuthRouter({ - provider, - issuerUrl: authServerUrl, - scopesSupported: ['mcp:tools'], - })); - - authApp.post('/introspect', async (req: Request, res: Response) => { - try { - const { token } = req.body; - if (!token) { - res.status(400).json({ error: 'Token is required' }); - return; - } - - const tokenInfo = await provider.verifyAccessToken(token); - res.json({ - active: true, - client_id: tokenInfo.clientId, - scope: tokenInfo.scopes.join(' '), - exp: tokenInfo.expiresAt, - aud: tokenInfo.resource, - }); - return - } catch (error) { - res.status(401).json({ - active: false, - error: 'Unauthorized', - error_description: `Invalid token: ${error}` - }); - } - }); - - const auth_port = authServerUrl.port; - // Start the auth server - authApp.listen(auth_port, (error) => { - if (error) { - console.error('Failed to start server:', error); - process.exit(1); - } - console.log(`OAuth Authorization Server listening on port ${auth_port}`); - }); + const auth_port = authServerUrl.port; + // Start the auth server + authApp.listen(auth_port, error => { + if (error) { + console.error('Failed to start server:', error); + process.exit(1); + } + console.log(`OAuth Authorization Server listening on port ${auth_port}`); + }); - // Note: we could fetch this from the server, but then we end up - // with some top level async which gets annoying. - const oauthMetadata: OAuthMetadata = createOAuthMetadata({ - provider, - issuerUrl: authServerUrl, - scopesSupported: ['mcp:tools'], - }) + // Note: we could fetch this from the server, but then we end up + // with some top level async which gets annoying. + const oauthMetadata: OAuthMetadata = createOAuthMetadata({ + provider, + issuerUrl: authServerUrl, + scopesSupported: ['mcp:tools'] + }); - oauthMetadata.introspection_endpoint = new URL("/introspect", authServerUrl).href; + oauthMetadata.introspection_endpoint = new URL('/introspect', authServerUrl).href; - return oauthMetadata; -} + return oauthMetadata; +}; diff --git a/src/examples/server/jsonResponseStreamableHttp.ts b/src/examples/server/jsonResponseStreamableHttp.ts index bc740c5fa..8b640777d 100644 --- a/src/examples/server/jsonResponseStreamableHttp.ts +++ b/src/examples/server/jsonResponseStreamableHttp.ts @@ -6,168 +6,181 @@ import { z } from 'zod'; import { CallToolResult, isInitializeRequest } from '../../types.js'; import cors from 'cors'; - // Create an MCP server with implementation details const getServer = () => { - const server = new McpServer({ - name: 'json-response-streamable-http-server', - version: '1.0.0', - }, { - capabilities: { - logging: {}, - } - }); - - // Register a simple tool that returns a greeting - server.tool( - 'greet', - 'A simple greeting tool', - { - name: z.string().describe('Name to greet'), - }, - async ({ name }): Promise => { - return { - content: [ - { - type: 'text', - text: `Hello, ${name}!`, - }, - ], - }; - } - ); - - // Register a tool that sends multiple greetings with notifications - server.tool( - 'multi-greet', - 'A tool that sends different greetings with delays between them', - { - name: z.string().describe('Name to greet'), - }, - async ({ name }, extra): Promise => { - const sleep = (ms: number) => new Promise(resolve => setTimeout(resolve, ms)); - - await server.sendLoggingMessage({ - level: "debug", - data: `Starting multi-greet for ${name}` - }, extra.sessionId); - - await sleep(1000); // Wait 1 second before first greeting - - await server.sendLoggingMessage({ - level: "info", - data: `Sending first greeting to ${name}` - }, extra.sessionId); - - await sleep(1000); // Wait another second before second greeting - - await server.sendLoggingMessage({ - level: "info", - data: `Sending second greeting to ${name}` - }, extra.sessionId); - - return { - content: [ - { - type: 'text', - text: `Good morning, ${name}!`, - } - ], - }; - } - ); - return server; -} + const server = new McpServer( + { + name: 'json-response-streamable-http-server', + version: '1.0.0' + }, + { + capabilities: { + logging: {} + } + } + ); + + // Register a simple tool that returns a greeting + server.tool( + 'greet', + 'A simple greeting tool', + { + name: z.string().describe('Name to greet') + }, + async ({ name }): Promise => { + return { + content: [ + { + type: 'text', + text: `Hello, ${name}!` + } + ] + }; + } + ); + + // Register a tool that sends multiple greetings with notifications + server.tool( + 'multi-greet', + 'A tool that sends different greetings with delays between them', + { + name: z.string().describe('Name to greet') + }, + async ({ name }, extra): Promise => { + const sleep = (ms: number) => new Promise(resolve => setTimeout(resolve, ms)); + + await server.sendLoggingMessage( + { + level: 'debug', + data: `Starting multi-greet for ${name}` + }, + extra.sessionId + ); + + await sleep(1000); // Wait 1 second before first greeting + + await server.sendLoggingMessage( + { + level: 'info', + data: `Sending first greeting to ${name}` + }, + extra.sessionId + ); + + await sleep(1000); // Wait another second before second greeting + + await server.sendLoggingMessage( + { + level: 'info', + data: `Sending second greeting to ${name}` + }, + extra.sessionId + ); + + return { + content: [ + { + type: 'text', + text: `Good morning, ${name}!` + } + ] + }; + } + ); + return server; +}; const app = express(); app.use(express.json()); // Configure CORS to expose Mcp-Session-Id header for browser-based clients -app.use(cors({ - origin: '*', // Allow all origins - adjust as needed for production - exposedHeaders: ['Mcp-Session-Id'] -})); +app.use( + cors({ + origin: '*', // Allow all origins - adjust as needed for production + exposedHeaders: ['Mcp-Session-Id'] + }) +); // Map to store transports by session ID const transports: { [sessionId: string]: StreamableHTTPServerTransport } = {}; app.post('/mcp', async (req: Request, res: Response) => { - console.log('Received MCP request:', req.body); - try { - // Check for existing session ID - const sessionId = req.headers['mcp-session-id'] as string | undefined; - let transport: StreamableHTTPServerTransport; - - if (sessionId && transports[sessionId]) { - // Reuse existing transport - transport = transports[sessionId]; - } else if (!sessionId && isInitializeRequest(req.body)) { - // New initialization request - use JSON response mode - transport = new StreamableHTTPServerTransport({ - sessionIdGenerator: () => randomUUID(), - enableJsonResponse: true, // Enable JSON response mode - onsessioninitialized: (sessionId) => { - // Store the transport by session ID when session is initialized - // This avoids race conditions where requests might come in before the session is stored - console.log(`Session initialized with ID: ${sessionId}`); - transports[sessionId] = transport; + console.log('Received MCP request:', req.body); + try { + // Check for existing session ID + const sessionId = req.headers['mcp-session-id'] as string | undefined; + let transport: StreamableHTTPServerTransport; + + if (sessionId && transports[sessionId]) { + // Reuse existing transport + transport = transports[sessionId]; + } else if (!sessionId && isInitializeRequest(req.body)) { + // New initialization request - use JSON response mode + transport = new StreamableHTTPServerTransport({ + sessionIdGenerator: () => randomUUID(), + enableJsonResponse: true, // Enable JSON response mode + onsessioninitialized: sessionId => { + // Store the transport by session ID when session is initialized + // This avoids race conditions where requests might come in before the session is stored + console.log(`Session initialized with ID: ${sessionId}`); + transports[sessionId] = transport; + } + }); + + // Connect the transport to the MCP server BEFORE handling the request + const server = getServer(); + await server.connect(transport); + await transport.handleRequest(req, res, req.body); + return; // Already handled + } else { + // Invalid request - no session ID or not initialization request + res.status(400).json({ + jsonrpc: '2.0', + error: { + code: -32000, + message: 'Bad Request: No valid session ID provided' + }, + id: null + }); + return; } - }); - - // Connect the transport to the MCP server BEFORE handling the request - const server = getServer(); - await server.connect(transport); - await transport.handleRequest(req, res, req.body); - return; // Already handled - } else { - // Invalid request - no session ID or not initialization request - res.status(400).json({ - jsonrpc: '2.0', - error: { - code: -32000, - message: 'Bad Request: No valid session ID provided', - }, - id: null, - }); - return; - } - // Handle the request with existing transport - no need to reconnect - await transport.handleRequest(req, res, req.body); - } catch (error) { - console.error('Error handling MCP request:', error); - if (!res.headersSent) { - res.status(500).json({ - jsonrpc: '2.0', - error: { - code: -32603, - message: 'Internal server error', - }, - id: null, - }); + // Handle the request with existing transport - no need to reconnect + await transport.handleRequest(req, res, req.body); + } catch (error) { + console.error('Error handling MCP request:', error); + if (!res.headersSent) { + res.status(500).json({ + jsonrpc: '2.0', + error: { + code: -32603, + message: 'Internal server error' + }, + id: null + }); + } } - } }); // Handle GET requests for SSE streams according to spec app.get('/mcp', async (req: Request, res: Response) => { - // Since this is a very simple example, we don't support GET requests for this server - // The spec requires returning 405 Method Not Allowed in this case - res.status(405).set('Allow', 'POST').send('Method Not Allowed'); + // Since this is a very simple example, we don't support GET requests for this server + // The spec requires returning 405 Method Not Allowed in this case + res.status(405).set('Allow', 'POST').send('Method Not Allowed'); }); // Start the server const PORT = 3000; -app.listen(PORT, (error) => { - if (error) { - console.error('Failed to start server:', error); - process.exit(1); - } - console.log(`MCP Streamable HTTP Server listening on port ${PORT}`); +app.listen(PORT, error => { + if (error) { + console.error('Failed to start server:', error); + process.exit(1); + } + console.log(`MCP Streamable HTTP Server listening on port ${PORT}`); }); // Handle server shutdown process.on('SIGINT', async () => { - console.log('Shutting down server...'); - process.exit(0); + console.log('Shutting down server...'); + process.exit(0); }); diff --git a/src/examples/server/mcpServerOutputSchema.ts b/src/examples/server/mcpServerOutputSchema.ts index 75bfe6900..5d1cab0bd 100644 --- a/src/examples/server/mcpServerOutputSchema.ts +++ b/src/examples/server/mcpServerOutputSchema.ts @@ -4,77 +4,77 @@ * This demonstrates how to easily create tools with structured output */ -import { McpServer } from "../../server/mcp.js"; -import { StdioServerTransport } from "../../server/stdio.js"; -import { z } from "zod"; +import { McpServer } from '../../server/mcp.js'; +import { StdioServerTransport } from '../../server/stdio.js'; +import { z } from 'zod'; -const server = new McpServer( - { - name: "mcp-output-schema-high-level-example", - version: "1.0.0", - } -); +const server = new McpServer({ + name: 'mcp-output-schema-high-level-example', + version: '1.0.0' +}); // Define a tool with structured output - Weather data server.registerTool( - "get_weather", - { - description: "Get weather information for a city", - inputSchema: { - city: z.string().describe("City name"), - country: z.string().describe("Country code (e.g., US, UK)") - }, - outputSchema: { - temperature: z.object({ - celsius: z.number(), - fahrenheit: z.number() - }), - conditions: z.enum(["sunny", "cloudy", "rainy", "stormy", "snowy"]), - humidity: z.number().min(0).max(100), - wind: z.object({ - speed_kmh: z.number(), - direction: z.string() - }) + 'get_weather', + { + description: 'Get weather information for a city', + inputSchema: { + city: z.string().describe('City name'), + country: z.string().describe('Country code (e.g., US, UK)') + }, + outputSchema: { + temperature: z.object({ + celsius: z.number(), + fahrenheit: z.number() + }), + conditions: z.enum(['sunny', 'cloudy', 'rainy', 'stormy', 'snowy']), + humidity: z.number().min(0).max(100), + wind: z.object({ + speed_kmh: z.number(), + direction: z.string() + }) + } }, - }, - async ({ city, country }) => { - // Parameters are available but not used in this example - void city; - void country; - // Simulate weather API call - const temp_c = Math.round((Math.random() * 35 - 5) * 10) / 10; - const conditions = ["sunny", "cloudy", "rainy", "stormy", "snowy"][Math.floor(Math.random() * 5)]; + async ({ city, country }) => { + // Parameters are available but not used in this example + void city; + void country; + // Simulate weather API call + const temp_c = Math.round((Math.random() * 35 - 5) * 10) / 10; + const conditions = ['sunny', 'cloudy', 'rainy', 'stormy', 'snowy'][Math.floor(Math.random() * 5)]; - const structuredContent = { - temperature: { - celsius: temp_c, - fahrenheit: Math.round((temp_c * 9 / 5 + 32) * 10) / 10 - }, - conditions, - humidity: Math.round(Math.random() * 100), - wind: { - speed_kmh: Math.round(Math.random() * 50), - direction: ["N", "NE", "E", "SE", "S", "SW", "W", "NW"][Math.floor(Math.random() * 8)] - } - }; + const structuredContent = { + temperature: { + celsius: temp_c, + fahrenheit: Math.round(((temp_c * 9) / 5 + 32) * 10) / 10 + }, + conditions, + humidity: Math.round(Math.random() * 100), + wind: { + speed_kmh: Math.round(Math.random() * 50), + direction: ['N', 'NE', 'E', 'SE', 'S', 'SW', 'W', 'NW'][Math.floor(Math.random() * 8)] + } + }; - return { - content: [{ - type: "text", - text: JSON.stringify(structuredContent, null, 2) - }], - structuredContent - }; - } + return { + content: [ + { + type: 'text', + text: JSON.stringify(structuredContent, null, 2) + } + ], + structuredContent + }; + } ); async function main() { - const transport = new StdioServerTransport(); - await server.connect(transport); - console.error("High-level Output Schema Example Server running on stdio"); + const transport = new StdioServerTransport(); + await server.connect(transport); + console.error('High-level Output Schema Example Server running on stdio'); } -main().catch((error) => { - console.error("Server error:", error); - process.exit(1); -}); \ No newline at end of file +main().catch(error => { + console.error('Server error:', error); + process.exit(1); +}); diff --git a/src/examples/server/simpleSseServer.ts b/src/examples/server/simpleSseServer.ts index 664b15008..b99334369 100644 --- a/src/examples/server/simpleSseServer.ts +++ b/src/examples/server/simpleSseServer.ts @@ -16,55 +16,63 @@ import { CallToolResult } from '../../types.js'; // Create an MCP server instance const getServer = () => { - const server = new McpServer({ - name: 'simple-sse-server', - version: '1.0.0', - }, { capabilities: { logging: {} } }); - - server.tool( - 'start-notification-stream', - 'Starts sending periodic notifications', - { - interval: z.number().describe('Interval in milliseconds between notifications').default(1000), - count: z.number().describe('Number of notifications to send').default(10), - }, - async ({ interval, count }, extra): Promise => { - const sleep = (ms: number) => new Promise(resolve => setTimeout(resolve, ms)); - let counter = 0; - - // Send the initial notification - await server.sendLoggingMessage({ - level: "info", - data: `Starting notification stream with ${count} messages every ${interval}ms` - }, extra.sessionId); - - // Send periodic notifications - while (counter < count) { - counter++; - await sleep(interval); - - try { - await server.sendLoggingMessage({ - level: "info", - data: `Notification #${counter} at ${new Date().toISOString()}` - }, extra.sessionId); + const server = new McpServer( + { + name: 'simple-sse-server', + version: '1.0.0' + }, + { capabilities: { logging: {} } } + ); + + server.tool( + 'start-notification-stream', + 'Starts sending periodic notifications', + { + interval: z.number().describe('Interval in milliseconds between notifications').default(1000), + count: z.number().describe('Number of notifications to send').default(10) + }, + async ({ interval, count }, extra): Promise => { + const sleep = (ms: number) => new Promise(resolve => setTimeout(resolve, ms)); + let counter = 0; + + // Send the initial notification + await server.sendLoggingMessage( + { + level: 'info', + data: `Starting notification stream with ${count} messages every ${interval}ms` + }, + extra.sessionId + ); + + // Send periodic notifications + while (counter < count) { + counter++; + await sleep(interval); + + try { + await server.sendLoggingMessage( + { + level: 'info', + data: `Notification #${counter} at ${new Date().toISOString()}` + }, + extra.sessionId + ); + } catch (error) { + console.error('Error sending notification:', error); + } + } + + return { + content: [ + { + type: 'text', + text: `Completed sending ${count} notifications every ${interval}ms` + } + ] + }; } - catch (error) { - console.error("Error sending notification:", error); - } - } - - return { - content: [ - { - type: 'text', - text: `Completed sending ${count} notifications every ${interval}ms`, - } - ], - }; - } - ); - return server; + ); + return server; }; const app = express(); @@ -75,92 +83,92 @@ const transports: Record = {}; // SSE endpoint for establishing the stream app.get('/mcp', async (req: Request, res: Response) => { - console.log('Received GET request to /sse (establishing SSE stream)'); - - try { - // Create a new SSE transport for the client - // The endpoint for POST messages is '/messages' - const transport = new SSEServerTransport('/messages', res); - - // Store the transport by session ID - const sessionId = transport.sessionId; - transports[sessionId] = transport; - - // Set up onclose handler to clean up transport when closed - transport.onclose = () => { - console.log(`SSE transport closed for session ${sessionId}`); - delete transports[sessionId]; - }; - - // Connect the transport to the MCP server - const server = getServer(); - await server.connect(transport); - - console.log(`Established SSE stream with session ID: ${sessionId}`); - } catch (error) { - console.error('Error establishing SSE stream:', error); - if (!res.headersSent) { - res.status(500).send('Error establishing SSE stream'); + console.log('Received GET request to /sse (establishing SSE stream)'); + + try { + // Create a new SSE transport for the client + // The endpoint for POST messages is '/messages' + const transport = new SSEServerTransport('/messages', res); + + // Store the transport by session ID + const sessionId = transport.sessionId; + transports[sessionId] = transport; + + // Set up onclose handler to clean up transport when closed + transport.onclose = () => { + console.log(`SSE transport closed for session ${sessionId}`); + delete transports[sessionId]; + }; + + // Connect the transport to the MCP server + const server = getServer(); + await server.connect(transport); + + console.log(`Established SSE stream with session ID: ${sessionId}`); + } catch (error) { + console.error('Error establishing SSE stream:', error); + if (!res.headersSent) { + res.status(500).send('Error establishing SSE stream'); + } } - } }); // Messages endpoint for receiving client JSON-RPC requests app.post('/messages', async (req: Request, res: Response) => { - console.log('Received POST request to /messages'); - - // Extract session ID from URL query parameter - // In the SSE protocol, this is added by the client based on the endpoint event - const sessionId = req.query.sessionId as string | undefined; - - if (!sessionId) { - console.error('No session ID provided in request URL'); - res.status(400).send('Missing sessionId parameter'); - return; - } - - const transport = transports[sessionId]; - if (!transport) { - console.error(`No active transport found for session ID: ${sessionId}`); - res.status(404).send('Session not found'); - return; - } - - try { - // Handle the POST message with the transport - await transport.handlePostMessage(req, res, req.body); - } catch (error) { - console.error('Error handling request:', error); - if (!res.headersSent) { - res.status(500).send('Error handling request'); + console.log('Received POST request to /messages'); + + // Extract session ID from URL query parameter + // In the SSE protocol, this is added by the client based on the endpoint event + const sessionId = req.query.sessionId as string | undefined; + + if (!sessionId) { + console.error('No session ID provided in request URL'); + res.status(400).send('Missing sessionId parameter'); + return; + } + + const transport = transports[sessionId]; + if (!transport) { + console.error(`No active transport found for session ID: ${sessionId}`); + res.status(404).send('Session not found'); + return; + } + + try { + // Handle the POST message with the transport + await transport.handlePostMessage(req, res, req.body); + } catch (error) { + console.error('Error handling request:', error); + if (!res.headersSent) { + res.status(500).send('Error handling request'); + } } - } }); // Start the server const PORT = 3000; -app.listen(PORT, (error) => { - if (error) { - console.error('Failed to start server:', error); - process.exit(1); - } - console.log(`Simple SSE Server (deprecated protocol version 2024-11-05) listening on port ${PORT}`); +app.listen(PORT, error => { + if (error) { + console.error('Failed to start server:', error); + process.exit(1); + } + console.log(`Simple SSE Server (deprecated protocol version 2024-11-05) listening on port ${PORT}`); }); // Handle server shutdown process.on('SIGINT', async () => { - console.log('Shutting down server...'); + console.log('Shutting down server...'); - // Close all active transports to properly clean up resources - for (const sessionId in transports) { - try { - console.log(`Closing transport for session ${sessionId}`); - await transports[sessionId].close(); - delete transports[sessionId]; - } catch (error) { - console.error(`Error closing transport for session ${sessionId}:`, error); + // Close all active transports to properly clean up resources + for (const sessionId in transports) { + try { + console.log(`Closing transport for session ${sessionId}`); + await transports[sessionId].close(); + delete transports[sessionId]; + } catch (error) { + console.error(`Error closing transport for session ${sessionId}:`, error); + } } - } - console.log('Server shutdown complete'); - process.exit(0); + console.log('Server shutdown complete'); + process.exit(0); }); diff --git a/src/examples/server/simpleStatelessStreamableHttp.ts b/src/examples/server/simpleStatelessStreamableHttp.ts index d91f3a7b5..f71e5db6c 100644 --- a/src/examples/server/simpleStatelessStreamableHttp.ts +++ b/src/examples/server/simpleStatelessStreamableHttp.ts @@ -6,165 +6,175 @@ import { CallToolResult, GetPromptResult, ReadResourceResult } from '../../types import cors from 'cors'; const getServer = () => { - // Create an MCP server with implementation details - const server = new McpServer({ - name: 'stateless-streamable-http-server', - version: '1.0.0', - }, { capabilities: { logging: {} } }); + // Create an MCP server with implementation details + const server = new McpServer( + { + name: 'stateless-streamable-http-server', + version: '1.0.0' + }, + { capabilities: { logging: {} } } + ); - // Register a simple prompt - server.prompt( - 'greeting-template', - 'A simple greeting prompt template', - { - name: z.string().describe('Name to include in greeting'), - }, - async ({ name }): Promise => { - return { - messages: [ - { - role: 'user', - content: { - type: 'text', - text: `Please greet ${name} in a friendly manner.`, - }, - }, - ], - }; - } - ); + // Register a simple prompt + server.prompt( + 'greeting-template', + 'A simple greeting prompt template', + { + name: z.string().describe('Name to include in greeting') + }, + async ({ name }): Promise => { + return { + messages: [ + { + role: 'user', + content: { + type: 'text', + text: `Please greet ${name} in a friendly manner.` + } + } + ] + }; + } + ); - // Register a tool specifically for testing resumability - server.tool( - 'start-notification-stream', - 'Starts sending periodic notifications for testing resumability', - { - interval: z.number().describe('Interval in milliseconds between notifications').default(100), - count: z.number().describe('Number of notifications to send (0 for 100)').default(10), - }, - async ({ interval, count }, extra): Promise => { - const sleep = (ms: number) => new Promise(resolve => setTimeout(resolve, ms)); - let counter = 0; + // Register a tool specifically for testing resumability + server.tool( + 'start-notification-stream', + 'Starts sending periodic notifications for testing resumability', + { + interval: z.number().describe('Interval in milliseconds between notifications').default(100), + count: z.number().describe('Number of notifications to send (0 for 100)').default(10) + }, + async ({ interval, count }, extra): Promise => { + const sleep = (ms: number) => new Promise(resolve => setTimeout(resolve, ms)); + let counter = 0; - while (count === 0 || counter < count) { - counter++; - try { - await server.sendLoggingMessage({ - level: "info", - data: `Periodic notification #${counter} at ${new Date().toISOString()}` - }, extra.sessionId); - } - catch (error) { - console.error("Error sending notification:", error); - } - // Wait for the specified interval - await sleep(interval); - } + while (count === 0 || counter < count) { + counter++; + try { + await server.sendLoggingMessage( + { + level: 'info', + data: `Periodic notification #${counter} at ${new Date().toISOString()}` + }, + extra.sessionId + ); + } catch (error) { + console.error('Error sending notification:', error); + } + // Wait for the specified interval + await sleep(interval); + } - return { - content: [ - { - type: 'text', - text: `Started sending periodic notifications every ${interval}ms`, - } - ], - }; - } - ); + return { + content: [ + { + type: 'text', + text: `Started sending periodic notifications every ${interval}ms` + } + ] + }; + } + ); - // Create a simple resource at a fixed URI - server.resource( - 'greeting-resource', - 'https://example.com/greetings/default', - { mimeType: 'text/plain' }, - async (): Promise => { - return { - contents: [ - { - uri: 'https://example.com/greetings/default', - text: 'Hello, world!', - }, - ], - }; - } - ); - return server; -} + // Create a simple resource at a fixed URI + server.resource( + 'greeting-resource', + 'https://example.com/greetings/default', + { mimeType: 'text/plain' }, + async (): Promise => { + return { + contents: [ + { + uri: 'https://example.com/greetings/default', + text: 'Hello, world!' + } + ] + }; + } + ); + return server; +}; const app = express(); app.use(express.json()); // Configure CORS to expose Mcp-Session-Id header for browser-based clients -app.use(cors({ - origin: '*', // Allow all origins - adjust as needed for production - exposedHeaders: ['Mcp-Session-Id'] -})); +app.use( + cors({ + origin: '*', // Allow all origins - adjust as needed for production + exposedHeaders: ['Mcp-Session-Id'] + }) +); app.post('/mcp', async (req: Request, res: Response) => { - const server = getServer(); - try { - const transport: StreamableHTTPServerTransport = new StreamableHTTPServerTransport({ - sessionIdGenerator: undefined, - }); - await server.connect(transport); - await transport.handleRequest(req, res, req.body); - res.on('close', () => { - console.log('Request closed'); - transport.close(); - server.close(); - }); - } catch (error) { - console.error('Error handling MCP request:', error); - if (!res.headersSent) { - res.status(500).json({ - jsonrpc: '2.0', - error: { - code: -32603, - message: 'Internal server error', - }, - id: null, - }); + const server = getServer(); + try { + const transport: StreamableHTTPServerTransport = new StreamableHTTPServerTransport({ + sessionIdGenerator: undefined + }); + await server.connect(transport); + await transport.handleRequest(req, res, req.body); + res.on('close', () => { + console.log('Request closed'); + transport.close(); + server.close(); + }); + } catch (error) { + console.error('Error handling MCP request:', error); + if (!res.headersSent) { + res.status(500).json({ + jsonrpc: '2.0', + error: { + code: -32603, + message: 'Internal server error' + }, + id: null + }); + } } - } }); app.get('/mcp', async (req: Request, res: Response) => { - console.log('Received GET MCP request'); - res.writeHead(405).end(JSON.stringify({ - jsonrpc: "2.0", - error: { - code: -32000, - message: "Method not allowed." - }, - id: null - })); + console.log('Received GET MCP request'); + res.writeHead(405).end( + JSON.stringify({ + jsonrpc: '2.0', + error: { + code: -32000, + message: 'Method not allowed.' + }, + id: null + }) + ); }); app.delete('/mcp', async (req: Request, res: Response) => { - console.log('Received DELETE MCP request'); - res.writeHead(405).end(JSON.stringify({ - jsonrpc: "2.0", - error: { - code: -32000, - message: "Method not allowed." - }, - id: null - })); + console.log('Received DELETE MCP request'); + res.writeHead(405).end( + JSON.stringify({ + jsonrpc: '2.0', + error: { + code: -32000, + message: 'Method not allowed.' + }, + id: null + }) + ); }); - // Start the server const PORT = 3000; -app.listen(PORT, (error) => { - if (error) { - console.error('Failed to start server:', error); - process.exit(1); - } - console.log(`MCP Stateless Streamable HTTP Server listening on port ${PORT}`); +app.listen(PORT, error => { + if (error) { + console.error('Failed to start server:', error); + process.exit(1); + } + console.log(`MCP Stateless Streamable HTTP Server listening on port ${PORT}`); }); // Handle server shutdown process.on('SIGINT', async () => { - console.log('Shutting down server...'); - process.exit(0); + console.log('Shutting down server...'); + process.exit(0); }); diff --git a/src/examples/server/simpleStreamableHttp.ts b/src/examples/server/simpleStreamableHttp.ts index 6f1e20080..c3cef9191 100644 --- a/src/examples/server/simpleStreamableHttp.ts +++ b/src/examples/server/simpleStreamableHttp.ts @@ -5,7 +5,14 @@ import { McpServer } from '../../server/mcp.js'; import { StreamableHTTPServerTransport } from '../../server/streamableHttp.js'; import { getOAuthProtectedResourceMetadataUrl, mcpAuthMetadataRouter } from '../../server/auth/router.js'; import { requireBearerAuth } from '../../server/auth/middleware/bearerAuth.js'; -import { CallToolResult, GetPromptResult, isInitializeRequest, PrimitiveSchemaDefinition, ReadResourceResult, ResourceLink } from '../../types.js'; +import { + CallToolResult, + GetPromptResult, + isInitializeRequest, + PrimitiveSchemaDefinition, + ReadResourceResult, + ResourceLink +} from '../../types.js'; import { InMemoryEventStore } from '../shared/inMemoryEventStore.js'; import { setupAuthServer } from './demoInMemoryOAuthProvider.js'; import { OAuthMetadata } from 'src/shared/auth.js'; @@ -19,406 +26,420 @@ const strictOAuth = process.argv.includes('--oauth-strict'); // Create an MCP server with implementation details const getServer = () => { - const server = new McpServer({ - name: 'simple-streamable-http-server', - version: '1.0.0', - icons: [{src: './mcp.svg', sizes: '512x512', mimeType: 'image/svg+xml'}], - websiteUrl: 'https://github.com/modelcontextprotocol/typescript-sdk', - }, { capabilities: { logging: {} } }); - - // Register a simple tool that returns a greeting - server.registerTool( - 'greet', - { - title: 'Greeting Tool', // Display name for UI - description: 'A simple greeting tool', - inputSchema: { - name: z.string().describe('Name to greet'), - }, - }, - async ({ name }): Promise => { - return { - content: [ - { - type: 'text', - text: `Hello, ${name}!`, - }, - ], - }; - } - ); - - // Register a tool that sends multiple greetings with notifications (with annotations) - server.tool( - 'multi-greet', - 'A tool that sends different greetings with delays between them', - { - name: z.string().describe('Name to greet'), - }, - { - title: 'Multiple Greeting Tool', - readOnlyHint: true, - openWorldHint: false - }, - async ({ name }, extra): Promise => { - const sleep = (ms: number) => new Promise(resolve => setTimeout(resolve, ms)); - - await server.sendLoggingMessage({ - level: "debug", - data: `Starting multi-greet for ${name}` - }, extra.sessionId); - - await sleep(1000); // Wait 1 second before first greeting - - await server.sendLoggingMessage({ - level: "info", - data: `Sending first greeting to ${name}` - }, extra.sessionId); - - await sleep(1000); // Wait another second before second greeting - - await server.sendLoggingMessage({ - level: "info", - data: `Sending second greeting to ${name}` - }, extra.sessionId); - - return { - content: [ - { - type: 'text', - text: `Good morning, ${name}!`, - } - ], - }; - } - ); - // Register a tool that demonstrates elicitation (user input collection) - // This creates a closure that captures the server instance - server.tool( - 'collect-user-info', - 'A tool that collects user information through elicitation', - { - infoType: z.enum(['contact', 'preferences', 'feedback']).describe('Type of information to collect'), - }, - async ({ infoType }): Promise => { - let message: string; - let requestedSchema: { - type: 'object'; - properties: Record; - required?: string[]; - }; - - switch (infoType) { - case 'contact': - message = 'Please provide your contact information'; - requestedSchema = { - type: 'object', - properties: { - name: { - type: 'string', - title: 'Full Name', - description: 'Your full name', - }, - email: { - type: 'string', - title: 'Email Address', - description: 'Your email address', - format: 'email', - }, - phone: { - type: 'string', - title: 'Phone Number', - description: 'Your phone number (optional)', - }, - }, - required: ['name', 'email'], - }; - break; - case 'preferences': - message = 'Please set your preferences'; - requestedSchema = { - type: 'object', - properties: { - theme: { - type: 'string', - title: 'Theme', - description: 'Choose your preferred theme', - enum: ['light', 'dark', 'auto'], - enumNames: ['Light', 'Dark', 'Auto'], - }, - notifications: { - type: 'boolean', - title: 'Enable Notifications', - description: 'Would you like to receive notifications?', - default: true, - }, - frequency: { - type: 'string', - title: 'Notification Frequency', - description: 'How often would you like notifications?', - enum: ['daily', 'weekly', 'monthly'], - enumNames: ['Daily', 'Weekly', 'Monthly'], - }, - }, - required: ['theme'], - }; - break; - case 'feedback': - message = 'Please provide your feedback'; - requestedSchema = { - type: 'object', - properties: { - rating: { - type: 'integer', - title: 'Rating', - description: 'Rate your experience (1-5)', - minimum: 1, - maximum: 5, - }, - comments: { - type: 'string', - title: 'Comments', - description: 'Additional comments (optional)', - maxLength: 500, - }, - recommend: { - type: 'boolean', - title: 'Would you recommend this?', - description: 'Would you recommend this to others?', - }, - }, - required: ['rating', 'recommend'], - }; - break; - default: - throw new Error(`Unknown info type: ${infoType}`); - } - - try { - // Use the underlying server instance to elicit input from the client - const result = await server.server.elicitInput({ - message, - requestedSchema, - }); - - if (result.action === 'accept') { - return { - content: [ - { - type: 'text', - text: `Thank you! Collected ${infoType} information: ${JSON.stringify(result.content, null, 2)}`, - }, - ], - }; - } else if (result.action === 'decline') { - return { - content: [ - { - type: 'text', - text: `No information was collected. User declined ${infoType} information request.`, - }, - ], - }; - } else { - return { - content: [ - { - type: 'text', - text: `Information collection was cancelled by the user.`, - }, - ], - }; + const server = new McpServer( + { + name: 'simple-streamable-http-server', + version: '1.0.0', + icons: [{ src: './mcp.svg', sizes: '512x512', mimeType: 'image/svg+xml' }], + websiteUrl: 'https://github.com/modelcontextprotocol/typescript-sdk' + }, + { capabilities: { logging: {} } } + ); + + // Register a simple tool that returns a greeting + server.registerTool( + 'greet', + { + title: 'Greeting Tool', // Display name for UI + description: 'A simple greeting tool', + inputSchema: { + name: z.string().describe('Name to greet') + } + }, + async ({ name }): Promise => { + return { + content: [ + { + type: 'text', + text: `Hello, ${name}!` + } + ] + }; } - } catch (error) { - return { - content: [ - { - type: 'text', - text: `Error collecting ${infoType} information: ${error}`, - }, - ], - }; - } - } - ); - - // Register a simple prompt with title - server.registerPrompt( - 'greeting-template', - { - title: 'Greeting Template', // Display name for UI - description: 'A simple greeting prompt template', - argsSchema: { - name: z.string().describe('Name to include in greeting'), - }, - }, - async ({ name }): Promise => { - return { - messages: [ - { - role: 'user', - content: { - type: 'text', - text: `Please greet ${name} in a friendly manner.`, - }, - }, - ], - }; - } - ); - - // Register a tool specifically for testing resumability - server.tool( - 'start-notification-stream', - 'Starts sending periodic notifications for testing resumability', - { - interval: z.number().describe('Interval in milliseconds between notifications').default(100), - count: z.number().describe('Number of notifications to send (0 for 100)').default(50), - }, - async ({ interval, count }, extra): Promise => { - const sleep = (ms: number) => new Promise(resolve => setTimeout(resolve, ms)); - let counter = 0; - - while (count === 0 || counter < count) { - counter++; - try { - await server.sendLoggingMessage( { - level: "info", - data: `Periodic notification #${counter} at ${new Date().toISOString()}` - }, extra.sessionId); + ); + + // Register a tool that sends multiple greetings with notifications (with annotations) + server.tool( + 'multi-greet', + 'A tool that sends different greetings with delays between them', + { + name: z.string().describe('Name to greet') + }, + { + title: 'Multiple Greeting Tool', + readOnlyHint: true, + openWorldHint: false + }, + async ({ name }, extra): Promise => { + const sleep = (ms: number) => new Promise(resolve => setTimeout(resolve, ms)); + + await server.sendLoggingMessage( + { + level: 'debug', + data: `Starting multi-greet for ${name}` + }, + extra.sessionId + ); + + await sleep(1000); // Wait 1 second before first greeting + + await server.sendLoggingMessage( + { + level: 'info', + data: `Sending first greeting to ${name}` + }, + extra.sessionId + ); + + await sleep(1000); // Wait another second before second greeting + + await server.sendLoggingMessage( + { + level: 'info', + data: `Sending second greeting to ${name}` + }, + extra.sessionId + ); + + return { + content: [ + { + type: 'text', + text: `Good morning, ${name}!` + } + ] + }; } - catch (error) { - console.error("Error sending notification:", error); + ); + // Register a tool that demonstrates elicitation (user input collection) + // This creates a closure that captures the server instance + server.tool( + 'collect-user-info', + 'A tool that collects user information through elicitation', + { + infoType: z.enum(['contact', 'preferences', 'feedback']).describe('Type of information to collect') + }, + async ({ infoType }): Promise => { + let message: string; + let requestedSchema: { + type: 'object'; + properties: Record; + required?: string[]; + }; + + switch (infoType) { + case 'contact': + message = 'Please provide your contact information'; + requestedSchema = { + type: 'object', + properties: { + name: { + type: 'string', + title: 'Full Name', + description: 'Your full name' + }, + email: { + type: 'string', + title: 'Email Address', + description: 'Your email address', + format: 'email' + }, + phone: { + type: 'string', + title: 'Phone Number', + description: 'Your phone number (optional)' + } + }, + required: ['name', 'email'] + }; + break; + case 'preferences': + message = 'Please set your preferences'; + requestedSchema = { + type: 'object', + properties: { + theme: { + type: 'string', + title: 'Theme', + description: 'Choose your preferred theme', + enum: ['light', 'dark', 'auto'], + enumNames: ['Light', 'Dark', 'Auto'] + }, + notifications: { + type: 'boolean', + title: 'Enable Notifications', + description: 'Would you like to receive notifications?', + default: true + }, + frequency: { + type: 'string', + title: 'Notification Frequency', + description: 'How often would you like notifications?', + enum: ['daily', 'weekly', 'monthly'], + enumNames: ['Daily', 'Weekly', 'Monthly'] + } + }, + required: ['theme'] + }; + break; + case 'feedback': + message = 'Please provide your feedback'; + requestedSchema = { + type: 'object', + properties: { + rating: { + type: 'integer', + title: 'Rating', + description: 'Rate your experience (1-5)', + minimum: 1, + maximum: 5 + }, + comments: { + type: 'string', + title: 'Comments', + description: 'Additional comments (optional)', + maxLength: 500 + }, + recommend: { + type: 'boolean', + title: 'Would you recommend this?', + description: 'Would you recommend this to others?' + } + }, + required: ['rating', 'recommend'] + }; + break; + default: + throw new Error(`Unknown info type: ${infoType}`); + } + + try { + // Use the underlying server instance to elicit input from the client + const result = await server.server.elicitInput({ + message, + requestedSchema + }); + + if (result.action === 'accept') { + return { + content: [ + { + type: 'text', + text: `Thank you! Collected ${infoType} information: ${JSON.stringify(result.content, null, 2)}` + } + ] + }; + } else if (result.action === 'decline') { + return { + content: [ + { + type: 'text', + text: `No information was collected. User declined ${infoType} information request.` + } + ] + }; + } else { + return { + content: [ + { + type: 'text', + text: `Information collection was cancelled by the user.` + } + ] + }; + } + } catch (error) { + return { + content: [ + { + type: 'text', + text: `Error collecting ${infoType} information: ${error}` + } + ] + }; + } } - // Wait for the specified interval - await sleep(interval); - } - - return { - content: [ - { - type: 'text', - text: `Started sending periodic notifications every ${interval}ms`, - } - ], - }; - } - ); - - // Create a simple resource at a fixed URI - server.registerResource( - 'greeting-resource', - 'https://example.com/greetings/default', - { - title: 'Default Greeting', // Display name for UI - description: 'A simple greeting resource', - mimeType: 'text/plain' - }, - async (): Promise => { - return { - contents: [ - { - uri: 'https://example.com/greetings/default', - text: 'Hello, world!', - }, - ], - }; - } - ); - - // Create additional resources for ResourceLink demonstration - server.registerResource( - 'example-file-1', - 'file:///example/file1.txt', - { - title: 'Example File 1', - description: 'First example file for ResourceLink demonstration', - mimeType: 'text/plain' - }, - async (): Promise => { - return { - contents: [ - { - uri: 'file:///example/file1.txt', - text: 'This is the content of file 1', - }, - ], - }; - } - ); - - server.registerResource( - 'example-file-2', - 'file:///example/file2.txt', - { - title: 'Example File 2', - description: 'Second example file for ResourceLink demonstration', - mimeType: 'text/plain' - }, - async (): Promise => { - return { - contents: [ - { - uri: 'file:///example/file2.txt', - text: 'This is the content of file 2', - }, - ], - }; - } - ); - - // Register a tool that returns ResourceLinks - server.registerTool( - 'list-files', - { - title: 'List Files with ResourceLinks', - description: 'Returns a list of files as ResourceLinks without embedding their content', - inputSchema: { - includeDescriptions: z.boolean().optional().describe('Whether to include descriptions in the resource links'), - }, - }, - async ({ includeDescriptions = true }): Promise => { - const resourceLinks: ResourceLink[] = [ + ); + + // Register a simple prompt with title + server.registerPrompt( + 'greeting-template', { - type: 'resource_link', - uri: 'https://example.com/greetings/default', - name: 'Default Greeting', - mimeType: 'text/plain', - ...(includeDescriptions && { description: 'A simple greeting resource' }) + title: 'Greeting Template', // Display name for UI + description: 'A simple greeting prompt template', + argsSchema: { + name: z.string().describe('Name to include in greeting') + } }, + async ({ name }): Promise => { + return { + messages: [ + { + role: 'user', + content: { + type: 'text', + text: `Please greet ${name} in a friendly manner.` + } + } + ] + }; + } + ); + + // Register a tool specifically for testing resumability + server.tool( + 'start-notification-stream', + 'Starts sending periodic notifications for testing resumability', { - type: 'resource_link', - uri: 'file:///example/file1.txt', - name: 'Example File 1', - mimeType: 'text/plain', - ...(includeDescriptions && { description: 'First example file for ResourceLink demonstration' }) + interval: z.number().describe('Interval in milliseconds between notifications').default(100), + count: z.number().describe('Number of notifications to send (0 for 100)').default(50) }, + async ({ interval, count }, extra): Promise => { + const sleep = (ms: number) => new Promise(resolve => setTimeout(resolve, ms)); + let counter = 0; + + while (count === 0 || counter < count) { + counter++; + try { + await server.sendLoggingMessage( + { + level: 'info', + data: `Periodic notification #${counter} at ${new Date().toISOString()}` + }, + extra.sessionId + ); + } catch (error) { + console.error('Error sending notification:', error); + } + // Wait for the specified interval + await sleep(interval); + } + + return { + content: [ + { + type: 'text', + text: `Started sending periodic notifications every ${interval}ms` + } + ] + }; + } + ); + + // Create a simple resource at a fixed URI + server.registerResource( + 'greeting-resource', + 'https://example.com/greetings/default', { - type: 'resource_link', - uri: 'file:///example/file2.txt', - name: 'Example File 2', - mimeType: 'text/plain', - ...(includeDescriptions && { description: 'Second example file for ResourceLink demonstration' }) + title: 'Default Greeting', // Display name for UI + description: 'A simple greeting resource', + mimeType: 'text/plain' + }, + async (): Promise => { + return { + contents: [ + { + uri: 'https://example.com/greetings/default', + text: 'Hello, world!' + } + ] + }; } - ]; - - return { - content: [ - { - type: 'text', - text: 'Here are the available files as resource links:', - }, - ...resourceLinks, - { - type: 'text', - text: '\nYou can read any of these resources using their URI.', - } - ], - }; - } - ); + ); + + // Create additional resources for ResourceLink demonstration + server.registerResource( + 'example-file-1', + 'file:///example/file1.txt', + { + title: 'Example File 1', + description: 'First example file for ResourceLink demonstration', + mimeType: 'text/plain' + }, + async (): Promise => { + return { + contents: [ + { + uri: 'file:///example/file1.txt', + text: 'This is the content of file 1' + } + ] + }; + } + ); - return server; + server.registerResource( + 'example-file-2', + 'file:///example/file2.txt', + { + title: 'Example File 2', + description: 'Second example file for ResourceLink demonstration', + mimeType: 'text/plain' + }, + async (): Promise => { + return { + contents: [ + { + uri: 'file:///example/file2.txt', + text: 'This is the content of file 2' + } + ] + }; + } + ); + + // Register a tool that returns ResourceLinks + server.registerTool( + 'list-files', + { + title: 'List Files with ResourceLinks', + description: 'Returns a list of files as ResourceLinks without embedding their content', + inputSchema: { + includeDescriptions: z.boolean().optional().describe('Whether to include descriptions in the resource links') + } + }, + async ({ includeDescriptions = true }): Promise => { + const resourceLinks: ResourceLink[] = [ + { + type: 'resource_link', + uri: 'https://example.com/greetings/default', + name: 'Default Greeting', + mimeType: 'text/plain', + ...(includeDescriptions && { description: 'A simple greeting resource' }) + }, + { + type: 'resource_link', + uri: 'file:///example/file1.txt', + name: 'Example File 1', + mimeType: 'text/plain', + ...(includeDescriptions && { description: 'First example file for ResourceLink demonstration' }) + }, + { + type: 'resource_link', + uri: 'file:///example/file2.txt', + name: 'Example File 2', + mimeType: 'text/plain', + ...(includeDescriptions && { description: 'Second example file for ResourceLink demonstration' }) + } + ]; + + return { + content: [ + { + type: 'text', + text: 'Here are the available files as resource links:' + }, + ...resourceLinks, + { + type: 'text', + text: '\nYou can read any of these resources using their URI.' + } + ] + }; + } + ); + + return server; }; const MCP_PORT = process.env.MCP_PORT ? parseInt(process.env.MCP_PORT, 10) : 3000; @@ -428,76 +449,79 @@ const app = express(); app.use(express.json()); // Allow CORS all domains, expose the Mcp-Session-Id header -app.use(cors({ - origin: '*', // Allow all origins - exposedHeaders: ["Mcp-Session-Id"] -})); +app.use( + cors({ + origin: '*', // Allow all origins + exposedHeaders: ['Mcp-Session-Id'] + }) +); // Set up OAuth if enabled let authMiddleware = null; if (useOAuth) { - // Create auth middleware for MCP endpoints - const mcpServerUrl = new URL(`http://localhost:${MCP_PORT}/mcp`); - const authServerUrl = new URL(`http://localhost:${AUTH_PORT}`); - - const oauthMetadata: OAuthMetadata = setupAuthServer({ authServerUrl, mcpServerUrl, strictResource: strictOAuth }); - - const tokenVerifier = { - verifyAccessToken: async (token: string) => { - const endpoint = oauthMetadata.introspection_endpoint; - - if (!endpoint) { - throw new Error('No token verification endpoint available in metadata'); - } - - const response = await fetch(endpoint, { - method: 'POST', - headers: { - 'Content-Type': 'application/x-www-form-urlencoded', - }, - body: new URLSearchParams({ - token: token - }).toString() - }); - - - if (!response.ok) { - throw new Error(`Invalid or expired token: ${await response.text()}`); - } - - const data = await response.json(); - - if (strictOAuth) { - if (!data.aud) { - throw new Error(`Resource Indicator (RFC8707) missing`); - } - if (!checkResourceAllowed({ requestedResource: data.aud, configuredResource: mcpServerUrl })) { - throw new Error(`Expected resource indicator ${mcpServerUrl}, got: ${data.aud}`); + // Create auth middleware for MCP endpoints + const mcpServerUrl = new URL(`http://localhost:${MCP_PORT}/mcp`); + const authServerUrl = new URL(`http://localhost:${AUTH_PORT}`); + + const oauthMetadata: OAuthMetadata = setupAuthServer({ authServerUrl, mcpServerUrl, strictResource: strictOAuth }); + + const tokenVerifier = { + verifyAccessToken: async (token: string) => { + const endpoint = oauthMetadata.introspection_endpoint; + + if (!endpoint) { + throw new Error('No token verification endpoint available in metadata'); + } + + const response = await fetch(endpoint, { + method: 'POST', + headers: { + 'Content-Type': 'application/x-www-form-urlencoded' + }, + body: new URLSearchParams({ + token: token + }).toString() + }); + + if (!response.ok) { + throw new Error(`Invalid or expired token: ${await response.text()}`); + } + + const data = await response.json(); + + if (strictOAuth) { + if (!data.aud) { + throw new Error(`Resource Indicator (RFC8707) missing`); + } + if (!checkResourceAllowed({ requestedResource: data.aud, configuredResource: mcpServerUrl })) { + throw new Error(`Expected resource indicator ${mcpServerUrl}, got: ${data.aud}`); + } + } + + // Convert the response to AuthInfo format + return { + token, + clientId: data.client_id, + scopes: data.scope ? data.scope.split(' ') : [], + expiresAt: data.exp + }; } - } - - // Convert the response to AuthInfo format - return { - token, - clientId: data.client_id, - scopes: data.scope ? data.scope.split(' ') : [], - expiresAt: data.exp, - }; - } - } - // Add metadata routes to the main MCP server - app.use(mcpAuthMetadataRouter({ - oauthMetadata, - resourceServerUrl: mcpServerUrl, - scopesSupported: ['mcp:tools'], - resourceName: 'MCP Demo Server', - })); - - authMiddleware = requireBearerAuth({ - verifier: tokenVerifier, - requiredScopes: [], - resourceMetadataUrl: getOAuthProtectedResourceMetadataUrl(mcpServerUrl), - }); + }; + // Add metadata routes to the main MCP server + app.use( + mcpAuthMetadataRouter({ + oauthMetadata, + resourceServerUrl: mcpServerUrl, + scopesSupported: ['mcp:tools'], + resourceName: 'MCP Demo Server' + }) + ); + + authMiddleware = requireBearerAuth({ + verifier: tokenVerifier, + requiredScopes: [], + resourceMetadataUrl: getOAuthProtectedResourceMetadataUrl(mcpServerUrl) + }); } // Map to store transports by session ID @@ -505,170 +529,170 @@ const transports: { [sessionId: string]: StreamableHTTPServerTransport } = {}; // MCP POST endpoint with optional auth const mcpPostHandler = async (req: Request, res: Response) => { - const sessionId = req.headers['mcp-session-id'] as string | undefined; - if (sessionId) { - console.log(`Received MCP request for session: ${sessionId}`); - } else { - console.log('Request body:', req.body); - } - - if (useOAuth && req.auth) { - console.log('Authenticated user:', req.auth); - } - try { - let transport: StreamableHTTPServerTransport; - if (sessionId && transports[sessionId]) { - // Reuse existing transport - transport = transports[sessionId]; - } else if (!sessionId && isInitializeRequest(req.body)) { - // New initialization request - const eventStore = new InMemoryEventStore(); - transport = new StreamableHTTPServerTransport({ - sessionIdGenerator: () => randomUUID(), - eventStore, // Enable resumability - onsessioninitialized: (sessionId) => { - // Store the transport by session ID when session is initialized - // This avoids race conditions where requests might come in before the session is stored - console.log(`Session initialized with ID: ${sessionId}`); - transports[sessionId] = transport; - } - }); - - // Set up onclose handler to clean up transport when closed - transport.onclose = () => { - const sid = transport.sessionId; - if (sid && transports[sid]) { - console.log(`Transport closed for session ${sid}, removing from transports map`); - delete transports[sid]; - } - }; - - // Connect the transport to the MCP server BEFORE handling the request - // so responses can flow back through the same transport - const server = getServer(); - await server.connect(transport); - - await transport.handleRequest(req, res, req.body); - return; // Already handled + const sessionId = req.headers['mcp-session-id'] as string | undefined; + if (sessionId) { + console.log(`Received MCP request for session: ${sessionId}`); } else { - // Invalid request - no session ID or not initialization request - res.status(400).json({ - jsonrpc: '2.0', - error: { - code: -32000, - message: 'Bad Request: No valid session ID provided', - }, - id: null, - }); - return; + console.log('Request body:', req.body); } - // Handle the request with existing transport - no need to reconnect - // The existing transport is already connected to the server - await transport.handleRequest(req, res, req.body); - } catch (error) { - console.error('Error handling MCP request:', error); - if (!res.headersSent) { - res.status(500).json({ - jsonrpc: '2.0', - error: { - code: -32603, - message: 'Internal server error', - }, - id: null, - }); + if (useOAuth && req.auth) { + console.log('Authenticated user:', req.auth); + } + try { + let transport: StreamableHTTPServerTransport; + if (sessionId && transports[sessionId]) { + // Reuse existing transport + transport = transports[sessionId]; + } else if (!sessionId && isInitializeRequest(req.body)) { + // New initialization request + const eventStore = new InMemoryEventStore(); + transport = new StreamableHTTPServerTransport({ + sessionIdGenerator: () => randomUUID(), + eventStore, // Enable resumability + onsessioninitialized: sessionId => { + // Store the transport by session ID when session is initialized + // This avoids race conditions where requests might come in before the session is stored + console.log(`Session initialized with ID: ${sessionId}`); + transports[sessionId] = transport; + } + }); + + // Set up onclose handler to clean up transport when closed + transport.onclose = () => { + const sid = transport.sessionId; + if (sid && transports[sid]) { + console.log(`Transport closed for session ${sid}, removing from transports map`); + delete transports[sid]; + } + }; + + // Connect the transport to the MCP server BEFORE handling the request + // so responses can flow back through the same transport + const server = getServer(); + await server.connect(transport); + + await transport.handleRequest(req, res, req.body); + return; // Already handled + } else { + // Invalid request - no session ID or not initialization request + res.status(400).json({ + jsonrpc: '2.0', + error: { + code: -32000, + message: 'Bad Request: No valid session ID provided' + }, + id: null + }); + return; + } + + // Handle the request with existing transport - no need to reconnect + // The existing transport is already connected to the server + await transport.handleRequest(req, res, req.body); + } catch (error) { + console.error('Error handling MCP request:', error); + if (!res.headersSent) { + res.status(500).json({ + jsonrpc: '2.0', + error: { + code: -32603, + message: 'Internal server error' + }, + id: null + }); + } } - } }; // Set up routes with conditional auth middleware if (useOAuth && authMiddleware) { - app.post('/mcp', authMiddleware, mcpPostHandler); + app.post('/mcp', authMiddleware, mcpPostHandler); } else { - app.post('/mcp', mcpPostHandler); + app.post('/mcp', mcpPostHandler); } // Handle GET requests for SSE streams (using built-in support from StreamableHTTP) const mcpGetHandler = async (req: Request, res: Response) => { - const sessionId = req.headers['mcp-session-id'] as string | undefined; - if (!sessionId || !transports[sessionId]) { - res.status(400).send('Invalid or missing session ID'); - return; - } - - if (useOAuth && req.auth) { - console.log('Authenticated SSE connection from user:', req.auth); - } - - // Check for Last-Event-ID header for resumability - const lastEventId = req.headers['last-event-id'] as string | undefined; - if (lastEventId) { - console.log(`Client reconnecting with Last-Event-ID: ${lastEventId}`); - } else { - console.log(`Establishing new SSE stream for session ${sessionId}`); - } - - const transport = transports[sessionId]; - await transport.handleRequest(req, res); + const sessionId = req.headers['mcp-session-id'] as string | undefined; + if (!sessionId || !transports[sessionId]) { + res.status(400).send('Invalid or missing session ID'); + return; + } + + if (useOAuth && req.auth) { + console.log('Authenticated SSE connection from user:', req.auth); + } + + // Check for Last-Event-ID header for resumability + const lastEventId = req.headers['last-event-id'] as string | undefined; + if (lastEventId) { + console.log(`Client reconnecting with Last-Event-ID: ${lastEventId}`); + } else { + console.log(`Establishing new SSE stream for session ${sessionId}`); + } + + const transport = transports[sessionId]; + await transport.handleRequest(req, res); }; // Set up GET route with conditional auth middleware if (useOAuth && authMiddleware) { - app.get('/mcp', authMiddleware, mcpGetHandler); + app.get('/mcp', authMiddleware, mcpGetHandler); } else { - app.get('/mcp', mcpGetHandler); + app.get('/mcp', mcpGetHandler); } // Handle DELETE requests for session termination (according to MCP spec) const mcpDeleteHandler = async (req: Request, res: Response) => { - const sessionId = req.headers['mcp-session-id'] as string | undefined; - if (!sessionId || !transports[sessionId]) { - res.status(400).send('Invalid or missing session ID'); - return; - } + const sessionId = req.headers['mcp-session-id'] as string | undefined; + if (!sessionId || !transports[sessionId]) { + res.status(400).send('Invalid or missing session ID'); + return; + } - console.log(`Received session termination request for session ${sessionId}`); + console.log(`Received session termination request for session ${sessionId}`); - try { - const transport = transports[sessionId]; - await transport.handleRequest(req, res); - } catch (error) { - console.error('Error handling session termination:', error); - if (!res.headersSent) { - res.status(500).send('Error processing session termination'); + try { + const transport = transports[sessionId]; + await transport.handleRequest(req, res); + } catch (error) { + console.error('Error handling session termination:', error); + if (!res.headersSent) { + res.status(500).send('Error processing session termination'); + } } - } }; // Set up DELETE route with conditional auth middleware if (useOAuth && authMiddleware) { - app.delete('/mcp', authMiddleware, mcpDeleteHandler); + app.delete('/mcp', authMiddleware, mcpDeleteHandler); } else { - app.delete('/mcp', mcpDeleteHandler); + app.delete('/mcp', mcpDeleteHandler); } -app.listen(MCP_PORT, (error) => { - if (error) { - console.error('Failed to start server:', error); - process.exit(1); - } - console.log(`MCP Streamable HTTP Server listening on port ${MCP_PORT}`); +app.listen(MCP_PORT, error => { + if (error) { + console.error('Failed to start server:', error); + process.exit(1); + } + console.log(`MCP Streamable HTTP Server listening on port ${MCP_PORT}`); }); // Handle server shutdown process.on('SIGINT', async () => { - console.log('Shutting down server...'); + console.log('Shutting down server...'); - // Close all active transports to properly clean up resources - for (const sessionId in transports) { - try { - console.log(`Closing transport for session ${sessionId}`); - await transports[sessionId].close(); - delete transports[sessionId]; - } catch (error) { - console.error(`Error closing transport for session ${sessionId}:`, error); + // Close all active transports to properly clean up resources + for (const sessionId in transports) { + try { + console.log(`Closing transport for session ${sessionId}`); + await transports[sessionId].close(); + delete transports[sessionId]; + } catch (error) { + console.error(`Error closing transport for session ${sessionId}:`, error); + } } - } - console.log('Server shutdown complete'); - process.exit(0); + console.log('Server shutdown complete'); + process.exit(0); }); diff --git a/src/examples/server/sseAndStreamableHttpCompatibleServer.ts b/src/examples/server/sseAndStreamableHttpCompatibleServer.ts index a9d9b63d7..50e2e5125 100644 --- a/src/examples/server/sseAndStreamableHttpCompatibleServer.ts +++ b/src/examples/server/sseAndStreamableHttpCompatibleServer.ts @@ -1,5 +1,5 @@ import express, { Request, Response } from 'express'; -import { randomUUID } from "node:crypto"; +import { randomUUID } from 'node:crypto'; import { McpServer } from '../../server/mcp.js'; import { StreamableHTTPServerTransport } from '../../server/streamableHttp.js'; import { SSEServerTransport } from '../../server/sse.js'; @@ -20,49 +20,54 @@ import cors from 'cors'; */ const getServer = () => { - const server = new McpServer({ - name: 'backwards-compatible-server', - version: '1.0.0', - }, { capabilities: { logging: {} } }); - - // Register a simple tool that sends notifications over time - server.tool( - 'start-notification-stream', - 'Starts sending periodic notifications for testing resumability', - { - interval: z.number().describe('Interval in milliseconds between notifications').default(100), - count: z.number().describe('Number of notifications to send (0 for 100)').default(50), - }, - async ({ interval, count }, extra): Promise => { - const sleep = (ms: number) => new Promise(resolve => setTimeout(resolve, ms)); - let counter = 0; - - while (count === 0 || counter < count) { - counter++; - try { - await server.sendLoggingMessage({ - level: "info", - data: `Periodic notification #${counter} at ${new Date().toISOString()}` - }, extra.sessionId); - } - catch (error) { - console.error("Error sending notification:", error); + const server = new McpServer( + { + name: 'backwards-compatible-server', + version: '1.0.0' + }, + { capabilities: { logging: {} } } + ); + + // Register a simple tool that sends notifications over time + server.tool( + 'start-notification-stream', + 'Starts sending periodic notifications for testing resumability', + { + interval: z.number().describe('Interval in milliseconds between notifications').default(100), + count: z.number().describe('Number of notifications to send (0 for 100)').default(50) + }, + async ({ interval, count }, extra): Promise => { + const sleep = (ms: number) => new Promise(resolve => setTimeout(resolve, ms)); + let counter = 0; + + while (count === 0 || counter < count) { + counter++; + try { + await server.sendLoggingMessage( + { + level: 'info', + data: `Periodic notification #${counter} at ${new Date().toISOString()}` + }, + extra.sessionId + ); + } catch (error) { + console.error('Error sending notification:', error); + } + // Wait for the specified interval + await sleep(interval); + } + + return { + content: [ + { + type: 'text', + text: `Started sending periodic notifications every ${interval}ms` + } + ] + }; } - // Wait for the specified interval - await sleep(interval); - } - - return { - content: [ - { - type: 'text', - text: `Started sending periodic notifications every ${interval}ms`, - } - ], - }; - } - ); - return server; + ); + return server; }; // Create Express application @@ -70,10 +75,12 @@ const app = express(); app.use(express.json()); // Configure CORS to expose Mcp-Session-Id header for browser-based clients -app.use(cors({ - origin: '*', // Allow all origins - adjust as needed for production - exposedHeaders: ['Mcp-Session-Id'] -})); +app.use( + cors({ + origin: '*', // Allow all origins - adjust as needed for production + exposedHeaders: ['Mcp-Session-Id'] + }) +); // Store transports by session ID const transports: Record = {}; @@ -84,83 +91,83 @@ const transports: Record { - console.log(`Received ${req.method} request to /mcp`); + console.log(`Received ${req.method} request to /mcp`); - try { - // Check for existing session ID - const sessionId = req.headers['mcp-session-id'] as string | undefined; - let transport: StreamableHTTPServerTransport; + try { + // Check for existing session ID + const sessionId = req.headers['mcp-session-id'] as string | undefined; + let transport: StreamableHTTPServerTransport; - if (sessionId && transports[sessionId]) { - // Check if the transport is of the correct type - const existingTransport = transports[sessionId]; - if (existingTransport instanceof StreamableHTTPServerTransport) { - // Reuse existing transport - transport = existingTransport; - } else { - // Transport exists but is not a StreamableHTTPServerTransport (could be SSEServerTransport) - res.status(400).json({ - jsonrpc: '2.0', - error: { - code: -32000, - message: 'Bad Request: Session exists but uses a different transport protocol', - }, - id: null, - }); - return; - } - } else if (!sessionId && req.method === 'POST' && isInitializeRequest(req.body)) { - const eventStore = new InMemoryEventStore(); - transport = new StreamableHTTPServerTransport({ - sessionIdGenerator: () => randomUUID(), - eventStore, // Enable resumability - onsessioninitialized: (sessionId) => { - // Store the transport by session ID when session is initialized - console.log(`StreamableHTTP session initialized with ID: ${sessionId}`); - transports[sessionId] = transport; - } - }); - - // Set up onclose handler to clean up transport when closed - transport.onclose = () => { - const sid = transport.sessionId; - if (sid && transports[sid]) { - console.log(`Transport closed for session ${sid}, removing from transports map`); - delete transports[sid]; - } - }; + if (sessionId && transports[sessionId]) { + // Check if the transport is of the correct type + const existingTransport = transports[sessionId]; + if (existingTransport instanceof StreamableHTTPServerTransport) { + // Reuse existing transport + transport = existingTransport; + } else { + // Transport exists but is not a StreamableHTTPServerTransport (could be SSEServerTransport) + res.status(400).json({ + jsonrpc: '2.0', + error: { + code: -32000, + message: 'Bad Request: Session exists but uses a different transport protocol' + }, + id: null + }); + return; + } + } else if (!sessionId && req.method === 'POST' && isInitializeRequest(req.body)) { + const eventStore = new InMemoryEventStore(); + transport = new StreamableHTTPServerTransport({ + sessionIdGenerator: () => randomUUID(), + eventStore, // Enable resumability + onsessioninitialized: sessionId => { + // Store the transport by session ID when session is initialized + console.log(`StreamableHTTP session initialized with ID: ${sessionId}`); + transports[sessionId] = transport; + } + }); - // Connect the transport to the MCP server - const server = getServer(); - await server.connect(transport); - } else { - // Invalid request - no session ID or not initialization request - res.status(400).json({ - jsonrpc: '2.0', - error: { - code: -32000, - message: 'Bad Request: No valid session ID provided', - }, - id: null, - }); - return; - } + // Set up onclose handler to clean up transport when closed + transport.onclose = () => { + const sid = transport.sessionId; + if (sid && transports[sid]) { + console.log(`Transport closed for session ${sid}, removing from transports map`); + delete transports[sid]; + } + }; - // Handle the request with the transport - await transport.handleRequest(req, res, req.body); - } catch (error) { - console.error('Error handling MCP request:', error); - if (!res.headersSent) { - res.status(500).json({ - jsonrpc: '2.0', - error: { - code: -32603, - message: 'Internal server error', - }, - id: null, - }); + // Connect the transport to the MCP server + const server = getServer(); + await server.connect(transport); + } else { + // Invalid request - no session ID or not initialization request + res.status(400).json({ + jsonrpc: '2.0', + error: { + code: -32000, + message: 'Bad Request: No valid session ID provided' + }, + id: null + }); + return; + } + + // Handle the request with the transport + await transport.handleRequest(req, res, req.body); + } catch (error) { + console.error('Error handling MCP request:', error); + if (!res.headersSent) { + res.status(500).json({ + jsonrpc: '2.0', + error: { + code: -32603, + message: 'Internal server error' + }, + id: null + }); + } } - } }); //============================================================================= @@ -168,52 +175,51 @@ app.all('/mcp', async (req: Request, res: Response) => { //============================================================================= app.get('/sse', async (req: Request, res: Response) => { - console.log('Received GET request to /sse (deprecated SSE transport)'); - const transport = new SSEServerTransport('/messages', res); - transports[transport.sessionId] = transport; - res.on("close", () => { - delete transports[transport.sessionId]; - }); - const server = getServer(); - await server.connect(transport); -}); - -app.post("/messages", async (req: Request, res: Response) => { - const sessionId = req.query.sessionId as string; - let transport: SSEServerTransport; - const existingTransport = transports[sessionId]; - if (existingTransport instanceof SSEServerTransport) { - // Reuse existing transport - transport = existingTransport; - } else { - // Transport exists but is not a SSEServerTransport (could be StreamableHTTPServerTransport) - res.status(400).json({ - jsonrpc: '2.0', - error: { - code: -32000, - message: 'Bad Request: Session exists but uses a different transport protocol', - }, - id: null, + console.log('Received GET request to /sse (deprecated SSE transport)'); + const transport = new SSEServerTransport('/messages', res); + transports[transport.sessionId] = transport; + res.on('close', () => { + delete transports[transport.sessionId]; }); - return; - } - if (transport) { - await transport.handlePostMessage(req, res, req.body); - } else { - res.status(400).send('No transport found for sessionId'); - } + const server = getServer(); + await server.connect(transport); }); +app.post('/messages', async (req: Request, res: Response) => { + const sessionId = req.query.sessionId as string; + let transport: SSEServerTransport; + const existingTransport = transports[sessionId]; + if (existingTransport instanceof SSEServerTransport) { + // Reuse existing transport + transport = existingTransport; + } else { + // Transport exists but is not a SSEServerTransport (could be StreamableHTTPServerTransport) + res.status(400).json({ + jsonrpc: '2.0', + error: { + code: -32000, + message: 'Bad Request: Session exists but uses a different transport protocol' + }, + id: null + }); + return; + } + if (transport) { + await transport.handlePostMessage(req, res, req.body); + } else { + res.status(400).send('No transport found for sessionId'); + } +}); // Start the server const PORT = 3000; -app.listen(PORT, (error) => { - if (error) { - console.error('Failed to start server:', error); - process.exit(1); - } - console.log(`Backwards compatible MCP server listening on port ${PORT}`); - console.log(` +app.listen(PORT, error => { + if (error) { + console.error('Failed to start server:', error); + process.exit(1); + } + console.log(`Backwards compatible MCP server listening on port ${PORT}`); + console.log(` ============================================== SUPPORTED TRANSPORT OPTIONS: @@ -237,18 +243,18 @@ SUPPORTED TRANSPORT OPTIONS: // Handle server shutdown process.on('SIGINT', async () => { - console.log('Shutting down server...'); + console.log('Shutting down server...'); - // Close all active transports to properly clean up resources - for (const sessionId in transports) { - try { - console.log(`Closing transport for session ${sessionId}`); - await transports[sessionId].close(); - delete transports[sessionId]; - } catch (error) { - console.error(`Error closing transport for session ${sessionId}:`, error); + // Close all active transports to properly clean up resources + for (const sessionId in transports) { + try { + console.log(`Closing transport for session ${sessionId}`); + await transports[sessionId].close(); + delete transports[sessionId]; + } catch (error) { + console.error(`Error closing transport for session ${sessionId}:`, error); + } } - } - console.log('Server shutdown complete'); - process.exit(0); + console.log('Server shutdown complete'); + process.exit(0); }); diff --git a/src/examples/server/standaloneSseWithGetStreamableHttp.ts b/src/examples/server/standaloneSseWithGetStreamableHttp.ts index 279818139..6229c53a4 100644 --- a/src/examples/server/standaloneSseWithGetStreamableHttp.ts +++ b/src/examples/server/standaloneSseWithGetStreamableHttp.ts @@ -6,124 +6,122 @@ import { isInitializeRequest, ReadResourceResult } from '../../types.js'; // Create an MCP server with implementation details const server = new McpServer({ - name: 'resource-list-changed-notification-server', - version: '1.0.0', + name: 'resource-list-changed-notification-server', + version: '1.0.0' }); // Store transports by session ID to send notifications const transports: { [sessionId: string]: StreamableHTTPServerTransport } = {}; const addResource = (name: string, content: string) => { - const uri = `https://mcp-example.com/dynamic/${encodeURIComponent(name)}`; - server.resource( - name, - uri, - { mimeType: 'text/plain', description: `Dynamic resource: ${name}` }, - async (): Promise => { - return { - contents: [{ uri, text: content }], - }; - } - ); - + const uri = `https://mcp-example.com/dynamic/${encodeURIComponent(name)}`; + server.resource( + name, + uri, + { mimeType: 'text/plain', description: `Dynamic resource: ${name}` }, + async (): Promise => { + return { + contents: [{ uri, text: content }] + }; + } + ); }; addResource('example-resource', 'Initial content for example-resource'); const resourceChangeInterval = setInterval(() => { - const name = randomUUID(); - addResource(name, `Content for ${name}`); + const name = randomUUID(); + addResource(name, `Content for ${name}`); }, 5000); // Change resources every 5 seconds for testing const app = express(); app.use(express.json()); app.post('/mcp', async (req: Request, res: Response) => { - console.log('Received MCP request:', req.body); - try { - // Check for existing session ID - const sessionId = req.headers['mcp-session-id'] as string | undefined; - let transport: StreamableHTTPServerTransport; - - if (sessionId && transports[sessionId]) { - // Reuse existing transport - transport = transports[sessionId]; - } else if (!sessionId && isInitializeRequest(req.body)) { - // New initialization request - transport = new StreamableHTTPServerTransport({ - sessionIdGenerator: () => randomUUID(), - onsessioninitialized: (sessionId) => { - // Store the transport by session ID when session is initialized - // This avoids race conditions where requests might come in before the session is stored - console.log(`Session initialized with ID: ${sessionId}`); - transports[sessionId] = transport; + console.log('Received MCP request:', req.body); + try { + // Check for existing session ID + const sessionId = req.headers['mcp-session-id'] as string | undefined; + let transport: StreamableHTTPServerTransport; + + if (sessionId && transports[sessionId]) { + // Reuse existing transport + transport = transports[sessionId]; + } else if (!sessionId && isInitializeRequest(req.body)) { + // New initialization request + transport = new StreamableHTTPServerTransport({ + sessionIdGenerator: () => randomUUID(), + onsessioninitialized: sessionId => { + // Store the transport by session ID when session is initialized + // This avoids race conditions where requests might come in before the session is stored + console.log(`Session initialized with ID: ${sessionId}`); + transports[sessionId] = transport; + } + }); + + // Connect the transport to the MCP server + await server.connect(transport); + + // Handle the request - the onsessioninitialized callback will store the transport + await transport.handleRequest(req, res, req.body); + return; // Already handled + } else { + // Invalid request - no session ID or not initialization request + res.status(400).json({ + jsonrpc: '2.0', + error: { + code: -32000, + message: 'Bad Request: No valid session ID provided' + }, + id: null + }); + return; } - }); - - // Connect the transport to the MCP server - await server.connect(transport); - - // Handle the request - the onsessioninitialized callback will store the transport - await transport.handleRequest(req, res, req.body); - return; // Already handled - } else { - // Invalid request - no session ID or not initialization request - res.status(400).json({ - jsonrpc: '2.0', - error: { - code: -32000, - message: 'Bad Request: No valid session ID provided', - }, - id: null, - }); - return; - } - // Handle the request with existing transport - await transport.handleRequest(req, res, req.body); - } catch (error) { - console.error('Error handling MCP request:', error); - if (!res.headersSent) { - res.status(500).json({ - jsonrpc: '2.0', - error: { - code: -32603, - message: 'Internal server error', - }, - id: null, - }); + // Handle the request with existing transport + await transport.handleRequest(req, res, req.body); + } catch (error) { + console.error('Error handling MCP request:', error); + if (!res.headersSent) { + res.status(500).json({ + jsonrpc: '2.0', + error: { + code: -32603, + message: 'Internal server error' + }, + id: null + }); + } } - } }); // Handle GET requests for SSE streams (now using built-in support from StreamableHTTP) app.get('/mcp', async (req: Request, res: Response) => { - const sessionId = req.headers['mcp-session-id'] as string | undefined; - if (!sessionId || !transports[sessionId]) { - res.status(400).send('Invalid or missing session ID'); - return; - } + const sessionId = req.headers['mcp-session-id'] as string | undefined; + if (!sessionId || !transports[sessionId]) { + res.status(400).send('Invalid or missing session ID'); + return; + } - console.log(`Establishing SSE stream for session ${sessionId}`); - const transport = transports[sessionId]; - await transport.handleRequest(req, res); + console.log(`Establishing SSE stream for session ${sessionId}`); + const transport = transports[sessionId]; + await transport.handleRequest(req, res); }); - // Start the server const PORT = 3000; -app.listen(PORT, (error) => { - if (error) { - console.error('Failed to start server:', error); - process.exit(1); - } - console.log(`Server listening on port ${PORT}`); +app.listen(PORT, error => { + if (error) { + console.error('Failed to start server:', error); + process.exit(1); + } + console.log(`Server listening on port ${PORT}`); }); // Handle server shutdown process.on('SIGINT', async () => { - console.log('Shutting down server...'); - clearInterval(resourceChangeInterval); - await server.close(); - process.exit(0); -}); \ No newline at end of file + console.log('Shutting down server...'); + clearInterval(resourceChangeInterval); + await server.close(); + process.exit(0); +}); diff --git a/src/examples/server/toolWithSampleServer.ts b/src/examples/server/toolWithSampleServer.ts index 44e5cecbb..ad5a01bdc 100644 --- a/src/examples/server/toolWithSampleServer.ts +++ b/src/examples/server/toolWithSampleServer.ts @@ -1,57 +1,56 @@ - // Run with: npx tsx src/examples/server/toolWithSampleServer.ts -import { McpServer } from "../../server/mcp.js"; -import { StdioServerTransport } from "../../server/stdio.js"; -import { z } from "zod"; +import { McpServer } from '../../server/mcp.js'; +import { StdioServerTransport } from '../../server/stdio.js'; +import { z } from 'zod'; const mcpServer = new McpServer({ - name: "tools-with-sample-server", - version: "1.0.0", + name: 'tools-with-sample-server', + version: '1.0.0' }); // Tool that uses LLM sampling to summarize any text mcpServer.registerTool( - "summarize", - { - description: "Summarize any text using an LLM", - inputSchema: { - text: z.string().describe("Text to summarize"), + 'summarize', + { + description: 'Summarize any text using an LLM', + inputSchema: { + text: z.string().describe('Text to summarize') + } }, - }, - async ({ text }) => { - // Call the LLM through MCP sampling - const response = await mcpServer.server.createMessage({ - messages: [ - { - role: "user", - content: { - type: "text", - text: `Please summarize the following text concisely:\n\n${text}`, - }, - }, - ], - maxTokens: 500, - }); + async ({ text }) => { + // Call the LLM through MCP sampling + const response = await mcpServer.server.createMessage({ + messages: [ + { + role: 'user', + content: { + type: 'text', + text: `Please summarize the following text concisely:\n\n${text}` + } + } + ], + maxTokens: 500 + }); - return { - content: [ - { - type: "text", - text: response.content.type === "text" ? response.content.text : "Unable to generate summary", - }, - ], - }; - } + return { + content: [ + { + type: 'text', + text: response.content.type === 'text' ? response.content.text : 'Unable to generate summary' + } + ] + }; + } ); async function main() { - const transport = new StdioServerTransport(); - await mcpServer.connect(transport); - console.log("MCP server is running..."); + const transport = new StdioServerTransport(); + await mcpServer.connect(transport); + console.log('MCP server is running...'); } -main().catch((error) => { - console.error("Server error:", error); - process.exit(1); -}); \ No newline at end of file +main().catch(error => { + console.error('Server error:', error); + process.exit(1); +}); diff --git a/src/examples/shared/inMemoryEventStore.ts b/src/examples/shared/inMemoryEventStore.ts index fbdebe12a..d4d02eb91 100644 --- a/src/examples/shared/inMemoryEventStore.ts +++ b/src/examples/shared/inMemoryEventStore.ts @@ -7,71 +7,72 @@ import { EventStore } from '../../server/streamableHttp.js'; * where a persistent storage solution would be more appropriate. */ export class InMemoryEventStore implements EventStore { - private events: Map = new Map(); + private events: Map = new Map(); - /** - * Generates a unique event ID for a given stream ID - */ - private generateEventId(streamId: string): string { - return `${streamId}_${Date.now()}_${Math.random().toString(36).substring(2, 10)}`; - } - - /** - * Extracts the stream ID from an event ID - */ - private getStreamIdFromEventId(eventId: string): string { - const parts = eventId.split('_'); - return parts.length > 0 ? parts[0] : ''; - } - - /** - * Stores an event with a generated event ID - * Implements EventStore.storeEvent - */ - async storeEvent(streamId: string, message: JSONRPCMessage): Promise { - const eventId = this.generateEventId(streamId); - this.events.set(eventId, { streamId, message }); - return eventId; - } + /** + * Generates a unique event ID for a given stream ID + */ + private generateEventId(streamId: string): string { + return `${streamId}_${Date.now()}_${Math.random().toString(36).substring(2, 10)}`; + } - /** - * Replays events that occurred after a specific event ID - * Implements EventStore.replayEventsAfter - */ - async replayEventsAfter(lastEventId: string, - { send }: { send: (eventId: string, message: JSONRPCMessage) => Promise } - ): Promise { - if (!lastEventId || !this.events.has(lastEventId)) { - return ''; + /** + * Extracts the stream ID from an event ID + */ + private getStreamIdFromEventId(eventId: string): string { + const parts = eventId.split('_'); + return parts.length > 0 ? parts[0] : ''; } - // Extract the stream ID from the event ID - const streamId = this.getStreamIdFromEventId(lastEventId); - if (!streamId) { - return ''; + /** + * Stores an event with a generated event ID + * Implements EventStore.storeEvent + */ + async storeEvent(streamId: string, message: JSONRPCMessage): Promise { + const eventId = this.generateEventId(streamId); + this.events.set(eventId, { streamId, message }); + return eventId; } - let foundLastEvent = false; + /** + * Replays events that occurred after a specific event ID + * Implements EventStore.replayEventsAfter + */ + async replayEventsAfter( + lastEventId: string, + { send }: { send: (eventId: string, message: JSONRPCMessage) => Promise } + ): Promise { + if (!lastEventId || !this.events.has(lastEventId)) { + return ''; + } + + // Extract the stream ID from the event ID + const streamId = this.getStreamIdFromEventId(lastEventId); + if (!streamId) { + return ''; + } + + let foundLastEvent = false; - // Sort events by eventId for chronological ordering - const sortedEvents = [...this.events.entries()].sort((a, b) => a[0].localeCompare(b[0])); + // Sort events by eventId for chronological ordering + const sortedEvents = [...this.events.entries()].sort((a, b) => a[0].localeCompare(b[0])); - for (const [eventId, { streamId: eventStreamId, message }] of sortedEvents) { - // Only include events from the same stream - if (eventStreamId !== streamId) { - continue; - } + for (const [eventId, { streamId: eventStreamId, message }] of sortedEvents) { + // Only include events from the same stream + if (eventStreamId !== streamId) { + continue; + } - // Start sending events after we find the lastEventId - if (eventId === lastEventId) { - foundLastEvent = true; - continue; - } + // Start sending events after we find the lastEventId + if (eventId === lastEventId) { + foundLastEvent = true; + continue; + } - if (foundLastEvent) { - await send(eventId, message); - } + if (foundLastEvent) { + await send(eventId, message); + } + } + return streamId; } - return streamId; - } -} \ No newline at end of file +} diff --git a/src/inMemory.test.ts b/src/inMemory.test.ts index baf43446c..cb758ec0a 100644 --- a/src/inMemory.test.ts +++ b/src/inMemory.test.ts @@ -1,121 +1,119 @@ -import { InMemoryTransport } from "./inMemory.js"; -import { JSONRPCMessage } from "./types.js"; -import { AuthInfo } from "./server/auth/types.js"; - -describe("InMemoryTransport", () => { - let clientTransport: InMemoryTransport; - let serverTransport: InMemoryTransport; - - beforeEach(() => { - [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - }); - - test("should create linked pair", () => { - expect(clientTransport).toBeDefined(); - expect(serverTransport).toBeDefined(); - }); - - test("should start without error", async () => { - await expect(clientTransport.start()).resolves.not.toThrow(); - await expect(serverTransport.start()).resolves.not.toThrow(); - }); - - test("should send message from client to server", async () => { - const message: JSONRPCMessage = { - jsonrpc: "2.0", - method: "test", - id: 1, - }; - - let receivedMessage: JSONRPCMessage | undefined; - serverTransport.onmessage = (msg) => { - receivedMessage = msg; - }; - - await clientTransport.send(message); - expect(receivedMessage).toEqual(message); - }); - - test("should send message with auth info from client to server", async () => { - const message: JSONRPCMessage = { - jsonrpc: "2.0", - method: "test", - id: 1, - }; - - const authInfo: AuthInfo = { - token: "test-token", - clientId: "test-client", - scopes: ["read", "write"], - expiresAt: Date.now() / 1000 + 3600, - }; - - let receivedMessage: JSONRPCMessage | undefined; - let receivedAuthInfo: AuthInfo | undefined; - serverTransport.onmessage = (msg, extra) => { - receivedMessage = msg; - receivedAuthInfo = extra?.authInfo; - }; - - await clientTransport.send(message, { authInfo }); - expect(receivedMessage).toEqual(message); - expect(receivedAuthInfo).toEqual(authInfo); - }); - - test("should send message from server to client", async () => { - const message: JSONRPCMessage = { - jsonrpc: "2.0", - method: "test", - id: 1, - }; - - let receivedMessage: JSONRPCMessage | undefined; - clientTransport.onmessage = (msg) => { - receivedMessage = msg; - }; - - await serverTransport.send(message); - expect(receivedMessage).toEqual(message); - }); - - test("should handle close", async () => { - let clientClosed = false; - let serverClosed = false; - - clientTransport.onclose = () => { - clientClosed = true; - }; - - serverTransport.onclose = () => { - serverClosed = true; - }; - - await clientTransport.close(); - expect(clientClosed).toBe(true); - expect(serverClosed).toBe(true); - }); - - test("should throw error when sending after close", async () => { - await clientTransport.close(); - await expect( - clientTransport.send({ jsonrpc: "2.0", method: "test", id: 1 }), - ).rejects.toThrow("Not connected"); - }); - - test("should queue messages sent before start", async () => { - const message: JSONRPCMessage = { - jsonrpc: "2.0", - method: "test", - id: 1, - }; - - let receivedMessage: JSONRPCMessage | undefined; - serverTransport.onmessage = (msg) => { - receivedMessage = msg; - }; - - await clientTransport.send(message); - await serverTransport.start(); - expect(receivedMessage).toEqual(message); - }); +import { InMemoryTransport } from './inMemory.js'; +import { JSONRPCMessage } from './types.js'; +import { AuthInfo } from './server/auth/types.js'; + +describe('InMemoryTransport', () => { + let clientTransport: InMemoryTransport; + let serverTransport: InMemoryTransport; + + beforeEach(() => { + [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + }); + + test('should create linked pair', () => { + expect(clientTransport).toBeDefined(); + expect(serverTransport).toBeDefined(); + }); + + test('should start without error', async () => { + await expect(clientTransport.start()).resolves.not.toThrow(); + await expect(serverTransport.start()).resolves.not.toThrow(); + }); + + test('should send message from client to server', async () => { + const message: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'test', + id: 1 + }; + + let receivedMessage: JSONRPCMessage | undefined; + serverTransport.onmessage = msg => { + receivedMessage = msg; + }; + + await clientTransport.send(message); + expect(receivedMessage).toEqual(message); + }); + + test('should send message with auth info from client to server', async () => { + const message: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'test', + id: 1 + }; + + const authInfo: AuthInfo = { + token: 'test-token', + clientId: 'test-client', + scopes: ['read', 'write'], + expiresAt: Date.now() / 1000 + 3600 + }; + + let receivedMessage: JSONRPCMessage | undefined; + let receivedAuthInfo: AuthInfo | undefined; + serverTransport.onmessage = (msg, extra) => { + receivedMessage = msg; + receivedAuthInfo = extra?.authInfo; + }; + + await clientTransport.send(message, { authInfo }); + expect(receivedMessage).toEqual(message); + expect(receivedAuthInfo).toEqual(authInfo); + }); + + test('should send message from server to client', async () => { + const message: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'test', + id: 1 + }; + + let receivedMessage: JSONRPCMessage | undefined; + clientTransport.onmessage = msg => { + receivedMessage = msg; + }; + + await serverTransport.send(message); + expect(receivedMessage).toEqual(message); + }); + + test('should handle close', async () => { + let clientClosed = false; + let serverClosed = false; + + clientTransport.onclose = () => { + clientClosed = true; + }; + + serverTransport.onclose = () => { + serverClosed = true; + }; + + await clientTransport.close(); + expect(clientClosed).toBe(true); + expect(serverClosed).toBe(true); + }); + + test('should throw error when sending after close', async () => { + await clientTransport.close(); + await expect(clientTransport.send({ jsonrpc: '2.0', method: 'test', id: 1 })).rejects.toThrow('Not connected'); + }); + + test('should queue messages sent before start', async () => { + const message: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'test', + id: 1 + }; + + let receivedMessage: JSONRPCMessage | undefined; + serverTransport.onmessage = msg => { + receivedMessage = msg; + }; + + await clientTransport.send(message); + await serverTransport.start(); + expect(receivedMessage).toEqual(message); + }); }); diff --git a/src/inMemory.ts b/src/inMemory.ts index 5dd6e81e0..26062624d 100644 --- a/src/inMemory.ts +++ b/src/inMemory.ts @@ -1,63 +1,63 @@ -import { Transport } from "./shared/transport.js"; -import { JSONRPCMessage, RequestId } from "./types.js"; -import { AuthInfo } from "./server/auth/types.js"; +import { Transport } from './shared/transport.js'; +import { JSONRPCMessage, RequestId } from './types.js'; +import { AuthInfo } from './server/auth/types.js'; interface QueuedMessage { - message: JSONRPCMessage; - extra?: { authInfo?: AuthInfo }; + message: JSONRPCMessage; + extra?: { authInfo?: AuthInfo }; } /** * In-memory transport for creating clients and servers that talk to each other within the same process. */ export class InMemoryTransport implements Transport { - private _otherTransport?: InMemoryTransport; - private _messageQueue: QueuedMessage[] = []; - - onclose?: () => void; - onerror?: (error: Error) => void; - onmessage?: (message: JSONRPCMessage, extra?: { authInfo?: AuthInfo }) => void; - sessionId?: string; - - /** - * Creates a pair of linked in-memory transports that can communicate with each other. One should be passed to a Client and one to a Server. - */ - static createLinkedPair(): [InMemoryTransport, InMemoryTransport] { - const clientTransport = new InMemoryTransport(); - const serverTransport = new InMemoryTransport(); - clientTransport._otherTransport = serverTransport; - serverTransport._otherTransport = clientTransport; - return [clientTransport, serverTransport]; - } - - async start(): Promise { - // Process any messages that were queued before start was called - while (this._messageQueue.length > 0) { - const queuedMessage = this._messageQueue.shift()!; - this.onmessage?.(queuedMessage.message, queuedMessage.extra); + private _otherTransport?: InMemoryTransport; + private _messageQueue: QueuedMessage[] = []; + + onclose?: () => void; + onerror?: (error: Error) => void; + onmessage?: (message: JSONRPCMessage, extra?: { authInfo?: AuthInfo }) => void; + sessionId?: string; + + /** + * Creates a pair of linked in-memory transports that can communicate with each other. One should be passed to a Client and one to a Server. + */ + static createLinkedPair(): [InMemoryTransport, InMemoryTransport] { + const clientTransport = new InMemoryTransport(); + const serverTransport = new InMemoryTransport(); + clientTransport._otherTransport = serverTransport; + serverTransport._otherTransport = clientTransport; + return [clientTransport, serverTransport]; } - } - - async close(): Promise { - const other = this._otherTransport; - this._otherTransport = undefined; - await other?.close(); - this.onclose?.(); - } - - /** - * Sends a message with optional auth info. - * This is useful for testing authentication scenarios. - */ - async send(message: JSONRPCMessage, options?: { relatedRequestId?: RequestId, authInfo?: AuthInfo }): Promise { - if (!this._otherTransport) { - throw new Error("Not connected"); + + async start(): Promise { + // Process any messages that were queued before start was called + while (this._messageQueue.length > 0) { + const queuedMessage = this._messageQueue.shift()!; + this.onmessage?.(queuedMessage.message, queuedMessage.extra); + } + } + + async close(): Promise { + const other = this._otherTransport; + this._otherTransport = undefined; + await other?.close(); + this.onclose?.(); } - if (this._otherTransport.onmessage) { - this._otherTransport.onmessage(message, { authInfo: options?.authInfo }); - } else { - this._otherTransport._messageQueue.push({ message, extra: { authInfo: options?.authInfo } }); + /** + * Sends a message with optional auth info. + * This is useful for testing authentication scenarios. + */ + async send(message: JSONRPCMessage, options?: { relatedRequestId?: RequestId; authInfo?: AuthInfo }): Promise { + if (!this._otherTransport) { + throw new Error('Not connected'); + } + + if (this._otherTransport.onmessage) { + this._otherTransport.onmessage(message, { authInfo: options?.authInfo }); + } else { + this._otherTransport._messageQueue.push({ message, extra: { authInfo: options?.authInfo } }); + } } - } } diff --git a/src/integration-tests/process-cleanup.test.ts b/src/integration-tests/process-cleanup.test.ts index 0dd7861a4..8c7c42b46 100644 --- a/src/integration-tests/process-cleanup.test.ts +++ b/src/integration-tests/process-cleanup.test.ts @@ -1,28 +1,28 @@ -import { Server } from "../server/index.js"; -import { StdioServerTransport } from "../server/stdio.js"; +import { Server } from '../server/index.js'; +import { StdioServerTransport } from '../server/stdio.js'; -describe("Process cleanup", () => { - jest.setTimeout(5000); // 5 second timeout +describe('Process cleanup', () => { + jest.setTimeout(5000); // 5 second timeout - it("should exit cleanly after closing transport", async () => { - const server = new Server( - { - name: "test-server", - version: "1.0.0", - }, - { - capabilities: {}, - } - ); + it('should exit cleanly after closing transport', async () => { + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: {} + } + ); - const transport = new StdioServerTransport(); - await server.connect(transport); + const transport = new StdioServerTransport(); + await server.connect(transport); - // Close the transport - await transport.close(); + // Close the transport + await transport.close(); - // If we reach here without hanging, the test passes - // The test runner will fail if the process hangs - expect(true).toBe(true); - }); -}); \ No newline at end of file + // If we reach here without hanging, the test passes + // The test runner will fail if the process hangs + expect(true).toBe(true); + }); +}); diff --git a/src/integration-tests/stateManagementStreamableHttp.test.ts b/src/integration-tests/stateManagementStreamableHttp.test.ts index 4a191134b..629b01519 100644 --- a/src/integration-tests/stateManagementStreamableHttp.test.ts +++ b/src/integration-tests/stateManagementStreamableHttp.test.ts @@ -5,320 +5,353 @@ import { Client } from '../client/index.js'; import { StreamableHTTPClientTransport } from '../client/streamableHttp.js'; import { McpServer } from '../server/mcp.js'; import { StreamableHTTPServerTransport } from '../server/streamableHttp.js'; -import { CallToolResultSchema, ListToolsResultSchema, ListResourcesResultSchema, ListPromptsResultSchema, LATEST_PROTOCOL_VERSION } from '../types.js'; +import { + CallToolResultSchema, + ListToolsResultSchema, + ListResourcesResultSchema, + ListPromptsResultSchema, + LATEST_PROTOCOL_VERSION +} from '../types.js'; import { z } from 'zod'; describe('Streamable HTTP Transport Session Management', () => { - // Function to set up the server with optional session management - async function setupServer(withSessionManagement: boolean) { - const server: Server = createServer(); - const mcpServer = new McpServer( - { name: 'test-server', version: '1.0.0' }, - { - capabilities: { - logging: {}, - tools: {}, - resources: {}, - prompts: {} - } - } - ); - - // Add a simple resource - mcpServer.resource( - 'test-resource', - '/test', - { description: 'A test resource' }, - async () => ({ - contents: [{ - uri: '/test', - text: 'This is a test resource content' - }] - }) - ); - - mcpServer.prompt( - 'test-prompt', - 'A test prompt', - async () => ({ - messages: [{ - role: 'user', - content: { - type: 'text', - text: 'This is a test prompt' - } - }] - }) - ); - - mcpServer.tool( - 'greet', - 'A simple greeting tool', - { - name: z.string().describe('Name to greet').default('World'), - }, - async ({ name }) => { - return { - content: [{ type: 'text', text: `Hello, ${name}!` }] - }; - } - ); - - // Create transport with or without session management - const serverTransport = new StreamableHTTPServerTransport({ - sessionIdGenerator: withSessionManagement - ? () => randomUUID() // With session management, generate UUID - : undefined // Without session management, return undefined + // Function to set up the server with optional session management + async function setupServer(withSessionManagement: boolean) { + const server: Server = createServer(); + const mcpServer = new McpServer( + { name: 'test-server', version: '1.0.0' }, + { + capabilities: { + logging: {}, + tools: {}, + resources: {}, + prompts: {} + } + } + ); + + // Add a simple resource + mcpServer.resource('test-resource', '/test', { description: 'A test resource' }, async () => ({ + contents: [ + { + uri: '/test', + text: 'This is a test resource content' + } + ] + })); + + mcpServer.prompt('test-prompt', 'A test prompt', async () => ({ + messages: [ + { + role: 'user', + content: { + type: 'text', + text: 'This is a test prompt' + } + } + ] + })); + + mcpServer.tool( + 'greet', + 'A simple greeting tool', + { + name: z.string().describe('Name to greet').default('World') + }, + async ({ name }) => { + return { + content: [{ type: 'text', text: `Hello, ${name}!` }] + }; + } + ); + + // Create transport with or without session management + const serverTransport = new StreamableHTTPServerTransport({ + sessionIdGenerator: withSessionManagement + ? () => randomUUID() // With session management, generate UUID + : undefined // Without session management, return undefined + }); + + await mcpServer.connect(serverTransport); + + server.on('request', async (req, res) => { + await serverTransport.handleRequest(req, res); + }); + + // Start the server on a random port + const baseUrl = await new Promise(resolve => { + server.listen(0, '127.0.0.1', () => { + const addr = server.address() as AddressInfo; + resolve(new URL(`http://127.0.0.1:${addr.port}`)); + }); + }); + + return { server, mcpServer, serverTransport, baseUrl }; + } + + describe('Stateless Mode', () => { + let server: Server; + let mcpServer: McpServer; + let serverTransport: StreamableHTTPServerTransport; + let baseUrl: URL; + + beforeEach(async () => { + const setup = await setupServer(false); + server = setup.server; + mcpServer = setup.mcpServer; + serverTransport = setup.serverTransport; + baseUrl = setup.baseUrl; + }); + + afterEach(async () => { + // Clean up resources + await mcpServer.close().catch(() => {}); + await serverTransport.close().catch(() => {}); + server.close(); + }); + + it('should support multiple client connections', async () => { + // Create and connect a client + const client1 = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const transport1 = new StreamableHTTPClientTransport(baseUrl); + await client1.connect(transport1); + + // Verify that no session ID was set + expect(transport1.sessionId).toBeUndefined(); + + // List available tools + await client1.request( + { + method: 'tools/list', + params: {} + }, + ListToolsResultSchema + ); + + const client2 = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const transport2 = new StreamableHTTPClientTransport(baseUrl); + await client2.connect(transport2); + + // Verify that no session ID was set + expect(transport2.sessionId).toBeUndefined(); + + // List available tools + await client2.request( + { + method: 'tools/list', + params: {} + }, + ListToolsResultSchema + ); + }); + it('should operate without session management', async () => { + // Create and connect a client + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const transport = new StreamableHTTPClientTransport(baseUrl); + await client.connect(transport); + + // Verify that no session ID was set + expect(transport.sessionId).toBeUndefined(); + + // List available tools + const toolsResult = await client.request( + { + method: 'tools/list', + params: {} + }, + ListToolsResultSchema + ); + + // Verify tools are accessible + expect(toolsResult.tools).toContainEqual( + expect.objectContaining({ + name: 'greet' + }) + ); + + // List available resources + const resourcesResult = await client.request( + { + method: 'resources/list', + params: {} + }, + ListResourcesResultSchema + ); + + // Verify resources result structure + expect(resourcesResult).toHaveProperty('resources'); + + // List available prompts + const promptsResult = await client.request( + { + method: 'prompts/list', + params: {} + }, + ListPromptsResultSchema + ); + + // Verify prompts result structure + expect(promptsResult).toHaveProperty('prompts'); + expect(promptsResult.prompts).toContainEqual( + expect.objectContaining({ + name: 'test-prompt' + }) + ); + + // Call the greeting tool + const greetingResult = await client.request( + { + method: 'tools/call', + params: { + name: 'greet', + arguments: { + name: 'Stateless Transport' + } + } + }, + CallToolResultSchema + ); + + // Verify tool result + expect(greetingResult.content).toEqual([{ type: 'text', text: 'Hello, Stateless Transport!' }]); + + // Clean up + await transport.close(); + }); + + it('should set protocol version after connecting', async () => { + // Create and connect a client + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const transport = new StreamableHTTPClientTransport(baseUrl); + + // Verify protocol version is not set before connecting + expect(transport.protocolVersion).toBeUndefined(); + + await client.connect(transport); + + // Verify protocol version is set after connecting + expect(transport.protocolVersion).toBe(LATEST_PROTOCOL_VERSION); + + // Clean up + await transport.close(); + }); }); - await mcpServer.connect(serverTransport); - - server.on('request', async (req, res) => { - await serverTransport.handleRequest(req, res); - }); - - // Start the server on a random port - const baseUrl = await new Promise((resolve) => { - server.listen(0, '127.0.0.1', () => { - const addr = server.address() as AddressInfo; - resolve(new URL(`http://127.0.0.1:${addr.port}`)); - }); - }); - - return { server, mcpServer, serverTransport, baseUrl }; - } - - describe('Stateless Mode', () => { - let server: Server; - let mcpServer: McpServer; - let serverTransport: StreamableHTTPServerTransport; - let baseUrl: URL; - - beforeEach(async () => { - const setup = await setupServer(false); - server = setup.server; - mcpServer = setup.mcpServer; - serverTransport = setup.serverTransport; - baseUrl = setup.baseUrl; - }); - - afterEach(async () => { - // Clean up resources - await mcpServer.close().catch(() => { }); - await serverTransport.close().catch(() => { }); - server.close(); - }); - - it('should support multiple client connections', async () => { - // Create and connect a client - const client1 = new Client({ - name: 'test-client', - version: '1.0.0' - }); - - const transport1 = new StreamableHTTPClientTransport(baseUrl); - await client1.connect(transport1); - - // Verify that no session ID was set - expect(transport1.sessionId).toBeUndefined(); - - // List available tools - await client1.request({ - method: 'tools/list', - params: {} - }, ListToolsResultSchema); - - const client2 = new Client({ - name: 'test-client', - version: '1.0.0' - }); - - const transport2 = new StreamableHTTPClientTransport(baseUrl); - await client2.connect(transport2); - - // Verify that no session ID was set - expect(transport2.sessionId).toBeUndefined(); - - // List available tools - await client2.request({ - method: 'tools/list', - params: {} - }, ListToolsResultSchema); - - - }); - it('should operate without session management', async () => { - // Create and connect a client - const client = new Client({ - name: 'test-client', - version: '1.0.0' - }); - - const transport = new StreamableHTTPClientTransport(baseUrl); - await client.connect(transport); - - // Verify that no session ID was set - expect(transport.sessionId).toBeUndefined(); - - // List available tools - const toolsResult = await client.request({ - method: 'tools/list', - params: {} - }, ListToolsResultSchema); - - // Verify tools are accessible - expect(toolsResult.tools).toContainEqual(expect.objectContaining({ - name: 'greet' - })); - - // List available resources - const resourcesResult = await client.request({ - method: 'resources/list', - params: {} - }, ListResourcesResultSchema); - - // Verify resources result structure - expect(resourcesResult).toHaveProperty('resources'); - - // List available prompts - const promptsResult = await client.request({ - method: 'prompts/list', - params: {} - }, ListPromptsResultSchema); - - // Verify prompts result structure - expect(promptsResult).toHaveProperty('prompts'); - expect(promptsResult.prompts).toContainEqual(expect.objectContaining({ - name: 'test-prompt' - })); - - // Call the greeting tool - const greetingResult = await client.request({ - method: 'tools/call', - params: { - name: 'greet', - arguments: { - name: 'Stateless Transport' - } - } - }, CallToolResultSchema); - - // Verify tool result - expect(greetingResult.content).toEqual([ - { type: 'text', text: 'Hello, Stateless Transport!' } - ]); - - // Clean up - await transport.close(); - }); - - it('should set protocol version after connecting', async () => { - // Create and connect a client - const client = new Client({ - name: 'test-client', - version: '1.0.0' - }); - - const transport = new StreamableHTTPClientTransport(baseUrl); - - // Verify protocol version is not set before connecting - expect(transport.protocolVersion).toBeUndefined(); - - await client.connect(transport); - - // Verify protocol version is set after connecting - expect(transport.protocolVersion).toBe(LATEST_PROTOCOL_VERSION); - - // Clean up - await transport.close(); - }); - }); - - describe('Stateful Mode', () => { - let server: Server; - let mcpServer: McpServer; - let serverTransport: StreamableHTTPServerTransport; - let baseUrl: URL; - - beforeEach(async () => { - const setup = await setupServer(true); - server = setup.server; - mcpServer = setup.mcpServer; - serverTransport = setup.serverTransport; - baseUrl = setup.baseUrl; - }); - - afterEach(async () => { - // Clean up resources - await mcpServer.close().catch(() => { }); - await serverTransport.close().catch(() => { }); - server.close(); - }); - - it('should operate with session management', async () => { - // Create and connect a client - const client = new Client({ - name: 'test-client', - version: '1.0.0' - }); - - const transport = new StreamableHTTPClientTransport(baseUrl); - await client.connect(transport); - - // Verify that a session ID was set - expect(transport.sessionId).toBeDefined(); - expect(typeof transport.sessionId).toBe('string'); - - // List available tools - const toolsResult = await client.request({ - method: 'tools/list', - params: {} - }, ListToolsResultSchema); - - // Verify tools are accessible - expect(toolsResult.tools).toContainEqual(expect.objectContaining({ - name: 'greet' - })); - - // List available resources - const resourcesResult = await client.request({ - method: 'resources/list', - params: {} - }, ListResourcesResultSchema); - - // Verify resources result structure - expect(resourcesResult).toHaveProperty('resources'); - - // List available prompts - const promptsResult = await client.request({ - method: 'prompts/list', - params: {} - }, ListPromptsResultSchema); - - // Verify prompts result structure - expect(promptsResult).toHaveProperty('prompts'); - expect(promptsResult.prompts).toContainEqual(expect.objectContaining({ - name: 'test-prompt' - })); - - // Call the greeting tool - const greetingResult = await client.request({ - method: 'tools/call', - params: { - name: 'greet', - arguments: { - name: 'Stateful Transport' - } - } - }, CallToolResultSchema); - - // Verify tool result - expect(greetingResult.content).toEqual([ - { type: 'text', text: 'Hello, Stateful Transport!' } - ]); - - // Clean up - await transport.close(); + describe('Stateful Mode', () => { + let server: Server; + let mcpServer: McpServer; + let serverTransport: StreamableHTTPServerTransport; + let baseUrl: URL; + + beforeEach(async () => { + const setup = await setupServer(true); + server = setup.server; + mcpServer = setup.mcpServer; + serverTransport = setup.serverTransport; + baseUrl = setup.baseUrl; + }); + + afterEach(async () => { + // Clean up resources + await mcpServer.close().catch(() => {}); + await serverTransport.close().catch(() => {}); + server.close(); + }); + + it('should operate with session management', async () => { + // Create and connect a client + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const transport = new StreamableHTTPClientTransport(baseUrl); + await client.connect(transport); + + // Verify that a session ID was set + expect(transport.sessionId).toBeDefined(); + expect(typeof transport.sessionId).toBe('string'); + + // List available tools + const toolsResult = await client.request( + { + method: 'tools/list', + params: {} + }, + ListToolsResultSchema + ); + + // Verify tools are accessible + expect(toolsResult.tools).toContainEqual( + expect.objectContaining({ + name: 'greet' + }) + ); + + // List available resources + const resourcesResult = await client.request( + { + method: 'resources/list', + params: {} + }, + ListResourcesResultSchema + ); + + // Verify resources result structure + expect(resourcesResult).toHaveProperty('resources'); + + // List available prompts + const promptsResult = await client.request( + { + method: 'prompts/list', + params: {} + }, + ListPromptsResultSchema + ); + + // Verify prompts result structure + expect(promptsResult).toHaveProperty('prompts'); + expect(promptsResult.prompts).toContainEqual( + expect.objectContaining({ + name: 'test-prompt' + }) + ); + + // Call the greeting tool + const greetingResult = await client.request( + { + method: 'tools/call', + params: { + name: 'greet', + arguments: { + name: 'Stateful Transport' + } + } + }, + CallToolResultSchema + ); + + // Verify tool result + expect(greetingResult.content).toEqual([{ type: 'text', text: 'Hello, Stateful Transport!' }]); + + // Clean up + await transport.close(); + }); }); - }); -}); \ No newline at end of file +}); diff --git a/src/integration-tests/taskResumability.test.ts b/src/integration-tests/taskResumability.test.ts index efd2611f8..d397ffab3 100644 --- a/src/integration-tests/taskResumability.test.ts +++ b/src/integration-tests/taskResumability.test.ts @@ -9,264 +9,262 @@ import { CallToolResultSchema, LoggingMessageNotificationSchema } from '../types import { z } from 'zod'; import { InMemoryEventStore } from '../examples/shared/inMemoryEventStore.js'; - - describe('Transport resumability', () => { - let server: Server; - let mcpServer: McpServer; - let serverTransport: StreamableHTTPServerTransport; - let baseUrl: URL; - let eventStore: InMemoryEventStore; - - beforeEach(async () => { - // Create event store for resumability - eventStore = new InMemoryEventStore(); - - // Create a simple MCP server - mcpServer = new McpServer( - { name: 'test-server', version: '1.0.0' }, - { capabilities: { logging: {} } } - ); - - // Add a simple notification tool that completes quickly - mcpServer.tool( - 'send-notification', - 'Sends a single notification', - { - message: z.string().describe('Message to send').default('Test notification') - }, - async ({ message }, { sendNotification }) => { - // Send notification immediately - await sendNotification({ - method: "notifications/message", - params: { - level: "info", - data: message - } + let server: Server; + let mcpServer: McpServer; + let serverTransport: StreamableHTTPServerTransport; + let baseUrl: URL; + let eventStore: InMemoryEventStore; + + beforeEach(async () => { + // Create event store for resumability + eventStore = new InMemoryEventStore(); + + // Create a simple MCP server + mcpServer = new McpServer({ name: 'test-server', version: '1.0.0' }, { capabilities: { logging: {} } }); + + // Add a simple notification tool that completes quickly + mcpServer.tool( + 'send-notification', + 'Sends a single notification', + { + message: z.string().describe('Message to send').default('Test notification') + }, + async ({ message }, { sendNotification }) => { + // Send notification immediately + await sendNotification({ + method: 'notifications/message', + params: { + level: 'info', + data: message + } + }); + + return { + content: [{ type: 'text', text: 'Notification sent' }] + }; + } + ); + + // Add a long-running tool that sends multiple notifications + mcpServer.tool( + 'run-notifications', + 'Sends multiple notifications over time', + { + count: z.number().describe('Number of notifications to send').default(10), + interval: z.number().describe('Interval between notifications in ms').default(50) + }, + async ({ count, interval }, { sendNotification }) => { + // Send notifications at specified intervals + for (let i = 0; i < count; i++) { + await sendNotification({ + method: 'notifications/message', + params: { + level: 'info', + data: `Notification ${i + 1} of ${count}` + } + }); + + // Wait for the specified interval before sending next notification + if (i < count - 1) { + await new Promise(resolve => setTimeout(resolve, interval)); + } + } + + return { + content: [{ type: 'text', text: `Sent ${count} notifications` }] + }; + } + ); + + // Create a transport with the event store + serverTransport = new StreamableHTTPServerTransport({ + sessionIdGenerator: () => randomUUID(), + eventStore }); - return { - content: [{ type: 'text', text: 'Notification sent' }] - }; - } - ); - - // Add a long-running tool that sends multiple notifications - mcpServer.tool( - 'run-notifications', - 'Sends multiple notifications over time', - { - count: z.number().describe('Number of notifications to send').default(10), - interval: z.number().describe('Interval between notifications in ms').default(50) - }, - async ({ count, interval }, { sendNotification }) => { - // Send notifications at specified intervals - for (let i = 0; i < count; i++) { - await sendNotification({ - method: "notifications/message", - params: { - level: "info", - data: `Notification ${i + 1} of ${count}` - } - }); - - // Wait for the specified interval before sending next notification - if (i < count - 1) { - await new Promise(resolve => setTimeout(resolve, interval)); - } - } - - return { - content: [{ type: 'text', text: `Sent ${count} notifications` }] - }; - } - ); - - // Create a transport with the event store - serverTransport = new StreamableHTTPServerTransport({ - sessionIdGenerator: () => randomUUID(), - eventStore - }); + // Connect the transport to the MCP server + await mcpServer.connect(serverTransport); - // Connect the transport to the MCP server - await mcpServer.connect(serverTransport); + // Create and start an HTTP server + server = createServer(async (req, res) => { + await serverTransport.handleRequest(req, res); + }); - // Create and start an HTTP server - server = createServer(async (req, res) => { - await serverTransport.handleRequest(req, res); + // Start the server on a random port + baseUrl = await new Promise(resolve => { + server.listen(0, '127.0.0.1', () => { + const addr = server.address() as AddressInfo; + resolve(new URL(`http://127.0.0.1:${addr.port}`)); + }); + }); }); - // Start the server on a random port - baseUrl = await new Promise((resolve) => { - server.listen(0, '127.0.0.1', () => { - const addr = server.address() as AddressInfo; - resolve(new URL(`http://127.0.0.1:${addr.port}`)); - }); - }); - }); - - afterEach(async () => { - // Clean up resources - await mcpServer.close().catch(() => { }); - await serverTransport.close().catch(() => { }); - server.close(); - }); - - it('should store session ID when client connects', async () => { - // Create and connect a client - const client = new Client({ - name: 'test-client', - version: '1.0.0' + afterEach(async () => { + // Clean up resources + await mcpServer.close().catch(() => {}); + await serverTransport.close().catch(() => {}); + server.close(); }); - const transport = new StreamableHTTPClientTransport(baseUrl); - await client.connect(transport); + it('should store session ID when client connects', async () => { + // Create and connect a client + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); - // Verify session ID was generated - expect(transport.sessionId).toBeDefined(); + const transport = new StreamableHTTPClientTransport(baseUrl); + await client.connect(transport); - // Clean up - await transport.close(); - }); + // Verify session ID was generated + expect(transport.sessionId).toBeDefined(); - it('should have session ID functionality', async () => { - // The ability to store a session ID when connecting - const client = new Client({ - name: 'test-client-reconnection', - version: '1.0.0' + // Clean up + await transport.close(); }); - const transport = new StreamableHTTPClientTransport(baseUrl); - - // Make sure the client can connect and get a session ID - await client.connect(transport); - expect(transport.sessionId).toBeDefined(); - - // Clean up - await transport.close(); - }); - - // This test demonstrates the capability to resume long-running tools - // across client disconnection/reconnection - it('should resume long-running notifications with lastEventId', async () => { - // Create unique client ID for this test - const clientId = 'test-client-long-running'; - const notifications = []; - let lastEventId: string | undefined; - - // Create first client - const client1 = new Client({ - id: clientId, - name: 'test-client', - version: '1.0.0' - }); - - // Set up notification handler for first client - client1.setNotificationHandler(LoggingMessageNotificationSchema, (notification) => { - if (notification.method === 'notifications/message') { - notifications.push(notification.params); - } - }); + it('should have session ID functionality', async () => { + // The ability to store a session ID when connecting + const client = new Client({ + name: 'test-client-reconnection', + version: '1.0.0' + }); - // Connect first client - const transport1 = new StreamableHTTPClientTransport(baseUrl); - await client1.connect(transport1); - const sessionId = transport1.sessionId; - expect(sessionId).toBeDefined(); + const transport = new StreamableHTTPClientTransport(baseUrl); - // Start a long-running notification stream with tracking of lastEventId - const onLastEventIdUpdate = jest.fn((eventId: string) => { - lastEventId = eventId; - }); - expect(lastEventId).toBeUndefined(); - // Start the notification tool with event tracking using request - const toolPromise = client1.request({ - method: 'tools/call', - params: { - name: 'run-notifications', - arguments: { - count: 3, - interval: 10 - } - } - }, CallToolResultSchema, { - resumptionToken: lastEventId, - onresumptiontoken: onLastEventIdUpdate - }); + // Make sure the client can connect and get a session ID + await client.connect(transport); + expect(transport.sessionId).toBeDefined(); - // Wait for some notifications to arrive (not all) - shorter wait time - await new Promise(resolve => setTimeout(resolve, 20)); - - // Verify we received some notifications and lastEventId was updated - expect(notifications.length).toBeGreaterThan(0); - expect(notifications.length).toBeLessThan(4); - expect(onLastEventIdUpdate).toHaveBeenCalled(); - expect(lastEventId).toBeDefined(); - - - // Disconnect first client without waiting for completion - // When we close the connection, it will cause a ConnectionClosed error for - // any in-progress requests, which is expected behavior - await transport1.close(); - // Save the promise so we can catch it after closing - const catchPromise = toolPromise.catch(err => { - // This error is expected - the connection was intentionally closed - if (err?.code !== -32000) { // ConnectionClosed error code - console.error("Unexpected error type during transport close:", err); - } + // Clean up + await transport.close(); }); + // This test demonstrates the capability to resume long-running tools + // across client disconnection/reconnection + it('should resume long-running notifications with lastEventId', async () => { + // Create unique client ID for this test + const clientId = 'test-client-long-running'; + const notifications = []; + let lastEventId: string | undefined; + + // Create first client + const client1 = new Client({ + id: clientId, + name: 'test-client', + version: '1.0.0' + }); + // Set up notification handler for first client + client1.setNotificationHandler(LoggingMessageNotificationSchema, notification => { + if (notification.method === 'notifications/message') { + notifications.push(notification.params); + } + }); - // Add a short delay to ensure clean disconnect before reconnecting - await new Promise(resolve => setTimeout(resolve, 10)); - - // Wait for the rejection to be handled - await catchPromise; + // Connect first client + const transport1 = new StreamableHTTPClientTransport(baseUrl); + await client1.connect(transport1); + const sessionId = transport1.sessionId; + expect(sessionId).toBeDefined(); + // Start a long-running notification stream with tracking of lastEventId + const onLastEventIdUpdate = jest.fn((eventId: string) => { + lastEventId = eventId; + }); + expect(lastEventId).toBeUndefined(); + // Start the notification tool with event tracking using request + const toolPromise = client1.request( + { + method: 'tools/call', + params: { + name: 'run-notifications', + arguments: { + count: 3, + interval: 10 + } + } + }, + CallToolResultSchema, + { + resumptionToken: lastEventId, + onresumptiontoken: onLastEventIdUpdate + } + ); + + // Wait for some notifications to arrive (not all) - shorter wait time + await new Promise(resolve => setTimeout(resolve, 20)); + + // Verify we received some notifications and lastEventId was updated + expect(notifications.length).toBeGreaterThan(0); + expect(notifications.length).toBeLessThan(4); + expect(onLastEventIdUpdate).toHaveBeenCalled(); + expect(lastEventId).toBeDefined(); + + // Disconnect first client without waiting for completion + // When we close the connection, it will cause a ConnectionClosed error for + // any in-progress requests, which is expected behavior + await transport1.close(); + // Save the promise so we can catch it after closing + const catchPromise = toolPromise.catch(err => { + // This error is expected - the connection was intentionally closed + if (err?.code !== -32000) { + // ConnectionClosed error code + console.error('Unexpected error type during transport close:', err); + } + }); - // Create second client with same client ID - const client2 = new Client({ - id: clientId, - name: 'test-client', - version: '1.0.0' - }); + // Add a short delay to ensure clean disconnect before reconnecting + await new Promise(resolve => setTimeout(resolve, 10)); - // Set up notification handler for second client - client2.setNotificationHandler(LoggingMessageNotificationSchema, (notification) => { - if (notification.method === 'notifications/message') { - notifications.push(notification.params); - } - }); + // Wait for the rejection to be handled + await catchPromise; - // Connect second client with same session ID - const transport2 = new StreamableHTTPClientTransport(baseUrl, { - sessionId - }); - await client2.connect(transport2); - - // Resume the notification stream using lastEventId - // This is the key part - we're resuming the same long-running tool using lastEventId - await client2.request({ - method: 'tools/call', - params: { - name: 'run-notifications', - arguments: { - count: 1, - interval: 5 - } - } - }, CallToolResultSchema, { - resumptionToken: lastEventId, // Pass the lastEventId from the previous session - onresumptiontoken: onLastEventIdUpdate - }); + // Create second client with same client ID + const client2 = new Client({ + id: clientId, + name: 'test-client', + version: '1.0.0' + }); - // Verify we eventually received at leaset a few motifications - expect(notifications.length).toBeGreaterThan(1); + // Set up notification handler for second client + client2.setNotificationHandler(LoggingMessageNotificationSchema, notification => { + if (notification.method === 'notifications/message') { + notifications.push(notification.params); + } + }); + // Connect second client with same session ID + const transport2 = new StreamableHTTPClientTransport(baseUrl, { + sessionId + }); + await client2.connect(transport2); + + // Resume the notification stream using lastEventId + // This is the key part - we're resuming the same long-running tool using lastEventId + await client2.request( + { + method: 'tools/call', + params: { + name: 'run-notifications', + arguments: { + count: 1, + interval: 5 + } + } + }, + CallToolResultSchema, + { + resumptionToken: lastEventId, // Pass the lastEventId from the previous session + onresumptiontoken: onLastEventIdUpdate + } + ); - // Clean up - await transport2.close(); + // Verify we eventually received at leaset a few motifications + expect(notifications.length).toBeGreaterThan(1); - }); -}); \ No newline at end of file + // Clean up + await transport2.close(); + }); +}); diff --git a/src/server/auth/clients.ts b/src/server/auth/clients.ts index 8bbc6ac4d..4e3f8e17e 100644 --- a/src/server/auth/clients.ts +++ b/src/server/auth/clients.ts @@ -1,20 +1,22 @@ -import { OAuthClientInformationFull } from "../../shared/auth.js"; +import { OAuthClientInformationFull } from '../../shared/auth.js'; /** * Stores information about registered OAuth clients for this server. */ export interface OAuthRegisteredClientsStore { - /** - * Returns information about a registered client, based on its ID. - */ - getClient(clientId: string): OAuthClientInformationFull | undefined | Promise; + /** + * Returns information about a registered client, based on its ID. + */ + getClient(clientId: string): OAuthClientInformationFull | undefined | Promise; - /** - * Registers a new client with the server. The client ID and secret will be automatically generated by the library. A modified version of the client information can be returned to reflect specific values enforced by the server. - * - * NOTE: Implementations should NOT delete expired client secrets in-place. Auth middleware provided by this library will automatically check the `client_secret_expires_at` field and reject requests with expired secrets. Any custom logic for authenticating clients should check the `client_secret_expires_at` field as well. - * - * If unimplemented, dynamic client registration is unsupported. - */ - registerClient?(client: Omit): OAuthClientInformationFull | Promise; -} \ No newline at end of file + /** + * Registers a new client with the server. The client ID and secret will be automatically generated by the library. A modified version of the client information can be returned to reflect specific values enforced by the server. + * + * NOTE: Implementations should NOT delete expired client secrets in-place. Auth middleware provided by this library will automatically check the `client_secret_expires_at` field and reject requests with expired secrets. Any custom logic for authenticating clients should check the `client_secret_expires_at` field as well. + * + * If unimplemented, dynamic client registration is unsupported. + */ + registerClient?( + client: Omit + ): OAuthClientInformationFull | Promise; +} diff --git a/src/server/auth/errors.ts b/src/server/auth/errors.ts index 791b3b86c..e6871a19b 100644 --- a/src/server/auth/errors.ts +++ b/src/server/auth/errors.ts @@ -1,38 +1,38 @@ -import { OAuthErrorResponse } from "../../shared/auth.js"; +import { OAuthErrorResponse } from '../../shared/auth.js'; /** * Base class for all OAuth errors */ export class OAuthError extends Error { - static errorCode: string; - - constructor( - message: string, - public readonly errorUri?: string - ) { - super(message); - this.name = this.constructor.name; - } - - /** - * Converts the error to a standard OAuth error response object - */ - toResponseObject(): OAuthErrorResponse { - const response: OAuthErrorResponse = { - error: this.errorCode, - error_description: this.message - }; - - if (this.errorUri) { - response.error_uri = this.errorUri; + static errorCode: string; + + constructor( + message: string, + public readonly errorUri?: string + ) { + super(message); + this.name = this.constructor.name; } - return response; - } + /** + * Converts the error to a standard OAuth error response object + */ + toResponseObject(): OAuthErrorResponse { + const response: OAuthErrorResponse = { + error: this.errorCode, + error_description: this.message + }; - get errorCode(): string { - return (this.constructor as typeof OAuthError).errorCode - } + if (this.errorUri) { + response.error_uri = this.errorUri; + } + + return response; + } + + get errorCode(): string { + return (this.constructor as typeof OAuthError).errorCode; + } } /** @@ -41,7 +41,7 @@ export class OAuthError extends Error { * or is otherwise malformed. */ export class InvalidRequestError extends OAuthError { - static errorCode = "invalid_request"; + static errorCode = 'invalid_request'; } /** @@ -49,7 +49,7 @@ export class InvalidRequestError extends OAuthError { * authentication included, or unsupported authentication method). */ export class InvalidClientError extends OAuthError { - static errorCode = "invalid_client"; + static errorCode = 'invalid_client'; } /** @@ -58,7 +58,7 @@ export class InvalidClientError extends OAuthError { * authorization request, or was issued to another client. */ export class InvalidGrantError extends OAuthError { - static errorCode = "invalid_grant"; + static errorCode = 'invalid_grant'; } /** @@ -66,7 +66,7 @@ export class InvalidGrantError extends OAuthError { * this authorization grant type. */ export class UnauthorizedClientError extends OAuthError { - static errorCode = "unauthorized_client"; + static errorCode = 'unauthorized_client'; } /** @@ -74,7 +74,7 @@ export class UnauthorizedClientError extends OAuthError { * by the authorization server. */ export class UnsupportedGrantTypeError extends OAuthError { - static errorCode = "unsupported_grant_type"; + static errorCode = 'unsupported_grant_type'; } /** @@ -82,14 +82,14 @@ export class UnsupportedGrantTypeError extends OAuthError { * exceeds the scope granted by the resource owner. */ export class InvalidScopeError extends OAuthError { - static errorCode = "invalid_scope"; + static errorCode = 'invalid_scope'; } /** * Access denied error - The resource owner or authorization server denied the request. */ export class AccessDeniedError extends OAuthError { - static errorCode = "access_denied"; + static errorCode = 'access_denied'; } /** @@ -97,7 +97,7 @@ export class AccessDeniedError extends OAuthError { * that prevented it from fulfilling the request. */ export class ServerError extends OAuthError { - static errorCode = "server_error"; + static errorCode = 'server_error'; } /** @@ -105,7 +105,7 @@ export class ServerError extends OAuthError { * handle the request due to a temporary overloading or maintenance of the server. */ export class TemporarilyUnavailableError extends OAuthError { - static errorCode = "temporarily_unavailable"; + static errorCode = 'temporarily_unavailable'; } /** @@ -113,7 +113,7 @@ export class TemporarilyUnavailableError extends OAuthError { * obtaining an authorization code using this method. */ export class UnsupportedResponseTypeError extends OAuthError { - static errorCode = "unsupported_response_type"; + static errorCode = 'unsupported_response_type'; } /** @@ -121,7 +121,7 @@ export class UnsupportedResponseTypeError extends OAuthError { * the requested token type. */ export class UnsupportedTokenTypeError extends OAuthError { - static errorCode = "unsupported_token_type"; + static errorCode = 'unsupported_token_type'; } /** @@ -129,7 +129,7 @@ export class UnsupportedTokenTypeError extends OAuthError { * or invalid for other reasons. */ export class InvalidTokenError extends OAuthError { - static errorCode = "invalid_token"; + static errorCode = 'invalid_token'; } /** @@ -137,7 +137,7 @@ export class InvalidTokenError extends OAuthError { * (Custom, non-standard error) */ export class MethodNotAllowedError extends OAuthError { - static errorCode = "method_not_allowed"; + static errorCode = 'method_not_allowed'; } /** @@ -145,7 +145,7 @@ export class MethodNotAllowedError extends OAuthError { * (Custom, non-standard error based on RFC 6585) */ export class TooManyRequestsError extends OAuthError { - static errorCode = "too_many_requests"; + static errorCode = 'too_many_requests'; } /** @@ -153,47 +153,51 @@ export class TooManyRequestsError extends OAuthError { * (Custom error for dynamic client registration - RFC 7591) */ export class InvalidClientMetadataError extends OAuthError { - static errorCode = "invalid_client_metadata"; + static errorCode = 'invalid_client_metadata'; } /** * Insufficient scope error - The request requires higher privileges than provided by the access token. */ export class InsufficientScopeError extends OAuthError { - static errorCode = "insufficient_scope"; + static errorCode = 'insufficient_scope'; } /** * A utility class for defining one-off error codes */ export class CustomOAuthError extends OAuthError { - constructor(private readonly customErrorCode: string, message: string, errorUri?: string) { - super(message, errorUri); - } + constructor( + private readonly customErrorCode: string, + message: string, + errorUri?: string + ) { + super(message, errorUri); + } - get errorCode(): string { - return this.customErrorCode; - } + get errorCode(): string { + return this.customErrorCode; + } } /** * A full list of all OAuthErrors, enabling parsing from error responses */ export const OAUTH_ERRORS = { - [InvalidRequestError.errorCode]: InvalidRequestError, - [InvalidClientError.errorCode]: InvalidClientError, - [InvalidGrantError.errorCode]: InvalidGrantError, - [UnauthorizedClientError.errorCode]: UnauthorizedClientError, - [UnsupportedGrantTypeError.errorCode]: UnsupportedGrantTypeError, - [InvalidScopeError.errorCode]: InvalidScopeError, - [AccessDeniedError.errorCode]: AccessDeniedError, - [ServerError.errorCode]: ServerError, - [TemporarilyUnavailableError.errorCode]: TemporarilyUnavailableError, - [UnsupportedResponseTypeError.errorCode]: UnsupportedResponseTypeError, - [UnsupportedTokenTypeError.errorCode]: UnsupportedTokenTypeError, - [InvalidTokenError.errorCode]: InvalidTokenError, - [MethodNotAllowedError.errorCode]: MethodNotAllowedError, - [TooManyRequestsError.errorCode]: TooManyRequestsError, - [InvalidClientMetadataError.errorCode]: InvalidClientMetadataError, - [InsufficientScopeError.errorCode]: InsufficientScopeError, + [InvalidRequestError.errorCode]: InvalidRequestError, + [InvalidClientError.errorCode]: InvalidClientError, + [InvalidGrantError.errorCode]: InvalidGrantError, + [UnauthorizedClientError.errorCode]: UnauthorizedClientError, + [UnsupportedGrantTypeError.errorCode]: UnsupportedGrantTypeError, + [InvalidScopeError.errorCode]: InvalidScopeError, + [AccessDeniedError.errorCode]: AccessDeniedError, + [ServerError.errorCode]: ServerError, + [TemporarilyUnavailableError.errorCode]: TemporarilyUnavailableError, + [UnsupportedResponseTypeError.errorCode]: UnsupportedResponseTypeError, + [UnsupportedTokenTypeError.errorCode]: UnsupportedTokenTypeError, + [InvalidTokenError.errorCode]: InvalidTokenError, + [MethodNotAllowedError.errorCode]: MethodNotAllowedError, + [TooManyRequestsError.errorCode]: TooManyRequestsError, + [InvalidClientMetadataError.errorCode]: InvalidClientMetadataError, + [InsufficientScopeError.errorCode]: InsufficientScopeError } as const; diff --git a/src/server/auth/handlers/authorize.test.ts b/src/server/auth/handlers/authorize.test.ts index 438db6a6e..5dabd19e4 100644 --- a/src/server/auth/handlers/authorize.test.ts +++ b/src/server/auth/handlers/authorize.test.ts @@ -8,354 +8,319 @@ import { AuthInfo } from '../types.js'; import { InvalidTokenError } from '../errors.js'; describe('Authorization Handler', () => { - // Mock client data - const validClient: OAuthClientInformationFull = { - client_id: 'valid-client', - client_secret: 'valid-secret', - redirect_uris: ['https://example.com/callback'], - scope: 'profile email' - }; - - const multiRedirectClient: OAuthClientInformationFull = { - client_id: 'multi-redirect-client', - client_secret: 'valid-secret', - redirect_uris: [ - 'https://example.com/callback1', - 'https://example.com/callback2' - ], - scope: 'profile email' - }; - - // Mock client store - const mockClientStore: OAuthRegisteredClientsStore = { - async getClient(clientId: string): Promise { - if (clientId === 'valid-client') { - return validClient; - } else if (clientId === 'multi-redirect-client') { - return multiRedirectClient; - } - return undefined; - } - }; - - // Mock provider - const mockProvider: OAuthServerProvider = { - clientsStore: mockClientStore, - - async authorize(client: OAuthClientInformationFull, params: AuthorizationParams, res: Response): Promise { - // Mock implementation - redirects to redirectUri with code and state - const redirectUrl = new URL(params.redirectUri); - redirectUrl.searchParams.set('code', 'mock_auth_code'); - if (params.state) { - redirectUrl.searchParams.set('state', params.state); - } - res.redirect(302, redirectUrl.toString()); - }, - - async challengeForAuthorizationCode(): Promise { - return 'mock_challenge'; - }, - - async exchangeAuthorizationCode(): Promise { - return { - access_token: 'mock_access_token', - token_type: 'bearer', - expires_in: 3600, - refresh_token: 'mock_refresh_token' - }; - }, - - async exchangeRefreshToken(): Promise { - return { - access_token: 'new_mock_access_token', - token_type: 'bearer', - expires_in: 3600, - refresh_token: 'new_mock_refresh_token' - }; - }, - - async verifyAccessToken(token: string): Promise { - if (token === 'valid_token') { - return { - token, - clientId: 'valid-client', - scopes: ['read', 'write'], - expiresAt: Date.now() / 1000 + 3600 - }; - } - throw new InvalidTokenError('Token is invalid or expired'); - }, - - async revokeToken(): Promise { - // Do nothing in mock - } - }; - - // Setup express app with handler - let app: express.Express; - let options: AuthorizationHandlerOptions; - - beforeEach(() => { - app = express(); - options = { provider: mockProvider }; - const handler = authorizationHandler(options); - app.use('/authorize', handler); - }); - - describe('HTTP method validation', () => { - it('rejects non-GET/POST methods', async () => { - const response = await supertest(app) - .put('/authorize') - .query({ client_id: 'valid-client' }); - - expect(response.status).toBe(405); // Method not allowed response from handler + // Mock client data + const validClient: OAuthClientInformationFull = { + client_id: 'valid-client', + client_secret: 'valid-secret', + redirect_uris: ['https://example.com/callback'], + scope: 'profile email' + }; + + const multiRedirectClient: OAuthClientInformationFull = { + client_id: 'multi-redirect-client', + client_secret: 'valid-secret', + redirect_uris: ['https://example.com/callback1', 'https://example.com/callback2'], + scope: 'profile email' + }; + + // Mock client store + const mockClientStore: OAuthRegisteredClientsStore = { + async getClient(clientId: string): Promise { + if (clientId === 'valid-client') { + return validClient; + } else if (clientId === 'multi-redirect-client') { + return multiRedirectClient; + } + return undefined; + } + }; + + // Mock provider + const mockProvider: OAuthServerProvider = { + clientsStore: mockClientStore, + + async authorize(client: OAuthClientInformationFull, params: AuthorizationParams, res: Response): Promise { + // Mock implementation - redirects to redirectUri with code and state + const redirectUrl = new URL(params.redirectUri); + redirectUrl.searchParams.set('code', 'mock_auth_code'); + if (params.state) { + redirectUrl.searchParams.set('state', params.state); + } + res.redirect(302, redirectUrl.toString()); + }, + + async challengeForAuthorizationCode(): Promise { + return 'mock_challenge'; + }, + + async exchangeAuthorizationCode(): Promise { + return { + access_token: 'mock_access_token', + token_type: 'bearer', + expires_in: 3600, + refresh_token: 'mock_refresh_token' + }; + }, + + async exchangeRefreshToken(): Promise { + return { + access_token: 'new_mock_access_token', + token_type: 'bearer', + expires_in: 3600, + refresh_token: 'new_mock_refresh_token' + }; + }, + + async verifyAccessToken(token: string): Promise { + if (token === 'valid_token') { + return { + token, + clientId: 'valid-client', + scopes: ['read', 'write'], + expiresAt: Date.now() / 1000 + 3600 + }; + } + throw new InvalidTokenError('Token is invalid or expired'); + }, + + async revokeToken(): Promise { + // Do nothing in mock + } + }; + + // Setup express app with handler + let app: express.Express; + let options: AuthorizationHandlerOptions; + + beforeEach(() => { + app = express(); + options = { provider: mockProvider }; + const handler = authorizationHandler(options); + app.use('/authorize', handler); }); - }); - describe('Client validation', () => { - it('requires client_id parameter', async () => { - const response = await supertest(app) - .get('/authorize'); + describe('HTTP method validation', () => { + it('rejects non-GET/POST methods', async () => { + const response = await supertest(app).put('/authorize').query({ client_id: 'valid-client' }); - expect(response.status).toBe(400); - expect(response.text).toContain('client_id'); + expect(response.status).toBe(405); // Method not allowed response from handler + }); }); - it('validates that client exists', async () => { - const response = await supertest(app) - .get('/authorize') - .query({ client_id: 'nonexistent-client' }); + describe('Client validation', () => { + it('requires client_id parameter', async () => { + const response = await supertest(app).get('/authorize'); - expect(response.status).toBe(400); - }); - }); - - describe('Redirect URI validation', () => { - it('uses the only redirect_uri if client has just one and none provided', async () => { - const response = await supertest(app) - .get('/authorize') - .query({ - client_id: 'valid-client', - response_type: 'code', - code_challenge: 'challenge123', - code_challenge_method: 'S256' + expect(response.status).toBe(400); + expect(response.text).toContain('client_id'); }); - expect(response.status).toBe(302); - const location = new URL(response.header.location); - expect(location.origin + location.pathname).toBe('https://example.com/callback'); - }); + it('validates that client exists', async () => { + const response = await supertest(app).get('/authorize').query({ client_id: 'nonexistent-client' }); - it('requires redirect_uri if client has multiple', async () => { - const response = await supertest(app) - .get('/authorize') - .query({ - client_id: 'multi-redirect-client', - response_type: 'code', - code_challenge: 'challenge123', - code_challenge_method: 'S256' + expect(response.status).toBe(400); }); - - expect(response.status).toBe(400); }); - it('validates redirect_uri against client registered URIs', async () => { - const response = await supertest(app) - .get('/authorize') - .query({ - client_id: 'valid-client', - redirect_uri: 'https://malicious.com/callback', - response_type: 'code', - code_challenge: 'challenge123', - code_challenge_method: 'S256' + describe('Redirect URI validation', () => { + it('uses the only redirect_uri if client has just one and none provided', async () => { + const response = await supertest(app).get('/authorize').query({ + client_id: 'valid-client', + response_type: 'code', + code_challenge: 'challenge123', + code_challenge_method: 'S256' + }); + + expect(response.status).toBe(302); + const location = new URL(response.header.location); + expect(location.origin + location.pathname).toBe('https://example.com/callback'); }); - expect(response.status).toBe(400); - }); - - it('accepts valid redirect_uri that client registered with', async () => { - const response = await supertest(app) - .get('/authorize') - .query({ - client_id: 'valid-client', - redirect_uri: 'https://example.com/callback', - response_type: 'code', - code_challenge: 'challenge123', - code_challenge_method: 'S256' - }); + it('requires redirect_uri if client has multiple', async () => { + const response = await supertest(app).get('/authorize').query({ + client_id: 'multi-redirect-client', + response_type: 'code', + code_challenge: 'challenge123', + code_challenge_method: 'S256' + }); - expect(response.status).toBe(302); - const location = new URL(response.header.location); - expect(location.origin + location.pathname).toBe('https://example.com/callback'); - }); - }); - - describe('Authorization request validation', () => { - it('requires response_type=code', async () => { - const response = await supertest(app) - .get('/authorize') - .query({ - client_id: 'valid-client', - redirect_uri: 'https://example.com/callback', - response_type: 'token', // invalid - we only support code flow - code_challenge: 'challenge123', - code_challenge_method: 'S256' + expect(response.status).toBe(400); }); - expect(response.status).toBe(302); - const location = new URL(response.header.location); - expect(location.searchParams.get('error')).toBe('invalid_request'); - }); + it('validates redirect_uri against client registered URIs', async () => { + const response = await supertest(app).get('/authorize').query({ + client_id: 'valid-client', + redirect_uri: 'https://malicious.com/callback', + response_type: 'code', + code_challenge: 'challenge123', + code_challenge_method: 'S256' + }); - it('requires code_challenge parameter', async () => { - const response = await supertest(app) - .get('/authorize') - .query({ - client_id: 'valid-client', - redirect_uri: 'https://example.com/callback', - response_type: 'code', - code_challenge_method: 'S256' - // Missing code_challenge + expect(response.status).toBe(400); }); - expect(response.status).toBe(302); - const location = new URL(response.header.location); - expect(location.searchParams.get('error')).toBe('invalid_request'); + it('accepts valid redirect_uri that client registered with', async () => { + const response = await supertest(app).get('/authorize').query({ + client_id: 'valid-client', + redirect_uri: 'https://example.com/callback', + response_type: 'code', + code_challenge: 'challenge123', + code_challenge_method: 'S256' + }); + + expect(response.status).toBe(302); + const location = new URL(response.header.location); + expect(location.origin + location.pathname).toBe('https://example.com/callback'); + }); }); - it('requires code_challenge_method=S256', async () => { - const response = await supertest(app) - .get('/authorize') - .query({ - client_id: 'valid-client', - redirect_uri: 'https://example.com/callback', - response_type: 'code', - code_challenge: 'challenge123', - code_challenge_method: 'plain' // Only S256 is supported + describe('Authorization request validation', () => { + it('requires response_type=code', async () => { + const response = await supertest(app).get('/authorize').query({ + client_id: 'valid-client', + redirect_uri: 'https://example.com/callback', + response_type: 'token', // invalid - we only support code flow + code_challenge: 'challenge123', + code_challenge_method: 'S256' + }); + + expect(response.status).toBe(302); + const location = new URL(response.header.location); + expect(location.searchParams.get('error')).toBe('invalid_request'); }); - expect(response.status).toBe(302); - const location = new URL(response.header.location); - expect(location.searchParams.get('error')).toBe('invalid_request'); - }); - }); - - describe('Scope validation', () => { - it('validates requested scopes against client registered scopes', async () => { - const response = await supertest(app) - .get('/authorize') - .query({ - client_id: 'valid-client', - redirect_uri: 'https://example.com/callback', - response_type: 'code', - code_challenge: 'challenge123', - code_challenge_method: 'S256', - scope: 'profile email admin' // 'admin' not in client scopes + it('requires code_challenge parameter', async () => { + const response = await supertest(app).get('/authorize').query({ + client_id: 'valid-client', + redirect_uri: 'https://example.com/callback', + response_type: 'code', + code_challenge_method: 'S256' + // Missing code_challenge + }); + + expect(response.status).toBe(302); + const location = new URL(response.header.location); + expect(location.searchParams.get('error')).toBe('invalid_request'); }); - expect(response.status).toBe(302); - const location = new URL(response.header.location); - expect(location.searchParams.get('error')).toBe('invalid_scope'); + it('requires code_challenge_method=S256', async () => { + const response = await supertest(app).get('/authorize').query({ + client_id: 'valid-client', + redirect_uri: 'https://example.com/callback', + response_type: 'code', + code_challenge: 'challenge123', + code_challenge_method: 'plain' // Only S256 is supported + }); + + expect(response.status).toBe(302); + const location = new URL(response.header.location); + expect(location.searchParams.get('error')).toBe('invalid_request'); + }); }); - it('accepts valid scopes subset', async () => { - const response = await supertest(app) - .get('/authorize') - .query({ - client_id: 'valid-client', - redirect_uri: 'https://example.com/callback', - response_type: 'code', - code_challenge: 'challenge123', - code_challenge_method: 'S256', - scope: 'profile' // subset of client scopes + describe('Scope validation', () => { + it('validates requested scopes against client registered scopes', async () => { + const response = await supertest(app).get('/authorize').query({ + client_id: 'valid-client', + redirect_uri: 'https://example.com/callback', + response_type: 'code', + code_challenge: 'challenge123', + code_challenge_method: 'S256', + scope: 'profile email admin' // 'admin' not in client scopes + }); + + expect(response.status).toBe(302); + const location = new URL(response.header.location); + expect(location.searchParams.get('error')).toBe('invalid_scope'); }); - expect(response.status).toBe(302); - const location = new URL(response.header.location); - expect(location.searchParams.has('code')).toBe(true); - }); - }); - - describe('Resource parameter validation', () => { - it('propagates resource parameter', async () => { - const mockProviderWithResource = jest.spyOn(mockProvider, 'authorize'); - - const response = await supertest(app) - .get('/authorize') - .query({ - client_id: 'valid-client', - redirect_uri: 'https://example.com/callback', - response_type: 'code', - code_challenge: 'challenge123', - code_challenge_method: 'S256', - resource: 'https://api.example.com/resource' + it('accepts valid scopes subset', async () => { + const response = await supertest(app).get('/authorize').query({ + client_id: 'valid-client', + redirect_uri: 'https://example.com/callback', + response_type: 'code', + code_challenge: 'challenge123', + code_challenge_method: 'S256', + scope: 'profile' // subset of client scopes + }); + + expect(response.status).toBe(302); + const location = new URL(response.header.location); + expect(location.searchParams.has('code')).toBe(true); }); - - expect(response.status).toBe(302); - expect(mockProviderWithResource).toHaveBeenCalledWith( - validClient, - expect.objectContaining({ - resource: new URL('https://api.example.com/resource'), - redirectUri: 'https://example.com/callback', - codeChallenge: 'challenge123' - }), - expect.any(Object) - ); }); - }); - - describe('Successful authorization', () => { - it('handles successful authorization with all parameters', async () => { - const response = await supertest(app) - .get('/authorize') - .query({ - client_id: 'valid-client', - redirect_uri: 'https://example.com/callback', - response_type: 'code', - code_challenge: 'challenge123', - code_challenge_method: 'S256', - scope: 'profile email', - state: 'xyz789' - }); - expect(response.status).toBe(302); - const location = new URL(response.header.location); - expect(location.origin + location.pathname).toBe('https://example.com/callback'); - expect(location.searchParams.get('code')).toBe('mock_auth_code'); - expect(location.searchParams.get('state')).toBe('xyz789'); + describe('Resource parameter validation', () => { + it('propagates resource parameter', async () => { + const mockProviderWithResource = jest.spyOn(mockProvider, 'authorize'); + + const response = await supertest(app).get('/authorize').query({ + client_id: 'valid-client', + redirect_uri: 'https://example.com/callback', + response_type: 'code', + code_challenge: 'challenge123', + code_challenge_method: 'S256', + resource: 'https://api.example.com/resource' + }); + + expect(response.status).toBe(302); + expect(mockProviderWithResource).toHaveBeenCalledWith( + validClient, + expect.objectContaining({ + resource: new URL('https://api.example.com/resource'), + redirectUri: 'https://example.com/callback', + codeChallenge: 'challenge123' + }), + expect.any(Object) + ); + }); }); - it('preserves state parameter in response', async () => { - const response = await supertest(app) - .get('/authorize') - .query({ - client_id: 'valid-client', - redirect_uri: 'https://example.com/callback', - response_type: 'code', - code_challenge: 'challenge123', - code_challenge_method: 'S256', - state: 'state-value-123' + describe('Successful authorization', () => { + it('handles successful authorization with all parameters', async () => { + const response = await supertest(app).get('/authorize').query({ + client_id: 'valid-client', + redirect_uri: 'https://example.com/callback', + response_type: 'code', + code_challenge: 'challenge123', + code_challenge_method: 'S256', + scope: 'profile email', + state: 'xyz789' + }); + + expect(response.status).toBe(302); + const location = new URL(response.header.location); + expect(location.origin + location.pathname).toBe('https://example.com/callback'); + expect(location.searchParams.get('code')).toBe('mock_auth_code'); + expect(location.searchParams.get('state')).toBe('xyz789'); }); - expect(response.status).toBe(302); - const location = new URL(response.header.location); - expect(location.searchParams.get('state')).toBe('state-value-123'); - }); - - it('handles POST requests the same as GET', async () => { - const response = await supertest(app) - .post('/authorize') - .type('form') - .send({ - client_id: 'valid-client', - response_type: 'code', - code_challenge: 'challenge123', - code_challenge_method: 'S256' + it('preserves state parameter in response', async () => { + const response = await supertest(app).get('/authorize').query({ + client_id: 'valid-client', + redirect_uri: 'https://example.com/callback', + response_type: 'code', + code_challenge: 'challenge123', + code_challenge_method: 'S256', + state: 'state-value-123' + }); + + expect(response.status).toBe(302); + const location = new URL(response.header.location); + expect(location.searchParams.get('state')).toBe('state-value-123'); }); - expect(response.status).toBe(302); - const location = new URL(response.header.location); - expect(location.searchParams.has('code')).toBe(true); + it('handles POST requests the same as GET', async () => { + const response = await supertest(app).post('/authorize').type('form').send({ + client_id: 'valid-client', + response_type: 'code', + code_challenge: 'challenge123', + code_challenge_method: 'S256' + }); + + expect(response.status).toBe(302); + const location = new URL(response.header.location); + expect(location.searchParams.has('code')).toBe(true); + }); }); - }); -}); \ No newline at end of file +}); diff --git a/src/server/auth/handlers/authorize.ts b/src/server/auth/handlers/authorize.ts index 126ce006b..14c7121ae 100644 --- a/src/server/auth/handlers/authorize.ts +++ b/src/server/auth/handlers/authorize.ts @@ -1,171 +1,173 @@ -import { RequestHandler } from "express"; -import { z } from "zod"; -import express from "express"; -import { OAuthServerProvider } from "../provider.js"; -import { rateLimit, Options as RateLimitOptions } from "express-rate-limit"; -import { allowedMethods } from "../middleware/allowedMethods.js"; -import { - InvalidRequestError, - InvalidClientError, - InvalidScopeError, - ServerError, - TooManyRequestsError, - OAuthError -} from "../errors.js"; +import { RequestHandler } from 'express'; +import { z } from 'zod'; +import express from 'express'; +import { OAuthServerProvider } from '../provider.js'; +import { rateLimit, Options as RateLimitOptions } from 'express-rate-limit'; +import { allowedMethods } from '../middleware/allowedMethods.js'; +import { InvalidRequestError, InvalidClientError, InvalidScopeError, ServerError, TooManyRequestsError, OAuthError } from '../errors.js'; export type AuthorizationHandlerOptions = { - provider: OAuthServerProvider; - /** - * Rate limiting configuration for the authorization endpoint. - * Set to false to disable rate limiting for this endpoint. - */ - rateLimit?: Partial | false; + provider: OAuthServerProvider; + /** + * Rate limiting configuration for the authorization endpoint. + * Set to false to disable rate limiting for this endpoint. + */ + rateLimit?: Partial | false; }; // Parameters that must be validated in order to issue redirects. const ClientAuthorizationParamsSchema = z.object({ - client_id: z.string(), - redirect_uri: z.string().optional().refine((value) => value === undefined || URL.canParse(value), { message: "redirect_uri must be a valid URL" }), + client_id: z.string(), + redirect_uri: z + .string() + .optional() + .refine(value => value === undefined || URL.canParse(value), { message: 'redirect_uri must be a valid URL' }) }); // Parameters that must be validated for a successful authorization request. Failure can be reported to the redirect URI. const RequestAuthorizationParamsSchema = z.object({ - response_type: z.literal("code"), - code_challenge: z.string(), - code_challenge_method: z.literal("S256"), - scope: z.string().optional(), - state: z.string().optional(), - resource: z.string().url().optional(), + response_type: z.literal('code'), + code_challenge: z.string(), + code_challenge_method: z.literal('S256'), + scope: z.string().optional(), + state: z.string().optional(), + resource: z.string().url().optional() }); export function authorizationHandler({ provider, rateLimit: rateLimitConfig }: AuthorizationHandlerOptions): RequestHandler { - // Create a router to apply middleware - const router = express.Router(); - router.use(allowedMethods(["GET", "POST"])); - router.use(express.urlencoded({ extended: false })); - - // Apply rate limiting unless explicitly disabled - if (rateLimitConfig !== false) { - router.use(rateLimit({ - windowMs: 15 * 60 * 1000, // 15 minutes - max: 100, // 100 requests per windowMs - standardHeaders: true, - legacyHeaders: false, - message: new TooManyRequestsError('You have exceeded the rate limit for authorization requests').toResponseObject(), - ...rateLimitConfig - })); - } - - router.all("/", async (req, res) => { - res.setHeader('Cache-Control', 'no-store'); - - // In the authorization flow, errors are split into two categories: - // 1. Pre-redirect errors (direct response with 400) - // 2. Post-redirect errors (redirect with error parameters) - - // Phase 1: Validate client_id and redirect_uri. Any errors here must be direct responses. - let client_id, redirect_uri, client; - try { - const result = ClientAuthorizationParamsSchema.safeParse(req.method === 'POST' ? req.body : req.query); - if (!result.success) { - throw new InvalidRequestError(result.error.message); - } - - client_id = result.data.client_id; - redirect_uri = result.data.redirect_uri; - - client = await provider.clientsStore.getClient(client_id); - if (!client) { - throw new InvalidClientError("Invalid client_id"); - } - - if (redirect_uri !== undefined) { - if (!client.redirect_uris.includes(redirect_uri)) { - throw new InvalidRequestError("Unregistered redirect_uri"); - } - } else if (client.redirect_uris.length === 1) { - redirect_uri = client.redirect_uris[0]; - } else { - throw new InvalidRequestError("redirect_uri must be specified when client has multiple registered URIs"); - } - } catch (error) { - // Pre-redirect errors - return direct response - // - // These don't need to be JSON encoded, as they'll be displayed in a user - // agent, but OTOH they all represent exceptional situations (arguably, - // "programmer error"), so presenting a nice HTML page doesn't help the - // user anyway. - if (error instanceof OAuthError) { - const status = error instanceof ServerError ? 500 : 400; - res.status(status).json(error.toResponseObject()); - } else { - const serverError = new ServerError("Internal Server Error"); - res.status(500).json(serverError.toResponseObject()); - } - - return; + // Create a router to apply middleware + const router = express.Router(); + router.use(allowedMethods(['GET', 'POST'])); + router.use(express.urlencoded({ extended: false })); + + // Apply rate limiting unless explicitly disabled + if (rateLimitConfig !== false) { + router.use( + rateLimit({ + windowMs: 15 * 60 * 1000, // 15 minutes + max: 100, // 100 requests per windowMs + standardHeaders: true, + legacyHeaders: false, + message: new TooManyRequestsError('You have exceeded the rate limit for authorization requests').toResponseObject(), + ...rateLimitConfig + }) + ); } - // Phase 2: Validate other parameters. Any errors here should go into redirect responses. - let state; - try { - // Parse and validate authorization parameters - const parseResult = RequestAuthorizationParamsSchema.safeParse(req.method === 'POST' ? req.body : req.query); - if (!parseResult.success) { - throw new InvalidRequestError(parseResult.error.message); - } - - const { scope, code_challenge, resource } = parseResult.data; - state = parseResult.data.state; - - // Validate scopes - let requestedScopes: string[] = []; - if (scope !== undefined) { - requestedScopes = scope.split(" "); - const allowedScopes = new Set(client.scope?.split(" ")); - - // Check each requested scope against allowed scopes - for (const scope of requestedScopes) { - if (!allowedScopes.has(scope)) { - throw new InvalidScopeError(`Client was not registered with scope ${scope}`); - } + router.all('/', async (req, res) => { + res.setHeader('Cache-Control', 'no-store'); + + // In the authorization flow, errors are split into two categories: + // 1. Pre-redirect errors (direct response with 400) + // 2. Post-redirect errors (redirect with error parameters) + + // Phase 1: Validate client_id and redirect_uri. Any errors here must be direct responses. + let client_id, redirect_uri, client; + try { + const result = ClientAuthorizationParamsSchema.safeParse(req.method === 'POST' ? req.body : req.query); + if (!result.success) { + throw new InvalidRequestError(result.error.message); + } + + client_id = result.data.client_id; + redirect_uri = result.data.redirect_uri; + + client = await provider.clientsStore.getClient(client_id); + if (!client) { + throw new InvalidClientError('Invalid client_id'); + } + + if (redirect_uri !== undefined) { + if (!client.redirect_uris.includes(redirect_uri)) { + throw new InvalidRequestError('Unregistered redirect_uri'); + } + } else if (client.redirect_uris.length === 1) { + redirect_uri = client.redirect_uris[0]; + } else { + throw new InvalidRequestError('redirect_uri must be specified when client has multiple registered URIs'); + } + } catch (error) { + // Pre-redirect errors - return direct response + // + // These don't need to be JSON encoded, as they'll be displayed in a user + // agent, but OTOH they all represent exceptional situations (arguably, + // "programmer error"), so presenting a nice HTML page doesn't help the + // user anyway. + if (error instanceof OAuthError) { + const status = error instanceof ServerError ? 500 : 400; + res.status(status).json(error.toResponseObject()); + } else { + const serverError = new ServerError('Internal Server Error'); + res.status(500).json(serverError.toResponseObject()); + } + + return; } - } - - // All validation passed, proceed with authorization - await provider.authorize(client, { - state, - scopes: requestedScopes, - redirectUri: redirect_uri, - codeChallenge: code_challenge, - resource: resource ? new URL(resource) : undefined, - }, res); - } catch (error) { - // Post-redirect errors - redirect with error parameters - if (error instanceof OAuthError) { - res.redirect(302, createErrorRedirect(redirect_uri, error, state)); - } else { - const serverError = new ServerError("Internal Server Error"); - res.redirect(302, createErrorRedirect(redirect_uri, serverError, state)); - } - } - }); - return router; + // Phase 2: Validate other parameters. Any errors here should go into redirect responses. + let state; + try { + // Parse and validate authorization parameters + const parseResult = RequestAuthorizationParamsSchema.safeParse(req.method === 'POST' ? req.body : req.query); + if (!parseResult.success) { + throw new InvalidRequestError(parseResult.error.message); + } + + const { scope, code_challenge, resource } = parseResult.data; + state = parseResult.data.state; + + // Validate scopes + let requestedScopes: string[] = []; + if (scope !== undefined) { + requestedScopes = scope.split(' '); + const allowedScopes = new Set(client.scope?.split(' ')); + + // Check each requested scope against allowed scopes + for (const scope of requestedScopes) { + if (!allowedScopes.has(scope)) { + throw new InvalidScopeError(`Client was not registered with scope ${scope}`); + } + } + } + + // All validation passed, proceed with authorization + await provider.authorize( + client, + { + state, + scopes: requestedScopes, + redirectUri: redirect_uri, + codeChallenge: code_challenge, + resource: resource ? new URL(resource) : undefined + }, + res + ); + } catch (error) { + // Post-redirect errors - redirect with error parameters + if (error instanceof OAuthError) { + res.redirect(302, createErrorRedirect(redirect_uri, error, state)); + } else { + const serverError = new ServerError('Internal Server Error'); + res.redirect(302, createErrorRedirect(redirect_uri, serverError, state)); + } + } + }); + + return router; } /** * Helper function to create redirect URL with error parameters */ function createErrorRedirect(redirectUri: string, error: OAuthError, state?: string): string { - const errorUrl = new URL(redirectUri); - errorUrl.searchParams.set("error", error.errorCode); - errorUrl.searchParams.set("error_description", error.message); - if (error.errorUri) { - errorUrl.searchParams.set("error_uri", error.errorUri); - } - if (state) { - errorUrl.searchParams.set("state", state); - } - return errorUrl.href; -} \ No newline at end of file + const errorUrl = new URL(redirectUri); + errorUrl.searchParams.set('error', error.errorCode); + errorUrl.searchParams.set('error_description', error.message); + if (error.errorUri) { + errorUrl.searchParams.set('error_uri', error.errorUri); + } + if (state) { + errorUrl.searchParams.set('state', state); + } + return errorUrl.href; +} diff --git a/src/server/auth/handlers/metadata.test.ts b/src/server/auth/handlers/metadata.test.ts index 9f70b9654..32feb6429 100644 --- a/src/server/auth/handlers/metadata.test.ts +++ b/src/server/auth/handlers/metadata.test.ts @@ -4,81 +4,75 @@ import express from 'express'; import supertest from 'supertest'; describe('Metadata Handler', () => { - const exampleMetadata: OAuthMetadata = { - issuer: 'https://auth.example.com', - authorization_endpoint: 'https://auth.example.com/authorize', - token_endpoint: 'https://auth.example.com/token', - registration_endpoint: 'https://auth.example.com/register', - revocation_endpoint: 'https://auth.example.com/revoke', - scopes_supported: ['profile', 'email'], - response_types_supported: ['code'], - grant_types_supported: ['authorization_code', 'refresh_token'], - token_endpoint_auth_methods_supported: ['client_secret_basic'], - code_challenge_methods_supported: ['S256'] - }; + const exampleMetadata: OAuthMetadata = { + issuer: 'https://auth.example.com', + authorization_endpoint: 'https://auth.example.com/authorize', + token_endpoint: 'https://auth.example.com/token', + registration_endpoint: 'https://auth.example.com/register', + revocation_endpoint: 'https://auth.example.com/revoke', + scopes_supported: ['profile', 'email'], + response_types_supported: ['code'], + grant_types_supported: ['authorization_code', 'refresh_token'], + token_endpoint_auth_methods_supported: ['client_secret_basic'], + code_challenge_methods_supported: ['S256'] + }; - let app: express.Express; + let app: express.Express; - beforeEach(() => { - // Setup express app with metadata handler - app = express(); - app.use('/.well-known/oauth-authorization-server', metadataHandler(exampleMetadata)); - }); + beforeEach(() => { + // Setup express app with metadata handler + app = express(); + app.use('/.well-known/oauth-authorization-server', metadataHandler(exampleMetadata)); + }); - it('requires GET method', async () => { - const response = await supertest(app) - .post('/.well-known/oauth-authorization-server') - .send({}); + it('requires GET method', async () => { + const response = await supertest(app).post('/.well-known/oauth-authorization-server').send({}); - expect(response.status).toBe(405); - expect(response.headers.allow).toBe('GET'); - expect(response.body).toEqual({ - error: "method_not_allowed", - error_description: "The method POST is not allowed for this endpoint" + expect(response.status).toBe(405); + expect(response.headers.allow).toBe('GET'); + expect(response.body).toEqual({ + error: 'method_not_allowed', + error_description: 'The method POST is not allowed for this endpoint' + }); }); - }); - it('returns the metadata object', async () => { - const response = await supertest(app) - .get('/.well-known/oauth-authorization-server'); + it('returns the metadata object', async () => { + const response = await supertest(app).get('/.well-known/oauth-authorization-server'); - expect(response.status).toBe(200); - expect(response.body).toEqual(exampleMetadata); - }); + expect(response.status).toBe(200); + expect(response.body).toEqual(exampleMetadata); + }); - it('includes CORS headers in response', async () => { - const response = await supertest(app) - .get('/.well-known/oauth-authorization-server') - .set('Origin', 'https://example.com'); + it('includes CORS headers in response', async () => { + const response = await supertest(app).get('/.well-known/oauth-authorization-server').set('Origin', 'https://example.com'); - expect(response.header['access-control-allow-origin']).toBe('*'); - }); + expect(response.header['access-control-allow-origin']).toBe('*'); + }); - it('supports OPTIONS preflight requests', async () => { - const response = await supertest(app) - .options('/.well-known/oauth-authorization-server') - .set('Origin', 'https://example.com') - .set('Access-Control-Request-Method', 'GET'); + it('supports OPTIONS preflight requests', async () => { + const response = await supertest(app) + .options('/.well-known/oauth-authorization-server') + .set('Origin', 'https://example.com') + .set('Access-Control-Request-Method', 'GET'); - expect(response.status).toBe(204); - expect(response.header['access-control-allow-origin']).toBe('*'); - }); + expect(response.status).toBe(204); + expect(response.header['access-control-allow-origin']).toBe('*'); + }); - it('works with minimal metadata', async () => { - // Setup a new express app with minimal metadata - const minimalApp = express(); - const minimalMetadata: OAuthMetadata = { - issuer: 'https://auth.example.com', - authorization_endpoint: 'https://auth.example.com/authorize', - token_endpoint: 'https://auth.example.com/token', - response_types_supported: ['code'] - }; - minimalApp.use('/.well-known/oauth-authorization-server', metadataHandler(minimalMetadata)); + it('works with minimal metadata', async () => { + // Setup a new express app with minimal metadata + const minimalApp = express(); + const minimalMetadata: OAuthMetadata = { + issuer: 'https://auth.example.com', + authorization_endpoint: 'https://auth.example.com/authorize', + token_endpoint: 'https://auth.example.com/token', + response_types_supported: ['code'] + }; + minimalApp.use('/.well-known/oauth-authorization-server', metadataHandler(minimalMetadata)); - const response = await supertest(minimalApp) - .get('/.well-known/oauth-authorization-server'); + const response = await supertest(minimalApp).get('/.well-known/oauth-authorization-server'); - expect(response.status).toBe(200); - expect(response.body).toEqual(minimalMetadata); - }); -}); \ No newline at end of file + expect(response.status).toBe(200); + expect(response.body).toEqual(minimalMetadata); + }); +}); diff --git a/src/server/auth/handlers/metadata.ts b/src/server/auth/handlers/metadata.ts index 444b85054..d8ca0e62d 100644 --- a/src/server/auth/handlers/metadata.ts +++ b/src/server/auth/handlers/metadata.ts @@ -1,19 +1,19 @@ -import express, { RequestHandler } from "express"; -import { OAuthMetadata, OAuthProtectedResourceMetadata } from "../../../shared/auth.js"; +import express, { RequestHandler } from 'express'; +import { OAuthMetadata, OAuthProtectedResourceMetadata } from '../../../shared/auth.js'; import cors from 'cors'; -import { allowedMethods } from "../middleware/allowedMethods.js"; +import { allowedMethods } from '../middleware/allowedMethods.js'; export function metadataHandler(metadata: OAuthMetadata | OAuthProtectedResourceMetadata): RequestHandler { - // Nested router so we can configure middleware and restrict HTTP method - const router = express.Router(); + // Nested router so we can configure middleware and restrict HTTP method + const router = express.Router(); - // Configure CORS to allow any origin, to make accessible to web-based MCP clients - router.use(cors()); + // Configure CORS to allow any origin, to make accessible to web-based MCP clients + router.use(cors()); - router.use(allowedMethods(['GET'])); - router.get("/", (req, res) => { - res.status(200).json(metadata); - }); + router.use(allowedMethods(['GET'])); + router.get('/', (req, res) => { + res.status(200).json(metadata); + }); - return router; + return router; } diff --git a/src/server/auth/handlers/register.test.ts b/src/server/auth/handlers/register.test.ts index 1c3f16cb0..c4821431a 100644 --- a/src/server/auth/handlers/register.test.ts +++ b/src/server/auth/handlers/register.test.ts @@ -5,279 +5,267 @@ import express from 'express'; import supertest from 'supertest'; describe('Client Registration Handler', () => { - // Mock client store with registration support - const mockClientStoreWithRegistration: OAuthRegisteredClientsStore = { - async getClient(_clientId: string): Promise { - return undefined; - }, - - async registerClient(client: OAuthClientInformationFull): Promise { - // Return the client info as-is in the mock - return client; - } - }; - - // Mock client store without registration support - const mockClientStoreWithoutRegistration: OAuthRegisteredClientsStore = { - async getClient(_clientId: string): Promise { - return undefined; - } - // No registerClient method - }; - - describe('Handler creation', () => { - it('throws error if client store does not support registration', () => { - const options: ClientRegistrationHandlerOptions = { - clientsStore: mockClientStoreWithoutRegistration - }; - - expect(() => clientRegistrationHandler(options)).toThrow('does not support registering clients'); - }); + // Mock client store with registration support + const mockClientStoreWithRegistration: OAuthRegisteredClientsStore = { + async getClient(_clientId: string): Promise { + return undefined; + }, + + async registerClient(client: OAuthClientInformationFull): Promise { + // Return the client info as-is in the mock + return client; + } + }; + + // Mock client store without registration support + const mockClientStoreWithoutRegistration: OAuthRegisteredClientsStore = { + async getClient(_clientId: string): Promise { + return undefined; + } + // No registerClient method + }; + + describe('Handler creation', () => { + it('throws error if client store does not support registration', () => { + const options: ClientRegistrationHandlerOptions = { + clientsStore: mockClientStoreWithoutRegistration + }; + + expect(() => clientRegistrationHandler(options)).toThrow('does not support registering clients'); + }); - it('creates handler if client store supports registration', () => { - const options: ClientRegistrationHandlerOptions = { - clientsStore: mockClientStoreWithRegistration - }; + it('creates handler if client store supports registration', () => { + const options: ClientRegistrationHandlerOptions = { + clientsStore: mockClientStoreWithRegistration + }; - expect(() => clientRegistrationHandler(options)).not.toThrow(); + expect(() => clientRegistrationHandler(options)).not.toThrow(); + }); }); - }); - describe('Request handling', () => { - let app: express.Express; - let spyRegisterClient: jest.SpyInstance; + describe('Request handling', () => { + let app: express.Express; + let spyRegisterClient: jest.SpyInstance; - beforeEach(() => { - // Setup express app with registration handler - app = express(); - const options: ClientRegistrationHandlerOptions = { - clientsStore: mockClientStoreWithRegistration, - clientSecretExpirySeconds: 86400 // 1 day for testing - }; + beforeEach(() => { + // Setup express app with registration handler + app = express(); + const options: ClientRegistrationHandlerOptions = { + clientsStore: mockClientStoreWithRegistration, + clientSecretExpirySeconds: 86400 // 1 day for testing + }; - app.use('/register', clientRegistrationHandler(options)); + app.use('/register', clientRegistrationHandler(options)); - // Spy on the registerClient method - spyRegisterClient = jest.spyOn(mockClientStoreWithRegistration, 'registerClient'); - }); + // Spy on the registerClient method + spyRegisterClient = jest.spyOn(mockClientStoreWithRegistration, 'registerClient'); + }); - afterEach(() => { - spyRegisterClient.mockRestore(); - }); + afterEach(() => { + spyRegisterClient.mockRestore(); + }); - it('requires POST method', async () => { - const response = await supertest(app) - .get('/register') - .send({ - redirect_uris: ['https://example.com/callback'] + it('requires POST method', async () => { + const response = await supertest(app) + .get('/register') + .send({ + redirect_uris: ['https://example.com/callback'] + }); + + expect(response.status).toBe(405); + expect(response.headers.allow).toBe('POST'); + expect(response.body).toEqual({ + error: 'method_not_allowed', + error_description: 'The method GET is not allowed for this endpoint' + }); + expect(spyRegisterClient).not.toHaveBeenCalled(); }); - expect(response.status).toBe(405); - expect(response.headers.allow).toBe('POST'); - expect(response.body).toEqual({ - error: "method_not_allowed", - error_description: "The method GET is not allowed for this endpoint" - }); - expect(spyRegisterClient).not.toHaveBeenCalled(); - }); + it('validates required client metadata', async () => { + const response = await supertest(app).post('/register').send({ + // Missing redirect_uris (required) + client_name: 'Test Client' + }); - it('validates required client metadata', async () => { - const response = await supertest(app) - .post('/register') - .send({ - // Missing redirect_uris (required) - client_name: 'Test Client' + expect(response.status).toBe(400); + expect(response.body.error).toBe('invalid_client_metadata'); + expect(spyRegisterClient).not.toHaveBeenCalled(); }); - expect(response.status).toBe(400); - expect(response.body.error).toBe('invalid_client_metadata'); - expect(spyRegisterClient).not.toHaveBeenCalled(); - }); - - it('validates redirect URIs format', async () => { - const response = await supertest(app) - .post('/register') - .send({ - redirect_uris: ['invalid-url'] // Invalid URL format + it('validates redirect URIs format', async () => { + const response = await supertest(app) + .post('/register') + .send({ + redirect_uris: ['invalid-url'] // Invalid URL format + }); + + expect(response.status).toBe(400); + expect(response.body.error).toBe('invalid_client_metadata'); + expect(response.body.error_description).toContain('redirect_uris'); + expect(spyRegisterClient).not.toHaveBeenCalled(); }); - expect(response.status).toBe(400); - expect(response.body.error).toBe('invalid_client_metadata'); - expect(response.body.error_description).toContain('redirect_uris'); - expect(spyRegisterClient).not.toHaveBeenCalled(); - }); + it('successfully registers client with minimal metadata', async () => { + const clientMetadata: OAuthClientMetadata = { + redirect_uris: ['https://example.com/callback'] + }; - it('successfully registers client with minimal metadata', async () => { - const clientMetadata: OAuthClientMetadata = { - redirect_uris: ['https://example.com/callback'] - }; + const response = await supertest(app).post('/register').send(clientMetadata); - const response = await supertest(app) - .post('/register') - .send(clientMetadata); + expect(response.status).toBe(201); - expect(response.status).toBe(201); + // Verify the generated client information + expect(response.body.client_id).toBeDefined(); + expect(response.body.client_secret).toBeDefined(); + expect(response.body.client_id_issued_at).toBeDefined(); + expect(response.body.client_secret_expires_at).toBeDefined(); + expect(response.body.redirect_uris).toEqual(['https://example.com/callback']); - // Verify the generated client information - expect(response.body.client_id).toBeDefined(); - expect(response.body.client_secret).toBeDefined(); - expect(response.body.client_id_issued_at).toBeDefined(); - expect(response.body.client_secret_expires_at).toBeDefined(); - expect(response.body.redirect_uris).toEqual(['https://example.com/callback']); + // Verify client was registered + expect(spyRegisterClient).toHaveBeenCalledTimes(1); + }); - // Verify client was registered - expect(spyRegisterClient).toHaveBeenCalledTimes(1); - }); + it('sets client_secret to undefined for token_endpoint_auth_method=none', async () => { + const clientMetadata: OAuthClientMetadata = { + redirect_uris: ['https://example.com/callback'], + token_endpoint_auth_method: 'none' + }; - it('sets client_secret to undefined for token_endpoint_auth_method=none', async () => { - const clientMetadata: OAuthClientMetadata = { - redirect_uris: ['https://example.com/callback'], - token_endpoint_auth_method: 'none' - }; + const response = await supertest(app).post('/register').send(clientMetadata); - const response = await supertest(app) - .post('/register') - .send(clientMetadata); + expect(response.status).toBe(201); + expect(response.body.client_secret).toBeUndefined(); + expect(response.body.client_secret_expires_at).toBeUndefined(); + }); - expect(response.status).toBe(201); - expect(response.body.client_secret).toBeUndefined(); - expect(response.body.client_secret_expires_at).toBeUndefined(); - }); - - it('sets client_secret_expires_at for public clients only', async () => { - // Test for public client (token_endpoint_auth_method not 'none') - const publicClientMetadata: OAuthClientMetadata = { - redirect_uris: ['https://example.com/callback'], - token_endpoint_auth_method: 'client_secret_basic' - }; - - const publicResponse = await supertest(app) - .post('/register') - .send(publicClientMetadata); - - expect(publicResponse.status).toBe(201); - expect(publicResponse.body.client_secret).toBeDefined(); - expect(publicResponse.body.client_secret_expires_at).toBeDefined(); - - // Test for non-public client (token_endpoint_auth_method is 'none') - const nonPublicClientMetadata: OAuthClientMetadata = { - redirect_uris: ['https://example.com/callback'], - token_endpoint_auth_method: 'none' - }; - - const nonPublicResponse = await supertest(app) - .post('/register') - .send(nonPublicClientMetadata); - - expect(nonPublicResponse.status).toBe(201); - expect(nonPublicResponse.body.client_secret).toBeUndefined(); - expect(nonPublicResponse.body.client_secret_expires_at).toBeUndefined(); - }); + it('sets client_secret_expires_at for public clients only', async () => { + // Test for public client (token_endpoint_auth_method not 'none') + const publicClientMetadata: OAuthClientMetadata = { + redirect_uris: ['https://example.com/callback'], + token_endpoint_auth_method: 'client_secret_basic' + }; + + const publicResponse = await supertest(app).post('/register').send(publicClientMetadata); - it('sets expiry based on clientSecretExpirySeconds', async () => { - // Create handler with custom expiry time - const customApp = express(); - const options: ClientRegistrationHandlerOptions = { - clientsStore: mockClientStoreWithRegistration, - clientSecretExpirySeconds: 3600 // 1 hour - }; + expect(publicResponse.status).toBe(201); + expect(publicResponse.body.client_secret).toBeDefined(); + expect(publicResponse.body.client_secret_expires_at).toBeDefined(); - customApp.use('/register', clientRegistrationHandler(options)); + // Test for non-public client (token_endpoint_auth_method is 'none') + const nonPublicClientMetadata: OAuthClientMetadata = { + redirect_uris: ['https://example.com/callback'], + token_endpoint_auth_method: 'none' + }; - const response = await supertest(customApp) - .post('/register') - .send({ - redirect_uris: ['https://example.com/callback'] + const nonPublicResponse = await supertest(app).post('/register').send(nonPublicClientMetadata); + + expect(nonPublicResponse.status).toBe(201); + expect(nonPublicResponse.body.client_secret).toBeUndefined(); + expect(nonPublicResponse.body.client_secret_expires_at).toBeUndefined(); }); - expect(response.status).toBe(201); + it('sets expiry based on clientSecretExpirySeconds', async () => { + // Create handler with custom expiry time + const customApp = express(); + const options: ClientRegistrationHandlerOptions = { + clientsStore: mockClientStoreWithRegistration, + clientSecretExpirySeconds: 3600 // 1 hour + }; - // Verify the expiration time (~1 hour from now) - const issuedAt = response.body.client_id_issued_at; - const expiresAt = response.body.client_secret_expires_at; - expect(expiresAt - issuedAt).toBe(3600); - }); + customApp.use('/register', clientRegistrationHandler(options)); - it('sets no expiry when clientSecretExpirySeconds=0', async () => { - // Create handler with no expiry - const customApp = express(); - const options: ClientRegistrationHandlerOptions = { - clientsStore: mockClientStoreWithRegistration, - clientSecretExpirySeconds: 0 // No expiry - }; + const response = await supertest(customApp) + .post('/register') + .send({ + redirect_uris: ['https://example.com/callback'] + }); - customApp.use('/register', clientRegistrationHandler(options)); + expect(response.status).toBe(201); - const response = await supertest(customApp) - .post('/register') - .send({ - redirect_uris: ['https://example.com/callback'] + // Verify the expiration time (~1 hour from now) + const issuedAt = response.body.client_id_issued_at; + const expiresAt = response.body.client_secret_expires_at; + expect(expiresAt - issuedAt).toBe(3600); }); - expect(response.status).toBe(201); - expect(response.body.client_secret_expires_at).toBe(0); - }); + it('sets no expiry when clientSecretExpirySeconds=0', async () => { + // Create handler with no expiry + const customApp = express(); + const options: ClientRegistrationHandlerOptions = { + clientsStore: mockClientStoreWithRegistration, + clientSecretExpirySeconds: 0 // No expiry + }; - it('sets no client_id when clientIdGeneration=false', async () => { - // Create handler with no expiry - const customApp = express(); - const options: ClientRegistrationHandlerOptions = { - clientsStore: mockClientStoreWithRegistration, - clientIdGeneration: false - }; + customApp.use('/register', clientRegistrationHandler(options)); - customApp.use('/register', clientRegistrationHandler(options)); + const response = await supertest(customApp) + .post('/register') + .send({ + redirect_uris: ['https://example.com/callback'] + }); - const response = await supertest(customApp) - .post('/register') - .send({ - redirect_uris: ['https://example.com/callback'] + expect(response.status).toBe(201); + expect(response.body.client_secret_expires_at).toBe(0); }); - expect(response.status).toBe(201); - expect(response.body.client_id).toBeUndefined(); - expect(response.body.client_id_issued_at).toBeUndefined(); - }); - - it('handles client with all metadata fields', async () => { - const fullClientMetadata: OAuthClientMetadata = { - redirect_uris: ['https://example.com/callback'], - token_endpoint_auth_method: 'client_secret_basic', - grant_types: ['authorization_code', 'refresh_token'], - response_types: ['code'], - client_name: 'Test Client', - client_uri: 'https://example.com', - logo_uri: 'https://example.com/logo.png', - scope: 'profile email', - contacts: ['dev@example.com'], - tos_uri: 'https://example.com/tos', - policy_uri: 'https://example.com/privacy', - jwks_uri: 'https://example.com/jwks', - software_id: 'test-software', - software_version: '1.0.0' - }; - - const response = await supertest(app) - .post('/register') - .send(fullClientMetadata); - - expect(response.status).toBe(201); - - // Verify all metadata was preserved - Object.entries(fullClientMetadata).forEach(([key, value]) => { - expect(response.body[key]).toEqual(value); - }); - }); + it('sets no client_id when clientIdGeneration=false', async () => { + // Create handler with no expiry + const customApp = express(); + const options: ClientRegistrationHandlerOptions = { + clientsStore: mockClientStoreWithRegistration, + clientIdGeneration: false + }; + + customApp.use('/register', clientRegistrationHandler(options)); + + const response = await supertest(customApp) + .post('/register') + .send({ + redirect_uris: ['https://example.com/callback'] + }); + + expect(response.status).toBe(201); + expect(response.body.client_id).toBeUndefined(); + expect(response.body.client_id_issued_at).toBeUndefined(); + }); - it('includes CORS headers in response', async () => { - const response = await supertest(app) - .post('/register') - .set('Origin', 'https://example.com') - .send({ - redirect_uris: ['https://example.com/callback'] + it('handles client with all metadata fields', async () => { + const fullClientMetadata: OAuthClientMetadata = { + redirect_uris: ['https://example.com/callback'], + token_endpoint_auth_method: 'client_secret_basic', + grant_types: ['authorization_code', 'refresh_token'], + response_types: ['code'], + client_name: 'Test Client', + client_uri: 'https://example.com', + logo_uri: 'https://example.com/logo.png', + scope: 'profile email', + contacts: ['dev@example.com'], + tos_uri: 'https://example.com/tos', + policy_uri: 'https://example.com/privacy', + jwks_uri: 'https://example.com/jwks', + software_id: 'test-software', + software_version: '1.0.0' + }; + + const response = await supertest(app).post('/register').send(fullClientMetadata); + + expect(response.status).toBe(201); + + // Verify all metadata was preserved + Object.entries(fullClientMetadata).forEach(([key, value]) => { + expect(response.body[key]).toEqual(value); + }); }); - expect(response.header['access-control-allow-origin']).toBe('*'); + it('includes CORS headers in response', async () => { + const response = await supertest(app) + .post('/register') + .set('Origin', 'https://example.com') + .send({ + redirect_uris: ['https://example.com/callback'] + }); + + expect(response.header['access-control-allow-origin']).toBe('*'); + }); }); - }); -}); \ No newline at end of file +}); diff --git a/src/server/auth/handlers/register.ts b/src/server/auth/handlers/register.ts index 4d8bea1ac..1830619b4 100644 --- a/src/server/auth/handlers/register.ts +++ b/src/server/auth/handlers/register.ts @@ -1,124 +1,119 @@ -import express, { RequestHandler } from "express"; -import { OAuthClientInformationFull, OAuthClientMetadataSchema } from "../../../shared/auth.js"; +import express, { RequestHandler } from 'express'; +import { OAuthClientInformationFull, OAuthClientMetadataSchema } from '../../../shared/auth.js'; import crypto from 'node:crypto'; import cors from 'cors'; -import { OAuthRegisteredClientsStore } from "../clients.js"; -import { rateLimit, Options as RateLimitOptions } from "express-rate-limit"; -import { allowedMethods } from "../middleware/allowedMethods.js"; -import { - InvalidClientMetadataError, - ServerError, - TooManyRequestsError, - OAuthError -} from "../errors.js"; +import { OAuthRegisteredClientsStore } from '../clients.js'; +import { rateLimit, Options as RateLimitOptions } from 'express-rate-limit'; +import { allowedMethods } from '../middleware/allowedMethods.js'; +import { InvalidClientMetadataError, ServerError, TooManyRequestsError, OAuthError } from '../errors.js'; export type ClientRegistrationHandlerOptions = { - /** - * A store used to save information about dynamically registered OAuth clients. - */ - clientsStore: OAuthRegisteredClientsStore; - - /** - * The number of seconds after which to expire issued client secrets, or 0 to prevent expiration of client secrets (not recommended). - * - * If not set, defaults to 30 days. - */ - clientSecretExpirySeconds?: number; - - /** - * Rate limiting configuration for the client registration endpoint. - * Set to false to disable rate limiting for this endpoint. - * Registration endpoints are particularly sensitive to abuse and should be rate limited. - */ - rateLimit?: Partial | false; - - /** - * Whether to generate a client ID before calling the client registration endpoint. - * - * If not set, defaults to true. - */ - clientIdGeneration?: boolean; + /** + * A store used to save information about dynamically registered OAuth clients. + */ + clientsStore: OAuthRegisteredClientsStore; + + /** + * The number of seconds after which to expire issued client secrets, or 0 to prevent expiration of client secrets (not recommended). + * + * If not set, defaults to 30 days. + */ + clientSecretExpirySeconds?: number; + + /** + * Rate limiting configuration for the client registration endpoint. + * Set to false to disable rate limiting for this endpoint. + * Registration endpoints are particularly sensitive to abuse and should be rate limited. + */ + rateLimit?: Partial | false; + + /** + * Whether to generate a client ID before calling the client registration endpoint. + * + * If not set, defaults to true. + */ + clientIdGeneration?: boolean; }; const DEFAULT_CLIENT_SECRET_EXPIRY_SECONDS = 30 * 24 * 60 * 60; // 30 days export function clientRegistrationHandler({ - clientsStore, - clientSecretExpirySeconds = DEFAULT_CLIENT_SECRET_EXPIRY_SECONDS, - rateLimit: rateLimitConfig, - clientIdGeneration = true, + clientsStore, + clientSecretExpirySeconds = DEFAULT_CLIENT_SECRET_EXPIRY_SECONDS, + rateLimit: rateLimitConfig, + clientIdGeneration = true }: ClientRegistrationHandlerOptions): RequestHandler { - if (!clientsStore.registerClient) { - throw new Error("Client registration store does not support registering clients"); - } - - // Nested router so we can configure middleware and restrict HTTP method - const router = express.Router(); - - // Configure CORS to allow any origin, to make accessible to web-based MCP clients - router.use(cors()); - - router.use(allowedMethods(["POST"])); - router.use(express.json()); - - // Apply rate limiting unless explicitly disabled - stricter limits for registration - if (rateLimitConfig !== false) { - router.use(rateLimit({ - windowMs: 60 * 60 * 1000, // 1 hour - max: 20, // 20 requests per hour - stricter as registration is sensitive - standardHeaders: true, - legacyHeaders: false, - message: new TooManyRequestsError('You have exceeded the rate limit for client registration requests').toResponseObject(), - ...rateLimitConfig - })); - } - - router.post("/", async (req, res) => { - res.setHeader('Cache-Control', 'no-store'); - - try { - const parseResult = OAuthClientMetadataSchema.safeParse(req.body); - if (!parseResult.success) { - throw new InvalidClientMetadataError(parseResult.error.message); - } - - const clientMetadata = parseResult.data; - const isPublicClient = clientMetadata.token_endpoint_auth_method === 'none' - - // Generate client credentials - const clientSecret = isPublicClient - ? undefined - : crypto.randomBytes(32).toString('hex'); - const clientIdIssuedAt = Math.floor(Date.now() / 1000); - - // Calculate client secret expiry time - const clientsDoExpire = clientSecretExpirySeconds > 0 - const secretExpiryTime = clientsDoExpire ? clientIdIssuedAt + clientSecretExpirySeconds : 0 - const clientSecretExpiresAt = isPublicClient ? undefined : secretExpiryTime - - let clientInfo: Omit & { client_id?: string } = { - ...clientMetadata, - client_secret: clientSecret, - client_secret_expires_at: clientSecretExpiresAt, - }; - - if (clientIdGeneration) { - clientInfo.client_id = crypto.randomUUID(); - clientInfo.client_id_issued_at = clientIdIssuedAt; - } - - clientInfo = await clientsStore.registerClient!(clientInfo); - res.status(201).json(clientInfo); - } catch (error) { - if (error instanceof OAuthError) { - const status = error instanceof ServerError ? 500 : 400; - res.status(status).json(error.toResponseObject()); - } else { - const serverError = new ServerError("Internal Server Error"); - res.status(500).json(serverError.toResponseObject()); - } + if (!clientsStore.registerClient) { + throw new Error('Client registration store does not support registering clients'); } - }); - return router; -} \ No newline at end of file + // Nested router so we can configure middleware and restrict HTTP method + const router = express.Router(); + + // Configure CORS to allow any origin, to make accessible to web-based MCP clients + router.use(cors()); + + router.use(allowedMethods(['POST'])); + router.use(express.json()); + + // Apply rate limiting unless explicitly disabled - stricter limits for registration + if (rateLimitConfig !== false) { + router.use( + rateLimit({ + windowMs: 60 * 60 * 1000, // 1 hour + max: 20, // 20 requests per hour - stricter as registration is sensitive + standardHeaders: true, + legacyHeaders: false, + message: new TooManyRequestsError('You have exceeded the rate limit for client registration requests').toResponseObject(), + ...rateLimitConfig + }) + ); + } + + router.post('/', async (req, res) => { + res.setHeader('Cache-Control', 'no-store'); + + try { + const parseResult = OAuthClientMetadataSchema.safeParse(req.body); + if (!parseResult.success) { + throw new InvalidClientMetadataError(parseResult.error.message); + } + + const clientMetadata = parseResult.data; + const isPublicClient = clientMetadata.token_endpoint_auth_method === 'none'; + + // Generate client credentials + const clientSecret = isPublicClient ? undefined : crypto.randomBytes(32).toString('hex'); + const clientIdIssuedAt = Math.floor(Date.now() / 1000); + + // Calculate client secret expiry time + const clientsDoExpire = clientSecretExpirySeconds > 0; + const secretExpiryTime = clientsDoExpire ? clientIdIssuedAt + clientSecretExpirySeconds : 0; + const clientSecretExpiresAt = isPublicClient ? undefined : secretExpiryTime; + + let clientInfo: Omit & { client_id?: string } = { + ...clientMetadata, + client_secret: clientSecret, + client_secret_expires_at: clientSecretExpiresAt + }; + + if (clientIdGeneration) { + clientInfo.client_id = crypto.randomUUID(); + clientInfo.client_id_issued_at = clientIdIssuedAt; + } + + clientInfo = await clientsStore.registerClient!(clientInfo); + res.status(201).json(clientInfo); + } catch (error) { + if (error instanceof OAuthError) { + const status = error instanceof ServerError ? 500 : 400; + res.status(status).json(error.toResponseObject()); + } else { + const serverError = new ServerError('Internal Server Error'); + res.status(500).json(serverError.toResponseObject()); + } + } + }); + + return router; +} diff --git a/src/server/auth/handlers/revoke.test.ts b/src/server/auth/handlers/revoke.test.ts index bd34cab76..594b689e9 100644 --- a/src/server/auth/handlers/revoke.test.ts +++ b/src/server/auth/handlers/revoke.test.ts @@ -8,240 +8,222 @@ import { AuthInfo } from '../types.js'; import { InvalidTokenError } from '../errors.js'; describe('Revocation Handler', () => { - // Mock client data - const validClient: OAuthClientInformationFull = { - client_id: 'valid-client', - client_secret: 'valid-secret', - redirect_uris: ['https://example.com/callback'] - }; - - // Mock client store - const mockClientStore: OAuthRegisteredClientsStore = { - async getClient(clientId: string): Promise { - if (clientId === 'valid-client') { - return validClient; - } - return undefined; - } - }; - - // Mock provider with revocation capability - const mockProviderWithRevocation: OAuthServerProvider = { - clientsStore: mockClientStore, - - async authorize(client: OAuthClientInformationFull, params: AuthorizationParams, res: Response): Promise { - res.redirect('https://example.com/callback?code=mock_auth_code'); - }, - - async challengeForAuthorizationCode(): Promise { - return 'mock_challenge'; - }, - - async exchangeAuthorizationCode(): Promise { - return { - access_token: 'mock_access_token', - token_type: 'bearer', - expires_in: 3600, - refresh_token: 'mock_refresh_token' - }; - }, - - async exchangeRefreshToken(): Promise { - return { - access_token: 'new_mock_access_token', - token_type: 'bearer', - expires_in: 3600, - refresh_token: 'new_mock_refresh_token' - }; - }, - - async verifyAccessToken(token: string): Promise { - if (token === 'valid_token') { - return { - token, - clientId: 'valid-client', - scopes: ['read', 'write'], - expiresAt: Date.now() / 1000 + 3600 - }; - } - throw new InvalidTokenError('Token is invalid or expired'); - }, - - async revokeToken(_client: OAuthClientInformationFull, _request: OAuthTokenRevocationRequest): Promise { - // Success - do nothing in mock - } - }; - - // Mock provider without revocation capability - const mockProviderWithoutRevocation: OAuthServerProvider = { - clientsStore: mockClientStore, - - async authorize(client: OAuthClientInformationFull, params: AuthorizationParams, res: Response): Promise { - res.redirect('https://example.com/callback?code=mock_auth_code'); - }, - - async challengeForAuthorizationCode(): Promise { - return 'mock_challenge'; - }, - - async exchangeAuthorizationCode(): Promise { - return { - access_token: 'mock_access_token', - token_type: 'bearer', - expires_in: 3600, - refresh_token: 'mock_refresh_token' - }; - }, - - async exchangeRefreshToken(): Promise { - return { - access_token: 'new_mock_access_token', - token_type: 'bearer', - expires_in: 3600, - refresh_token: 'new_mock_refresh_token' - }; - }, - - async verifyAccessToken(token: string): Promise { - if (token === 'valid_token') { - return { - token, - clientId: 'valid-client', - scopes: ['read', 'write'], - expiresAt: Date.now() / 1000 + 3600 - }; - } - throw new InvalidTokenError('Token is invalid or expired'); - } - // No revokeToken method - }; - - describe('Handler creation', () => { - it('throws error if provider does not support token revocation', () => { - const options: RevocationHandlerOptions = { provider: mockProviderWithoutRevocation }; - expect(() => revocationHandler(options)).toThrow('does not support revoking tokens'); - }); + // Mock client data + const validClient: OAuthClientInformationFull = { + client_id: 'valid-client', + client_secret: 'valid-secret', + redirect_uris: ['https://example.com/callback'] + }; + + // Mock client store + const mockClientStore: OAuthRegisteredClientsStore = { + async getClient(clientId: string): Promise { + if (clientId === 'valid-client') { + return validClient; + } + return undefined; + } + }; + + // Mock provider with revocation capability + const mockProviderWithRevocation: OAuthServerProvider = { + clientsStore: mockClientStore, + + async authorize(client: OAuthClientInformationFull, params: AuthorizationParams, res: Response): Promise { + res.redirect('https://example.com/callback?code=mock_auth_code'); + }, + + async challengeForAuthorizationCode(): Promise { + return 'mock_challenge'; + }, + + async exchangeAuthorizationCode(): Promise { + return { + access_token: 'mock_access_token', + token_type: 'bearer', + expires_in: 3600, + refresh_token: 'mock_refresh_token' + }; + }, + + async exchangeRefreshToken(): Promise { + return { + access_token: 'new_mock_access_token', + token_type: 'bearer', + expires_in: 3600, + refresh_token: 'new_mock_refresh_token' + }; + }, + + async verifyAccessToken(token: string): Promise { + if (token === 'valid_token') { + return { + token, + clientId: 'valid-client', + scopes: ['read', 'write'], + expiresAt: Date.now() / 1000 + 3600 + }; + } + throw new InvalidTokenError('Token is invalid or expired'); + }, + + async revokeToken(_client: OAuthClientInformationFull, _request: OAuthTokenRevocationRequest): Promise { + // Success - do nothing in mock + } + }; + + // Mock provider without revocation capability + const mockProviderWithoutRevocation: OAuthServerProvider = { + clientsStore: mockClientStore, + + async authorize(client: OAuthClientInformationFull, params: AuthorizationParams, res: Response): Promise { + res.redirect('https://example.com/callback?code=mock_auth_code'); + }, + + async challengeForAuthorizationCode(): Promise { + return 'mock_challenge'; + }, + + async exchangeAuthorizationCode(): Promise { + return { + access_token: 'mock_access_token', + token_type: 'bearer', + expires_in: 3600, + refresh_token: 'mock_refresh_token' + }; + }, + + async exchangeRefreshToken(): Promise { + return { + access_token: 'new_mock_access_token', + token_type: 'bearer', + expires_in: 3600, + refresh_token: 'new_mock_refresh_token' + }; + }, + + async verifyAccessToken(token: string): Promise { + if (token === 'valid_token') { + return { + token, + clientId: 'valid-client', + scopes: ['read', 'write'], + expiresAt: Date.now() / 1000 + 3600 + }; + } + throw new InvalidTokenError('Token is invalid or expired'); + } + // No revokeToken method + }; + + describe('Handler creation', () => { + it('throws error if provider does not support token revocation', () => { + const options: RevocationHandlerOptions = { provider: mockProviderWithoutRevocation }; + expect(() => revocationHandler(options)).toThrow('does not support revoking tokens'); + }); - it('creates handler if provider supports token revocation', () => { - const options: RevocationHandlerOptions = { provider: mockProviderWithRevocation }; - expect(() => revocationHandler(options)).not.toThrow(); + it('creates handler if provider supports token revocation', () => { + const options: RevocationHandlerOptions = { provider: mockProviderWithRevocation }; + expect(() => revocationHandler(options)).not.toThrow(); + }); }); - }); - - describe('Request handling', () => { - let app: express.Express; - let spyRevokeToken: jest.SpyInstance; - beforeEach(() => { - // Setup express app with revocation handler - app = express(); - const options: RevocationHandlerOptions = { provider: mockProviderWithRevocation }; - app.use('/revoke', revocationHandler(options)); + describe('Request handling', () => { + let app: express.Express; + let spyRevokeToken: jest.SpyInstance; - // Spy on the revokeToken method - spyRevokeToken = jest.spyOn(mockProviderWithRevocation, 'revokeToken'); - }); - - afterEach(() => { - spyRevokeToken.mockRestore(); - }); + beforeEach(() => { + // Setup express app with revocation handler + app = express(); + const options: RevocationHandlerOptions = { provider: mockProviderWithRevocation }; + app.use('/revoke', revocationHandler(options)); - it('requires POST method', async () => { - const response = await supertest(app) - .get('/revoke') - .send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - token: 'token_to_revoke' + // Spy on the revokeToken method + spyRevokeToken = jest.spyOn(mockProviderWithRevocation, 'revokeToken'); }); - expect(response.status).toBe(405); - expect(response.headers.allow).toBe('POST'); - expect(response.body).toEqual({ - error: "method_not_allowed", - error_description: "The method GET is not allowed for this endpoint" - }); - expect(spyRevokeToken).not.toHaveBeenCalled(); - }); + afterEach(() => { + spyRevokeToken.mockRestore(); + }); - it('requires token parameter', async () => { - const response = await supertest(app) - .post('/revoke') - .type('form') - .send({ - client_id: 'valid-client', - client_secret: 'valid-secret' - // Missing token + it('requires POST method', async () => { + const response = await supertest(app).get('/revoke').send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + token: 'token_to_revoke' + }); + + expect(response.status).toBe(405); + expect(response.headers.allow).toBe('POST'); + expect(response.body).toEqual({ + error: 'method_not_allowed', + error_description: 'The method GET is not allowed for this endpoint' + }); + expect(spyRevokeToken).not.toHaveBeenCalled(); }); - expect(response.status).toBe(400); - expect(response.body.error).toBe('invalid_request'); - expect(spyRevokeToken).not.toHaveBeenCalled(); - }); + it('requires token parameter', async () => { + const response = await supertest(app).post('/revoke').type('form').send({ + client_id: 'valid-client', + client_secret: 'valid-secret' + // Missing token + }); - it('authenticates client before revoking token', async () => { - const response = await supertest(app) - .post('/revoke') - .type('form') - .send({ - client_id: 'invalid-client', - client_secret: 'wrong-secret', - token: 'token_to_revoke' + expect(response.status).toBe(400); + expect(response.body.error).toBe('invalid_request'); + expect(spyRevokeToken).not.toHaveBeenCalled(); }); - expect(response.status).toBe(400); - expect(response.body.error).toBe('invalid_client'); - expect(spyRevokeToken).not.toHaveBeenCalled(); - }); + it('authenticates client before revoking token', async () => { + const response = await supertest(app).post('/revoke').type('form').send({ + client_id: 'invalid-client', + client_secret: 'wrong-secret', + token: 'token_to_revoke' + }); - it('successfully revokes token', async () => { - const response = await supertest(app) - .post('/revoke') - .type('form') - .send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - token: 'token_to_revoke' + expect(response.status).toBe(400); + expect(response.body.error).toBe('invalid_client'); + expect(spyRevokeToken).not.toHaveBeenCalled(); }); - expect(response.status).toBe(200); - expect(response.body).toEqual({}); // Empty response on success - expect(spyRevokeToken).toHaveBeenCalledTimes(1); - expect(spyRevokeToken).toHaveBeenCalledWith(validClient, { - token: 'token_to_revoke' - }); - }); + it('successfully revokes token', async () => { + const response = await supertest(app).post('/revoke').type('form').send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + token: 'token_to_revoke' + }); + + expect(response.status).toBe(200); + expect(response.body).toEqual({}); // Empty response on success + expect(spyRevokeToken).toHaveBeenCalledTimes(1); + expect(spyRevokeToken).toHaveBeenCalledWith(validClient, { + token: 'token_to_revoke' + }); + }); - it('accepts optional token_type_hint', async () => { - const response = await supertest(app) - .post('/revoke') - .type('form') - .send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - token: 'token_to_revoke', - token_type_hint: 'refresh_token' + it('accepts optional token_type_hint', async () => { + const response = await supertest(app).post('/revoke').type('form').send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + token: 'token_to_revoke', + token_type_hint: 'refresh_token' + }); + + expect(response.status).toBe(200); + expect(spyRevokeToken).toHaveBeenCalledWith(validClient, { + token: 'token_to_revoke', + token_type_hint: 'refresh_token' + }); }); - expect(response.status).toBe(200); - expect(spyRevokeToken).toHaveBeenCalledWith(validClient, { - token: 'token_to_revoke', - token_type_hint: 'refresh_token' - }); - }); + it('includes CORS headers in response', async () => { + const response = await supertest(app).post('/revoke').type('form').set('Origin', 'https://example.com').send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + token: 'token_to_revoke' + }); - it('includes CORS headers in response', async () => { - const response = await supertest(app) - .post('/revoke') - .type('form') - .set('Origin', 'https://example.com') - .send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - token: 'token_to_revoke' + expect(response.header['access-control-allow-origin']).toBe('*'); }); - - expect(response.header['access-control-allow-origin']).toBe('*'); }); - }); -}); \ No newline at end of file +}); diff --git a/src/server/auth/handlers/revoke.ts b/src/server/auth/handlers/revoke.ts index 0d1b30e07..da7ef04f8 100644 --- a/src/server/auth/handlers/revoke.ts +++ b/src/server/auth/handlers/revoke.ts @@ -1,89 +1,79 @@ -import { OAuthServerProvider } from "../provider.js"; -import express, { RequestHandler } from "express"; -import cors from "cors"; -import { authenticateClient } from "../middleware/clientAuth.js"; -import { OAuthTokenRevocationRequestSchema } from "../../../shared/auth.js"; -import { rateLimit, Options as RateLimitOptions } from "express-rate-limit"; -import { allowedMethods } from "../middleware/allowedMethods.js"; -import { - InvalidRequestError, - ServerError, - TooManyRequestsError, - OAuthError, -} from "../errors.js"; +import { OAuthServerProvider } from '../provider.js'; +import express, { RequestHandler } from 'express'; +import cors from 'cors'; +import { authenticateClient } from '../middleware/clientAuth.js'; +import { OAuthTokenRevocationRequestSchema } from '../../../shared/auth.js'; +import { rateLimit, Options as RateLimitOptions } from 'express-rate-limit'; +import { allowedMethods } from '../middleware/allowedMethods.js'; +import { InvalidRequestError, ServerError, TooManyRequestsError, OAuthError } from '../errors.js'; export type RevocationHandlerOptions = { - provider: OAuthServerProvider; - /** - * Rate limiting configuration for the token revocation endpoint. - * Set to false to disable rate limiting for this endpoint. - */ - rateLimit?: Partial | false; + provider: OAuthServerProvider; + /** + * Rate limiting configuration for the token revocation endpoint. + * Set to false to disable rate limiting for this endpoint. + */ + rateLimit?: Partial | false; }; -export function revocationHandler({ - provider, - rateLimit: rateLimitConfig, -}: RevocationHandlerOptions): RequestHandler { - if (!provider.revokeToken) { - throw new Error("Auth provider does not support revoking tokens"); - } +export function revocationHandler({ provider, rateLimit: rateLimitConfig }: RevocationHandlerOptions): RequestHandler { + if (!provider.revokeToken) { + throw new Error('Auth provider does not support revoking tokens'); + } - // Nested router so we can configure middleware and restrict HTTP method - const router = express.Router(); + // Nested router so we can configure middleware and restrict HTTP method + const router = express.Router(); - // Configure CORS to allow any origin, to make accessible to web-based MCP clients - router.use(cors()); + // Configure CORS to allow any origin, to make accessible to web-based MCP clients + router.use(cors()); - router.use(allowedMethods(["POST"])); - router.use(express.urlencoded({ extended: false })); + router.use(allowedMethods(['POST'])); + router.use(express.urlencoded({ extended: false })); - // Apply rate limiting unless explicitly disabled - if (rateLimitConfig !== false) { - router.use( - rateLimit({ - windowMs: 15 * 60 * 1000, // 15 minutes - max: 50, // 50 requests per windowMs - standardHeaders: true, - legacyHeaders: false, - message: new TooManyRequestsError( - "You have exceeded the rate limit for token revocation requests" - ).toResponseObject(), - ...rateLimitConfig, - }) - ); - } + // Apply rate limiting unless explicitly disabled + if (rateLimitConfig !== false) { + router.use( + rateLimit({ + windowMs: 15 * 60 * 1000, // 15 minutes + max: 50, // 50 requests per windowMs + standardHeaders: true, + legacyHeaders: false, + message: new TooManyRequestsError('You have exceeded the rate limit for token revocation requests').toResponseObject(), + ...rateLimitConfig + }) + ); + } - // Authenticate and extract client details - router.use(authenticateClient({ clientsStore: provider.clientsStore })); + // Authenticate and extract client details + router.use(authenticateClient({ clientsStore: provider.clientsStore })); - router.post("/", async (req, res) => { - res.setHeader("Cache-Control", "no-store"); + router.post('/', async (req, res) => { + res.setHeader('Cache-Control', 'no-store'); - try { - const parseResult = OAuthTokenRevocationRequestSchema.safeParse(req.body); - if (!parseResult.success) { - throw new InvalidRequestError(parseResult.error.message); - } + try { + const parseResult = OAuthTokenRevocationRequestSchema.safeParse(req.body); + if (!parseResult.success) { + throw new InvalidRequestError(parseResult.error.message); + } - const client = req.client; - if (!client) { - // This should never happen - throw new ServerError("Internal Server Error"); - } + const client = req.client; + if (!client) { + // This should never happen + throw new ServerError('Internal Server Error'); + } - await provider.revokeToken!(client, parseResult.data); - res.status(200).json({}); - } catch (error) { - if (error instanceof OAuthError) { - const status = error instanceof ServerError ? 500 : 400; - res.status(status).json(error.toResponseObject()); - } else { - const serverError = new ServerError("Internal Server Error"); - res.status(500).json(serverError.toResponseObject()); - } - } - }); + await provider.revokeToken!(client, parseResult.data); + res.status(200).json({}); + } catch (error) { + if (error instanceof OAuthError) { + const status = error instanceof ServerError ? 500 : 400; + res.status(status).json(error.toResponseObject()); + } else { + const serverError = new ServerError('Internal Server Error'); + res.status(500).json(serverError.toResponseObject()); + } + } + }); - return router; + return router; } diff --git a/src/server/auth/handlers/token.test.ts b/src/server/auth/handlers/token.test.ts index 946cc6910..e0338f030 100644 --- a/src/server/auth/handlers/token.test.ts +++ b/src/server/auth/handlers/token.test.ts @@ -11,9 +11,9 @@ import { ProxyOAuthServerProvider } from '../providers/proxyProvider.js'; // Mock pkce-challenge jest.mock('pkce-challenge', () => ({ - verifyChallenge: jest.fn().mockImplementation(async (verifier, challenge) => { - return verifier === 'valid_verifier' && challenge === 'mock_challenge'; - }) + verifyChallenge: jest.fn().mockImplementation(async (verifier, challenge) => { + return verifier === 'valid_verifier' && challenge === 'mock_challenge'; + }) })); const mockTokens = { @@ -26,506 +26,453 @@ const mockTokens = { const mockTokensWithIdToken = { ...mockTokens, id_token: 'mock_id_token' -} +}; describe('Token Handler', () => { - // Mock client data - const validClient: OAuthClientInformationFull = { - client_id: 'valid-client', - client_secret: 'valid-secret', - redirect_uris: ['https://example.com/callback'] - }; - - // Mock client store - const mockClientStore: OAuthRegisteredClientsStore = { - async getClient(clientId: string): Promise { - if (clientId === 'valid-client') { - return validClient; - } - return undefined; - } - }; - - // Mock provider - let mockProvider: OAuthServerProvider; - let app: express.Express; - - beforeEach(() => { - // Create fresh mocks for each test - mockProvider = { - clientsStore: mockClientStore, - - async authorize(client: OAuthClientInformationFull, params: AuthorizationParams, res: Response): Promise { - res.redirect('https://example.com/callback?code=mock_auth_code'); - }, - - async challengeForAuthorizationCode(client: OAuthClientInformationFull, authorizationCode: string): Promise { - if (authorizationCode === 'valid_code') { - return 'mock_challenge'; - } else if (authorizationCode === 'expired_code') { - throw new InvalidGrantError('The authorization code has expired'); - } - throw new InvalidGrantError('The authorization code is invalid'); - }, + // Mock client data + const validClient: OAuthClientInformationFull = { + client_id: 'valid-client', + client_secret: 'valid-secret', + redirect_uris: ['https://example.com/callback'] + }; - async exchangeAuthorizationCode(client: OAuthClientInformationFull, authorizationCode: string): Promise { - if (authorizationCode === 'valid_code') { - return mockTokens; - } - throw new InvalidGrantError('The authorization code is invalid or has expired'); - }, - - async exchangeRefreshToken(client: OAuthClientInformationFull, refreshToken: string, scopes?: string[]): Promise { - if (refreshToken === 'valid_refresh_token') { - const response: OAuthTokens = { - access_token: 'new_mock_access_token', - token_type: 'bearer', - expires_in: 3600, - refresh_token: 'new_mock_refresh_token' - }; - - if (scopes) { - response.scope = scopes.join(' '); - } - - return response; - } - throw new InvalidGrantError('The refresh token is invalid or has expired'); - }, - - async verifyAccessToken(token: string): Promise { - if (token === 'valid_token') { - return { - token, - clientId: 'valid-client', - scopes: ['read', 'write'], - expiresAt: Date.now() / 1000 + 3600 - }; + // Mock client store + const mockClientStore: OAuthRegisteredClientsStore = { + async getClient(clientId: string): Promise { + if (clientId === 'valid-client') { + return validClient; + } + return undefined; } - throw new InvalidTokenError('Token is invalid or expired'); - }, - - async revokeToken(_client: OAuthClientInformationFull, _request: OAuthTokenRevocationRequest): Promise { - // Do nothing in mock - } }; - // Mock PKCE verification - (pkceChallenge.verifyChallenge as jest.Mock).mockImplementation( - async (verifier: string, challenge: string) => { - return verifier === 'valid_verifier' && challenge === 'mock_challenge'; - } - ); - - // Setup express app with token handler - app = express(); - const options: TokenHandlerOptions = { provider: mockProvider }; - app.use('/token', tokenHandler(options)); - }); - - describe('Basic request validation', () => { - it('requires POST method', async () => { - const response = await supertest(app) - .get('/token') - .send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - grant_type: 'authorization_code' - }); + // Mock provider + let mockProvider: OAuthServerProvider; + let app: express.Express; - expect(response.status).toBe(405); - expect(response.headers.allow).toBe('POST'); - expect(response.body).toEqual({ - error: "method_not_allowed", - error_description: "The method GET is not allowed for this endpoint" - }); - }); + beforeEach(() => { + // Create fresh mocks for each test + mockProvider = { + clientsStore: mockClientStore, - it('requires grant_type parameter', async () => { - const response = await supertest(app) - .post('/token') - .type('form') - .send({ - client_id: 'valid-client', - client_secret: 'valid-secret' - // Missing grant_type - }); + async authorize(client: OAuthClientInformationFull, params: AuthorizationParams, res: Response): Promise { + res.redirect('https://example.com/callback?code=mock_auth_code'); + }, - expect(response.status).toBe(400); - expect(response.body.error).toBe('invalid_request'); - }); + async challengeForAuthorizationCode(client: OAuthClientInformationFull, authorizationCode: string): Promise { + if (authorizationCode === 'valid_code') { + return 'mock_challenge'; + } else if (authorizationCode === 'expired_code') { + throw new InvalidGrantError('The authorization code has expired'); + } + throw new InvalidGrantError('The authorization code is invalid'); + }, - it('rejects unsupported grant types', async () => { - const response = await supertest(app) - .post('/token') - .type('form') - .send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - grant_type: 'password' // Unsupported grant type - }); + async exchangeAuthorizationCode(client: OAuthClientInformationFull, authorizationCode: string): Promise { + if (authorizationCode === 'valid_code') { + return mockTokens; + } + throw new InvalidGrantError('The authorization code is invalid or has expired'); + }, - expect(response.status).toBe(400); - expect(response.body.error).toBe('unsupported_grant_type'); - }); - }); - - describe('Client authentication', () => { - it('requires valid client credentials', async () => { - const response = await supertest(app) - .post('/token') - .type('form') - .send({ - client_id: 'invalid-client', - client_secret: 'wrong-secret', - grant_type: 'authorization_code' - }); + async exchangeRefreshToken(client: OAuthClientInformationFull, refreshToken: string, scopes?: string[]): Promise { + if (refreshToken === 'valid_refresh_token') { + const response: OAuthTokens = { + access_token: 'new_mock_access_token', + token_type: 'bearer', + expires_in: 3600, + refresh_token: 'new_mock_refresh_token' + }; + + if (scopes) { + response.scope = scopes.join(' '); + } + + return response; + } + throw new InvalidGrantError('The refresh token is invalid or has expired'); + }, - expect(response.status).toBe(400); - expect(response.body.error).toBe('invalid_client'); - }); + async verifyAccessToken(token: string): Promise { + if (token === 'valid_token') { + return { + token, + clientId: 'valid-client', + scopes: ['read', 'write'], + expiresAt: Date.now() / 1000 + 3600 + }; + } + throw new InvalidTokenError('Token is invalid or expired'); + }, - it('accepts valid client credentials', async () => { - const response = await supertest(app) - .post('/token') - .type('form') - .send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - grant_type: 'authorization_code', - code: 'valid_code', - code_verifier: 'valid_verifier' - }); + async revokeToken(_client: OAuthClientInformationFull, _request: OAuthTokenRevocationRequest): Promise { + // Do nothing in mock + } + }; - expect(response.status).toBe(200); - }); - }); - - describe('Authorization code grant', () => { - it('requires code parameter', async () => { - const response = await supertest(app) - .post('/token') - .type('form') - .send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - grant_type: 'authorization_code', - // Missing code - code_verifier: 'valid_verifier' + // Mock PKCE verification + (pkceChallenge.verifyChallenge as jest.Mock).mockImplementation(async (verifier: string, challenge: string) => { + return verifier === 'valid_verifier' && challenge === 'mock_challenge'; }); - expect(response.status).toBe(400); - expect(response.body.error).toBe('invalid_request'); + // Setup express app with token handler + app = express(); + const options: TokenHandlerOptions = { provider: mockProvider }; + app.use('/token', tokenHandler(options)); }); - it('requires code_verifier parameter', async () => { - const response = await supertest(app) - .post('/token') - .type('form') - .send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - grant_type: 'authorization_code', - code: 'valid_code' - // Missing code_verifier + describe('Basic request validation', () => { + it('requires POST method', async () => { + const response = await supertest(app).get('/token').send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + grant_type: 'authorization_code' + }); + + expect(response.status).toBe(405); + expect(response.headers.allow).toBe('POST'); + expect(response.body).toEqual({ + error: 'method_not_allowed', + error_description: 'The method GET is not allowed for this endpoint' + }); }); - expect(response.status).toBe(400); - expect(response.body.error).toBe('invalid_request'); - }); + it('requires grant_type parameter', async () => { + const response = await supertest(app).post('/token').type('form').send({ + client_id: 'valid-client', + client_secret: 'valid-secret' + // Missing grant_type + }); - it('verifies code_verifier against challenge', async () => { - // Setup invalid verifier - (pkceChallenge.verifyChallenge as jest.Mock).mockResolvedValueOnce(false); - - const response = await supertest(app) - .post('/token') - .type('form') - .send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - grant_type: 'authorization_code', - code: 'valid_code', - code_verifier: 'invalid_verifier' + expect(response.status).toBe(400); + expect(response.body.error).toBe('invalid_request'); }); - expect(response.status).toBe(400); - expect(response.body.error).toBe('invalid_grant'); - expect(response.body.error_description).toContain('code_verifier'); - }); + it('rejects unsupported grant types', async () => { + const response = await supertest(app).post('/token').type('form').send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + grant_type: 'password' // Unsupported grant type + }); - it('rejects expired or invalid authorization codes', async () => { - const response = await supertest(app) - .post('/token') - .type('form') - .send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - grant_type: 'authorization_code', - code: 'expired_code', - code_verifier: 'valid_verifier' + expect(response.status).toBe(400); + expect(response.body.error).toBe('unsupported_grant_type'); }); - - expect(response.status).toBe(400); - expect(response.body.error).toBe('invalid_grant'); }); - it('returns tokens for valid code exchange', async () => { - const mockExchangeCode = jest.spyOn(mockProvider, 'exchangeAuthorizationCode'); - const response = await supertest(app) - .post('/token') - .type('form') - .send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - resource: 'https://api.example.com/resource', - grant_type: 'authorization_code', - code: 'valid_code', - code_verifier: 'valid_verifier' + describe('Client authentication', () => { + it('requires valid client credentials', async () => { + const response = await supertest(app).post('/token').type('form').send({ + client_id: 'invalid-client', + client_secret: 'wrong-secret', + grant_type: 'authorization_code' + }); + + expect(response.status).toBe(400); + expect(response.body.error).toBe('invalid_client'); }); - expect(response.status).toBe(200); - expect(response.body.access_token).toBe('mock_access_token'); - expect(response.body.token_type).toBe('bearer'); - expect(response.body.expires_in).toBe(3600); - expect(response.body.refresh_token).toBe('mock_refresh_token'); - expect(mockExchangeCode).toHaveBeenCalledWith( - validClient, - 'valid_code', - undefined, // code_verifier is undefined after PKCE validation - undefined, // redirect_uri - new URL('https://api.example.com/resource') // resource parameter - ); + it('accepts valid client credentials', async () => { + const response = await supertest(app).post('/token').type('form').send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + grant_type: 'authorization_code', + code: 'valid_code', + code_verifier: 'valid_verifier' + }); + + expect(response.status).toBe(200); + }); }); - it('returns id token in code exchange if provided', async () => { - mockProvider.exchangeAuthorizationCode = async (client: OAuthClientInformationFull, authorizationCode: string): Promise => { - if (authorizationCode === 'valid_code') { - return mockTokensWithIdToken; - } - throw new InvalidGrantError('The authorization code is invalid or has expired'); - }; - - const response = await supertest(app) - .post('/token') - .type('form') - .send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - grant_type: 'authorization_code', - code: 'valid_code', - code_verifier: 'valid_verifier' - }); - - expect(response.status).toBe(200); - expect(response.body.id_token).toBe('mock_id_token'); - }); - - it('passes through code verifier when using proxy provider', async () => { - const originalFetch = global.fetch; - - try { - global.fetch = jest.fn().mockResolvedValue({ - ok: true, - json: () => Promise.resolve(mockTokens) + describe('Authorization code grant', () => { + it('requires code parameter', async () => { + const response = await supertest(app).post('/token').type('form').send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + grant_type: 'authorization_code', + // Missing code + code_verifier: 'valid_verifier' + }); + + expect(response.status).toBe(400); + expect(response.body.error).toBe('invalid_request'); }); - const proxyProvider = new ProxyOAuthServerProvider({ - endpoints: { - authorizationUrl: 'https://example.com/authorize', - tokenUrl: 'https://example.com/token' - }, - verifyAccessToken: async (token) => ({ - token, - clientId: 'valid-client', - scopes: ['read', 'write'], - expiresAt: Date.now() / 1000 + 3600 - }), - getClient: async (clientId) => clientId === 'valid-client' ? validClient : undefined + it('requires code_verifier parameter', async () => { + const response = await supertest(app).post('/token').type('form').send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + grant_type: 'authorization_code', + code: 'valid_code' + // Missing code_verifier + }); + + expect(response.status).toBe(400); + expect(response.body.error).toBe('invalid_request'); }); - const proxyApp = express(); - const options: TokenHandlerOptions = { provider: proxyProvider }; - proxyApp.use('/token', tokenHandler(options)); - - const response = await supertest(proxyApp) - .post('/token') - .type('form') - .send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - grant_type: 'authorization_code', - code: 'valid_code', - code_verifier: 'any_verifier', - redirect_uri: 'https://example.com/callback' - }); - - expect(response.status).toBe(200); - expect(response.body.access_token).toBe('mock_access_token'); - - expect(global.fetch).toHaveBeenCalledWith( - 'https://example.com/token', - expect.objectContaining({ - method: 'POST', - headers: { - 'Content-Type': 'application/x-www-form-urlencoded' - }, - body: expect.stringContaining('code_verifier=any_verifier') - }) - ); - } finally { - global.fetch = originalFetch; - } - }); + it('verifies code_verifier against challenge', async () => { + // Setup invalid verifier + (pkceChallenge.verifyChallenge as jest.Mock).mockResolvedValueOnce(false); + + const response = await supertest(app).post('/token').type('form').send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + grant_type: 'authorization_code', + code: 'valid_code', + code_verifier: 'invalid_verifier' + }); + + expect(response.status).toBe(400); + expect(response.body.error).toBe('invalid_grant'); + expect(response.body.error_description).toContain('code_verifier'); + }); - it('passes through redirect_uri when using proxy provider', async () => { - const originalFetch = global.fetch; + it('rejects expired or invalid authorization codes', async () => { + const response = await supertest(app).post('/token').type('form').send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + grant_type: 'authorization_code', + code: 'expired_code', + code_verifier: 'valid_verifier' + }); + + expect(response.status).toBe(400); + expect(response.body.error).toBe('invalid_grant'); + }); - try { - global.fetch = jest.fn().mockResolvedValue({ - ok: true, - json: () => Promise.resolve(mockTokens) + it('returns tokens for valid code exchange', async () => { + const mockExchangeCode = jest.spyOn(mockProvider, 'exchangeAuthorizationCode'); + const response = await supertest(app).post('/token').type('form').send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + resource: 'https://api.example.com/resource', + grant_type: 'authorization_code', + code: 'valid_code', + code_verifier: 'valid_verifier' + }); + + expect(response.status).toBe(200); + expect(response.body.access_token).toBe('mock_access_token'); + expect(response.body.token_type).toBe('bearer'); + expect(response.body.expires_in).toBe(3600); + expect(response.body.refresh_token).toBe('mock_refresh_token'); + expect(mockExchangeCode).toHaveBeenCalledWith( + validClient, + 'valid_code', + undefined, // code_verifier is undefined after PKCE validation + undefined, // redirect_uri + new URL('https://api.example.com/resource') // resource parameter + ); }); - const proxyProvider = new ProxyOAuthServerProvider({ - endpoints: { - authorizationUrl: 'https://example.com/authorize', - tokenUrl: 'https://example.com/token' - }, - verifyAccessToken: async (token) => ({ - token, - clientId: 'valid-client', - scopes: ['read', 'write'], - expiresAt: Date.now() / 1000 + 3600 - }), - getClient: async (clientId) => clientId === 'valid-client' ? validClient : undefined + it('returns id token in code exchange if provided', async () => { + mockProvider.exchangeAuthorizationCode = async ( + client: OAuthClientInformationFull, + authorizationCode: string + ): Promise => { + if (authorizationCode === 'valid_code') { + return mockTokensWithIdToken; + } + throw new InvalidGrantError('The authorization code is invalid or has expired'); + }; + + const response = await supertest(app).post('/token').type('form').send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + grant_type: 'authorization_code', + code: 'valid_code', + code_verifier: 'valid_verifier' + }); + + expect(response.status).toBe(200); + expect(response.body.id_token).toBe('mock_id_token'); }); - const proxyApp = express(); - const options: TokenHandlerOptions = { provider: proxyProvider }; - proxyApp.use('/token', tokenHandler(options)); - - const redirectUri = 'https://example.com/callback'; - const response = await supertest(proxyApp) - .post('/token') - .type('form') - .send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - grant_type: 'authorization_code', - code: 'valid_code', - code_verifier: 'any_verifier', - redirect_uri: redirectUri - }); - - expect(response.status).toBe(200); - expect(response.body.access_token).toBe('mock_access_token'); - - expect(global.fetch).toHaveBeenCalledWith( - 'https://example.com/token', - expect.objectContaining({ - method: 'POST', - headers: { - 'Content-Type': 'application/x-www-form-urlencoded' - }, - body: expect.stringContaining(`redirect_uri=${encodeURIComponent(redirectUri)}`) - }) - ); - } finally { - global.fetch = originalFetch; - } - }); - }); - - describe('Refresh token grant', () => { - it('requires refresh_token parameter', async () => { - const response = await supertest(app) - .post('/token') - .type('form') - .send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - grant_type: 'refresh_token' - // Missing refresh_token + it('passes through code verifier when using proxy provider', async () => { + const originalFetch = global.fetch; + + try { + global.fetch = jest.fn().mockResolvedValue({ + ok: true, + json: () => Promise.resolve(mockTokens) + }); + + const proxyProvider = new ProxyOAuthServerProvider({ + endpoints: { + authorizationUrl: 'https://example.com/authorize', + tokenUrl: 'https://example.com/token' + }, + verifyAccessToken: async token => ({ + token, + clientId: 'valid-client', + scopes: ['read', 'write'], + expiresAt: Date.now() / 1000 + 3600 + }), + getClient: async clientId => (clientId === 'valid-client' ? validClient : undefined) + }); + + const proxyApp = express(); + const options: TokenHandlerOptions = { provider: proxyProvider }; + proxyApp.use('/token', tokenHandler(options)); + + const response = await supertest(proxyApp).post('/token').type('form').send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + grant_type: 'authorization_code', + code: 'valid_code', + code_verifier: 'any_verifier', + redirect_uri: 'https://example.com/callback' + }); + + expect(response.status).toBe(200); + expect(response.body.access_token).toBe('mock_access_token'); + + expect(global.fetch).toHaveBeenCalledWith( + 'https://example.com/token', + expect.objectContaining({ + method: 'POST', + headers: { + 'Content-Type': 'application/x-www-form-urlencoded' + }, + body: expect.stringContaining('code_verifier=any_verifier') + }) + ); + } finally { + global.fetch = originalFetch; + } }); - expect(response.status).toBe(400); - expect(response.body.error).toBe('invalid_request'); + it('passes through redirect_uri when using proxy provider', async () => { + const originalFetch = global.fetch; + + try { + global.fetch = jest.fn().mockResolvedValue({ + ok: true, + json: () => Promise.resolve(mockTokens) + }); + + const proxyProvider = new ProxyOAuthServerProvider({ + endpoints: { + authorizationUrl: 'https://example.com/authorize', + tokenUrl: 'https://example.com/token' + }, + verifyAccessToken: async token => ({ + token, + clientId: 'valid-client', + scopes: ['read', 'write'], + expiresAt: Date.now() / 1000 + 3600 + }), + getClient: async clientId => (clientId === 'valid-client' ? validClient : undefined) + }); + + const proxyApp = express(); + const options: TokenHandlerOptions = { provider: proxyProvider }; + proxyApp.use('/token', tokenHandler(options)); + + const redirectUri = 'https://example.com/callback'; + const response = await supertest(proxyApp).post('/token').type('form').send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + grant_type: 'authorization_code', + code: 'valid_code', + code_verifier: 'any_verifier', + redirect_uri: redirectUri + }); + + expect(response.status).toBe(200); + expect(response.body.access_token).toBe('mock_access_token'); + + expect(global.fetch).toHaveBeenCalledWith( + 'https://example.com/token', + expect.objectContaining({ + method: 'POST', + headers: { + 'Content-Type': 'application/x-www-form-urlencoded' + }, + body: expect.stringContaining(`redirect_uri=${encodeURIComponent(redirectUri)}`) + }) + ); + } finally { + global.fetch = originalFetch; + } + }); }); - it('rejects invalid refresh tokens', async () => { - const response = await supertest(app) - .post('/token') - .type('form') - .send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - grant_type: 'refresh_token', - refresh_token: 'invalid_refresh_token' + describe('Refresh token grant', () => { + it('requires refresh_token parameter', async () => { + const response = await supertest(app).post('/token').type('form').send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + grant_type: 'refresh_token' + // Missing refresh_token + }); + + expect(response.status).toBe(400); + expect(response.body.error).toBe('invalid_request'); }); - expect(response.status).toBe(400); - expect(response.body.error).toBe('invalid_grant'); - }); + it('rejects invalid refresh tokens', async () => { + const response = await supertest(app).post('/token').type('form').send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + grant_type: 'refresh_token', + refresh_token: 'invalid_refresh_token' + }); - it('returns new tokens for valid refresh token', async () => { - const mockExchangeRefresh = jest.spyOn(mockProvider, 'exchangeRefreshToken'); - const response = await supertest(app) - .post('/token') - .type('form') - .send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - resource: 'https://api.example.com/resource', - grant_type: 'refresh_token', - refresh_token: 'valid_refresh_token' + expect(response.status).toBe(400); + expect(response.body.error).toBe('invalid_grant'); }); - expect(response.status).toBe(200); - expect(response.body.access_token).toBe('new_mock_access_token'); - expect(response.body.token_type).toBe('bearer'); - expect(response.body.expires_in).toBe(3600); - expect(response.body.refresh_token).toBe('new_mock_refresh_token'); - expect(mockExchangeRefresh).toHaveBeenCalledWith( - validClient, - 'valid_refresh_token', - undefined, // scopes - new URL('https://api.example.com/resource') // resource parameter - ); - }); - - it('respects requested scopes on refresh', async () => { - const response = await supertest(app) - .post('/token') - .type('form') - .send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - grant_type: 'refresh_token', - refresh_token: 'valid_refresh_token', - scope: 'profile email' + it('returns new tokens for valid refresh token', async () => { + const mockExchangeRefresh = jest.spyOn(mockProvider, 'exchangeRefreshToken'); + const response = await supertest(app).post('/token').type('form').send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + resource: 'https://api.example.com/resource', + grant_type: 'refresh_token', + refresh_token: 'valid_refresh_token' + }); + + expect(response.status).toBe(200); + expect(response.body.access_token).toBe('new_mock_access_token'); + expect(response.body.token_type).toBe('bearer'); + expect(response.body.expires_in).toBe(3600); + expect(response.body.refresh_token).toBe('new_mock_refresh_token'); + expect(mockExchangeRefresh).toHaveBeenCalledWith( + validClient, + 'valid_refresh_token', + undefined, // scopes + new URL('https://api.example.com/resource') // resource parameter + ); }); - expect(response.status).toBe(200); - expect(response.body.scope).toBe('profile email'); - }); - }); - - describe('CORS support', () => { - it('includes CORS headers in response', async () => { - const response = await supertest(app) - .post('/token') - .type('form') - .set('Origin', 'https://example.com') - .send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - grant_type: 'authorization_code', - code: 'valid_code', - code_verifier: 'valid_verifier' + it('respects requested scopes on refresh', async () => { + const response = await supertest(app).post('/token').type('form').send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + grant_type: 'refresh_token', + refresh_token: 'valid_refresh_token', + scope: 'profile email' + }); + + expect(response.status).toBe(200); + expect(response.body.scope).toBe('profile email'); }); + }); - expect(response.header['access-control-allow-origin']).toBe('*'); + describe('CORS support', () => { + it('includes CORS headers in response', async () => { + const response = await supertest(app).post('/token').type('form').set('Origin', 'https://example.com').send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + grant_type: 'authorization_code', + code: 'valid_code', + code_verifier: 'valid_verifier' + }); + + expect(response.header['access-control-allow-origin']).toBe('*'); + }); }); - }); -}); \ No newline at end of file +}); diff --git a/src/server/auth/handlers/token.ts b/src/server/auth/handlers/token.ts index b2ab74391..c387ff7bf 100644 --- a/src/server/auth/handlers/token.ts +++ b/src/server/auth/handlers/token.ts @@ -1,152 +1,157 @@ -import { z } from "zod"; -import express, { RequestHandler } from "express"; -import { OAuthServerProvider } from "../provider.js"; -import cors from "cors"; -import { verifyChallenge } from "pkce-challenge"; -import { authenticateClient } from "../middleware/clientAuth.js"; -import { rateLimit, Options as RateLimitOptions } from "express-rate-limit"; -import { allowedMethods } from "../middleware/allowedMethods.js"; +import { z } from 'zod'; +import express, { RequestHandler } from 'express'; +import { OAuthServerProvider } from '../provider.js'; +import cors from 'cors'; +import { verifyChallenge } from 'pkce-challenge'; +import { authenticateClient } from '../middleware/clientAuth.js'; +import { rateLimit, Options as RateLimitOptions } from 'express-rate-limit'; +import { allowedMethods } from '../middleware/allowedMethods.js'; import { - InvalidRequestError, - InvalidGrantError, - UnsupportedGrantTypeError, - ServerError, - TooManyRequestsError, - OAuthError -} from "../errors.js"; + InvalidRequestError, + InvalidGrantError, + UnsupportedGrantTypeError, + ServerError, + TooManyRequestsError, + OAuthError +} from '../errors.js'; export type TokenHandlerOptions = { - provider: OAuthServerProvider; - /** - * Rate limiting configuration for the token endpoint. - * Set to false to disable rate limiting for this endpoint. - */ - rateLimit?: Partial | false; + provider: OAuthServerProvider; + /** + * Rate limiting configuration for the token endpoint. + * Set to false to disable rate limiting for this endpoint. + */ + rateLimit?: Partial | false; }; const TokenRequestSchema = z.object({ - grant_type: z.string(), + grant_type: z.string() }); const AuthorizationCodeGrantSchema = z.object({ - code: z.string(), - code_verifier: z.string(), - redirect_uri: z.string().optional(), - resource: z.string().url().optional(), + code: z.string(), + code_verifier: z.string(), + redirect_uri: z.string().optional(), + resource: z.string().url().optional() }); const RefreshTokenGrantSchema = z.object({ - refresh_token: z.string(), - scope: z.string().optional(), - resource: z.string().url().optional(), + refresh_token: z.string(), + scope: z.string().optional(), + resource: z.string().url().optional() }); export function tokenHandler({ provider, rateLimit: rateLimitConfig }: TokenHandlerOptions): RequestHandler { - // Nested router so we can configure middleware and restrict HTTP method - const router = express.Router(); - - // Configure CORS to allow any origin, to make accessible to web-based MCP clients - router.use(cors()); - - router.use(allowedMethods(["POST"])); - router.use(express.urlencoded({ extended: false })); - - // Apply rate limiting unless explicitly disabled - if (rateLimitConfig !== false) { - router.use(rateLimit({ - windowMs: 15 * 60 * 1000, // 15 minutes - max: 50, // 50 requests per windowMs - standardHeaders: true, - legacyHeaders: false, - message: new TooManyRequestsError('You have exceeded the rate limit for token requests').toResponseObject(), - ...rateLimitConfig - })); - } - - // Authenticate and extract client details - router.use(authenticateClient({ clientsStore: provider.clientsStore })); - - router.post("/", async (req, res) => { - res.setHeader('Cache-Control', 'no-store'); - - try { - const parseResult = TokenRequestSchema.safeParse(req.body); - if (!parseResult.success) { - throw new InvalidRequestError(parseResult.error.message); - } - - const { grant_type } = parseResult.data; - - const client = req.client; - if (!client) { - // This should never happen - throw new ServerError("Internal Server Error"); - } - - switch (grant_type) { - case "authorization_code": { - const parseResult = AuthorizationCodeGrantSchema.safeParse(req.body); - if (!parseResult.success) { - throw new InvalidRequestError(parseResult.error.message); - } - - const { code, code_verifier, redirect_uri, resource } = parseResult.data; - - const skipLocalPkceValidation = provider.skipLocalPkceValidation; - - // Perform local PKCE validation unless explicitly skipped - // (e.g. to validate code_verifier in upstream server) - if (!skipLocalPkceValidation) { - const codeChallenge = await provider.challengeForAuthorizationCode(client, code); - if (!(await verifyChallenge(code_verifier, codeChallenge))) { - throw new InvalidGrantError("code_verifier does not match the challenge"); + // Nested router so we can configure middleware and restrict HTTP method + const router = express.Router(); + + // Configure CORS to allow any origin, to make accessible to web-based MCP clients + router.use(cors()); + + router.use(allowedMethods(['POST'])); + router.use(express.urlencoded({ extended: false })); + + // Apply rate limiting unless explicitly disabled + if (rateLimitConfig !== false) { + router.use( + rateLimit({ + windowMs: 15 * 60 * 1000, // 15 minutes + max: 50, // 50 requests per windowMs + standardHeaders: true, + legacyHeaders: false, + message: new TooManyRequestsError('You have exceeded the rate limit for token requests').toResponseObject(), + ...rateLimitConfig + }) + ); + } + + // Authenticate and extract client details + router.use(authenticateClient({ clientsStore: provider.clientsStore })); + + router.post('/', async (req, res) => { + res.setHeader('Cache-Control', 'no-store'); + + try { + const parseResult = TokenRequestSchema.safeParse(req.body); + if (!parseResult.success) { + throw new InvalidRequestError(parseResult.error.message); } - } - - // Passes the code_verifier to the provider if PKCE validation didn't occur locally - const tokens = await provider.exchangeAuthorizationCode( - client, - code, - skipLocalPkceValidation ? code_verifier : undefined, - redirect_uri, - resource ? new URL(resource) : undefined - ); - res.status(200).json(tokens); - break; - } - case "refresh_token": { - const parseResult = RefreshTokenGrantSchema.safeParse(req.body); - if (!parseResult.success) { - throw new InvalidRequestError(parseResult.error.message); - } + const { grant_type } = parseResult.data; - const { refresh_token, scope, resource } = parseResult.data; + const client = req.client; + if (!client) { + // This should never happen + throw new ServerError('Internal Server Error'); + } - const scopes = scope?.split(" "); - const tokens = await provider.exchangeRefreshToken(client, refresh_token, scopes, resource ? new URL(resource) : undefined); - res.status(200).json(tokens); - break; + switch (grant_type) { + case 'authorization_code': { + const parseResult = AuthorizationCodeGrantSchema.safeParse(req.body); + if (!parseResult.success) { + throw new InvalidRequestError(parseResult.error.message); + } + + const { code, code_verifier, redirect_uri, resource } = parseResult.data; + + const skipLocalPkceValidation = provider.skipLocalPkceValidation; + + // Perform local PKCE validation unless explicitly skipped + // (e.g. to validate code_verifier in upstream server) + if (!skipLocalPkceValidation) { + const codeChallenge = await provider.challengeForAuthorizationCode(client, code); + if (!(await verifyChallenge(code_verifier, codeChallenge))) { + throw new InvalidGrantError('code_verifier does not match the challenge'); + } + } + + // Passes the code_verifier to the provider if PKCE validation didn't occur locally + const tokens = await provider.exchangeAuthorizationCode( + client, + code, + skipLocalPkceValidation ? code_verifier : undefined, + redirect_uri, + resource ? new URL(resource) : undefined + ); + res.status(200).json(tokens); + break; + } + + case 'refresh_token': { + const parseResult = RefreshTokenGrantSchema.safeParse(req.body); + if (!parseResult.success) { + throw new InvalidRequestError(parseResult.error.message); + } + + const { refresh_token, scope, resource } = parseResult.data; + + const scopes = scope?.split(' '); + const tokens = await provider.exchangeRefreshToken( + client, + refresh_token, + scopes, + resource ? new URL(resource) : undefined + ); + res.status(200).json(tokens); + break; + } + + // Not supported right now + //case "client_credentials": + + default: + throw new UnsupportedGrantTypeError('The grant type is not supported by this authorization server.'); + } + } catch (error) { + if (error instanceof OAuthError) { + const status = error instanceof ServerError ? 500 : 400; + res.status(status).json(error.toResponseObject()); + } else { + const serverError = new ServerError('Internal Server Error'); + res.status(500).json(serverError.toResponseObject()); + } } + }); - // Not supported right now - //case "client_credentials": - - default: - throw new UnsupportedGrantTypeError( - "The grant type is not supported by this authorization server." - ); - } - } catch (error) { - if (error instanceof OAuthError) { - const status = error instanceof ServerError ? 500 : 400; - res.status(status).json(error.toResponseObject()); - } else { - const serverError = new ServerError("Internal Server Error"); - res.status(500).json(serverError.toResponseObject()); - } - } - }); - - return router; -} \ No newline at end of file + return router; +} diff --git a/src/server/auth/middleware/allowedMethods.test.ts b/src/server/auth/middleware/allowedMethods.test.ts index 61f8c8018..1f30fea85 100644 --- a/src/server/auth/middleware/allowedMethods.test.ts +++ b/src/server/auth/middleware/allowedMethods.test.ts @@ -1,75 +1,75 @@ -import { allowedMethods } from "./allowedMethods.js"; -import express, { Request, Response } from "express"; -import request from "supertest"; +import { allowedMethods } from './allowedMethods.js'; +import express, { Request, Response } from 'express'; +import request from 'supertest'; -describe("allowedMethods", () => { - let app: express.Express; +describe('allowedMethods', () => { + let app: express.Express; - beforeEach(() => { - app = express(); + beforeEach(() => { + app = express(); - // Set up a test router with a GET handler and 405 middleware - const router = express.Router(); + // Set up a test router with a GET handler and 405 middleware + const router = express.Router(); - router.get("/test", (req, res) => { - res.status(200).send("GET success"); + router.get('/test', (req, res) => { + res.status(200).send('GET success'); + }); + + // Add method not allowed middleware for all other methods + router.all('/test', allowedMethods(['GET'])); + + app.use(router); + }); + + test('allows specified HTTP method', async () => { + const response = await request(app).get('/test'); + expect(response.status).toBe(200); + expect(response.text).toBe('GET success'); }); - // Add method not allowed middleware for all other methods - router.all("/test", allowedMethods(["GET"])); - - app.use(router); - }); - - test("allows specified HTTP method", async () => { - const response = await request(app).get("/test"); - expect(response.status).toBe(200); - expect(response.text).toBe("GET success"); - }); - - test("returns 405 for unspecified HTTP methods", async () => { - const methods = ["post", "put", "delete", "patch"]; - - for (const method of methods) { - // @ts-expect-error - dynamic method call - const response = await request(app)[method]("/test"); - expect(response.status).toBe(405); - expect(response.body).toEqual({ - error: "method_not_allowed", - error_description: `The method ${method.toUpperCase()} is not allowed for this endpoint` - }); - } - }); - - test("includes Allow header with specified methods", async () => { - const response = await request(app).post("/test"); - expect(response.headers.allow).toBe("GET"); - }); - - test("works with multiple allowed methods", async () => { - const multiMethodApp = express(); - const router = express.Router(); - - router.get("/multi", (req: Request, res: Response) => { - res.status(200).send("GET"); + test('returns 405 for unspecified HTTP methods', async () => { + const methods = ['post', 'put', 'delete', 'patch']; + + for (const method of methods) { + // @ts-expect-error - dynamic method call + const response = await request(app)[method]('/test'); + expect(response.status).toBe(405); + expect(response.body).toEqual({ + error: 'method_not_allowed', + error_description: `The method ${method.toUpperCase()} is not allowed for this endpoint` + }); + } }); - router.post("/multi", (req: Request, res: Response) => { - res.status(200).send("POST"); + + test('includes Allow header with specified methods', async () => { + const response = await request(app).post('/test'); + expect(response.headers.allow).toBe('GET'); }); - router.all("/multi", allowedMethods(["GET", "POST"])); - multiMethodApp.use(router); + test('works with multiple allowed methods', async () => { + const multiMethodApp = express(); + const router = express.Router(); + + router.get('/multi', (req: Request, res: Response) => { + res.status(200).send('GET'); + }); + router.post('/multi', (req: Request, res: Response) => { + res.status(200).send('POST'); + }); + router.all('/multi', allowedMethods(['GET', 'POST'])); - // Allowed methods should work - const getResponse = await request(multiMethodApp).get("/multi"); - expect(getResponse.status).toBe(200); + multiMethodApp.use(router); - const postResponse = await request(multiMethodApp).post("/multi"); - expect(postResponse.status).toBe(200); + // Allowed methods should work + const getResponse = await request(multiMethodApp).get('/multi'); + expect(getResponse.status).toBe(200); - // Unallowed methods should return 405 - const putResponse = await request(multiMethodApp).put("/multi"); - expect(putResponse.status).toBe(405); - expect(putResponse.headers.allow).toBe("GET, POST"); - }); -}); \ No newline at end of file + const postResponse = await request(multiMethodApp).post('/multi'); + expect(postResponse.status).toBe(200); + + // Unallowed methods should return 405 + const putResponse = await request(multiMethodApp).put('/multi'); + expect(putResponse.status).toBe(405); + expect(putResponse.headers.allow).toBe('GET, POST'); + }); +}); diff --git a/src/server/auth/middleware/allowedMethods.ts b/src/server/auth/middleware/allowedMethods.ts index cd80c7c21..74633aa57 100644 --- a/src/server/auth/middleware/allowedMethods.ts +++ b/src/server/auth/middleware/allowedMethods.ts @@ -1,22 +1,20 @@ -import { RequestHandler } from "express"; -import { MethodNotAllowedError } from "../errors.js"; +import { RequestHandler } from 'express'; +import { MethodNotAllowedError } from '../errors.js'; /** * Middleware to handle unsupported HTTP methods with a 405 Method Not Allowed response. - * + * * @param allowedMethods Array of allowed HTTP methods for this endpoint (e.g., ['GET', 'POST']) * @returns Express middleware that returns a 405 error if method not in allowed list */ export function allowedMethods(allowedMethods: string[]): RequestHandler { - return (req, res, next) => { - if (allowedMethods.includes(req.method)) { - next(); - return; - } + return (req, res, next) => { + if (allowedMethods.includes(req.method)) { + next(); + return; + } - const error = new MethodNotAllowedError(`The method ${req.method} is not allowed for this endpoint`); - res.status(405) - .set('Allow', allowedMethods.join(', ')) - .json(error.toResponseObject()); - }; -} \ No newline at end of file + const error = new MethodNotAllowedError(`The method ${req.method} is not allowed for this endpoint`); + res.status(405).set('Allow', allowedMethods.join(', ')).json(error.toResponseObject()); + }; +} diff --git a/src/server/auth/middleware/bearerAuth.test.ts b/src/server/auth/middleware/bearerAuth.test.ts index 38639b1de..5790a0eb0 100644 --- a/src/server/auth/middleware/bearerAuth.test.ts +++ b/src/server/auth/middleware/bearerAuth.test.ts @@ -1,459 +1,438 @@ -import { Request, Response } from "express"; -import { requireBearerAuth } from "./bearerAuth.js"; -import { AuthInfo } from "../types.js"; -import { InsufficientScopeError, InvalidTokenError, CustomOAuthError, ServerError } from "../errors.js"; -import { OAuthTokenVerifier } from "../provider.js"; +import { Request, Response } from 'express'; +import { requireBearerAuth } from './bearerAuth.js'; +import { AuthInfo } from '../types.js'; +import { InsufficientScopeError, InvalidTokenError, CustomOAuthError, ServerError } from '../errors.js'; +import { OAuthTokenVerifier } from '../provider.js'; // Mock verifier const mockVerifyAccessToken = jest.fn(); const mockVerifier: OAuthTokenVerifier = { - verifyAccessToken: mockVerifyAccessToken, + verifyAccessToken: mockVerifyAccessToken }; -describe("requireBearerAuth middleware", () => { - let mockRequest: Partial; - let mockResponse: Partial; - let nextFunction: jest.Mock; - - beforeEach(() => { - mockRequest = { - headers: {}, - }; - mockResponse = { - status: jest.fn().mockReturnThis(), - json: jest.fn(), - set: jest.fn().mockReturnThis(), - }; - nextFunction = jest.fn(); - jest.spyOn(console, 'error').mockImplementation(() => {}); - }) - - afterEach(() => { - jest.clearAllMocks(); - }); - - it("should call next when token is valid", async () => { - const validAuthInfo: AuthInfo = { - token: "valid-token", - clientId: "client-123", - scopes: ["read", "write"], - expiresAt: Math.floor(Date.now() / 1000) + 3600, // Token expires in an hour - }; - mockVerifyAccessToken.mockResolvedValue(validAuthInfo); - - mockRequest.headers = { - authorization: "Bearer valid-token", - }; - - const middleware = requireBearerAuth({ verifier: mockVerifier }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockVerifyAccessToken).toHaveBeenCalledWith("valid-token"); - expect(mockRequest.auth).toEqual(validAuthInfo); - expect(nextFunction).toHaveBeenCalled(); - expect(mockResponse.status).not.toHaveBeenCalled(); - expect(mockResponse.json).not.toHaveBeenCalled(); - }); - - it.each([ - [100], // Token expired 100 seconds ago - [0], // Token expires at the same time as now - ])("should reject expired tokens (expired %s seconds ago)", async (expiredSecondsAgo: number) => { - const expiresAt = Math.floor(Date.now() / 1000) - expiredSecondsAgo; - const expiredAuthInfo: AuthInfo = { - token: "expired-token", - clientId: "client-123", - scopes: ["read", "write"], - expiresAt - }; - mockVerifyAccessToken.mockResolvedValue(expiredAuthInfo); - - mockRequest.headers = { - authorization: "Bearer expired-token", - }; - - const middleware = requireBearerAuth({ verifier: mockVerifier }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockVerifyAccessToken).toHaveBeenCalledWith("expired-token"); - expect(mockResponse.status).toHaveBeenCalledWith(401); - expect(mockResponse.set).toHaveBeenCalledWith( - "WWW-Authenticate", - expect.stringContaining('Bearer error="invalid_token"') - ); - expect(mockResponse.json).toHaveBeenCalledWith( - expect.objectContaining({ error: "invalid_token", error_description: "Token has expired" }) - ); - expect(nextFunction).not.toHaveBeenCalled(); - }); - - it.each([ - [undefined], // Token has no expiration time - [NaN], // Token has no expiration time - ])("should reject tokens with no expiration time (expiresAt: %s)", async (expiresAt: number | undefined) => { - const noExpirationAuthInfo: AuthInfo = { - token: "no-expiration-token", - clientId: "client-123", - scopes: ["read", "write"], - expiresAt - }; - mockVerifyAccessToken.mockResolvedValue(noExpirationAuthInfo); - - mockRequest.headers = { - authorization: "Bearer expired-token", - }; - - const middleware = requireBearerAuth({ verifier: mockVerifier }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockVerifyAccessToken).toHaveBeenCalledWith("expired-token"); - expect(mockResponse.status).toHaveBeenCalledWith(401); - expect(mockResponse.set).toHaveBeenCalledWith( - "WWW-Authenticate", - expect.stringContaining('Bearer error="invalid_token"') - ); - expect(mockResponse.json).toHaveBeenCalledWith( - expect.objectContaining({ error: "invalid_token", error_description: "Token has no expiration time" }) - ); - expect(nextFunction).not.toHaveBeenCalled(); - }); - - it("should accept non-expired tokens", async () => { - const nonExpiredAuthInfo: AuthInfo = { - token: "valid-token", - clientId: "client-123", - scopes: ["read", "write"], - expiresAt: Math.floor(Date.now() / 1000) + 3600, // Token expires in an hour - }; - mockVerifyAccessToken.mockResolvedValue(nonExpiredAuthInfo); - - mockRequest.headers = { - authorization: "Bearer valid-token", - }; - - const middleware = requireBearerAuth({ verifier: mockVerifier }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockVerifyAccessToken).toHaveBeenCalledWith("valid-token"); - expect(mockRequest.auth).toEqual(nonExpiredAuthInfo); - expect(nextFunction).toHaveBeenCalled(); - expect(mockResponse.status).not.toHaveBeenCalled(); - expect(mockResponse.json).not.toHaveBeenCalled(); - }); - - it("should require specific scopes when configured", async () => { - const authInfo: AuthInfo = { - token: "valid-token", - clientId: "client-123", - scopes: ["read"], - }; - mockVerifyAccessToken.mockResolvedValue(authInfo); - - mockRequest.headers = { - authorization: "Bearer valid-token", - }; - - const middleware = requireBearerAuth({ - verifier: mockVerifier, - requiredScopes: ["read", "write"] +describe('requireBearerAuth middleware', () => { + let mockRequest: Partial; + let mockResponse: Partial; + let nextFunction: jest.Mock; + + beforeEach(() => { + mockRequest = { + headers: {} + }; + mockResponse = { + status: jest.fn().mockReturnThis(), + json: jest.fn(), + set: jest.fn().mockReturnThis() + }; + nextFunction = jest.fn(); + jest.spyOn(console, 'error').mockImplementation(() => {}); }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockVerifyAccessToken).toHaveBeenCalledWith("valid-token"); - expect(mockResponse.status).toHaveBeenCalledWith(403); - expect(mockResponse.set).toHaveBeenCalledWith( - "WWW-Authenticate", - expect.stringContaining('Bearer error="insufficient_scope"') - ); - expect(mockResponse.json).toHaveBeenCalledWith( - expect.objectContaining({ error: "insufficient_scope", error_description: "Insufficient scope" }) - ); - expect(nextFunction).not.toHaveBeenCalled(); - }); - - it("should accept token with all required scopes", async () => { - const authInfo: AuthInfo = { - token: "valid-token", - clientId: "client-123", - scopes: ["read", "write", "admin"], - expiresAt: Math.floor(Date.now() / 1000) + 3600, // Token expires in an hour - }; - mockVerifyAccessToken.mockResolvedValue(authInfo); - - mockRequest.headers = { - authorization: "Bearer valid-token", - }; - - const middleware = requireBearerAuth({ - verifier: mockVerifier, - requiredScopes: ["read", "write"] + afterEach(() => { + jest.clearAllMocks(); }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockVerifyAccessToken).toHaveBeenCalledWith("valid-token"); - expect(mockRequest.auth).toEqual(authInfo); - expect(nextFunction).toHaveBeenCalled(); - expect(mockResponse.status).not.toHaveBeenCalled(); - expect(mockResponse.json).not.toHaveBeenCalled(); - }); - - it("should return 401 when no Authorization header is present", async () => { - const middleware = requireBearerAuth({ verifier: mockVerifier }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockVerifyAccessToken).not.toHaveBeenCalled(); - expect(mockResponse.status).toHaveBeenCalledWith(401); - expect(mockResponse.set).toHaveBeenCalledWith( - "WWW-Authenticate", - expect.stringContaining('Bearer error="invalid_token"') - ); - expect(mockResponse.json).toHaveBeenCalledWith( - expect.objectContaining({ error: "invalid_token", error_description: "Missing Authorization header" }) - ); - expect(nextFunction).not.toHaveBeenCalled(); - }); - - it("should return 401 when Authorization header format is invalid", async () => { - mockRequest.headers = { - authorization: "InvalidFormat", - }; - - const middleware = requireBearerAuth({ verifier: mockVerifier }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockVerifyAccessToken).not.toHaveBeenCalled(); - expect(mockResponse.status).toHaveBeenCalledWith(401); - expect(mockResponse.set).toHaveBeenCalledWith( - "WWW-Authenticate", - expect.stringContaining('Bearer error="invalid_token"') - ); - expect(mockResponse.json).toHaveBeenCalledWith( - expect.objectContaining({ - error: "invalid_token", - error_description: "Invalid Authorization header format, expected 'Bearer TOKEN'" - }) - ); - expect(nextFunction).not.toHaveBeenCalled(); - }); - - it("should return 401 when token verification fails with InvalidTokenError", async () => { - mockRequest.headers = { - authorization: "Bearer invalid-token", - }; - - mockVerifyAccessToken.mockRejectedValue(new InvalidTokenError("Token expired")); - - const middleware = requireBearerAuth({ verifier: mockVerifier }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockVerifyAccessToken).toHaveBeenCalledWith("invalid-token"); - expect(mockResponse.status).toHaveBeenCalledWith(401); - expect(mockResponse.set).toHaveBeenCalledWith( - "WWW-Authenticate", - expect.stringContaining('Bearer error="invalid_token"') - ); - expect(mockResponse.json).toHaveBeenCalledWith( - expect.objectContaining({ error: "invalid_token", error_description: "Token expired" }) - ); - expect(nextFunction).not.toHaveBeenCalled(); - }); - - it("should return 403 when access token has insufficient scopes", async () => { - mockRequest.headers = { - authorization: "Bearer valid-token", - }; - - mockVerifyAccessToken.mockRejectedValue(new InsufficientScopeError("Required scopes: read, write")); - - const middleware = requireBearerAuth({ verifier: mockVerifier }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockVerifyAccessToken).toHaveBeenCalledWith("valid-token"); - expect(mockResponse.status).toHaveBeenCalledWith(403); - expect(mockResponse.set).toHaveBeenCalledWith( - "WWW-Authenticate", - expect.stringContaining('Bearer error="insufficient_scope"') - ); - expect(mockResponse.json).toHaveBeenCalledWith( - expect.objectContaining({ error: "insufficient_scope", error_description: "Required scopes: read, write" }) - ); - expect(nextFunction).not.toHaveBeenCalled(); - }); - - it("should return 500 when a ServerError occurs", async () => { - mockRequest.headers = { - authorization: "Bearer valid-token", - }; - - mockVerifyAccessToken.mockRejectedValue(new ServerError("Internal server issue")); - - const middleware = requireBearerAuth({ verifier: mockVerifier }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockVerifyAccessToken).toHaveBeenCalledWith("valid-token"); - expect(mockResponse.status).toHaveBeenCalledWith(500); - expect(mockResponse.json).toHaveBeenCalledWith( - expect.objectContaining({ error: "server_error", error_description: "Internal server issue" }) - ); - expect(nextFunction).not.toHaveBeenCalled(); - }); - - it("should return 400 for generic OAuthError", async () => { - mockRequest.headers = { - authorization: "Bearer valid-token", - }; - - mockVerifyAccessToken.mockRejectedValue(new CustomOAuthError("custom_error", "Some OAuth error")); - - const middleware = requireBearerAuth({ verifier: mockVerifier }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockVerifyAccessToken).toHaveBeenCalledWith("valid-token"); - expect(mockResponse.status).toHaveBeenCalledWith(400); - expect(mockResponse.json).toHaveBeenCalledWith( - expect.objectContaining({ error: "custom_error", error_description: "Some OAuth error" }) - ); - expect(nextFunction).not.toHaveBeenCalled(); - }); - - it("should return 500 when unexpected error occurs", async () => { - mockRequest.headers = { - authorization: "Bearer valid-token", - }; - - mockVerifyAccessToken.mockRejectedValue(new Error("Unexpected error")); - - const middleware = requireBearerAuth({ verifier: mockVerifier }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockVerifyAccessToken).toHaveBeenCalledWith("valid-token"); - expect(mockResponse.status).toHaveBeenCalledWith(500); - expect(mockResponse.json).toHaveBeenCalledWith( - expect.objectContaining({ error: "server_error", error_description: "Internal Server Error" }) - ); - expect(nextFunction).not.toHaveBeenCalled(); - }); - - describe("with resourceMetadataUrl", () => { - const resourceMetadataUrl = "https://api.example.com/.well-known/oauth-protected-resource"; - - it("should include resource_metadata in WWW-Authenticate header for 401 responses", async () => { - mockRequest.headers = {}; - - const middleware = requireBearerAuth({ verifier: mockVerifier, resourceMetadataUrl }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockResponse.status).toHaveBeenCalledWith(401); - expect(mockResponse.set).toHaveBeenCalledWith( - "WWW-Authenticate", - `Bearer error="invalid_token", error_description="Missing Authorization header", resource_metadata="${resourceMetadataUrl}"` - ); - expect(nextFunction).not.toHaveBeenCalled(); + it('should call next when token is valid', async () => { + const validAuthInfo: AuthInfo = { + token: 'valid-token', + clientId: 'client-123', + scopes: ['read', 'write'], + expiresAt: Math.floor(Date.now() / 1000) + 3600 // Token expires in an hour + }; + mockVerifyAccessToken.mockResolvedValue(validAuthInfo); + + mockRequest.headers = { + authorization: 'Bearer valid-token' + }; + + const middleware = requireBearerAuth({ verifier: mockVerifier }); + await middleware(mockRequest as Request, mockResponse as Response, nextFunction); + + expect(mockVerifyAccessToken).toHaveBeenCalledWith('valid-token'); + expect(mockRequest.auth).toEqual(validAuthInfo); + expect(nextFunction).toHaveBeenCalled(); + expect(mockResponse.status).not.toHaveBeenCalled(); + expect(mockResponse.json).not.toHaveBeenCalled(); }); - it("should include resource_metadata in WWW-Authenticate header when token verification fails", async () => { - mockRequest.headers = { - authorization: "Bearer invalid-token", - }; + it.each([ + [100], // Token expired 100 seconds ago + [0] // Token expires at the same time as now + ])('should reject expired tokens (expired %s seconds ago)', async (expiredSecondsAgo: number) => { + const expiresAt = Math.floor(Date.now() / 1000) - expiredSecondsAgo; + const expiredAuthInfo: AuthInfo = { + token: 'expired-token', + clientId: 'client-123', + scopes: ['read', 'write'], + expiresAt + }; + mockVerifyAccessToken.mockResolvedValue(expiredAuthInfo); + + mockRequest.headers = { + authorization: 'Bearer expired-token' + }; + + const middleware = requireBearerAuth({ verifier: mockVerifier }); + await middleware(mockRequest as Request, mockResponse as Response, nextFunction); + + expect(mockVerifyAccessToken).toHaveBeenCalledWith('expired-token'); + expect(mockResponse.status).toHaveBeenCalledWith(401); + expect(mockResponse.set).toHaveBeenCalledWith('WWW-Authenticate', expect.stringContaining('Bearer error="invalid_token"')); + expect(mockResponse.json).toHaveBeenCalledWith( + expect.objectContaining({ error: 'invalid_token', error_description: 'Token has expired' }) + ); + expect(nextFunction).not.toHaveBeenCalled(); + }); + + it.each([ + [undefined], // Token has no expiration time + [NaN] // Token has no expiration time + ])('should reject tokens with no expiration time (expiresAt: %s)', async (expiresAt: number | undefined) => { + const noExpirationAuthInfo: AuthInfo = { + token: 'no-expiration-token', + clientId: 'client-123', + scopes: ['read', 'write'], + expiresAt + }; + mockVerifyAccessToken.mockResolvedValue(noExpirationAuthInfo); + + mockRequest.headers = { + authorization: 'Bearer expired-token' + }; + + const middleware = requireBearerAuth({ verifier: mockVerifier }); + await middleware(mockRequest as Request, mockResponse as Response, nextFunction); + + expect(mockVerifyAccessToken).toHaveBeenCalledWith('expired-token'); + expect(mockResponse.status).toHaveBeenCalledWith(401); + expect(mockResponse.set).toHaveBeenCalledWith('WWW-Authenticate', expect.stringContaining('Bearer error="invalid_token"')); + expect(mockResponse.json).toHaveBeenCalledWith( + expect.objectContaining({ error: 'invalid_token', error_description: 'Token has no expiration time' }) + ); + expect(nextFunction).not.toHaveBeenCalled(); + }); + + it('should accept non-expired tokens', async () => { + const nonExpiredAuthInfo: AuthInfo = { + token: 'valid-token', + clientId: 'client-123', + scopes: ['read', 'write'], + expiresAt: Math.floor(Date.now() / 1000) + 3600 // Token expires in an hour + }; + mockVerifyAccessToken.mockResolvedValue(nonExpiredAuthInfo); + + mockRequest.headers = { + authorization: 'Bearer valid-token' + }; + + const middleware = requireBearerAuth({ verifier: mockVerifier }); + await middleware(mockRequest as Request, mockResponse as Response, nextFunction); + + expect(mockVerifyAccessToken).toHaveBeenCalledWith('valid-token'); + expect(mockRequest.auth).toEqual(nonExpiredAuthInfo); + expect(nextFunction).toHaveBeenCalled(); + expect(mockResponse.status).not.toHaveBeenCalled(); + expect(mockResponse.json).not.toHaveBeenCalled(); + }); + + it('should require specific scopes when configured', async () => { + const authInfo: AuthInfo = { + token: 'valid-token', + clientId: 'client-123', + scopes: ['read'] + }; + mockVerifyAccessToken.mockResolvedValue(authInfo); + + mockRequest.headers = { + authorization: 'Bearer valid-token' + }; + + const middleware = requireBearerAuth({ + verifier: mockVerifier, + requiredScopes: ['read', 'write'] + }); + + await middleware(mockRequest as Request, mockResponse as Response, nextFunction); + + expect(mockVerifyAccessToken).toHaveBeenCalledWith('valid-token'); + expect(mockResponse.status).toHaveBeenCalledWith(403); + expect(mockResponse.set).toHaveBeenCalledWith('WWW-Authenticate', expect.stringContaining('Bearer error="insufficient_scope"')); + expect(mockResponse.json).toHaveBeenCalledWith( + expect.objectContaining({ error: 'insufficient_scope', error_description: 'Insufficient scope' }) + ); + expect(nextFunction).not.toHaveBeenCalled(); + }); + + it('should accept token with all required scopes', async () => { + const authInfo: AuthInfo = { + token: 'valid-token', + clientId: 'client-123', + scopes: ['read', 'write', 'admin'], + expiresAt: Math.floor(Date.now() / 1000) + 3600 // Token expires in an hour + }; + mockVerifyAccessToken.mockResolvedValue(authInfo); + + mockRequest.headers = { + authorization: 'Bearer valid-token' + }; + + const middleware = requireBearerAuth({ + verifier: mockVerifier, + requiredScopes: ['read', 'write'] + }); + + await middleware(mockRequest as Request, mockResponse as Response, nextFunction); + + expect(mockVerifyAccessToken).toHaveBeenCalledWith('valid-token'); + expect(mockRequest.auth).toEqual(authInfo); + expect(nextFunction).toHaveBeenCalled(); + expect(mockResponse.status).not.toHaveBeenCalled(); + expect(mockResponse.json).not.toHaveBeenCalled(); + }); + + it('should return 401 when no Authorization header is present', async () => { + const middleware = requireBearerAuth({ verifier: mockVerifier }); + await middleware(mockRequest as Request, mockResponse as Response, nextFunction); + + expect(mockVerifyAccessToken).not.toHaveBeenCalled(); + expect(mockResponse.status).toHaveBeenCalledWith(401); + expect(mockResponse.set).toHaveBeenCalledWith('WWW-Authenticate', expect.stringContaining('Bearer error="invalid_token"')); + expect(mockResponse.json).toHaveBeenCalledWith( + expect.objectContaining({ error: 'invalid_token', error_description: 'Missing Authorization header' }) + ); + expect(nextFunction).not.toHaveBeenCalled(); + }); - mockVerifyAccessToken.mockRejectedValue(new InvalidTokenError("Token expired")); + it('should return 401 when Authorization header format is invalid', async () => { + mockRequest.headers = { + authorization: 'InvalidFormat' + }; + + const middleware = requireBearerAuth({ verifier: mockVerifier }); + await middleware(mockRequest as Request, mockResponse as Response, nextFunction); + + expect(mockVerifyAccessToken).not.toHaveBeenCalled(); + expect(mockResponse.status).toHaveBeenCalledWith(401); + expect(mockResponse.set).toHaveBeenCalledWith('WWW-Authenticate', expect.stringContaining('Bearer error="invalid_token"')); + expect(mockResponse.json).toHaveBeenCalledWith( + expect.objectContaining({ + error: 'invalid_token', + error_description: "Invalid Authorization header format, expected 'Bearer TOKEN'" + }) + ); + expect(nextFunction).not.toHaveBeenCalled(); + }); + + it('should return 401 when token verification fails with InvalidTokenError', async () => { + mockRequest.headers = { + authorization: 'Bearer invalid-token' + }; + + mockVerifyAccessToken.mockRejectedValue(new InvalidTokenError('Token expired')); - const middleware = requireBearerAuth({ verifier: mockVerifier, resourceMetadataUrl }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); + const middleware = requireBearerAuth({ verifier: mockVerifier }); + await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - expect(mockResponse.status).toHaveBeenCalledWith(401); - expect(mockResponse.set).toHaveBeenCalledWith( - "WWW-Authenticate", - `Bearer error="invalid_token", error_description="Token expired", resource_metadata="${resourceMetadataUrl}"` - ); - expect(nextFunction).not.toHaveBeenCalled(); + expect(mockVerifyAccessToken).toHaveBeenCalledWith('invalid-token'); + expect(mockResponse.status).toHaveBeenCalledWith(401); + expect(mockResponse.set).toHaveBeenCalledWith('WWW-Authenticate', expect.stringContaining('Bearer error="invalid_token"')); + expect(mockResponse.json).toHaveBeenCalledWith( + expect.objectContaining({ error: 'invalid_token', error_description: 'Token expired' }) + ); + expect(nextFunction).not.toHaveBeenCalled(); }); - it("should include resource_metadata in WWW-Authenticate header for insufficient scope errors", async () => { - mockRequest.headers = { - authorization: "Bearer valid-token", - }; + it('should return 403 when access token has insufficient scopes', async () => { + mockRequest.headers = { + authorization: 'Bearer valid-token' + }; - mockVerifyAccessToken.mockRejectedValue(new InsufficientScopeError("Required scopes: admin")); + mockVerifyAccessToken.mockRejectedValue(new InsufficientScopeError('Required scopes: read, write')); - const middleware = requireBearerAuth({ verifier: mockVerifier, resourceMetadataUrl }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); + const middleware = requireBearerAuth({ verifier: mockVerifier }); + await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - expect(mockResponse.status).toHaveBeenCalledWith(403); - expect(mockResponse.set).toHaveBeenCalledWith( - "WWW-Authenticate", - `Bearer error="insufficient_scope", error_description="Required scopes: admin", resource_metadata="${resourceMetadataUrl}"` - ); - expect(nextFunction).not.toHaveBeenCalled(); + expect(mockVerifyAccessToken).toHaveBeenCalledWith('valid-token'); + expect(mockResponse.status).toHaveBeenCalledWith(403); + expect(mockResponse.set).toHaveBeenCalledWith('WWW-Authenticate', expect.stringContaining('Bearer error="insufficient_scope"')); + expect(mockResponse.json).toHaveBeenCalledWith( + expect.objectContaining({ error: 'insufficient_scope', error_description: 'Required scopes: read, write' }) + ); + expect(nextFunction).not.toHaveBeenCalled(); }); - it("should include resource_metadata when token is expired", async () => { - const expiredAuthInfo: AuthInfo = { - token: "expired-token", - clientId: "client-123", - scopes: ["read", "write"], - expiresAt: Math.floor(Date.now() / 1000) - 100, - }; - mockVerifyAccessToken.mockResolvedValue(expiredAuthInfo); - - mockRequest.headers = { - authorization: "Bearer expired-token", - }; - - const middleware = requireBearerAuth({ verifier: mockVerifier, resourceMetadataUrl }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockResponse.status).toHaveBeenCalledWith(401); - expect(mockResponse.set).toHaveBeenCalledWith( - "WWW-Authenticate", - `Bearer error="invalid_token", error_description="Token has expired", resource_metadata="${resourceMetadataUrl}"` - ); - expect(nextFunction).not.toHaveBeenCalled(); + it('should return 500 when a ServerError occurs', async () => { + mockRequest.headers = { + authorization: 'Bearer valid-token' + }; + + mockVerifyAccessToken.mockRejectedValue(new ServerError('Internal server issue')); + + const middleware = requireBearerAuth({ verifier: mockVerifier }); + await middleware(mockRequest as Request, mockResponse as Response, nextFunction); + + expect(mockVerifyAccessToken).toHaveBeenCalledWith('valid-token'); + expect(mockResponse.status).toHaveBeenCalledWith(500); + expect(mockResponse.json).toHaveBeenCalledWith( + expect.objectContaining({ error: 'server_error', error_description: 'Internal server issue' }) + ); + expect(nextFunction).not.toHaveBeenCalled(); }); - it("should include resource_metadata when scope check fails", async () => { - const authInfo: AuthInfo = { - token: "valid-token", - clientId: "client-123", - scopes: ["read"], - }; - mockVerifyAccessToken.mockResolvedValue(authInfo); - - mockRequest.headers = { - authorization: "Bearer valid-token", - }; - - const middleware = requireBearerAuth({ - verifier: mockVerifier, - requiredScopes: ["read", "write"], - resourceMetadataUrl - }); - - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - - expect(mockResponse.status).toHaveBeenCalledWith(403); - expect(mockResponse.set).toHaveBeenCalledWith( - "WWW-Authenticate", - `Bearer error="insufficient_scope", error_description="Insufficient scope", resource_metadata="${resourceMetadataUrl}"` - ); - expect(nextFunction).not.toHaveBeenCalled(); + it('should return 400 for generic OAuthError', async () => { + mockRequest.headers = { + authorization: 'Bearer valid-token' + }; + + mockVerifyAccessToken.mockRejectedValue(new CustomOAuthError('custom_error', 'Some OAuth error')); + + const middleware = requireBearerAuth({ verifier: mockVerifier }); + await middleware(mockRequest as Request, mockResponse as Response, nextFunction); + + expect(mockVerifyAccessToken).toHaveBeenCalledWith('valid-token'); + expect(mockResponse.status).toHaveBeenCalledWith(400); + expect(mockResponse.json).toHaveBeenCalledWith( + expect.objectContaining({ error: 'custom_error', error_description: 'Some OAuth error' }) + ); + expect(nextFunction).not.toHaveBeenCalled(); }); - it("should not affect server errors (no WWW-Authenticate header)", async () => { - mockRequest.headers = { - authorization: "Bearer valid-token", - }; + it('should return 500 when unexpected error occurs', async () => { + mockRequest.headers = { + authorization: 'Bearer valid-token' + }; - mockVerifyAccessToken.mockRejectedValue(new ServerError("Internal server issue")); + mockVerifyAccessToken.mockRejectedValue(new Error('Unexpected error')); - const middleware = requireBearerAuth({ verifier: mockVerifier, resourceMetadataUrl }); - await middleware(mockRequest as Request, mockResponse as Response, nextFunction); + const middleware = requireBearerAuth({ verifier: mockVerifier }); + await middleware(mockRequest as Request, mockResponse as Response, nextFunction); + + expect(mockVerifyAccessToken).toHaveBeenCalledWith('valid-token'); + expect(mockResponse.status).toHaveBeenCalledWith(500); + expect(mockResponse.json).toHaveBeenCalledWith( + expect.objectContaining({ error: 'server_error', error_description: 'Internal Server Error' }) + ); + expect(nextFunction).not.toHaveBeenCalled(); + }); - expect(mockResponse.status).toHaveBeenCalledWith(500); - expect(mockResponse.set).not.toHaveBeenCalledWith("WWW-Authenticate", expect.anything()); - expect(nextFunction).not.toHaveBeenCalled(); + describe('with resourceMetadataUrl', () => { + const resourceMetadataUrl = 'https://api.example.com/.well-known/oauth-protected-resource'; + + it('should include resource_metadata in WWW-Authenticate header for 401 responses', async () => { + mockRequest.headers = {}; + + const middleware = requireBearerAuth({ verifier: mockVerifier, resourceMetadataUrl }); + await middleware(mockRequest as Request, mockResponse as Response, nextFunction); + + expect(mockResponse.status).toHaveBeenCalledWith(401); + expect(mockResponse.set).toHaveBeenCalledWith( + 'WWW-Authenticate', + `Bearer error="invalid_token", error_description="Missing Authorization header", resource_metadata="${resourceMetadataUrl}"` + ); + expect(nextFunction).not.toHaveBeenCalled(); + }); + + it('should include resource_metadata in WWW-Authenticate header when token verification fails', async () => { + mockRequest.headers = { + authorization: 'Bearer invalid-token' + }; + + mockVerifyAccessToken.mockRejectedValue(new InvalidTokenError('Token expired')); + + const middleware = requireBearerAuth({ verifier: mockVerifier, resourceMetadataUrl }); + await middleware(mockRequest as Request, mockResponse as Response, nextFunction); + + expect(mockResponse.status).toHaveBeenCalledWith(401); + expect(mockResponse.set).toHaveBeenCalledWith( + 'WWW-Authenticate', + `Bearer error="invalid_token", error_description="Token expired", resource_metadata="${resourceMetadataUrl}"` + ); + expect(nextFunction).not.toHaveBeenCalled(); + }); + + it('should include resource_metadata in WWW-Authenticate header for insufficient scope errors', async () => { + mockRequest.headers = { + authorization: 'Bearer valid-token' + }; + + mockVerifyAccessToken.mockRejectedValue(new InsufficientScopeError('Required scopes: admin')); + + const middleware = requireBearerAuth({ verifier: mockVerifier, resourceMetadataUrl }); + await middleware(mockRequest as Request, mockResponse as Response, nextFunction); + + expect(mockResponse.status).toHaveBeenCalledWith(403); + expect(mockResponse.set).toHaveBeenCalledWith( + 'WWW-Authenticate', + `Bearer error="insufficient_scope", error_description="Required scopes: admin", resource_metadata="${resourceMetadataUrl}"` + ); + expect(nextFunction).not.toHaveBeenCalled(); + }); + + it('should include resource_metadata when token is expired', async () => { + const expiredAuthInfo: AuthInfo = { + token: 'expired-token', + clientId: 'client-123', + scopes: ['read', 'write'], + expiresAt: Math.floor(Date.now() / 1000) - 100 + }; + mockVerifyAccessToken.mockResolvedValue(expiredAuthInfo); + + mockRequest.headers = { + authorization: 'Bearer expired-token' + }; + + const middleware = requireBearerAuth({ verifier: mockVerifier, resourceMetadataUrl }); + await middleware(mockRequest as Request, mockResponse as Response, nextFunction); + + expect(mockResponse.status).toHaveBeenCalledWith(401); + expect(mockResponse.set).toHaveBeenCalledWith( + 'WWW-Authenticate', + `Bearer error="invalid_token", error_description="Token has expired", resource_metadata="${resourceMetadataUrl}"` + ); + expect(nextFunction).not.toHaveBeenCalled(); + }); + + it('should include resource_metadata when scope check fails', async () => { + const authInfo: AuthInfo = { + token: 'valid-token', + clientId: 'client-123', + scopes: ['read'] + }; + mockVerifyAccessToken.mockResolvedValue(authInfo); + + mockRequest.headers = { + authorization: 'Bearer valid-token' + }; + + const middleware = requireBearerAuth({ + verifier: mockVerifier, + requiredScopes: ['read', 'write'], + resourceMetadataUrl + }); + + await middleware(mockRequest as Request, mockResponse as Response, nextFunction); + + expect(mockResponse.status).toHaveBeenCalledWith(403); + expect(mockResponse.set).toHaveBeenCalledWith( + 'WWW-Authenticate', + `Bearer error="insufficient_scope", error_description="Insufficient scope", resource_metadata="${resourceMetadataUrl}"` + ); + expect(nextFunction).not.toHaveBeenCalled(); + }); + + it('should not affect server errors (no WWW-Authenticate header)', async () => { + mockRequest.headers = { + authorization: 'Bearer valid-token' + }; + + mockVerifyAccessToken.mockRejectedValue(new ServerError('Internal server issue')); + + const middleware = requireBearerAuth({ verifier: mockVerifier, resourceMetadataUrl }); + await middleware(mockRequest as Request, mockResponse as Response, nextFunction); + + expect(mockResponse.status).toHaveBeenCalledWith(500); + expect(mockResponse.set).not.toHaveBeenCalledWith('WWW-Authenticate', expect.anything()); + expect(nextFunction).not.toHaveBeenCalled(); + }); }); - }); }); diff --git a/src/server/auth/middleware/bearerAuth.ts b/src/server/auth/middleware/bearerAuth.ts index 7b6d8f61f..363fd7a42 100644 --- a/src/server/auth/middleware/bearerAuth.ts +++ b/src/server/auth/middleware/bearerAuth.ts @@ -1,32 +1,32 @@ -import { RequestHandler } from "express"; -import { InsufficientScopeError, InvalidTokenError, OAuthError, ServerError } from "../errors.js"; -import { OAuthTokenVerifier } from "../provider.js"; -import { AuthInfo } from "../types.js"; +import { RequestHandler } from 'express'; +import { InsufficientScopeError, InvalidTokenError, OAuthError, ServerError } from '../errors.js'; +import { OAuthTokenVerifier } from '../provider.js'; +import { AuthInfo } from '../types.js'; export type BearerAuthMiddlewareOptions = { - /** - * A provider used to verify tokens. - */ - verifier: OAuthTokenVerifier; - - /** - * Optional scopes that the token must have. - */ - requiredScopes?: string[]; + /** + * A provider used to verify tokens. + */ + verifier: OAuthTokenVerifier; - /** - * Optional resource metadata URL to include in WWW-Authenticate header. - */ - resourceMetadataUrl?: string; -}; + /** + * Optional scopes that the token must have. + */ + requiredScopes?: string[]; -declare module "express-serve-static-core" { - interface Request { /** - * Information about the validated access token, if the `requireBearerAuth` middleware was used. + * Optional resource metadata URL to include in WWW-Authenticate header. */ - auth?: AuthInfo; - } + resourceMetadataUrl?: string; +}; + +declare module 'express-serve-static-core' { + interface Request { + /** + * Information about the validated access token, if the `requireBearerAuth` middleware was used. + */ + auth?: AuthInfo; + } } /** @@ -38,61 +38,59 @@ declare module "express-serve-static-core" { * for 401 responses as per the OAuth 2.0 Protected Resource Metadata spec. */ export function requireBearerAuth({ verifier, requiredScopes = [], resourceMetadataUrl }: BearerAuthMiddlewareOptions): RequestHandler { - return async (req, res, next) => { - try { - const authHeader = req.headers.authorization; - if (!authHeader) { - throw new InvalidTokenError("Missing Authorization header"); - } + return async (req, res, next) => { + try { + const authHeader = req.headers.authorization; + if (!authHeader) { + throw new InvalidTokenError('Missing Authorization header'); + } - const [type, token] = authHeader.split(' '); - if (type.toLowerCase() !== 'bearer' || !token) { - throw new InvalidTokenError("Invalid Authorization header format, expected 'Bearer TOKEN'"); - } + const [type, token] = authHeader.split(' '); + if (type.toLowerCase() !== 'bearer' || !token) { + throw new InvalidTokenError("Invalid Authorization header format, expected 'Bearer TOKEN'"); + } - const authInfo = await verifier.verifyAccessToken(token); + const authInfo = await verifier.verifyAccessToken(token); - // Check if token has the required scopes (if any) - if (requiredScopes.length > 0) { - const hasAllScopes = requiredScopes.every(scope => - authInfo.scopes.includes(scope) - ); + // Check if token has the required scopes (if any) + if (requiredScopes.length > 0) { + const hasAllScopes = requiredScopes.every(scope => authInfo.scopes.includes(scope)); - if (!hasAllScopes) { - throw new InsufficientScopeError("Insufficient scope"); - } - } + if (!hasAllScopes) { + throw new InsufficientScopeError('Insufficient scope'); + } + } - // Check if the token is set to expire or if it is expired - if (typeof authInfo.expiresAt !== 'number' || isNaN(authInfo.expiresAt)) { - throw new InvalidTokenError("Token has no expiration time"); - } else if (authInfo.expiresAt < Date.now() / 1000) { - throw new InvalidTokenError("Token has expired"); - } + // Check if the token is set to expire or if it is expired + if (typeof authInfo.expiresAt !== 'number' || isNaN(authInfo.expiresAt)) { + throw new InvalidTokenError('Token has no expiration time'); + } else if (authInfo.expiresAt < Date.now() / 1000) { + throw new InvalidTokenError('Token has expired'); + } - req.auth = authInfo; - next(); - } catch (error) { - if (error instanceof InvalidTokenError) { - const wwwAuthValue = resourceMetadataUrl - ? `Bearer error="${error.errorCode}", error_description="${error.message}", resource_metadata="${resourceMetadataUrl}"` - : `Bearer error="${error.errorCode}", error_description="${error.message}"`; - res.set("WWW-Authenticate", wwwAuthValue); - res.status(401).json(error.toResponseObject()); - } else if (error instanceof InsufficientScopeError) { - const wwwAuthValue = resourceMetadataUrl - ? `Bearer error="${error.errorCode}", error_description="${error.message}", resource_metadata="${resourceMetadataUrl}"` - : `Bearer error="${error.errorCode}", error_description="${error.message}"`; - res.set("WWW-Authenticate", wwwAuthValue); - res.status(403).json(error.toResponseObject()); - } else if (error instanceof ServerError) { - res.status(500).json(error.toResponseObject()); - } else if (error instanceof OAuthError) { - res.status(400).json(error.toResponseObject()); - } else { - const serverError = new ServerError("Internal Server Error"); - res.status(500).json(serverError.toResponseObject()); - } - } - }; + req.auth = authInfo; + next(); + } catch (error) { + if (error instanceof InvalidTokenError) { + const wwwAuthValue = resourceMetadataUrl + ? `Bearer error="${error.errorCode}", error_description="${error.message}", resource_metadata="${resourceMetadataUrl}"` + : `Bearer error="${error.errorCode}", error_description="${error.message}"`; + res.set('WWW-Authenticate', wwwAuthValue); + res.status(401).json(error.toResponseObject()); + } else if (error instanceof InsufficientScopeError) { + const wwwAuthValue = resourceMetadataUrl + ? `Bearer error="${error.errorCode}", error_description="${error.message}", resource_metadata="${resourceMetadataUrl}"` + : `Bearer error="${error.errorCode}", error_description="${error.message}"`; + res.set('WWW-Authenticate', wwwAuthValue); + res.status(403).json(error.toResponseObject()); + } else if (error instanceof ServerError) { + res.status(500).json(error.toResponseObject()); + } else if (error instanceof OAuthError) { + res.status(400).json(error.toResponseObject()); + } else { + const serverError = new ServerError('Internal Server Error'); + res.status(500).json(serverError.toResponseObject()); + } + } + }; } diff --git a/src/server/auth/middleware/clientAuth.test.ts b/src/server/auth/middleware/clientAuth.test.ts index 0cfe2247f..5ad6f301f 100644 --- a/src/server/auth/middleware/clientAuth.test.ts +++ b/src/server/auth/middleware/clientAuth.test.ts @@ -5,144 +5,128 @@ import express from 'express'; import supertest from 'supertest'; describe('clientAuth middleware', () => { - // Mock client store - const mockClientStore: OAuthRegisteredClientsStore = { - async getClient(clientId: string): Promise { - if (clientId === 'valid-client') { - return { - client_id: 'valid-client', - client_secret: 'valid-secret', - redirect_uris: ['https://example.com/callback'] - }; - } else if (clientId === 'expired-client') { - // Client with no secret - return { - client_id: 'expired-client', - redirect_uris: ['https://example.com/callback'] - }; - } else if (clientId === 'client-with-expired-secret') { - // Client with an expired secret - return { - client_id: 'client-with-expired-secret', - client_secret: 'expired-secret', - client_secret_expires_at: Math.floor(Date.now() / 1000) - 3600, // Expired 1 hour ago - redirect_uris: ['https://example.com/callback'] - }; - } - return undefined; - } - }; - - // Setup Express app with middleware - let app: express.Express; - let options: ClientAuthenticationMiddlewareOptions; - - beforeEach(() => { - app = express(); - app.use(express.json()); - - options = { - clientsStore: mockClientStore + // Mock client store + const mockClientStore: OAuthRegisteredClientsStore = { + async getClient(clientId: string): Promise { + if (clientId === 'valid-client') { + return { + client_id: 'valid-client', + client_secret: 'valid-secret', + redirect_uris: ['https://example.com/callback'] + }; + } else if (clientId === 'expired-client') { + // Client with no secret + return { + client_id: 'expired-client', + redirect_uris: ['https://example.com/callback'] + }; + } else if (clientId === 'client-with-expired-secret') { + // Client with an expired secret + return { + client_id: 'client-with-expired-secret', + client_secret: 'expired-secret', + client_secret_expires_at: Math.floor(Date.now() / 1000) - 3600, // Expired 1 hour ago + redirect_uris: ['https://example.com/callback'] + }; + } + return undefined; + } }; - // Setup route with client auth - app.post('/protected', authenticateClient(options), (req, res) => { - res.status(200).json({ success: true, client: req.client }); + // Setup Express app with middleware + let app: express.Express; + let options: ClientAuthenticationMiddlewareOptions; + + beforeEach(() => { + app = express(); + app.use(express.json()); + + options = { + clientsStore: mockClientStore + }; + + // Setup route with client auth + app.post('/protected', authenticateClient(options), (req, res) => { + res.status(200).json({ success: true, client: req.client }); + }); + }); + + it('authenticates valid client credentials', async () => { + const response = await supertest(app).post('/protected').send({ + client_id: 'valid-client', + client_secret: 'valid-secret' + }); + + expect(response.status).toBe(200); + expect(response.body.success).toBe(true); + expect(response.body.client.client_id).toBe('valid-client'); + }); + + it('rejects invalid client_id', async () => { + const response = await supertest(app).post('/protected').send({ + client_id: 'non-existent-client', + client_secret: 'some-secret' + }); + + expect(response.status).toBe(400); + expect(response.body.error).toBe('invalid_client'); + expect(response.body.error_description).toBe('Invalid client_id'); + }); + + it('rejects invalid client_secret', async () => { + const response = await supertest(app).post('/protected').send({ + client_id: 'valid-client', + client_secret: 'wrong-secret' + }); + + expect(response.status).toBe(400); + expect(response.body.error).toBe('invalid_client'); + expect(response.body.error_description).toBe('Invalid client_secret'); + }); + + it('rejects missing client_id', async () => { + const response = await supertest(app).post('/protected').send({ + client_secret: 'valid-secret' + }); + + expect(response.status).toBe(400); + expect(response.body.error).toBe('invalid_request'); + }); + + it('allows missing client_secret if client has none', async () => { + const response = await supertest(app).post('/protected').send({ + client_id: 'expired-client' + }); + + // Since the client has no secret, this should pass without providing one + expect(response.status).toBe(200); + }); + + it('rejects request when client secret has expired', async () => { + const response = await supertest(app).post('/protected').send({ + client_id: 'client-with-expired-secret', + client_secret: 'expired-secret' + }); + + expect(response.status).toBe(400); + expect(response.body.error).toBe('invalid_client'); + expect(response.body.error_description).toBe('Client secret has expired'); + }); + + it('handles malformed request body', async () => { + const response = await supertest(app).post('/protected').send('not-json-format'); + + expect(response.status).toBe(400); + }); + + // Testing request with extra fields to ensure they're ignored + it('ignores extra fields in request', async () => { + const response = await supertest(app).post('/protected').send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + extra_field: 'should be ignored' + }); + + expect(response.status).toBe(200); }); - }); - - it('authenticates valid client credentials', async () => { - const response = await supertest(app) - .post('/protected') - .send({ - client_id: 'valid-client', - client_secret: 'valid-secret' - }); - - expect(response.status).toBe(200); - expect(response.body.success).toBe(true); - expect(response.body.client.client_id).toBe('valid-client'); - }); - - it('rejects invalid client_id', async () => { - const response = await supertest(app) - .post('/protected') - .send({ - client_id: 'non-existent-client', - client_secret: 'some-secret' - }); - - expect(response.status).toBe(400); - expect(response.body.error).toBe('invalid_client'); - expect(response.body.error_description).toBe('Invalid client_id'); - }); - - it('rejects invalid client_secret', async () => { - const response = await supertest(app) - .post('/protected') - .send({ - client_id: 'valid-client', - client_secret: 'wrong-secret' - }); - - expect(response.status).toBe(400); - expect(response.body.error).toBe('invalid_client'); - expect(response.body.error_description).toBe('Invalid client_secret'); - }); - - it('rejects missing client_id', async () => { - const response = await supertest(app) - .post('/protected') - .send({ - client_secret: 'valid-secret' - }); - - expect(response.status).toBe(400); - expect(response.body.error).toBe('invalid_request'); - }); - - it('allows missing client_secret if client has none', async () => { - const response = await supertest(app) - .post('/protected') - .send({ - client_id: 'expired-client' - }); - - // Since the client has no secret, this should pass without providing one - expect(response.status).toBe(200); - }); - - it('rejects request when client secret has expired', async () => { - const response = await supertest(app) - .post('/protected') - .send({ - client_id: 'client-with-expired-secret', - client_secret: 'expired-secret' - }); - - expect(response.status).toBe(400); - expect(response.body.error).toBe('invalid_client'); - expect(response.body.error_description).toBe('Client secret has expired'); - }); - - it('handles malformed request body', async () => { - const response = await supertest(app) - .post('/protected') - .send('not-json-format'); - - expect(response.status).toBe(400); - }); - - // Testing request with extra fields to ensure they're ignored - it('ignores extra fields in request', async () => { - const response = await supertest(app) - .post('/protected') - .send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - extra_field: 'should be ignored' - }); - - expect(response.status).toBe(200); - }); -}); \ No newline at end of file +}); diff --git a/src/server/auth/middleware/clientAuth.ts b/src/server/auth/middleware/clientAuth.ts index ecd9a7b65..9969b8724 100644 --- a/src/server/auth/middleware/clientAuth.ts +++ b/src/server/auth/middleware/clientAuth.ts @@ -1,72 +1,72 @@ -import { z } from "zod"; -import { RequestHandler } from "express"; -import { OAuthRegisteredClientsStore } from "../clients.js"; -import { OAuthClientInformationFull } from "../../../shared/auth.js"; -import { InvalidRequestError, InvalidClientError, ServerError, OAuthError } from "../errors.js"; +import { z } from 'zod'; +import { RequestHandler } from 'express'; +import { OAuthRegisteredClientsStore } from '../clients.js'; +import { OAuthClientInformationFull } from '../../../shared/auth.js'; +import { InvalidRequestError, InvalidClientError, ServerError, OAuthError } from '../errors.js'; export type ClientAuthenticationMiddlewareOptions = { - /** - * A store used to read information about registered OAuth clients. - */ - clientsStore: OAuthRegisteredClientsStore; -} + /** + * A store used to read information about registered OAuth clients. + */ + clientsStore: OAuthRegisteredClientsStore; +}; const ClientAuthenticatedRequestSchema = z.object({ - client_id: z.string(), - client_secret: z.string().optional(), + client_id: z.string(), + client_secret: z.string().optional() }); -declare module "express-serve-static-core" { - interface Request { - /** - * The authenticated client for this request, if the `authenticateClient` middleware was used. - */ - client?: OAuthClientInformationFull; - } +declare module 'express-serve-static-core' { + interface Request { + /** + * The authenticated client for this request, if the `authenticateClient` middleware was used. + */ + client?: OAuthClientInformationFull; + } } export function authenticateClient({ clientsStore }: ClientAuthenticationMiddlewareOptions): RequestHandler { - return async (req, res, next) => { - try { - const result = ClientAuthenticatedRequestSchema.safeParse(req.body); - if (!result.success) { - throw new InvalidRequestError(String(result.error)); - } + return async (req, res, next) => { + try { + const result = ClientAuthenticatedRequestSchema.safeParse(req.body); + if (!result.success) { + throw new InvalidRequestError(String(result.error)); + } - const { client_id, client_secret } = result.data; - const client = await clientsStore.getClient(client_id); - if (!client) { - throw new InvalidClientError("Invalid client_id"); - } + const { client_id, client_secret } = result.data; + const client = await clientsStore.getClient(client_id); + if (!client) { + throw new InvalidClientError('Invalid client_id'); + } - // If client has a secret, validate it - if (client.client_secret) { - // Check if client_secret is required but not provided - if (!client_secret) { - throw new InvalidClientError("Client secret is required"); - } + // If client has a secret, validate it + if (client.client_secret) { + // Check if client_secret is required but not provided + if (!client_secret) { + throw new InvalidClientError('Client secret is required'); + } - // Check if client_secret matches - if (client.client_secret !== client_secret) { - throw new InvalidClientError("Invalid client_secret"); - } + // Check if client_secret matches + if (client.client_secret !== client_secret) { + throw new InvalidClientError('Invalid client_secret'); + } - // Check if client_secret has expired - if (client.client_secret_expires_at && client.client_secret_expires_at < Math.floor(Date.now() / 1000)) { - throw new InvalidClientError("Client secret has expired"); - } - } + // Check if client_secret has expired + if (client.client_secret_expires_at && client.client_secret_expires_at < Math.floor(Date.now() / 1000)) { + throw new InvalidClientError('Client secret has expired'); + } + } - req.client = client; - next(); - } catch (error) { - if (error instanceof OAuthError) { - const status = error instanceof ServerError ? 500 : 400; - res.status(status).json(error.toResponseObject()); - } else { - const serverError = new ServerError("Internal Server Error"); - res.status(500).json(serverError.toResponseObject()); - } - } - } -} \ No newline at end of file + req.client = client; + next(); + } catch (error) { + if (error instanceof OAuthError) { + const status = error instanceof ServerError ? 500 : 400; + res.status(status).json(error.toResponseObject()); + } else { + const serverError = new ServerError('Internal Server Error'); + res.status(500).json(serverError.toResponseObject()); + } + } + }; +} diff --git a/src/server/auth/provider.ts b/src/server/auth/provider.ts index 18beb2166..cf1c306de 100644 --- a/src/server/auth/provider.ts +++ b/src/server/auth/provider.ts @@ -1,84 +1,83 @@ -import { Response } from "express"; -import { OAuthRegisteredClientsStore } from "./clients.js"; -import { OAuthClientInformationFull, OAuthTokenRevocationRequest, OAuthTokens } from "../../shared/auth.js"; -import { AuthInfo } from "./types.js"; +import { Response } from 'express'; +import { OAuthRegisteredClientsStore } from './clients.js'; +import { OAuthClientInformationFull, OAuthTokenRevocationRequest, OAuthTokens } from '../../shared/auth.js'; +import { AuthInfo } from './types.js'; export type AuthorizationParams = { - state?: string; - scopes?: string[]; - codeChallenge: string; - redirectUri: string; - resource?: URL; + state?: string; + scopes?: string[]; + codeChallenge: string; + redirectUri: string; + resource?: URL; }; /** * Implements an end-to-end OAuth server. */ export interface OAuthServerProvider { - /** - * A store used to read information about registered OAuth clients. - */ - get clientsStore(): OAuthRegisteredClientsStore; + /** + * A store used to read information about registered OAuth clients. + */ + get clientsStore(): OAuthRegisteredClientsStore; - /** - * Begins the authorization flow, which can either be implemented by this server itself or via redirection to a separate authorization server. - * - * This server must eventually issue a redirect with an authorization response or an error response to the given redirect URI. Per OAuth 2.1: - * - In the successful case, the redirect MUST include the `code` and `state` (if present) query parameters. - * - In the error case, the redirect MUST include the `error` query parameter, and MAY include an optional `error_description` query parameter. - */ - authorize(client: OAuthClientInformationFull, params: AuthorizationParams, res: Response): Promise; + /** + * Begins the authorization flow, which can either be implemented by this server itself or via redirection to a separate authorization server. + * + * This server must eventually issue a redirect with an authorization response or an error response to the given redirect URI. Per OAuth 2.1: + * - In the successful case, the redirect MUST include the `code` and `state` (if present) query parameters. + * - In the error case, the redirect MUST include the `error` query parameter, and MAY include an optional `error_description` query parameter. + */ + authorize(client: OAuthClientInformationFull, params: AuthorizationParams, res: Response): Promise; - /** - * Returns the `codeChallenge` that was used when the indicated authorization began. - */ - challengeForAuthorizationCode(client: OAuthClientInformationFull, authorizationCode: string): Promise; + /** + * Returns the `codeChallenge` that was used when the indicated authorization began. + */ + challengeForAuthorizationCode(client: OAuthClientInformationFull, authorizationCode: string): Promise; - /** - * Exchanges an authorization code for an access token. - */ - exchangeAuthorizationCode( - client: OAuthClientInformationFull, - authorizationCode: string, - codeVerifier?: string, - redirectUri?: string, - resource?: URL - ): Promise; + /** + * Exchanges an authorization code for an access token. + */ + exchangeAuthorizationCode( + client: OAuthClientInformationFull, + authorizationCode: string, + codeVerifier?: string, + redirectUri?: string, + resource?: URL + ): Promise; - /** - * Exchanges a refresh token for an access token. - */ - exchangeRefreshToken(client: OAuthClientInformationFull, refreshToken: string, scopes?: string[], resource?: URL): Promise; + /** + * Exchanges a refresh token for an access token. + */ + exchangeRefreshToken(client: OAuthClientInformationFull, refreshToken: string, scopes?: string[], resource?: URL): Promise; - /** - * Verifies an access token and returns information about it. - */ - verifyAccessToken(token: string): Promise; + /** + * Verifies an access token and returns information about it. + */ + verifyAccessToken(token: string): Promise; - /** - * Revokes an access or refresh token. If unimplemented, token revocation is not supported (not recommended). - * - * If the given token is invalid or already revoked, this method should do nothing. - */ - revokeToken?(client: OAuthClientInformationFull, request: OAuthTokenRevocationRequest): Promise; + /** + * Revokes an access or refresh token. If unimplemented, token revocation is not supported (not recommended). + * + * If the given token is invalid or already revoked, this method should do nothing. + */ + revokeToken?(client: OAuthClientInformationFull, request: OAuthTokenRevocationRequest): Promise; - /** - * Whether to skip local PKCE validation. - * - * If true, the server will not perform PKCE validation locally and will pass the code_verifier to the upstream server. - * - * NOTE: This should only be true if the upstream server is performing the actual PKCE validation. - */ - skipLocalPkceValidation?: boolean; + /** + * Whether to skip local PKCE validation. + * + * If true, the server will not perform PKCE validation locally and will pass the code_verifier to the upstream server. + * + * NOTE: This should only be true if the upstream server is performing the actual PKCE validation. + */ + skipLocalPkceValidation?: boolean; } - /** * Slim implementation useful for token verification */ export interface OAuthTokenVerifier { - /** - * Verifies an access token and returns information about it. - */ - verifyAccessToken(token: string): Promise; + /** + * Verifies an access token and returns information about it. + */ + verifyAccessToken(token: string): Promise; } diff --git a/src/server/auth/providers/proxyProvider.test.ts b/src/server/auth/providers/proxyProvider.test.ts index 4e98d0dc0..97069ca6b 100644 --- a/src/server/auth/providers/proxyProvider.test.ts +++ b/src/server/auth/providers/proxyProvider.test.ts @@ -1,365 +1,343 @@ -import { Response } from "express"; -import { ProxyOAuthServerProvider, ProxyOptions } from "./proxyProvider.js"; -import { AuthInfo } from "../types.js"; -import { OAuthClientInformationFull, OAuthTokens } from "../../../shared/auth.js"; -import { ServerError } from "../errors.js"; -import { InvalidTokenError } from "../errors.js"; -import { InsufficientScopeError } from "../errors.js"; - -describe("Proxy OAuth Server Provider", () => { - // Mock client data - const validClient: OAuthClientInformationFull = { - client_id: "test-client", - client_secret: "test-secret", - redirect_uris: ["https://example.com/callback"], - }; - - // Mock response object - const mockResponse = { - redirect: jest.fn(), - } as unknown as Response; - - // Mock provider functions - const mockVerifyToken = jest.fn(); - const mockGetClient = jest.fn(); - - // Base provider options - const baseOptions: ProxyOptions = { - endpoints: { - authorizationUrl: "https://auth.example.com/authorize", - tokenUrl: "https://auth.example.com/token", - revocationUrl: "https://auth.example.com/revoke", - registrationUrl: "https://auth.example.com/register", - }, - verifyAccessToken: mockVerifyToken, - getClient: mockGetClient, - }; - - let provider: ProxyOAuthServerProvider; - let originalFetch: typeof global.fetch; - - beforeEach(() => { - provider = new ProxyOAuthServerProvider(baseOptions); - originalFetch = global.fetch; - global.fetch = jest.fn(); - - // Setup mock implementations - mockVerifyToken.mockImplementation(async (token: string) => { - if (token === "valid-token") { - return { - token, - clientId: "test-client", - scopes: ["read", "write"], - expiresAt: Date.now() / 1000 + 3600, - } as AuthInfo; - } - throw new InvalidTokenError("Invalid token"); - }); +import { Response } from 'express'; +import { ProxyOAuthServerProvider, ProxyOptions } from './proxyProvider.js'; +import { AuthInfo } from '../types.js'; +import { OAuthClientInformationFull, OAuthTokens } from '../../../shared/auth.js'; +import { ServerError } from '../errors.js'; +import { InvalidTokenError } from '../errors.js'; +import { InsufficientScopeError } from '../errors.js'; + +describe('Proxy OAuth Server Provider', () => { + // Mock client data + const validClient: OAuthClientInformationFull = { + client_id: 'test-client', + client_secret: 'test-secret', + redirect_uris: ['https://example.com/callback'] + }; - mockGetClient.mockImplementation(async (clientId: string) => { - if (clientId === "test-client") { - return validClient; - } - return undefined; - }); - }); - - // Add helper function for failed responses - const mockFailedResponse = () => { - (global.fetch as jest.Mock).mockImplementation(() => - Promise.resolve({ - ok: false, - status: 400, - }) - ); - }; - - afterEach(() => { - global.fetch = originalFetch; - jest.clearAllMocks(); - }); - - describe("authorization", () => { - it("redirects to authorization endpoint with correct parameters", async () => { - await provider.authorize( - validClient, - { - redirectUri: "https://example.com/callback", - codeChallenge: "test-challenge", - state: "test-state", - scopes: ["read", "write"], - resource: new URL('https://api.example.com/resource'), + // Mock response object + const mockResponse = { + redirect: jest.fn() + } as unknown as Response; + + // Mock provider functions + const mockVerifyToken = jest.fn(); + const mockGetClient = jest.fn(); + + // Base provider options + const baseOptions: ProxyOptions = { + endpoints: { + authorizationUrl: 'https://auth.example.com/authorize', + tokenUrl: 'https://auth.example.com/token', + revocationUrl: 'https://auth.example.com/revoke', + registrationUrl: 'https://auth.example.com/register' }, - mockResponse - ); - - const expectedUrl = new URL("https://auth.example.com/authorize"); - expectedUrl.searchParams.set("client_id", "test-client"); - expectedUrl.searchParams.set("response_type", "code"); - expectedUrl.searchParams.set("redirect_uri", "https://example.com/callback"); - expectedUrl.searchParams.set("code_challenge", "test-challenge"); - expectedUrl.searchParams.set("code_challenge_method", "S256"); - expectedUrl.searchParams.set("state", "test-state"); - expectedUrl.searchParams.set("scope", "read write"); - expectedUrl.searchParams.set('resource', 'https://api.example.com/resource'); - - expect(mockResponse.redirect).toHaveBeenCalledWith(expectedUrl.toString()); - }); - }); - - describe("token exchange", () => { - const mockTokenResponse: OAuthTokens = { - access_token: "new-access-token", - token_type: "Bearer", - expires_in: 3600, - refresh_token: "new-refresh-token", + verifyAccessToken: mockVerifyToken, + getClient: mockGetClient }; - beforeEach(() => { - (global.fetch as jest.Mock).mockImplementation(() => - Promise.resolve({ - ok: true, - json: () => Promise.resolve(mockTokenResponse), - }) - ); - }); - - it("exchanges authorization code for tokens", async () => { - const tokens = await provider.exchangeAuthorizationCode( - validClient, - "test-code", - "test-verifier" - ); - - expect(global.fetch).toHaveBeenCalledWith( - "https://auth.example.com/token", - expect.objectContaining({ - method: "POST", - headers: { - "Content-Type": "application/x-www-form-urlencoded", - }, - body: expect.stringContaining("grant_type=authorization_code") - }) - ); - expect(tokens).toEqual(mockTokenResponse); - }); - - it("includes redirect_uri in token request when provided", async () => { - const redirectUri = "https://example.com/callback"; - const tokens = await provider.exchangeAuthorizationCode( - validClient, - "test-code", - "test-verifier", - redirectUri - ); - - expect(global.fetch).toHaveBeenCalledWith( - "https://auth.example.com/token", - expect.objectContaining({ - method: "POST", - headers: { - "Content-Type": "application/x-www-form-urlencoded", - }, - body: expect.stringContaining(`redirect_uri=${encodeURIComponent(redirectUri)}`) - }) - ); - expect(tokens).toEqual(mockTokenResponse); - }); - - it('includes resource parameter in authorization code exchange', async () => { - const tokens = await provider.exchangeAuthorizationCode( - validClient, - 'test-code', - 'test-verifier', - 'https://example.com/callback', - new URL('https://api.example.com/resource') - ); - - expect(global.fetch).toHaveBeenCalledWith( - 'https://auth.example.com/token', - expect.objectContaining({ - method: 'POST', - headers: { - 'Content-Type': 'application/x-www-form-urlencoded', - }, - body: expect.stringContaining('resource=' + encodeURIComponent('https://api.example.com/resource')) - }) - ); - expect(tokens).toEqual(mockTokenResponse); - }); + let provider: ProxyOAuthServerProvider; + let originalFetch: typeof global.fetch; - it('handles authorization code exchange without resource parameter', async () => { - const tokens = await provider.exchangeAuthorizationCode( - validClient, - 'test-code', - 'test-verifier' - ); - - const fetchCall = (global.fetch as jest.Mock).mock.calls[0]; - const body = fetchCall[1].body as string; - expect(body).not.toContain('resource='); - expect(tokens).toEqual(mockTokenResponse); + beforeEach(() => { + provider = new ProxyOAuthServerProvider(baseOptions); + originalFetch = global.fetch; + global.fetch = jest.fn(); + + // Setup mock implementations + mockVerifyToken.mockImplementation(async (token: string) => { + if (token === 'valid-token') { + return { + token, + clientId: 'test-client', + scopes: ['read', 'write'], + expiresAt: Date.now() / 1000 + 3600 + } as AuthInfo; + } + throw new InvalidTokenError('Invalid token'); + }); + + mockGetClient.mockImplementation(async (clientId: string) => { + if (clientId === 'test-client') { + return validClient; + } + return undefined; + }); }); - it("exchanges refresh token for new tokens", async () => { - const tokens = await provider.exchangeRefreshToken( - validClient, - "test-refresh-token", - ["read", "write"] - ); - - expect(global.fetch).toHaveBeenCalledWith( - "https://auth.example.com/token", - expect.objectContaining({ - method: "POST", - headers: { - "Content-Type": "application/x-www-form-urlencoded", - }, - body: expect.stringContaining("grant_type=refresh_token") - }) - ); - expect(tokens).toEqual(mockTokenResponse); - }); + // Add helper function for failed responses + const mockFailedResponse = () => { + (global.fetch as jest.Mock).mockImplementation(() => + Promise.resolve({ + ok: false, + status: 400 + }) + ); + }; - it('includes resource parameter in refresh token exchange', async () => { - const tokens = await provider.exchangeRefreshToken( - validClient, - 'test-refresh-token', - ['read', 'write'], - new URL('https://api.example.com/resource') - ); - - expect(global.fetch).toHaveBeenCalledWith( - 'https://auth.example.com/token', - expect.objectContaining({ - method: 'POST', - headers: { - 'Content-Type': 'application/x-www-form-urlencoded', - }, - body: expect.stringContaining('resource=' + encodeURIComponent('https://api.example.com/resource')) - }) - ); - expect(tokens).toEqual(mockTokenResponse); - }); - }); - - describe("client registration", () => { - it("registers new client", async () => { - const newClient: OAuthClientInformationFull = { - client_id: "new-client", - redirect_uris: ["https://new-client.com/callback"], - }; - - (global.fetch as jest.Mock).mockImplementation(() => - Promise.resolve({ - ok: true, - json: () => Promise.resolve(newClient), - }) - ); - - const result = await provider.clientsStore.registerClient!(newClient); - - expect(global.fetch).toHaveBeenCalledWith( - "https://auth.example.com/register", - expect.objectContaining({ - method: "POST", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify(newClient), - }) - ); - expect(result).toEqual(newClient); + afterEach(() => { + global.fetch = originalFetch; + jest.clearAllMocks(); }); - it("handles registration failure", async () => { - mockFailedResponse(); - const newClient: OAuthClientInformationFull = { - client_id: "new-client", - redirect_uris: ["https://new-client.com/callback"], - }; - - await expect( - provider.clientsStore.registerClient!(newClient) - ).rejects.toThrow(ServerError); - }); - }); - - describe("token revocation", () => { - it("revokes token", async () => { - (global.fetch as jest.Mock).mockImplementation(() => - Promise.resolve({ - ok: true, - }) - ); - - await provider.revokeToken!(validClient, { - token: "token-to-revoke", - token_type_hint: "access_token", - }); - - expect(global.fetch).toHaveBeenCalledWith( - "https://auth.example.com/revoke", - expect.objectContaining({ - method: "POST", - headers: { - "Content-Type": "application/x-www-form-urlencoded", - }, - body: expect.stringContaining("token=token-to-revoke"), - }) - ); + describe('authorization', () => { + it('redirects to authorization endpoint with correct parameters', async () => { + await provider.authorize( + validClient, + { + redirectUri: 'https://example.com/callback', + codeChallenge: 'test-challenge', + state: 'test-state', + scopes: ['read', 'write'], + resource: new URL('https://api.example.com/resource') + }, + mockResponse + ); + + const expectedUrl = new URL('https://auth.example.com/authorize'); + expectedUrl.searchParams.set('client_id', 'test-client'); + expectedUrl.searchParams.set('response_type', 'code'); + expectedUrl.searchParams.set('redirect_uri', 'https://example.com/callback'); + expectedUrl.searchParams.set('code_challenge', 'test-challenge'); + expectedUrl.searchParams.set('code_challenge_method', 'S256'); + expectedUrl.searchParams.set('state', 'test-state'); + expectedUrl.searchParams.set('scope', 'read write'); + expectedUrl.searchParams.set('resource', 'https://api.example.com/resource'); + + expect(mockResponse.redirect).toHaveBeenCalledWith(expectedUrl.toString()); + }); }); - it("handles revocation failure", async () => { - mockFailedResponse(); - await expect( - provider.revokeToken!(validClient, { - token: "invalid-token", - }) - ).rejects.toThrow(ServerError); - }); - }); - - describe("token verification", () => { - it("verifies valid token", async () => { - const validAuthInfo: AuthInfo = { - token: "valid-token", - clientId: "test-client", - scopes: ["read", "write"], - expiresAt: Date.now() / 1000 + 3600, - }; - mockVerifyToken.mockResolvedValue(validAuthInfo); - - const authInfo = await provider.verifyAccessToken("valid-token"); - expect(authInfo).toEqual(validAuthInfo); - expect(mockVerifyToken).toHaveBeenCalledWith("valid-token"); + describe('token exchange', () => { + const mockTokenResponse: OAuthTokens = { + access_token: 'new-access-token', + token_type: 'Bearer', + expires_in: 3600, + refresh_token: 'new-refresh-token' + }; + + beforeEach(() => { + (global.fetch as jest.Mock).mockImplementation(() => + Promise.resolve({ + ok: true, + json: () => Promise.resolve(mockTokenResponse) + }) + ); + }); + + it('exchanges authorization code for tokens', async () => { + const tokens = await provider.exchangeAuthorizationCode(validClient, 'test-code', 'test-verifier'); + + expect(global.fetch).toHaveBeenCalledWith( + 'https://auth.example.com/token', + expect.objectContaining({ + method: 'POST', + headers: { + 'Content-Type': 'application/x-www-form-urlencoded' + }, + body: expect.stringContaining('grant_type=authorization_code') + }) + ); + expect(tokens).toEqual(mockTokenResponse); + }); + + it('includes redirect_uri in token request when provided', async () => { + const redirectUri = 'https://example.com/callback'; + const tokens = await provider.exchangeAuthorizationCode(validClient, 'test-code', 'test-verifier', redirectUri); + + expect(global.fetch).toHaveBeenCalledWith( + 'https://auth.example.com/token', + expect.objectContaining({ + method: 'POST', + headers: { + 'Content-Type': 'application/x-www-form-urlencoded' + }, + body: expect.stringContaining(`redirect_uri=${encodeURIComponent(redirectUri)}`) + }) + ); + expect(tokens).toEqual(mockTokenResponse); + }); + + it('includes resource parameter in authorization code exchange', async () => { + const tokens = await provider.exchangeAuthorizationCode( + validClient, + 'test-code', + 'test-verifier', + 'https://example.com/callback', + new URL('https://api.example.com/resource') + ); + + expect(global.fetch).toHaveBeenCalledWith( + 'https://auth.example.com/token', + expect.objectContaining({ + method: 'POST', + headers: { + 'Content-Type': 'application/x-www-form-urlencoded' + }, + body: expect.stringContaining('resource=' + encodeURIComponent('https://api.example.com/resource')) + }) + ); + expect(tokens).toEqual(mockTokenResponse); + }); + + it('handles authorization code exchange without resource parameter', async () => { + const tokens = await provider.exchangeAuthorizationCode(validClient, 'test-code', 'test-verifier'); + + const fetchCall = (global.fetch as jest.Mock).mock.calls[0]; + const body = fetchCall[1].body as string; + expect(body).not.toContain('resource='); + expect(tokens).toEqual(mockTokenResponse); + }); + + it('exchanges refresh token for new tokens', async () => { + const tokens = await provider.exchangeRefreshToken(validClient, 'test-refresh-token', ['read', 'write']); + + expect(global.fetch).toHaveBeenCalledWith( + 'https://auth.example.com/token', + expect.objectContaining({ + method: 'POST', + headers: { + 'Content-Type': 'application/x-www-form-urlencoded' + }, + body: expect.stringContaining('grant_type=refresh_token') + }) + ); + expect(tokens).toEqual(mockTokenResponse); + }); + + it('includes resource parameter in refresh token exchange', async () => { + const tokens = await provider.exchangeRefreshToken( + validClient, + 'test-refresh-token', + ['read', 'write'], + new URL('https://api.example.com/resource') + ); + + expect(global.fetch).toHaveBeenCalledWith( + 'https://auth.example.com/token', + expect.objectContaining({ + method: 'POST', + headers: { + 'Content-Type': 'application/x-www-form-urlencoded' + }, + body: expect.stringContaining('resource=' + encodeURIComponent('https://api.example.com/resource')) + }) + ); + expect(tokens).toEqual(mockTokenResponse); + }); }); - it("passes through InvalidTokenError", async () => { - const error = new InvalidTokenError("Token expired"); - mockVerifyToken.mockRejectedValue(error); - - await expect(provider.verifyAccessToken("invalid-token")) - .rejects.toBe(error); - expect(mockVerifyToken).toHaveBeenCalledWith("invalid-token"); + describe('client registration', () => { + it('registers new client', async () => { + const newClient: OAuthClientInformationFull = { + client_id: 'new-client', + redirect_uris: ['https://new-client.com/callback'] + }; + + (global.fetch as jest.Mock).mockImplementation(() => + Promise.resolve({ + ok: true, + json: () => Promise.resolve(newClient) + }) + ); + + const result = await provider.clientsStore.registerClient!(newClient); + + expect(global.fetch).toHaveBeenCalledWith( + 'https://auth.example.com/register', + expect.objectContaining({ + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify(newClient) + }) + ); + expect(result).toEqual(newClient); + }); + + it('handles registration failure', async () => { + mockFailedResponse(); + const newClient: OAuthClientInformationFull = { + client_id: 'new-client', + redirect_uris: ['https://new-client.com/callback'] + }; + + await expect(provider.clientsStore.registerClient!(newClient)).rejects.toThrow(ServerError); + }); }); - it("passes through InsufficientScopeError", async () => { - const error = new InsufficientScopeError("Required scopes: read, write"); - mockVerifyToken.mockRejectedValue(error); - - await expect(provider.verifyAccessToken("token-with-insufficient-scope")) - .rejects.toBe(error); - expect(mockVerifyToken).toHaveBeenCalledWith("token-with-insufficient-scope"); + describe('token revocation', () => { + it('revokes token', async () => { + (global.fetch as jest.Mock).mockImplementation(() => + Promise.resolve({ + ok: true + }) + ); + + await provider.revokeToken!(validClient, { + token: 'token-to-revoke', + token_type_hint: 'access_token' + }); + + expect(global.fetch).toHaveBeenCalledWith( + 'https://auth.example.com/revoke', + expect.objectContaining({ + method: 'POST', + headers: { + 'Content-Type': 'application/x-www-form-urlencoded' + }, + body: expect.stringContaining('token=token-to-revoke') + }) + ); + }); + + it('handles revocation failure', async () => { + mockFailedResponse(); + await expect( + provider.revokeToken!(validClient, { + token: 'invalid-token' + }) + ).rejects.toThrow(ServerError); + }); }); - it("passes through unexpected errors", async () => { - const error = new Error("Unexpected error"); - mockVerifyToken.mockRejectedValue(error); - - await expect(provider.verifyAccessToken("valid-token")) - .rejects.toBe(error); - expect(mockVerifyToken).toHaveBeenCalledWith("valid-token"); + describe('token verification', () => { + it('verifies valid token', async () => { + const validAuthInfo: AuthInfo = { + token: 'valid-token', + clientId: 'test-client', + scopes: ['read', 'write'], + expiresAt: Date.now() / 1000 + 3600 + }; + mockVerifyToken.mockResolvedValue(validAuthInfo); + + const authInfo = await provider.verifyAccessToken('valid-token'); + expect(authInfo).toEqual(validAuthInfo); + expect(mockVerifyToken).toHaveBeenCalledWith('valid-token'); + }); + + it('passes through InvalidTokenError', async () => { + const error = new InvalidTokenError('Token expired'); + mockVerifyToken.mockRejectedValue(error); + + await expect(provider.verifyAccessToken('invalid-token')).rejects.toBe(error); + expect(mockVerifyToken).toHaveBeenCalledWith('invalid-token'); + }); + + it('passes through InsufficientScopeError', async () => { + const error = new InsufficientScopeError('Required scopes: read, write'); + mockVerifyToken.mockRejectedValue(error); + + await expect(provider.verifyAccessToken('token-with-insufficient-scope')).rejects.toBe(error); + expect(mockVerifyToken).toHaveBeenCalledWith('token-with-insufficient-scope'); + }); + + it('passes through unexpected errors', async () => { + const error = new Error('Unexpected error'); + mockVerifyToken.mockRejectedValue(error); + + await expect(provider.verifyAccessToken('valid-token')).rejects.toBe(error); + expect(mockVerifyToken).toHaveBeenCalledWith('valid-token'); + }); }); - }); -}); \ No newline at end of file +}); diff --git a/src/server/auth/providers/proxyProvider.ts b/src/server/auth/providers/proxyProvider.ts index c66a8707c..32f256450 100644 --- a/src/server/auth/providers/proxyProvider.ts +++ b/src/server/auth/providers/proxyProvider.ts @@ -1,249 +1,234 @@ -import { Response } from "express"; -import { OAuthRegisteredClientsStore } from "../clients.js"; +import { Response } from 'express'; +import { OAuthRegisteredClientsStore } from '../clients.js'; import { - OAuthClientInformationFull, - OAuthClientInformationFullSchema, - OAuthTokenRevocationRequest, - OAuthTokens, - OAuthTokensSchema, -} from "../../../shared/auth.js"; -import { AuthInfo } from "../types.js"; -import { AuthorizationParams, OAuthServerProvider } from "../provider.js"; -import { ServerError } from "../errors.js"; -import { FetchLike } from "../../../shared/transport.js"; + OAuthClientInformationFull, + OAuthClientInformationFullSchema, + OAuthTokenRevocationRequest, + OAuthTokens, + OAuthTokensSchema +} from '../../../shared/auth.js'; +import { AuthInfo } from '../types.js'; +import { AuthorizationParams, OAuthServerProvider } from '../provider.js'; +import { ServerError } from '../errors.js'; +import { FetchLike } from '../../../shared/transport.js'; export type ProxyEndpoints = { - authorizationUrl: string; - tokenUrl: string; - revocationUrl?: string; - registrationUrl?: string; + authorizationUrl: string; + tokenUrl: string; + revocationUrl?: string; + registrationUrl?: string; }; export type ProxyOptions = { - /** - * Individual endpoint URLs for proxying specific OAuth operations - */ - endpoints: ProxyEndpoints; - - /** - * Function to verify access tokens and return auth info - */ - verifyAccessToken: (token: string) => Promise; - - /** - * Function to fetch client information from the upstream server - */ - getClient: (clientId: string) => Promise; - - /** - * Custom fetch implementation used for all network requests. - */ - fetch?: FetchLike; + /** + * Individual endpoint URLs for proxying specific OAuth operations + */ + endpoints: ProxyEndpoints; + + /** + * Function to verify access tokens and return auth info + */ + verifyAccessToken: (token: string) => Promise; + + /** + * Function to fetch client information from the upstream server + */ + getClient: (clientId: string) => Promise; + + /** + * Custom fetch implementation used for all network requests. + */ + fetch?: FetchLike; }; /** * Implements an OAuth server that proxies requests to another OAuth server. */ export class ProxyOAuthServerProvider implements OAuthServerProvider { - protected readonly _endpoints: ProxyEndpoints; - protected readonly _verifyAccessToken: (token: string) => Promise; - protected readonly _getClient: (clientId: string) => Promise; - protected readonly _fetch?: FetchLike; - - skipLocalPkceValidation = true; - - revokeToken?: ( - client: OAuthClientInformationFull, - request: OAuthTokenRevocationRequest - ) => Promise; - - constructor(options: ProxyOptions) { - this._endpoints = options.endpoints; - this._verifyAccessToken = options.verifyAccessToken; - this._getClient = options.getClient; - this._fetch = options.fetch; - if (options.endpoints?.revocationUrl) { - this.revokeToken = async ( - client: OAuthClientInformationFull, - request: OAuthTokenRevocationRequest - ) => { - const revocationUrl = this._endpoints.revocationUrl; - - if (!revocationUrl) { - throw new Error("No revocation endpoint configured"); + protected readonly _endpoints: ProxyEndpoints; + protected readonly _verifyAccessToken: (token: string) => Promise; + protected readonly _getClient: (clientId: string) => Promise; + protected readonly _fetch?: FetchLike; + + skipLocalPkceValidation = true; + + revokeToken?: (client: OAuthClientInformationFull, request: OAuthTokenRevocationRequest) => Promise; + + constructor(options: ProxyOptions) { + this._endpoints = options.endpoints; + this._verifyAccessToken = options.verifyAccessToken; + this._getClient = options.getClient; + this._fetch = options.fetch; + if (options.endpoints?.revocationUrl) { + this.revokeToken = async (client: OAuthClientInformationFull, request: OAuthTokenRevocationRequest) => { + const revocationUrl = this._endpoints.revocationUrl; + + if (!revocationUrl) { + throw new Error('No revocation endpoint configured'); + } + + const params = new URLSearchParams(); + params.set('token', request.token); + params.set('client_id', client.client_id); + if (client.client_secret) { + params.set('client_secret', client.client_secret); + } + if (request.token_type_hint) { + params.set('token_type_hint', request.token_type_hint); + } + + const response = await (this._fetch ?? fetch)(revocationUrl, { + method: 'POST', + headers: { + 'Content-Type': 'application/x-www-form-urlencoded' + }, + body: params.toString() + }); + + if (!response.ok) { + throw new ServerError(`Token revocation failed: ${response.status}`); + } + }; } + } + + get clientsStore(): OAuthRegisteredClientsStore { + const registrationUrl = this._endpoints.registrationUrl; + return { + getClient: this._getClient, + ...(registrationUrl && { + registerClient: async (client: OAuthClientInformationFull) => { + const response = await (this._fetch ?? fetch)(registrationUrl, { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify(client) + }); + + if (!response.ok) { + throw new ServerError(`Client registration failed: ${response.status}`); + } + + const data = await response.json(); + return OAuthClientInformationFullSchema.parse(data); + } + }) + }; + } + + async authorize(client: OAuthClientInformationFull, params: AuthorizationParams, res: Response): Promise { + // Start with required OAuth parameters + const targetUrl = new URL(this._endpoints.authorizationUrl); + const searchParams = new URLSearchParams({ + client_id: client.client_id, + response_type: 'code', + redirect_uri: params.redirectUri, + code_challenge: params.codeChallenge, + code_challenge_method: 'S256' + }); + + // Add optional standard OAuth parameters + if (params.state) searchParams.set('state', params.state); + if (params.scopes?.length) searchParams.set('scope', params.scopes.join(' ')); + if (params.resource) searchParams.set('resource', params.resource.href); + + targetUrl.search = searchParams.toString(); + res.redirect(targetUrl.toString()); + } + + async challengeForAuthorizationCode(_client: OAuthClientInformationFull, _authorizationCode: string): Promise { + // In a proxy setup, we don't store the code challenge ourselves + // Instead, we proxy the token request and let the upstream server validate it + return ''; + } + + async exchangeAuthorizationCode( + client: OAuthClientInformationFull, + authorizationCode: string, + codeVerifier?: string, + redirectUri?: string, + resource?: URL + ): Promise { + const params = new URLSearchParams({ + grant_type: 'authorization_code', + client_id: client.client_id, + code: authorizationCode + }); - const params = new URLSearchParams(); - params.set("token", request.token); - params.set("client_id", client.client_id); if (client.client_secret) { - params.set("client_secret", client.client_secret); + params.append('client_secret', client.client_secret); } - if (request.token_type_hint) { - params.set("token_type_hint", request.token_type_hint); + + if (codeVerifier) { + params.append('code_verifier', codeVerifier); } - const response = await (this._fetch ?? fetch)(revocationUrl, { - method: "POST", - headers: { - "Content-Type": "application/x-www-form-urlencoded", - }, - body: params.toString(), - }); + if (redirectUri) { + params.append('redirect_uri', redirectUri); + } - if (!response.ok) { - throw new ServerError(`Token revocation failed: ${response.status}`); + if (resource) { + params.append('resource', resource.href); } - } - } - } - - get clientsStore(): OAuthRegisteredClientsStore { - const registrationUrl = this._endpoints.registrationUrl; - return { - getClient: this._getClient, - ...(registrationUrl && { - registerClient: async (client: OAuthClientInformationFull) => { - const response = await (this._fetch ?? fetch)(registrationUrl, { - method: "POST", + + const response = await (this._fetch ?? fetch)(this._endpoints.tokenUrl, { + method: 'POST', headers: { - "Content-Type": "application/json", + 'Content-Type': 'application/x-www-form-urlencoded' }, - body: JSON.stringify(client), - }); - - if (!response.ok) { - throw new ServerError(`Client registration failed: ${response.status}`); - } + body: params.toString() + }); - const data = await response.json(); - return OAuthClientInformationFullSchema.parse(data); + if (!response.ok) { + throw new ServerError(`Token exchange failed: ${response.status}`); } - }) - } - } - - async authorize( - client: OAuthClientInformationFull, - params: AuthorizationParams, - res: Response - ): Promise { - // Start with required OAuth parameters - const targetUrl = new URL(this._endpoints.authorizationUrl); - const searchParams = new URLSearchParams({ - client_id: client.client_id, - response_type: "code", - redirect_uri: params.redirectUri, - code_challenge: params.codeChallenge, - code_challenge_method: "S256" - }); - - // Add optional standard OAuth parameters - if (params.state) searchParams.set("state", params.state); - if (params.scopes?.length) searchParams.set("scope", params.scopes.join(" ")); - if (params.resource) searchParams.set("resource", params.resource.href); - - targetUrl.search = searchParams.toString(); - res.redirect(targetUrl.toString()); - } - - async challengeForAuthorizationCode( - _client: OAuthClientInformationFull, - _authorizationCode: string - ): Promise { - // In a proxy setup, we don't store the code challenge ourselves - // Instead, we proxy the token request and let the upstream server validate it - return ""; - } - - async exchangeAuthorizationCode( - client: OAuthClientInformationFull, - authorizationCode: string, - codeVerifier?: string, - redirectUri?: string, - resource?: URL - ): Promise { - const params = new URLSearchParams({ - grant_type: "authorization_code", - client_id: client.client_id, - code: authorizationCode, - }); - - if (client.client_secret) { - params.append("client_secret", client.client_secret); - } - if (codeVerifier) { - params.append("code_verifier", codeVerifier); + const data = await response.json(); + return OAuthTokensSchema.parse(data); } - if (redirectUri) { - params.append("redirect_uri", redirectUri); - } - - if (resource) { - params.append("resource", resource.href); - } + async exchangeRefreshToken( + client: OAuthClientInformationFull, + refreshToken: string, + scopes?: string[], + resource?: URL + ): Promise { + const params = new URLSearchParams({ + grant_type: 'refresh_token', + client_id: client.client_id, + refresh_token: refreshToken + }); - const response = await (this._fetch ?? fetch)(this._endpoints.tokenUrl, { - method: "POST", - headers: { - "Content-Type": "application/x-www-form-urlencoded", - }, - body: params.toString(), - }); + if (client.client_secret) { + params.set('client_secret', client.client_secret); + } + if (scopes?.length) { + params.set('scope', scopes.join(' ')); + } - if (!response.ok) { - throw new ServerError(`Token exchange failed: ${response.status}`); - } + if (resource) { + params.set('resource', resource.href); + } - const data = await response.json(); - return OAuthTokensSchema.parse(data); - } - - async exchangeRefreshToken( - client: OAuthClientInformationFull, - refreshToken: string, - scopes?: string[], - resource?: URL - ): Promise { - - const params = new URLSearchParams({ - grant_type: "refresh_token", - client_id: client.client_id, - refresh_token: refreshToken, - }); - - if (client.client_secret) { - params.set("client_secret", client.client_secret); - } + const response = await (this._fetch ?? fetch)(this._endpoints.tokenUrl, { + method: 'POST', + headers: { + 'Content-Type': 'application/x-www-form-urlencoded' + }, + body: params.toString() + }); - if (scopes?.length) { - params.set("scope", scopes.join(" ")); - } + if (!response.ok) { + throw new ServerError(`Token refresh failed: ${response.status}`); + } - if (resource) { - params.set("resource", resource.href); + const data = await response.json(); + return OAuthTokensSchema.parse(data); } - const response = await (this._fetch ?? fetch)(this._endpoints.tokenUrl, { - method: "POST", - headers: { - "Content-Type": "application/x-www-form-urlencoded", - }, - body: params.toString(), - }); - - if (!response.ok) { - throw new ServerError(`Token refresh failed: ${response.status}`); + async verifyAccessToken(token: string): Promise { + return this._verifyAccessToken(token); } - - const data = await response.json(); - return OAuthTokensSchema.parse(data); - } - - async verifyAccessToken(token: string): Promise { - return this._verifyAccessToken(token); - } -} \ No newline at end of file +} diff --git a/src/server/auth/router.test.ts b/src/server/auth/router.test.ts index bcf0a51af..f2091bcbe 100644 --- a/src/server/auth/router.test.ts +++ b/src/server/auth/router.test.ts @@ -7,476 +7,457 @@ import supertest from 'supertest'; import { AuthInfo } from './types.js'; import { InvalidTokenError } from './errors.js'; - describe('MCP Auth Router', () => { - // Setup mock provider with full capabilities - const mockClientStore: OAuthRegisteredClientsStore = { - async getClient(clientId: string): Promise { - if (clientId === 'valid-client') { - return { - client_id: 'valid-client', - client_secret: 'valid-secret', - redirect_uris: ['https://example.com/callback'] - }; - } - return undefined; - }, - - async registerClient(client: OAuthClientInformationFull): Promise { - return client; - } - }; - - const mockProvider: OAuthServerProvider = { - clientsStore: mockClientStore, - - async authorize(client: OAuthClientInformationFull, params: AuthorizationParams, res: Response): Promise { - const redirectUrl = new URL(params.redirectUri); - redirectUrl.searchParams.set('code', 'mock_auth_code'); - if (params.state) { - redirectUrl.searchParams.set('state', params.state); - } - res.redirect(302, redirectUrl.toString()); - }, - - async challengeForAuthorizationCode(): Promise { - return 'mock_challenge'; - }, - - async exchangeAuthorizationCode(): Promise { - return { - access_token: 'mock_access_token', - token_type: 'bearer', - expires_in: 3600, - refresh_token: 'mock_refresh_token' - }; - }, - - async exchangeRefreshToken(): Promise { - return { - access_token: 'new_mock_access_token', - token_type: 'bearer', - expires_in: 3600, - refresh_token: 'new_mock_refresh_token' - }; - }, - - async verifyAccessToken(token: string): Promise { - if (token === 'valid_token') { - return { - token, - clientId: 'valid-client', - scopes: ['read', 'write'], - expiresAt: Date.now() / 1000 + 3600 - }; - } - throw new InvalidTokenError('Token is invalid or expired'); - }, - - async revokeToken(_client: OAuthClientInformationFull, _request: OAuthTokenRevocationRequest): Promise { - // Success - do nothing in mock - } - }; - - // Provider without registration and revocation - const mockProviderMinimal: OAuthServerProvider = { - clientsStore: { - async getClient(clientId: string): Promise { - if (clientId === 'valid-client') { - return { - client_id: 'valid-client', - client_secret: 'valid-secret', - redirect_uris: ['https://example.com/callback'] - }; + // Setup mock provider with full capabilities + const mockClientStore: OAuthRegisteredClientsStore = { + async getClient(clientId: string): Promise { + if (clientId === 'valid-client') { + return { + client_id: 'valid-client', + client_secret: 'valid-secret', + redirect_uris: ['https://example.com/callback'] + }; + } + return undefined; + }, + + async registerClient(client: OAuthClientInformationFull): Promise { + return client; } - return undefined; - } - }, - - async authorize(client: OAuthClientInformationFull, params: AuthorizationParams, res: Response): Promise { - const redirectUrl = new URL(params.redirectUri); - redirectUrl.searchParams.set('code', 'mock_auth_code'); - if (params.state) { - redirectUrl.searchParams.set('state', params.state); - } - res.redirect(302, redirectUrl.toString()); - }, - - async challengeForAuthorizationCode(): Promise { - return 'mock_challenge'; - }, - - async exchangeAuthorizationCode(): Promise { - return { - access_token: 'mock_access_token', - token_type: 'bearer', - expires_in: 3600, - refresh_token: 'mock_refresh_token' - }; - }, - - async exchangeRefreshToken(): Promise { - return { - access_token: 'new_mock_access_token', - token_type: 'bearer', - expires_in: 3600, - refresh_token: 'new_mock_refresh_token' - }; - }, - - async verifyAccessToken(token: string): Promise { - if (token === 'valid_token') { - return { - token, - clientId: 'valid-client', - scopes: ['read'], - expiresAt: Date.now() / 1000 + 3600 - }; - } - throw new InvalidTokenError('Token is invalid or expired'); - } - }; - - describe('Router creation', () => { - it('throws error for non-HTTPS issuer URL', () => { - const options: AuthRouterOptions = { - provider: mockProvider, - issuerUrl: new URL('http://auth.example.com') - }; - - expect(() => mcpAuthRouter(options)).toThrow('Issuer URL must be HTTPS'); - }); + }; + + const mockProvider: OAuthServerProvider = { + clientsStore: mockClientStore, + + async authorize(client: OAuthClientInformationFull, params: AuthorizationParams, res: Response): Promise { + const redirectUrl = new URL(params.redirectUri); + redirectUrl.searchParams.set('code', 'mock_auth_code'); + if (params.state) { + redirectUrl.searchParams.set('state', params.state); + } + res.redirect(302, redirectUrl.toString()); + }, + + async challengeForAuthorizationCode(): Promise { + return 'mock_challenge'; + }, + + async exchangeAuthorizationCode(): Promise { + return { + access_token: 'mock_access_token', + token_type: 'bearer', + expires_in: 3600, + refresh_token: 'mock_refresh_token' + }; + }, + + async exchangeRefreshToken(): Promise { + return { + access_token: 'new_mock_access_token', + token_type: 'bearer', + expires_in: 3600, + refresh_token: 'new_mock_refresh_token' + }; + }, + + async verifyAccessToken(token: string): Promise { + if (token === 'valid_token') { + return { + token, + clientId: 'valid-client', + scopes: ['read', 'write'], + expiresAt: Date.now() / 1000 + 3600 + }; + } + throw new InvalidTokenError('Token is invalid or expired'); + }, + + async revokeToken(_client: OAuthClientInformationFull, _request: OAuthTokenRevocationRequest): Promise { + // Success - do nothing in mock + } + }; + + // Provider without registration and revocation + const mockProviderMinimal: OAuthServerProvider = { + clientsStore: { + async getClient(clientId: string): Promise { + if (clientId === 'valid-client') { + return { + client_id: 'valid-client', + client_secret: 'valid-secret', + redirect_uris: ['https://example.com/callback'] + }; + } + return undefined; + } + }, + + async authorize(client: OAuthClientInformationFull, params: AuthorizationParams, res: Response): Promise { + const redirectUrl = new URL(params.redirectUri); + redirectUrl.searchParams.set('code', 'mock_auth_code'); + if (params.state) { + redirectUrl.searchParams.set('state', params.state); + } + res.redirect(302, redirectUrl.toString()); + }, + + async challengeForAuthorizationCode(): Promise { + return 'mock_challenge'; + }, + + async exchangeAuthorizationCode(): Promise { + return { + access_token: 'mock_access_token', + token_type: 'bearer', + expires_in: 3600, + refresh_token: 'mock_refresh_token' + }; + }, + + async exchangeRefreshToken(): Promise { + return { + access_token: 'new_mock_access_token', + token_type: 'bearer', + expires_in: 3600, + refresh_token: 'new_mock_refresh_token' + }; + }, + + async verifyAccessToken(token: string): Promise { + if (token === 'valid_token') { + return { + token, + clientId: 'valid-client', + scopes: ['read'], + expiresAt: Date.now() / 1000 + 3600 + }; + } + throw new InvalidTokenError('Token is invalid or expired'); + } + }; - it('allows localhost HTTP for development', () => { - const options: AuthRouterOptions = { - provider: mockProvider, - issuerUrl: new URL('http://localhost:3000') - }; + describe('Router creation', () => { + it('throws error for non-HTTPS issuer URL', () => { + const options: AuthRouterOptions = { + provider: mockProvider, + issuerUrl: new URL('http://auth.example.com') + }; - expect(() => mcpAuthRouter(options)).not.toThrow(); - }); + expect(() => mcpAuthRouter(options)).toThrow('Issuer URL must be HTTPS'); + }); - it('throws error for issuer URL with fragment', () => { - const options: AuthRouterOptions = { - provider: mockProvider, - issuerUrl: new URL('https://auth.example.com#fragment') - }; + it('allows localhost HTTP for development', () => { + const options: AuthRouterOptions = { + provider: mockProvider, + issuerUrl: new URL('http://localhost:3000') + }; - expect(() => mcpAuthRouter(options)).toThrow('Issuer URL must not have a fragment'); - }); + expect(() => mcpAuthRouter(options)).not.toThrow(); + }); - it('throws error for issuer URL with query string', () => { - const options: AuthRouterOptions = { - provider: mockProvider, - issuerUrl: new URL('https://auth.example.com?param=value') - }; + it('throws error for issuer URL with fragment', () => { + const options: AuthRouterOptions = { + provider: mockProvider, + issuerUrl: new URL('https://auth.example.com#fragment') + }; - expect(() => mcpAuthRouter(options)).toThrow('Issuer URL must not have a query string'); - }); + expect(() => mcpAuthRouter(options)).toThrow('Issuer URL must not have a fragment'); + }); - it('successfully creates router with valid options', () => { - const options: AuthRouterOptions = { - provider: mockProvider, - issuerUrl: new URL('https://auth.example.com') - }; + it('throws error for issuer URL with query string', () => { + const options: AuthRouterOptions = { + provider: mockProvider, + issuerUrl: new URL('https://auth.example.com?param=value') + }; - expect(() => mcpAuthRouter(options)).not.toThrow(); - }); - }); - - describe('Metadata endpoint', () => { - let app: express.Express; - - beforeEach(() => { - // Setup full-featured router - app = express(); - const options: AuthRouterOptions = { - provider: mockProvider, - issuerUrl: new URL('https://auth.example.com'), - serviceDocumentationUrl: new URL('https://docs.example.com') - }; - app.use(mcpAuthRouter(options)); - }); + expect(() => mcpAuthRouter(options)).toThrow('Issuer URL must not have a query string'); + }); - it('returns complete metadata for full-featured router', async () => { - const response = await supertest(app) - .get('/.well-known/oauth-authorization-server'); + it('successfully creates router with valid options', () => { + const options: AuthRouterOptions = { + provider: mockProvider, + issuerUrl: new URL('https://auth.example.com') + }; - expect(response.status).toBe(200); + expect(() => mcpAuthRouter(options)).not.toThrow(); + }); + }); - // Verify essential fields - expect(response.body.issuer).toBe('https://auth.example.com/'); - expect(response.body.authorization_endpoint).toBe('https://auth.example.com/authorize'); - expect(response.body.token_endpoint).toBe('https://auth.example.com/token'); - expect(response.body.registration_endpoint).toBe('https://auth.example.com/register'); - expect(response.body.revocation_endpoint).toBe('https://auth.example.com/revoke'); + describe('Metadata endpoint', () => { + let app: express.Express; + + beforeEach(() => { + // Setup full-featured router + app = express(); + const options: AuthRouterOptions = { + provider: mockProvider, + issuerUrl: new URL('https://auth.example.com'), + serviceDocumentationUrl: new URL('https://docs.example.com') + }; + app.use(mcpAuthRouter(options)); + }); - // Verify supported features - expect(response.body.response_types_supported).toEqual(['code']); - expect(response.body.grant_types_supported).toEqual(['authorization_code', 'refresh_token']); - expect(response.body.code_challenge_methods_supported).toEqual(['S256']); - expect(response.body.token_endpoint_auth_methods_supported).toEqual(['client_secret_post']); - expect(response.body.revocation_endpoint_auth_methods_supported).toEqual(['client_secret_post']); + it('returns complete metadata for full-featured router', async () => { + const response = await supertest(app).get('/.well-known/oauth-authorization-server'); - // Verify optional fields - expect(response.body.service_documentation).toBe('https://docs.example.com/'); - }); + expect(response.status).toBe(200); - it('returns minimal metadata for minimal router', async () => { - // Setup minimal router - const minimalApp = express(); - const options: AuthRouterOptions = { - provider: mockProviderMinimal, - issuerUrl: new URL('https://auth.example.com') - }; - minimalApp.use(mcpAuthRouter(options)); - - const response = await supertest(minimalApp) - .get('/.well-known/oauth-authorization-server'); - - expect(response.status).toBe(200); - - // Verify essential endpoints - expect(response.body.issuer).toBe('https://auth.example.com/'); - expect(response.body.authorization_endpoint).toBe('https://auth.example.com/authorize'); - expect(response.body.token_endpoint).toBe('https://auth.example.com/token'); - - // Verify missing optional endpoints - expect(response.body.registration_endpoint).toBeUndefined(); - expect(response.body.revocation_endpoint).toBeUndefined(); - expect(response.body.revocation_endpoint_auth_methods_supported).toBeUndefined(); - expect(response.body.service_documentation).toBeUndefined(); - }); + // Verify essential fields + expect(response.body.issuer).toBe('https://auth.example.com/'); + expect(response.body.authorization_endpoint).toBe('https://auth.example.com/authorize'); + expect(response.body.token_endpoint).toBe('https://auth.example.com/token'); + expect(response.body.registration_endpoint).toBe('https://auth.example.com/register'); + expect(response.body.revocation_endpoint).toBe('https://auth.example.com/revoke'); - it('provides protected resource metadata', async () => { - // Setup router with draft protocol version - const draftApp = express(); - const options: AuthRouterOptions = { - provider: mockProvider, - issuerUrl: new URL('https://mcp.example.com'), - scopesSupported: ['read', 'write'], - resourceName: 'Test API' - }; - draftApp.use(mcpAuthRouter(options)); - - const response = await supertest(draftApp) - .get('/.well-known/oauth-protected-resource'); - - expect(response.status).toBe(200); - - // Verify protected resource metadata - expect(response.body.resource).toBe('https://mcp.example.com/'); - expect(response.body.authorization_servers).toContain('https://mcp.example.com/'); - expect(response.body.scopes_supported).toEqual(['read', 'write']); - expect(response.body.resource_name).toBe('Test API'); - }); - }); - - describe('Endpoint routing', () => { - let app: express.Express; - - beforeEach(() => { - // Setup full-featured router - app = express(); - const options: AuthRouterOptions = { - provider: mockProvider, - issuerUrl: new URL('https://auth.example.com') - }; - app.use(mcpAuthRouter(options)); - jest.spyOn(console, 'error').mockImplementation(() => {}); - }); + // Verify supported features + expect(response.body.response_types_supported).toEqual(['code']); + expect(response.body.grant_types_supported).toEqual(['authorization_code', 'refresh_token']); + expect(response.body.code_challenge_methods_supported).toEqual(['S256']); + expect(response.body.token_endpoint_auth_methods_supported).toEqual(['client_secret_post']); + expect(response.body.revocation_endpoint_auth_methods_supported).toEqual(['client_secret_post']); - afterEach(() => { - jest.restoreAllMocks(); - }); + // Verify optional fields + expect(response.body.service_documentation).toBe('https://docs.example.com/'); + }); - it('routes to authorization endpoint', async () => { - const response = await supertest(app) - .get('/authorize') - .query({ - client_id: 'valid-client', - response_type: 'code', - code_challenge: 'challenge123', - code_challenge_method: 'S256' + it('returns minimal metadata for minimal router', async () => { + // Setup minimal router + const minimalApp = express(); + const options: AuthRouterOptions = { + provider: mockProviderMinimal, + issuerUrl: new URL('https://auth.example.com') + }; + minimalApp.use(mcpAuthRouter(options)); + + const response = await supertest(minimalApp).get('/.well-known/oauth-authorization-server'); + + expect(response.status).toBe(200); + + // Verify essential endpoints + expect(response.body.issuer).toBe('https://auth.example.com/'); + expect(response.body.authorization_endpoint).toBe('https://auth.example.com/authorize'); + expect(response.body.token_endpoint).toBe('https://auth.example.com/token'); + + // Verify missing optional endpoints + expect(response.body.registration_endpoint).toBeUndefined(); + expect(response.body.revocation_endpoint).toBeUndefined(); + expect(response.body.revocation_endpoint_auth_methods_supported).toBeUndefined(); + expect(response.body.service_documentation).toBeUndefined(); }); - expect(response.status).toBe(302); - const location = new URL(response.header.location); - expect(location.searchParams.has('code')).toBe(true); + it('provides protected resource metadata', async () => { + // Setup router with draft protocol version + const draftApp = express(); + const options: AuthRouterOptions = { + provider: mockProvider, + issuerUrl: new URL('https://mcp.example.com'), + scopesSupported: ['read', 'write'], + resourceName: 'Test API' + }; + draftApp.use(mcpAuthRouter(options)); + + const response = await supertest(draftApp).get('/.well-known/oauth-protected-resource'); + + expect(response.status).toBe(200); + + // Verify protected resource metadata + expect(response.body.resource).toBe('https://mcp.example.com/'); + expect(response.body.authorization_servers).toContain('https://mcp.example.com/'); + expect(response.body.scopes_supported).toEqual(['read', 'write']); + expect(response.body.resource_name).toBe('Test API'); + }); }); - it('routes to token endpoint', async () => { - // Setup verifyChallenge mock for token handler - jest.mock('pkce-challenge', () => ({ - verifyChallenge: jest.fn().mockResolvedValue(true) - })); - - const response = await supertest(app) - .post('/token') - .type('form') - .send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - grant_type: 'authorization_code', - code: 'valid_code', - code_verifier: 'valid_verifier' + describe('Endpoint routing', () => { + let app: express.Express; + + beforeEach(() => { + // Setup full-featured router + app = express(); + const options: AuthRouterOptions = { + provider: mockProvider, + issuerUrl: new URL('https://auth.example.com') + }; + app.use(mcpAuthRouter(options)); + jest.spyOn(console, 'error').mockImplementation(() => {}); }); - // The request will fail in testing due to mocking limitations, - // but we can verify the route was matched - expect(response.status).not.toBe(404); - }); + afterEach(() => { + jest.restoreAllMocks(); + }); - it('routes to registration endpoint', async () => { - const response = await supertest(app) - .post('/register') - .send({ - redirect_uris: ['https://example.com/callback'] + it('routes to authorization endpoint', async () => { + const response = await supertest(app).get('/authorize').query({ + client_id: 'valid-client', + response_type: 'code', + code_challenge: 'challenge123', + code_challenge_method: 'S256' + }); + + expect(response.status).toBe(302); + const location = new URL(response.header.location); + expect(location.searchParams.has('code')).toBe(true); }); - // The request will fail in testing due to mocking limitations, - // but we can verify the route was matched - expect(response.status).not.toBe(404); - }); + it('routes to token endpoint', async () => { + // Setup verifyChallenge mock for token handler + jest.mock('pkce-challenge', () => ({ + verifyChallenge: jest.fn().mockResolvedValue(true) + })); + + const response = await supertest(app).post('/token').type('form').send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + grant_type: 'authorization_code', + code: 'valid_code', + code_verifier: 'valid_verifier' + }); + + // The request will fail in testing due to mocking limitations, + // but we can verify the route was matched + expect(response.status).not.toBe(404); + }); + + it('routes to registration endpoint', async () => { + const response = await supertest(app) + .post('/register') + .send({ + redirect_uris: ['https://example.com/callback'] + }); - it('routes to revocation endpoint', async () => { - const response = await supertest(app) - .post('/revoke') - .type('form') - .send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - token: 'token_to_revoke' + // The request will fail in testing due to mocking limitations, + // but we can verify the route was matched + expect(response.status).not.toBe(404); }); - // The request will fail in testing due to mocking limitations, - // but we can verify the route was matched - expect(response.status).not.toBe(404); - }); + it('routes to revocation endpoint', async () => { + const response = await supertest(app).post('/revoke').type('form').send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + token: 'token_to_revoke' + }); - it('excludes endpoints for unsupported features', async () => { - // Setup minimal router - const minimalApp = express(); - const options: AuthRouterOptions = { - provider: mockProviderMinimal, - issuerUrl: new URL('https://auth.example.com') - }; - minimalApp.use(mcpAuthRouter(options)); - - // Registration should not be available - const regResponse = await supertest(minimalApp) - .post('/register') - .send({ - redirect_uris: ['https://example.com/callback'] + // The request will fail in testing due to mocking limitations, + // but we can verify the route was matched + expect(response.status).not.toBe(404); }); - expect(regResponse.status).toBe(404); - - // Revocation should not be available - const revokeResponse = await supertest(minimalApp) - .post('/revoke') - .send({ - client_id: 'valid-client', - client_secret: 'valid-secret', - token: 'token_to_revoke' + + it('excludes endpoints for unsupported features', async () => { + // Setup minimal router + const minimalApp = express(); + const options: AuthRouterOptions = { + provider: mockProviderMinimal, + issuerUrl: new URL('https://auth.example.com') + }; + minimalApp.use(mcpAuthRouter(options)); + + // Registration should not be available + const regResponse = await supertest(minimalApp) + .post('/register') + .send({ + redirect_uris: ['https://example.com/callback'] + }); + expect(regResponse.status).toBe(404); + + // Revocation should not be available + const revokeResponse = await supertest(minimalApp).post('/revoke').send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + token: 'token_to_revoke' + }); + expect(revokeResponse.status).toBe(404); }); - expect(revokeResponse.status).toBe(404); }); - }); }); describe('MCP Auth Metadata Router', () => { - - const mockOAuthMetadata : OAuthMetadata = { - issuer: 'https://auth.example.com/', - authorization_endpoint: "https://auth.example.com/authorize", - token_endpoint: "https://auth.example.com/token", - response_types_supported: ["code"], - grant_types_supported: ["authorization_code", "refresh_token"], - code_challenge_methods_supported: ["S256"], - token_endpoint_auth_methods_supported: ["client_secret_post"], - } - - describe('Router creation', () => { - it('successfully creates router with valid options', () => { - const options: AuthMetadataOptions = { - oauthMetadata: mockOAuthMetadata, - resourceServerUrl: new URL('https://api.example.com'), - }; - - expect(() => mcpAuthMetadataRouter(options)).not.toThrow(); - }); - }); - - describe('Metadata endpoints', () => { - let app: express.Express; - - beforeEach(() => { - app = express(); - const options: AuthMetadataOptions = { - oauthMetadata: mockOAuthMetadata, - resourceServerUrl: new URL('https://api.example.com'), - serviceDocumentationUrl: new URL('https://docs.example.com'), - scopesSupported: ['read', 'write'], - resourceName: 'Test API' - }; - app.use(mcpAuthMetadataRouter(options)); + const mockOAuthMetadata: OAuthMetadata = { + issuer: 'https://auth.example.com/', + authorization_endpoint: 'https://auth.example.com/authorize', + token_endpoint: 'https://auth.example.com/token', + response_types_supported: ['code'], + grant_types_supported: ['authorization_code', 'refresh_token'], + code_challenge_methods_supported: ['S256'], + token_endpoint_auth_methods_supported: ['client_secret_post'] + }; + + describe('Router creation', () => { + it('successfully creates router with valid options', () => { + const options: AuthMetadataOptions = { + oauthMetadata: mockOAuthMetadata, + resourceServerUrl: new URL('https://api.example.com') + }; + + expect(() => mcpAuthMetadataRouter(options)).not.toThrow(); + }); }); - it('returns OAuth authorization server metadata', async () => { - const response = await supertest(app) - .get('/.well-known/oauth-authorization-server'); + describe('Metadata endpoints', () => { + let app: express.Express; + + beforeEach(() => { + app = express(); + const options: AuthMetadataOptions = { + oauthMetadata: mockOAuthMetadata, + resourceServerUrl: new URL('https://api.example.com'), + serviceDocumentationUrl: new URL('https://docs.example.com'), + scopesSupported: ['read', 'write'], + resourceName: 'Test API' + }; + app.use(mcpAuthMetadataRouter(options)); + }); - expect(response.status).toBe(200); + it('returns OAuth authorization server metadata', async () => { + const response = await supertest(app).get('/.well-known/oauth-authorization-server'); - // Verify metadata points to authorization server - expect(response.body.issuer).toBe('https://auth.example.com/'); - expect(response.body.authorization_endpoint).toBe('https://auth.example.com/authorize'); - expect(response.body.token_endpoint).toBe('https://auth.example.com/token'); - expect(response.body.response_types_supported).toEqual(['code']); - expect(response.body.grant_types_supported).toEqual(['authorization_code', 'refresh_token']); - expect(response.body.code_challenge_methods_supported).toEqual(['S256']); - expect(response.body.token_endpoint_auth_methods_supported).toEqual(['client_secret_post']); - }); + expect(response.status).toBe(200); - it('returns OAuth protected resource metadata', async () => { - const response = await supertest(app) - .get('/.well-known/oauth-protected-resource'); + // Verify metadata points to authorization server + expect(response.body.issuer).toBe('https://auth.example.com/'); + expect(response.body.authorization_endpoint).toBe('https://auth.example.com/authorize'); + expect(response.body.token_endpoint).toBe('https://auth.example.com/token'); + expect(response.body.response_types_supported).toEqual(['code']); + expect(response.body.grant_types_supported).toEqual(['authorization_code', 'refresh_token']); + expect(response.body.code_challenge_methods_supported).toEqual(['S256']); + expect(response.body.token_endpoint_auth_methods_supported).toEqual(['client_secret_post']); + }); - expect(response.status).toBe(200); + it('returns OAuth protected resource metadata', async () => { + const response = await supertest(app).get('/.well-known/oauth-protected-resource'); - // Verify protected resource metadata - expect(response.body.resource).toBe('https://api.example.com/'); - expect(response.body.authorization_servers).toEqual(['https://auth.example.com/']); - expect(response.body.scopes_supported).toEqual(['read', 'write']); - expect(response.body.resource_name).toBe('Test API'); - expect(response.body.resource_documentation).toBe('https://docs.example.com/'); - }); + expect(response.status).toBe(200); - it('works with minimal configuration', async () => { - const minimalApp = express(); - const options: AuthMetadataOptions = { - oauthMetadata: mockOAuthMetadata, - resourceServerUrl: new URL('https://api.example.com'), - }; - minimalApp.use(mcpAuthMetadataRouter(options)); - - const authResponse = await supertest(minimalApp) - .get('/.well-known/oauth-authorization-server'); - - expect(authResponse.status).toBe(200); - expect(authResponse.body.issuer).toBe('https://auth.example.com/'); - expect(authResponse.body.service_documentation).toBeUndefined(); - expect(authResponse.body.scopes_supported).toBeUndefined(); - - const resourceResponse = await supertest(minimalApp) - .get('/.well-known/oauth-protected-resource'); - - expect(resourceResponse.status).toBe(200); - expect(resourceResponse.body.resource).toBe('https://api.example.com/'); - expect(resourceResponse.body.authorization_servers).toEqual(['https://auth.example.com/']); - expect(resourceResponse.body.scopes_supported).toBeUndefined(); - expect(resourceResponse.body.resource_name).toBeUndefined(); - expect(resourceResponse.body.resource_documentation).toBeUndefined(); + // Verify protected resource metadata + expect(response.body.resource).toBe('https://api.example.com/'); + expect(response.body.authorization_servers).toEqual(['https://auth.example.com/']); + expect(response.body.scopes_supported).toEqual(['read', 'write']); + expect(response.body.resource_name).toBe('Test API'); + expect(response.body.resource_documentation).toBe('https://docs.example.com/'); + }); + + it('works with minimal configuration', async () => { + const minimalApp = express(); + const options: AuthMetadataOptions = { + oauthMetadata: mockOAuthMetadata, + resourceServerUrl: new URL('https://api.example.com') + }; + minimalApp.use(mcpAuthMetadataRouter(options)); + + const authResponse = await supertest(minimalApp).get('/.well-known/oauth-authorization-server'); + + expect(authResponse.status).toBe(200); + expect(authResponse.body.issuer).toBe('https://auth.example.com/'); + expect(authResponse.body.service_documentation).toBeUndefined(); + expect(authResponse.body.scopes_supported).toBeUndefined(); + + const resourceResponse = await supertest(minimalApp).get('/.well-known/oauth-protected-resource'); + + expect(resourceResponse.status).toBe(200); + expect(resourceResponse.body.resource).toBe('https://api.example.com/'); + expect(resourceResponse.body.authorization_servers).toEqual(['https://auth.example.com/']); + expect(resourceResponse.body.scopes_supported).toBeUndefined(); + expect(resourceResponse.body.resource_name).toBeUndefined(); + expect(resourceResponse.body.resource_documentation).toBeUndefined(); + }); }); - }); }); diff --git a/src/server/auth/router.ts b/src/server/auth/router.ts index a06bf73a1..c46832784 100644 --- a/src/server/auth/router.ts +++ b/src/server/auth/router.ts @@ -1,105 +1,104 @@ -import express, { RequestHandler } from "express"; -import { clientRegistrationHandler, ClientRegistrationHandlerOptions } from "./handlers/register.js"; -import { tokenHandler, TokenHandlerOptions } from "./handlers/token.js"; -import { authorizationHandler, AuthorizationHandlerOptions } from "./handlers/authorize.js"; -import { revocationHandler, RevocationHandlerOptions } from "./handlers/revoke.js"; -import { metadataHandler } from "./handlers/metadata.js"; -import { OAuthServerProvider } from "./provider.js"; -import { OAuthMetadata, OAuthProtectedResourceMetadata } from "../../shared/auth.js"; +import express, { RequestHandler } from 'express'; +import { clientRegistrationHandler, ClientRegistrationHandlerOptions } from './handlers/register.js'; +import { tokenHandler, TokenHandlerOptions } from './handlers/token.js'; +import { authorizationHandler, AuthorizationHandlerOptions } from './handlers/authorize.js'; +import { revocationHandler, RevocationHandlerOptions } from './handlers/revoke.js'; +import { metadataHandler } from './handlers/metadata.js'; +import { OAuthServerProvider } from './provider.js'; +import { OAuthMetadata, OAuthProtectedResourceMetadata } from '../../shared/auth.js'; export type AuthRouterOptions = { - /** - * A provider implementing the actual authorization logic for this router. - */ - provider: OAuthServerProvider; - - /** - * The authorization server's issuer identifier, which is a URL that uses the "https" scheme and has no query or fragment components. - */ - issuerUrl: URL; - - /** - * The base URL of the authorization server to use for the metadata endpoints. - * - * If not provided, the issuer URL will be used as the base URL. - */ - baseUrl?: URL; - - /** - * An optional URL of a page containing human-readable information that developers might want or need to know when using the authorization server. - */ - serviceDocumentationUrl?: URL; - - /** - * An optional list of scopes supported by this authorization server - */ - scopesSupported?: string[]; - - - /** - * The resource name to be displayed in protected resource metadata - */ - resourceName?: string; - - // Individual options per route - authorizationOptions?: Omit; - clientRegistrationOptions?: Omit; - revocationOptions?: Omit; - tokenOptions?: Omit; + /** + * A provider implementing the actual authorization logic for this router. + */ + provider: OAuthServerProvider; + + /** + * The authorization server's issuer identifier, which is a URL that uses the "https" scheme and has no query or fragment components. + */ + issuerUrl: URL; + + /** + * The base URL of the authorization server to use for the metadata endpoints. + * + * If not provided, the issuer URL will be used as the base URL. + */ + baseUrl?: URL; + + /** + * An optional URL of a page containing human-readable information that developers might want or need to know when using the authorization server. + */ + serviceDocumentationUrl?: URL; + + /** + * An optional list of scopes supported by this authorization server + */ + scopesSupported?: string[]; + + /** + * The resource name to be displayed in protected resource metadata + */ + resourceName?: string; + + // Individual options per route + authorizationOptions?: Omit; + clientRegistrationOptions?: Omit; + revocationOptions?: Omit; + tokenOptions?: Omit; }; const checkIssuerUrl = (issuer: URL): void => { - // Technically RFC 8414 does not permit a localhost HTTPS exemption, but this will be necessary for ease of testing - if (issuer.protocol !== "https:" && issuer.hostname !== "localhost" && issuer.hostname !== "127.0.0.1") { - throw new Error("Issuer URL must be HTTPS"); - } - if (issuer.hash) { - throw new Error(`Issuer URL must not have a fragment: ${issuer}`); - } - if (issuer.search) { - throw new Error(`Issuer URL must not have a query string: ${issuer}`); - } -} + // Technically RFC 8414 does not permit a localhost HTTPS exemption, but this will be necessary for ease of testing + if (issuer.protocol !== 'https:' && issuer.hostname !== 'localhost' && issuer.hostname !== '127.0.0.1') { + throw new Error('Issuer URL must be HTTPS'); + } + if (issuer.hash) { + throw new Error(`Issuer URL must not have a fragment: ${issuer}`); + } + if (issuer.search) { + throw new Error(`Issuer URL must not have a query string: ${issuer}`); + } +}; export const createOAuthMetadata = (options: { - provider: OAuthServerProvider, - issuerUrl: URL, - baseUrl?: URL - serviceDocumentationUrl?: URL, - scopesSupported?: string[]; + provider: OAuthServerProvider; + issuerUrl: URL; + baseUrl?: URL; + serviceDocumentationUrl?: URL; + scopesSupported?: string[]; }): OAuthMetadata => { - const issuer = options.issuerUrl; - const baseUrl = options.baseUrl; + const issuer = options.issuerUrl; + const baseUrl = options.baseUrl; - checkIssuerUrl(issuer); + checkIssuerUrl(issuer); - const authorization_endpoint = "/authorize"; - const token_endpoint = "/token"; - const registration_endpoint = options.provider.clientsStore.registerClient ? "/register" : undefined; - const revocation_endpoint = options.provider.revokeToken ? "/revoke" : undefined; + const authorization_endpoint = '/authorize'; + const token_endpoint = '/token'; + const registration_endpoint = options.provider.clientsStore.registerClient ? '/register' : undefined; + const revocation_endpoint = options.provider.revokeToken ? '/revoke' : undefined; - const metadata: OAuthMetadata = { - issuer: issuer.href, - service_documentation: options.serviceDocumentationUrl?.href, + const metadata: OAuthMetadata = { + issuer: issuer.href, + service_documentation: options.serviceDocumentationUrl?.href, - authorization_endpoint: new URL(authorization_endpoint, baseUrl || issuer).href, - response_types_supported: ["code"], - code_challenge_methods_supported: ["S256"], + authorization_endpoint: new URL(authorization_endpoint, baseUrl || issuer).href, + response_types_supported: ['code'], + code_challenge_methods_supported: ['S256'], - token_endpoint: new URL(token_endpoint, baseUrl || issuer).href, - token_endpoint_auth_methods_supported: ["client_secret_post"], - grant_types_supported: ["authorization_code", "refresh_token"], + token_endpoint: new URL(token_endpoint, baseUrl || issuer).href, + token_endpoint_auth_methods_supported: ['client_secret_post'], + grant_types_supported: ['authorization_code', 'refresh_token'], - scopes_supported: options.scopesSupported, + scopes_supported: options.scopesSupported, - revocation_endpoint: revocation_endpoint ? new URL(revocation_endpoint, baseUrl || issuer).href : undefined, - revocation_endpoint_auth_methods_supported: revocation_endpoint ? ["client_secret_post"] : undefined, + revocation_endpoint: revocation_endpoint ? new URL(revocation_endpoint, baseUrl || issuer).href : undefined, + revocation_endpoint_auth_methods_supported: revocation_endpoint ? ['client_secret_post'] : undefined, - registration_endpoint: registration_endpoint ? new URL(registration_endpoint, baseUrl || issuer).href : undefined, - }; + registration_endpoint: registration_endpoint ? new URL(registration_endpoint, baseUrl || issuer).href : undefined + }; - return metadata -} + return metadata; +}; /** * Installs standard MCP authorization server endpoints, including dynamic client registration and token revocation (if supported). @@ -114,100 +113,97 @@ export const createOAuthMetadata = (options: { * app.use(mcpAuthRouter(...)); */ export function mcpAuthRouter(options: AuthRouterOptions): RequestHandler { - const oauthMetadata = createOAuthMetadata(options); - - const router = express.Router(); - - router.use( - new URL(oauthMetadata.authorization_endpoint).pathname, - authorizationHandler({ provider: options.provider, ...options.authorizationOptions }) - ); - - router.use( - new URL(oauthMetadata.token_endpoint).pathname, - tokenHandler({ provider: options.provider, ...options.tokenOptions }) - ); - - router.use(mcpAuthMetadataRouter({ - oauthMetadata, - // This router is used for AS+RS combo's, so the issuer is also the resource server - resourceServerUrl: new URL(oauthMetadata.issuer), - serviceDocumentationUrl: options.serviceDocumentationUrl, - scopesSupported: options.scopesSupported, - resourceName: options.resourceName - })); - - if (oauthMetadata.registration_endpoint) { + const oauthMetadata = createOAuthMetadata(options); + + const router = express.Router(); + router.use( - new URL(oauthMetadata.registration_endpoint).pathname, - clientRegistrationHandler({ - clientsStore: options.provider.clientsStore, - ...options.clientRegistrationOptions, - }) + new URL(oauthMetadata.authorization_endpoint).pathname, + authorizationHandler({ provider: options.provider, ...options.authorizationOptions }) ); - } - if (oauthMetadata.revocation_endpoint) { + router.use(new URL(oauthMetadata.token_endpoint).pathname, tokenHandler({ provider: options.provider, ...options.tokenOptions })); + router.use( - new URL(oauthMetadata.revocation_endpoint).pathname, - revocationHandler({ provider: options.provider, ...options.revocationOptions }) + mcpAuthMetadataRouter({ + oauthMetadata, + // This router is used for AS+RS combo's, so the issuer is also the resource server + resourceServerUrl: new URL(oauthMetadata.issuer), + serviceDocumentationUrl: options.serviceDocumentationUrl, + scopesSupported: options.scopesSupported, + resourceName: options.resourceName + }) ); - } - return router; + if (oauthMetadata.registration_endpoint) { + router.use( + new URL(oauthMetadata.registration_endpoint).pathname, + clientRegistrationHandler({ + clientsStore: options.provider.clientsStore, + ...options.clientRegistrationOptions + }) + ); + } + + if (oauthMetadata.revocation_endpoint) { + router.use( + new URL(oauthMetadata.revocation_endpoint).pathname, + revocationHandler({ provider: options.provider, ...options.revocationOptions }) + ); + } + + return router; } export type AuthMetadataOptions = { - /** - * OAuth Metadata as would be returned from the authorization server - * this MCP server relies on - */ - oauthMetadata: OAuthMetadata; - - /** - * The url of the MCP server, for use in protected resource metadata - */ - resourceServerUrl: URL; - - /** - * The url for documentation for the MCP server - */ - serviceDocumentationUrl?: URL; - - /** - * An optional list of scopes supported by this MCP server - */ - scopesSupported?: string[]; - - /** - * An optional resource name to display in resource metadata - */ - resourceName?: string; -} + /** + * OAuth Metadata as would be returned from the authorization server + * this MCP server relies on + */ + oauthMetadata: OAuthMetadata; + + /** + * The url of the MCP server, for use in protected resource metadata + */ + resourceServerUrl: URL; + + /** + * The url for documentation for the MCP server + */ + serviceDocumentationUrl?: URL; + + /** + * An optional list of scopes supported by this MCP server + */ + scopesSupported?: string[]; + + /** + * An optional resource name to display in resource metadata + */ + resourceName?: string; +}; export function mcpAuthMetadataRouter(options: AuthMetadataOptions) { - checkIssuerUrl(new URL(options.oauthMetadata.issuer)); + checkIssuerUrl(new URL(options.oauthMetadata.issuer)); - const router = express.Router(); + const router = express.Router(); - const protectedResourceMetadata: OAuthProtectedResourceMetadata = { - resource: options.resourceServerUrl.href, + const protectedResourceMetadata: OAuthProtectedResourceMetadata = { + resource: options.resourceServerUrl.href, - authorization_servers: [ - options.oauthMetadata.issuer - ], + authorization_servers: [options.oauthMetadata.issuer], - scopes_supported: options.scopesSupported, - resource_name: options.resourceName, - resource_documentation: options.serviceDocumentationUrl?.href, - }; + scopes_supported: options.scopesSupported, + resource_name: options.resourceName, + resource_documentation: options.serviceDocumentationUrl?.href + }; - router.use("/.well-known/oauth-protected-resource", metadataHandler(protectedResourceMetadata)); + router.use('/.well-known/oauth-protected-resource', metadataHandler(protectedResourceMetadata)); - // Always add this for backwards compatibility - router.use("/.well-known/oauth-authorization-server", metadataHandler(options.oauthMetadata)); + // Always add this for backwards compatibility + router.use('/.well-known/oauth-authorization-server', metadataHandler(options.oauthMetadata)); - return router; + return router; } /** @@ -222,5 +218,5 @@ export function mcpAuthMetadataRouter(options: AuthMetadataOptions) { * // Returns: 'https://api.example.com/.well-known/oauth-protected-resource' */ export function getOAuthProtectedResourceMetadataUrl(serverUrl: URL): string { - return new URL('/.well-known/oauth-protected-resource', serverUrl).href; + return new URL('/.well-known/oauth-protected-resource', serverUrl).href; } diff --git a/src/server/auth/types.ts b/src/server/auth/types.ts index 0189e9ed8..a38a7e750 100644 --- a/src/server/auth/types.ts +++ b/src/server/auth/types.ts @@ -2,35 +2,35 @@ * Information about a validated access token, provided to request handlers. */ export interface AuthInfo { - /** - * The access token. - */ - token: string; + /** + * The access token. + */ + token: string; - /** - * The client ID associated with this token. - */ - clientId: string; + /** + * The client ID associated with this token. + */ + clientId: string; - /** - * Scopes associated with this token. - */ - scopes: string[]; + /** + * Scopes associated with this token. + */ + scopes: string[]; - /** - * When the token expires (in seconds since epoch). - */ - expiresAt?: number; + /** + * When the token expires (in seconds since epoch). + */ + expiresAt?: number; - /** - * The RFC 8707 resource server identifier for which this token is valid. - * If set, this MUST match the MCP server's resource identifier (minus hash fragment). - */ - resource?: URL; + /** + * The RFC 8707 resource server identifier for which this token is valid. + * If set, this MUST match the MCP server's resource identifier (minus hash fragment). + */ + resource?: URL; - /** - * Additional data associated with the token. - * This field should be used for any additional data that needs to be attached to the auth info. - */ - extra?: Record; -} \ No newline at end of file + /** + * Additional data associated with the token. + * This field should be used for any additional data that needs to be attached to the auth info. + */ + extra?: Record; +} diff --git a/src/server/completable.test.ts b/src/server/completable.test.ts index 6040ff3f6..b5effc272 100644 --- a/src/server/completable.test.ts +++ b/src/server/completable.test.ts @@ -1,46 +1,46 @@ -import { z } from "zod"; -import { completable } from "./completable.js"; +import { z } from 'zod'; +import { completable } from './completable.js'; -describe("completable", () => { - it("preserves types and values of underlying schema", () => { - const baseSchema = z.string(); - const schema = completable(baseSchema, () => []); +describe('completable', () => { + it('preserves types and values of underlying schema', () => { + const baseSchema = z.string(); + const schema = completable(baseSchema, () => []); - expect(schema.parse("test")).toBe("test"); - expect(() => schema.parse(123)).toThrow(); - }); + expect(schema.parse('test')).toBe('test'); + expect(() => schema.parse(123)).toThrow(); + }); - it("provides access to completion function", async () => { - const completions = ["foo", "bar", "baz"]; - const schema = completable(z.string(), () => completions); + it('provides access to completion function', async () => { + const completions = ['foo', 'bar', 'baz']; + const schema = completable(z.string(), () => completions); - expect(await schema._def.complete("")).toEqual(completions); - }); + expect(await schema._def.complete('')).toEqual(completions); + }); - it("allows async completion functions", async () => { - const completions = ["foo", "bar", "baz"]; - const schema = completable(z.string(), async () => completions); + it('allows async completion functions', async () => { + const completions = ['foo', 'bar', 'baz']; + const schema = completable(z.string(), async () => completions); - expect(await schema._def.complete("")).toEqual(completions); - }); + expect(await schema._def.complete('')).toEqual(completions); + }); - it("passes current value to completion function", async () => { - const schema = completable(z.string(), (value) => [value + "!"]); + it('passes current value to completion function', async () => { + const schema = completable(z.string(), value => [value + '!']); - expect(await schema._def.complete("test")).toEqual(["test!"]); - }); + expect(await schema._def.complete('test')).toEqual(['test!']); + }); - it("works with number schemas", async () => { - const schema = completable(z.number(), () => [1, 2, 3]); + it('works with number schemas', async () => { + const schema = completable(z.number(), () => [1, 2, 3]); - expect(schema.parse(1)).toBe(1); - expect(await schema._def.complete(0)).toEqual([1, 2, 3]); - }); + expect(schema.parse(1)).toBe(1); + expect(await schema._def.complete(0)).toEqual([1, 2, 3]); + }); - it("preserves schema description", () => { - const desc = "test description"; - const schema = completable(z.string().describe(desc), () => []); + it('preserves schema description', () => { + const desc = 'test description'; + const schema = completable(z.string().describe(desc), () => []); - expect(schema.description).toBe(desc); - }); + expect(schema.description).toBe(desc); + }); }); diff --git a/src/server/completable.ts b/src/server/completable.ts index 652eaf72e..67d91c383 100644 --- a/src/server/completable.ts +++ b/src/server/completable.ts @@ -1,98 +1,79 @@ -import { - ZodTypeAny, - ZodTypeDef, - ZodType, - ParseInput, - ParseReturnType, - RawCreateParams, - ZodErrorMap, - ProcessedCreateParams, -} from "zod"; +import { ZodTypeAny, ZodTypeDef, ZodType, ParseInput, ParseReturnType, RawCreateParams, ZodErrorMap, ProcessedCreateParams } from 'zod'; export enum McpZodTypeKind { - Completable = "McpCompletable", + Completable = 'McpCompletable' } export type CompleteCallback = ( - value: T["_input"], - context?: { - arguments?: Record; - }, -) => T["_input"][] | Promise; + value: T['_input'], + context?: { + arguments?: Record; + } +) => T['_input'][] | Promise; -export interface CompletableDef - extends ZodTypeDef { - type: T; - complete: CompleteCallback; - typeName: McpZodTypeKind.Completable; +export interface CompletableDef extends ZodTypeDef { + type: T; + complete: CompleteCallback; + typeName: McpZodTypeKind.Completable; } -export class Completable extends ZodType< - T["_output"], - CompletableDef, - T["_input"] -> { - _parse(input: ParseInput): ParseReturnType { - const { ctx } = this._processInputParams(input); - const data = ctx.data; - return this._def.type._parse({ - data, - path: ctx.path, - parent: ctx, - }); - } +export class Completable extends ZodType, T['_input']> { + _parse(input: ParseInput): ParseReturnType { + const { ctx } = this._processInputParams(input); + const data = ctx.data; + return this._def.type._parse({ + data, + path: ctx.path, + parent: ctx + }); + } - unwrap() { - return this._def.type; - } + unwrap() { + return this._def.type; + } - static create = ( - type: T, - params: RawCreateParams & { - complete: CompleteCallback; - }, - ): Completable => { - return new Completable({ - type, - typeName: McpZodTypeKind.Completable, - complete: params.complete, - ...processCreateParams(params), - }); - }; + static create = ( + type: T, + params: RawCreateParams & { + complete: CompleteCallback; + } + ): Completable => { + return new Completable({ + type, + typeName: McpZodTypeKind.Completable, + complete: params.complete, + ...processCreateParams(params) + }); + }; } /** * Wraps a Zod type to provide autocompletion capabilities. Useful for, e.g., prompt arguments in MCP. */ -export function completable( - schema: T, - complete: CompleteCallback, -): Completable { - return Completable.create(schema, { ...schema._def, complete }); +export function completable(schema: T, complete: CompleteCallback): Completable { + return Completable.create(schema, { ...schema._def, complete }); } // Not sure why this isn't exported from Zod: // https://github.com/colinhacks/zod/blob/f7ad26147ba291cb3fb257545972a8e00e767470/src/types.ts#L130 function processCreateParams(params: RawCreateParams): ProcessedCreateParams { - if (!params) return {}; - const { errorMap, invalid_type_error, required_error, description } = params; - if (errorMap && (invalid_type_error || required_error)) { - throw new Error( - `Can't use "invalid_type_error" or "required_error" in conjunction with custom error map.`, - ); - } - if (errorMap) return { errorMap: errorMap, description }; - const customMap: ZodErrorMap = (iss, ctx) => { - const { message } = params; - - if (iss.code === "invalid_enum_value") { - return { message: message ?? ctx.defaultError }; - } - if (typeof ctx.data === "undefined") { - return { message: message ?? required_error ?? ctx.defaultError }; + if (!params) return {}; + const { errorMap, invalid_type_error, required_error, description } = params; + if (errorMap && (invalid_type_error || required_error)) { + throw new Error(`Can't use "invalid_type_error" or "required_error" in conjunction with custom error map.`); } - if (iss.code !== "invalid_type") return { message: ctx.defaultError }; - return { message: message ?? invalid_type_error ?? ctx.defaultError }; - }; - return { errorMap: customMap, description }; + if (errorMap) return { errorMap: errorMap, description }; + const customMap: ZodErrorMap = (iss, ctx) => { + const { message } = params; + + if (iss.code === 'invalid_enum_value') { + return { message: message ?? ctx.defaultError }; + } + if (typeof ctx.data === 'undefined') { + return { message: message ?? required_error ?? ctx.defaultError }; + } + if (iss.code !== 'invalid_type') return { message: ctx.defaultError }; + return { message: message ?? invalid_type_error ?? ctx.defaultError }; + }; + return { errorMap: customMap, description }; } diff --git a/src/server/index.test.ts b/src/server/index.test.ts index 664ed4520..d056707fe 100644 --- a/src/server/index.test.ts +++ b/src/server/index.test.ts @@ -1,927 +1,879 @@ /* eslint-disable @typescript-eslint/no-unused-vars */ /* eslint-disable no-constant-binary-expression */ /* eslint-disable @typescript-eslint/no-unused-expressions */ -import { Server } from "./index.js"; -import { z } from "zod"; +import { Server } from './index.js'; +import { z } from 'zod'; import { - RequestSchema, - NotificationSchema, - ResultSchema, - LATEST_PROTOCOL_VERSION, - SUPPORTED_PROTOCOL_VERSIONS, - CreateMessageRequestSchema, - ElicitRequestSchema, - ListPromptsRequestSchema, - ListResourcesRequestSchema, - ListToolsRequestSchema, - SetLevelRequestSchema, - ErrorCode, - LoggingMessageNotification -} from "../types.js"; -import { Transport } from "../shared/transport.js"; -import { InMemoryTransport } from "../inMemory.js"; -import { Client } from "../client/index.js"; - -test("should accept latest protocol version", async () => { - let sendPromiseResolve: (value: unknown) => void; - const sendPromise = new Promise((resolve) => { - sendPromiseResolve = resolve; - }); - - const serverTransport: Transport = { - start: jest.fn().mockResolvedValue(undefined), - close: jest.fn().mockResolvedValue(undefined), - send: jest.fn().mockImplementation((message) => { - if (message.id === 1 && message.result) { - expect(message.result).toEqual({ - protocolVersion: LATEST_PROTOCOL_VERSION, - capabilities: expect.any(Object), - serverInfo: { - name: "test server", - version: "1.0", - }, - instructions: "Test instructions", - }); - sendPromiseResolve(undefined); - } - return Promise.resolve(); - }), - }; - - const server = new Server( - { - name: "test server", - version: "1.0", - }, - { - capabilities: { - prompts: {}, - resources: {}, - tools: {}, - logging: {}, - }, - instructions: "Test instructions", - }, - ); - - await server.connect(serverTransport); - - // Simulate initialize request with latest version - serverTransport.onmessage?.({ - jsonrpc: "2.0", - id: 1, - method: "initialize", - params: { - protocolVersion: LATEST_PROTOCOL_VERSION, - capabilities: {}, - clientInfo: { - name: "test client", - version: "1.0", - }, - }, - }); - - await expect(sendPromise).resolves.toBeUndefined(); -}); + RequestSchema, + NotificationSchema, + ResultSchema, + LATEST_PROTOCOL_VERSION, + SUPPORTED_PROTOCOL_VERSIONS, + CreateMessageRequestSchema, + ElicitRequestSchema, + ListPromptsRequestSchema, + ListResourcesRequestSchema, + ListToolsRequestSchema, + SetLevelRequestSchema, + ErrorCode, + LoggingMessageNotification +} from '../types.js'; +import { Transport } from '../shared/transport.js'; +import { InMemoryTransport } from '../inMemory.js'; +import { Client } from '../client/index.js'; + +test('should accept latest protocol version', async () => { + let sendPromiseResolve: (value: unknown) => void; + const sendPromise = new Promise(resolve => { + sendPromiseResolve = resolve; + }); -test("should accept supported older protocol version", async () => { - const OLD_VERSION = SUPPORTED_PROTOCOL_VERSIONS[1]; - let sendPromiseResolve: (value: unknown) => void; - const sendPromise = new Promise((resolve) => { - sendPromiseResolve = resolve; - }); - - const serverTransport: Transport = { - start: jest.fn().mockResolvedValue(undefined), - close: jest.fn().mockResolvedValue(undefined), - send: jest.fn().mockImplementation((message) => { - if (message.id === 1 && message.result) { - expect(message.result).toEqual({ - protocolVersion: OLD_VERSION, - capabilities: expect.any(Object), - serverInfo: { - name: "test server", - version: "1.0", - }, - }); - sendPromiseResolve(undefined); - } - return Promise.resolve(); - }), - }; - - const server = new Server( - { - name: "test server", - version: "1.0", - }, - { - capabilities: { - prompts: {}, - resources: {}, - tools: {}, - logging: {}, - }, - }, - ); - - await server.connect(serverTransport); - - // Simulate initialize request with older version - serverTransport.onmessage?.({ - jsonrpc: "2.0", - id: 1, - method: "initialize", - params: { - protocolVersion: OLD_VERSION, - capabilities: {}, - clientInfo: { - name: "test client", - version: "1.0", - }, - }, - }); - - await expect(sendPromise).resolves.toBeUndefined(); -}); + const serverTransport: Transport = { + start: jest.fn().mockResolvedValue(undefined), + close: jest.fn().mockResolvedValue(undefined), + send: jest.fn().mockImplementation(message => { + if (message.id === 1 && message.result) { + expect(message.result).toEqual({ + protocolVersion: LATEST_PROTOCOL_VERSION, + capabilities: expect.any(Object), + serverInfo: { + name: 'test server', + version: '1.0' + }, + instructions: 'Test instructions' + }); + sendPromiseResolve(undefined); + } + return Promise.resolve(); + }) + }; -test("should handle unsupported protocol version", async () => { - let sendPromiseResolve: (value: unknown) => void; - const sendPromise = new Promise((resolve) => { - sendPromiseResolve = resolve; - }); - - const serverTransport: Transport = { - start: jest.fn().mockResolvedValue(undefined), - close: jest.fn().mockResolvedValue(undefined), - send: jest.fn().mockImplementation((message) => { - if (message.id === 1 && message.result) { - expect(message.result).toEqual({ - protocolVersion: LATEST_PROTOCOL_VERSION, - capabilities: expect.any(Object), - serverInfo: { - name: "test server", - version: "1.0", - }, - }); - sendPromiseResolve(undefined); - } - return Promise.resolve(); - }), - }; - - const server = new Server( - { - name: "test server", - version: "1.0", - }, - { - capabilities: { - prompts: {}, - resources: {}, - tools: {}, - logging: {}, - }, - }, - ); - - await server.connect(serverTransport); - - // Simulate initialize request with unsupported version - serverTransport.onmessage?.({ - jsonrpc: "2.0", - id: 1, - method: "initialize", - params: { - protocolVersion: "invalid-version", - capabilities: {}, - clientInfo: { - name: "test client", - version: "1.0", - }, - }, - }); - - await expect(sendPromise).resolves.toBeUndefined(); + const server = new Server( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: { + prompts: {}, + resources: {}, + tools: {}, + logging: {} + }, + instructions: 'Test instructions' + } + ); + + await server.connect(serverTransport); + + // Simulate initialize request with latest version + serverTransport.onmessage?.({ + jsonrpc: '2.0', + id: 1, + method: 'initialize', + params: { + protocolVersion: LATEST_PROTOCOL_VERSION, + capabilities: {}, + clientInfo: { + name: 'test client', + version: '1.0' + } + } + }); + + await expect(sendPromise).resolves.toBeUndefined(); }); -test("should respect client capabilities", async () => { - const server = new Server( - { - name: "test server", - version: "1.0", - }, - { - capabilities: { - prompts: {}, - resources: {}, - tools: {}, - logging: {}, - }, - enforceStrictCapabilities: true, - }, - ); - - const client = new Client( - { - name: "test client", - version: "1.0", - }, - { - capabilities: { - sampling: {}, - }, - }, - ); - - // Implement request handler for sampling/createMessage - client.setRequestHandler(CreateMessageRequestSchema, async (request) => { - // Mock implementation of createMessage - return { - model: "test-model", - role: "assistant", - content: { - type: "text", - text: "This is a test response", - }, +test('should accept supported older protocol version', async () => { + const OLD_VERSION = SUPPORTED_PROTOCOL_VERSIONS[1]; + let sendPromiseResolve: (value: unknown) => void; + const sendPromise = new Promise(resolve => { + sendPromiseResolve = resolve; + }); + + const serverTransport: Transport = { + start: jest.fn().mockResolvedValue(undefined), + close: jest.fn().mockResolvedValue(undefined), + send: jest.fn().mockImplementation(message => { + if (message.id === 1 && message.result) { + expect(message.result).toEqual({ + protocolVersion: OLD_VERSION, + capabilities: expect.any(Object), + serverInfo: { + name: 'test server', + version: '1.0' + } + }); + sendPromiseResolve(undefined); + } + return Promise.resolve(); + }) }; - }); - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); + const server = new Server( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: { + prompts: {}, + resources: {}, + tools: {}, + logging: {} + } + } + ); + + await server.connect(serverTransport); + + // Simulate initialize request with older version + serverTransport.onmessage?.({ + jsonrpc: '2.0', + id: 1, + method: 'initialize', + params: { + protocolVersion: OLD_VERSION, + capabilities: {}, + clientInfo: { + name: 'test client', + version: '1.0' + } + } + }); - await Promise.all([ - client.connect(clientTransport), - server.connect(serverTransport), - ]); + await expect(sendPromise).resolves.toBeUndefined(); +}); - expect(server.getClientCapabilities()).toEqual({ sampling: {} }); +test('should handle unsupported protocol version', async () => { + let sendPromiseResolve: (value: unknown) => void; + const sendPromise = new Promise(resolve => { + sendPromiseResolve = resolve; + }); - // This should work because sampling is supported by the client - await expect( - server.createMessage({ - messages: [], - maxTokens: 10, - }), - ).resolves.not.toThrow(); + const serverTransport: Transport = { + start: jest.fn().mockResolvedValue(undefined), + close: jest.fn().mockResolvedValue(undefined), + send: jest.fn().mockImplementation(message => { + if (message.id === 1 && message.result) { + expect(message.result).toEqual({ + protocolVersion: LATEST_PROTOCOL_VERSION, + capabilities: expect.any(Object), + serverInfo: { + name: 'test server', + version: '1.0' + } + }); + sendPromiseResolve(undefined); + } + return Promise.resolve(); + }) + }; - // This should still throw because roots are not supported by the client - await expect(server.listRoots()).rejects.toThrow(/^Client does not support/); + const server = new Server( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: { + prompts: {}, + resources: {}, + tools: {}, + logging: {} + } + } + ); + + await server.connect(serverTransport); + + // Simulate initialize request with unsupported version + serverTransport.onmessage?.({ + jsonrpc: '2.0', + id: 1, + method: 'initialize', + params: { + protocolVersion: 'invalid-version', + capabilities: {}, + clientInfo: { + name: 'test client', + version: '1.0' + } + } + }); + + await expect(sendPromise).resolves.toBeUndefined(); }); -test("should respect client elicitation capabilities", async () => { - const server = new Server( - { - name: "test server", - version: "1.0", - }, - { - capabilities: { - prompts: {}, - resources: {}, - tools: {}, - logging: {}, - }, - enforceStrictCapabilities: true, - }, - ); - - const client = new Client( - { - name: "test client", - version: "1.0", - }, - { - capabilities: { - elicitation: {}, - }, - }, - ); - - client.setRequestHandler(ElicitRequestSchema, (params) => ({ - action: "accept", - content: { - username: params.params.message.includes("username") ? "test-user" : undefined, - confirmed: true, - }, - })); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - server.connect(serverTransport), - ]); - - expect(server.getClientCapabilities()).toEqual({ elicitation: {} }); - - // This should work because elicitation is supported by the client - await expect( - server.elicitInput({ - message: "Please provide your username", - requestedSchema: { - type: "object", - properties: { - username: { - type: "string", - title: "Username", - description: "Your username", - }, - confirmed: { - type: "boolean", - title: "Confirm", - description: "Please confirm", - default: false, - }, +test('should respect client capabilities', async () => { + const server = new Server( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: { + prompts: {}, + resources: {}, + tools: {}, + logging: {} + }, + enforceStrictCapabilities: true + } + ); + + const client = new Client( + { + name: 'test client', + version: '1.0' }, - required: ["username"], - }, - }), - ).resolves.toEqual({ - action: "accept", - content: { - username: "test-user", - confirmed: true, - }, - }); - - // This should still throw because sampling is not supported by the client - await expect( - server.createMessage({ - messages: [], - maxTokens: 10, - }), - ).rejects.toThrow(/^Client does not support/); + { + capabilities: { + sampling: {} + } + } + ); + + // Implement request handler for sampling/createMessage + client.setRequestHandler(CreateMessageRequestSchema, async request => { + // Mock implementation of createMessage + return { + model: 'test-model', + role: 'assistant', + content: { + type: 'text', + text: 'This is a test response' + } + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + expect(server.getClientCapabilities()).toEqual({ sampling: {} }); + + // This should work because sampling is supported by the client + await expect( + server.createMessage({ + messages: [], + maxTokens: 10 + }) + ).resolves.not.toThrow(); + + // This should still throw because roots are not supported by the client + await expect(server.listRoots()).rejects.toThrow(/^Client does not support/); }); -test("should validate elicitation response against requested schema", async () => { - const server = new Server( - { - name: "test server", - version: "1.0", - }, - { - capabilities: { - prompts: {}, - resources: {}, - tools: {}, - logging: {}, - }, - enforceStrictCapabilities: true, - }, - ); - - const client = new Client( - { - name: "test client", - version: "1.0", - }, - { - capabilities: { - elicitation: {}, - }, - }, - ); - - // Set up client to return valid response - client.setRequestHandler(ElicitRequestSchema, (request) => ({ - action: "accept", - content: { - name: "John Doe", - email: "john@example.com", - age: 30, - }, - })); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - server.connect(serverTransport), - ]); - - // Test with valid response - await expect( - server.elicitInput({ - message: "Please provide your information", - requestedSchema: { - type: "object", - properties: { - name: { - type: "string", - minLength: 1, - }, - email: { - type: "string", - minLength: 1, - }, - age: { - type: "integer", - minimum: 0, - maximum: 150, - }, +test('should respect client elicitation capabilities', async () => { + const server = new Server( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: { + prompts: {}, + resources: {}, + tools: {}, + logging: {} + }, + enforceStrictCapabilities: true + } + ); + + const client = new Client( + { + name: 'test client', + version: '1.0' }, - required: ["name", "email"], - }, - }), - ).resolves.toEqual({ - action: "accept", - content: { - name: "John Doe", - email: "john@example.com", - age: 30, - }, - }); + { + capabilities: { + elicitation: {} + } + } + ); + + client.setRequestHandler(ElicitRequestSchema, params => ({ + action: 'accept', + content: { + username: params.params.message.includes('username') ? 'test-user' : undefined, + confirmed: true + } + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + expect(server.getClientCapabilities()).toEqual({ elicitation: {} }); + + // This should work because elicitation is supported by the client + await expect( + server.elicitInput({ + message: 'Please provide your username', + requestedSchema: { + type: 'object', + properties: { + username: { + type: 'string', + title: 'Username', + description: 'Your username' + }, + confirmed: { + type: 'boolean', + title: 'Confirm', + description: 'Please confirm', + default: false + } + }, + required: ['username'] + } + }) + ).resolves.toEqual({ + action: 'accept', + content: { + username: 'test-user', + confirmed: true + } + }); + + // This should still throw because sampling is not supported by the client + await expect( + server.createMessage({ + messages: [], + maxTokens: 10 + }) + ).rejects.toThrow(/^Client does not support/); }); -test("should reject elicitation response with invalid data", async () => { - const server = new Server( - { - name: "test server", - version: "1.0", - }, - { - capabilities: { - prompts: {}, - resources: {}, - tools: {}, - logging: {}, - }, - enforceStrictCapabilities: true, - }, - ); - - const client = new Client( - { - name: "test client", - version: "1.0", - }, - { - capabilities: { - elicitation: {}, - }, - }, - ); - - // Set up client to return invalid response (missing required field, invalid age) - client.setRequestHandler(ElicitRequestSchema, (request) => ({ - action: "accept", - content: { - email: "", // Invalid - too short - age: -5, // Invalid age - }, - })); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - server.connect(serverTransport), - ]); - - // Test with invalid response - await expect( - server.elicitInput({ - message: "Please provide your information", - requestedSchema: { - type: "object", - properties: { - name: { - type: "string", - minLength: 1, - }, - email: { - type: "string", - minLength: 1, - }, - age: { - type: "integer", - minimum: 0, - maximum: 150, - }, +test('should validate elicitation response against requested schema', async () => { + const server = new Server( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: { + prompts: {}, + resources: {}, + tools: {}, + logging: {} + }, + enforceStrictCapabilities: true + } + ); + + const client = new Client( + { + name: 'test client', + version: '1.0' }, - required: ["name", "email"], - }, - }), - ).rejects.toThrow(/does not match requested schema/); + { + capabilities: { + elicitation: {} + } + } + ); + + // Set up client to return valid response + client.setRequestHandler(ElicitRequestSchema, request => ({ + action: 'accept', + content: { + name: 'John Doe', + email: 'john@example.com', + age: 30 + } + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Test with valid response + await expect( + server.elicitInput({ + message: 'Please provide your information', + requestedSchema: { + type: 'object', + properties: { + name: { + type: 'string', + minLength: 1 + }, + email: { + type: 'string', + minLength: 1 + }, + age: { + type: 'integer', + minimum: 0, + maximum: 150 + } + }, + required: ['name', 'email'] + } + }) + ).resolves.toEqual({ + action: 'accept', + content: { + name: 'John Doe', + email: 'john@example.com', + age: 30 + } + }); }); -test("should allow elicitation reject and cancel without validation", async () => { - const server = new Server( - { - name: "test server", - version: "1.0", - }, - { - capabilities: { - prompts: {}, - resources: {}, - tools: {}, - logging: {}, - }, - enforceStrictCapabilities: true, - }, - ); - - const client = new Client( - { - name: "test client", - version: "1.0", - }, - { - capabilities: { - elicitation: {}, - }, - }, - ); - - let requestCount = 0; - client.setRequestHandler(ElicitRequestSchema, (request) => { - requestCount++; - if (requestCount === 1) { - return { action: "decline" }; - } else { - return { action: "cancel" }; - } - }); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - server.connect(serverTransport), - ]); - - const schema = { - type: "object" as const, - properties: { - name: { type: "string" as const }, - }, - required: ["name"], - }; - - // Test reject - should not validate - await expect( - server.elicitInput({ - message: "Please provide your name", - requestedSchema: schema, - }), - ).resolves.toEqual({ - action: "decline", - }); - - // Test cancel - should not validate - await expect( - server.elicitInput({ - message: "Please provide your name", - requestedSchema: schema, - }), - ).resolves.toEqual({ - action: "cancel", - }); +test('should reject elicitation response with invalid data', async () => { + const server = new Server( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: { + prompts: {}, + resources: {}, + tools: {}, + logging: {} + }, + enforceStrictCapabilities: true + } + ); + + const client = new Client( + { + name: 'test client', + version: '1.0' + }, + { + capabilities: { + elicitation: {} + } + } + ); + + // Set up client to return invalid response (missing required field, invalid age) + client.setRequestHandler(ElicitRequestSchema, request => ({ + action: 'accept', + content: { + email: '', // Invalid - too short + age: -5 // Invalid age + } + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Test with invalid response + await expect( + server.elicitInput({ + message: 'Please provide your information', + requestedSchema: { + type: 'object', + properties: { + name: { + type: 'string', + minLength: 1 + }, + email: { + type: 'string', + minLength: 1 + }, + age: { + type: 'integer', + minimum: 0, + maximum: 150 + } + }, + required: ['name', 'email'] + } + }) + ).rejects.toThrow(/does not match requested schema/); }); -test("should respect server notification capabilities", async () => { - const server = new Server( - { - name: "test server", - version: "1.0", - }, - { - capabilities: { - logging: {}, - }, - enforceStrictCapabilities: true, - }, - ); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await server.connect(serverTransport); - - // This should work because logging is supported by the server - await expect( - server.sendLoggingMessage({ - level: "info", - data: "Test log message", - }), - ).resolves.not.toThrow(); - - // This should throw because resource notificaitons are not supported by the server - await expect( - server.sendResourceUpdated({ uri: "test://resource" }), - ).rejects.toThrow(/^Server does not support/); +test('should allow elicitation reject and cancel without validation', async () => { + const server = new Server( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: { + prompts: {}, + resources: {}, + tools: {}, + logging: {} + }, + enforceStrictCapabilities: true + } + ); + + const client = new Client( + { + name: 'test client', + version: '1.0' + }, + { + capabilities: { + elicitation: {} + } + } + ); + + let requestCount = 0; + client.setRequestHandler(ElicitRequestSchema, request => { + requestCount++; + if (requestCount === 1) { + return { action: 'decline' }; + } else { + return { action: 'cancel' }; + } + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + const schema = { + type: 'object' as const, + properties: { + name: { type: 'string' as const } + }, + required: ['name'] + }; + + // Test reject - should not validate + await expect( + server.elicitInput({ + message: 'Please provide your name', + requestedSchema: schema + }) + ).resolves.toEqual({ + action: 'decline' + }); + + // Test cancel - should not validate + await expect( + server.elicitInput({ + message: 'Please provide your name', + requestedSchema: schema + }) + ).resolves.toEqual({ + action: 'cancel' + }); }); -test("should only allow setRequestHandler for declared capabilities", () => { - const server = new Server( - { - name: "test server", - version: "1.0", - }, - { - capabilities: { - prompts: {}, - resources: {}, - }, - }, - ); - - // These should work because the capabilities are declared - expect(() => { - server.setRequestHandler(ListPromptsRequestSchema, () => ({ prompts: [] })); - }).not.toThrow(); - - expect(() => { - server.setRequestHandler(ListResourcesRequestSchema, () => ({ - resources: [], - })); - }).not.toThrow(); +test('should respect server notification capabilities', async () => { + const server = new Server( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: { + logging: {} + }, + enforceStrictCapabilities: true + } + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await server.connect(serverTransport); + + // This should work because logging is supported by the server + await expect( + server.sendLoggingMessage({ + level: 'info', + data: 'Test log message' + }) + ).resolves.not.toThrow(); - // These should throw because the capabilities are not declared - expect(() => { - server.setRequestHandler(ListToolsRequestSchema, () => ({ tools: [] })); - }).toThrow(/^Server does not support tools/); + // This should throw because resource notificaitons are not supported by the server + await expect(server.sendResourceUpdated({ uri: 'test://resource' })).rejects.toThrow(/^Server does not support/); +}); + +test('should only allow setRequestHandler for declared capabilities', () => { + const server = new Server( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: { + prompts: {}, + resources: {} + } + } + ); - expect(() => { - server.setRequestHandler(SetLevelRequestSchema, () => ({})); - }).toThrow(/^Server does not support logging/); + // These should work because the capabilities are declared + expect(() => { + server.setRequestHandler(ListPromptsRequestSchema, () => ({ prompts: [] })); + }).not.toThrow(); + + expect(() => { + server.setRequestHandler(ListResourcesRequestSchema, () => ({ + resources: [] + })); + }).not.toThrow(); + + // These should throw because the capabilities are not declared + expect(() => { + server.setRequestHandler(ListToolsRequestSchema, () => ({ tools: [] })); + }).toThrow(/^Server does not support tools/); + + expect(() => { + server.setRequestHandler(SetLevelRequestSchema, () => ({})); + }).toThrow(/^Server does not support logging/); }); /* Test that custom request/notification/result schemas can be used with the Server class. */ -test("should typecheck", () => { - const GetWeatherRequestSchema = RequestSchema.extend({ - method: z.literal("weather/get"), - params: z.object({ - city: z.string(), - }), - }); - - const GetForecastRequestSchema = RequestSchema.extend({ - method: z.literal("weather/forecast"), - params: z.object({ - city: z.string(), - days: z.number(), - }), - }); - - const WeatherForecastNotificationSchema = NotificationSchema.extend({ - method: z.literal("weather/alert"), - params: z.object({ - severity: z.enum(["warning", "watch"]), - message: z.string(), - }), - }); - - const WeatherRequestSchema = GetWeatherRequestSchema.or( - GetForecastRequestSchema, - ); - const WeatherNotificationSchema = WeatherForecastNotificationSchema; - const WeatherResultSchema = ResultSchema.extend({ - temperature: z.number(), - conditions: z.string(), - }); - - type WeatherRequest = z.infer; - type WeatherNotification = z.infer; - type WeatherResult = z.infer; - - // Create a typed Server for weather data - const weatherServer = new Server< - WeatherRequest, - WeatherNotification, - WeatherResult - >( - { - name: "WeatherServer", - version: "1.0.0", - }, - { - capabilities: { - prompts: {}, - resources: {}, - tools: {}, - logging: {}, - }, - }, - ); - - // Typecheck that only valid weather requests/notifications/results are allowed - weatherServer.setRequestHandler(GetWeatherRequestSchema, (request) => { - return { - temperature: 72, - conditions: "sunny", - }; - }); - - weatherServer.setNotificationHandler( - WeatherForecastNotificationSchema, - (notification) => { - console.log(`Weather alert: ${notification.params.message}`); - }, - ); +test('should typecheck', () => { + const GetWeatherRequestSchema = RequestSchema.extend({ + method: z.literal('weather/get'), + params: z.object({ + city: z.string() + }) + }); + + const GetForecastRequestSchema = RequestSchema.extend({ + method: z.literal('weather/forecast'), + params: z.object({ + city: z.string(), + days: z.number() + }) + }); + + const WeatherForecastNotificationSchema = NotificationSchema.extend({ + method: z.literal('weather/alert'), + params: z.object({ + severity: z.enum(['warning', 'watch']), + message: z.string() + }) + }); + + const WeatherRequestSchema = GetWeatherRequestSchema.or(GetForecastRequestSchema); + const WeatherNotificationSchema = WeatherForecastNotificationSchema; + const WeatherResultSchema = ResultSchema.extend({ + temperature: z.number(), + conditions: z.string() + }); + + type WeatherRequest = z.infer; + type WeatherNotification = z.infer; + type WeatherResult = z.infer; + + // Create a typed Server for weather data + const weatherServer = new Server( + { + name: 'WeatherServer', + version: '1.0.0' + }, + { + capabilities: { + prompts: {}, + resources: {}, + tools: {}, + logging: {} + } + } + ); + + // Typecheck that only valid weather requests/notifications/results are allowed + weatherServer.setRequestHandler(GetWeatherRequestSchema, request => { + return { + temperature: 72, + conditions: 'sunny' + }; + }); + + weatherServer.setNotificationHandler(WeatherForecastNotificationSchema, notification => { + console.log(`Weather alert: ${notification.params.message}`); + }); }); -test("should handle server cancelling a request", async () => { - const server = new Server( - { - name: "test server", - version: "1.0", - }, - { - capabilities: { - sampling: {}, - }, - }, - ); - - const client = new Client( - { - name: "test client", - version: "1.0", - }, - { - capabilities: { - sampling: {}, - }, - }, - ); - - // Set up client to delay responding to createMessage - client.setRequestHandler( - CreateMessageRequestSchema, - async (_request, extra) => { - await new Promise((resolve) => setTimeout(resolve, 1000)); - return { - model: "test", - role: "assistant", - content: { - type: "text", - text: "Test response", +test('should handle server cancelling a request', async () => { + const server = new Server( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: { + sampling: {} + } + } + ); + + const client = new Client( + { + name: 'test client', + version: '1.0' + }, + { + capabilities: { + sampling: {} + } + } + ); + + // Set up client to delay responding to createMessage + client.setRequestHandler(CreateMessageRequestSchema, async (_request, extra) => { + await new Promise(resolve => setTimeout(resolve, 1000)); + return { + model: 'test', + role: 'assistant', + content: { + type: 'text', + text: 'Test response' + } + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Set up abort controller + const controller = new AbortController(); + + // Issue request but cancel it immediately + const createMessagePromise = server.createMessage( + { + messages: [], + maxTokens: 10 }, - }; - }, - ); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - server.connect(serverTransport), - ]); - - // Set up abort controller - const controller = new AbortController(); - - // Issue request but cancel it immediately - const createMessagePromise = server.createMessage( - { - messages: [], - maxTokens: 10, - }, - { - signal: controller.signal, - }, - ); - controller.abort("Cancelled by test"); - - // Request should be rejected - await expect(createMessagePromise).rejects.toBe("Cancelled by test"); + { + signal: controller.signal + } + ); + controller.abort('Cancelled by test'); + + // Request should be rejected + await expect(createMessagePromise).rejects.toBe('Cancelled by test'); }); -test("should handle request timeout", async () => { - const server = new Server( - { - name: "test server", - version: "1.0", - }, - { - capabilities: { - sampling: {}, - }, - }, - ); - - // Set up client that delays responses - const client = new Client( - { - name: "test client", - version: "1.0", - }, - { - capabilities: { - sampling: {}, - }, - }, - ); - - client.setRequestHandler( - CreateMessageRequestSchema, - async (_request, extra) => { - await new Promise((resolve, reject) => { - const timeout = setTimeout(resolve, 100); - extra.signal.addEventListener("abort", () => { - clearTimeout(timeout); - reject(extra.signal.reason); - }); - }); +test('should handle request timeout', async () => { + const server = new Server( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: { + sampling: {} + } + } + ); - return { - model: "test", - role: "assistant", - content: { - type: "text", - text: "Test response", + // Set up client that delays responses + const client = new Client( + { + name: 'test client', + version: '1.0' }, - }; - }, - ); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - server.connect(serverTransport), - ]); - - // Request with 0 msec timeout should fail immediately - await expect( - server.createMessage( - { - messages: [], - maxTokens: 10, - }, - { timeout: 0 }, - ), - ).rejects.toMatchObject({ - code: ErrorCode.RequestTimeout, - }); + { + capabilities: { + sampling: {} + } + } + ); + + client.setRequestHandler(CreateMessageRequestSchema, async (_request, extra) => { + await new Promise((resolve, reject) => { + const timeout = setTimeout(resolve, 100); + extra.signal.addEventListener('abort', () => { + clearTimeout(timeout); + reject(extra.signal.reason); + }); + }); + + return { + model: 'test', + role: 'assistant', + content: { + type: 'text', + text: 'Test response' + } + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Request with 0 msec timeout should fail immediately + await expect( + server.createMessage( + { + messages: [], + maxTokens: 10 + }, + { timeout: 0 } + ) + ).rejects.toMatchObject({ + code: ErrorCode.RequestTimeout + }); }); /* Test automatic log level handling for transports with and without sessionId */ -test("should respect log level for transport without sessionId", async () => { - +test('should respect log level for transport without sessionId', async () => { const server = new Server( { - name: "test server", - version: "1.0", + name: 'test server', + version: '1.0' }, { capabilities: { prompts: {}, resources: {}, tools: {}, - logging: {}, + logging: {} }, - enforceStrictCapabilities: true, - }, + enforceStrictCapabilities: true + } ); - const client = new Client( - { - name: "test client", - version: "1.0", - }, - ); + const client = new Client({ + name: 'test client', + version: '1.0' + }); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - await Promise.all([ - client.connect(clientTransport), - server.connect(serverTransport), - ]); + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); expect(clientTransport.sessionId).toEqual(undefined); // Client sets logging level to warning - await client.setLoggingLevel("warning"); + await client.setLoggingLevel('warning'); // This one will make it through - const warningParams: LoggingMessageNotification["params"] = { - level: "warning", - logger: "test server", - data: "Warning message", + const warningParams: LoggingMessageNotification['params'] = { + level: 'warning', + logger: 'test server', + data: 'Warning message' }; // This one will not - const debugParams: LoggingMessageNotification["params"] = { - level: "debug", - logger: "test server", - data: "Debug message", + const debugParams: LoggingMessageNotification['params'] = { + level: 'debug', + logger: 'test server', + data: 'Debug message' }; // Test the one that makes it through - clientTransport.onmessage = jest.fn().mockImplementation((message) => { + clientTransport.onmessage = jest.fn().mockImplementation(message => { expect(message).toEqual({ - jsonrpc: "2.0", - method: "notifications/message", + jsonrpc: '2.0', + method: 'notifications/message', params: warningParams }); }); @@ -933,72 +885,64 @@ test("should respect log level for transport without sessionId", async () => { // This one will, triggering the above test in clientTransport.onmessage await server.sendLoggingMessage(warningParams); expect(clientTransport.onmessage).toHaveBeenCalled(); - }); -test("should respect log level for transport with sessionId", async () => { - +test('should respect log level for transport with sessionId', async () => { const server = new Server( { - name: "test server", - version: "1.0", + name: 'test server', + version: '1.0' }, { capabilities: { prompts: {}, resources: {}, tools: {}, - logging: {}, + logging: {} }, - enforceStrictCapabilities: true, - }, + enforceStrictCapabilities: true + } ); - const client = new Client( - { - name: "test client", - version: "1.0", - }, - ); + const client = new Client({ + name: 'test client', + version: '1.0' + }); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); // Add a session id to the transports - const SESSION_ID = "test-session-id"; + const SESSION_ID = 'test-session-id'; clientTransport.sessionId = SESSION_ID; serverTransport.sessionId = SESSION_ID; expect(clientTransport.sessionId).toBeDefined(); expect(serverTransport.sessionId).toBeDefined(); - await Promise.all([ - client.connect(clientTransport), - server.connect(serverTransport), - ]); - + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); // Client sets logging level to warning - await client.setLoggingLevel("warning"); + await client.setLoggingLevel('warning'); // This one will make it through - const warningParams: LoggingMessageNotification["params"] = { - level: "warning", - logger: "test server", - data: "Warning message", + const warningParams: LoggingMessageNotification['params'] = { + level: 'warning', + logger: 'test server', + data: 'Warning message' }; // This one will not - const debugParams: LoggingMessageNotification["params"] = { - level: "debug", - logger: "test server", - data: "Debug message", + const debugParams: LoggingMessageNotification['params'] = { + level: 'debug', + logger: 'test server', + data: 'Debug message' }; // Test the one that makes it through - clientTransport.onmessage = jest.fn().mockImplementation((message) => { + clientTransport.onmessage = jest.fn().mockImplementation(message => { expect(message).toEqual({ - jsonrpc: "2.0", - method: "notifications/message", + jsonrpc: '2.0', + method: 'notifications/message', params: warningParams }); }); @@ -1010,6 +954,4 @@ test("should respect log level for transport with sessionId", async () => { // This one will, triggering the above test in clientTransport.onmessage await server.sendLoggingMessage(warningParams, SESSION_ID); expect(clientTransport.onmessage).toHaveBeenCalled(); - }); - diff --git a/src/server/index.ts b/src/server/index.ts index 970657358..3eb0ba0d4 100644 --- a/src/server/index.ts +++ b/src/server/index.ts @@ -1,53 +1,48 @@ +import { mergeCapabilities, Protocol, ProtocolOptions, RequestOptions } from '../shared/protocol.js'; import { - mergeCapabilities, - Protocol, - ProtocolOptions, - RequestOptions, -} from "../shared/protocol.js"; -import { - ClientCapabilities, - CreateMessageRequest, - CreateMessageResultSchema, - ElicitRequest, - ElicitResult, - ElicitResultSchema, - EmptyResultSchema, - Implementation, - InitializedNotificationSchema, - InitializeRequest, - InitializeRequestSchema, - InitializeResult, - LATEST_PROTOCOL_VERSION, - ListRootsRequest, - ListRootsResultSchema, - LoggingMessageNotification, - McpError, - ErrorCode, - Notification, - Request, - ResourceUpdatedNotification, - Result, - ServerCapabilities, - ServerNotification, - ServerRequest, - ServerResult, - SUPPORTED_PROTOCOL_VERSIONS, - LoggingLevel, - SetLevelRequestSchema, - LoggingLevelSchema -} from "../types.js"; -import Ajv from "ajv"; + ClientCapabilities, + CreateMessageRequest, + CreateMessageResultSchema, + ElicitRequest, + ElicitResult, + ElicitResultSchema, + EmptyResultSchema, + Implementation, + InitializedNotificationSchema, + InitializeRequest, + InitializeRequestSchema, + InitializeResult, + LATEST_PROTOCOL_VERSION, + ListRootsRequest, + ListRootsResultSchema, + LoggingMessageNotification, + McpError, + ErrorCode, + Notification, + Request, + ResourceUpdatedNotification, + Result, + ServerCapabilities, + ServerNotification, + ServerRequest, + ServerResult, + SUPPORTED_PROTOCOL_VERSIONS, + LoggingLevel, + SetLevelRequestSchema, + LoggingLevelSchema +} from '../types.js'; +import Ajv from 'ajv'; export type ServerOptions = ProtocolOptions & { - /** - * Capabilities to advertise as being supported by this server. - */ - capabilities?: ServerCapabilities; - - /** - * Optional instructions describing how to use the server and its features. - */ - instructions?: string; + /** + * Capabilities to advertise as being supported by this server. + */ + capabilities?: ServerCapabilities; + + /** + * Optional instructions describing how to use the server and its features. + */ + instructions?: string; }; /** @@ -76,318 +71,251 @@ export type ServerOptions = ProtocolOptions & { * ``` */ export class Server< - RequestT extends Request = Request, - NotificationT extends Notification = Notification, - ResultT extends Result = Result, -> extends Protocol< - ServerRequest | RequestT, - ServerNotification | NotificationT, - ServerResult | ResultT -> { - private _clientCapabilities?: ClientCapabilities; - private _clientVersion?: Implementation; - private _capabilities: ServerCapabilities; - private _instructions?: string; - - /** - * Callback for when initialization has fully completed (i.e., the client has sent an `initialized` notification). - */ - oninitialized?: () => void; - - /** - * Initializes this server with the given name and version information. - */ - constructor( - private _serverInfo: Implementation, - options?: ServerOptions, - ) { - super(options); - this._capabilities = options?.capabilities ?? {}; - this._instructions = options?.instructions; - - this.setRequestHandler(InitializeRequestSchema, (request) => - this._oninitialize(request), - ); - this.setNotificationHandler(InitializedNotificationSchema, () => - this.oninitialized?.(), - ); - - if (this._capabilities.logging) { - this.setRequestHandler(SetLevelRequestSchema, async (request, extra) => { - const transportSessionId: string | undefined = extra.sessionId || extra.requestInfo?.headers['mcp-session-id'] as string || undefined; - const { level } = request.params; - const parseResult = LoggingLevelSchema.safeParse(level); - if (parseResult.success) { - this._loggingLevels.set(transportSessionId, parseResult.data); - } - return {}; - }) - } - } - - // Map log levels by session id - private _loggingLevels = new Map(); - - // Map LogLevelSchema to severity index - private readonly LOG_LEVEL_SEVERITY = new Map( - LoggingLevelSchema.options.map((level, index) => [level, index]) - ); - - // Is a message with the given level ignored in the log level set for the given session id? - private isMessageIgnored = (level: LoggingLevel, sessionId?: string): boolean => { - const currentLevel = this._loggingLevels.get(sessionId); - return (currentLevel) - ? this.LOG_LEVEL_SEVERITY.get(level)! < this.LOG_LEVEL_SEVERITY.get(currentLevel)! - : false; - }; - - /** - * Registers new capabilities. This can only be called before connecting to a transport. - * - * The new capabilities will be merged with any existing capabilities previously given (e.g., at initialization). - */ - public registerCapabilities(capabilities: ServerCapabilities): void { - if (this.transport) { - throw new Error( - "Cannot register capabilities after connecting to transport", - ); - } - this._capabilities = mergeCapabilities(this._capabilities, capabilities); - } - - protected assertCapabilityForMethod(method: RequestT["method"]): void { - switch (method as ServerRequest["method"]) { - case "sampling/createMessage": - if (!this._clientCapabilities?.sampling) { - throw new Error( - `Client does not support sampling (required for ${method})`, - ); - } - break; + RequestT extends Request = Request, + NotificationT extends Notification = Notification, + ResultT extends Result = Result +> extends Protocol { + private _clientCapabilities?: ClientCapabilities; + private _clientVersion?: Implementation; + private _capabilities: ServerCapabilities; + private _instructions?: string; - case "elicitation/create": - if (!this._clientCapabilities?.elicitation) { - throw new Error( - `Client does not support elicitation (required for ${method})`, - ); - } - break; + /** + * Callback for when initialization has fully completed (i.e., the client has sent an `initialized` notification). + */ + oninitialized?: () => void; - case "roots/list": - if (!this._clientCapabilities?.roots) { - throw new Error( - `Client does not support listing roots (required for ${method})`, - ); + /** + * Initializes this server with the given name and version information. + */ + constructor( + private _serverInfo: Implementation, + options?: ServerOptions + ) { + super(options); + this._capabilities = options?.capabilities ?? {}; + this._instructions = options?.instructions; + + this.setRequestHandler(InitializeRequestSchema, request => this._oninitialize(request)); + this.setNotificationHandler(InitializedNotificationSchema, () => this.oninitialized?.()); + + if (this._capabilities.logging) { + this.setRequestHandler(SetLevelRequestSchema, async (request, extra) => { + const transportSessionId: string | undefined = + extra.sessionId || (extra.requestInfo?.headers['mcp-session-id'] as string) || undefined; + const { level } = request.params; + const parseResult = LoggingLevelSchema.safeParse(level); + if (parseResult.success) { + this._loggingLevels.set(transportSessionId, parseResult.data); + } + return {}; + }); } - break; - - case "ping": - // No specific capability required for ping - break; } - } - - protected assertNotificationCapability( - method: (ServerNotification | NotificationT)["method"], - ): void { - switch (method as ServerNotification["method"]) { - case "notifications/message": - if (!this._capabilities.logging) { - throw new Error( - `Server does not support logging (required for ${method})`, - ); + + // Map log levels by session id + private _loggingLevels = new Map(); + + // Map LogLevelSchema to severity index + private readonly LOG_LEVEL_SEVERITY = new Map(LoggingLevelSchema.options.map((level, index) => [level, index])); + + // Is a message with the given level ignored in the log level set for the given session id? + private isMessageIgnored = (level: LoggingLevel, sessionId?: string): boolean => { + const currentLevel = this._loggingLevels.get(sessionId); + return currentLevel ? this.LOG_LEVEL_SEVERITY.get(level)! < this.LOG_LEVEL_SEVERITY.get(currentLevel)! : false; + }; + + /** + * Registers new capabilities. This can only be called before connecting to a transport. + * + * The new capabilities will be merged with any existing capabilities previously given (e.g., at initialization). + */ + public registerCapabilities(capabilities: ServerCapabilities): void { + if (this.transport) { + throw new Error('Cannot register capabilities after connecting to transport'); } - break; - - case "notifications/resources/updated": - case "notifications/resources/list_changed": - if (!this._capabilities.resources) { - throw new Error( - `Server does not support notifying about resources (required for ${method})`, - ); + this._capabilities = mergeCapabilities(this._capabilities, capabilities); + } + + protected assertCapabilityForMethod(method: RequestT['method']): void { + switch (method as ServerRequest['method']) { + case 'sampling/createMessage': + if (!this._clientCapabilities?.sampling) { + throw new Error(`Client does not support sampling (required for ${method})`); + } + break; + + case 'elicitation/create': + if (!this._clientCapabilities?.elicitation) { + throw new Error(`Client does not support elicitation (required for ${method})`); + } + break; + + case 'roots/list': + if (!this._clientCapabilities?.roots) { + throw new Error(`Client does not support listing roots (required for ${method})`); + } + break; + + case 'ping': + // No specific capability required for ping + break; } - break; + } - case "notifications/tools/list_changed": - if (!this._capabilities.tools) { - throw new Error( - `Server does not support notifying of tool list changes (required for ${method})`, - ); + protected assertNotificationCapability(method: (ServerNotification | NotificationT)['method']): void { + switch (method as ServerNotification['method']) { + case 'notifications/message': + if (!this._capabilities.logging) { + throw new Error(`Server does not support logging (required for ${method})`); + } + break; + + case 'notifications/resources/updated': + case 'notifications/resources/list_changed': + if (!this._capabilities.resources) { + throw new Error(`Server does not support notifying about resources (required for ${method})`); + } + break; + + case 'notifications/tools/list_changed': + if (!this._capabilities.tools) { + throw new Error(`Server does not support notifying of tool list changes (required for ${method})`); + } + break; + + case 'notifications/prompts/list_changed': + if (!this._capabilities.prompts) { + throw new Error(`Server does not support notifying of prompt list changes (required for ${method})`); + } + break; + + case 'notifications/cancelled': + // Cancellation notifications are always allowed + break; + + case 'notifications/progress': + // Progress notifications are always allowed + break; } - break; + } - case "notifications/prompts/list_changed": - if (!this._capabilities.prompts) { - throw new Error( - `Server does not support notifying of prompt list changes (required for ${method})`, - ); + protected assertRequestHandlerCapability(method: string): void { + switch (method) { + case 'sampling/createMessage': + if (!this._capabilities.sampling) { + throw new Error(`Server does not support sampling (required for ${method})`); + } + break; + + case 'logging/setLevel': + if (!this._capabilities.logging) { + throw new Error(`Server does not support logging (required for ${method})`); + } + break; + + case 'prompts/get': + case 'prompts/list': + if (!this._capabilities.prompts) { + throw new Error(`Server does not support prompts (required for ${method})`); + } + break; + + case 'resources/list': + case 'resources/templates/list': + case 'resources/read': + if (!this._capabilities.resources) { + throw new Error(`Server does not support resources (required for ${method})`); + } + break; + + case 'tools/call': + case 'tools/list': + if (!this._capabilities.tools) { + throw new Error(`Server does not support tools (required for ${method})`); + } + break; + + case 'ping': + case 'initialize': + // No specific capability required for these methods + break; } - break; + } + + private async _oninitialize(request: InitializeRequest): Promise { + const requestedVersion = request.params.protocolVersion; + + this._clientCapabilities = request.params.capabilities; + this._clientVersion = request.params.clientInfo; - case "notifications/cancelled": - // Cancellation notifications are always allowed - break; + const protocolVersion = SUPPORTED_PROTOCOL_VERSIONS.includes(requestedVersion) ? requestedVersion : LATEST_PROTOCOL_VERSION; - case "notifications/progress": - // Progress notifications are always allowed - break; + return { + protocolVersion, + capabilities: this.getCapabilities(), + serverInfo: this._serverInfo, + ...(this._instructions && { instructions: this._instructions }) + }; } - } - - protected assertRequestHandlerCapability(method: string): void { - switch (method) { - case "sampling/createMessage": - if (!this._capabilities.sampling) { - throw new Error( - `Server does not support sampling (required for ${method})`, - ); - } - break; - case "logging/setLevel": - if (!this._capabilities.logging) { - throw new Error( - `Server does not support logging (required for ${method})`, - ); - } - break; - - case "prompts/get": - case "prompts/list": - if (!this._capabilities.prompts) { - throw new Error( - `Server does not support prompts (required for ${method})`, - ); - } - break; - - case "resources/list": - case "resources/templates/list": - case "resources/read": - if (!this._capabilities.resources) { - throw new Error( - `Server does not support resources (required for ${method})`, - ); - } - break; - - case "tools/call": - case "tools/list": - if (!this._capabilities.tools) { - throw new Error( - `Server does not support tools (required for ${method})`, - ); - } - break; + /** + * After initialization has completed, this will be populated with the client's reported capabilities. + */ + getClientCapabilities(): ClientCapabilities | undefined { + return this._clientCapabilities; + } - case "ping": - case "initialize": - // No specific capability required for these methods - break; + /** + * After initialization has completed, this will be populated with information about the client's name and version. + */ + getClientVersion(): Implementation | undefined { + return this._clientVersion; } - } - private async _oninitialize( - request: InitializeRequest, - ): Promise { - const requestedVersion = request.params.protocolVersion; + private getCapabilities(): ServerCapabilities { + return this._capabilities; + } - this._clientCapabilities = request.params.capabilities; - this._clientVersion = request.params.clientInfo; + async ping() { + return this.request({ method: 'ping' }, EmptyResultSchema); + } - const protocolVersion = SUPPORTED_PROTOCOL_VERSIONS.includes(requestedVersion) - ? requestedVersion - : LATEST_PROTOCOL_VERSION; + async createMessage(params: CreateMessageRequest['params'], options?: RequestOptions) { + return this.request({ method: 'sampling/createMessage', params }, CreateMessageResultSchema, options); + } - return { - protocolVersion, - capabilities: this.getCapabilities(), - serverInfo: this._serverInfo, - ...(this._instructions && { instructions: this._instructions }), - }; - } - - /** - * After initialization has completed, this will be populated with the client's reported capabilities. - */ - getClientCapabilities(): ClientCapabilities | undefined { - return this._clientCapabilities; - } - - /** - * After initialization has completed, this will be populated with information about the client's name and version. - */ - getClientVersion(): Implementation | undefined { - return this._clientVersion; - } - - private getCapabilities(): ServerCapabilities { - return this._capabilities; - } - - async ping() { - return this.request({ method: "ping" }, EmptyResultSchema); - } - - async createMessage( - params: CreateMessageRequest["params"], - options?: RequestOptions, - ) { - return this.request( - { method: "sampling/createMessage", params }, - CreateMessageResultSchema, - options, - ); - } - - async elicitInput( - params: ElicitRequest["params"], - options?: RequestOptions, - ): Promise { - const result = await this.request( - { method: "elicitation/create", params }, - ElicitResultSchema, - options, - ); - - // Validate the response content against the requested schema if action is "accept" - if (result.action === "accept" && result.content) { - try { - const ajv = new Ajv(); - - const validate = ajv.compile(params.requestedSchema); - const isValid = validate(result.content); - - if (!isValid) { - throw new McpError( - ErrorCode.InvalidParams, - `Elicitation response content does not match requested schema: ${ajv.errorsText(validate.errors)}`, - ); - } - } catch (error) { - if (error instanceof McpError) { - throw error; + async elicitInput(params: ElicitRequest['params'], options?: RequestOptions): Promise { + const result = await this.request({ method: 'elicitation/create', params }, ElicitResultSchema, options); + + // Validate the response content against the requested schema if action is "accept" + if (result.action === 'accept' && result.content) { + try { + const ajv = new Ajv(); + + const validate = ajv.compile(params.requestedSchema); + const isValid = validate(result.content); + + if (!isValid) { + throw new McpError( + ErrorCode.InvalidParams, + `Elicitation response content does not match requested schema: ${ajv.errorsText(validate.errors)}` + ); + } + } catch (error) { + if (error instanceof McpError) { + throw error; + } + throw new McpError(ErrorCode.InternalError, `Error validating elicitation response: ${error}`); + } } - throw new McpError( - ErrorCode.InternalError, - `Error validating elicitation response: ${error}`, - ); - } - } - return result; - } + return result; + } - async listRoots( - params?: ListRootsRequest["params"], - options?: RequestOptions, - ) { - return this.request( - { method: "roots/list", params }, - ListRootsResultSchema, - options, - ); - } + async listRoots(params?: ListRootsRequest['params'], options?: RequestOptions) { + return this.request({ method: 'roots/list', params }, ListRootsResultSchema, options); + } /** * Sends a logging message to the client, if connected. @@ -396,32 +324,32 @@ export class Server< * @param params * @param sessionId optional for stateless and backward compatibility */ - async sendLoggingMessage(params: LoggingMessageNotification["params"], sessionId?: string) { - if (this._capabilities.logging) { - if (!this.isMessageIgnored(params.level, sessionId)) { - return this.notification({method: "notifications/message", params}) - } - } - } - - async sendResourceUpdated(params: ResourceUpdatedNotification["params"]) { - return this.notification({ - method: "notifications/resources/updated", - params, - }); - } - - async sendResourceListChanged() { - return this.notification({ - method: "notifications/resources/list_changed", - }); - } - - async sendToolListChanged() { - return this.notification({ method: "notifications/tools/list_changed" }); - } - - async sendPromptListChanged() { - return this.notification({ method: "notifications/prompts/list_changed" }); - } + async sendLoggingMessage(params: LoggingMessageNotification['params'], sessionId?: string) { + if (this._capabilities.logging) { + if (!this.isMessageIgnored(params.level, sessionId)) { + return this.notification({ method: 'notifications/message', params }); + } + } + } + + async sendResourceUpdated(params: ResourceUpdatedNotification['params']) { + return this.notification({ + method: 'notifications/resources/updated', + params + }); + } + + async sendResourceListChanged() { + return this.notification({ + method: 'notifications/resources/list_changed' + }); + } + + async sendToolListChanged() { + return this.notification({ method: 'notifications/tools/list_changed' }); + } + + async sendPromptListChanged() { + return this.notification({ method: 'notifications/prompts/list_changed' }); + } } diff --git a/src/server/mcp.test.ts b/src/server/mcp.test.ts index d9142702f..4bb42d7fc 100644 --- a/src/server/mcp.test.ts +++ b/src/server/mcp.test.ts @@ -1,4386 +1,4088 @@ -import { McpServer } from "./mcp.js"; -import { Client } from "../client/index.js"; -import { InMemoryTransport } from "../inMemory.js"; -import { z } from "zod"; +import { McpServer } from './mcp.js'; +import { Client } from '../client/index.js'; +import { InMemoryTransport } from '../inMemory.js'; +import { z } from 'zod'; import { - ListToolsResultSchema, - CallToolResultSchema, - ListResourcesResultSchema, - ListResourceTemplatesResultSchema, - ReadResourceResultSchema, - ListPromptsResultSchema, - GetPromptResultSchema, - CompleteResultSchema, - LoggingMessageNotificationSchema, - Notification, - TextContent, - ElicitRequestSchema -} from "../types.js"; -import { ResourceTemplate } from "./mcp.js"; -import { completable } from "./completable.js"; -import { UriTemplate } from "../shared/uriTemplate.js"; -import { getDisplayName } from "../shared/metadataUtils.js"; - -describe("McpServer", () => { - /*** - * Test: Basic Server Instance - */ - test("should expose underlying Server instance", () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); + ListToolsResultSchema, + CallToolResultSchema, + ListResourcesResultSchema, + ListResourceTemplatesResultSchema, + ReadResourceResultSchema, + ListPromptsResultSchema, + GetPromptResultSchema, + CompleteResultSchema, + LoggingMessageNotificationSchema, + Notification, + TextContent, + ElicitRequestSchema +} from '../types.js'; +import { ResourceTemplate } from './mcp.js'; +import { completable } from './completable.js'; +import { UriTemplate } from '../shared/uriTemplate.js'; +import { getDisplayName } from '../shared/metadataUtils.js'; + +describe('McpServer', () => { + /*** + * Test: Basic Server Instance + */ + test('should expose underlying Server instance', () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + expect(mcpServer.server).toBeDefined(); + }); + + /*** + * Test: Notification Sending via Server + */ + test('should allow sending notifications via Server', async () => { + const mcpServer = new McpServer( + { + name: 'test server', + version: '1.0' + }, + { capabilities: { logging: {} } } + ); + + const notifications: Notification[] = []; + const client = new Client({ + name: 'test client', + version: '1.0' + }); + client.fallbackNotificationHandler = async notification => { + notifications.push(notification); + }; - expect(mcpServer.server).toBeDefined(); - }); - - /*** - * Test: Notification Sending via Server - */ - test("should allow sending notifications via Server", async () => { - const mcpServer = new McpServer( - { - name: "test server", - version: "1.0", - }, - { capabilities: { logging: {} } }, - ); - - const notifications: Notification[] = [] - const client = new Client({ - name: "test client", - version: "1.0", - }); - client.fallbackNotificationHandler = async (notification) => { - notifications.push(notification) - } - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.server.connect(serverTransport), - ]); - - // This should work because we're using the underlying server - await expect( - mcpServer.server.sendLoggingMessage({ - level: "info", - data: "Test log message", - }), - ).resolves.not.toThrow(); - - expect(notifications).toMatchObject([ - { - "method": "notifications/message", - params: { - level: "info", - data: "Test log message", - } - } - ]) - }); - - /*** - * Test: Progress Notification with Message Field - */ - test("should send progress notifications with message field", async () => { - const mcpServer = new McpServer( - { - name: "test server", - version: "1.0", - } - ); - - // Create a tool that sends progress updates - mcpServer.tool( - "long-operation", - "A long running operation with progress updates", - { - steps: z.number().min(1).describe("Number of steps to perform"), - }, - async ({ steps }, { sendNotification, _meta }) => { - const progressToken = _meta?.progressToken; - - if (progressToken) { - // Send progress notification for each step - for (let i = 1; i <= steps; i++) { - await sendNotification({ - method: "notifications/progress", - params: { - progressToken, - progress: i, - total: steps, - message: `Completed step ${i} of ${steps}`, - }, - }); - } - } - - return { content: [{ type: "text" as const, text: `Operation completed with ${steps} steps` }] }; - } - ); - - const progressUpdates: Array<{ progress: number, total?: number, message?: string }> = []; - - const client = new Client({ - name: "test client", - version: "1.0", - }); + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.server.connect(serverTransport), - ]); - - // Call the tool with progress tracking - await client.request( - { - method: "tools/call", - params: { - name: "long-operation", - arguments: { steps: 3 }, - _meta: { - progressToken: "progress-test-1" - } - } - }, - CallToolResultSchema, - { - onprogress: (progress) => { - progressUpdates.push(progress); - } - } - ); - - // Verify progress notifications were received with message field - expect(progressUpdates).toHaveLength(3); - expect(progressUpdates[0]).toMatchObject({ - progress: 1, - total: 3, - message: "Completed step 1 of 3", - }); - expect(progressUpdates[1]).toMatchObject({ - progress: 2, - total: 3, - message: "Completed step 2 of 3", - }); - expect(progressUpdates[2]).toMatchObject({ - progress: 3, - total: 3, - message: "Completed step 3 of 3", - }); - }); -}); + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); -describe("ResourceTemplate", () => { - /*** - * Test: ResourceTemplate Creation with String Pattern - */ - test("should create ResourceTemplate with string pattern", () => { - const template = new ResourceTemplate("test://{category}/{id}", { - list: undefined, - }); - expect(template.uriTemplate.toString()).toBe("test://{category}/{id}"); - expect(template.listCallback).toBeUndefined(); - }); - - /*** - * Test: ResourceTemplate Creation with UriTemplate Instance - */ - test("should create ResourceTemplate with UriTemplate", () => { - const uriTemplate = new UriTemplate("test://{category}/{id}"); - const template = new ResourceTemplate(uriTemplate, { list: undefined }); - expect(template.uriTemplate).toBe(uriTemplate); - expect(template.listCallback).toBeUndefined(); - }); - - /*** - * Test: ResourceTemplate with List Callback - */ - test("should create ResourceTemplate with list callback", async () => { - const list = jest.fn().mockResolvedValue({ - resources: [{ name: "Test", uri: "test://example" }], - }); + // This should work because we're using the underlying server + await expect( + mcpServer.server.sendLoggingMessage({ + level: 'info', + data: 'Test log message' + }) + ).resolves.not.toThrow(); + + expect(notifications).toMatchObject([ + { + method: 'notifications/message', + params: { + level: 'info', + data: 'Test log message' + } + } + ]); + }); + + /*** + * Test: Progress Notification with Message Field + */ + test('should send progress notifications with message field', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + // Create a tool that sends progress updates + mcpServer.tool( + 'long-operation', + 'A long running operation with progress updates', + { + steps: z.number().min(1).describe('Number of steps to perform') + }, + async ({ steps }, { sendNotification, _meta }) => { + const progressToken = _meta?.progressToken; + + if (progressToken) { + // Send progress notification for each step + for (let i = 1; i <= steps; i++) { + await sendNotification({ + method: 'notifications/progress', + params: { + progressToken, + progress: i, + total: steps, + message: `Completed step ${i} of ${steps}` + } + }); + } + } + + return { content: [{ type: 'text' as const, text: `Operation completed with ${steps} steps` }] }; + } + ); - const template = new ResourceTemplate("test://{id}", { list }); - expect(template.listCallback).toBe(list); + const progressUpdates: Array<{ progress: number; total?: number; message?: string }> = []; - const abortController = new AbortController(); - const result = await template.listCallback?.({ - signal: abortController.signal, - requestId: 'not-implemented', - sendRequest: () => { throw new Error("Not implemented") }, - sendNotification: () => { throw new Error("Not implemented") } + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + // Call the tool with progress tracking + await client.request( + { + method: 'tools/call', + params: { + name: 'long-operation', + arguments: { steps: 3 }, + _meta: { + progressToken: 'progress-test-1' + } + } + }, + CallToolResultSchema, + { + onprogress: progress => { + progressUpdates.push(progress); + } + } + ); + + // Verify progress notifications were received with message field + expect(progressUpdates).toHaveLength(3); + expect(progressUpdates[0]).toMatchObject({ + progress: 1, + total: 3, + message: 'Completed step 1 of 3' + }); + expect(progressUpdates[1]).toMatchObject({ + progress: 2, + total: 3, + message: 'Completed step 2 of 3' + }); + expect(progressUpdates[2]).toMatchObject({ + progress: 3, + total: 3, + message: 'Completed step 3 of 3' + }); }); - expect(result?.resources).toHaveLength(1); - expect(list).toHaveBeenCalled(); - }); }); -describe("tool()", () => { - /*** - * Test: Zero-Argument Tool Registration - */ - test("should register zero-argument tool", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); - const notifications: Notification[] = [] - const client = new Client({ - name: "test client", - version: "1.0", - }); - client.fallbackNotificationHandler = async (notification) => { - notifications.push(notification) - } - - mcpServer.tool("test", async () => ({ - content: [ - { - type: "text", - text: "Test response", - }, - ], - })); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.connect(serverTransport), - ]); - - const result = await client.request( - { - method: "tools/list", - }, - ListToolsResultSchema, - ); - - expect(result.tools).toHaveLength(1); - expect(result.tools[0].name).toBe("test"); - expect(result.tools[0].inputSchema).toEqual({ - type: "object", - properties: {}, +describe('ResourceTemplate', () => { + /*** + * Test: ResourceTemplate Creation with String Pattern + */ + test('should create ResourceTemplate with string pattern', () => { + const template = new ResourceTemplate('test://{category}/{id}', { + list: undefined + }); + expect(template.uriTemplate.toString()).toBe('test://{category}/{id}'); + expect(template.listCallback).toBeUndefined(); + }); + + /*** + * Test: ResourceTemplate Creation with UriTemplate Instance + */ + test('should create ResourceTemplate with UriTemplate', () => { + const uriTemplate = new UriTemplate('test://{category}/{id}'); + const template = new ResourceTemplate(uriTemplate, { list: undefined }); + expect(template.uriTemplate).toBe(uriTemplate); + expect(template.listCallback).toBeUndefined(); + }); + + /*** + * Test: ResourceTemplate with List Callback + */ + test('should create ResourceTemplate with list callback', async () => { + const list = jest.fn().mockResolvedValue({ + resources: [{ name: 'Test', uri: 'test://example' }] + }); + + const template = new ResourceTemplate('test://{id}', { list }); + expect(template.listCallback).toBe(list); + + const abortController = new AbortController(); + const result = await template.listCallback?.({ + signal: abortController.signal, + requestId: 'not-implemented', + sendRequest: () => { + throw new Error('Not implemented'); + }, + sendNotification: () => { + throw new Error('Not implemented'); + } + }); + expect(result?.resources).toHaveLength(1); + expect(list).toHaveBeenCalled(); }); +}); - // Adding the tool before the connection was established means no notification was sent - expect(notifications).toHaveLength(0) - - // Adding another tool triggers the update notification - mcpServer.tool("test2", async () => ({ - content: [ - { - type: "text", - text: "Test response", - }, - ], - })); - - // Yield event loop to let the notification fly - await new Promise(process.nextTick) - - expect(notifications).toMatchObject([ - { - method: "notifications/tools/list_changed", - } - ]) - }); - - /*** - * Test: Updating Existing Tool - */ - test("should update existing tool", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); - const notifications: Notification[] = [] - const client = new Client({ - name: "test client", - version: "1.0", - }); - client.fallbackNotificationHandler = async (notification) => { - notifications.push(notification) - } - - // Register initial tool - const tool = mcpServer.tool("test", async () => ({ - content: [ - { - type: "text", - text: "Initial response", - }, - ], - })); - - // Update the tool - tool.update({ - callback: async () => ({ - content: [ - { - type: "text", - text: "Updated response", - }, - ], - }) - }); +describe('tool()', () => { + /*** + * Test: Zero-Argument Tool Registration + */ + test('should register zero-argument tool', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const notifications: Notification[] = []; + const client = new Client({ + name: 'test client', + version: '1.0' + }); + client.fallbackNotificationHandler = async notification => { + notifications.push(notification); + }; - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.connect(serverTransport), - ]); - - // Call the tool and verify we get the updated response - const result = await client.request( - { - method: "tools/call", - params: { - name: "test", - }, - }, - CallToolResultSchema, - ); - - expect(result.content).toEqual([ - { - type: "text", - text: "Updated response", - }, - ]); - - // Update happened before transport was connected, so no notifications should be expected - expect(notifications).toHaveLength(0) - }); - - /*** - * Test: Updating Tool with Schema - */ - test("should update tool with schema", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); - const notifications: Notification[] = [] - const client = new Client({ - name: "test client", - version: "1.0", - }); - client.fallbackNotificationHandler = async (notification) => { - notifications.push(notification) - } - - // Register initial tool - const tool = mcpServer.tool( - "test", - { - name: z.string(), - }, - async ({ name }) => ({ - content: [ - { - type: "text", - text: `Initial: ${name}`, - }, - ], - }), - ); - - // Update the tool with a different schema - tool.update({ - paramsSchema: { - name: z.string(), - value: z.number(), - }, - callback: async ({ name, value }) => ({ - content: [ - { - type: "text", - text: `Updated: ${name}, ${value}`, - }, - ], - }) - }); + mcpServer.tool('test', async () => ({ + content: [ + { + type: 'text', + text: 'Test response' + } + ] + })); - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.connect(serverTransport), - ]); - - // Verify the schema was updated - const listResult = await client.request( - { - method: "tools/list", - }, - ListToolsResultSchema, - ); - - expect(listResult.tools[0].inputSchema).toMatchObject({ - properties: { - name: { type: "string" }, - value: { type: "number" }, - }, - }); + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - // Call the tool with the new schema - const callResult = await client.request( - { - method: "tools/call", - params: { - name: "test", - arguments: { - name: "test", - value: 42, - }, - }, - }, - CallToolResultSchema, - ); - - expect(callResult.content).toEqual([ - { - type: "text", - text: "Updated: test, 42", - }, - ]); - - // Update happened before transport was connected, so no notifications should be expected - expect(notifications).toHaveLength(0) - }); - - /*** - * Test: Tool List Changed Notifications - */ - test("should send tool list changed notifications when connected", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); - const notifications: Notification[] = [] - const client = new Client({ - name: "test client", - version: "1.0", - }); - client.fallbackNotificationHandler = async (notification) => { - notifications.push(notification) - } - - // Register initial tool - const tool = mcpServer.tool("test", async () => ({ - content: [ - { - type: "text", - text: "Test response", - }, - ], - })); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.connect(serverTransport), - ]); - - expect(notifications).toHaveLength(0) - - // Now update the tool - tool.update({ - callback: async () => ({ - content: [ - { - type: "text", - text: "Updated response", - }, - ], - }) - }); + await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); - // Yield event loop to let the notification fly - await new Promise(process.nextTick) + const result = await client.request( + { + method: 'tools/list' + }, + ListToolsResultSchema + ); + + expect(result.tools).toHaveLength(1); + expect(result.tools[0].name).toBe('test'); + expect(result.tools[0].inputSchema).toEqual({ + type: 'object', + properties: {} + }); + + // Adding the tool before the connection was established means no notification was sent + expect(notifications).toHaveLength(0); + + // Adding another tool triggers the update notification + mcpServer.tool('test2', async () => ({ + content: [ + { + type: 'text', + text: 'Test response' + } + ] + })); - expect(notifications).toMatchObject([ - { method: "notifications/tools/list_changed" } - ]) + // Yield event loop to let the notification fly + await new Promise(process.nextTick); - // Now delete the tool - tool.remove(); + expect(notifications).toMatchObject([ + { + method: 'notifications/tools/list_changed' + } + ]); + }); + + /*** + * Test: Updating Existing Tool + */ + test('should update existing tool', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const notifications: Notification[] = []; + const client = new Client({ + name: 'test client', + version: '1.0' + }); + client.fallbackNotificationHandler = async notification => { + notifications.push(notification); + }; - // Yield event loop to let the notification fly - await new Promise(process.nextTick) + // Register initial tool + const tool = mcpServer.tool('test', async () => ({ + content: [ + { + type: 'text', + text: 'Initial response' + } + ] + })); + + // Update the tool + tool.update({ + callback: async () => ({ + content: [ + { + type: 'text', + text: 'Updated response' + } + ] + }) + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); + + // Call the tool and verify we get the updated response + const result = await client.request( + { + method: 'tools/call', + params: { + name: 'test' + } + }, + CallToolResultSchema + ); - expect(notifications).toMatchObject([ - { method: "notifications/tools/list_changed" }, - { method: "notifications/tools/list_changed" }, - ]) - }); + expect(result.content).toEqual([ + { + type: 'text', + text: 'Updated response' + } + ]); + + // Update happened before transport was connected, so no notifications should be expected + expect(notifications).toHaveLength(0); + }); + + /*** + * Test: Updating Tool with Schema + */ + test('should update tool with schema', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const notifications: Notification[] = []; + const client = new Client({ + name: 'test client', + version: '1.0' + }); + client.fallbackNotificationHandler = async notification => { + notifications.push(notification); + }; - /*** - * Test: Tool Registration with Parameters - */ - test("should register tool with params", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); - const client = new Client({ - name: "test client", - version: "1.0", - }); + // Register initial tool + const tool = mcpServer.tool( + 'test', + { + name: z.string() + }, + async ({ name }) => ({ + content: [ + { + type: 'text', + text: `Initial: ${name}` + } + ] + }) + ); + + // Update the tool with a different schema + tool.update({ + paramsSchema: { + name: z.string(), + value: z.number() + }, + callback: async ({ name, value }) => ({ + content: [ + { + type: 'text', + text: `Updated: ${name}, ${value}` + } + ] + }) + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); + + // Verify the schema was updated + const listResult = await client.request( + { + method: 'tools/list' + }, + ListToolsResultSchema + ); - // old api - mcpServer.tool( - "test", - { - name: z.string(), - value: z.number(), - }, - async ({ name, value }) => ({ - content: [ - { - type: "text", - text: `${name}: ${value}`, - }, - ], - }), - ); - - // new api - mcpServer.registerTool( - "test (new api)", - { - inputSchema: { name: z.string(), value: z.number() }, - }, - async ({ name, value }) => ({ - content: [{ type: "text", text: `${name}: ${value}` }], - }) - ); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.server.connect(serverTransport), - ]); - - const result = await client.request( - { - method: "tools/list", - }, - ListToolsResultSchema, - ); - - expect(result.tools).toHaveLength(2); - expect(result.tools[0].name).toBe("test"); - expect(result.tools[0].inputSchema).toMatchObject({ - type: "object", - properties: { - name: { type: "string" }, - value: { type: "number" }, - }, - }); - expect(result.tools[1].name).toBe("test (new api)"); - expect(result.tools[1].inputSchema).toEqual(result.tools[0].inputSchema); - }); - - /*** - * Test: Tool Registration with Description - */ - test("should register tool with description", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); - const client = new Client({ - name: "test client", - version: "1.0", - }); + expect(listResult.tools[0].inputSchema).toMatchObject({ + properties: { + name: { type: 'string' }, + value: { type: 'number' } + } + }); - // old api - mcpServer.tool("test", "Test description", async () => ({ - content: [ - { - type: "text", - text: "Test response", - }, - ], - })); - - // new api - mcpServer.registerTool( - "test (new api)", - { - description: "Test description", - }, - async () => ({ - content: [ - { - type: "text" as const, - text: "Test response", - }, - ], - }) - ); - - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.server.connect(serverTransport), - ]); - - const result = await client.request( - { - method: "tools/list", - }, - ListToolsResultSchema, - ); - - expect(result.tools).toHaveLength(2); - expect(result.tools[0].name).toBe("test"); - expect(result.tools[0].description).toBe("Test description"); - expect(result.tools[1].name).toBe("test (new api)"); - expect(result.tools[1].description).toBe("Test description"); - }); - - /*** - * Test: Tool Registration with Annotations - */ - test("should register tool with annotations", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); - const client = new Client({ - name: "test client", - version: "1.0", - }); + // Call the tool with the new schema + const callResult = await client.request( + { + method: 'tools/call', + params: { + name: 'test', + arguments: { + name: 'test', + value: 42 + } + } + }, + CallToolResultSchema + ); - mcpServer.tool("test", { title: "Test Tool", readOnlyHint: true }, async () => ({ - content: [ - { - type: "text", - text: "Test response", - }, - ], - })); - - mcpServer.registerTool( - "test (new api)", - { - annotations: { title: "Test Tool", readOnlyHint: true }, - }, - async () => ({ - content: [ - { - type: "text" as const, - text: "Test response", - }, - ], - }) - ); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.server.connect(serverTransport), - ]); - - const result = await client.request( - { - method: "tools/list", - }, - ListToolsResultSchema, - ); - - expect(result.tools).toHaveLength(2); - expect(result.tools[0].name).toBe("test"); - expect(result.tools[0].annotations).toEqual({ title: "Test Tool", readOnlyHint: true }); - expect(result.tools[1].name).toBe("test (new api)"); - expect(result.tools[1].annotations).toEqual({ title: "Test Tool", readOnlyHint: true }); - }); - - /*** - * Test: Tool Registration with Parameters and Annotations - */ - test("should register tool with params and annotations", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); - const client = new Client({ - name: "test client", - version: "1.0", - }); + expect(callResult.content).toEqual([ + { + type: 'text', + text: 'Updated: test, 42' + } + ]); + + // Update happened before transport was connected, so no notifications should be expected + expect(notifications).toHaveLength(0); + }); + + /*** + * Test: Tool List Changed Notifications + */ + test('should send tool list changed notifications when connected', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const notifications: Notification[] = []; + const client = new Client({ + name: 'test client', + version: '1.0' + }); + client.fallbackNotificationHandler = async notification => { + notifications.push(notification); + }; - mcpServer.tool( - "test", - { name: z.string() }, - { title: "Test Tool", readOnlyHint: true }, - async ({ name }) => ({ - content: [{ type: "text", text: `Hello, ${name}!` }] - }) - ); - - mcpServer.registerTool( - "test (new api)", - { - inputSchema: { name: z.string() }, - annotations: { title: "Test Tool", readOnlyHint: true }, - }, - async ({ name }) => ({ - content: [{ type: "text", text: `Hello, ${name}!` }] - }) - ); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.server.connect(serverTransport), - ]); - - const result = await client.request( - { method: "tools/list" }, - ListToolsResultSchema, - ); - - expect(result.tools).toHaveLength(2); - expect(result.tools[0].name).toBe("test"); - expect(result.tools[0].inputSchema).toMatchObject({ - type: "object", - properties: { name: { type: "string" } } - }); - expect(result.tools[0].annotations).toEqual({ title: "Test Tool", readOnlyHint: true }); - expect(result.tools[1].name).toBe("test (new api)"); - expect(result.tools[1].inputSchema).toEqual(result.tools[0].inputSchema); - expect(result.tools[1].annotations).toEqual(result.tools[0].annotations); - }); - - /*** - * Test: Tool Registration with Description, Parameters, and Annotations - */ - test("should register tool with description, params, and annotations", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); - const client = new Client({ - name: "test client", - version: "1.0", - }); + // Register initial tool + const tool = mcpServer.tool('test', async () => ({ + content: [ + { + type: 'text', + text: 'Test response' + } + ] + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); + + expect(notifications).toHaveLength(0); + + // Now update the tool + tool.update({ + callback: async () => ({ + content: [ + { + type: 'text', + text: 'Updated response' + } + ] + }) + }); + + // Yield event loop to let the notification fly + await new Promise(process.nextTick); + + expect(notifications).toMatchObject([{ method: 'notifications/tools/list_changed' }]); + + // Now delete the tool + tool.remove(); + + // Yield event loop to let the notification fly + await new Promise(process.nextTick); + + expect(notifications).toMatchObject([ + { method: 'notifications/tools/list_changed' }, + { method: 'notifications/tools/list_changed' } + ]); + }); + + /*** + * Test: Tool Registration with Parameters + */ + test('should register tool with params', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + // old api + mcpServer.tool( + 'test', + { + name: z.string(), + value: z.number() + }, + async ({ name, value }) => ({ + content: [ + { + type: 'text', + text: `${name}: ${value}` + } + ] + }) + ); + + // new api + mcpServer.registerTool( + 'test (new api)', + { + inputSchema: { name: z.string(), value: z.number() } + }, + async ({ name, value }) => ({ + content: [{ type: 'text', text: `${name}: ${value}` }] + }) + ); - mcpServer.tool( - "test", - "A tool with everything", - { name: z.string() }, - { title: "Complete Test Tool", readOnlyHint: true, openWorldHint: false }, - async ({ name }) => ({ - content: [{ type: "text", text: `Hello, ${name}!` }] - }) - ); - - mcpServer.registerTool( - "test (new api)", - { - description: "A tool with everything", - inputSchema: { name: z.string() }, - annotations: { title: "Complete Test Tool", readOnlyHint: true, openWorldHint: false }, - }, - async ({ name }) => ({ - content: [{ type: "text", text: `Hello, ${name}!` }] - }) - ); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.server.connect(serverTransport), - ]); - - const result = await client.request( - { method: "tools/list" }, - ListToolsResultSchema, - ); - - expect(result.tools).toHaveLength(2); - expect(result.tools[0].name).toBe("test"); - expect(result.tools[0].description).toBe("A tool with everything"); - expect(result.tools[0].inputSchema).toMatchObject({ - type: "object", - properties: { name: { type: "string" } } - }); - expect(result.tools[0].annotations).toEqual({ - title: "Complete Test Tool", - readOnlyHint: true, - openWorldHint: false - }); - expect(result.tools[1].name).toBe("test (new api)"); - expect(result.tools[1].description).toBe("A tool with everything"); - expect(result.tools[1].inputSchema).toEqual(result.tools[0].inputSchema); - expect(result.tools[1].annotations).toEqual(result.tools[0].annotations); - }); - - /*** - * Test: Tool Registration with Description, Empty Parameters, and Annotations - */ - test("should register tool with description, empty params, and annotations", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); - const client = new Client({ - name: "test client", - version: "1.0", - }); + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - mcpServer.tool( - "test", - "A tool with everything but empty params", - {}, - { title: "Complete Test Tool with empty params", readOnlyHint: true, openWorldHint: false }, - async () => ({ - content: [{ type: "text", text: "Test response" }] - }) - ); - - mcpServer.registerTool( - "test (new api)", - { - description: "A tool with everything but empty params", - inputSchema: {}, - annotations: { title: "Complete Test Tool with empty params", readOnlyHint: true, openWorldHint: false }, - }, - async () => ({ - content: [{ type: "text" as const, text: "Test response" }] - }) - ); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.server.connect(serverTransport), - ]); - - const result = await client.request( - { method: "tools/list" }, - ListToolsResultSchema, - ); - - expect(result.tools).toHaveLength(2); - expect(result.tools[0].name).toBe("test"); - expect(result.tools[0].description).toBe("A tool with everything but empty params"); - expect(result.tools[0].inputSchema).toMatchObject({ - type: "object", - properties: {} - }); - expect(result.tools[0].annotations).toEqual({ - title: "Complete Test Tool with empty params", - readOnlyHint: true, - openWorldHint: false - }); - expect(result.tools[1].name).toBe("test (new api)"); - expect(result.tools[1].description).toBe("A tool with everything but empty params"); - expect(result.tools[1].inputSchema).toEqual(result.tools[0].inputSchema); - expect(result.tools[1].annotations).toEqual(result.tools[0].annotations); - }); - - /*** - * Test: Tool Argument Validation - */ - test("should validate tool args", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request( + { + method: 'tools/list' + }, + ListToolsResultSchema + ); + + expect(result.tools).toHaveLength(2); + expect(result.tools[0].name).toBe('test'); + expect(result.tools[0].inputSchema).toMatchObject({ + type: 'object', + properties: { + name: { type: 'string' }, + value: { type: 'number' } + } + }); + expect(result.tools[1].name).toBe('test (new api)'); + expect(result.tools[1].inputSchema).toEqual(result.tools[0].inputSchema); + }); + + /*** + * Test: Tool Registration with Description + */ + test('should register tool with description', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + // old api + mcpServer.tool('test', 'Test description', async () => ({ + content: [ + { + type: 'text', + text: 'Test response' + } + ] + })); + + // new api + mcpServer.registerTool( + 'test (new api)', + { + description: 'Test description' + }, + async () => ({ + content: [ + { + type: 'text' as const, + text: 'Test response' + } + ] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request( + { + method: 'tools/list' + }, + ListToolsResultSchema + ); + + expect(result.tools).toHaveLength(2); + expect(result.tools[0].name).toBe('test'); + expect(result.tools[0].description).toBe('Test description'); + expect(result.tools[1].name).toBe('test (new api)'); + expect(result.tools[1].description).toBe('Test description'); + }); + + /*** + * Test: Tool Registration with Annotations + */ + test('should register tool with annotations', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.tool('test', { title: 'Test Tool', readOnlyHint: true }, async () => ({ + content: [ + { + type: 'text', + text: 'Test response' + } + ] + })); + + mcpServer.registerTool( + 'test (new api)', + { + annotations: { title: 'Test Tool', readOnlyHint: true } + }, + async () => ({ + content: [ + { + type: 'text' as const, + text: 'Test response' + } + ] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request( + { + method: 'tools/list' + }, + ListToolsResultSchema + ); + + expect(result.tools).toHaveLength(2); + expect(result.tools[0].name).toBe('test'); + expect(result.tools[0].annotations).toEqual({ title: 'Test Tool', readOnlyHint: true }); + expect(result.tools[1].name).toBe('test (new api)'); + expect(result.tools[1].annotations).toEqual({ title: 'Test Tool', readOnlyHint: true }); + }); + + /*** + * Test: Tool Registration with Parameters and Annotations + */ + test('should register tool with params and annotations', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.tool('test', { name: z.string() }, { title: 'Test Tool', readOnlyHint: true }, async ({ name }) => ({ + content: [{ type: 'text', text: `Hello, ${name}!` }] + })); + + mcpServer.registerTool( + 'test (new api)', + { + inputSchema: { name: z.string() }, + annotations: { title: 'Test Tool', readOnlyHint: true } + }, + async ({ name }) => ({ + content: [{ type: 'text', text: `Hello, ${name}!` }] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request({ method: 'tools/list' }, ListToolsResultSchema); + + expect(result.tools).toHaveLength(2); + expect(result.tools[0].name).toBe('test'); + expect(result.tools[0].inputSchema).toMatchObject({ + type: 'object', + properties: { name: { type: 'string' } } + }); + expect(result.tools[0].annotations).toEqual({ title: 'Test Tool', readOnlyHint: true }); + expect(result.tools[1].name).toBe('test (new api)'); + expect(result.tools[1].inputSchema).toEqual(result.tools[0].inputSchema); + expect(result.tools[1].annotations).toEqual(result.tools[0].annotations); + }); + + /*** + * Test: Tool Registration with Description, Parameters, and Annotations + */ + test('should register tool with description, params, and annotations', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.tool( + 'test', + 'A tool with everything', + { name: z.string() }, + { title: 'Complete Test Tool', readOnlyHint: true, openWorldHint: false }, + async ({ name }) => ({ + content: [{ type: 'text', text: `Hello, ${name}!` }] + }) + ); + + mcpServer.registerTool( + 'test (new api)', + { + description: 'A tool with everything', + inputSchema: { name: z.string() }, + annotations: { title: 'Complete Test Tool', readOnlyHint: true, openWorldHint: false } + }, + async ({ name }) => ({ + content: [{ type: 'text', text: `Hello, ${name}!` }] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request({ method: 'tools/list' }, ListToolsResultSchema); + + expect(result.tools).toHaveLength(2); + expect(result.tools[0].name).toBe('test'); + expect(result.tools[0].description).toBe('A tool with everything'); + expect(result.tools[0].inputSchema).toMatchObject({ + type: 'object', + properties: { name: { type: 'string' } } + }); + expect(result.tools[0].annotations).toEqual({ + title: 'Complete Test Tool', + readOnlyHint: true, + openWorldHint: false + }); + expect(result.tools[1].name).toBe('test (new api)'); + expect(result.tools[1].description).toBe('A tool with everything'); + expect(result.tools[1].inputSchema).toEqual(result.tools[0].inputSchema); + expect(result.tools[1].annotations).toEqual(result.tools[0].annotations); + }); + + /*** + * Test: Tool Registration with Description, Empty Parameters, and Annotations + */ + test('should register tool with description, empty params, and annotations', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.tool( + 'test', + 'A tool with everything but empty params', + {}, + { title: 'Complete Test Tool with empty params', readOnlyHint: true, openWorldHint: false }, + async () => ({ + content: [{ type: 'text', text: 'Test response' }] + }) + ); + + mcpServer.registerTool( + 'test (new api)', + { + description: 'A tool with everything but empty params', + inputSchema: {}, + annotations: { title: 'Complete Test Tool with empty params', readOnlyHint: true, openWorldHint: false } + }, + async () => ({ + content: [{ type: 'text' as const, text: 'Test response' }] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request({ method: 'tools/list' }, ListToolsResultSchema); + + expect(result.tools).toHaveLength(2); + expect(result.tools[0].name).toBe('test'); + expect(result.tools[0].description).toBe('A tool with everything but empty params'); + expect(result.tools[0].inputSchema).toMatchObject({ + type: 'object', + properties: {} + }); + expect(result.tools[0].annotations).toEqual({ + title: 'Complete Test Tool with empty params', + readOnlyHint: true, + openWorldHint: false + }); + expect(result.tools[1].name).toBe('test (new api)'); + expect(result.tools[1].description).toBe('A tool with everything but empty params'); + expect(result.tools[1].inputSchema).toEqual(result.tools[0].inputSchema); + expect(result.tools[1].annotations).toEqual(result.tools[0].annotations); + }); + + /*** + * Test: Tool Argument Validation + */ + test('should validate tool args', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.tool( + 'test', + { + name: z.string(), + value: z.number() + }, + async ({ name, value }) => ({ + content: [ + { + type: 'text', + text: `${name}: ${value}` + } + ] + }) + ); + + mcpServer.registerTool( + 'test (new api)', + { + inputSchema: { + name: z.string(), + value: z.number() + } + }, + async ({ name, value }) => ({ + content: [ + { + type: 'text', + text: `${name}: ${value}` + } + ] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + await expect( + client.request( + { + method: 'tools/call', + params: { + name: 'test', + arguments: { + name: 'test', + value: 'not a number' + } + } + }, + CallToolResultSchema + ) + ).rejects.toThrow(/Invalid arguments/); + + await expect( + client.request( + { + method: 'tools/call', + params: { + name: 'test (new api)', + arguments: { + name: 'test', + value: 'not a number' + } + } + }, + CallToolResultSchema + ) + ).rejects.toThrow(/Invalid arguments/); + }); + + /*** + * Test: Preventing Duplicate Tool Registration + */ + test('should prevent duplicate tool registration', () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + mcpServer.tool('test', async () => ({ + content: [ + { + type: 'text', + text: 'Test response' + } + ] + })); + + expect(() => { + mcpServer.tool('test', async () => ({ + content: [ + { + type: 'text', + text: 'Test response 2' + } + ] + })); + }).toThrow(/already registered/); + }); + + /*** + * Test: Multiple Tool Registration + */ + test('should allow registering multiple tools', () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + // This should succeed + mcpServer.tool('tool1', () => ({ content: [] })); + + // This should also succeed and not throw about request handlers + mcpServer.tool('tool2', () => ({ content: [] })); + }); + + /*** + * Test: Tool with Output Schema and Structured Content + */ + test('should support tool with outputSchema and structuredContent', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + // Register a tool with outputSchema + mcpServer.registerTool( + 'test', + { + description: 'Test tool with structured output', + inputSchema: { + input: z.string() + }, + outputSchema: { + processedInput: z.string(), + resultType: z.string(), + timestamp: z.string() + } + }, + async ({ input }) => ({ + structuredContent: { + processedInput: input, + resultType: 'structured', + timestamp: '2023-01-01T00:00:00Z' + }, + content: [ + { + type: 'text', + text: JSON.stringify({ + processedInput: input, + resultType: 'structured', + timestamp: '2023-01-01T00:00:00Z' + }) + } + ] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + // Verify the tool registration includes outputSchema + const listResult = await client.request( + { + method: 'tools/list' + }, + ListToolsResultSchema + ); + + expect(listResult.tools).toHaveLength(1); + expect(listResult.tools[0].outputSchema).toMatchObject({ + type: 'object', + properties: { + processedInput: { type: 'string' }, + resultType: { type: 'string' }, + timestamp: { type: 'string' } + }, + required: ['processedInput', 'resultType', 'timestamp'] + }); + + // Call the tool and verify it returns valid structuredContent + const result = await client.request( + { + method: 'tools/call', + params: { + name: 'test', + arguments: { + input: 'hello' + } + } + }, + CallToolResultSchema + ); + + expect(result.structuredContent).toBeDefined(); + const structuredContent = result.structuredContent as { + processedInput: string; + resultType: string; + timestamp: string; + }; + expect(structuredContent.processedInput).toBe('hello'); + expect(structuredContent.resultType).toBe('structured'); + expect(structuredContent.timestamp).toBe('2023-01-01T00:00:00Z'); + + // For backward compatibility, content is auto-generated from structuredContent + expect(result.content).toBeDefined(); + expect(result.content!).toHaveLength(1); + expect(result.content![0]).toMatchObject({ type: 'text' }); + const textContent = result.content![0] as TextContent; + expect(JSON.parse(textContent.text)).toEqual(result.structuredContent); + }); + + /*** + * Test: Tool with Output Schema Must Provide Structured Content + */ + test('should throw error when tool with outputSchema returns no structuredContent', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + // Register a tool with outputSchema that returns only content without structuredContent + mcpServer.registerTool( + 'test', + { + description: 'Test tool with output schema but missing structured content', + inputSchema: { + input: z.string() + }, + outputSchema: { + processedInput: z.string(), + resultType: z.string() + } + }, + async ({ input }) => ({ + // Only return content without structuredContent + content: [ + { + type: 'text', + text: `Processed: ${input}` + } + ] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + // Call the tool and expect it to throw an error + await expect( + client.callTool({ + name: 'test', + arguments: { + input: 'hello' + } + }) + ).rejects.toThrow(/Tool test has an output schema but no structured content was provided/); + }); + /*** + * Test: Tool with Output Schema Must Provide Structured Content + */ + test('should skip outputSchema validation when isError is true', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.registerTool( + 'test', + { + description: 'Test tool with output schema but missing structured content', + inputSchema: { + input: z.string() + }, + outputSchema: { + processedInput: z.string(), + resultType: z.string() + } + }, + async ({ input }) => ({ + content: [ + { + type: 'text', + text: `Processed: ${input}` + } + ], + isError: true + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + await expect( + client.callTool({ + name: 'test', + arguments: { + input: 'hello' + } + }) + ).resolves.toStrictEqual({ + content: [ + { + type: 'text', + text: `Processed: hello` + } + ], + isError: true + }); + }); + + /*** + * Test: Schema Validation Failure for Invalid Structured Content + */ + test('should fail schema validation when tool returns invalid structuredContent', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + // Register a tool with outputSchema that returns invalid data + mcpServer.registerTool( + 'test', + { + description: 'Test tool with invalid structured output', + inputSchema: { + input: z.string() + }, + outputSchema: { + processedInput: z.string(), + resultType: z.string(), + timestamp: z.string() + } + }, + async ({ input }) => ({ + content: [ + { + type: 'text', + text: JSON.stringify({ + processedInput: input, + resultType: 'structured', + // Missing required 'timestamp' field + someExtraField: 'unexpected' // Extra field not in schema + }) + } + ], + structuredContent: { + processedInput: input, + resultType: 'structured', + // Missing required 'timestamp' field + someExtraField: 'unexpected' // Extra field not in schema + } + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + // Call the tool and expect it to throw a server-side validation error + await expect( + client.callTool({ + name: 'test', + arguments: { + input: 'hello' + } + }) + ).rejects.toThrow(/Invalid structured content for tool test/); + }); + + /*** + * Test: Pass Session ID to Tool Callback + */ + test('should pass sessionId to tool callback via RequestHandlerExtra', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + let receivedSessionId: string | undefined; + mcpServer.tool('test-tool', async extra => { + receivedSessionId = extra.sessionId; + return { + content: [ + { + type: 'text', + text: 'Test response' + } + ] + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + // Set a test sessionId on the server transport + serverTransport.sessionId = 'test-session-123'; + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + await client.request( + { + method: 'tools/call', + params: { + name: 'test-tool' + } + }, + CallToolResultSchema + ); + + expect(receivedSessionId).toBe('test-session-123'); }); - const client = new Client({ - name: "test client", - version: "1.0", + + /*** + * Test: Pass Request ID to Tool Callback + */ + test('should pass requestId to tool callback via RequestHandlerExtra', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + let receivedRequestId: string | number | undefined; + mcpServer.tool('request-id-test', async extra => { + receivedRequestId = extra.requestId; + return { + content: [ + { + type: 'text', + text: `Received request ID: ${extra.requestId}` + } + ] + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request( + { + method: 'tools/call', + params: { + name: 'request-id-test' + } + }, + CallToolResultSchema + ); + + expect(receivedRequestId).toBeDefined(); + expect(typeof receivedRequestId === 'string' || typeof receivedRequestId === 'number').toBe(true); + expect(result.content && result.content[0].text).toContain('Received request ID:'); }); - mcpServer.tool( - "test", - { - name: z.string(), - value: z.number(), - }, - async ({ name, value }) => ({ - content: [ - { - type: "text", - text: `${name}: ${value}`, - }, - ], - }), - ); - - mcpServer.registerTool( - "test (new api)", - { - inputSchema: { - name: z.string(), - value: z.number(), - }, - }, - async ({ name, value }) => ({ - content: [ - { - type: "text", - text: `${name}: ${value}`, - }, - ], - }) - ); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.server.connect(serverTransport), - ]); - - await expect( - client.request( - { - method: "tools/call", - params: { - name: "test", - arguments: { - name: "test", - value: "not a number", + /*** + * Test: Send Notification within Tool Call + */ + test('should provide sendNotification within tool call', async () => { + const mcpServer = new McpServer( + { + name: 'test server', + version: '1.0' }, - }, - }, - CallToolResultSchema, - ), - ).rejects.toThrow(/Invalid arguments/); - - await expect( - client.request( - { - method: "tools/call", - params: { - name: "test (new api)", - arguments: { - name: "test", - value: "not a number", + { capabilities: { logging: {} } } + ); + + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + let receivedLogMessage: string | undefined; + const loggingMessage = 'hello here is log message 1'; + + client.setNotificationHandler(LoggingMessageNotificationSchema, notification => { + receivedLogMessage = notification.params.data as string; + }); + + mcpServer.tool('test-tool', async ({ sendNotification }) => { + await sendNotification({ method: 'notifications/message', params: { level: 'debug', data: loggingMessage } }); + return { + content: [ + { + type: 'text', + text: 'Test response' + } + ] + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + await client.request( + { + method: 'tools/call', + params: { + name: 'test-tool' + } }, - }, - }, - CallToolResultSchema, - ), - ).rejects.toThrow(/Invalid arguments/); - }); - - /*** - * Test: Preventing Duplicate Tool Registration - */ - test("should prevent duplicate tool registration", () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); + CallToolResultSchema + ); + expect(receivedLogMessage).toBe(loggingMessage); + }); + + /*** + * Test: Client to Server Tool Call + */ + test('should allow client to call server tools', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.tool( + 'test', + 'Test tool', + { + input: z.string() + }, + async ({ input }) => ({ + content: [ + { + type: 'text', + text: `Processed: ${input}` + } + ] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); - mcpServer.tool("test", async () => ({ - content: [ - { - type: "text", - text: "Test response", - }, - ], - })); - - expect(() => { - mcpServer.tool("test", async () => ({ - content: [ - { - type: "text", - text: "Test response 2", - }, - ], - })); - }).toThrow(/already registered/); - }); - - /*** - * Test: Multiple Tool Registration - */ - test("should allow registering multiple tools", () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", + const result = await client.request( + { + method: 'tools/call', + params: { + name: 'test', + arguments: { + input: 'hello' + } + } + }, + CallToolResultSchema + ); + + expect(result.content).toEqual([ + { + type: 'text', + text: 'Processed: hello' + } + ]); }); - // This should succeed - mcpServer.tool("tool1", () => ({ content: [] })); + /*** + * Test: Graceful Tool Error Handling + */ + test('should handle server tool errors gracefully', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.tool('error-test', async () => { + throw new Error('Tool execution failed'); + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request( + { + method: 'tools/call', + params: { + name: 'error-test' + } + }, + CallToolResultSchema + ); + + expect(result.isError).toBe(true); + expect(result.content).toEqual([ + { + type: 'text', + text: 'Tool execution failed' + } + ]); + }); + + /*** + * Test: McpError for Invalid Tool Name + */ + test('should throw McpError for invalid tool name', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.tool('test-tool', async () => ({ + content: [ + { + type: 'text', + text: 'Test response' + } + ] + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + await expect( + client.request( + { + method: 'tools/call', + params: { + name: 'nonexistent-tool' + } + }, + CallToolResultSchema + ) + ).rejects.toThrow(/Tool nonexistent-tool not found/); + }); + + /*** + * Test: Tool Registration with _meta field + */ + test('should register tool with _meta field and include it in list response', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + const metaData = { + author: 'test-author', + version: '1.2.3', + category: 'utility', + tags: ['test', 'example'] + }; + + mcpServer.registerTool( + 'test-with-meta', + { + description: 'A tool with _meta field', + inputSchema: { name: z.string() }, + _meta: metaData + }, + async ({ name }) => ({ + content: [{ type: 'text', text: `Hello, ${name}!` }] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request({ method: 'tools/list' }, ListToolsResultSchema); + + expect(result.tools).toHaveLength(1); + expect(result.tools[0].name).toBe('test-with-meta'); + expect(result.tools[0].description).toBe('A tool with _meta field'); + expect(result.tools[0]._meta).toEqual(metaData); + }); + + /*** + * Test: Tool Registration without _meta field should have undefined _meta + */ + test('should register tool without _meta field and have undefined _meta in response', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.registerTool( + 'test-without-meta', + { + description: 'A tool without _meta field', + inputSchema: { name: z.string() } + }, + async ({ name }) => ({ + content: [{ type: 'text', text: `Hello, ${name}!` }] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); - // This should also succeed and not throw about request handlers - mcpServer.tool("tool2", () => ({ content: [] })); - }); + const result = await client.request({ method: 'tools/list' }, ListToolsResultSchema); - /*** - * Test: Tool with Output Schema and Structured Content - */ - test("should support tool with outputSchema and structuredContent", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", + expect(result.tools).toHaveLength(1); + expect(result.tools[0].name).toBe('test-without-meta'); + expect(result.tools[0]._meta).toBeUndefined(); }); +}); + +describe('resource()', () => { + /*** + * Test: Resource Registration with URI and Read Callback + */ + test('should register resource with uri and readCallback', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.resource('test', 'test://resource', async () => ({ + contents: [ + { + uri: 'test://resource', + text: 'Test content' + } + ] + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request( + { + method: 'resources/list' + }, + ListResourcesResultSchema + ); + + expect(result.resources).toHaveLength(1); + expect(result.resources[0].name).toBe('test'); + expect(result.resources[0].uri).toBe('test://resource'); + }); + + /*** + * Test: Update Resource with URI + */ + test('should update resource with uri', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const notifications: Notification[] = []; + const client = new Client({ + name: 'test client', + version: '1.0' + }); + client.fallbackNotificationHandler = async notification => { + notifications.push(notification); + }; + + // Register initial resource + const resource = mcpServer.resource('test', 'test://resource', async () => ({ + contents: [ + { + uri: 'test://resource', + text: 'Initial content' + } + ] + })); + + // Update the resource + resource.update({ + callback: async () => ({ + contents: [ + { + uri: 'test://resource', + text: 'Updated content' + } + ] + }) + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); + + // Read the resource and verify we get the updated content + const result = await client.request( + { + method: 'resources/read', + params: { + uri: 'test://resource' + } + }, + ReadResourceResultSchema + ); + + expect(result.contents).toHaveLength(1); + expect(result.contents[0].text).toBe('Updated content'); + + // Update happened before transport was connected, so no notifications should be expected + expect(notifications).toHaveLength(0); + }); + + /*** + * Test: Update Resource Template + */ + test('should update resource template', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const notifications: Notification[] = []; + const client = new Client({ + name: 'test client', + version: '1.0' + }); + client.fallbackNotificationHandler = async notification => { + notifications.push(notification); + }; + + // Register initial resource template + const resourceTemplate = mcpServer.resource( + 'test', + new ResourceTemplate('test://resource/{id}', { list: undefined }), + async uri => ({ + contents: [ + { + uri: uri.href, + text: 'Initial content' + } + ] + }) + ); + + // Update the resource template + resourceTemplate.update({ + callback: async uri => ({ + contents: [ + { + uri: uri.href, + text: 'Updated content' + } + ] + }) + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); + + // Read the resource and verify we get the updated content + const result = await client.request( + { + method: 'resources/read', + params: { + uri: 'test://resource/123' + } + }, + ReadResourceResultSchema + ); + + expect(result.contents).toHaveLength(1); + expect(result.contents[0].text).toBe('Updated content'); + + // Update happened before transport was connected, so no notifications should be expected + expect(notifications).toHaveLength(0); + }); + + /*** + * Test: Resource List Changed Notification + */ + test('should send resource list changed notification when connected', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const notifications: Notification[] = []; + const client = new Client({ + name: 'test client', + version: '1.0' + }); + client.fallbackNotificationHandler = async notification => { + notifications.push(notification); + }; + + // Register initial resource + const resource = mcpServer.resource('test', 'test://resource', async () => ({ + contents: [ + { + uri: 'test://resource', + text: 'Test content' + } + ] + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); + + expect(notifications).toHaveLength(0); + + // Now update the resource while connected + resource.update({ + callback: async () => ({ + contents: [ + { + uri: 'test://resource', + text: 'Updated content' + } + ] + }) + }); + + // Yield event loop to let the notification fly + await new Promise(process.nextTick); + + expect(notifications).toMatchObject([{ method: 'notifications/resources/list_changed' }]); + }); + + /*** + * Test: Remove Resource and Send Notification + */ + test('should remove resource and send notification when connected', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const notifications: Notification[] = []; + const client = new Client({ + name: 'test client', + version: '1.0' + }); + client.fallbackNotificationHandler = async notification => { + notifications.push(notification); + }; + + // Register initial resources + const resource1 = mcpServer.resource('resource1', 'test://resource1', async () => ({ + contents: [{ uri: 'test://resource1', text: 'Resource 1 content' }] + })); + + mcpServer.resource('resource2', 'test://resource2', async () => ({ + contents: [{ uri: 'test://resource2', text: 'Resource 2 content' }] + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); + + // Verify both resources are registered + let result = await client.request({ method: 'resources/list' }, ListResourcesResultSchema); + + expect(result.resources).toHaveLength(2); + + expect(notifications).toHaveLength(0); + + // Remove a resource + resource1.remove(); + + // Yield event loop to let the notification fly + await new Promise(process.nextTick); + + // Should have sent notification + expect(notifications).toMatchObject([{ method: 'notifications/resources/list_changed' }]); - const client = new Client({ - name: "test client", - version: "1.0", + // Verify the resource was removed + result = await client.request({ method: 'resources/list' }, ListResourcesResultSchema); + + expect(result.resources).toHaveLength(1); + expect(result.resources[0].uri).toBe('test://resource2'); }); - // Register a tool with outputSchema - mcpServer.registerTool( - "test", - { - description: "Test tool with structured output", - inputSchema: { - input: z.string(), - }, - outputSchema: { - processedInput: z.string(), - resultType: z.string(), - timestamp: z.string() - }, - }, - async ({ input }) => ({ - structuredContent: { - processedInput: input, - resultType: "structured", - timestamp: "2023-01-01T00:00:00Z" - }, - content: [ - { - type: "text", - text: JSON.stringify({ - processedInput: input, - resultType: "structured", - timestamp: "2023-01-01T00:00:00Z" + /*** + * Test: Remove Resource Template and Send Notification + */ + test('should remove resource template and send notification when connected', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const notifications: Notification[] = []; + const client = new Client({ + name: 'test client', + version: '1.0' + }); + client.fallbackNotificationHandler = async notification => { + notifications.push(notification); + }; + + // Register resource template + const resourceTemplate = mcpServer.resource( + 'template', + new ResourceTemplate('test://resource/{id}', { list: undefined }), + async uri => ({ + contents: [ + { + uri: uri.href, + text: 'Template content' + } + ] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); + + // Verify template is registered + const result = await client.request({ method: 'resources/templates/list' }, ListResourceTemplatesResultSchema); + + expect(result.resourceTemplates).toHaveLength(1); + expect(notifications).toHaveLength(0); + + // Remove the template + resourceTemplate.remove(); + + // Yield event loop to let the notification fly + await new Promise(process.nextTick); + + // Should have sent notification + expect(notifications).toMatchObject([{ method: 'notifications/resources/list_changed' }]); + + // Verify the template was removed + const result2 = await client.request({ method: 'resources/templates/list' }, ListResourceTemplatesResultSchema); + + expect(result2.resourceTemplates).toHaveLength(0); + }); + + /*** + * Test: Resource Registration with Metadata + */ + test('should register resource with metadata', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.resource( + 'test', + 'test://resource', + { + description: 'Test resource', + mimeType: 'text/plain' + }, + async () => ({ + contents: [ + { + uri: 'test://resource', + text: 'Test content' + } + ] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request( + { + method: 'resources/list' + }, + ListResourcesResultSchema + ); + + expect(result.resources).toHaveLength(1); + expect(result.resources[0].description).toBe('Test resource'); + expect(result.resources[0].mimeType).toBe('text/plain'); + }); + + /*** + * Test: Resource Template Registration + */ + test('should register resource template', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.resource('test', new ResourceTemplate('test://resource/{id}', { list: undefined }), async () => ({ + contents: [ + { + uri: 'test://resource/123', + text: 'Test content' + } + ] + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request( + { + method: 'resources/templates/list' + }, + ListResourceTemplatesResultSchema + ); + + expect(result.resourceTemplates).toHaveLength(1); + expect(result.resourceTemplates[0].name).toBe('test'); + expect(result.resourceTemplates[0].uriTemplate).toBe('test://resource/{id}'); + }); + + /*** + * Test: Resource Template with List Callback + */ + test('should register resource template with listCallback', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.resource( + 'test', + new ResourceTemplate('test://resource/{id}', { + list: async () => ({ + resources: [ + { + name: 'Resource 1', + uri: 'test://resource/1' + }, + { + name: 'Resource 2', + uri: 'test://resource/2' + } + ] + }) }), - }, - ] - }) - ); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.server.connect(serverTransport), - ]); - - // Verify the tool registration includes outputSchema - const listResult = await client.request( - { - method: "tools/list", - }, - ListToolsResultSchema, - ); - - expect(listResult.tools).toHaveLength(1); - expect(listResult.tools[0].outputSchema).toMatchObject({ - type: "object", - properties: { - processedInput: { type: "string" }, - resultType: { type: "string" }, - timestamp: { type: "string" } - }, - required: ["processedInput", "resultType", "timestamp"] - }); + async uri => ({ + contents: [ + { + uri: uri.href, + text: 'Test content' + } + ] + }) + ); - // Call the tool and verify it returns valid structuredContent - const result = await client.request( - { - method: "tools/call", - params: { - name: "test", - arguments: { - input: "hello", - }, - }, - }, - CallToolResultSchema, - ); - - expect(result.structuredContent).toBeDefined(); - const structuredContent = result.structuredContent as { - processedInput: string; - resultType: string; - timestamp: string; - }; - expect(structuredContent.processedInput).toBe("hello"); - expect(structuredContent.resultType).toBe("structured"); - expect(structuredContent.timestamp).toBe("2023-01-01T00:00:00Z"); - - // For backward compatibility, content is auto-generated from structuredContent - expect(result.content).toBeDefined(); - expect(result.content!).toHaveLength(1); - expect(result.content![0]).toMatchObject({ type: "text" }); - const textContent = result.content![0] as TextContent; - expect(JSON.parse(textContent.text)).toEqual(result.structuredContent); - }); - - /*** - * Test: Tool with Output Schema Must Provide Structured Content - */ - test("should throw error when tool with outputSchema returns no structuredContent", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - const client = new Client({ - name: "test client", - version: "1.0", - }); + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); - // Register a tool with outputSchema that returns only content without structuredContent - mcpServer.registerTool( - "test", - { - description: "Test tool with output schema but missing structured content", - inputSchema: { - input: z.string(), - }, - outputSchema: { - processedInput: z.string(), - resultType: z.string(), - }, - }, - async ({ input }) => ({ - // Only return content without structuredContent - content: [ - { - type: "text", - text: `Processed: ${input}`, - }, - ], - }) - ); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.server.connect(serverTransport), - ]); - - // Call the tool and expect it to throw an error - await expect( - client.callTool({ - name: "test", - arguments: { - input: "hello", - }, - }), - ).rejects.toThrow(/Tool test has an output schema but no structured content was provided/); - }); - /*** - * Test: Tool with Output Schema Must Provide Structured Content - */ - test("should skip outputSchema validation when isError is true", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); + const result = await client.request( + { + method: 'resources/list' + }, + ListResourcesResultSchema + ); + + expect(result.resources).toHaveLength(2); + expect(result.resources[0].name).toBe('Resource 1'); + expect(result.resources[0].uri).toBe('test://resource/1'); + expect(result.resources[1].name).toBe('Resource 2'); + expect(result.resources[1].uri).toBe('test://resource/2'); + }); + + /*** + * Test: Template Variables to Read Callback + */ + test('should pass template variables to readCallback', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.resource( + 'test', + new ResourceTemplate('test://resource/{category}/{id}', { + list: undefined + }), + async (uri, { category, id }) => ({ + contents: [ + { + uri: uri.href, + text: `Category: ${category}, ID: ${id}` + } + ] + }) + ); - const client = new Client({ - name: "test client", - version: "1.0", - }); + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - mcpServer.registerTool( - "test", - { - description: "Test tool with output schema but missing structured content", - inputSchema: { - input: z.string(), - }, - outputSchema: { - processedInput: z.string(), - resultType: z.string(), - }, - }, - async ({ input }) => ({ - content: [ - { - type: "text", - text: `Processed: ${input}`, - }, - ], - isError: true, - }) - ); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.server.connect(serverTransport), - ]); - - await expect( - client.callTool({ - name: "test", - arguments: { - input: "hello", - }, - }), - ).resolves.toStrictEqual({ - content: [ - { - type: "text", - text: `Processed: hello`, - }, - ], - isError: true, - }); - }); - - /*** - * Test: Schema Validation Failure for Invalid Structured Content - */ - test("should fail schema validation when tool returns invalid structuredContent", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request( + { + method: 'resources/read', + params: { + uri: 'test://resource/books/123' + } + }, + ReadResourceResultSchema + ); + + expect(result.contents[0].text).toBe('Category: books, ID: 123'); + }); + + /*** + * Test: Preventing Duplicate Resource Registration + */ + test('should prevent duplicate resource registration', () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + mcpServer.resource('test', 'test://resource', async () => ({ + contents: [ + { + uri: 'test://resource', + text: 'Test content' + } + ] + })); + + expect(() => { + mcpServer.resource('test2', 'test://resource', async () => ({ + contents: [ + { + uri: 'test://resource', + text: 'Test content 2' + } + ] + })); + }).toThrow(/already registered/); + }); + + /*** + * Test: Multiple Resource Registration + */ + test('should allow registering multiple resources', () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + // This should succeed + mcpServer.resource('resource1', 'test://resource1', async () => ({ + contents: [ + { + uri: 'test://resource1', + text: 'Test content 1' + } + ] + })); + + // This should also succeed and not throw about request handlers + mcpServer.resource('resource2', 'test://resource2', async () => ({ + contents: [ + { + uri: 'test://resource2', + text: 'Test content 2' + } + ] + })); + }); + + /*** + * Test: Preventing Duplicate Resource Template Registration + */ + test('should prevent duplicate resource template registration', () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + mcpServer.resource('test', new ResourceTemplate('test://resource/{id}', { list: undefined }), async () => ({ + contents: [ + { + uri: 'test://resource/123', + text: 'Test content' + } + ] + })); + + expect(() => { + mcpServer.resource('test', new ResourceTemplate('test://resource/{id}', { list: undefined }), async () => ({ + contents: [ + { + uri: 'test://resource/123', + text: 'Test content 2' + } + ] + })); + }).toThrow(/already registered/); + }); + + /*** + * Test: Graceful Resource Read Error Handling + */ + test('should handle resource read errors gracefully', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.resource('error-test', 'test://error', async () => { + throw new Error('Resource read failed'); + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + await expect( + client.request( + { + method: 'resources/read', + params: { + uri: 'test://error' + } + }, + ReadResourceResultSchema + ) + ).rejects.toThrow(/Resource read failed/); + }); + + /*** + * Test: McpError for Invalid Resource URI + */ + test('should throw McpError for invalid resource URI', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.resource('test', 'test://resource', async () => ({ + contents: [ + { + uri: 'test://resource', + text: 'Test content' + } + ] + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + await expect( + client.request( + { + method: 'resources/read', + params: { + uri: 'test://nonexistent' + } + }, + ReadResourceResultSchema + ) + ).rejects.toThrow(/Resource test:\/\/nonexistent not found/); + }); + + /*** + * Test: Registering a resource template with a complete callback should update server capabilities to advertise support for completion + */ + test('should advertise support for completion when a resource template with a complete callback is defined', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.resource( + 'test', + new ResourceTemplate('test://resource/{category}', { + list: undefined, + complete: { + category: () => ['books', 'movies', 'music'] + } + }), + async () => ({ + contents: [ + { + uri: 'test://resource/test', + text: 'Test content' + } + ] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + expect(client.getServerCapabilities()).toMatchObject({ completions: {} }); + }); + + /*** + * Test: Resource Template Parameter Completion + */ + test('should support completion of resource template parameters', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.resource( + 'test', + new ResourceTemplate('test://resource/{category}', { + list: undefined, + complete: { + category: () => ['books', 'movies', 'music'] + } + }), + async () => ({ + contents: [ + { + uri: 'test://resource/test', + text: 'Test content' + } + ] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request( + { + method: 'completion/complete', + params: { + ref: { + type: 'ref/resource', + uri: 'test://resource/{category}' + }, + argument: { + name: 'category', + value: '' + } + } + }, + CompleteResultSchema + ); + + expect(result.completion.values).toEqual(['books', 'movies', 'music']); + expect(result.completion.total).toBe(3); + }); + + /*** + * Test: Filtered Resource Template Parameter Completion + */ + test('should support filtered completion of resource template parameters', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.resource( + 'test', + new ResourceTemplate('test://resource/{category}', { + list: undefined, + complete: { + category: (test: string) => ['books', 'movies', 'music'].filter(value => value.startsWith(test)) + } + }), + async () => ({ + contents: [ + { + uri: 'test://resource/test', + text: 'Test content' + } + ] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request( + { + method: 'completion/complete', + params: { + ref: { + type: 'ref/resource', + uri: 'test://resource/{category}' + }, + argument: { + name: 'category', + value: 'm' + } + } + }, + CompleteResultSchema + ); + + expect(result.completion.values).toEqual(['movies', 'music']); + expect(result.completion.total).toBe(2); + }); + + /*** + * Test: Pass Request ID to Resource Callback + */ + test('should pass requestId to resource callback via RequestHandlerExtra', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + let receivedRequestId: string | number | undefined; + mcpServer.resource('request-id-test', 'test://resource', async (_uri, extra) => { + receivedRequestId = extra.requestId; + return { + contents: [ + { + uri: 'test://resource', + text: `Received request ID: ${extra.requestId}` + } + ] + }; + }); - const client = new Client({ - name: "test client", - version: "1.0", - }); + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - // Register a tool with outputSchema that returns invalid data - mcpServer.registerTool( - "test", - { - description: "Test tool with invalid structured output", - inputSchema: { - input: z.string(), - }, - outputSchema: { - processedInput: z.string(), - resultType: z.string(), - timestamp: z.string() - }, - }, - async ({ input }) => ({ - content: [ - { - type: "text", - text: JSON.stringify({ - processedInput: input, - resultType: "structured", - // Missing required 'timestamp' field - someExtraField: "unexpected" // Extra field not in schema - }), - }, - ], - structuredContent: { - processedInput: input, - resultType: "structured", - // Missing required 'timestamp' field - someExtraField: "unexpected" // Extra field not in schema - }, - }) - ); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.server.connect(serverTransport), - ]); - - // Call the tool and expect it to throw a server-side validation error - await expect( - client.callTool({ - name: "test", - arguments: { - input: "hello", - }, - }), - ).rejects.toThrow(/Invalid structured content for tool test/); - }); - - /*** - * Test: Pass Session ID to Tool Callback - */ - test("should pass sessionId to tool callback via RequestHandlerExtra", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); - const client = new Client({ - name: "test client", - version: "1.0", - }); + const result = await client.request( + { + method: 'resources/read', + params: { + uri: 'test://resource' + } + }, + ReadResourceResultSchema + ); - let receivedSessionId: string | undefined; - mcpServer.tool("test-tool", async (extra) => { - receivedSessionId = extra.sessionId; - return { - content: [ - { - type: "text", - text: "Test response", - }, - ], - }; + expect(receivedRequestId).toBeDefined(); + expect(typeof receivedRequestId === 'string' || typeof receivedRequestId === 'number').toBe(true); + expect(result.contents[0].text).toContain('Received request ID:'); }); +}); - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - // Set a test sessionId on the server transport - serverTransport.sessionId = "test-session-123"; - - await Promise.all([ - client.connect(clientTransport), - mcpServer.server.connect(serverTransport), - ]); - - await client.request( - { - method: "tools/call", - params: { - name: "test-tool", - }, - }, - CallToolResultSchema, - ); - - expect(receivedSessionId).toBe("test-session-123"); - }); - - /*** - * Test: Pass Request ID to Tool Callback - */ - test("should pass requestId to tool callback via RequestHandlerExtra", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); +describe('prompt()', () => { + /*** + * Test: Zero-Argument Prompt Registration + */ + test('should register zero-argument prompt', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.prompt('test', async () => ({ + messages: [ + { + role: 'assistant', + content: { + type: 'text', + text: 'Test response' + } + } + ] + })); - const client = new Client({ - name: "test client", - version: "1.0", - }); + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - let receivedRequestId: string | number | undefined; - mcpServer.tool("request-id-test", async (extra) => { - receivedRequestId = extra.requestId; - return { - content: [ - { - type: "text", - text: `Received request ID: ${extra.requestId}`, - }, - ], - }; - }); + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.server.connect(serverTransport), - ]); - - const result = await client.request( - { - method: "tools/call", - params: { - name: "request-id-test", - }, - }, - CallToolResultSchema, - ); - - expect(receivedRequestId).toBeDefined(); - expect(typeof receivedRequestId === 'string' || typeof receivedRequestId === 'number').toBe(true); - expect(result.content && result.content[0].text).toContain("Received request ID:"); - }); - - /*** - * Test: Send Notification within Tool Call - */ - test("should provide sendNotification within tool call", async () => { - const mcpServer = new McpServer( - { - name: "test server", - version: "1.0", - }, - { capabilities: { logging: {} } }, - ); - - const client = new Client({ - name: "test client", - version: "1.0", - }); + const result = await client.request( + { + method: 'prompts/list' + }, + ListPromptsResultSchema + ); + + expect(result.prompts).toHaveLength(1); + expect(result.prompts[0].name).toBe('test'); + expect(result.prompts[0].arguments).toBeUndefined(); + }); + /*** + * Test: Updating Existing Prompt + */ + test('should update existing prompt', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const notifications: Notification[] = []; + const client = new Client({ + name: 'test client', + version: '1.0' + }); + client.fallbackNotificationHandler = async notification => { + notifications.push(notification); + }; - let receivedLogMessage: string | undefined; - const loggingMessage = "hello here is log message 1"; + // Register initial prompt + const prompt = mcpServer.prompt('test', async () => ({ + messages: [ + { + role: 'assistant', + content: { + type: 'text', + text: 'Initial response' + } + } + ] + })); + + // Update the prompt + prompt.update({ + callback: async () => ({ + messages: [ + { + role: 'assistant', + content: { + type: 'text', + text: 'Updated response' + } + } + ] + }) + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); + + // Call the prompt and verify we get the updated response + const result = await client.request( + { + method: 'prompts/get', + params: { + name: 'test' + } + }, + GetPromptResultSchema + ); + + expect(result.messages).toHaveLength(1); + expect(result.messages[0].content.text).toBe('Updated response'); + + // Update happened before transport was connected, so no notifications should be expected + expect(notifications).toHaveLength(0); + }); + + /*** + * Test: Updating Prompt with Schema + */ + test('should update prompt with schema', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const notifications: Notification[] = []; + const client = new Client({ + name: 'test client', + version: '1.0' + }); + client.fallbackNotificationHandler = async notification => { + notifications.push(notification); + }; - client.setNotificationHandler(LoggingMessageNotificationSchema, (notification) => { - receivedLogMessage = notification.params.data as string; - }); + // Register initial prompt + const prompt = mcpServer.prompt( + 'test', + { + name: z.string() + }, + async ({ name }) => ({ + messages: [ + { + role: 'assistant', + content: { + type: 'text', + text: `Initial: ${name}` + } + } + ] + }) + ); + + // Update the prompt with a different schema + prompt.update({ + argsSchema: { + name: z.string(), + value: z.string() + }, + callback: async ({ name, value }) => ({ + messages: [ + { + role: 'assistant', + content: { + type: 'text', + text: `Updated: ${name}, ${value}` + } + } + ] + }) + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); + + // Verify the schema was updated + const listResult = await client.request( + { + method: 'prompts/list' + }, + ListPromptsResultSchema + ); - mcpServer.tool("test-tool", async ({ sendNotification }) => { - await sendNotification({ method: "notifications/message", params: { level: "debug", data: loggingMessage } }); - return { - content: [ - { - type: "text", - text: "Test response", - }, - ], - }; - }); + expect(listResult.prompts[0].arguments).toHaveLength(2); + expect(listResult.prompts[0].arguments?.map(a => a.name).sort()).toEqual(['name', 'value']); - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - await Promise.all([ - client.connect(clientTransport), - mcpServer.server.connect(serverTransport), - ]); - await client.request( - { - method: "tools/call", - params: { - name: "test-tool", - }, - }, - CallToolResultSchema, - ); - expect(receivedLogMessage).toBe(loggingMessage); - }); - - /*** - * Test: Client to Server Tool Call - */ - test("should allow client to call server tools", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); + // Call the prompt with the new schema + const getResult = await client.request( + { + method: 'prompts/get', + params: { + name: 'test', + arguments: { + name: 'test', + value: 'value' + } + } + }, + GetPromptResultSchema + ); + + expect(getResult.messages).toHaveLength(1); + expect(getResult.messages[0].content.text).toBe('Updated: test, value'); + + // Update happened before transport was connected, so no notifications should be expected + expect(notifications).toHaveLength(0); + }); + + /*** + * Test: Prompt List Changed Notification + */ + test('should send prompt list changed notification when connected', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const notifications: Notification[] = []; + const client = new Client({ + name: 'test client', + version: '1.0' + }); + client.fallbackNotificationHandler = async notification => { + notifications.push(notification); + }; - const client = new Client({ - name: "test client", - version: "1.0", - }); + // Register initial prompt + const prompt = mcpServer.prompt('test', async () => ({ + messages: [ + { + role: 'assistant', + content: { + type: 'text', + text: 'Test response' + } + } + ] + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); + + expect(notifications).toHaveLength(0); + + // Now update the prompt while connected + prompt.update({ + callback: async () => ({ + messages: [ + { + role: 'assistant', + content: { + type: 'text', + text: 'Updated response' + } + } + ] + }) + }); + + // Yield event loop to let the notification fly + await new Promise(process.nextTick); + + expect(notifications).toMatchObject([{ method: 'notifications/prompts/list_changed' }]); + }); + + /*** + * Test: Remove Prompt and Send Notification + */ + test('should remove prompt and send notification when connected', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const notifications: Notification[] = []; + const client = new Client({ + name: 'test client', + version: '1.0' + }); + client.fallbackNotificationHandler = async notification => { + notifications.push(notification); + }; - mcpServer.tool( - "test", - "Test tool", - { - input: z.string(), - }, - async ({ input }) => ({ - content: [ - { - type: "text", - text: `Processed: ${input}`, - }, - ], - }), - ); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.server.connect(serverTransport), - ]); - - const result = await client.request( - { - method: "tools/call", - params: { - name: "test", - arguments: { - input: "hello", - }, - }, - }, - CallToolResultSchema, - ); - - expect(result.content).toEqual([ - { - type: "text", - text: "Processed: hello", - }, - ]); - }); - - /*** - * Test: Graceful Tool Error Handling - */ - test("should handle server tool errors gracefully", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); + // Register initial prompts + const prompt1 = mcpServer.prompt('prompt1', async () => ({ + messages: [ + { + role: 'assistant', + content: { + type: 'text', + text: 'Prompt 1 response' + } + } + ] + })); + + mcpServer.prompt('prompt2', async () => ({ + messages: [ + { + role: 'assistant', + content: { + type: 'text', + text: 'Prompt 2 response' + } + } + ] + })); - const client = new Client({ - name: "test client", - version: "1.0", - }); + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - mcpServer.tool("error-test", async () => { - throw new Error("Tool execution failed"); - }); + await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.server.connect(serverTransport), - ]); - - const result = await client.request( - { - method: "tools/call", - params: { - name: "error-test", - }, - }, - CallToolResultSchema, - ); - - expect(result.isError).toBe(true); - expect(result.content).toEqual([ - { - type: "text", - text: "Tool execution failed", - }, - ]); - }); - - /*** - * Test: McpError for Invalid Tool Name - */ - test("should throw McpError for invalid tool name", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); + // Verify both prompts are registered + let result = await client.request({ method: 'prompts/list' }, ListPromptsResultSchema); - const client = new Client({ - name: "test client", - version: "1.0", - }); + expect(result.prompts).toHaveLength(2); + expect(result.prompts.map(p => p.name).sort()).toEqual(['prompt1', 'prompt2']); - mcpServer.tool("test-tool", async () => ({ - content: [ - { - type: "text", - text: "Test response", - }, - ], - })); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.server.connect(serverTransport), - ]); - - await expect( - client.request( - { - method: "tools/call", - params: { - name: "nonexistent-tool", - }, - }, - CallToolResultSchema, - ), - ).rejects.toThrow(/Tool nonexistent-tool not found/); - }); - - /*** - * Test: Tool Registration with _meta field - */ - test("should register tool with _meta field and include it in list response", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); - const client = new Client({ - name: "test client", - version: "1.0", - }); + expect(notifications).toHaveLength(0); - const metaData = { - author: "test-author", - version: "1.2.3", - category: "utility", - tags: ["test", "example"] - }; - - mcpServer.registerTool( - "test-with-meta", - { - description: "A tool with _meta field", - inputSchema: { name: z.string() }, - _meta: metaData, - }, - async ({ name }) => ({ - content: [{ type: "text", text: `Hello, ${name}!` }] - }) - ); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.server.connect(serverTransport), - ]); - - const result = await client.request( - { method: "tools/list" }, - ListToolsResultSchema, - ); - - expect(result.tools).toHaveLength(1); - expect(result.tools[0].name).toBe("test-with-meta"); - expect(result.tools[0].description).toBe("A tool with _meta field"); - expect(result.tools[0]._meta).toEqual(metaData); - }); - - /*** - * Test: Tool Registration without _meta field should have undefined _meta - */ - test("should register tool without _meta field and have undefined _meta in response", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); - const client = new Client({ - name: "test client", - version: "1.0", - }); + // Remove a prompt + prompt1.remove(); - mcpServer.registerTool( - "test-without-meta", - { - description: "A tool without _meta field", - inputSchema: { name: z.string() }, - }, - async ({ name }) => ({ - content: [{ type: "text", text: `Hello, ${name}!` }] - }) - ); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.server.connect(serverTransport), - ]); - - const result = await client.request( - { method: "tools/list" }, - ListToolsResultSchema, - ); - - expect(result.tools).toHaveLength(1); - expect(result.tools[0].name).toBe("test-without-meta"); - expect(result.tools[0]._meta).toBeUndefined(); - }); -}); + // Yield event loop to let the notification fly + await new Promise(process.nextTick); -describe("resource()", () => { - /*** - * Test: Resource Registration with URI and Read Callback - */ - test("should register resource with uri and readCallback", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); - const client = new Client({ - name: "test client", - version: "1.0", - }); + // Should have sent notification + expect(notifications).toMatchObject([{ method: 'notifications/prompts/list_changed' }]); - mcpServer.resource("test", "test://resource", async () => ({ - contents: [ - { - uri: "test://resource", - text: "Test content", - }, - ], - })); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.server.connect(serverTransport), - ]); - - const result = await client.request( - { - method: "resources/list", - }, - ListResourcesResultSchema, - ); - - expect(result.resources).toHaveLength(1); - expect(result.resources[0].name).toBe("test"); - expect(result.resources[0].uri).toBe("test://resource"); - }); - - /*** - * Test: Update Resource with URI - */ - test("should update resource with uri", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); - const notifications: Notification[] = []; - const client = new Client({ - name: "test client", - version: "1.0", - }); - client.fallbackNotificationHandler = async (notification) => { - notifications.push(notification); - }; - - // Register initial resource - const resource = mcpServer.resource("test", "test://resource", async () => ({ - contents: [ - { - uri: "test://resource", - text: "Initial content", - }, - ], - })); - - // Update the resource - resource.update({ - callback: async () => ({ - contents: [ - { - uri: "test://resource", - text: "Updated content", - }, - ], - }) - }); + // Verify the prompt was removed + result = await client.request({ method: 'prompts/list' }, ListPromptsResultSchema); - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.connect(serverTransport), - ]); - - // Read the resource and verify we get the updated content - const result = await client.request( - { - method: "resources/read", - params: { - uri: "test://resource", - }, - }, - ReadResourceResultSchema, - ); - - expect(result.contents).toHaveLength(1); - expect(result.contents[0].text).toBe("Updated content"); - - // Update happened before transport was connected, so no notifications should be expected - expect(notifications).toHaveLength(0); - }); - - /*** - * Test: Update Resource Template - */ - test("should update resource template", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); - const notifications: Notification[] = []; - const client = new Client({ - name: "test client", - version: "1.0", - }); - client.fallbackNotificationHandler = async (notification) => { - notifications.push(notification); - }; - - // Register initial resource template - const resourceTemplate = mcpServer.resource( - "test", - new ResourceTemplate("test://resource/{id}", { list: undefined }), - async (uri) => ({ - contents: [ - { - uri: uri.href, - text: "Initial content", - }, - ], - }), - ); - - // Update the resource template - resourceTemplate.update({ - callback: async (uri) => ({ - contents: [ - { - uri: uri.href, - text: "Updated content", - }, - ], - }) + expect(result.prompts).toHaveLength(1); + expect(result.prompts[0].name).toBe('prompt2'); }); - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.connect(serverTransport), - ]); - - // Read the resource and verify we get the updated content - const result = await client.request( - { - method: "resources/read", - params: { - uri: "test://resource/123", - }, - }, - ReadResourceResultSchema, - ); - - expect(result.contents).toHaveLength(1); - expect(result.contents[0].text).toBe("Updated content"); - - // Update happened before transport was connected, so no notifications should be expected - expect(notifications).toHaveLength(0); - }); - - /*** - * Test: Resource List Changed Notification - */ - test("should send resource list changed notification when connected", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); - const notifications: Notification[] = []; - const client = new Client({ - name: "test client", - version: "1.0", - }); - client.fallbackNotificationHandler = async (notification) => { - notifications.push(notification); - }; - - // Register initial resource - const resource = mcpServer.resource("test", "test://resource", async () => ({ - contents: [ - { - uri: "test://resource", - text: "Test content", - }, - ], - })); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.connect(serverTransport), - ]); - - expect(notifications).toHaveLength(0); - - // Now update the resource while connected - resource.update({ - callback: async () => ({ - contents: [ - { - uri: "test://resource", - text: "Updated content", - }, - ], - }) - }); + /*** + * Test: Prompt Registration with Arguments Schema + */ + test('should register prompt with args schema', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); - // Yield event loop to let the notification fly - await new Promise(process.nextTick); - - expect(notifications).toMatchObject([ - { method: "notifications/resources/list_changed" } - ]); - }); - - /*** - * Test: Remove Resource and Send Notification - */ - test("should remove resource and send notification when connected", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); - const notifications: Notification[] = []; - const client = new Client({ - name: "test client", - version: "1.0", - }); - client.fallbackNotificationHandler = async (notification) => { - notifications.push(notification); - }; - - // Register initial resources - const resource1 = mcpServer.resource("resource1", "test://resource1", async () => ({ - contents: [{ uri: "test://resource1", text: "Resource 1 content" }], - })); - - mcpServer.resource("resource2", "test://resource2", async () => ({ - contents: [{ uri: "test://resource2", text: "Resource 2 content" }], - })); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.connect(serverTransport), - ]); - - // Verify both resources are registered - let result = await client.request( - { method: "resources/list" }, - ListResourcesResultSchema, - ); - - expect(result.resources).toHaveLength(2); - - expect(notifications).toHaveLength(0); - - // Remove a resource - resource1.remove() - - // Yield event loop to let the notification fly - await new Promise(process.nextTick); - - // Should have sent notification - expect(notifications).toMatchObject([ - { method: "notifications/resources/list_changed" } - ]); - - // Verify the resource was removed - result = await client.request( - { method: "resources/list" }, - ListResourcesResultSchema, - ); - - expect(result.resources).toHaveLength(1); - expect(result.resources[0].uri).toBe("test://resource2"); - }); - - /*** - * Test: Remove Resource Template and Send Notification - */ - test("should remove resource template and send notification when connected", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); - const notifications: Notification[] = []; - const client = new Client({ - name: "test client", - version: "1.0", - }); - client.fallbackNotificationHandler = async (notification) => { - notifications.push(notification); - }; - - // Register resource template - const resourceTemplate = mcpServer.resource( - "template", - new ResourceTemplate("test://resource/{id}", { list: undefined }), - async (uri) => ({ - contents: [ - { - uri: uri.href, - text: "Template content", - }, - ], - }), - ); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.connect(serverTransport), - ]); - - // Verify template is registered - const result = await client.request( - { method: "resources/templates/list" }, - ListResourceTemplatesResultSchema, - ); - - expect(result.resourceTemplates).toHaveLength(1); - expect(notifications).toHaveLength(0); - - // Remove the template - resourceTemplate.remove() - - // Yield event loop to let the notification fly - await new Promise(process.nextTick); - - // Should have sent notification - expect(notifications).toMatchObject([ - { method: "notifications/resources/list_changed" } - ]); - - // Verify the template was removed - const result2 = await client.request( - { method: "resources/templates/list" }, - ListResourceTemplatesResultSchema, - ); - - expect(result2.resourceTemplates).toHaveLength(0); - }); - - /*** - * Test: Resource Registration with Metadata - */ - test("should register resource with metadata", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); - const client = new Client({ - name: "test client", - version: "1.0", - }); + mcpServer.prompt( + 'test', + { + name: z.string(), + value: z.string() + }, + async ({ name, value }) => ({ + messages: [ + { + role: 'assistant', + content: { + type: 'text', + text: `${name}: ${value}` + } + } + ] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request( + { + method: 'prompts/list' + }, + ListPromptsResultSchema + ); + + expect(result.prompts).toHaveLength(1); + expect(result.prompts[0].name).toBe('test'); + expect(result.prompts[0].arguments).toEqual([ + { name: 'name', required: true }, + { name: 'value', required: true } + ]); + }); + + /*** + * Test: Prompt Registration with Description + */ + test('should register prompt with description', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.prompt('test', 'Test description', async () => ({ + messages: [ + { + role: 'assistant', + content: { + type: 'text', + text: 'Test response' + } + } + ] + })); - mcpServer.resource( - "test", - "test://resource", - { - description: "Test resource", - mimeType: "text/plain", - }, - async () => ({ - contents: [ - { - uri: "test://resource", - text: "Test content", - }, - ], - }), - ); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.server.connect(serverTransport), - ]); - - const result = await client.request( - { - method: "resources/list", - }, - ListResourcesResultSchema, - ); - - expect(result.resources).toHaveLength(1); - expect(result.resources[0].description).toBe("Test resource"); - expect(result.resources[0].mimeType).toBe("text/plain"); - }); - - /*** - * Test: Resource Template Registration - */ - test("should register resource template", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); - const client = new Client({ - name: "test client", - version: "1.0", - }); + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - mcpServer.resource( - "test", - new ResourceTemplate("test://resource/{id}", { list: undefined }), - async () => ({ - contents: [ - { - uri: "test://resource/123", - text: "Test content", - }, - ], - }), - ); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.server.connect(serverTransport), - ]); - - const result = await client.request( - { - method: "resources/templates/list", - }, - ListResourceTemplatesResultSchema, - ); - - expect(result.resourceTemplates).toHaveLength(1); - expect(result.resourceTemplates[0].name).toBe("test"); - expect(result.resourceTemplates[0].uriTemplate).toBe( - "test://resource/{id}", - ); - }); - - /*** - * Test: Resource Template with List Callback - */ - test("should register resource template with listCallback", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); - const client = new Client({ - name: "test client", - version: "1.0", - }); + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); - mcpServer.resource( - "test", - new ResourceTemplate("test://resource/{id}", { - list: async () => ({ - resources: [ + const result = await client.request( { - name: "Resource 1", - uri: "test://resource/1", + method: 'prompts/list' }, + ListPromptsResultSchema + ); + + expect(result.prompts).toHaveLength(1); + expect(result.prompts[0].name).toBe('test'); + expect(result.prompts[0].description).toBe('Test description'); + }); + + /*** + * Test: Prompt Argument Validation + */ + test('should validate prompt args', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.prompt( + 'test', { - name: "Resource 2", - uri: "test://resource/2", + name: z.string(), + value: z.string().min(3) }, - ], - }), - }), - async (uri) => ({ - contents: [ - { - uri: uri.href, - text: "Test content", - }, - ], - }), - ); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.server.connect(serverTransport), - ]); - - const result = await client.request( - { - method: "resources/list", - }, - ListResourcesResultSchema, - ); - - expect(result.resources).toHaveLength(2); - expect(result.resources[0].name).toBe("Resource 1"); - expect(result.resources[0].uri).toBe("test://resource/1"); - expect(result.resources[1].name).toBe("Resource 2"); - expect(result.resources[1].uri).toBe("test://resource/2"); - }); - - /*** - * Test: Template Variables to Read Callback - */ - test("should pass template variables to readCallback", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); - const client = new Client({ - name: "test client", - version: "1.0", - }); - - mcpServer.resource( - "test", - new ResourceTemplate("test://resource/{category}/{id}", { - list: undefined, - }), - async (uri, { category, id }) => ({ - contents: [ - { - uri: uri.href, - text: `Category: ${category}, ID: ${id}`, - }, - ], - }), - ); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.server.connect(serverTransport), - ]); - - const result = await client.request( - { - method: "resources/read", - params: { - uri: "test://resource/books/123", - }, - }, - ReadResourceResultSchema, - ); - - expect(result.contents[0].text).toBe("Category: books, ID: 123"); - }); - - /*** - * Test: Preventing Duplicate Resource Registration - */ - test("should prevent duplicate resource registration", () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); + async ({ name, value }) => ({ + messages: [ + { + role: 'assistant', + content: { + type: 'text', + text: `${name}: ${value}` + } + } + ] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + await expect( + client.request( + { + method: 'prompts/get', + params: { + name: 'test', + arguments: { + name: 'test', + value: 'ab' // Too short + } + } + }, + GetPromptResultSchema + ) + ).rejects.toThrow(/Invalid arguments/); + }); + + /*** + * Test: Preventing Duplicate Prompt Registration + */ + test('should prevent duplicate prompt registration', () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + mcpServer.prompt('test', async () => ({ + messages: [ + { + role: 'assistant', + content: { + type: 'text', + text: 'Test response' + } + } + ] + })); + + expect(() => { + mcpServer.prompt('test', async () => ({ + messages: [ + { + role: 'assistant', + content: { + type: 'text', + text: 'Test response 2' + } + } + ] + })); + }).toThrow(/already registered/); + }); + + /*** + * Test: Multiple Prompt Registration + */ + test('should allow registering multiple prompts', () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + // This should succeed + mcpServer.prompt('prompt1', async () => ({ + messages: [ + { + role: 'assistant', + content: { + type: 'text', + text: 'Test response 1' + } + } + ] + })); + + // This should also succeed and not throw about request handlers + mcpServer.prompt('prompt2', async () => ({ + messages: [ + { + role: 'assistant', + content: { + type: 'text', + text: 'Test response 2' + } + } + ] + })); + }); + + /*** + * Test: Prompt Registration with Arguments + */ + test('should allow registering prompts with arguments', () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + // This should succeed + mcpServer.prompt('echo', { message: z.string() }, ({ message }) => ({ + messages: [ + { + role: 'user', + content: { + type: 'text', + text: `Please process this message: ${message}` + } + } + ] + })); + }); + + /*** + * Test: Resources and Prompts with Completion Handlers + */ + test('should allow registering both resources and prompts with completion handlers', () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + // Register a resource with completion + mcpServer.resource( + 'test', + new ResourceTemplate('test://resource/{category}', { + list: undefined, + complete: { + category: () => ['books', 'movies', 'music'] + } + }), + async () => ({ + contents: [ + { + uri: 'test://resource/test', + text: 'Test content' + } + ] + }) + ); + + // Register a prompt with completion + mcpServer.prompt('echo', { message: completable(z.string(), () => ['hello', 'world']) }, ({ message }) => ({ + messages: [ + { + role: 'user', + content: { + type: 'text', + text: `Please process this message: ${message}` + } + } + ] + })); + }); + + /*** + * Test: McpError for Invalid Prompt Name + */ + test('should throw McpError for invalid prompt name', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.prompt('test-prompt', async () => ({ + messages: [ + { + role: 'assistant', + content: { + type: 'text', + text: 'Test response' + } + } + ] + })); - mcpServer.resource("test", "test://resource", async () => ({ - contents: [ - { - uri: "test://resource", - text: "Test content", - }, - ], - })); - - expect(() => { - mcpServer.resource("test2", "test://resource", async () => ({ - contents: [ - { - uri: "test://resource", - text: "Test content 2", - }, - ], - })); - }).toThrow(/already registered/); - }); - - /*** - * Test: Multiple Resource Registration - */ - test("should allow registering multiple resources", () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - // This should succeed - mcpServer.resource("resource1", "test://resource1", async () => ({ - contents: [ - { - uri: "test://resource1", - text: "Test content 1", - }, - ], - })); - - // This should also succeed and not throw about request handlers - mcpServer.resource("resource2", "test://resource2", async () => ({ - contents: [ - { - uri: "test://resource2", - text: "Test content 2", - }, - ], - })); - }); - - /*** - * Test: Preventing Duplicate Resource Template Registration - */ - test("should prevent duplicate resource template registration", () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); - mcpServer.resource( - "test", - new ResourceTemplate("test://resource/{id}", { list: undefined }), - async () => ({ - contents: [ - { - uri: "test://resource/123", - text: "Test content", - }, - ], - }), - ); - - expect(() => { - mcpServer.resource( - "test", - new ResourceTemplate("test://resource/{id}", { list: undefined }), - async () => ({ - contents: [ + await expect( + client.request( + { + method: 'prompts/get', + params: { + name: 'nonexistent-prompt' + } + }, + GetPromptResultSchema + ) + ).rejects.toThrow(/Prompt nonexistent-prompt not found/); + }); + + /*** + * Test: Registering a prompt with a completable argument should update server capabilities to advertise support for completion + */ + test('should advertise support for completion when a prompt with a completable argument is defined', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.prompt( + 'test-prompt', { - uri: "test://resource/123", - text: "Test content 2", + name: completable(z.string(), () => ['Alice', 'Bob', 'Charlie']) }, - ], - }), - ); - }).toThrow(/already registered/); - }); - - /*** - * Test: Graceful Resource Read Error Handling - */ - test("should handle resource read errors gracefully", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); - const client = new Client({ - name: "test client", - version: "1.0", - }); + async ({ name }) => ({ + messages: [ + { + role: 'assistant', + content: { + type: 'text', + text: `Hello ${name}` + } + } + ] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + expect(client.getServerCapabilities()).toMatchObject({ completions: {} }); + }); + + /*** + * Test: Prompt Argument Completion + */ + test('should support completion of prompt arguments', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.prompt( + 'test-prompt', + { + name: completable(z.string(), () => ['Alice', 'Bob', 'Charlie']) + }, + async ({ name }) => ({ + messages: [ + { + role: 'assistant', + content: { + type: 'text', + text: `Hello ${name}` + } + } + ] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request( + { + method: 'completion/complete', + params: { + ref: { + type: 'ref/prompt', + name: 'test-prompt' + }, + argument: { + name: 'name', + value: '' + } + } + }, + CompleteResultSchema + ); - mcpServer.resource("error-test", "test://error", async () => { - throw new Error("Resource read failed"); + expect(result.completion.values).toEqual(['Alice', 'Bob', 'Charlie']); + expect(result.completion.total).toBe(3); }); - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.server.connect(serverTransport), - ]); - - await expect( - client.request( - { - method: "resources/read", - params: { - uri: "test://error", - }, - }, - ReadResourceResultSchema, - ), - ).rejects.toThrow(/Resource read failed/); - }); - - /*** - * Test: McpError for Invalid Resource URI - */ - test("should throw McpError for invalid resource URI", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); - const client = new Client({ - name: "test client", - version: "1.0", - }); + /*** + * Test: Filtered Prompt Argument Completion + */ + test('should support filtered completion of prompt arguments', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); - mcpServer.resource("test", "test://resource", async () => ({ - contents: [ - { - uri: "test://resource", - text: "Test content", - }, - ], - })); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.server.connect(serverTransport), - ]); - - await expect( - client.request( - { - method: "resources/read", - params: { - uri: "test://nonexistent", - }, - }, - ReadResourceResultSchema, - ), - ).rejects.toThrow(/Resource test:\/\/nonexistent not found/); - }); - - /*** - * Test: Registering a resource template with a complete callback should update server capabilities to advertise support for completion - */ - test("should advertise support for completion when a resource template with a complete callback is defined", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); - const client = new Client({ - name: "test client", - version: "1.0", - }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); - mcpServer.resource( - "test", - new ResourceTemplate("test://resource/{category}", { - list: undefined, - complete: { - category: () => ["books", "movies", "music"], - }, - }), - async () => ({ - contents: [ - { - uri: "test://resource/test", - text: "Test content", - }, - ], - }), - ); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.server.connect(serverTransport), - ]); - - expect(client.getServerCapabilities()).toMatchObject({ completions: {} }) - }) - - /*** - * Test: Resource Template Parameter Completion - */ - test("should support completion of resource template parameters", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); + mcpServer.prompt( + 'test-prompt', + { + name: completable(z.string(), test => ['Alice', 'Bob', 'Charlie'].filter(value => value.startsWith(test))) + }, + async ({ name }) => ({ + messages: [ + { + role: 'assistant', + content: { + type: 'text', + text: `Hello ${name}` + } + } + ] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request( + { + method: 'completion/complete', + params: { + ref: { + type: 'ref/prompt', + name: 'test-prompt' + }, + argument: { + name: 'name', + value: 'A' + } + } + }, + CompleteResultSchema + ); + + expect(result.completion.values).toEqual(['Alice']); + expect(result.completion.total).toBe(1); + }); + + /*** + * Test: Pass Request ID to Prompt Callback + */ + test('should pass requestId to prompt callback via RequestHandlerExtra', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + let receivedRequestId: string | number | undefined; + mcpServer.prompt('request-id-test', async extra => { + receivedRequestId = extra.requestId; + return { + messages: [ + { + role: 'assistant', + content: { + type: 'text', + text: `Received request ID: ${extra.requestId}` + } + } + ] + }; + }); - const client = new Client({ - name: "test client", - version: "1.0", - }); + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - mcpServer.resource( - "test", - new ResourceTemplate("test://resource/{category}", { - list: undefined, - complete: { - category: () => ["books", "movies", "music"], - }, - }), - async () => ({ - contents: [ - { - uri: "test://resource/test", - text: "Test content", - }, - ], - }), - ); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.server.connect(serverTransport), - ]); - - const result = await client.request( - { - method: "completion/complete", - params: { - ref: { - type: "ref/resource", - uri: "test://resource/{category}", - }, - argument: { - name: "category", - value: "", - }, - }, - }, - CompleteResultSchema, - ); - - expect(result.completion.values).toEqual(["books", "movies", "music"]); - expect(result.completion.total).toBe(3); - }); - - /*** - * Test: Filtered Resource Template Parameter Completion - */ - test("should support filtered completion of resource template parameters", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); - const client = new Client({ - name: "test client", - version: "1.0", - }); + const result = await client.request( + { + method: 'prompts/get', + params: { + name: 'request-id-test' + } + }, + GetPromptResultSchema + ); + + expect(receivedRequestId).toBeDefined(); + expect(typeof receivedRequestId === 'string' || typeof receivedRequestId === 'number').toBe(true); + expect(result.messages[0].content.text).toContain('Received request ID:'); + }); + + /*** + * Test: Resource Template Metadata Priority + */ + test('should prioritize individual resource metadata over template metadata', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.resource( + 'test', + new ResourceTemplate('test://resource/{id}', { + list: async () => ({ + resources: [ + { + name: 'Resource 1', + uri: 'test://resource/1', + description: 'Individual resource description', + mimeType: 'text/plain' + }, + { + name: 'Resource 2', + uri: 'test://resource/2' + // This resource has no description or mimeType + } + ] + }) + }), + { + description: 'Template description', + mimeType: 'application/json' + }, + async uri => ({ + contents: [ + { + uri: uri.href, + text: 'Test content' + } + ] + }) + ); - mcpServer.resource( - "test", - new ResourceTemplate("test://resource/{category}", { - list: undefined, - complete: { - category: (test: string) => - ["books", "movies", "music"].filter((value) => - value.startsWith(test), - ), - }, - }), - async () => ({ - contents: [ - { - uri: "test://resource/test", - text: "Test content", - }, - ], - }), - ); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.server.connect(serverTransport), - ]); - - const result = await client.request( - { - method: "completion/complete", - params: { - ref: { - type: "ref/resource", - uri: "test://resource/{category}", - }, - argument: { - name: "category", - value: "m", - }, - }, - }, - CompleteResultSchema, - ); - - expect(result.completion.values).toEqual(["movies", "music"]); - expect(result.completion.total).toBe(2); - }); - - /*** - * Test: Pass Request ID to Resource Callback - */ - test("should pass requestId to resource callback via RequestHandlerExtra", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - const client = new Client({ - name: "test client", - version: "1.0", - }); + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); - let receivedRequestId: string | number | undefined; - mcpServer.resource("request-id-test", "test://resource", async (_uri, extra) => { - receivedRequestId = extra.requestId; - return { - contents: [ - { - uri: "test://resource", - text: `Received request ID: ${extra.requestId}`, - }, - ], - }; - }); + const result = await client.request( + { + method: 'resources/list' + }, + ListResourcesResultSchema + ); + + expect(result.resources).toHaveLength(2); + + // Resource 1 should have its own metadata + expect(result.resources[0].name).toBe('Resource 1'); + expect(result.resources[0].description).toBe('Individual resource description'); + expect(result.resources[0].mimeType).toBe('text/plain'); + + // Resource 2 should inherit template metadata + expect(result.resources[1].name).toBe('Resource 2'); + expect(result.resources[1].description).toBe('Template description'); + expect(result.resources[1].mimeType).toBe('application/json'); + }); + + /*** + * Test: Resource Template Metadata Overrides All Fields + */ + test('should allow resource to override all template metadata fields', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.resource( + 'test', + new ResourceTemplate('test://resource/{id}', { + list: async () => ({ + resources: [ + { + name: 'Overridden Name', + uri: 'test://resource/1', + description: 'Overridden description', + mimeType: 'text/markdown' + // Add any other metadata fields if they exist + } + ] + }) + }), + { + name: 'Template Name', + description: 'Template description', + mimeType: 'application/json' + }, + async uri => ({ + contents: [ + { + uri: uri.href, + text: 'Test content' + } + ] + }) + ); - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.server.connect(serverTransport), - ]); - - const result = await client.request( - { - method: "resources/read", - params: { - uri: "test://resource", - }, - }, - ReadResourceResultSchema, - ); - - expect(receivedRequestId).toBeDefined(); - expect(typeof receivedRequestId === 'string' || typeof receivedRequestId === 'number').toBe(true); - expect(result.contents[0].text).toContain("Received request ID:"); - }); -}); + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); -describe("prompt()", () => { - /*** - * Test: Zero-Argument Prompt Registration - */ - test("should register zero-argument prompt", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); - const client = new Client({ - name: "test client", - version: "1.0", - }); + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); - mcpServer.prompt("test", async () => ({ - messages: [ - { - role: "assistant", - content: { - type: "text", - text: "Test response", - }, - }, - ], - })); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.server.connect(serverTransport), - ]); - - const result = await client.request( - { - method: "prompts/list", - }, - ListPromptsResultSchema, - ); - - expect(result.prompts).toHaveLength(1); - expect(result.prompts[0].name).toBe("test"); - expect(result.prompts[0].arguments).toBeUndefined(); - }); - /*** - * Test: Updating Existing Prompt - */ - test("should update existing prompt", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); - const notifications: Notification[] = []; - const client = new Client({ - name: "test client", - version: "1.0", - }); - client.fallbackNotificationHandler = async (notification) => { - notifications.push(notification); - }; - - // Register initial prompt - const prompt = mcpServer.prompt("test", async () => ({ - messages: [ - { - role: "assistant", - content: { - type: "text", - text: "Initial response", - }, - }, - ], - })); - - // Update the prompt - prompt.update({ - callback: async () => ({ - messages: [ - { - role: "assistant", - content: { - type: "text", - text: "Updated response", + const result = await client.request( + { + method: 'resources/list' }, - }, - ], - }) - }); + ListResourcesResultSchema + ); - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.connect(serverTransport), - ]); - - // Call the prompt and verify we get the updated response - const result = await client.request( - { - method: "prompts/get", - params: { - name: "test", - }, - }, - GetPromptResultSchema, - ); - - expect(result.messages).toHaveLength(1); - expect(result.messages[0].content.text).toBe("Updated response"); - - // Update happened before transport was connected, so no notifications should be expected - expect(notifications).toHaveLength(0); - }); - - /*** - * Test: Updating Prompt with Schema - */ - test("should update prompt with schema", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); - const notifications: Notification[] = []; - const client = new Client({ - name: "test client", - version: "1.0", + expect(result.resources).toHaveLength(1); + + // All fields should be from the individual resource, not the template + expect(result.resources[0].name).toBe('Overridden Name'); + expect(result.resources[0].description).toBe('Overridden description'); + expect(result.resources[0].mimeType).toBe('text/markdown'); }); - client.fallbackNotificationHandler = async (notification) => { - notifications.push(notification); - }; - - // Register initial prompt - const prompt = mcpServer.prompt( - "test", - { - name: z.string(), - }, - async ({ name }) => ({ - messages: [ - { - role: "assistant", - content: { - type: "text", - text: `Initial: ${name}`, +}); + +describe('Tool title precedence', () => { + test('should follow correct title precedence: title → annotations.title → name', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + // Tool 1: Only name + mcpServer.tool('tool_name_only', async () => ({ + content: [{ type: 'text', text: 'Response' }] + })); + + // Tool 2: Name and annotations.title + mcpServer.tool( + 'tool_with_annotations_title', + 'Tool with annotations title', + { + title: 'Annotations Title' }, - }, - ], - }), - ); - - // Update the prompt with a different schema - prompt.update({ - argsSchema: { - name: z.string(), - value: z.string(), - }, - callback: async ({ name, value }) => ({ - messages: [ - { - role: "assistant", - content: { - type: "text", - text: `Updated: ${name}, ${value}`, + async () => ({ + content: [{ type: 'text', text: 'Response' }] + }) + ); + + // Tool 3: Name and title (using registerTool) + mcpServer.registerTool( + 'tool_with_title', + { + title: 'Regular Title', + description: 'Tool with regular title' }, - }, - ], - }) - }); + async () => ({ + content: [{ type: 'text', text: 'Response' }] + }) + ); + + // Tool 4: All three - title should win + mcpServer.registerTool( + 'tool_with_all_titles', + { + title: 'Regular Title Wins', + description: 'Tool with all titles', + annotations: { + title: 'Annotations Title Should Not Show' + } + }, + async () => ({ + content: [{ type: 'text', text: 'Response' }] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); + + const result = await client.request({ method: 'tools/list' }, ListToolsResultSchema); + + expect(result.tools).toHaveLength(4); + + // Tool 1: Only name - should display name + const tool1 = result.tools.find(t => t.name === 'tool_name_only'); + expect(tool1).toBeDefined(); + expect(getDisplayName(tool1!)).toBe('tool_name_only'); + + // Tool 2: Name and annotations.title - should display annotations.title + const tool2 = result.tools.find(t => t.name === 'tool_with_annotations_title'); + expect(tool2).toBeDefined(); + expect(tool2!.annotations?.title).toBe('Annotations Title'); + expect(getDisplayName(tool2!)).toBe('Annotations Title'); + + // Tool 3: Name and title - should display title + const tool3 = result.tools.find(t => t.name === 'tool_with_title'); + expect(tool3).toBeDefined(); + expect(tool3!.title).toBe('Regular Title'); + expect(getDisplayName(tool3!)).toBe('Regular Title'); + + // Tool 4: All three - title should take precedence + const tool4 = result.tools.find(t => t.name === 'tool_with_all_titles'); + expect(tool4).toBeDefined(); + expect(tool4!.title).toBe('Regular Title Wins'); + expect(tool4!.annotations?.title).toBe('Annotations Title Should Not Show'); + expect(getDisplayName(tool4!)).toBe('Regular Title Wins'); + }); + + test('getDisplayName unit tests for title precedence', () => { + // Test 1: Only name + expect(getDisplayName({ name: 'tool_name' })).toBe('tool_name'); + + // Test 2: Name and title - title wins + expect( + getDisplayName({ + name: 'tool_name', + title: 'Tool Title' + }) + ).toBe('Tool Title'); + + // Test 3: Name and annotations.title - annotations.title wins + expect( + getDisplayName({ + name: 'tool_name', + annotations: { title: 'Annotations Title' } + }) + ).toBe('Annotations Title'); + + // Test 4: All three - title wins (correct precedence) + expect( + getDisplayName({ + name: 'tool_name', + title: 'Regular Title', + annotations: { title: 'Annotations Title' } + }) + ).toBe('Regular Title'); + + // Test 5: Empty title should not be used + expect( + getDisplayName({ + name: 'tool_name', + title: '', + annotations: { title: 'Annotations Title' } + }) + ).toBe('Annotations Title'); + + // Test 6: Undefined vs null handling + expect( + getDisplayName({ + name: 'tool_name', + title: undefined, + annotations: { title: 'Annotations Title' } + }) + ).toBe('Annotations Title'); + }); + + test('should support resource template completion with resolved context', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.registerResource( + 'test', + new ResourceTemplate('github://repos/{owner}/{repo}', { + list: undefined, + complete: { + repo: (value, context) => { + if (context?.arguments?.['owner'] === 'org1') { + return ['project1', 'project2', 'project3'].filter(r => r.startsWith(value)); + } else if (context?.arguments?.['owner'] === 'org2') { + return ['repo1', 'repo2', 'repo3'].filter(r => r.startsWith(value)); + } + return []; + } + } + }), + { + title: 'GitHub Repository', + description: 'Repository information' + }, + async () => ({ + contents: [ + { + uri: 'github://repos/test/test', + text: 'Test content' + } + ] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + // Test with microsoft owner + const result1 = await client.request( + { + method: 'completion/complete', + params: { + ref: { + type: 'ref/resource', + uri: 'github://repos/{owner}/{repo}' + }, + argument: { + name: 'repo', + value: 'p' + }, + context: { + arguments: { + owner: 'org1' + } + } + } + }, + CompleteResultSchema + ); - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.connect(serverTransport), - ]); - - // Verify the schema was updated - const listResult = await client.request( - { - method: "prompts/list", - }, - ListPromptsResultSchema, - ); - - expect(listResult.prompts[0].arguments).toHaveLength(2); - expect(listResult.prompts[0].arguments?.map(a => a.name).sort()).toEqual(["name", "value"]); - - // Call the prompt with the new schema - const getResult = await client.request( - { - method: "prompts/get", - params: { - name: "test", - arguments: { - name: "test", - value: "value", - }, - }, - }, - GetPromptResultSchema, - ); - - expect(getResult.messages).toHaveLength(1); - expect(getResult.messages[0].content.text).toBe("Updated: test, value"); - - // Update happened before transport was connected, so no notifications should be expected - expect(notifications).toHaveLength(0); - }); - - /*** - * Test: Prompt List Changed Notification - */ - test("should send prompt list changed notification when connected", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); - const notifications: Notification[] = []; - const client = new Client({ - name: "test client", - version: "1.0", - }); - client.fallbackNotificationHandler = async (notification) => { - notifications.push(notification); - }; - - // Register initial prompt - const prompt = mcpServer.prompt("test", async () => ({ - messages: [ - { - role: "assistant", - content: { - type: "text", - text: "Test response", - }, - }, - ], - })); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.connect(serverTransport), - ]); - - expect(notifications).toHaveLength(0); - - // Now update the prompt while connected - prompt.update({ - callback: async () => ({ - messages: [ - { - role: "assistant", - content: { - type: "text", - text: "Updated response", + expect(result1.completion.values).toEqual(['project1', 'project2', 'project3']); + expect(result1.completion.total).toBe(3); + + // Test with facebook owner + const result2 = await client.request( + { + method: 'completion/complete', + params: { + ref: { + type: 'ref/resource', + uri: 'github://repos/{owner}/{repo}' + }, + argument: { + name: 'repo', + value: 'r' + }, + context: { + arguments: { + owner: 'org2' + } + } + } }, - }, - ], - }) - }); + CompleteResultSchema + ); - // Yield event loop to let the notification fly - await new Promise(process.nextTick); - - expect(notifications).toMatchObject([ - { method: "notifications/prompts/list_changed" } - ]); - }); - - /*** - * Test: Remove Prompt and Send Notification - */ - test("should remove prompt and send notification when connected", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); - const notifications: Notification[] = []; - const client = new Client({ - name: "test client", - version: "1.0", - }); - client.fallbackNotificationHandler = async (notification) => { - notifications.push(notification); - }; - - // Register initial prompts - const prompt1 = mcpServer.prompt("prompt1", async () => ({ - messages: [ - { - role: "assistant", - content: { - type: "text", - text: "Prompt 1 response", - }, - }, - ], - })); - - mcpServer.prompt("prompt2", async () => ({ - messages: [ - { - role: "assistant", - content: { - type: "text", - text: "Prompt 2 response", - }, - }, - ], - })); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.connect(serverTransport), - ]); - - // Verify both prompts are registered - let result = await client.request( - { method: "prompts/list" }, - ListPromptsResultSchema, - ); - - expect(result.prompts).toHaveLength(2); - expect(result.prompts.map(p => p.name).sort()).toEqual(["prompt1", "prompt2"]); - - expect(notifications).toHaveLength(0); - - // Remove a prompt - prompt1.remove() - - // Yield event loop to let the notification fly - await new Promise(process.nextTick); - - // Should have sent notification - expect(notifications).toMatchObject([ - { method: "notifications/prompts/list_changed" } - ]); - - // Verify the prompt was removed - result = await client.request( - { method: "prompts/list" }, - ListPromptsResultSchema, - ); - - expect(result.prompts).toHaveLength(1); - expect(result.prompts[0].name).toBe("prompt2"); - }); - - /*** - * Test: Prompt Registration with Arguments Schema - */ - test("should register prompt with args schema", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); - const client = new Client({ - name: "test client", - version: "1.0", - }); + expect(result2.completion.values).toEqual(['repo1', 'repo2', 'repo3']); + expect(result2.completion.total).toBe(3); - mcpServer.prompt( - "test", - { - name: z.string(), - value: z.string(), - }, - async ({ name, value }) => ({ - messages: [ - { - role: "assistant", - content: { - type: "text", - text: `${name}: ${value}`, + // Test with no resolved context + const result3 = await client.request( + { + method: 'completion/complete', + params: { + ref: { + type: 'ref/resource', + uri: 'github://repos/{owner}/{repo}' + }, + argument: { + name: 'repo', + value: 't' + } + } }, - }, - ], - }), - ); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.server.connect(serverTransport), - ]); - - const result = await client.request( - { - method: "prompts/list", - }, - ListPromptsResultSchema, - ); - - expect(result.prompts).toHaveLength(1); - expect(result.prompts[0].name).toBe("test"); - expect(result.prompts[0].arguments).toEqual([ - { name: "name", required: true }, - { name: "value", required: true }, - ]); - }); - - /*** - * Test: Prompt Registration with Description - */ - test("should register prompt with description", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); - const client = new Client({ - name: "test client", - version: "1.0", - }); + CompleteResultSchema + ); - mcpServer.prompt("test", "Test description", async () => ({ - messages: [ - { - role: "assistant", - content: { - type: "text", - text: "Test response", - }, - }, - ], - })); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.server.connect(serverTransport), - ]); - - const result = await client.request( - { - method: "prompts/list", - }, - ListPromptsResultSchema, - ); - - expect(result.prompts).toHaveLength(1); - expect(result.prompts[0].name).toBe("test"); - expect(result.prompts[0].description).toBe("Test description"); - }); - - /*** - * Test: Prompt Argument Validation - */ - test("should validate prompt args", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", + expect(result3.completion.values).toEqual([]); + expect(result3.completion.total).toBe(0); }); - const client = new Client({ - name: "test client", - version: "1.0", - }); + test('should support prompt argument completion with resolved context', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); - mcpServer.prompt( - "test", - { - name: z.string(), - value: z.string().min(3), - }, - async ({ name, value }) => ({ - messages: [ - { - role: "assistant", - content: { - type: "text", - text: `${name}: ${value}`, - }, - }, - ], - }), - ); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.server.connect(serverTransport), - ]); - - await expect( - client.request( - { - method: "prompts/get", - params: { - name: "test", - arguments: { - name: "test", - value: "ab", // Too short - }, - }, - }, - GetPromptResultSchema, - ), - ).rejects.toThrow(/Invalid arguments/); - }); - - /*** - * Test: Preventing Duplicate Prompt Registration - */ - test("should prevent duplicate prompt registration", () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); - mcpServer.prompt("test", async () => ({ - messages: [ - { - role: "assistant", - content: { - type: "text", - text: "Test response", - }, - }, - ], - })); - - expect(() => { - mcpServer.prompt("test", async () => ({ - messages: [ - { - role: "assistant", - content: { - type: "text", - text: "Test response 2", + mcpServer.registerPrompt( + 'test-prompt', + { + title: 'Team Greeting', + description: 'Generate a greeting for team members', + argsSchema: { + department: completable(z.string(), value => { + return ['engineering', 'sales', 'marketing', 'support'].filter(d => d.startsWith(value)); + }), + name: completable(z.string(), (value, context) => { + const department = context?.arguments?.['department']; + if (department === 'engineering') { + return ['Alice', 'Bob', 'Charlie'].filter(n => n.startsWith(value)); + } else if (department === 'sales') { + return ['David', 'Eve', 'Frank'].filter(n => n.startsWith(value)); + } else if (department === 'marketing') { + return ['Grace', 'Henry', 'Iris'].filter(n => n.startsWith(value)); + } + return ['Guest'].filter(n => n.startsWith(value)); + }) + } }, - }, - ], - })); - }).toThrow(/already registered/); - }); - - /*** - * Test: Multiple Prompt Registration - */ - test("should allow registering multiple prompts", () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); + async ({ department, name }) => ({ + messages: [ + { + role: 'assistant', + content: { + type: 'text', + text: `Hello ${name}, welcome to the ${department} team!` + } + } + ] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + // Test with engineering department + const result1 = await client.request( + { + method: 'completion/complete', + params: { + ref: { + type: 'ref/prompt', + name: 'test-prompt' + }, + argument: { + name: 'name', + value: 'A' + }, + context: { + arguments: { + department: 'engineering' + } + } + } + }, + CompleteResultSchema + ); - // This should succeed - mcpServer.prompt("prompt1", async () => ({ - messages: [ - { - role: "assistant", - content: { - type: "text", - text: "Test response 1", - }, - }, - ], - })); - - // This should also succeed and not throw about request handlers - mcpServer.prompt("prompt2", async () => ({ - messages: [ - { - role: "assistant", - content: { - type: "text", - text: "Test response 2", - }, - }, - ], - })); - }); - - /*** - * Test: Prompt Registration with Arguments - */ - test("should allow registering prompts with arguments", () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); + expect(result1.completion.values).toEqual(['Alice']); - // This should succeed - mcpServer.prompt( - "echo", - { message: z.string() }, - ({ message }) => ({ - messages: [{ - role: "user", - content: { - type: "text", - text: `Please process this message: ${message}` - } - }] - }) - ); - }); - - /*** - * Test: Resources and Prompts with Completion Handlers - */ - test("should allow registering both resources and prompts with completion handlers", () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); + // Test with sales department + const result2 = await client.request( + { + method: 'completion/complete', + params: { + ref: { + type: 'ref/prompt', + name: 'test-prompt' + }, + argument: { + name: 'name', + value: 'D' + }, + context: { + arguments: { + department: 'sales' + } + } + } + }, + CompleteResultSchema + ); - // Register a resource with completion - mcpServer.resource( - "test", - new ResourceTemplate("test://resource/{category}", { - list: undefined, - complete: { - category: () => ["books", "movies", "music"], - }, - }), - async () => ({ - contents: [ - { - uri: "test://resource/test", - text: "Test content", - }, - ], - }), - ); - - // Register a prompt with completion - mcpServer.prompt( - "echo", - { message: completable(z.string(), () => ["hello", "world"]) }, - ({ message }) => ({ - messages: [{ - role: "user", - content: { - type: "text", - text: `Please process this message: ${message}` - } - }] - }) - ); - }); - - /*** - * Test: McpError for Invalid Prompt Name - */ - test("should throw McpError for invalid prompt name", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); + expect(result2.completion.values).toEqual(['David']); - const client = new Client({ - name: "test client", - version: "1.0", - }); + // Test with marketing department + const result3 = await client.request( + { + method: 'completion/complete', + params: { + ref: { + type: 'ref/prompt', + name: 'test-prompt' + }, + argument: { + name: 'name', + value: 'G' + }, + context: { + arguments: { + department: 'marketing' + } + } + } + }, + CompleteResultSchema + ); - mcpServer.prompt("test-prompt", async () => ({ - messages: [ - { - role: "assistant", - content: { - type: "text", - text: "Test response", - }, - }, - ], - })); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.server.connect(serverTransport), - ]); - - await expect( - client.request( - { - method: "prompts/get", - params: { - name: "nonexistent-prompt", - }, - }, - GetPromptResultSchema, - ), - ).rejects.toThrow(/Prompt nonexistent-prompt not found/); - }); - - - /*** - * Test: Registering a prompt with a completable argument should update server capabilities to advertise support for completion - */ - test("should advertise support for completion when a prompt with a completable argument is defined", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); - const client = new Client({ - name: "test client", - version: "1.0", - }); + expect(result3.completion.values).toEqual(['Grace']); - mcpServer.prompt( - "test-prompt", - { - name: completable(z.string(), () => ["Alice", "Bob", "Charlie"]), - }, - async ({ name }) => ({ - messages: [ - { - role: "assistant", - content: { - type: "text", - text: `Hello ${name}`, + // Test with no resolved context + const result4 = await client.request( + { + method: 'completion/complete', + params: { + ref: { + type: 'ref/prompt', + name: 'test-prompt' + }, + argument: { + name: 'name', + value: 'G' + } + } }, - }, - ], - }), - ); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.server.connect(serverTransport), - ]); - - expect(client.getServerCapabilities()).toMatchObject({ completions: {} }) - }) - - /*** - * Test: Prompt Argument Completion - */ - test("should support completion of prompt arguments", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); + CompleteResultSchema + ); - const client = new Client({ - name: "test client", - version: "1.0", + expect(result4.completion.values).toEqual(['Guest']); }); +}); - mcpServer.prompt( - "test-prompt", - { - name: completable(z.string(), () => ["Alice", "Bob", "Charlie"]), - }, - async ({ name }) => ({ - messages: [ - { - role: "assistant", - content: { - type: "text", - text: `Hello ${name}`, - }, - }, - ], - }), - ); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.server.connect(serverTransport), - ]); - - const result = await client.request( - { - method: "completion/complete", - params: { - ref: { - type: "ref/prompt", - name: "test-prompt", - }, - argument: { - name: "name", - value: "", - }, - }, - }, - CompleteResultSchema, - ); - - expect(result.completion.values).toEqual(["Alice", "Bob", "Charlie"]); - expect(result.completion.total).toBe(3); - }); - - /*** - * Test: Filtered Prompt Argument Completion - */ - test("should support filtered completion of prompt arguments", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); +describe('elicitInput()', () => { + const checkAvailability = jest.fn().mockResolvedValue(false); + const findAlternatives = jest.fn().mockResolvedValue([]); + const makeBooking = jest.fn().mockResolvedValue('BOOKING-123'); - const client = new Client({ - name: "test client", - version: "1.0", - }); + let mcpServer: McpServer; + let client: Client; - mcpServer.prompt( - "test-prompt", - { - name: completable(z.string(), (test) => - ["Alice", "Bob", "Charlie"].filter((value) => value.startsWith(test)), - ), - }, - async ({ name }) => ({ - messages: [ - { - role: "assistant", - content: { - type: "text", - text: `Hello ${name}`, - }, - }, - ], - }), - ); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.server.connect(serverTransport), - ]); - - const result = await client.request( - { - method: "completion/complete", - params: { - ref: { - type: "ref/prompt", - name: "test-prompt", - }, - argument: { - name: "name", - value: "A", - }, - }, - }, - CompleteResultSchema, - ); - - expect(result.completion.values).toEqual(["Alice"]); - expect(result.completion.total).toBe(1); - }); - - /*** - * Test: Pass Request ID to Prompt Callback - */ - test("should pass requestId to prompt callback via RequestHandlerExtra", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); + beforeEach(() => { + jest.clearAllMocks(); - const client = new Client({ - name: "test client", - version: "1.0", - }); + // Create server with restaurant booking tool + mcpServer = new McpServer({ + name: 'restaurant-booking-server', + version: '1.0.0' + }); - let receivedRequestId: string | number | undefined; - mcpServer.prompt("request-id-test", async (extra) => { - receivedRequestId = extra.requestId; - return { - messages: [ - { - role: "assistant", - content: { - type: "text", - text: `Received request ID: ${extra.requestId}`, + // Register the restaurant booking tool from README example + mcpServer.tool( + 'book-restaurant', + { + restaurant: z.string(), + date: z.string(), + partySize: z.number() }, - }, - ], - }; - }); + async ({ restaurant, date, partySize }) => { + // Check availability + const available = await checkAvailability(restaurant, date, partySize); + + if (!available) { + // Ask user if they want to try alternative dates + const result = await mcpServer.server.elicitInput({ + message: `No tables available at ${restaurant} on ${date}. Would you like to check alternative dates?`, + requestedSchema: { + type: 'object', + properties: { + checkAlternatives: { + type: 'boolean', + title: 'Check alternative dates', + description: 'Would you like me to check other dates?' + }, + flexibleDates: { + type: 'string', + title: 'Date flexibility', + description: 'How flexible are your dates?', + enum: ['next_day', 'same_week', 'next_week'], + enumNames: ['Next day', 'Same week', 'Next week'] + } + }, + required: ['checkAlternatives'] + } + }); + + if (result.action === 'accept' && result.content?.checkAlternatives) { + const alternatives = await findAlternatives(restaurant, date, partySize, result.content.flexibleDates as string); + return { + content: [ + { + type: 'text', + text: `Found these alternatives: ${alternatives.join(', ')}` + } + ] + }; + } + + return { + content: [ + { + type: 'text', + text: 'No booking made. Original date not available.' + } + ] + }; + } - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.server.connect(serverTransport), - ]); - - const result = await client.request( - { - method: "prompts/get", - params: { - name: "request-id-test", - }, - }, - GetPromptResultSchema, - ); - - expect(receivedRequestId).toBeDefined(); - expect(typeof receivedRequestId === 'string' || typeof receivedRequestId === 'number').toBe(true); - expect(result.messages[0].content.text).toContain("Received request ID:"); - }); - - /*** - * Test: Resource Template Metadata Priority - */ - test("should prioritize individual resource metadata over template metadata", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); - const client = new Client({ - name: "test client", - version: "1.0", - }); + await makeBooking(restaurant, date, partySize); + return { + content: [ + { + type: 'text', + text: `Booked table for ${partySize} at ${restaurant} on ${date}` + } + ] + }; + } + ); - mcpServer.resource( - "test", - new ResourceTemplate("test://resource/{id}", { - list: async () => ({ - resources: [ + // Create client with elicitation capability + client = new Client( { - name: "Resource 1", - uri: "test://resource/1", - description: "Individual resource description", - mimeType: "text/plain", + name: 'test-client', + version: '1.0.0' }, { - name: "Resource 2", - uri: "test://resource/2", - // This resource has no description or mimeType - }, - ], - }), - }), - { - description: "Template description", - mimeType: "application/json", - }, - async (uri) => ({ - contents: [ - { - uri: uri.href, - text: "Test content", - }, - ], - }), - ); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.server.connect(serverTransport), - ]); - - const result = await client.request( - { - method: "resources/list", - }, - ListResourcesResultSchema, - ); - - expect(result.resources).toHaveLength(2); - - // Resource 1 should have its own metadata - expect(result.resources[0].name).toBe("Resource 1"); - expect(result.resources[0].description).toBe("Individual resource description"); - expect(result.resources[0].mimeType).toBe("text/plain"); - - // Resource 2 should inherit template metadata - expect(result.resources[1].name).toBe("Resource 2"); - expect(result.resources[1].description).toBe("Template description"); - expect(result.resources[1].mimeType).toBe("application/json"); - }); - - /*** - * Test: Resource Template Metadata Overrides All Fields - */ - test("should allow resource to override all template metadata fields", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); - const client = new Client({ - name: "test client", - version: "1.0", + capabilities: { + elicitation: {} + } + } + ); }); - mcpServer.resource( - "test", - new ResourceTemplate("test://resource/{id}", { - list: async () => ({ - resources: [ - { - name: "Overridden Name", - uri: "test://resource/1", - description: "Overridden description", - mimeType: "text/markdown", - // Add any other metadata fields if they exist - }, - ], - }), - }), - { - name: "Template Name", - description: "Template description", - mimeType: "application/json", - }, - async (uri) => ({ - contents: [ - { - uri: uri.href, - text: "Test content", - }, - ], - }), - ); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.server.connect(serverTransport), - ]); - - const result = await client.request( - { - method: "resources/list", - }, - ListResourcesResultSchema, - ); - - expect(result.resources).toHaveLength(1); - - // All fields should be from the individual resource, not the template - expect(result.resources[0].name).toBe("Overridden Name"); - expect(result.resources[0].description).toBe("Overridden description"); - expect(result.resources[0].mimeType).toBe("text/markdown"); - }); -}); + test('should successfully elicit additional information', async () => { + // Mock availability check to return false + checkAvailability.mockResolvedValue(false); + findAlternatives.mockResolvedValue(['2024-12-26', '2024-12-27', '2024-12-28']); -describe("Tool title precedence", () => { - test("should follow correct title precedence: title → annotations.title → name", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); - const client = new Client({ - name: "test client", - version: "1.0", - }); + // Set up client to accept alternative date checking + client.setRequestHandler(ElicitRequestSchema, async request => { + expect(request.params.message).toContain('No tables available at ABC Restaurant on 2024-12-25'); + return { + action: 'accept', + content: { + checkAlternatives: true, + flexibleDates: 'same_week' + } + }; + }); - // Tool 1: Only name - mcpServer.tool( - "tool_name_only", - async () => ({ - content: [{ type: "text", text: "Response" }], - }) - ); - - // Tool 2: Name and annotations.title - mcpServer.tool( - "tool_with_annotations_title", - "Tool with annotations title", - { - title: "Annotations Title" - }, - async () => ({ - content: [{ type: "text", text: "Response" }], - }) - ); - - // Tool 3: Name and title (using registerTool) - mcpServer.registerTool( - "tool_with_title", - { - title: "Regular Title", - description: "Tool with regular title" - }, - async () => ({ - content: [{ type: "text", text: "Response" }], - }) - ); - - // Tool 4: All three - title should win - mcpServer.registerTool( - "tool_with_all_titles", - { - title: "Regular Title Wins", - description: "Tool with all titles", - annotations: { - title: "Annotations Title Should Not Show" - } - }, - async () => ({ - content: [{ type: "text", text: "Response" }], - }) - ); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - await Promise.all([ - client.connect(clientTransport), - mcpServer.connect(serverTransport), - ]); - - const result = await client.request( - { method: "tools/list" }, - ListToolsResultSchema, - ); - - - expect(result.tools).toHaveLength(4); - - // Tool 1: Only name - should display name - const tool1 = result.tools.find(t => t.name === "tool_name_only"); - expect(tool1).toBeDefined(); - expect(getDisplayName(tool1!)).toBe("tool_name_only"); - - // Tool 2: Name and annotations.title - should display annotations.title - const tool2 = result.tools.find(t => t.name === "tool_with_annotations_title"); - expect(tool2).toBeDefined(); - expect(tool2!.annotations?.title).toBe("Annotations Title"); - expect(getDisplayName(tool2!)).toBe("Annotations Title"); - - // Tool 3: Name and title - should display title - const tool3 = result.tools.find(t => t.name === "tool_with_title"); - expect(tool3).toBeDefined(); - expect(tool3!.title).toBe("Regular Title"); - expect(getDisplayName(tool3!)).toBe("Regular Title"); - - // Tool 4: All three - title should take precedence - const tool4 = result.tools.find(t => t.name === "tool_with_all_titles"); - expect(tool4).toBeDefined(); - expect(tool4!.title).toBe("Regular Title Wins"); - expect(tool4!.annotations?.title).toBe("Annotations Title Should Not Show"); - expect(getDisplayName(tool4!)).toBe("Regular Title Wins"); - }); - - test("getDisplayName unit tests for title precedence", () => { - - // Test 1: Only name - expect(getDisplayName({ name: "tool_name" })).toBe("tool_name"); - - // Test 2: Name and title - title wins - expect(getDisplayName({ - name: "tool_name", - title: "Tool Title" - })).toBe("Tool Title"); - - // Test 3: Name and annotations.title - annotations.title wins - expect(getDisplayName({ - name: "tool_name", - annotations: { title: "Annotations Title" } - })).toBe("Annotations Title"); - - // Test 4: All three - title wins (correct precedence) - expect(getDisplayName({ - name: "tool_name", - title: "Regular Title", - annotations: { title: "Annotations Title" } - })).toBe("Regular Title"); - - // Test 5: Empty title should not be used - expect(getDisplayName({ - name: "tool_name", - title: "", - annotations: { title: "Annotations Title" } - })).toBe("Annotations Title"); - - // Test 6: Undefined vs null handling - expect(getDisplayName({ - name: "tool_name", - title: undefined, - annotations: { title: "Annotations Title" } - })).toBe("Annotations Title"); - }); - - test("should support resource template completion with resolved context", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - const client = new Client({ - name: "test client", - version: "1.0", - }); + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); - mcpServer.registerResource( - "test", - new ResourceTemplate("github://repos/{owner}/{repo}", { - list: undefined, - complete: { - repo: (value, context) => { - if (context?.arguments?.["owner"] === "org1") { - return ["project1", "project2", "project3"].filter(r => r.startsWith(value)); - } else if (context?.arguments?.["owner"] === "org2") { - return ["repo1", "repo2", "repo3"].filter(r => r.startsWith(value)); - } - return []; - }, - }, - }), - { - title: "GitHub Repository", - description: "Repository information" - }, - async () => ({ - contents: [ - { - uri: "github://repos/test/test", - text: "Test content", - }, - ], - }), - ); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.server.connect(serverTransport), - ]); - - // Test with microsoft owner - const result1 = await client.request( - { - method: "completion/complete", - params: { - ref: { - type: "ref/resource", - uri: "github://repos/{owner}/{repo}", - }, - argument: { - name: "repo", - value: "p", - }, - context: { - arguments: { - owner: "org1", - }, - }, - }, - }, - CompleteResultSchema, - ); - - expect(result1.completion.values).toEqual(["project1", "project2", "project3"]); - expect(result1.completion.total).toBe(3); - - // Test with facebook owner - const result2 = await client.request( - { - method: "completion/complete", - params: { - ref: { - type: "ref/resource", - uri: "github://repos/{owner}/{repo}", - }, - argument: { - name: "repo", - value: "r", - }, - context: { + // Call the tool + const result = await client.callTool({ + name: 'book-restaurant', arguments: { - owner: "org2", - }, - }, - }, - }, - CompleteResultSchema, - ); - - expect(result2.completion.values).toEqual(["repo1", "repo2", "repo3"]); - expect(result2.completion.total).toBe(3); - - // Test with no resolved context - const result3 = await client.request( - { - method: "completion/complete", - params: { - ref: { - type: "ref/resource", - uri: "github://repos/{owner}/{repo}", - }, - argument: { - name: "repo", - value: "t", - }, - }, - }, - CompleteResultSchema, - ); - - expect(result3.completion.values).toEqual([]); - expect(result3.completion.total).toBe(0); - }); - - test("should support prompt argument completion with resolved context", async () => { - const mcpServer = new McpServer({ - name: "test server", - version: "1.0", - }); + restaurant: 'ABC Restaurant', + date: '2024-12-25', + partySize: 2 + } + }); - const client = new Client({ - name: "test client", - version: "1.0", + expect(checkAvailability).toHaveBeenCalledWith('ABC Restaurant', '2024-12-25', 2); + expect(findAlternatives).toHaveBeenCalledWith('ABC Restaurant', '2024-12-25', 2, 'same_week'); + expect(result.content).toEqual([ + { + type: 'text', + text: 'Found these alternatives: 2024-12-26, 2024-12-27, 2024-12-28' + } + ]); }); - mcpServer.registerPrompt( - "test-prompt", - { - title: "Team Greeting", - description: "Generate a greeting for team members", - argsSchema: { - department: completable(z.string(), (value) => { - return ["engineering", "sales", "marketing", "support"].filter(d => d.startsWith(value)); - }), - name: completable(z.string(), (value, context) => { - const department = context?.arguments?.["department"]; - if (department === "engineering") { - return ["Alice", "Bob", "Charlie"].filter(n => n.startsWith(value)); - } else if (department === "sales") { - return ["David", "Eve", "Frank"].filter(n => n.startsWith(value)); - } else if (department === "marketing") { - return ["Grace", "Henry", "Iris"].filter(n => n.startsWith(value)); - } - return ["Guest"].filter(n => n.startsWith(value)); - }), - } - }, - async ({ department, name }) => ({ - messages: [ - { - role: "assistant", - content: { - type: "text", - text: `Hello ${name}, welcome to the ${department} team!`, - }, - }, - ], - }), - ); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.server.connect(serverTransport), - ]); - - // Test with engineering department - const result1 = await client.request( - { - method: "completion/complete", - params: { - ref: { - type: "ref/prompt", - name: "test-prompt", - }, - argument: { - name: "name", - value: "A", - }, - context: { - arguments: { - department: "engineering", - }, - }, - }, - }, - CompleteResultSchema, - ); - - expect(result1.completion.values).toEqual(["Alice"]); - - // Test with sales department - const result2 = await client.request( - { - method: "completion/complete", - params: { - ref: { - type: "ref/prompt", - name: "test-prompt", - }, - argument: { - name: "name", - value: "D", - }, - context: { - arguments: { - department: "sales", - }, - }, - }, - }, - CompleteResultSchema, - ); - - expect(result2.completion.values).toEqual(["David"]); - - // Test with marketing department - const result3 = await client.request( - { - method: "completion/complete", - params: { - ref: { - type: "ref/prompt", - name: "test-prompt", - }, - argument: { - name: "name", - value: "G", - }, - context: { - arguments: { - department: "marketing", - }, - }, - }, - }, - CompleteResultSchema, - ); - - expect(result3.completion.values).toEqual(["Grace"]); - - // Test with no resolved context - const result4 = await client.request( - { - method: "completion/complete", - params: { - ref: { - type: "ref/prompt", - name: "test-prompt", - }, - argument: { - name: "name", - value: "G", - }, - }, - }, - CompleteResultSchema, - ); - - expect(result4.completion.values).toEqual(["Guest"]); - }); -}); + test('should handle user declining to elicitation request', async () => { + // Mock availability check to return false + checkAvailability.mockResolvedValue(false); -describe("elicitInput()", () => { + // Set up client to reject alternative date checking + client.setRequestHandler(ElicitRequestSchema, async () => { + return { + action: 'accept', + content: { + checkAlternatives: false + } + }; + }); - const checkAvailability = jest.fn().mockResolvedValue(false); - const findAlternatives = jest.fn().mockResolvedValue([]); - const makeBooking = jest.fn().mockResolvedValue("BOOKING-123"); + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - let mcpServer: McpServer; - let client: Client; + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); - beforeEach(() => { - jest.clearAllMocks(); + // Call the tool + const result = await client.callTool({ + name: 'book-restaurant', + arguments: { + restaurant: 'ABC Restaurant', + date: '2024-12-25', + partySize: 2 + } + }); - // Create server with restaurant booking tool - mcpServer = new McpServer({ - name: "restaurant-booking-server", - version: "1.0.0", + expect(checkAvailability).toHaveBeenCalledWith('ABC Restaurant', '2024-12-25', 2); + expect(findAlternatives).not.toHaveBeenCalled(); + expect(result.content).toEqual([ + { + type: 'text', + text: 'No booking made. Original date not available.' + } + ]); }); - // Register the restaurant booking tool from README example - mcpServer.tool( - "book-restaurant", - { - restaurant: z.string(), - date: z.string(), - partySize: z.number() - }, - async ({ restaurant, date, partySize }) => { - // Check availability - const available = await checkAvailability(restaurant, date, partySize); - - if (!available) { - // Ask user if they want to try alternative dates - const result = await mcpServer.server.elicitInput({ - message: `No tables available at ${restaurant} on ${date}. Would you like to check alternative dates?`, - requestedSchema: { - type: "object", - properties: { - checkAlternatives: { - type: "boolean", - title: "Check alternative dates", - description: "Would you like me to check other dates?" - }, - flexibleDates: { - type: "string", - title: "Date flexibility", - description: "How flexible are your dates?", - enum: ["next_day", "same_week", "next_week"], - enumNames: ["Next day", "Same week", "Next week"] - } - }, - required: ["checkAlternatives"] - } - }); - - if (result.action === "accept" && result.content?.checkAlternatives) { - const alternatives = await findAlternatives( - restaurant, - date, - partySize, - result.content.flexibleDates as string - ); + test('should handle user cancelling the elicitation', async () => { + // Mock availability check to return false + checkAvailability.mockResolvedValue(false); + + // Set up client to cancel the elicitation + client.setRequestHandler(ElicitRequestSchema, async () => { return { - content: [{ - type: "text", - text: `Found these alternatives: ${alternatives.join(", ")}` - }] + action: 'cancel' }; - } - - return { - content: [{ - type: "text", - text: "No booking made. Original date not available." - }] - }; - } - - await makeBooking(restaurant, date, partySize); - return { - content: [{ - type: "text", - text: `Booked table for ${partySize} at ${restaurant} on ${date}` - }] - }; - } - ); - - // Create client with elicitation capability - client = new Client( - { - name: "test-client", - version: "1.0.0", - }, - { - capabilities: { - elicitation: {}, - }, - } - ); - }); - - test("should successfully elicit additional information", async () => { - // Mock availability check to return false - checkAvailability.mockResolvedValue(false); - findAlternatives.mockResolvedValue(["2024-12-26", "2024-12-27", "2024-12-28"]); - - // Set up client to accept alternative date checking - client.setRequestHandler(ElicitRequestSchema, async (request) => { - expect(request.params.message).toContain("No tables available at ABC Restaurant on 2024-12-25"); - return { - action: "accept", - content: { - checkAlternatives: true, - flexibleDates: "same_week" - } - }; - }); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.server.connect(serverTransport), - ]); - - // Call the tool - const result = await client.callTool({ - name: "book-restaurant", - arguments: { - restaurant: "ABC Restaurant", - date: "2024-12-25", - partySize: 2 - } - }); + }); - expect(checkAvailability).toHaveBeenCalledWith("ABC Restaurant", "2024-12-25", 2); - expect(findAlternatives).toHaveBeenCalledWith("ABC Restaurant", "2024-12-25", 2, "same_week"); - expect(result.content).toEqual([{ - type: "text", - text: "Found these alternatives: 2024-12-26, 2024-12-27, 2024-12-28" - }]); - }); - - test("should handle user declining to elicitation request", async () => { - // Mock availability check to return false - checkAvailability.mockResolvedValue(false); - - // Set up client to reject alternative date checking - client.setRequestHandler(ElicitRequestSchema, async () => { - return { - action: "accept", - content: { - checkAlternatives: false - } - }; - }); + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.server.connect(serverTransport), - ]); - - // Call the tool - const result = await client.callTool({ - name: "book-restaurant", - arguments: { - restaurant: "ABC Restaurant", - date: "2024-12-25", - partySize: 2 - } - }); + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); - expect(checkAvailability).toHaveBeenCalledWith("ABC Restaurant", "2024-12-25", 2); - expect(findAlternatives).not.toHaveBeenCalled(); - expect(result.content).toEqual([{ - type: "text", - text: "No booking made. Original date not available." - }]); - }); - - test("should handle user cancelling the elicitation", async () => { - // Mock availability check to return false - checkAvailability.mockResolvedValue(false); - - // Set up client to cancel the elicitation - client.setRequestHandler(ElicitRequestSchema, async () => { - return { - action: "cancel" - }; - }); + // Call the tool + const result = await client.callTool({ + name: 'book-restaurant', + arguments: { + restaurant: 'ABC Restaurant', + date: '2024-12-25', + partySize: 2 + } + }); - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - mcpServer.server.connect(serverTransport), - ]); - - // Call the tool - const result = await client.callTool({ - name: "book-restaurant", - arguments: { - restaurant: "ABC Restaurant", - date: "2024-12-25", - partySize: 2 - } + expect(checkAvailability).toHaveBeenCalledWith('ABC Restaurant', '2024-12-25', 2); + expect(findAlternatives).not.toHaveBeenCalled(); + expect(result.content).toEqual([ + { + type: 'text', + text: 'No booking made. Original date not available.' + } + ]); }); - - expect(checkAvailability).toHaveBeenCalledWith("ABC Restaurant", "2024-12-25", 2); - expect(findAlternatives).not.toHaveBeenCalled(); - expect(result.content).toEqual([{ - type: "text", - text: "No booking made. Original date not available." - }]); - }); }); diff --git a/src/server/mcp.ts b/src/server/mcp.ts index ac4880c99..cef1722d6 100644 --- a/src/server/mcp.ts +++ b/src/server/mcp.ts @@ -1,52 +1,42 @@ -import { Server, ServerOptions } from "./index.js"; -import { zodToJsonSchema } from "zod-to-json-schema"; +import { Server, ServerOptions } from './index.js'; +import { zodToJsonSchema } from 'zod-to-json-schema'; +import { z, ZodRawShape, ZodObject, ZodString, AnyZodObject, ZodTypeAny, ZodType, ZodTypeDef, ZodOptional } from 'zod'; import { - z, - ZodRawShape, - ZodObject, - ZodString, - AnyZodObject, - ZodTypeAny, - ZodType, - ZodTypeDef, - ZodOptional, -} from "zod"; -import { - Implementation, - Tool, - ListToolsResult, - CallToolResult, - McpError, - ErrorCode, - CompleteRequest, - CompleteResult, - PromptReference, - ResourceTemplateReference, - BaseMetadata, - Resource, - ListResourcesResult, - ListResourceTemplatesRequestSchema, - ReadResourceRequestSchema, - ListToolsRequestSchema, - CallToolRequestSchema, - ListResourcesRequestSchema, - ListPromptsRequestSchema, - GetPromptRequestSchema, - CompleteRequestSchema, - ListPromptsResult, - Prompt, - PromptArgument, - GetPromptResult, - ReadResourceResult, - ServerRequest, - ServerNotification, - ToolAnnotations, - LoggingMessageNotification, -} from "../types.js"; -import { Completable, CompletableDef } from "./completable.js"; -import { UriTemplate, Variables } from "../shared/uriTemplate.js"; -import { RequestHandlerExtra } from "../shared/protocol.js"; -import { Transport } from "../shared/transport.js"; + Implementation, + Tool, + ListToolsResult, + CallToolResult, + McpError, + ErrorCode, + CompleteRequest, + CompleteResult, + PromptReference, + ResourceTemplateReference, + BaseMetadata, + Resource, + ListResourcesResult, + ListResourceTemplatesRequestSchema, + ReadResourceRequestSchema, + ListToolsRequestSchema, + CallToolRequestSchema, + ListResourcesRequestSchema, + ListPromptsRequestSchema, + GetPromptRequestSchema, + CompleteRequestSchema, + ListPromptsResult, + Prompt, + PromptArgument, + GetPromptResult, + ReadResourceResult, + ServerRequest, + ServerNotification, + ToolAnnotations, + LoggingMessageNotification +} from '../types.js'; +import { Completable, CompletableDef } from './completable.js'; +import { UriTemplate, Variables } from '../shared/uriTemplate.js'; +import { RequestHandlerExtra } from '../shared/protocol.js'; +import { Transport } from '../shared/transport.js'; /** * High-level MCP server that provides a simpler API for working with resources, tools, and prompts. @@ -54,1052 +44,913 @@ import { Transport } from "../shared/transport.js"; * Server instance available via the `server` property. */ export class McpServer { - /** - * The underlying Server instance, useful for advanced operations like sending notifications. - */ - public readonly server: Server; - - private _registeredResources: { [uri: string]: RegisteredResource } = {}; - private _registeredResourceTemplates: { - [name: string]: RegisteredResourceTemplate; - } = {}; - private _registeredTools: { [name: string]: RegisteredTool } = {}; - private _registeredPrompts: { [name: string]: RegisteredPrompt } = {}; - - constructor(serverInfo: Implementation, options?: ServerOptions) { - this.server = new Server(serverInfo, options); - } - - /** - * Attaches to the given transport, starts it, and starts listening for messages. - * - * The `server` object assumes ownership of the Transport, replacing any callbacks that have already been set, and expects that it is the only user of the Transport instance going forward. - */ - async connect(transport: Transport): Promise { - return await this.server.connect(transport); - } - - /** - * Closes the connection. - */ - async close(): Promise { - await this.server.close(); - } - - private _toolHandlersInitialized = false; - - private setToolRequestHandlers() { - if (this._toolHandlersInitialized) { - return; + /** + * The underlying Server instance, useful for advanced operations like sending notifications. + */ + public readonly server: Server; + + private _registeredResources: { [uri: string]: RegisteredResource } = {}; + private _registeredResourceTemplates: { + [name: string]: RegisteredResourceTemplate; + } = {}; + private _registeredTools: { [name: string]: RegisteredTool } = {}; + private _registeredPrompts: { [name: string]: RegisteredPrompt } = {}; + + constructor(serverInfo: Implementation, options?: ServerOptions) { + this.server = new Server(serverInfo, options); } - this.server.assertCanSetRequestHandler( - ListToolsRequestSchema.shape.method.value, - ); - this.server.assertCanSetRequestHandler( - CallToolRequestSchema.shape.method.value, - ); + /** + * Attaches to the given transport, starts it, and starts listening for messages. + * + * The `server` object assumes ownership of the Transport, replacing any callbacks that have already been set, and expects that it is the only user of the Transport instance going forward. + */ + async connect(transport: Transport): Promise { + return await this.server.connect(transport); + } - this.server.registerCapabilities({ - tools: { - listChanged: true - } - }) - - this.server.setRequestHandler( - ListToolsRequestSchema, - (): ListToolsResult => ({ - tools: Object.entries(this._registeredTools).filter( - ([, tool]) => tool.enabled, - ).map( - ([name, tool]): Tool => { - const toolDefinition: Tool = { - name, - title: tool.title, - description: tool.description, - inputSchema: tool.inputSchema - ? (zodToJsonSchema(tool.inputSchema, { - strictUnions: true, - }) as Tool["inputSchema"]) - : EMPTY_OBJECT_JSON_SCHEMA, - annotations: tool.annotations, - _meta: tool._meta, - }; + /** + * Closes the connection. + */ + async close(): Promise { + await this.server.close(); + } + + private _toolHandlersInitialized = false; + + private setToolRequestHandlers() { + if (this._toolHandlersInitialized) { + return; + } + + this.server.assertCanSetRequestHandler(ListToolsRequestSchema.shape.method.value); + this.server.assertCanSetRequestHandler(CallToolRequestSchema.shape.method.value); - if (tool.outputSchema) { - toolDefinition.outputSchema = zodToJsonSchema( - tool.outputSchema, - { strictUnions: true } - ) as Tool["outputSchema"]; + this.server.registerCapabilities({ + tools: { + listChanged: true } + }); + + this.server.setRequestHandler( + ListToolsRequestSchema, + (): ListToolsResult => ({ + tools: Object.entries(this._registeredTools) + .filter(([, tool]) => tool.enabled) + .map(([name, tool]): Tool => { + const toolDefinition: Tool = { + name, + title: tool.title, + description: tool.description, + inputSchema: tool.inputSchema + ? (zodToJsonSchema(tool.inputSchema, { + strictUnions: true + }) as Tool['inputSchema']) + : EMPTY_OBJECT_JSON_SCHEMA, + annotations: tool.annotations, + _meta: tool._meta + }; + + if (tool.outputSchema) { + toolDefinition.outputSchema = zodToJsonSchema(tool.outputSchema, { + strictUnions: true + }) as Tool['outputSchema']; + } + + return toolDefinition; + }) + }) + ); - return toolDefinition; - }, - ), - }), - ); + this.server.setRequestHandler(CallToolRequestSchema, async (request, extra): Promise => { + const tool = this._registeredTools[request.params.name]; + if (!tool) { + throw new McpError(ErrorCode.InvalidParams, `Tool ${request.params.name} not found`); + } - this.server.setRequestHandler( - CallToolRequestSchema, - async (request, extra): Promise => { - const tool = this._registeredTools[request.params.name]; - if (!tool) { - throw new McpError( - ErrorCode.InvalidParams, - `Tool ${request.params.name} not found`, - ); - } + if (!tool.enabled) { + throw new McpError(ErrorCode.InvalidParams, `Tool ${request.params.name} disabled`); + } - if (!tool.enabled) { - throw new McpError( - ErrorCode.InvalidParams, - `Tool ${request.params.name} disabled`, - ); - } + let result: CallToolResult; + + if (tool.inputSchema) { + const parseResult = await tool.inputSchema.safeParseAsync(request.params.arguments); + if (!parseResult.success) { + throw new McpError( + ErrorCode.InvalidParams, + `Invalid arguments for tool ${request.params.name}: ${parseResult.error.message}` + ); + } + + const args = parseResult.data; + const cb = tool.callback as ToolCallback; + try { + result = await Promise.resolve(cb(args, extra)); + } catch (error) { + result = { + content: [ + { + type: 'text', + text: error instanceof Error ? error.message : String(error) + } + ], + isError: true + }; + } + } else { + const cb = tool.callback as ToolCallback; + try { + result = await Promise.resolve(cb(extra)); + } catch (error) { + result = { + content: [ + { + type: 'text', + text: error instanceof Error ? error.message : String(error) + } + ], + isError: true + }; + } + } - let result: CallToolResult; + if (tool.outputSchema && !result.isError) { + if (!result.structuredContent) { + throw new McpError( + ErrorCode.InvalidParams, + `Tool ${request.params.name} has an output schema but no structured content was provided` + ); + } + + // if the tool has an output schema, validate structured content + const parseResult = await tool.outputSchema.safeParseAsync(result.structuredContent); + if (!parseResult.success) { + throw new McpError( + ErrorCode.InvalidParams, + `Invalid structured content for tool ${request.params.name}: ${parseResult.error.message}` + ); + } + } - if (tool.inputSchema) { - const parseResult = await tool.inputSchema.safeParseAsync( - request.params.arguments, - ); - if (!parseResult.success) { - throw new McpError( - ErrorCode.InvalidParams, - `Invalid arguments for tool ${request.params.name}: ${parseResult.error.message}`, - ); - } - - const args = parseResult.data; - const cb = tool.callback as ToolCallback; - try { - result = await Promise.resolve(cb(args, extra)); - } catch (error) { - result = { - content: [ - { - type: "text", - text: error instanceof Error ? error.message : String(error), - }, - ], - isError: true, - }; - } - } else { - const cb = tool.callback as ToolCallback; - try { - result = await Promise.resolve(cb(extra)); - } catch (error) { - result = { - content: [ - { - type: "text", - text: error instanceof Error ? error.message : String(error), - }, - ], - isError: true, - }; - } - } + return result; + }); - if (tool.outputSchema && !result.isError) { - if (!result.structuredContent) { - throw new McpError( - ErrorCode.InvalidParams, - `Tool ${request.params.name} has an output schema but no structured content was provided`, - ); - } - - // if the tool has an output schema, validate structured content - const parseResult = await tool.outputSchema.safeParseAsync( - result.structuredContent, - ); - if (!parseResult.success) { - throw new McpError( - ErrorCode.InvalidParams, - `Invalid structured content for tool ${request.params.name}: ${parseResult.error.message}`, - ); - } + this._toolHandlersInitialized = true; + } + + private _completionHandlerInitialized = false; + + private setCompletionRequestHandler() { + if (this._completionHandlerInitialized) { + return; } - return result; - }, - ); + this.server.assertCanSetRequestHandler(CompleteRequestSchema.shape.method.value); - this._toolHandlersInitialized = true; - } + this.server.registerCapabilities({ + completions: {} + }); - private _completionHandlerInitialized = false; + this.server.setRequestHandler(CompleteRequestSchema, async (request): Promise => { + switch (request.params.ref.type) { + case 'ref/prompt': + return this.handlePromptCompletion(request, request.params.ref); - private setCompletionRequestHandler() { - if (this._completionHandlerInitialized) { - return; - } + case 'ref/resource': + return this.handleResourceCompletion(request, request.params.ref); - this.server.assertCanSetRequestHandler( - CompleteRequestSchema.shape.method.value, - ); + default: + throw new McpError(ErrorCode.InvalidParams, `Invalid completion reference: ${request.params.ref}`); + } + }); - this.server.registerCapabilities({ - completions: {}, - }); + this._completionHandlerInitialized = true; + } - this.server.setRequestHandler( - CompleteRequestSchema, - async (request): Promise => { - switch (request.params.ref.type) { - case "ref/prompt": - return this.handlePromptCompletion(request, request.params.ref); + private async handlePromptCompletion(request: CompleteRequest, ref: PromptReference): Promise { + const prompt = this._registeredPrompts[ref.name]; + if (!prompt) { + throw new McpError(ErrorCode.InvalidParams, `Prompt ${ref.name} not found`); + } - case "ref/resource": - return this.handleResourceCompletion(request, request.params.ref); + if (!prompt.enabled) { + throw new McpError(ErrorCode.InvalidParams, `Prompt ${ref.name} disabled`); + } - default: - throw new McpError( - ErrorCode.InvalidParams, - `Invalid completion reference: ${request.params.ref}`, - ); + if (!prompt.argsSchema) { + return EMPTY_COMPLETION_RESULT; } - }, - ); - this._completionHandlerInitialized = true; - } - - private async handlePromptCompletion( - request: CompleteRequest, - ref: PromptReference, - ): Promise { - const prompt = this._registeredPrompts[ref.name]; - if (!prompt) { - throw new McpError( - ErrorCode.InvalidParams, - `Prompt ${ref.name} not found`, - ); - } + const field = prompt.argsSchema.shape[request.params.argument.name]; + if (!(field instanceof Completable)) { + return EMPTY_COMPLETION_RESULT; + } - if (!prompt.enabled) { - throw new McpError( - ErrorCode.InvalidParams, - `Prompt ${ref.name} disabled`, - ); + const def: CompletableDef = field._def; + const suggestions = await def.complete(request.params.argument.value, request.params.context); + return createCompletionResult(suggestions); } - if (!prompt.argsSchema) { - return EMPTY_COMPLETION_RESULT; - } + private async handleResourceCompletion(request: CompleteRequest, ref: ResourceTemplateReference): Promise { + const template = Object.values(this._registeredResourceTemplates).find(t => t.resourceTemplate.uriTemplate.toString() === ref.uri); - const field = prompt.argsSchema.shape[request.params.argument.name]; - if (!(field instanceof Completable)) { - return EMPTY_COMPLETION_RESULT; - } + if (!template) { + if (this._registeredResources[ref.uri]) { + // Attempting to autocomplete a fixed resource URI is not an error in the spec (but probably should be). + return EMPTY_COMPLETION_RESULT; + } - const def: CompletableDef = field._def; - const suggestions = await def.complete(request.params.argument.value, request.params.context); - return createCompletionResult(suggestions); - } - - private async handleResourceCompletion( - request: CompleteRequest, - ref: ResourceTemplateReference, - ): Promise { - const template = Object.values(this._registeredResourceTemplates).find( - (t) => t.resourceTemplate.uriTemplate.toString() === ref.uri, - ); + throw new McpError(ErrorCode.InvalidParams, `Resource template ${request.params.ref.uri} not found`); + } - if (!template) { - if (this._registeredResources[ref.uri]) { - // Attempting to autocomplete a fixed resource URI is not an error in the spec (but probably should be). - return EMPTY_COMPLETION_RESULT; - } + const completer = template.resourceTemplate.completeCallback(request.params.argument.name); + if (!completer) { + return EMPTY_COMPLETION_RESULT; + } - throw new McpError( - ErrorCode.InvalidParams, - `Resource template ${request.params.ref.uri} not found`, - ); + const suggestions = await completer(request.params.argument.value, request.params.context); + return createCompletionResult(suggestions); } - const completer = template.resourceTemplate.completeCallback( - request.params.argument.name, - ); - if (!completer) { - return EMPTY_COMPLETION_RESULT; - } + private _resourceHandlersInitialized = false; - const suggestions = await completer(request.params.argument.value, request.params.context); - return createCompletionResult(suggestions); - } + private setResourceRequestHandlers() { + if (this._resourceHandlersInitialized) { + return; + } - private _resourceHandlersInitialized = false; + this.server.assertCanSetRequestHandler(ListResourcesRequestSchema.shape.method.value); + this.server.assertCanSetRequestHandler(ListResourceTemplatesRequestSchema.shape.method.value); + this.server.assertCanSetRequestHandler(ReadResourceRequestSchema.shape.method.value); - private setResourceRequestHandlers() { - if (this._resourceHandlersInitialized) { - return; - } + this.server.registerCapabilities({ + resources: { + listChanged: true + } + }); + + this.server.setRequestHandler(ListResourcesRequestSchema, async (request, extra) => { + const resources = Object.entries(this._registeredResources) + .filter(([_, resource]) => resource.enabled) + .map(([uri, resource]) => ({ + uri, + name: resource.name, + ...resource.metadata + })); + + const templateResources: Resource[] = []; + for (const template of Object.values(this._registeredResourceTemplates)) { + if (!template.resourceTemplate.listCallback) { + continue; + } + + const result = await template.resourceTemplate.listCallback(extra); + for (const resource of result.resources) { + templateResources.push({ + ...template.metadata, + // the defined resource metadata should override the template metadata if present + ...resource + }); + } + } - this.server.assertCanSetRequestHandler( - ListResourcesRequestSchema.shape.method.value, - ); - this.server.assertCanSetRequestHandler( - ListResourceTemplatesRequestSchema.shape.method.value, - ); - this.server.assertCanSetRequestHandler( - ReadResourceRequestSchema.shape.method.value, - ); + return { resources: [...resources, ...templateResources] }; + }); + + this.server.setRequestHandler(ListResourceTemplatesRequestSchema, async () => { + const resourceTemplates = Object.entries(this._registeredResourceTemplates).map(([name, template]) => ({ + name, + uriTemplate: template.resourceTemplate.uriTemplate.toString(), + ...template.metadata + })); + + return { resourceTemplates }; + }); + + this.server.setRequestHandler(ReadResourceRequestSchema, async (request, extra) => { + const uri = new URL(request.params.uri); + + // First check for exact resource match + const resource = this._registeredResources[uri.toString()]; + if (resource) { + if (!resource.enabled) { + throw new McpError(ErrorCode.InvalidParams, `Resource ${uri} disabled`); + } + return resource.readCallback(uri, extra); + } - this.server.registerCapabilities({ - resources: { - listChanged: true - } - }) - - this.server.setRequestHandler( - ListResourcesRequestSchema, - async (request, extra) => { - const resources = Object.entries(this._registeredResources).filter( - ([_, resource]) => resource.enabled, - ).map( - ([uri, resource]) => ({ - uri, - name: resource.name, - ...resource.metadata, - }), - ); + // Then check templates + for (const template of Object.values(this._registeredResourceTemplates)) { + const variables = template.resourceTemplate.uriTemplate.match(uri.toString()); + if (variables) { + return template.readCallback(uri, variables, extra); + } + } - const templateResources: Resource[] = []; - for (const template of Object.values( - this._registeredResourceTemplates, - )) { - if (!template.resourceTemplate.listCallback) { - continue; - } - - const result = await template.resourceTemplate.listCallback(extra); - for (const resource of result.resources) { - templateResources.push({ - ...template.metadata, - // the defined resource metadata should override the template metadata if present - ...resource, - }); - } - } + throw new McpError(ErrorCode.InvalidParams, `Resource ${uri} not found`); + }); - return { resources: [...resources, ...templateResources] }; - }, - ); + this.setCompletionRequestHandler(); - this.server.setRequestHandler( - ListResourceTemplatesRequestSchema, - async () => { - const resourceTemplates = Object.entries( - this._registeredResourceTemplates, - ).map(([name, template]) => ({ - name, - uriTemplate: template.resourceTemplate.uriTemplate.toString(), - ...template.metadata, - })); - - return { resourceTemplates }; - }, - ); + this._resourceHandlersInitialized = true; + } - this.server.setRequestHandler( - ReadResourceRequestSchema, - async (request, extra) => { - const uri = new URL(request.params.uri); - - // First check for exact resource match - const resource = this._registeredResources[uri.toString()]; - if (resource) { - if (!resource.enabled) { - throw new McpError( - ErrorCode.InvalidParams, - `Resource ${uri} disabled`, - ); - } - return resource.readCallback(uri, extra); - } + private _promptHandlersInitialized = false; - // Then check templates - for (const template of Object.values( - this._registeredResourceTemplates, - )) { - const variables = template.resourceTemplate.uriTemplate.match( - uri.toString(), - ); - if (variables) { - return template.readCallback(uri, variables, extra); - } + private setPromptRequestHandlers() { + if (this._promptHandlersInitialized) { + return; } - throw new McpError( - ErrorCode.InvalidParams, - `Resource ${uri} not found`, + this.server.assertCanSetRequestHandler(ListPromptsRequestSchema.shape.method.value); + this.server.assertCanSetRequestHandler(GetPromptRequestSchema.shape.method.value); + + this.server.registerCapabilities({ + prompts: { + listChanged: true + } + }); + + this.server.setRequestHandler( + ListPromptsRequestSchema, + (): ListPromptsResult => ({ + prompts: Object.entries(this._registeredPrompts) + .filter(([, prompt]) => prompt.enabled) + .map(([name, prompt]): Prompt => { + return { + name, + title: prompt.title, + description: prompt.description, + arguments: prompt.argsSchema ? promptArgumentsFromSchema(prompt.argsSchema) : undefined + }; + }) + }) ); - }, - ); - this.setCompletionRequestHandler(); + this.server.setRequestHandler(GetPromptRequestSchema, async (request, extra): Promise => { + const prompt = this._registeredPrompts[request.params.name]; + if (!prompt) { + throw new McpError(ErrorCode.InvalidParams, `Prompt ${request.params.name} not found`); + } + + if (!prompt.enabled) { + throw new McpError(ErrorCode.InvalidParams, `Prompt ${request.params.name} disabled`); + } - this._resourceHandlersInitialized = true; - } + if (prompt.argsSchema) { + const parseResult = await prompt.argsSchema.safeParseAsync(request.params.arguments); + if (!parseResult.success) { + throw new McpError( + ErrorCode.InvalidParams, + `Invalid arguments for prompt ${request.params.name}: ${parseResult.error.message}` + ); + } + + const args = parseResult.data; + const cb = prompt.callback as PromptCallback; + return await Promise.resolve(cb(args, extra)); + } else { + const cb = prompt.callback as PromptCallback; + return await Promise.resolve(cb(extra)); + } + }); - private _promptHandlersInitialized = false; + this.setCompletionRequestHandler(); - private setPromptRequestHandlers() { - if (this._promptHandlersInitialized) { - return; + this._promptHandlersInitialized = true; } - this.server.assertCanSetRequestHandler( - ListPromptsRequestSchema.shape.method.value, - ); - this.server.assertCanSetRequestHandler( - GetPromptRequestSchema.shape.method.value, - ); + /** + * Registers a resource `name` at a fixed URI, which will use the given callback to respond to read requests. + */ + resource(name: string, uri: string, readCallback: ReadResourceCallback): RegisteredResource; + + /** + * Registers a resource `name` at a fixed URI with metadata, which will use the given callback to respond to read requests. + */ + resource(name: string, uri: string, metadata: ResourceMetadata, readCallback: ReadResourceCallback): RegisteredResource; + + /** + * Registers a resource `name` with a template pattern, which will use the given callback to respond to read requests. + */ + resource(name: string, template: ResourceTemplate, readCallback: ReadResourceTemplateCallback): RegisteredResourceTemplate; + + /** + * Registers a resource `name` with a template pattern and metadata, which will use the given callback to respond to read requests. + */ + resource( + name: string, + template: ResourceTemplate, + metadata: ResourceMetadata, + readCallback: ReadResourceTemplateCallback + ): RegisteredResourceTemplate; + + resource(name: string, uriOrTemplate: string | ResourceTemplate, ...rest: unknown[]): RegisteredResource | RegisteredResourceTemplate { + let metadata: ResourceMetadata | undefined; + if (typeof rest[0] === 'object') { + metadata = rest.shift() as ResourceMetadata; + } - this.server.registerCapabilities({ - prompts: { - listChanged: true - } - }) - - this.server.setRequestHandler( - ListPromptsRequestSchema, - (): ListPromptsResult => ({ - prompts: Object.entries(this._registeredPrompts).filter( - ([, prompt]) => prompt.enabled, - ).map( - ([name, prompt]): Prompt => { - return { - name, - title: prompt.title, - description: prompt.description, - arguments: prompt.argsSchema - ? promptArgumentsFromSchema(prompt.argsSchema) - : undefined, - }; - }, - ), - }), - ); + const readCallback = rest[0] as ReadResourceCallback | ReadResourceTemplateCallback; - this.server.setRequestHandler( - GetPromptRequestSchema, - async (request, extra): Promise => { - const prompt = this._registeredPrompts[request.params.name]; - if (!prompt) { - throw new McpError( - ErrorCode.InvalidParams, - `Prompt ${request.params.name} not found`, - ); - } + if (typeof uriOrTemplate === 'string') { + if (this._registeredResources[uriOrTemplate]) { + throw new Error(`Resource ${uriOrTemplate} is already registered`); + } - if (!prompt.enabled) { - throw new McpError( - ErrorCode.InvalidParams, - `Prompt ${request.params.name} disabled`, - ); + const registeredResource = this._createRegisteredResource( + name, + undefined, + uriOrTemplate, + metadata, + readCallback as ReadResourceCallback + ); + + this.setResourceRequestHandlers(); + this.sendResourceListChanged(); + return registeredResource; + } else { + if (this._registeredResourceTemplates[name]) { + throw new Error(`Resource template ${name} is already registered`); + } + + const registeredResourceTemplate = this._createRegisteredResourceTemplate( + name, + undefined, + uriOrTemplate, + metadata, + readCallback as ReadResourceTemplateCallback + ); + + this.setResourceRequestHandlers(); + this.sendResourceListChanged(); + return registeredResourceTemplate; } + } + + /** + * Registers a resource with a config object and callback. + * For static resources, use a URI string. For dynamic resources, use a ResourceTemplate. + */ + registerResource(name: string, uriOrTemplate: string, config: ResourceMetadata, readCallback: ReadResourceCallback): RegisteredResource; + registerResource( + name: string, + uriOrTemplate: ResourceTemplate, + config: ResourceMetadata, + readCallback: ReadResourceTemplateCallback + ): RegisteredResourceTemplate; + registerResource( + name: string, + uriOrTemplate: string | ResourceTemplate, + config: ResourceMetadata, + readCallback: ReadResourceCallback | ReadResourceTemplateCallback + ): RegisteredResource | RegisteredResourceTemplate { + if (typeof uriOrTemplate === 'string') { + if (this._registeredResources[uriOrTemplate]) { + throw new Error(`Resource ${uriOrTemplate} is already registered`); + } - if (prompt.argsSchema) { - const parseResult = await prompt.argsSchema.safeParseAsync( - request.params.arguments, - ); - if (!parseResult.success) { - throw new McpError( - ErrorCode.InvalidParams, - `Invalid arguments for prompt ${request.params.name}: ${parseResult.error.message}`, + const registeredResource = this._createRegisteredResource( + name, + (config as BaseMetadata).title, + uriOrTemplate, + config, + readCallback as ReadResourceCallback ); - } - const args = parseResult.data; - const cb = prompt.callback as PromptCallback; - return await Promise.resolve(cb(args, extra)); + this.setResourceRequestHandlers(); + this.sendResourceListChanged(); + return registeredResource; } else { - const cb = prompt.callback as PromptCallback; - return await Promise.resolve(cb(extra)); + if (this._registeredResourceTemplates[name]) { + throw new Error(`Resource template ${name} is already registered`); + } + + const registeredResourceTemplate = this._createRegisteredResourceTemplate( + name, + (config as BaseMetadata).title, + uriOrTemplate, + config, + readCallback as ReadResourceTemplateCallback + ); + + this.setResourceRequestHandlers(); + this.sendResourceListChanged(); + return registeredResourceTemplate; } - }, - ); + } + + private _createRegisteredResource( + name: string, + title: string | undefined, + uri: string, + metadata: ResourceMetadata | undefined, + readCallback: ReadResourceCallback + ): RegisteredResource { + const registeredResource: RegisteredResource = { + name, + title, + metadata, + readCallback, + enabled: true, + disable: () => registeredResource.update({ enabled: false }), + enable: () => registeredResource.update({ enabled: true }), + remove: () => registeredResource.update({ uri: null }), + update: updates => { + if (typeof updates.uri !== 'undefined' && updates.uri !== uri) { + delete this._registeredResources[uri]; + if (updates.uri) this._registeredResources[updates.uri] = registeredResource; + } + if (typeof updates.name !== 'undefined') registeredResource.name = updates.name; + if (typeof updates.title !== 'undefined') registeredResource.title = updates.title; + if (typeof updates.metadata !== 'undefined') registeredResource.metadata = updates.metadata; + if (typeof updates.callback !== 'undefined') registeredResource.readCallback = updates.callback; + if (typeof updates.enabled !== 'undefined') registeredResource.enabled = updates.enabled; + this.sendResourceListChanged(); + } + }; + this._registeredResources[uri] = registeredResource; + return registeredResource; + } - this.setCompletionRequestHandler(); - - this._promptHandlersInitialized = true; - } - - /** - * Registers a resource `name` at a fixed URI, which will use the given callback to respond to read requests. - */ - resource(name: string, uri: string, readCallback: ReadResourceCallback): RegisteredResource; - - /** - * Registers a resource `name` at a fixed URI with metadata, which will use the given callback to respond to read requests. - */ - resource( - name: string, - uri: string, - metadata: ResourceMetadata, - readCallback: ReadResourceCallback, - ): RegisteredResource; - - /** - * Registers a resource `name` with a template pattern, which will use the given callback to respond to read requests. - */ - resource( - name: string, - template: ResourceTemplate, - readCallback: ReadResourceTemplateCallback, - ): RegisteredResourceTemplate; - - /** - * Registers a resource `name` with a template pattern and metadata, which will use the given callback to respond to read requests. - */ - resource( - name: string, - template: ResourceTemplate, - metadata: ResourceMetadata, - readCallback: ReadResourceTemplateCallback, - ): RegisteredResourceTemplate; - - resource( - name: string, - uriOrTemplate: string | ResourceTemplate, - ...rest: unknown[] - ): RegisteredResource | RegisteredResourceTemplate { - let metadata: ResourceMetadata | undefined; - if (typeof rest[0] === "object") { - metadata = rest.shift() as ResourceMetadata; + private _createRegisteredResourceTemplate( + name: string, + title: string | undefined, + template: ResourceTemplate, + metadata: ResourceMetadata | undefined, + readCallback: ReadResourceTemplateCallback + ): RegisteredResourceTemplate { + const registeredResourceTemplate: RegisteredResourceTemplate = { + resourceTemplate: template, + title, + metadata, + readCallback, + enabled: true, + disable: () => registeredResourceTemplate.update({ enabled: false }), + enable: () => registeredResourceTemplate.update({ enabled: true }), + remove: () => registeredResourceTemplate.update({ name: null }), + update: updates => { + if (typeof updates.name !== 'undefined' && updates.name !== name) { + delete this._registeredResourceTemplates[name]; + if (updates.name) this._registeredResourceTemplates[updates.name] = registeredResourceTemplate; + } + if (typeof updates.title !== 'undefined') registeredResourceTemplate.title = updates.title; + if (typeof updates.template !== 'undefined') registeredResourceTemplate.resourceTemplate = updates.template; + if (typeof updates.metadata !== 'undefined') registeredResourceTemplate.metadata = updates.metadata; + if (typeof updates.callback !== 'undefined') registeredResourceTemplate.readCallback = updates.callback; + if (typeof updates.enabled !== 'undefined') registeredResourceTemplate.enabled = updates.enabled; + this.sendResourceListChanged(); + } + }; + this._registeredResourceTemplates[name] = registeredResourceTemplate; + return registeredResourceTemplate; } - const readCallback = rest[0] as - | ReadResourceCallback - | ReadResourceTemplateCallback; - - if (typeof uriOrTemplate === "string") { - if (this._registeredResources[uriOrTemplate]) { - throw new Error(`Resource ${uriOrTemplate} is already registered`); - } - - const registeredResource = this._createRegisteredResource( - name, - undefined, - uriOrTemplate, - metadata, - readCallback as ReadResourceCallback - ); - - this.setResourceRequestHandlers(); - this.sendResourceListChanged(); - return registeredResource; - } else { - if (this._registeredResourceTemplates[name]) { - throw new Error(`Resource template ${name} is already registered`); - } - - const registeredResourceTemplate = this._createRegisteredResourceTemplate( - name, - undefined, - uriOrTemplate, - metadata, - readCallback as ReadResourceTemplateCallback - ); - - this.setResourceRequestHandlers(); - this.sendResourceListChanged(); - return registeredResourceTemplate; + private _createRegisteredPrompt( + name: string, + title: string | undefined, + description: string | undefined, + argsSchema: PromptArgsRawShape | undefined, + callback: PromptCallback + ): RegisteredPrompt { + const registeredPrompt: RegisteredPrompt = { + title, + description, + argsSchema: argsSchema === undefined ? undefined : z.object(argsSchema), + callback, + enabled: true, + disable: () => registeredPrompt.update({ enabled: false }), + enable: () => registeredPrompt.update({ enabled: true }), + remove: () => registeredPrompt.update({ name: null }), + update: updates => { + if (typeof updates.name !== 'undefined' && updates.name !== name) { + delete this._registeredPrompts[name]; + if (updates.name) this._registeredPrompts[updates.name] = registeredPrompt; + } + if (typeof updates.title !== 'undefined') registeredPrompt.title = updates.title; + if (typeof updates.description !== 'undefined') registeredPrompt.description = updates.description; + if (typeof updates.argsSchema !== 'undefined') registeredPrompt.argsSchema = z.object(updates.argsSchema); + if (typeof updates.callback !== 'undefined') registeredPrompt.callback = updates.callback; + if (typeof updates.enabled !== 'undefined') registeredPrompt.enabled = updates.enabled; + this.sendPromptListChanged(); + } + }; + this._registeredPrompts[name] = registeredPrompt; + return registeredPrompt; } - } - - /** - * Registers a resource with a config object and callback. - * For static resources, use a URI string. For dynamic resources, use a ResourceTemplate. - */ - registerResource( - name: string, - uriOrTemplate: string, - config: ResourceMetadata, - readCallback: ReadResourceCallback - ): RegisteredResource; - registerResource( - name: string, - uriOrTemplate: ResourceTemplate, - config: ResourceMetadata, - readCallback: ReadResourceTemplateCallback - ): RegisteredResourceTemplate; - registerResource( - name: string, - uriOrTemplate: string | ResourceTemplate, - config: ResourceMetadata, - readCallback: ReadResourceCallback | ReadResourceTemplateCallback - ): RegisteredResource | RegisteredResourceTemplate { - if (typeof uriOrTemplate === "string") { - if (this._registeredResources[uriOrTemplate]) { - throw new Error(`Resource ${uriOrTemplate} is already registered`); - } - - const registeredResource = this._createRegisteredResource( - name, - (config as BaseMetadata).title, - uriOrTemplate, - config, - readCallback as ReadResourceCallback - ); - - this.setResourceRequestHandlers(); - this.sendResourceListChanged(); - return registeredResource; - } else { - if (this._registeredResourceTemplates[name]) { - throw new Error(`Resource template ${name} is already registered`); - } - - const registeredResourceTemplate = this._createRegisteredResourceTemplate( - name, - (config as BaseMetadata).title, - uriOrTemplate, - config, - readCallback as ReadResourceTemplateCallback - ); - - this.setResourceRequestHandlers(); - this.sendResourceListChanged(); - return registeredResourceTemplate; + + private _createRegisteredTool( + name: string, + title: string | undefined, + description: string | undefined, + inputSchema: ZodRawShape | undefined, + outputSchema: ZodRawShape | undefined, + annotations: ToolAnnotations | undefined, + _meta: Record | undefined, + callback: ToolCallback + ): RegisteredTool { + const registeredTool: RegisteredTool = { + title, + description, + inputSchema: inputSchema === undefined ? undefined : z.object(inputSchema), + outputSchema: outputSchema === undefined ? undefined : z.object(outputSchema), + annotations, + _meta, + callback, + enabled: true, + disable: () => registeredTool.update({ enabled: false }), + enable: () => registeredTool.update({ enabled: true }), + remove: () => registeredTool.update({ name: null }), + update: updates => { + if (typeof updates.name !== 'undefined' && updates.name !== name) { + delete this._registeredTools[name]; + if (updates.name) this._registeredTools[updates.name] = registeredTool; + } + if (typeof updates.title !== 'undefined') registeredTool.title = updates.title; + if (typeof updates.description !== 'undefined') registeredTool.description = updates.description; + if (typeof updates.paramsSchema !== 'undefined') registeredTool.inputSchema = z.object(updates.paramsSchema); + if (typeof updates.callback !== 'undefined') registeredTool.callback = updates.callback; + if (typeof updates.annotations !== 'undefined') registeredTool.annotations = updates.annotations; + if (typeof updates._meta !== 'undefined') registeredTool._meta = updates._meta; + if (typeof updates.enabled !== 'undefined') registeredTool.enabled = updates.enabled; + this.sendToolListChanged(); + } + }; + this._registeredTools[name] = registeredTool; + + this.setToolRequestHandlers(); + this.sendToolListChanged(); + + return registeredTool; } - } - - private _createRegisteredResource( - name: string, - title: string | undefined, - uri: string, - metadata: ResourceMetadata | undefined, - readCallback: ReadResourceCallback - ): RegisteredResource { - const registeredResource: RegisteredResource = { - name, - title, - metadata, - readCallback, - enabled: true, - disable: () => registeredResource.update({ enabled: false }), - enable: () => registeredResource.update({ enabled: true }), - remove: () => registeredResource.update({ uri: null }), - update: (updates) => { - if (typeof updates.uri !== "undefined" && updates.uri !== uri) { - delete this._registeredResources[uri] - if (updates.uri) this._registeredResources[updates.uri] = registeredResource - } - if (typeof updates.name !== "undefined") registeredResource.name = updates.name - if (typeof updates.title !== "undefined") registeredResource.title = updates.title - if (typeof updates.metadata !== "undefined") registeredResource.metadata = updates.metadata - if (typeof updates.callback !== "undefined") registeredResource.readCallback = updates.callback - if (typeof updates.enabled !== "undefined") registeredResource.enabled = updates.enabled - this.sendResourceListChanged() - }, - }; - this._registeredResources[uri] = registeredResource; - return registeredResource; - } - - private _createRegisteredResourceTemplate( - name: string, - title: string | undefined, - template: ResourceTemplate, - metadata: ResourceMetadata | undefined, - readCallback: ReadResourceTemplateCallback - ): RegisteredResourceTemplate { - const registeredResourceTemplate: RegisteredResourceTemplate = { - resourceTemplate: template, - title, - metadata, - readCallback, - enabled: true, - disable: () => registeredResourceTemplate.update({ enabled: false }), - enable: () => registeredResourceTemplate.update({ enabled: true }), - remove: () => registeredResourceTemplate.update({ name: null }), - update: (updates) => { - if (typeof updates.name !== "undefined" && updates.name !== name) { - delete this._registeredResourceTemplates[name] - if (updates.name) this._registeredResourceTemplates[updates.name] = registeredResourceTemplate + + /** + * Registers a zero-argument tool `name`, which will run the given function when the client calls it. + */ + tool(name: string, cb: ToolCallback): RegisteredTool; + + /** + * Registers a zero-argument tool `name` (with a description) which will run the given function when the client calls it. + */ + tool(name: string, description: string, cb: ToolCallback): RegisteredTool; + + /** + * Registers a tool taking either a parameter schema for validation or annotations for additional metadata. + * This unified overload handles both `tool(name, paramsSchema, cb)` and `tool(name, annotations, cb)` cases. + * + * Note: We use a union type for the second parameter because TypeScript cannot reliably disambiguate + * between ToolAnnotations and ZodRawShape during overload resolution, as both are plain object types. + */ + tool(name: string, paramsSchemaOrAnnotations: Args | ToolAnnotations, cb: ToolCallback): RegisteredTool; + + /** + * Registers a tool `name` (with a description) taking either parameter schema or annotations. + * This unified overload handles both `tool(name, description, paramsSchema, cb)` and + * `tool(name, description, annotations, cb)` cases. + * + * Note: We use a union type for the third parameter because TypeScript cannot reliably disambiguate + * between ToolAnnotations and ZodRawShape during overload resolution, as both are plain object types. + */ + tool( + name: string, + description: string, + paramsSchemaOrAnnotations: Args | ToolAnnotations, + cb: ToolCallback + ): RegisteredTool; + + /** + * Registers a tool with both parameter schema and annotations. + */ + tool(name: string, paramsSchema: Args, annotations: ToolAnnotations, cb: ToolCallback): RegisteredTool; + + /** + * Registers a tool with description, parameter schema, and annotations. + */ + tool( + name: string, + description: string, + paramsSchema: Args, + annotations: ToolAnnotations, + cb: ToolCallback + ): RegisteredTool; + + /** + * tool() implementation. Parses arguments passed to overrides defined above. + */ + tool(name: string, ...rest: unknown[]): RegisteredTool { + if (this._registeredTools[name]) { + throw new Error(`Tool ${name} is already registered`); } - if (typeof updates.title !== "undefined") registeredResourceTemplate.title = updates.title - if (typeof updates.template !== "undefined") registeredResourceTemplate.resourceTemplate = updates.template - if (typeof updates.metadata !== "undefined") registeredResourceTemplate.metadata = updates.metadata - if (typeof updates.callback !== "undefined") registeredResourceTemplate.readCallback = updates.callback - if (typeof updates.enabled !== "undefined") registeredResourceTemplate.enabled = updates.enabled - this.sendResourceListChanged() - }, - }; - this._registeredResourceTemplates[name] = registeredResourceTemplate; - return registeredResourceTemplate; - } - - private _createRegisteredPrompt( - name: string, - title: string | undefined, - description: string | undefined, - argsSchema: PromptArgsRawShape | undefined, - callback: PromptCallback - ): RegisteredPrompt { - const registeredPrompt: RegisteredPrompt = { - title, - description, - argsSchema: argsSchema === undefined ? undefined : z.object(argsSchema), - callback, - enabled: true, - disable: () => registeredPrompt.update({ enabled: false }), - enable: () => registeredPrompt.update({ enabled: true }), - remove: () => registeredPrompt.update({ name: null }), - update: (updates) => { - if (typeof updates.name !== "undefined" && updates.name !== name) { - delete this._registeredPrompts[name] - if (updates.name) this._registeredPrompts[updates.name] = registeredPrompt + + let description: string | undefined; + let inputSchema: ZodRawShape | undefined; + let outputSchema: ZodRawShape | undefined; + let annotations: ToolAnnotations | undefined; + + // Tool properties are passed as separate arguments, with omissions allowed. + // Support for this style is frozen as of protocol version 2025-03-26. Future additions + // to tool definition should *NOT* be added. + + if (typeof rest[0] === 'string') { + description = rest.shift() as string; } - if (typeof updates.title !== "undefined") registeredPrompt.title = updates.title - if (typeof updates.description !== "undefined") registeredPrompt.description = updates.description - if (typeof updates.argsSchema !== "undefined") registeredPrompt.argsSchema = z.object(updates.argsSchema) - if (typeof updates.callback !== "undefined") registeredPrompt.callback = updates.callback - if (typeof updates.enabled !== "undefined") registeredPrompt.enabled = updates.enabled - this.sendPromptListChanged() - }, - }; - this._registeredPrompts[name] = registeredPrompt; - return registeredPrompt; - } - - private _createRegisteredTool( - name: string, - title: string | undefined, - description: string | undefined, - inputSchema: ZodRawShape | undefined, - outputSchema: ZodRawShape | undefined, - annotations: ToolAnnotations | undefined, - _meta: Record | undefined, - callback: ToolCallback - ): RegisteredTool { - const registeredTool: RegisteredTool = { - title, - description, - inputSchema: - inputSchema === undefined ? undefined : z.object(inputSchema), - outputSchema: - outputSchema === undefined ? undefined : z.object(outputSchema), - annotations, - _meta, - callback, - enabled: true, - disable: () => registeredTool.update({ enabled: false }), - enable: () => registeredTool.update({ enabled: true }), - remove: () => registeredTool.update({ name: null }), - update: (updates) => { - if (typeof updates.name !== "undefined" && updates.name !== name) { - delete this._registeredTools[name] - if (updates.name) this._registeredTools[updates.name] = registeredTool + + // Handle the different overload combinations + if (rest.length > 1) { + // We have at least one more arg before the callback + const firstArg = rest[0]; + + if (isZodRawShape(firstArg)) { + // We have a params schema as the first arg + inputSchema = rest.shift() as ZodRawShape; + + // Check if the next arg is potentially annotations + if (rest.length > 1 && typeof rest[0] === 'object' && rest[0] !== null && !isZodRawShape(rest[0])) { + // Case: tool(name, paramsSchema, annotations, cb) + // Or: tool(name, description, paramsSchema, annotations, cb) + annotations = rest.shift() as ToolAnnotations; + } + } else if (typeof firstArg === 'object' && firstArg !== null) { + // Not a ZodRawShape, so must be annotations in this position + // Case: tool(name, annotations, cb) + // Or: tool(name, description, annotations, cb) + annotations = rest.shift() as ToolAnnotations; + } } - if (typeof updates.title !== "undefined") registeredTool.title = updates.title - if (typeof updates.description !== "undefined") registeredTool.description = updates.description - if (typeof updates.paramsSchema !== "undefined") registeredTool.inputSchema = z.object(updates.paramsSchema) - if (typeof updates.callback !== "undefined") registeredTool.callback = updates.callback - if (typeof updates.annotations !== "undefined") registeredTool.annotations = updates.annotations - if (typeof updates._meta !== "undefined") registeredTool._meta = updates._meta - if (typeof updates.enabled !== "undefined") registeredTool.enabled = updates.enabled - this.sendToolListChanged() - }, - }; - this._registeredTools[name] = registeredTool; - - this.setToolRequestHandlers(); - this.sendToolListChanged() - - return registeredTool - } - - /** - * Registers a zero-argument tool `name`, which will run the given function when the client calls it. - */ - tool(name: string, cb: ToolCallback): RegisteredTool; - - /** - * Registers a zero-argument tool `name` (with a description) which will run the given function when the client calls it. - */ - tool(name: string, description: string, cb: ToolCallback): RegisteredTool; - - /** - * Registers a tool taking either a parameter schema for validation or annotations for additional metadata. - * This unified overload handles both `tool(name, paramsSchema, cb)` and `tool(name, annotations, cb)` cases. - * - * Note: We use a union type for the second parameter because TypeScript cannot reliably disambiguate - * between ToolAnnotations and ZodRawShape during overload resolution, as both are plain object types. - */ - tool( - name: string, - paramsSchemaOrAnnotations: Args | ToolAnnotations, - cb: ToolCallback, - ): RegisteredTool; - - /** - * Registers a tool `name` (with a description) taking either parameter schema or annotations. - * This unified overload handles both `tool(name, description, paramsSchema, cb)` and - * `tool(name, description, annotations, cb)` cases. - * - * Note: We use a union type for the third parameter because TypeScript cannot reliably disambiguate - * between ToolAnnotations and ZodRawShape during overload resolution, as both are plain object types. - */ - tool( - name: string, - description: string, - paramsSchemaOrAnnotations: Args | ToolAnnotations, - cb: ToolCallback, - ): RegisteredTool; - - /** - * Registers a tool with both parameter schema and annotations. - */ - tool( - name: string, - paramsSchema: Args, - annotations: ToolAnnotations, - cb: ToolCallback, - ): RegisteredTool; - - /** - * Registers a tool with description, parameter schema, and annotations. - */ - tool( - name: string, - description: string, - paramsSchema: Args, - annotations: ToolAnnotations, - cb: ToolCallback, - ): RegisteredTool; - - - /** - * tool() implementation. Parses arguments passed to overrides defined above. - */ - tool(name: string, ...rest: unknown[]): RegisteredTool { - if (this._registeredTools[name]) { - throw new Error(`Tool ${name} is already registered`); - } + const callback = rest[0] as ToolCallback; - let description: string | undefined; - let inputSchema: ZodRawShape | undefined; - let outputSchema: ZodRawShape | undefined; - let annotations: ToolAnnotations | undefined; + return this._createRegisteredTool(name, undefined, description, inputSchema, outputSchema, annotations, undefined, callback); + } - // Tool properties are passed as separate arguments, with omissions allowed. - // Support for this style is frozen as of protocol version 2025-03-26. Future additions - // to tool definition should *NOT* be added. + /** + * Registers a tool with a config object and callback. + */ + registerTool( + name: string, + config: { + title?: string; + description?: string; + inputSchema?: InputArgs; + outputSchema?: OutputArgs; + annotations?: ToolAnnotations; + _meta?: Record; + }, + cb: ToolCallback + ): RegisteredTool { + if (this._registeredTools[name]) { + throw new Error(`Tool ${name} is already registered`); + } - if (typeof rest[0] === "string") { - description = rest.shift() as string; + const { title, description, inputSchema, outputSchema, annotations, _meta } = config; + + return this._createRegisteredTool( + name, + title, + description, + inputSchema, + outputSchema, + annotations, + _meta, + cb as ToolCallback + ); } - // Handle the different overload combinations - if (rest.length > 1) { - // We have at least one more arg before the callback - const firstArg = rest[0]; + /** + * Registers a zero-argument prompt `name`, which will run the given function when the client calls it. + */ + prompt(name: string, cb: PromptCallback): RegisteredPrompt; + + /** + * Registers a zero-argument prompt `name` (with a description) which will run the given function when the client calls it. + */ + prompt(name: string, description: string, cb: PromptCallback): RegisteredPrompt; + + /** + * Registers a prompt `name` accepting the given arguments, which must be an object containing named properties associated with Zod schemas. When the client calls it, the function will be run with the parsed and validated arguments. + */ + prompt(name: string, argsSchema: Args, cb: PromptCallback): RegisteredPrompt; + + /** + * Registers a prompt `name` (with a description) accepting the given arguments, which must be an object containing named properties associated with Zod schemas. When the client calls it, the function will be run with the parsed and validated arguments. + */ + prompt( + name: string, + description: string, + argsSchema: Args, + cb: PromptCallback + ): RegisteredPrompt; + + prompt(name: string, ...rest: unknown[]): RegisteredPrompt { + if (this._registeredPrompts[name]) { + throw new Error(`Prompt ${name} is already registered`); + } - if (isZodRawShape(firstArg)) { - // We have a params schema as the first arg - inputSchema = rest.shift() as ZodRawShape; + let description: string | undefined; + if (typeof rest[0] === 'string') { + description = rest.shift() as string; + } - // Check if the next arg is potentially annotations - if (rest.length > 1 && typeof rest[0] === "object" && rest[0] !== null && !(isZodRawShape(rest[0]))) { - // Case: tool(name, paramsSchema, annotations, cb) - // Or: tool(name, description, paramsSchema, annotations, cb) - annotations = rest.shift() as ToolAnnotations; + let argsSchema: PromptArgsRawShape | undefined; + if (rest.length > 1) { + argsSchema = rest.shift() as PromptArgsRawShape; } - } else if (typeof firstArg === "object" && firstArg !== null) { - // Not a ZodRawShape, so must be annotations in this position - // Case: tool(name, annotations, cb) - // Or: tool(name, description, annotations, cb) - annotations = rest.shift() as ToolAnnotations; - } - } - const callback = rest[0] as ToolCallback; - - return this._createRegisteredTool(name, undefined, description, inputSchema, outputSchema, annotations, undefined, callback) - } - - /** - * Registers a tool with a config object and callback. - */ - registerTool( - name: string, - config: { - title?: string; - description?: string; - inputSchema?: InputArgs; - outputSchema?: OutputArgs; - annotations?: ToolAnnotations; - _meta?: Record; - }, - cb: ToolCallback - ): RegisteredTool { - if (this._registeredTools[name]) { - throw new Error(`Tool ${name} is already registered`); - } - const { title, description, inputSchema, outputSchema, annotations, _meta } = config; - - return this._createRegisteredTool( - name, - title, - description, - inputSchema, - outputSchema, - annotations, - _meta, - cb as ToolCallback - ); - } - - /** - * Registers a zero-argument prompt `name`, which will run the given function when the client calls it. - */ - prompt(name: string, cb: PromptCallback): RegisteredPrompt; - - /** - * Registers a zero-argument prompt `name` (with a description) which will run the given function when the client calls it. - */ - prompt(name: string, description: string, cb: PromptCallback): RegisteredPrompt; - - /** - * Registers a prompt `name` accepting the given arguments, which must be an object containing named properties associated with Zod schemas. When the client calls it, the function will be run with the parsed and validated arguments. - */ - prompt( - name: string, - argsSchema: Args, - cb: PromptCallback, - ): RegisteredPrompt; - - /** - * Registers a prompt `name` (with a description) accepting the given arguments, which must be an object containing named properties associated with Zod schemas. When the client calls it, the function will be run with the parsed and validated arguments. - */ - prompt( - name: string, - description: string, - argsSchema: Args, - cb: PromptCallback, - ): RegisteredPrompt; - - prompt(name: string, ...rest: unknown[]): RegisteredPrompt { - if (this._registeredPrompts[name]) { - throw new Error(`Prompt ${name} is already registered`); - } + const cb = rest[0] as PromptCallback; + const registeredPrompt = this._createRegisteredPrompt(name, undefined, description, argsSchema, cb); - let description: string | undefined; - if (typeof rest[0] === "string") { - description = rest.shift() as string; - } + this.setPromptRequestHandlers(); + this.sendPromptListChanged(); - let argsSchema: PromptArgsRawShape | undefined; - if (rest.length > 1) { - argsSchema = rest.shift() as PromptArgsRawShape; + return registeredPrompt; } - const cb = rest[0] as PromptCallback; - const registeredPrompt = this._createRegisteredPrompt( - name, - undefined, - description, - argsSchema, - cb - ); + /** + * Registers a prompt with a config object and callback. + */ + registerPrompt( + name: string, + config: { + title?: string; + description?: string; + argsSchema?: Args; + }, + cb: PromptCallback + ): RegisteredPrompt { + if (this._registeredPrompts[name]) { + throw new Error(`Prompt ${name} is already registered`); + } - this.setPromptRequestHandlers(); - this.sendPromptListChanged() - - return registeredPrompt - } - - /** - * Registers a prompt with a config object and callback. - */ - registerPrompt( - name: string, - config: { - title?: string; - description?: string; - argsSchema?: Args; - }, - cb: PromptCallback - ): RegisteredPrompt { - if (this._registeredPrompts[name]) { - throw new Error(`Prompt ${name} is already registered`); - } + const { title, description, argsSchema } = config; - const { title, description, argsSchema } = config; + const registeredPrompt = this._createRegisteredPrompt( + name, + title, + description, + argsSchema, + cb as PromptCallback + ); - const registeredPrompt = this._createRegisteredPrompt( - name, - title, - description, - argsSchema, - cb as PromptCallback - ); + this.setPromptRequestHandlers(); + this.sendPromptListChanged(); + + return registeredPrompt; + } - this.setPromptRequestHandlers(); - this.sendPromptListChanged() - - return registeredPrompt; - } - - /** - * Checks if the server is connected to a transport. - * @returns True if the server is connected - */ - isConnected() { - return this.server.transport !== undefined - } - - /** - * Sends a logging message to the client, if connected. - * Note: You only need to send the parameters object, not the entire JSON RPC message - * @see LoggingMessageNotification - * @param params - * @param sessionId optional for stateless and backward compatibility - */ - async sendLoggingMessage(params: LoggingMessageNotification["params"], sessionId?: string) { - return this.server.sendLoggingMessage(params, sessionId); - } - /** - * Sends a resource list changed event to the client, if connected. - */ - sendResourceListChanged() { - if (this.isConnected()) { - this.server.sendResourceListChanged(); + /** + * Checks if the server is connected to a transport. + * @returns True if the server is connected + */ + isConnected() { + return this.server.transport !== undefined; } - } - - /** - * Sends a tool list changed event to the client, if connected. - */ - sendToolListChanged() { - if (this.isConnected()) { - this.server.sendToolListChanged(); + + /** + * Sends a logging message to the client, if connected. + * Note: You only need to send the parameters object, not the entire JSON RPC message + * @see LoggingMessageNotification + * @param params + * @param sessionId optional for stateless and backward compatibility + */ + async sendLoggingMessage(params: LoggingMessageNotification['params'], sessionId?: string) { + return this.server.sendLoggingMessage(params, sessionId); + } + /** + * Sends a resource list changed event to the client, if connected. + */ + sendResourceListChanged() { + if (this.isConnected()) { + this.server.sendResourceListChanged(); + } + } + + /** + * Sends a tool list changed event to the client, if connected. + */ + sendToolListChanged() { + if (this.isConnected()) { + this.server.sendToolListChanged(); + } } - } - - /** - * Sends a prompt list changed event to the client, if connected. - */ - sendPromptListChanged() { - if (this.isConnected()) { - this.server.sendPromptListChanged(); + + /** + * Sends a prompt list changed event to the client, if connected. + */ + sendPromptListChanged() { + if (this.isConnected()) { + this.server.sendPromptListChanged(); + } } - } } /** * A callback to complete one variable within a resource template's URI template. */ export type CompleteResourceTemplateCallback = ( - value: string, - context?: { - arguments?: Record; - }, + value: string, + context?: { + arguments?: Record; + } ) => string[] | Promise; /** @@ -1107,52 +958,47 @@ export type CompleteResourceTemplateCallback = ( * all resources matching that pattern. */ export class ResourceTemplate { - private _uriTemplate: UriTemplate; - - constructor( - uriTemplate: string | UriTemplate, - private _callbacks: { - /** - * A callback to list all resources matching this template. This is required to specified, even if `undefined`, to avoid accidentally forgetting resource listing. - */ - list: ListResourcesCallback | undefined; - - /** - * An optional callback to autocomplete variables within the URI template. Useful for clients and users to discover possible values. - */ - complete?: { - [variable: string]: CompleteResourceTemplateCallback; - }; - }, - ) { - this._uriTemplate = - typeof uriTemplate === "string" - ? new UriTemplate(uriTemplate) - : uriTemplate; - } - - /** - * Gets the URI template pattern. - */ - get uriTemplate(): UriTemplate { - return this._uriTemplate; - } - - /** - * Gets the list callback, if one was provided. - */ - get listCallback(): ListResourcesCallback | undefined { - return this._callbacks.list; - } - - /** - * Gets the callback for completing a specific URI template variable, if one was provided. - */ - completeCallback( - variable: string, - ): CompleteResourceTemplateCallback | undefined { - return this._callbacks.complete?.[variable]; - } + private _uriTemplate: UriTemplate; + + constructor( + uriTemplate: string | UriTemplate, + private _callbacks: { + /** + * A callback to list all resources matching this template. This is required to specified, even if `undefined`, to avoid accidentally forgetting resource listing. + */ + list: ListResourcesCallback | undefined; + + /** + * An optional callback to autocomplete variables within the URI template. Useful for clients and users to discover possible values. + */ + complete?: { + [variable: string]: CompleteResourceTemplateCallback; + }; + } + ) { + this._uriTemplate = typeof uriTemplate === 'string' ? new UriTemplate(uriTemplate) : uriTemplate; + } + + /** + * Gets the URI template pattern. + */ + get uriTemplate(): UriTemplate { + return this._uriTemplate; + } + + /** + * Gets the list callback, if one was provided. + */ + get listCallback(): ListResourcesCallback | undefined { + return this._callbacks.list; + } + + /** + * Gets the callback for completing a specific URI template variable, if one was provided. + */ + completeCallback(variable: string): CompleteResourceTemplateCallback | undefined { + return this._callbacks.complete?.[variable]; + } } /** @@ -1165,168 +1011,185 @@ export class ResourceTemplate { * - `content` if the tool does not have an outputSchema * - Both fields are optional but typically one should be provided */ -export type ToolCallback = - Args extends ZodRawShape - ? ( - args: z.objectOutputType, - extra: RequestHandlerExtra, - ) => CallToolResult | Promise - : (extra: RequestHandlerExtra) => CallToolResult | Promise; +export type ToolCallback = Args extends ZodRawShape + ? ( + args: z.objectOutputType, + extra: RequestHandlerExtra + ) => CallToolResult | Promise + : (extra: RequestHandlerExtra) => CallToolResult | Promise; export type RegisteredTool = { - title?: string; - description?: string; - inputSchema?: AnyZodObject; - outputSchema?: AnyZodObject; - annotations?: ToolAnnotations; - _meta?: Record; - callback: ToolCallback; - enabled: boolean; - enable(): void; - disable(): void; - update( - updates: { - name?: string | null, - title?: string, - description?: string, - paramsSchema?: InputArgs, - outputSchema?: OutputArgs, - annotations?: ToolAnnotations, - _meta?: Record, - callback?: ToolCallback, - enabled?: boolean - }): void - remove(): void + title?: string; + description?: string; + inputSchema?: AnyZodObject; + outputSchema?: AnyZodObject; + annotations?: ToolAnnotations; + _meta?: Record; + callback: ToolCallback; + enabled: boolean; + enable(): void; + disable(): void; + update(updates: { + name?: string | null; + title?: string; + description?: string; + paramsSchema?: InputArgs; + outputSchema?: OutputArgs; + annotations?: ToolAnnotations; + _meta?: Record; + callback?: ToolCallback; + enabled?: boolean; + }): void; + remove(): void; }; const EMPTY_OBJECT_JSON_SCHEMA = { - type: "object" as const, - properties: {}, + type: 'object' as const, + properties: {} }; // Helper to check if an object is a Zod schema (ZodRawShape) function isZodRawShape(obj: unknown): obj is ZodRawShape { - if (typeof obj !== "object" || obj === null) return false; + if (typeof obj !== 'object' || obj === null) return false; - const isEmptyObject = Object.keys(obj).length === 0; + const isEmptyObject = Object.keys(obj).length === 0; - // Check if object is empty or at least one property is a ZodType instance - // Note: use heuristic check to avoid instanceof failure across different Zod versions - return isEmptyObject || Object.values(obj as object).some(isZodTypeLike); + // Check if object is empty or at least one property is a ZodType instance + // Note: use heuristic check to avoid instanceof failure across different Zod versions + return isEmptyObject || Object.values(obj as object).some(isZodTypeLike); } function isZodTypeLike(value: unknown): value is ZodType { - return value !== null && - typeof value === 'object' && - 'parse' in value && typeof value.parse === 'function' && - 'safeParse' in value && typeof value.safeParse === 'function'; + return ( + value !== null && + typeof value === 'object' && + 'parse' in value && + typeof value.parse === 'function' && + 'safeParse' in value && + typeof value.safeParse === 'function' + ); } /** * Additional, optional information for annotating a resource. */ -export type ResourceMetadata = Omit; +export type ResourceMetadata = Omit; /** * Callback to list all resources matching a given template. */ export type ListResourcesCallback = ( - extra: RequestHandlerExtra, + extra: RequestHandlerExtra ) => ListResourcesResult | Promise; /** * Callback to read a resource at a given URI. */ export type ReadResourceCallback = ( - uri: URL, - extra: RequestHandlerExtra, + uri: URL, + extra: RequestHandlerExtra ) => ReadResourceResult | Promise; export type RegisteredResource = { - name: string; - title?: string; - metadata?: ResourceMetadata; - readCallback: ReadResourceCallback; - enabled: boolean; - enable(): void; - disable(): void; - update(updates: { name?: string, title?: string, uri?: string | null, metadata?: ResourceMetadata, callback?: ReadResourceCallback, enabled?: boolean }): void - remove(): void + name: string; + title?: string; + metadata?: ResourceMetadata; + readCallback: ReadResourceCallback; + enabled: boolean; + enable(): void; + disable(): void; + update(updates: { + name?: string; + title?: string; + uri?: string | null; + metadata?: ResourceMetadata; + callback?: ReadResourceCallback; + enabled?: boolean; + }): void; + remove(): void; }; /** * Callback to read a resource at a given URI, following a filled-in URI template. */ export type ReadResourceTemplateCallback = ( - uri: URL, - variables: Variables, - extra: RequestHandlerExtra, + uri: URL, + variables: Variables, + extra: RequestHandlerExtra ) => ReadResourceResult | Promise; export type RegisteredResourceTemplate = { - resourceTemplate: ResourceTemplate; - title?: string; - metadata?: ResourceMetadata; - readCallback: ReadResourceTemplateCallback; - enabled: boolean; - enable(): void; - disable(): void; - update(updates: { name?: string | null, title?: string, template?: ResourceTemplate, metadata?: ResourceMetadata, callback?: ReadResourceTemplateCallback, enabled?: boolean }): void - remove(): void + resourceTemplate: ResourceTemplate; + title?: string; + metadata?: ResourceMetadata; + readCallback: ReadResourceTemplateCallback; + enabled: boolean; + enable(): void; + disable(): void; + update(updates: { + name?: string | null; + title?: string; + template?: ResourceTemplate; + metadata?: ResourceMetadata; + callback?: ReadResourceTemplateCallback; + enabled?: boolean; + }): void; + remove(): void; }; type PromptArgsRawShape = { - [k: string]: - | ZodType - | ZodOptional>; + [k: string]: ZodType | ZodOptional>; }; -export type PromptCallback< - Args extends undefined | PromptArgsRawShape = undefined, -> = Args extends PromptArgsRawShape - ? ( - args: z.objectOutputType, - extra: RequestHandlerExtra, - ) => GetPromptResult | Promise - : (extra: RequestHandlerExtra) => GetPromptResult | Promise; +export type PromptCallback = Args extends PromptArgsRawShape + ? ( + args: z.objectOutputType, + extra: RequestHandlerExtra + ) => GetPromptResult | Promise + : (extra: RequestHandlerExtra) => GetPromptResult | Promise; export type RegisteredPrompt = { - title?: string; - description?: string; - argsSchema?: ZodObject; - callback: PromptCallback; - enabled: boolean; - enable(): void; - disable(): void; - update(updates: { name?: string | null, title?: string, description?: string, argsSchema?: Args, callback?: PromptCallback, enabled?: boolean }): void - remove(): void + title?: string; + description?: string; + argsSchema?: ZodObject; + callback: PromptCallback; + enabled: boolean; + enable(): void; + disable(): void; + update(updates: { + name?: string | null; + title?: string; + description?: string; + argsSchema?: Args; + callback?: PromptCallback; + enabled?: boolean; + }): void; + remove(): void; }; -function promptArgumentsFromSchema( - schema: ZodObject, -): PromptArgument[] { - return Object.entries(schema.shape).map( - ([name, field]): PromptArgument => ({ - name, - description: field.description, - required: !field.isOptional(), - }), - ); +function promptArgumentsFromSchema(schema: ZodObject): PromptArgument[] { + return Object.entries(schema.shape).map( + ([name, field]): PromptArgument => ({ + name, + description: field.description, + required: !field.isOptional() + }) + ); } function createCompletionResult(suggestions: string[]): CompleteResult { - return { - completion: { - values: suggestions.slice(0, 100), - total: suggestions.length, - hasMore: suggestions.length > 100, - }, - }; + return { + completion: { + values: suggestions.slice(0, 100), + total: suggestions.length, + hasMore: suggestions.length > 100 + } + }; } const EMPTY_COMPLETION_RESULT: CompleteResult = { - completion: { - values: [], - hasMore: false, - }, + completion: { + values: [], + hasMore: false + } }; diff --git a/src/server/sse.test.ts b/src/server/sse.test.ts index a7f180961..372cb689f 100644 --- a/src/server/sse.test.ts +++ b/src/server/sse.test.ts @@ -1,716 +1,710 @@ -import http from 'http'; +import http from 'http'; import { jest } from '@jest/globals'; -import { SSEServerTransport } from './sse.js'; +import { SSEServerTransport } from './sse.js'; import { McpServer } from './mcp.js'; -import { createServer, type Server } from "node:http"; -import { AddressInfo } from "node:net"; +import { createServer, type Server } from 'node:http'; +import { AddressInfo } from 'node:net'; import { z } from 'zod'; import { CallToolResult, JSONRPCMessage } from 'src/types.js'; const createMockResponse = () => { - const res = { - writeHead: jest.fn().mockReturnThis(), - write: jest.fn().mockReturnThis(), - on: jest.fn().mockReturnThis(), - end: jest.fn().mockReturnThis(), - }; - - return res as unknown as jest.Mocked; + const res = { + writeHead: jest.fn().mockReturnThis(), + write: jest.fn().mockReturnThis(), + on: jest.fn().mockReturnThis(), + end: jest.fn().mockReturnThis() + }; + + return res as unknown as jest.Mocked; }; -const createMockRequest = ({ headers = {}, body }: { headers?: Record, body?: string } = {}) => { - const mockReq = { - headers, - body: body ? body : undefined, - auth: { - token: 'test-token', - }, - on: jest.fn().mockImplementation((event, listener) => { - const mockListener = listener as unknown as (...args: unknown[]) => void; - if (event === 'data') { - mockListener(Buffer.from(body || '') as unknown as Error); - } - if (event === 'error') { - mockListener(new Error('test')); - } - if (event === 'end') { - mockListener(); - } - if (event === 'close') { - setTimeout(listener, 100); - } - return mockReq; - }), - listeners: jest.fn(), - removeListener: jest.fn(), - } as unknown as http.IncomingMessage; - - return mockReq; +const createMockRequest = ({ headers = {}, body }: { headers?: Record; body?: string } = {}) => { + const mockReq = { + headers, + body: body ? body : undefined, + auth: { + token: 'test-token' + }, + on: jest.fn().mockImplementation((event, listener) => { + const mockListener = listener as unknown as (...args: unknown[]) => void; + if (event === 'data') { + mockListener(Buffer.from(body || '') as unknown as Error); + } + if (event === 'error') { + mockListener(new Error('test')); + } + if (event === 'end') { + mockListener(); + } + if (event === 'close') { + setTimeout(listener, 100); + } + return mockReq; + }), + listeners: jest.fn(), + removeListener: jest.fn() + } as unknown as http.IncomingMessage; + + return mockReq; }; /** * Helper to create and start test HTTP server with MCP setup */ -async function createTestServerWithSse(args: { - mockRes: http.ServerResponse; -}): Promise<{ - server: Server; - transport: SSEServerTransport; - mcpServer: McpServer; - baseUrl: URL; - sessionId: string - serverPort: number; +async function createTestServerWithSse(args: { mockRes: http.ServerResponse }): Promise<{ + server: Server; + transport: SSEServerTransport; + mcpServer: McpServer; + baseUrl: URL; + sessionId: string; + serverPort: number; }> { - const mcpServer = new McpServer( - { name: "test-server", version: "1.0.0" }, - { capabilities: { logging: {} } } - ); - - mcpServer.tool( - "greet", - "A simple greeting tool", - { name: z.string().describe("Name to greet") }, - async ({ name }): Promise => { - return { content: [{ type: "text", text: `Hello, ${name}!` }] }; - } - ); + const mcpServer = new McpServer({ name: 'test-server', version: '1.0.0' }, { capabilities: { logging: {} } }); + + mcpServer.tool( + 'greet', + 'A simple greeting tool', + { name: z.string().describe('Name to greet') }, + async ({ name }): Promise => { + return { content: [{ type: 'text', text: `Hello, ${name}!` }] }; + } + ); - const endpoint = '/messages'; + const endpoint = '/messages'; - const transport = new SSEServerTransport(endpoint, args.mockRes); - const sessionId = transport.sessionId; + const transport = new SSEServerTransport(endpoint, args.mockRes); + const sessionId = transport.sessionId; - await mcpServer.connect(transport); + await mcpServer.connect(transport); - const server = createServer(async (req, res) => { - try { - await transport.handlePostMessage(req, res); - } catch (error) { - console.error("Error handling request:", error); - if (!res.headersSent) res.writeHead(500).end(); - } - }); + const server = createServer(async (req, res) => { + try { + await transport.handlePostMessage(req, res); + } catch (error) { + console.error('Error handling request:', error); + if (!res.headersSent) res.writeHead(500).end(); + } + }); - const baseUrl = await new Promise((resolve) => { - server.listen(0, "127.0.0.1", () => { - const addr = server.address() as AddressInfo; - resolve(new URL(`http://127.0.0.1:${addr.port}`)); + const baseUrl = await new Promise(resolve => { + server.listen(0, '127.0.0.1', () => { + const addr = server.address() as AddressInfo; + resolve(new URL(`http://127.0.0.1:${addr.port}`)); + }); }); - }); - const port = (server.address() as AddressInfo).port; + const port = (server.address() as AddressInfo).port; - return { server, transport, mcpServer, baseUrl, sessionId, serverPort: port }; + return { server, transport, mcpServer, baseUrl, sessionId, serverPort: port }; } async function readAllSSEEvents(response: Response): Promise { - const reader = response.body?.getReader(); - if (!reader) throw new Error('No readable stream'); - - const events: string[] = []; - const decoder = new TextDecoder(); - - try { - while (true) { - const { done, value } = await reader.read(); - if (done) break; - - if (value) { - events.push(decoder.decode(value)); - } + const reader = response.body?.getReader(); + if (!reader) throw new Error('No readable stream'); + + const events: string[] = []; + const decoder = new TextDecoder(); + + try { + while (true) { + const { done, value } = await reader.read(); + if (done) break; + + if (value) { + events.push(decoder.decode(value)); + } + } + } finally { + reader.releaseLock(); } - } finally { - reader.releaseLock(); - } - - return events; + + return events; } /** * Helper to send JSON-RPC request */ -async function sendSsePostRequest(baseUrl: URL, message: JSONRPCMessage | JSONRPCMessage[], sessionId?: string, extraHeaders?: Record): Promise { - const headers: Record = { - "Content-Type": "application/json", - Accept: "application/json, text/event-stream", - ...extraHeaders - }; - - if (sessionId) { - baseUrl.searchParams.set('sessionId', sessionId); - } - - return fetch(baseUrl, { - method: "POST", - headers, - body: JSON.stringify(message), - }); -} - -describe('SSEServerTransport', () => { - - async function initializeServer(baseUrl: URL): Promise { - const response = await sendSsePostRequest(baseUrl, { - jsonrpc: "2.0", - method: "initialize", - params: { - clientInfo: { name: "test-client", version: "1.0" }, - protocolVersion: "2025-03-26", - capabilities: { - }, - }, - - id: "init-1", - } as JSONRPCMessage); - - expect(response.status).toBe(202); - - const text = await readAllSSEEvents(response); - - expect(text).toHaveLength(1); - expect(text[0]).toBe('Accepted'); - } - - describe('start method', () => { - it('should correctly append sessionId to a simple relative endpoint', async () => { - const mockRes = createMockResponse(); - const endpoint = '/messages'; - const transport = new SSEServerTransport(endpoint, mockRes); - const expectedSessionId = transport.sessionId; - - await transport.start(); - - expect(mockRes.writeHead).toHaveBeenCalledWith(200, expect.any(Object)); - expect(mockRes.write).toHaveBeenCalledTimes(1); - expect(mockRes.write).toHaveBeenCalledWith( - `event: endpoint\ndata: /messages?sessionId=${expectedSessionId}\n\n` - ); - }); - - it('should correctly append sessionId to an endpoint with existing query parameters', async () => { - const mockRes = createMockResponse(); - const endpoint = '/messages?foo=bar&baz=qux'; - const transport = new SSEServerTransport(endpoint, mockRes); - const expectedSessionId = transport.sessionId; - - await transport.start(); - - expect(mockRes.writeHead).toHaveBeenCalledWith(200, expect.any(Object)); - expect(mockRes.write).toHaveBeenCalledTimes(1); - expect(mockRes.write).toHaveBeenCalledWith( - `event: endpoint\ndata: /messages?foo=bar&baz=qux&sessionId=${expectedSessionId}\n\n` - ); - }); - - it('should correctly append sessionId to an endpoint with a hash fragment', async () => { - const mockRes = createMockResponse(); - const endpoint = '/messages#section1'; - const transport = new SSEServerTransport(endpoint, mockRes); - const expectedSessionId = transport.sessionId; +async function sendSsePostRequest( + baseUrl: URL, + message: JSONRPCMessage | JSONRPCMessage[], + sessionId?: string, + extraHeaders?: Record +): Promise { + const headers: Record = { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream', + ...extraHeaders + }; - await transport.start(); + if (sessionId) { + baseUrl.searchParams.set('sessionId', sessionId); + } - expect(mockRes.writeHead).toHaveBeenCalledWith(200, expect.any(Object)); - expect(mockRes.write).toHaveBeenCalledTimes(1); - expect(mockRes.write).toHaveBeenCalledWith( - `event: endpoint\ndata: /messages?sessionId=${expectedSessionId}#section1\n\n` - ); + return fetch(baseUrl, { + method: 'POST', + headers, + body: JSON.stringify(message) }); +} - it('should correctly append sessionId to an endpoint with query parameters and a hash fragment', async () => { - const mockRes = createMockResponse(); - const endpoint = '/messages?key=value#section2'; - const transport = new SSEServerTransport(endpoint, mockRes); - const expectedSessionId = transport.sessionId; - - await transport.start(); +describe('SSEServerTransport', () => { + async function initializeServer(baseUrl: URL): Promise { + const response = await sendSsePostRequest(baseUrl, { + jsonrpc: '2.0', + method: 'initialize', + params: { + clientInfo: { name: 'test-client', version: '1.0' }, + protocolVersion: '2025-03-26', + capabilities: {} + }, - expect(mockRes.writeHead).toHaveBeenCalledWith(200, expect.any(Object)); - expect(mockRes.write).toHaveBeenCalledTimes(1); - expect(mockRes.write).toHaveBeenCalledWith( - `event: endpoint\ndata: /messages?key=value&sessionId=${expectedSessionId}#section2\n\n` - ); - }); + id: 'init-1' + } as JSONRPCMessage); - it('should correctly handle the root path endpoint "/"', async () => { - const mockRes = createMockResponse(); - const endpoint = '/'; - const transport = new SSEServerTransport(endpoint, mockRes); - const expectedSessionId = transport.sessionId; + expect(response.status).toBe(202); - await transport.start(); + const text = await readAllSSEEvents(response); - expect(mockRes.writeHead).toHaveBeenCalledWith(200, expect.any(Object)); - expect(mockRes.write).toHaveBeenCalledTimes(1); - expect(mockRes.write).toHaveBeenCalledWith( - `event: endpoint\ndata: /?sessionId=${expectedSessionId}\n\n` - ); - }); + expect(text).toHaveLength(1); + expect(text[0]).toBe('Accepted'); + } - it('should correctly handle an empty string endpoint ""', async () => { - const mockRes = createMockResponse(); - const endpoint = ''; - const transport = new SSEServerTransport(endpoint, mockRes); - const expectedSessionId = transport.sessionId; + describe('start method', () => { + it('should correctly append sessionId to a simple relative endpoint', async () => { + const mockRes = createMockResponse(); + const endpoint = '/messages'; + const transport = new SSEServerTransport(endpoint, mockRes); + const expectedSessionId = transport.sessionId; - await transport.start(); + await transport.start(); - expect(mockRes.writeHead).toHaveBeenCalledWith(200, expect.any(Object)); - expect(mockRes.write).toHaveBeenCalledTimes(1); - expect(mockRes.write).toHaveBeenCalledWith( - `event: endpoint\ndata: /?sessionId=${expectedSessionId}\n\n` - ); - }); + expect(mockRes.writeHead).toHaveBeenCalledWith(200, expect.any(Object)); + expect(mockRes.write).toHaveBeenCalledTimes(1); + expect(mockRes.write).toHaveBeenCalledWith(`event: endpoint\ndata: /messages?sessionId=${expectedSessionId}\n\n`); + }); - /** - * Test: Tool With Request Info - */ - it("should pass request info to tool callback", async () => { - const mockRes = createMockResponse(); - const { mcpServer, baseUrl, sessionId, serverPort } = await createTestServerWithSse({ mockRes }); - await initializeServer(baseUrl); + it('should correctly append sessionId to an endpoint with existing query parameters', async () => { + const mockRes = createMockResponse(); + const endpoint = '/messages?foo=bar&baz=qux'; + const transport = new SSEServerTransport(endpoint, mockRes); + const expectedSessionId = transport.sessionId; - mcpServer.tool( - "test-request-info", - "A simple test tool with request info", - { name: z.string().describe("Name to greet") }, - async ({ name }, { requestInfo }): Promise => { - return { content: [{ type: "text", text: `Hello, ${name}!` }, { type: "text", text: `${JSON.stringify(requestInfo)}` }] }; - } - ); - - const toolCallMessage: JSONRPCMessage = { - jsonrpc: "2.0", - method: "tools/call", - params: { - name: "test-request-info", - arguments: { - name: "Test User", - }, - }, - id: "call-1", - }; + await transport.start(); - const response = await sendSsePostRequest(baseUrl, toolCallMessage, sessionId); - - expect(response.status).toBe(202); - - expect(mockRes.write).toHaveBeenCalledWith(`event: endpoint\ndata: /messages?sessionId=${sessionId}\n\n`); - - const expectedMessage = { - result: { - content: [ - { - type: "text", - text: "Hello, Test User!", - }, - { - type: "text", - text: JSON.stringify({ - headers: { - host: `127.0.0.1:${serverPort}`, - connection: 'keep-alive', - 'content-type': 'application/json', - accept: 'application/json, text/event-stream', - 'accept-language': '*', - 'sec-fetch-mode': 'cors', - 'user-agent': 'node', - 'accept-encoding': 'gzip, deflate', - 'content-length': '124' - }, - }) - }, - ], - }, - jsonrpc: "2.0", - id: "call-1", - }; - expect(mockRes.write).toHaveBeenCalledWith(`event: message\ndata: ${JSON.stringify(expectedMessage)}\n\n`); - }); - }); - - describe('handlePostMessage method', () => { - it('should return 500 if server has not started', async () => { - const mockReq = createMockRequest(); - const mockRes = createMockResponse(); - const endpoint = '/messages'; - const transport = new SSEServerTransport(endpoint, mockRes); - - const error = 'SSE connection not established'; - await expect(transport.handlePostMessage(mockReq, mockRes)) - .rejects.toThrow(error); - expect(mockRes.writeHead).toHaveBeenCalledWith(500); - expect(mockRes.end).toHaveBeenCalledWith(error); - }); + expect(mockRes.writeHead).toHaveBeenCalledWith(200, expect.any(Object)); + expect(mockRes.write).toHaveBeenCalledTimes(1); + expect(mockRes.write).toHaveBeenCalledWith( + `event: endpoint\ndata: /messages?foo=bar&baz=qux&sessionId=${expectedSessionId}\n\n` + ); + }); - it('should return 400 if content-type is not application/json', async () => { - const mockReq = createMockRequest({ headers: { 'content-type': 'text/plain' } }); - const mockRes = createMockResponse(); - const endpoint = '/messages'; - const transport = new SSEServerTransport(endpoint, mockRes); - await transport.start(); - - transport.onerror = jest.fn(); - const error = 'Unsupported content-type: text/plain'; - await expect(transport.handlePostMessage(mockReq, mockRes)) - .resolves.toBe(undefined); - expect(mockRes.writeHead).toHaveBeenCalledWith(400); - expect(mockRes.end).toHaveBeenCalledWith(expect.stringContaining(error)); - expect(transport.onerror).toHaveBeenCalledWith(new Error(error)); - }); + it('should correctly append sessionId to an endpoint with a hash fragment', async () => { + const mockRes = createMockResponse(); + const endpoint = '/messages#section1'; + const transport = new SSEServerTransport(endpoint, mockRes); + const expectedSessionId = transport.sessionId; - it('should return 400 if message has not a valid schema', async () => { - const invalidMessage = JSON.stringify({ - // missing jsonrpc field - method: 'call', - params: [1, 2, 3], - id: 1, - }) - const mockReq = createMockRequest({ - headers: { 'content-type': 'application/json' }, - body: invalidMessage, - }); - const mockRes = createMockResponse(); - const endpoint = '/messages'; - const transport = new SSEServerTransport(endpoint, mockRes); - await transport.start(); - - transport.onmessage = jest.fn(); - await transport.handlePostMessage(mockReq, mockRes); - expect(mockRes.writeHead).toHaveBeenCalledWith(400); - expect(transport.onmessage).not.toHaveBeenCalled(); - expect(mockRes.end).toHaveBeenCalledWith(`Invalid message: ${invalidMessage}`); - }); + await transport.start(); - it('should return 202 if message has a valid schema', async () => { - const validMessage = JSON.stringify({ - jsonrpc: "2.0", - method: 'call', - params: { - a: 1, - b: 2, - c: 3, - }, - id: 1 - }) - const mockReq = createMockRequest({ - headers: { 'content-type': 'application/json' }, - body: validMessage, - }); - const mockRes = createMockResponse(); - const endpoint = '/messages'; - const transport = new SSEServerTransport(endpoint, mockRes); - await transport.start(); - - transport.onmessage = jest.fn(); - await transport.handlePostMessage(mockReq, mockRes); - expect(mockRes.writeHead).toHaveBeenCalledWith(202); - expect(mockRes.end).toHaveBeenCalledWith('Accepted'); - expect(transport.onmessage).toHaveBeenCalledWith({ - jsonrpc: "2.0", - method: 'call', - params: { - a: 1, - b: 2, - c: 3, - }, - id: 1 - }, { - authInfo: { - token: 'test-token', - }, - requestInfo: { - headers: { - 'content-type': 'application/json', - }, - }, - }); - }); - }); - - describe('close method', () => { - it('should call onclose', async () => { - const mockRes = createMockResponse(); - const endpoint = '/messages'; - const transport = new SSEServerTransport(endpoint, mockRes); - await transport.start(); - transport.onclose = jest.fn(); - await transport.close(); - expect(transport.onclose).toHaveBeenCalled(); - }); - }); - - describe('send method', () => { - it('should call onsend', async () => { - const mockRes = createMockResponse(); - const endpoint = '/messages'; - const transport = new SSEServerTransport(endpoint, mockRes); - await transport.start(); - expect(mockRes.write).toHaveBeenCalledTimes(1); - expect(mockRes.write).toHaveBeenCalledWith( - expect.stringContaining('event: endpoint')); - expect(mockRes.write).toHaveBeenCalledWith( - expect.stringContaining(`data: /messages?sessionId=${transport.sessionId}`)); - }); - }); + expect(mockRes.writeHead).toHaveBeenCalledWith(200, expect.any(Object)); + expect(mockRes.write).toHaveBeenCalledTimes(1); + expect(mockRes.write).toHaveBeenCalledWith(`event: endpoint\ndata: /messages?sessionId=${expectedSessionId}#section1\n\n`); + }); - describe('DNS rebinding protection', () => { - beforeEach(() => { - jest.clearAllMocks(); - }); + it('should correctly append sessionId to an endpoint with query parameters and a hash fragment', async () => { + const mockRes = createMockResponse(); + const endpoint = '/messages?key=value#section2'; + const transport = new SSEServerTransport(endpoint, mockRes); + const expectedSessionId = transport.sessionId; - describe('Host header validation', () => { - it('should accept requests with allowed host headers', async () => { - const mockRes = createMockResponse(); - const transport = new SSEServerTransport('/messages', mockRes, { - allowedHosts: ['localhost:3000', 'example.com'], - enableDnsRebindingProtection: true, - }); - await transport.start(); + await transport.start(); - const mockReq = createMockRequest({ - headers: { - host: 'localhost:3000', - 'content-type': 'application/json', - } + expect(mockRes.writeHead).toHaveBeenCalledWith(200, expect.any(Object)); + expect(mockRes.write).toHaveBeenCalledTimes(1); + expect(mockRes.write).toHaveBeenCalledWith( + `event: endpoint\ndata: /messages?key=value&sessionId=${expectedSessionId}#section2\n\n` + ); }); - const mockHandleRes = createMockResponse(); - await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' }); + it('should correctly handle the root path endpoint "/"', async () => { + const mockRes = createMockResponse(); + const endpoint = '/'; + const transport = new SSEServerTransport(endpoint, mockRes); + const expectedSessionId = transport.sessionId; - expect(mockHandleRes.writeHead).toHaveBeenCalledWith(202); - expect(mockHandleRes.end).toHaveBeenCalledWith('Accepted'); - }); + await transport.start(); - it('should reject requests with disallowed host headers', async () => { - const mockRes = createMockResponse(); - const transport = new SSEServerTransport('/messages', mockRes, { - allowedHosts: ['localhost:3000'], - enableDnsRebindingProtection: true, + expect(mockRes.writeHead).toHaveBeenCalledWith(200, expect.any(Object)); + expect(mockRes.write).toHaveBeenCalledTimes(1); + expect(mockRes.write).toHaveBeenCalledWith(`event: endpoint\ndata: /?sessionId=${expectedSessionId}\n\n`); }); - await transport.start(); - const mockReq = createMockRequest({ - headers: { - host: 'evil.com', - 'content-type': 'application/json', - } - }); - const mockHandleRes = createMockResponse(); - - await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' }); + it('should correctly handle an empty string endpoint ""', async () => { + const mockRes = createMockResponse(); + const endpoint = ''; + const transport = new SSEServerTransport(endpoint, mockRes); + const expectedSessionId = transport.sessionId; - expect(mockHandleRes.writeHead).toHaveBeenCalledWith(403); - expect(mockHandleRes.end).toHaveBeenCalledWith('Invalid Host header: evil.com'); - }); + await transport.start(); - it('should reject requests without host header when allowedHosts is configured', async () => { - const mockRes = createMockResponse(); - const transport = new SSEServerTransport('/messages', mockRes, { - allowedHosts: ['localhost:3000'], - enableDnsRebindingProtection: true, + expect(mockRes.writeHead).toHaveBeenCalledWith(200, expect.any(Object)); + expect(mockRes.write).toHaveBeenCalledTimes(1); + expect(mockRes.write).toHaveBeenCalledWith(`event: endpoint\ndata: /?sessionId=${expectedSessionId}\n\n`); }); - await transport.start(); - const mockReq = createMockRequest({ - headers: { - 'content-type': 'application/json', - } + /** + * Test: Tool With Request Info + */ + it('should pass request info to tool callback', async () => { + const mockRes = createMockResponse(); + const { mcpServer, baseUrl, sessionId, serverPort } = await createTestServerWithSse({ mockRes }); + await initializeServer(baseUrl); + + mcpServer.tool( + 'test-request-info', + 'A simple test tool with request info', + { name: z.string().describe('Name to greet') }, + async ({ name }, { requestInfo }): Promise => { + return { + content: [ + { type: 'text', text: `Hello, ${name}!` }, + { type: 'text', text: `${JSON.stringify(requestInfo)}` } + ] + }; + } + ); + + const toolCallMessage: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'tools/call', + params: { + name: 'test-request-info', + arguments: { + name: 'Test User' + } + }, + id: 'call-1' + }; + + const response = await sendSsePostRequest(baseUrl, toolCallMessage, sessionId); + + expect(response.status).toBe(202); + + expect(mockRes.write).toHaveBeenCalledWith(`event: endpoint\ndata: /messages?sessionId=${sessionId}\n\n`); + + const expectedMessage = { + result: { + content: [ + { + type: 'text', + text: 'Hello, Test User!' + }, + { + type: 'text', + text: JSON.stringify({ + headers: { + host: `127.0.0.1:${serverPort}`, + connection: 'keep-alive', + 'content-type': 'application/json', + accept: 'application/json, text/event-stream', + 'accept-language': '*', + 'sec-fetch-mode': 'cors', + 'user-agent': 'node', + 'accept-encoding': 'gzip, deflate', + 'content-length': '124' + } + }) + } + ] + }, + jsonrpc: '2.0', + id: 'call-1' + }; + expect(mockRes.write).toHaveBeenCalledWith(`event: message\ndata: ${JSON.stringify(expectedMessage)}\n\n`); }); - const mockHandleRes = createMockResponse(); - - await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' }); - - expect(mockHandleRes.writeHead).toHaveBeenCalledWith(403); - expect(mockHandleRes.end).toHaveBeenCalledWith('Invalid Host header: undefined'); - }); }); - describe('Origin header validation', () => { - it('should accept requests with allowed origin headers', async () => { - const mockRes = createMockResponse(); - const transport = new SSEServerTransport('/messages', mockRes, { - allowedOrigins: ['http://localhost:3000', 'https://example.com'], - enableDnsRebindingProtection: true, + describe('handlePostMessage method', () => { + it('should return 500 if server has not started', async () => { + const mockReq = createMockRequest(); + const mockRes = createMockResponse(); + const endpoint = '/messages'; + const transport = new SSEServerTransport(endpoint, mockRes); + + const error = 'SSE connection not established'; + await expect(transport.handlePostMessage(mockReq, mockRes)).rejects.toThrow(error); + expect(mockRes.writeHead).toHaveBeenCalledWith(500); + expect(mockRes.end).toHaveBeenCalledWith(error); }); - await transport.start(); - const mockReq = createMockRequest({ - headers: { - origin: 'http://localhost:3000', - 'content-type': 'application/json', - } + it('should return 400 if content-type is not application/json', async () => { + const mockReq = createMockRequest({ headers: { 'content-type': 'text/plain' } }); + const mockRes = createMockResponse(); + const endpoint = '/messages'; + const transport = new SSEServerTransport(endpoint, mockRes); + await transport.start(); + + transport.onerror = jest.fn(); + const error = 'Unsupported content-type: text/plain'; + await expect(transport.handlePostMessage(mockReq, mockRes)).resolves.toBe(undefined); + expect(mockRes.writeHead).toHaveBeenCalledWith(400); + expect(mockRes.end).toHaveBeenCalledWith(expect.stringContaining(error)); + expect(transport.onerror).toHaveBeenCalledWith(new Error(error)); }); - const mockHandleRes = createMockResponse(); - - await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' }); - - expect(mockHandleRes.writeHead).toHaveBeenCalledWith(202); - expect(mockHandleRes.end).toHaveBeenCalledWith('Accepted'); - }); - it('should reject requests with disallowed origin headers', async () => { - const mockRes = createMockResponse(); - const transport = new SSEServerTransport('/messages', mockRes, { - allowedOrigins: ['http://localhost:3000'], - enableDnsRebindingProtection: true, + it('should return 400 if message has not a valid schema', async () => { + const invalidMessage = JSON.stringify({ + // missing jsonrpc field + method: 'call', + params: [1, 2, 3], + id: 1 + }); + const mockReq = createMockRequest({ + headers: { 'content-type': 'application/json' }, + body: invalidMessage + }); + const mockRes = createMockResponse(); + const endpoint = '/messages'; + const transport = new SSEServerTransport(endpoint, mockRes); + await transport.start(); + + transport.onmessage = jest.fn(); + await transport.handlePostMessage(mockReq, mockRes); + expect(mockRes.writeHead).toHaveBeenCalledWith(400); + expect(transport.onmessage).not.toHaveBeenCalled(); + expect(mockRes.end).toHaveBeenCalledWith(`Invalid message: ${invalidMessage}`); }); - await transport.start(); - const mockReq = createMockRequest({ - headers: { - origin: 'http://evil.com', - 'content-type': 'application/json', - } + it('should return 202 if message has a valid schema', async () => { + const validMessage = JSON.stringify({ + jsonrpc: '2.0', + method: 'call', + params: { + a: 1, + b: 2, + c: 3 + }, + id: 1 + }); + const mockReq = createMockRequest({ + headers: { 'content-type': 'application/json' }, + body: validMessage + }); + const mockRes = createMockResponse(); + const endpoint = '/messages'; + const transport = new SSEServerTransport(endpoint, mockRes); + await transport.start(); + + transport.onmessage = jest.fn(); + await transport.handlePostMessage(mockReq, mockRes); + expect(mockRes.writeHead).toHaveBeenCalledWith(202); + expect(mockRes.end).toHaveBeenCalledWith('Accepted'); + expect(transport.onmessage).toHaveBeenCalledWith( + { + jsonrpc: '2.0', + method: 'call', + params: { + a: 1, + b: 2, + c: 3 + }, + id: 1 + }, + { + authInfo: { + token: 'test-token' + }, + requestInfo: { + headers: { + 'content-type': 'application/json' + } + } + } + ); }); - const mockHandleRes = createMockResponse(); - - await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' }); - - expect(mockHandleRes.writeHead).toHaveBeenCalledWith(403); - expect(mockHandleRes.end).toHaveBeenCalledWith('Invalid Origin header: http://evil.com'); - }); }); - describe('Content-Type validation', () => { - it('should accept requests with application/json content-type', async () => { - const mockRes = createMockResponse(); - const transport = new SSEServerTransport('/messages', mockRes); - await transport.start(); - - const mockReq = createMockRequest({ - headers: { - 'content-type': 'application/json', - } - }); - const mockHandleRes = createMockResponse(); - - await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' }); - - expect(mockHandleRes.writeHead).toHaveBeenCalledWith(202); - expect(mockHandleRes.end).toHaveBeenCalledWith('Accepted'); - }); - - it('should accept requests with application/json with charset', async () => { - const mockRes = createMockResponse(); - const transport = new SSEServerTransport('/messages', mockRes); - await transport.start(); - - const mockReq = createMockRequest({ - headers: { - 'content-type': 'application/json; charset=utf-8', - } + describe('close method', () => { + it('should call onclose', async () => { + const mockRes = createMockResponse(); + const endpoint = '/messages'; + const transport = new SSEServerTransport(endpoint, mockRes); + await transport.start(); + transport.onclose = jest.fn(); + await transport.close(); + expect(transport.onclose).toHaveBeenCalled(); }); - const mockHandleRes = createMockResponse(); - - await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' }); - - expect(mockHandleRes.writeHead).toHaveBeenCalledWith(202); - expect(mockHandleRes.end).toHaveBeenCalledWith('Accepted'); - }); - - it('should reject requests with non-application/json content-type when protection is enabled', async () => { - const mockRes = createMockResponse(); - const transport = new SSEServerTransport('/messages', mockRes); - await transport.start(); - - const mockReq = createMockRequest({ - headers: { - 'content-type': 'text/plain', - } - }); - const mockHandleRes = createMockResponse(); - - await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' }); - - expect(mockHandleRes.writeHead).toHaveBeenCalledWith(400); - expect(mockHandleRes.end).toHaveBeenCalledWith('Error: Unsupported content-type: text/plain'); - }); }); - describe('enableDnsRebindingProtection option', () => { - it('should skip all validations when enableDnsRebindingProtection is false', async () => { - const mockRes = createMockResponse(); - const transport = new SSEServerTransport('/messages', mockRes, { - allowedHosts: ['localhost:3000'], - allowedOrigins: ['http://localhost:3000'], - enableDnsRebindingProtection: false, - }); - await transport.start(); - - const mockReq = createMockRequest({ - headers: { - host: 'evil.com', - origin: 'http://evil.com', - 'content-type': 'text/plain', - } + describe('send method', () => { + it('should call onsend', async () => { + const mockRes = createMockResponse(); + const endpoint = '/messages'; + const transport = new SSEServerTransport(endpoint, mockRes); + await transport.start(); + expect(mockRes.write).toHaveBeenCalledTimes(1); + expect(mockRes.write).toHaveBeenCalledWith(expect.stringContaining('event: endpoint')); + expect(mockRes.write).toHaveBeenCalledWith(expect.stringContaining(`data: /messages?sessionId=${transport.sessionId}`)); }); - const mockHandleRes = createMockResponse(); - - await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' }); - - // Should pass even with invalid headers because protection is disabled - expect(mockHandleRes.writeHead).toHaveBeenCalledWith(400); - // The error should be from content-type parsing, not DNS rebinding protection - expect(mockHandleRes.end).toHaveBeenCalledWith('Error: Unsupported content-type: text/plain'); - }); }); - describe('Combined validations', () => { - it('should validate both host and origin when both are configured', async () => { - const mockRes = createMockResponse(); - const transport = new SSEServerTransport('/messages', mockRes, { - allowedHosts: ['localhost:3000'], - allowedOrigins: ['http://localhost:3000'], - enableDnsRebindingProtection: true, - }); - await transport.start(); - - // Valid host, invalid origin - const mockReq1 = createMockRequest({ - headers: { - host: 'localhost:3000', - origin: 'http://evil.com', - 'content-type': 'application/json', - } + describe('DNS rebinding protection', () => { + beforeEach(() => { + jest.clearAllMocks(); }); - const mockHandleRes1 = createMockResponse(); - - await transport.handlePostMessage(mockReq1, mockHandleRes1, { jsonrpc: '2.0', method: 'test' }); - expect(mockHandleRes1.writeHead).toHaveBeenCalledWith(403); - expect(mockHandleRes1.end).toHaveBeenCalledWith('Invalid Origin header: http://evil.com'); - - // Invalid host, valid origin - const mockReq2 = createMockRequest({ - headers: { - host: 'evil.com', - origin: 'http://localhost:3000', - 'content-type': 'application/json', - } + describe('Host header validation', () => { + it('should accept requests with allowed host headers', async () => { + const mockRes = createMockResponse(); + const transport = new SSEServerTransport('/messages', mockRes, { + allowedHosts: ['localhost:3000', 'example.com'], + enableDnsRebindingProtection: true + }); + await transport.start(); + + const mockReq = createMockRequest({ + headers: { + host: 'localhost:3000', + 'content-type': 'application/json' + } + }); + const mockHandleRes = createMockResponse(); + + await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' }); + + expect(mockHandleRes.writeHead).toHaveBeenCalledWith(202); + expect(mockHandleRes.end).toHaveBeenCalledWith('Accepted'); + }); + + it('should reject requests with disallowed host headers', async () => { + const mockRes = createMockResponse(); + const transport = new SSEServerTransport('/messages', mockRes, { + allowedHosts: ['localhost:3000'], + enableDnsRebindingProtection: true + }); + await transport.start(); + + const mockReq = createMockRequest({ + headers: { + host: 'evil.com', + 'content-type': 'application/json' + } + }); + const mockHandleRes = createMockResponse(); + + await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' }); + + expect(mockHandleRes.writeHead).toHaveBeenCalledWith(403); + expect(mockHandleRes.end).toHaveBeenCalledWith('Invalid Host header: evil.com'); + }); + + it('should reject requests without host header when allowedHosts is configured', async () => { + const mockRes = createMockResponse(); + const transport = new SSEServerTransport('/messages', mockRes, { + allowedHosts: ['localhost:3000'], + enableDnsRebindingProtection: true + }); + await transport.start(); + + const mockReq = createMockRequest({ + headers: { + 'content-type': 'application/json' + } + }); + const mockHandleRes = createMockResponse(); + + await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' }); + + expect(mockHandleRes.writeHead).toHaveBeenCalledWith(403); + expect(mockHandleRes.end).toHaveBeenCalledWith('Invalid Host header: undefined'); + }); }); - const mockHandleRes2 = createMockResponse(); - - await transport.handlePostMessage(mockReq2, mockHandleRes2, { jsonrpc: '2.0', method: 'test' }); - expect(mockHandleRes2.writeHead).toHaveBeenCalledWith(403); - expect(mockHandleRes2.end).toHaveBeenCalledWith('Invalid Host header: evil.com'); + describe('Origin header validation', () => { + it('should accept requests with allowed origin headers', async () => { + const mockRes = createMockResponse(); + const transport = new SSEServerTransport('/messages', mockRes, { + allowedOrigins: ['http://localhost:3000', 'https://example.com'], + enableDnsRebindingProtection: true + }); + await transport.start(); + + const mockReq = createMockRequest({ + headers: { + origin: 'http://localhost:3000', + 'content-type': 'application/json' + } + }); + const mockHandleRes = createMockResponse(); + + await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' }); + + expect(mockHandleRes.writeHead).toHaveBeenCalledWith(202); + expect(mockHandleRes.end).toHaveBeenCalledWith('Accepted'); + }); + + it('should reject requests with disallowed origin headers', async () => { + const mockRes = createMockResponse(); + const transport = new SSEServerTransport('/messages', mockRes, { + allowedOrigins: ['http://localhost:3000'], + enableDnsRebindingProtection: true + }); + await transport.start(); + + const mockReq = createMockRequest({ + headers: { + origin: 'http://evil.com', + 'content-type': 'application/json' + } + }); + const mockHandleRes = createMockResponse(); + + await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' }); + + expect(mockHandleRes.writeHead).toHaveBeenCalledWith(403); + expect(mockHandleRes.end).toHaveBeenCalledWith('Invalid Origin header: http://evil.com'); + }); + }); - // Both valid - const mockReq3 = createMockRequest({ - headers: { - host: 'localhost:3000', - origin: 'http://localhost:3000', - 'content-type': 'application/json', - } + describe('Content-Type validation', () => { + it('should accept requests with application/json content-type', async () => { + const mockRes = createMockResponse(); + const transport = new SSEServerTransport('/messages', mockRes); + await transport.start(); + + const mockReq = createMockRequest({ + headers: { + 'content-type': 'application/json' + } + }); + const mockHandleRes = createMockResponse(); + + await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' }); + + expect(mockHandleRes.writeHead).toHaveBeenCalledWith(202); + expect(mockHandleRes.end).toHaveBeenCalledWith('Accepted'); + }); + + it('should accept requests with application/json with charset', async () => { + const mockRes = createMockResponse(); + const transport = new SSEServerTransport('/messages', mockRes); + await transport.start(); + + const mockReq = createMockRequest({ + headers: { + 'content-type': 'application/json; charset=utf-8' + } + }); + const mockHandleRes = createMockResponse(); + + await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' }); + + expect(mockHandleRes.writeHead).toHaveBeenCalledWith(202); + expect(mockHandleRes.end).toHaveBeenCalledWith('Accepted'); + }); + + it('should reject requests with non-application/json content-type when protection is enabled', async () => { + const mockRes = createMockResponse(); + const transport = new SSEServerTransport('/messages', mockRes); + await transport.start(); + + const mockReq = createMockRequest({ + headers: { + 'content-type': 'text/plain' + } + }); + const mockHandleRes = createMockResponse(); + + await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' }); + + expect(mockHandleRes.writeHead).toHaveBeenCalledWith(400); + expect(mockHandleRes.end).toHaveBeenCalledWith('Error: Unsupported content-type: text/plain'); + }); }); - const mockHandleRes3 = createMockResponse(); - await transport.handlePostMessage(mockReq3, mockHandleRes3, { jsonrpc: '2.0', method: 'test' }); + describe('enableDnsRebindingProtection option', () => { + it('should skip all validations when enableDnsRebindingProtection is false', async () => { + const mockRes = createMockResponse(); + const transport = new SSEServerTransport('/messages', mockRes, { + allowedHosts: ['localhost:3000'], + allowedOrigins: ['http://localhost:3000'], + enableDnsRebindingProtection: false + }); + await transport.start(); + + const mockReq = createMockRequest({ + headers: { + host: 'evil.com', + origin: 'http://evil.com', + 'content-type': 'text/plain' + } + }); + const mockHandleRes = createMockResponse(); + + await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' }); + + // Should pass even with invalid headers because protection is disabled + expect(mockHandleRes.writeHead).toHaveBeenCalledWith(400); + // The error should be from content-type parsing, not DNS rebinding protection + expect(mockHandleRes.end).toHaveBeenCalledWith('Error: Unsupported content-type: text/plain'); + }); + }); - expect(mockHandleRes3.writeHead).toHaveBeenCalledWith(202); - expect(mockHandleRes3.end).toHaveBeenCalledWith('Accepted'); - }); + describe('Combined validations', () => { + it('should validate both host and origin when both are configured', async () => { + const mockRes = createMockResponse(); + const transport = new SSEServerTransport('/messages', mockRes, { + allowedHosts: ['localhost:3000'], + allowedOrigins: ['http://localhost:3000'], + enableDnsRebindingProtection: true + }); + await transport.start(); + + // Valid host, invalid origin + const mockReq1 = createMockRequest({ + headers: { + host: 'localhost:3000', + origin: 'http://evil.com', + 'content-type': 'application/json' + } + }); + const mockHandleRes1 = createMockResponse(); + + await transport.handlePostMessage(mockReq1, mockHandleRes1, { jsonrpc: '2.0', method: 'test' }); + + expect(mockHandleRes1.writeHead).toHaveBeenCalledWith(403); + expect(mockHandleRes1.end).toHaveBeenCalledWith('Invalid Origin header: http://evil.com'); + + // Invalid host, valid origin + const mockReq2 = createMockRequest({ + headers: { + host: 'evil.com', + origin: 'http://localhost:3000', + 'content-type': 'application/json' + } + }); + const mockHandleRes2 = createMockResponse(); + + await transport.handlePostMessage(mockReq2, mockHandleRes2, { jsonrpc: '2.0', method: 'test' }); + + expect(mockHandleRes2.writeHead).toHaveBeenCalledWith(403); + expect(mockHandleRes2.end).toHaveBeenCalledWith('Invalid Host header: evil.com'); + + // Both valid + const mockReq3 = createMockRequest({ + headers: { + host: 'localhost:3000', + origin: 'http://localhost:3000', + 'content-type': 'application/json' + } + }); + const mockHandleRes3 = createMockResponse(); + + await transport.handlePostMessage(mockReq3, mockHandleRes3, { jsonrpc: '2.0', method: 'test' }); + + expect(mockHandleRes3.writeHead).toHaveBeenCalledWith(202); + expect(mockHandleRes3.end).toHaveBeenCalledWith('Accepted'); + }); + }); }); - }); }); diff --git a/src/server/sse.ts b/src/server/sse.ts index e07256867..3405af705 100644 --- a/src/server/sse.ts +++ b/src/server/sse.ts @@ -1,35 +1,35 @@ -import { randomUUID } from "node:crypto"; -import { IncomingMessage, ServerResponse } from "node:http"; -import { Transport } from "../shared/transport.js"; -import { JSONRPCMessage, JSONRPCMessageSchema, MessageExtraInfo, RequestInfo } from "../types.js"; -import getRawBody from "raw-body"; -import contentType from "content-type"; -import { AuthInfo } from "./auth/types.js"; +import { randomUUID } from 'node:crypto'; +import { IncomingMessage, ServerResponse } from 'node:http'; +import { Transport } from '../shared/transport.js'; +import { JSONRPCMessage, JSONRPCMessageSchema, MessageExtraInfo, RequestInfo } from '../types.js'; +import getRawBody from 'raw-body'; +import contentType from 'content-type'; +import { AuthInfo } from './auth/types.js'; import { URL } from 'url'; -const MAXIMUM_MESSAGE_SIZE = "4mb"; +const MAXIMUM_MESSAGE_SIZE = '4mb'; /** * Configuration options for SSEServerTransport. */ export interface SSEServerTransportOptions { - /** - * List of allowed host header values for DNS rebinding protection. - * If not specified, host validation is disabled. - */ - allowedHosts?: string[]; - - /** - * List of allowed origin header values for DNS rebinding protection. - * If not specified, origin validation is disabled. - */ - allowedOrigins?: string[]; - - /** - * Enable DNS rebinding protection (requires allowedHosts and/or allowedOrigins to be configured). - * Default is false for backwards compatibility. - */ - enableDnsRebindingProtection?: boolean; + /** + * List of allowed host header values for DNS rebinding protection. + * If not specified, host validation is disabled. + */ + allowedHosts?: string[]; + + /** + * List of allowed origin header values for DNS rebinding protection. + * If not specified, origin validation is disabled. + */ + allowedOrigins?: string[]; + + /** + * Enable DNS rebinding protection (requires allowedHosts and/or allowedOrigins to be configured). + * Default is false for backwards compatibility. + */ + enableDnsRebindingProtection?: boolean; } /** @@ -38,184 +38,176 @@ export interface SSEServerTransportOptions { * This transport is only available in Node.js environments. */ export class SSEServerTransport implements Transport { - private _sseResponse?: ServerResponse; - private _sessionId: string; - private _options: SSEServerTransportOptions; - onclose?: () => void; - onerror?: (error: Error) => void; - onmessage?: (message: JSONRPCMessage, extra?: MessageExtraInfo) => void; - - /** - * Creates a new SSE server transport, which will direct the client to POST messages to the relative or absolute URL identified by `_endpoint`. - */ - constructor( - private _endpoint: string, - private res: ServerResponse, - options?: SSEServerTransportOptions, - ) { - this._sessionId = randomUUID(); - this._options = options || {enableDnsRebindingProtection: false}; - } - - /** - * Validates request headers for DNS rebinding protection. - * @returns Error message if validation fails, undefined if validation passes. - */ - private validateRequestHeaders(req: IncomingMessage): string | undefined { - // Skip validation if protection is not enabled - if (!this._options.enableDnsRebindingProtection) { - return undefined; + private _sseResponse?: ServerResponse; + private _sessionId: string; + private _options: SSEServerTransportOptions; + onclose?: () => void; + onerror?: (error: Error) => void; + onmessage?: (message: JSONRPCMessage, extra?: MessageExtraInfo) => void; + + /** + * Creates a new SSE server transport, which will direct the client to POST messages to the relative or absolute URL identified by `_endpoint`. + */ + constructor( + private _endpoint: string, + private res: ServerResponse, + options?: SSEServerTransportOptions + ) { + this._sessionId = randomUUID(); + this._options = options || { enableDnsRebindingProtection: false }; } - // Validate Host header if allowedHosts is configured - if (this._options.allowedHosts && this._options.allowedHosts.length > 0) { - const hostHeader = req.headers.host; - if (!hostHeader || !this._options.allowedHosts.includes(hostHeader)) { - return `Invalid Host header: ${hostHeader}`; - } + /** + * Validates request headers for DNS rebinding protection. + * @returns Error message if validation fails, undefined if validation passes. + */ + private validateRequestHeaders(req: IncomingMessage): string | undefined { + // Skip validation if protection is not enabled + if (!this._options.enableDnsRebindingProtection) { + return undefined; + } + + // Validate Host header if allowedHosts is configured + if (this._options.allowedHosts && this._options.allowedHosts.length > 0) { + const hostHeader = req.headers.host; + if (!hostHeader || !this._options.allowedHosts.includes(hostHeader)) { + return `Invalid Host header: ${hostHeader}`; + } + } + + // Validate Origin header if allowedOrigins is configured + if (this._options.allowedOrigins && this._options.allowedOrigins.length > 0) { + const originHeader = req.headers.origin; + if (!originHeader || !this._options.allowedOrigins.includes(originHeader)) { + return `Invalid Origin header: ${originHeader}`; + } + } + + return undefined; } - // Validate Origin header if allowedOrigins is configured - if (this._options.allowedOrigins && this._options.allowedOrigins.length > 0) { - const originHeader = req.headers.origin; - if (!originHeader || !this._options.allowedOrigins.includes(originHeader)) { - return `Invalid Origin header: ${originHeader}`; - } + /** + * Handles the initial SSE connection request. + * + * This should be called when a GET request is made to establish the SSE stream. + */ + async start(): Promise { + if (this._sseResponse) { + throw new Error('SSEServerTransport already started! If using Server class, note that connect() calls start() automatically.'); + } + + this.res.writeHead(200, { + 'Content-Type': 'text/event-stream', + 'Cache-Control': 'no-cache, no-transform', + Connection: 'keep-alive' + }); + + // Send the endpoint event + // Use a dummy base URL because this._endpoint is relative. + // This allows using URL/URLSearchParams for robust parameter handling. + const dummyBase = 'http://localhost'; // Any valid base works + const endpointUrl = new URL(this._endpoint, dummyBase); + endpointUrl.searchParams.set('sessionId', this._sessionId); + + // Reconstruct the relative URL string (pathname + search + hash) + const relativeUrlWithSession = endpointUrl.pathname + endpointUrl.search + endpointUrl.hash; + + this.res.write(`event: endpoint\ndata: ${relativeUrlWithSession}\n\n`); + + this._sseResponse = this.res; + this.res.on('close', () => { + this._sseResponse = undefined; + this.onclose?.(); + }); } - return undefined; - } - - /** - * Handles the initial SSE connection request. - * - * This should be called when a GET request is made to establish the SSE stream. - */ - async start(): Promise { - if (this._sseResponse) { - throw new Error( - "SSEServerTransport already started! If using Server class, note that connect() calls start() automatically.", - ); + /** + * Handles incoming POST messages. + * + * This should be called when a POST request is made to send a message to the server. + */ + async handlePostMessage(req: IncomingMessage & { auth?: AuthInfo }, res: ServerResponse, parsedBody?: unknown): Promise { + if (!this._sseResponse) { + const message = 'SSE connection not established'; + res.writeHead(500).end(message); + throw new Error(message); + } + + // Validate request headers for DNS rebinding protection + const validationError = this.validateRequestHeaders(req); + if (validationError) { + res.writeHead(403).end(validationError); + this.onerror?.(new Error(validationError)); + return; + } + + const authInfo: AuthInfo | undefined = req.auth; + const requestInfo: RequestInfo = { headers: req.headers }; + + let body: string | unknown; + try { + const ct = contentType.parse(req.headers['content-type'] ?? ''); + if (ct.type !== 'application/json') { + throw new Error(`Unsupported content-type: ${ct.type}`); + } + + body = + parsedBody ?? + (await getRawBody(req, { + limit: MAXIMUM_MESSAGE_SIZE, + encoding: ct.parameters.charset ?? 'utf-8' + })); + } catch (error) { + res.writeHead(400).end(String(error)); + this.onerror?.(error as Error); + return; + } + + try { + await this.handleMessage(typeof body === 'string' ? JSON.parse(body) : body, { requestInfo, authInfo }); + } catch { + res.writeHead(400).end(`Invalid message: ${body}`); + return; + } + + res.writeHead(202).end('Accepted'); } - this.res.writeHead(200, { - "Content-Type": "text/event-stream", - "Cache-Control": "no-cache, no-transform", - Connection: "keep-alive", - }); - - // Send the endpoint event - // Use a dummy base URL because this._endpoint is relative. - // This allows using URL/URLSearchParams for robust parameter handling. - const dummyBase = 'http://localhost'; // Any valid base works - const endpointUrl = new URL(this._endpoint, dummyBase); - endpointUrl.searchParams.set('sessionId', this._sessionId); - - // Reconstruct the relative URL string (pathname + search + hash) - const relativeUrlWithSession = endpointUrl.pathname + endpointUrl.search + endpointUrl.hash; - - this.res.write( - `event: endpoint\ndata: ${relativeUrlWithSession}\n\n`, - ); - - this._sseResponse = this.res; - this.res.on("close", () => { - this._sseResponse = undefined; - this.onclose?.(); - }); - } - - /** - * Handles incoming POST messages. - * - * This should be called when a POST request is made to send a message to the server. - */ - async handlePostMessage( - req: IncomingMessage & { auth?: AuthInfo }, - res: ServerResponse, - parsedBody?: unknown, - ): Promise { - if (!this._sseResponse) { - const message = "SSE connection not established"; - res.writeHead(500).end(message); - throw new Error(message); + /** + * Handle a client message, regardless of how it arrived. This can be used to inform the server of messages that arrive via a means different than HTTP POST. + */ + async handleMessage(message: unknown, extra?: MessageExtraInfo): Promise { + let parsedMessage: JSONRPCMessage; + try { + parsedMessage = JSONRPCMessageSchema.parse(message); + } catch (error) { + this.onerror?.(error as Error); + throw error; + } + + this.onmessage?.(parsedMessage, extra); } - // Validate request headers for DNS rebinding protection - const validationError = this.validateRequestHeaders(req); - if (validationError) { - res.writeHead(403).end(validationError); - this.onerror?.(new Error(validationError)); - return; + async close(): Promise { + this._sseResponse?.end(); + this._sseResponse = undefined; + this.onclose?.(); } - const authInfo: AuthInfo | undefined = req.auth; - const requestInfo: RequestInfo = { headers: req.headers }; - - let body: string | unknown; - try { - const ct = contentType.parse(req.headers["content-type"] ?? ""); - if (ct.type !== "application/json") { - throw new Error(`Unsupported content-type: ${ct.type}`); - } - - body = parsedBody ?? await getRawBody(req, { - limit: MAXIMUM_MESSAGE_SIZE, - encoding: ct.parameters.charset ?? "utf-8", - }); - } catch (error) { - res.writeHead(400).end(String(error)); - this.onerror?.(error as Error); - return; - } - - try { - await this.handleMessage(typeof body === 'string' ? JSON.parse(body) : body, { requestInfo, authInfo }); - } catch { - res.writeHead(400).end(`Invalid message: ${body}`); - return; - } + async send(message: JSONRPCMessage): Promise { + if (!this._sseResponse) { + throw new Error('Not connected'); + } - res.writeHead(202).end("Accepted"); - } - - /** - * Handle a client message, regardless of how it arrived. This can be used to inform the server of messages that arrive via a means different than HTTP POST. - */ - async handleMessage(message: unknown, extra?: MessageExtraInfo): Promise { - let parsedMessage: JSONRPCMessage; - try { - parsedMessage = JSONRPCMessageSchema.parse(message); - } catch (error) { - this.onerror?.(error as Error); - throw error; + this._sseResponse.write(`event: message\ndata: ${JSON.stringify(message)}\n\n`); } - this.onmessage?.(parsedMessage, extra); - } - - async close(): Promise { - this._sseResponse?.end(); - this._sseResponse = undefined; - this.onclose?.(); - } - - async send(message: JSONRPCMessage): Promise { - if (!this._sseResponse) { - throw new Error("Not connected"); + /** + * Returns the session ID for this transport. + * + * This can be used to route incoming POST requests. + */ + get sessionId(): string { + return this._sessionId; } - - this._sseResponse.write( - `event: message\ndata: ${JSON.stringify(message)}\n\n`, - ); - } - - /** - * Returns the session ID for this transport. - * - * This can be used to route incoming POST requests. - */ - get sessionId(): string { - return this._sessionId; - } } diff --git a/src/server/stdio.test.ts b/src/server/stdio.test.ts index 5243268d8..7d5d5c11b 100644 --- a/src/server/stdio.test.ts +++ b/src/server/stdio.test.ts @@ -1,102 +1,102 @@ -import { Readable, Writable } from "node:stream"; -import { ReadBuffer, serializeMessage } from "../shared/stdio.js"; -import { JSONRPCMessage } from "../types.js"; -import { StdioServerTransport } from "./stdio.js"; +import { Readable, Writable } from 'node:stream'; +import { ReadBuffer, serializeMessage } from '../shared/stdio.js'; +import { JSONRPCMessage } from '../types.js'; +import { StdioServerTransport } from './stdio.js'; let input: Readable; let outputBuffer: ReadBuffer; let output: Writable; beforeEach(() => { - input = new Readable({ - // We'll use input.push() instead. - read: () => {}, - }); + input = new Readable({ + // We'll use input.push() instead. + read: () => {} + }); - outputBuffer = new ReadBuffer(); - output = new Writable({ - write(chunk, encoding, callback) { - outputBuffer.append(chunk); - callback(); - }, - }); + outputBuffer = new ReadBuffer(); + output = new Writable({ + write(chunk, encoding, callback) { + outputBuffer.append(chunk); + callback(); + } + }); }); -test("should start then close cleanly", async () => { - const server = new StdioServerTransport(input, output); - server.onerror = (error) => { - throw error; - }; +test('should start then close cleanly', async () => { + const server = new StdioServerTransport(input, output); + server.onerror = error => { + throw error; + }; - let didClose = false; - server.onclose = () => { - didClose = true; - }; + let didClose = false; + server.onclose = () => { + didClose = true; + }; - await server.start(); - expect(didClose).toBeFalsy(); - await server.close(); - expect(didClose).toBeTruthy(); + await server.start(); + expect(didClose).toBeFalsy(); + await server.close(); + expect(didClose).toBeTruthy(); }); -test("should not read until started", async () => { - const server = new StdioServerTransport(input, output); - server.onerror = (error) => { - throw error; - }; - - let didRead = false; - const readMessage = new Promise((resolve) => { - server.onmessage = (message) => { - didRead = true; - resolve(message); +test('should not read until started', async () => { + const server = new StdioServerTransport(input, output); + server.onerror = error => { + throw error; }; - }); - const message: JSONRPCMessage = { - jsonrpc: "2.0", - id: 1, - method: "ping", - }; - input.push(serializeMessage(message)); + let didRead = false; + const readMessage = new Promise(resolve => { + server.onmessage = message => { + didRead = true; + resolve(message); + }; + }); - expect(didRead).toBeFalsy(); - await server.start(); - expect(await readMessage).toEqual(message); + const message: JSONRPCMessage = { + jsonrpc: '2.0', + id: 1, + method: 'ping' + }; + input.push(serializeMessage(message)); + + expect(didRead).toBeFalsy(); + await server.start(); + expect(await readMessage).toEqual(message); }); -test("should read multiple messages", async () => { - const server = new StdioServerTransport(input, output); - server.onerror = (error) => { - throw error; - }; +test('should read multiple messages', async () => { + const server = new StdioServerTransport(input, output); + server.onerror = error => { + throw error; + }; - const messages: JSONRPCMessage[] = [ - { - jsonrpc: "2.0", - id: 1, - method: "ping", - }, - { - jsonrpc: "2.0", - method: "notifications/initialized", - }, - ]; + const messages: JSONRPCMessage[] = [ + { + jsonrpc: '2.0', + id: 1, + method: 'ping' + }, + { + jsonrpc: '2.0', + method: 'notifications/initialized' + } + ]; - const readMessages: JSONRPCMessage[] = []; - const finished = new Promise((resolve) => { - server.onmessage = (message) => { - readMessages.push(message); - if (JSON.stringify(message) === JSON.stringify(messages[1])) { - resolve(); - } - }; - }); + const readMessages: JSONRPCMessage[] = []; + const finished = new Promise(resolve => { + server.onmessage = message => { + readMessages.push(message); + if (JSON.stringify(message) === JSON.stringify(messages[1])) { + resolve(); + } + }; + }); - input.push(serializeMessage(messages[0])); - input.push(serializeMessage(messages[1])); + input.push(serializeMessage(messages[0])); + input.push(serializeMessage(messages[1])); - await server.start(); - await finished; - expect(readMessages).toEqual(messages); + await server.start(); + await finished; + expect(readMessages).toEqual(messages); }); diff --git a/src/server/stdio.ts b/src/server/stdio.ts index 30c80012e..ece9d0ae4 100644 --- a/src/server/stdio.ts +++ b/src/server/stdio.ts @@ -1,8 +1,8 @@ -import process from "node:process"; -import { Readable, Writable } from "node:stream"; -import { ReadBuffer, serializeMessage } from "../shared/stdio.js"; -import { JSONRPCMessage } from "../types.js"; -import { Transport } from "../shared/transport.js"; +import process from 'node:process'; +import { Readable, Writable } from 'node:stream'; +import { ReadBuffer, serializeMessage } from '../shared/stdio.js'; +import { JSONRPCMessage } from '../types.js'; +import { Transport } from '../shared/transport.js'; /** * Server transport for stdio: this communicates with a MCP client by reading from the current process' stdin and writing to stdout. @@ -10,83 +10,83 @@ import { Transport } from "../shared/transport.js"; * This transport is only available in Node.js environments. */ export class StdioServerTransport implements Transport { - private _readBuffer: ReadBuffer = new ReadBuffer(); - private _started = false; + private _readBuffer: ReadBuffer = new ReadBuffer(); + private _started = false; - constructor( - private _stdin: Readable = process.stdin, - private _stdout: Writable = process.stdout, - ) {} + constructor( + private _stdin: Readable = process.stdin, + private _stdout: Writable = process.stdout + ) {} - onclose?: () => void; - onerror?: (error: Error) => void; - onmessage?: (message: JSONRPCMessage) => void; + onclose?: () => void; + onerror?: (error: Error) => void; + onmessage?: (message: JSONRPCMessage) => void; - // Arrow functions to bind `this` properly, while maintaining function identity. - _ondata = (chunk: Buffer) => { - this._readBuffer.append(chunk); - this.processReadBuffer(); - }; - _onerror = (error: Error) => { - this.onerror?.(error); - }; + // Arrow functions to bind `this` properly, while maintaining function identity. + _ondata = (chunk: Buffer) => { + this._readBuffer.append(chunk); + this.processReadBuffer(); + }; + _onerror = (error: Error) => { + this.onerror?.(error); + }; - /** - * Starts listening for messages on stdin. - */ - async start(): Promise { - if (this._started) { - throw new Error( - "StdioServerTransport already started! If using Server class, note that connect() calls start() automatically.", - ); + /** + * Starts listening for messages on stdin. + */ + async start(): Promise { + if (this._started) { + throw new Error( + 'StdioServerTransport already started! If using Server class, note that connect() calls start() automatically.' + ); + } + + this._started = true; + this._stdin.on('data', this._ondata); + this._stdin.on('error', this._onerror); } - this._started = true; - this._stdin.on("data", this._ondata); - this._stdin.on("error", this._onerror); - } + private processReadBuffer() { + while (true) { + try { + const message = this._readBuffer.readMessage(); + if (message === null) { + break; + } - private processReadBuffer() { - while (true) { - try { - const message = this._readBuffer.readMessage(); - if (message === null) { - break; + this.onmessage?.(message); + } catch (error) { + this.onerror?.(error as Error); + } } - - this.onmessage?.(message); - } catch (error) { - this.onerror?.(error as Error); - } } - } - async close(): Promise { - // Remove our event listeners first - this._stdin.off("data", this._ondata); - this._stdin.off("error", this._onerror); + async close(): Promise { + // Remove our event listeners first + this._stdin.off('data', this._ondata); + this._stdin.off('error', this._onerror); - // Check if we were the only data listener - const remainingDataListeners = this._stdin.listenerCount('data'); - if (remainingDataListeners === 0) { - // Only pause stdin if we were the only listener - // This prevents interfering with other parts of the application that might be using stdin - this._stdin.pause(); + // Check if we were the only data listener + const remainingDataListeners = this._stdin.listenerCount('data'); + if (remainingDataListeners === 0) { + // Only pause stdin if we were the only listener + // This prevents interfering with other parts of the application that might be using stdin + this._stdin.pause(); + } + + // Clear the buffer and notify closure + this._readBuffer.clear(); + this.onclose?.(); } - - // Clear the buffer and notify closure - this._readBuffer.clear(); - this.onclose?.(); - } - send(message: JSONRPCMessage): Promise { - return new Promise((resolve) => { - const json = serializeMessage(message); - if (this._stdout.write(json)) { - resolve(); - } else { - this._stdout.once("drain", resolve); - } - }); - } + send(message: JSONRPCMessage): Promise { + return new Promise(resolve => { + const json = serializeMessage(message); + if (this._stdout.write(json)) { + resolve(); + } else { + this._stdout.once('drain', resolve); + } + }); + } } diff --git a/src/server/streamableHttp.test.ts b/src/server/streamableHttp.test.ts index 3a0a5c066..824d0f423 100644 --- a/src/server/streamableHttp.test.ts +++ b/src/server/streamableHttp.test.ts @@ -1,185 +1,178 @@ -import { createServer, type Server, IncomingMessage, ServerResponse } from "node:http"; -import { createServer as netCreateServer, AddressInfo } from "node:net"; -import { randomUUID } from "node:crypto"; -import { EventStore, StreamableHTTPServerTransport, EventId, StreamId } from "./streamableHttp.js"; -import { McpServer } from "./mcp.js"; -import { CallToolResult, JSONRPCMessage } from "../types.js"; -import { z } from "zod"; -import { AuthInfo } from "./auth/types.js"; +import { createServer, type Server, IncomingMessage, ServerResponse } from 'node:http'; +import { createServer as netCreateServer, AddressInfo } from 'node:net'; +import { randomUUID } from 'node:crypto'; +import { EventStore, StreamableHTTPServerTransport, EventId, StreamId } from './streamableHttp.js'; +import { McpServer } from './mcp.js'; +import { CallToolResult, JSONRPCMessage } from '../types.js'; +import { z } from 'zod'; +import { AuthInfo } from './auth/types.js'; async function getFreePort() { - return new Promise(res => { - const srv = netCreateServer(); - srv.listen(0, () => { - const address = srv.address()! - if (typeof address === "string") { - throw new Error("Unexpected address type: " + typeof address); - } - const port = (address as AddressInfo).port; - srv.close((_err) => res(port)) - }); - }) + return new Promise(res => { + const srv = netCreateServer(); + srv.listen(0, () => { + const address = srv.address()!; + if (typeof address === 'string') { + throw new Error('Unexpected address type: ' + typeof address); + } + const port = (address as AddressInfo).port; + srv.close(_err => res(port)); + }); + }); } /** * Test server configuration for StreamableHTTPServerTransport tests */ interface TestServerConfig { - sessionIdGenerator: (() => string) | undefined; - enableJsonResponse?: boolean; - customRequestHandler?: (req: IncomingMessage, res: ServerResponse, parsedBody?: unknown) => Promise; - eventStore?: EventStore; - onsessioninitialized?: (sessionId: string) => void | Promise; - onsessionclosed?: (sessionId: string) => void | Promise; + sessionIdGenerator: (() => string) | undefined; + enableJsonResponse?: boolean; + customRequestHandler?: (req: IncomingMessage, res: ServerResponse, parsedBody?: unknown) => Promise; + eventStore?: EventStore; + onsessioninitialized?: (sessionId: string) => void | Promise; + onsessionclosed?: (sessionId: string) => void | Promise; } /** * Helper to create and start test HTTP server with MCP setup */ -async function createTestServer(config: TestServerConfig = { sessionIdGenerator: (() => randomUUID()) }): Promise<{ - server: Server; - transport: StreamableHTTPServerTransport; - mcpServer: McpServer; - baseUrl: URL; +async function createTestServer(config: TestServerConfig = { sessionIdGenerator: () => randomUUID() }): Promise<{ + server: Server; + transport: StreamableHTTPServerTransport; + mcpServer: McpServer; + baseUrl: URL; }> { - const mcpServer = new McpServer( - { name: "test-server", version: "1.0.0" }, - { capabilities: { logging: {} } } - ); - - mcpServer.tool( - "greet", - "A simple greeting tool", - { name: z.string().describe("Name to greet") }, - async ({ name }): Promise => { - return { content: [{ type: "text", text: `Hello, ${name}!` }] }; - } - ); - - const transport = new StreamableHTTPServerTransport({ - sessionIdGenerator: config.sessionIdGenerator, - enableJsonResponse: config.enableJsonResponse ?? false, - eventStore: config.eventStore, - onsessioninitialized: config.onsessioninitialized, - onsessionclosed: config.onsessionclosed - }); - - await mcpServer.connect(transport); - - const server = createServer(async (req, res) => { - try { - if (config.customRequestHandler) { - await config.customRequestHandler(req, res); - } else { - await transport.handleRequest(req, res); - } - } catch (error) { - console.error("Error handling request:", error); - if (!res.headersSent) res.writeHead(500).end(); - } - }); + const mcpServer = new McpServer({ name: 'test-server', version: '1.0.0' }, { capabilities: { logging: {} } }); - const baseUrl = await new Promise((resolve) => { - server.listen(0, "127.0.0.1", () => { - const addr = server.address() as AddressInfo; - resolve(new URL(`http://127.0.0.1:${addr.port}`)); + mcpServer.tool( + 'greet', + 'A simple greeting tool', + { name: z.string().describe('Name to greet') }, + async ({ name }): Promise => { + return { content: [{ type: 'text', text: `Hello, ${name}!` }] }; + } + ); + + const transport = new StreamableHTTPServerTransport({ + sessionIdGenerator: config.sessionIdGenerator, + enableJsonResponse: config.enableJsonResponse ?? false, + eventStore: config.eventStore, + onsessioninitialized: config.onsessioninitialized, + onsessionclosed: config.onsessionclosed }); - }); - return { server, transport, mcpServer, baseUrl }; + await mcpServer.connect(transport); + + const server = createServer(async (req, res) => { + try { + if (config.customRequestHandler) { + await config.customRequestHandler(req, res); + } else { + await transport.handleRequest(req, res); + } + } catch (error) { + console.error('Error handling request:', error); + if (!res.headersSent) res.writeHead(500).end(); + } + }); + + const baseUrl = await new Promise(resolve => { + server.listen(0, '127.0.0.1', () => { + const addr = server.address() as AddressInfo; + resolve(new URL(`http://127.0.0.1:${addr.port}`)); + }); + }); + + return { server, transport, mcpServer, baseUrl }; } /** * Helper to create and start authenticated test HTTP server with MCP setup */ -async function createTestAuthServer(config: TestServerConfig = { sessionIdGenerator: (() => randomUUID()) }): Promise<{ - server: Server; - transport: StreamableHTTPServerTransport; - mcpServer: McpServer; - baseUrl: URL; +async function createTestAuthServer(config: TestServerConfig = { sessionIdGenerator: () => randomUUID() }): Promise<{ + server: Server; + transport: StreamableHTTPServerTransport; + mcpServer: McpServer; + baseUrl: URL; }> { - const mcpServer = new McpServer( - { name: "test-server", version: "1.0.0" }, - { capabilities: { logging: {} } } - ); - - mcpServer.tool( - "profile", - "A user profile data tool", - { active: z.boolean().describe("Profile status") }, - async ({ active }, { authInfo }): Promise => { - return { content: [{ type: "text", text: `${active ? 'Active' : 'Inactive'} profile from token: ${authInfo?.token}!` }] }; - } - ); - - const transport = new StreamableHTTPServerTransport({ - sessionIdGenerator: config.sessionIdGenerator, - enableJsonResponse: config.enableJsonResponse ?? false, - eventStore: config.eventStore, - onsessioninitialized: config.onsessioninitialized, - onsessionclosed: config.onsessionclosed - }); - - await mcpServer.connect(transport); - - const server = createServer(async (req: IncomingMessage & { auth?: AuthInfo }, res) => { - try { - if (config.customRequestHandler) { - await config.customRequestHandler(req, res); - } else { - req.auth = { token: req.headers["authorization"]?.split(" ")[1] } as AuthInfo; - await transport.handleRequest(req, res); - } - } catch (error) { - console.error("Error handling request:", error); - if (!res.headersSent) res.writeHead(500).end(); - } - }); + const mcpServer = new McpServer({ name: 'test-server', version: '1.0.0' }, { capabilities: { logging: {} } }); + + mcpServer.tool( + 'profile', + 'A user profile data tool', + { active: z.boolean().describe('Profile status') }, + async ({ active }, { authInfo }): Promise => { + return { content: [{ type: 'text', text: `${active ? 'Active' : 'Inactive'} profile from token: ${authInfo?.token}!` }] }; + } + ); - const baseUrl = await new Promise((resolve) => { - server.listen(0, "127.0.0.1", () => { - const addr = server.address() as AddressInfo; - resolve(new URL(`http://127.0.0.1:${addr.port}`)); + const transport = new StreamableHTTPServerTransport({ + sessionIdGenerator: config.sessionIdGenerator, + enableJsonResponse: config.enableJsonResponse ?? false, + eventStore: config.eventStore, + onsessioninitialized: config.onsessioninitialized, + onsessionclosed: config.onsessionclosed }); - }); - return { server, transport, mcpServer, baseUrl }; + await mcpServer.connect(transport); + + const server = createServer(async (req: IncomingMessage & { auth?: AuthInfo }, res) => { + try { + if (config.customRequestHandler) { + await config.customRequestHandler(req, res); + } else { + req.auth = { token: req.headers['authorization']?.split(' ')[1] } as AuthInfo; + await transport.handleRequest(req, res); + } + } catch (error) { + console.error('Error handling request:', error); + if (!res.headersSent) res.writeHead(500).end(); + } + }); + + const baseUrl = await new Promise(resolve => { + server.listen(0, '127.0.0.1', () => { + const addr = server.address() as AddressInfo; + resolve(new URL(`http://127.0.0.1:${addr.port}`)); + }); + }); + + return { server, transport, mcpServer, baseUrl }; } /** * Helper to stop test server */ async function stopTestServer({ server, transport }: { server: Server; transport: StreamableHTTPServerTransport }): Promise { - // First close the transport to ensure all SSE streams are closed - await transport.close(); + // First close the transport to ensure all SSE streams are closed + await transport.close(); - // Close the server without waiting indefinitely - server.close(); + // Close the server without waiting indefinitely + server.close(); } /** * Common test messages */ const TEST_MESSAGES = { - initialize: { - jsonrpc: "2.0", - method: "initialize", - params: { - clientInfo: { name: "test-client", version: "1.0" }, - protocolVersion: "2025-03-26", - capabilities: { - }, - }, - - id: "init-1", - } as JSONRPCMessage, - - toolsList: { - jsonrpc: "2.0", - method: "tools/list", - params: {}, - id: "tools-1", - } as JSONRPCMessage + initialize: { + jsonrpc: '2.0', + method: 'initialize', + params: { + clientInfo: { name: 'test-client', version: '1.0' }, + protocolVersion: '2025-03-26', + capabilities: {} + }, + + id: 'init-1' + } as JSONRPCMessage, + + toolsList: { + jsonrpc: '2.0', + method: 'tools/list', + params: {}, + id: 'tools-1' + } as JSONRPCMessage }; /** @@ -188,1961 +181,1968 @@ const TEST_MESSAGES = { * get the reader manually and read multiple times. */ async function readSSEEvent(response: Response): Promise { - const reader = response.body?.getReader(); - const { value } = await reader!.read(); - return new TextDecoder().decode(value); + const reader = response.body?.getReader(); + const { value } = await reader!.read(); + return new TextDecoder().decode(value); } /** * Helper to send JSON-RPC request */ -async function sendPostRequest(baseUrl: URL, message: JSONRPCMessage | JSONRPCMessage[], sessionId?: string, extraHeaders?: Record): Promise { - const headers: Record = { - "Content-Type": "application/json", - Accept: "application/json, text/event-stream", - ...extraHeaders - }; - - if (sessionId) { - headers["mcp-session-id"] = sessionId; - // After initialization, include the protocol version header - headers["mcp-protocol-version"] = "2025-03-26"; - } - - return fetch(baseUrl, { - method: "POST", - headers, - body: JSON.stringify(message), - }); -} - -function expectErrorResponse(data: unknown, expectedCode: number, expectedMessagePattern: RegExp): void { - expect(data).toMatchObject({ - jsonrpc: "2.0", - error: expect.objectContaining({ - code: expectedCode, - message: expect.stringMatching(expectedMessagePattern), - }), - }); -} - -describe("StreamableHTTPServerTransport", () => { - let server: Server; - let mcpServer: McpServer; - let transport: StreamableHTTPServerTransport; - let baseUrl: URL; - let sessionId: string; - - beforeEach(async () => { - const result = await createTestServer(); - server = result.server; - transport = result.transport; - mcpServer = result.mcpServer; - baseUrl = result.baseUrl; - }); - - afterEach(async () => { - await stopTestServer({ server, transport }); - }); - - async function initializeServer(): Promise { - const response = await sendPostRequest(baseUrl, TEST_MESSAGES.initialize); - - expect(response.status).toBe(200); - const newSessionId = response.headers.get("mcp-session-id"); - expect(newSessionId).toBeDefined(); - return newSessionId as string; - } - - it("should initialize server and generate session ID", async () => { - const response = await sendPostRequest(baseUrl, TEST_MESSAGES.initialize); - - expect(response.status).toBe(200); - expect(response.headers.get("content-type")).toBe("text/event-stream"); - expect(response.headers.get("mcp-session-id")).toBeDefined(); - }); - - it("should reject second initialization request", async () => { - // First initialize - const sessionId = await initializeServer(); - expect(sessionId).toBeDefined(); - - // Try second initialize - const secondInitMessage = { - ...TEST_MESSAGES.initialize, - id: "second-init" +async function sendPostRequest( + baseUrl: URL, + message: JSONRPCMessage | JSONRPCMessage[], + sessionId?: string, + extraHeaders?: Record +): Promise { + const headers: Record = { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream', + ...extraHeaders }; - const response = await sendPostRequest(baseUrl, secondInitMessage); - - expect(response.status).toBe(400); - const errorData = await response.json(); - expectErrorResponse(errorData, -32600, /Server already initialized/); - }); - - it("should reject batch initialize request", async () => { - const batchInitMessages: JSONRPCMessage[] = [ - TEST_MESSAGES.initialize, - { - jsonrpc: "2.0", - method: "initialize", - params: { - clientInfo: { name: "test-client-2", version: "1.0" }, - protocolVersion: "2025-03-26", - }, - id: "init-2", - } - ]; - - const response = await sendPostRequest(baseUrl, batchInitMessages); - - expect(response.status).toBe(400); - const errorData = await response.json(); - expectErrorResponse(errorData, -32600, /Only one initialization request is allowed/); - }); - - it("should handle post requests via sse response correctly", async () => { - sessionId = await initializeServer(); - - const response = await sendPostRequest(baseUrl, TEST_MESSAGES.toolsList, sessionId); - - expect(response.status).toBe(200); - - // Read the SSE stream for the response - const text = await readSSEEvent(response); - - // Parse the SSE event - const eventLines = text.split("\n"); - const dataLine = eventLines.find(line => line.startsWith("data:")); - expect(dataLine).toBeDefined(); - - const eventData = JSON.parse(dataLine!.substring(5)); - expect(eventData).toMatchObject({ - jsonrpc: "2.0", - result: expect.objectContaining({ - tools: expect.arrayContaining([ - expect.objectContaining({ - name: "greet", - description: "A simple greeting tool", - }), - ]), - }), - id: "tools-1", - }); - }); - - it("should call a tool and return the result", async () => { - sessionId = await initializeServer(); - - const toolCallMessage: JSONRPCMessage = { - jsonrpc: "2.0", - method: "tools/call", - params: { - name: "greet", - arguments: { - name: "Test User", - }, - }, - id: "call-1", - }; - - const response = await sendPostRequest(baseUrl, toolCallMessage, sessionId); - expect(response.status).toBe(200); - - const text = await readSSEEvent(response); - const eventLines = text.split("\n"); - const dataLine = eventLines.find(line => line.startsWith("data:")); - expect(dataLine).toBeDefined(); - - const eventData = JSON.parse(dataLine!.substring(5)); - expect(eventData).toMatchObject({ - jsonrpc: "2.0", - result: { - content: [ - { - type: "text", - text: "Hello, Test User!", - }, - ], - }, - id: "call-1", - }); - }); - - /*** - * Test: Tool With Request Info - */ - it("should pass request info to tool callback", async () => { - sessionId = await initializeServer(); - - mcpServer.tool( - "test-request-info", - "A simple test tool with request info", - { name: z.string().describe("Name to greet") }, - async ({ name }, { requestInfo }): Promise => { - return { content: [{ type: "text", text: `Hello, ${name}!` }, { type: "text", text: `${JSON.stringify(requestInfo)}` }] }; - } - ); - - const toolCallMessage: JSONRPCMessage = { - jsonrpc: "2.0", - method: "tools/call", - params: { - name: "test-request-info", - arguments: { - name: "Test User", - }, - }, - id: "call-1", - }; - - const response = await sendPostRequest(baseUrl, toolCallMessage, sessionId); - expect(response.status).toBe(200); - - const text = await readSSEEvent(response); - const eventLines = text.split("\n"); - const dataLine = eventLines.find(line => line.startsWith("data:")); - expect(dataLine).toBeDefined(); - - const eventData = JSON.parse(dataLine!.substring(5)); - - expect(eventData).toMatchObject({ - jsonrpc: "2.0", - result: { - content: [ - { type: "text", text: "Hello, Test User!" }, - { type: "text", text: expect.any(String) } - ], - }, - id: "call-1", - }); - - const requestInfo = JSON.parse(eventData.result.content[1].text); - expect(requestInfo).toMatchObject({ - headers: { - 'content-type': 'application/json', - accept: 'application/json, text/event-stream', - connection: 'keep-alive', - 'mcp-session-id': sessionId, - 'accept-language': '*', - 'user-agent': expect.any(String), - 'accept-encoding': expect.any(String), - 'content-length': expect.any(String), - }, - }); - }); - - it("should reject requests without a valid session ID", async () => { - const response = await sendPostRequest(baseUrl, TEST_MESSAGES.toolsList); - - expect(response.status).toBe(400); - const errorData = await response.json(); - expectErrorResponse(errorData, -32000, /Bad Request/); - expect(errorData.id).toBeNull(); - }); - - it("should reject invalid session ID", async () => { - // First initialize to be in valid state - await initializeServer(); - - // Now try with invalid session ID - const response = await sendPostRequest(baseUrl, TEST_MESSAGES.toolsList, "invalid-session-id"); - - expect(response.status).toBe(404); - const errorData = await response.json(); - expectErrorResponse(errorData, -32001, /Session not found/); - }); - - it("should establish standalone SSE stream and receive server-initiated messages", async () => { - // First initialize to get a session ID - sessionId = await initializeServer(); - - // Open a standalone SSE stream - const sseResponse = await fetch(baseUrl, { - method: "GET", - headers: { - Accept: "text/event-stream", - "mcp-session-id": sessionId, - "mcp-protocol-version": "2025-03-26", - }, - }); - - expect(sseResponse.status).toBe(200); - expect(sseResponse.headers.get("content-type")).toBe("text/event-stream"); - - - // Send a notification (server-initiated message) that should appear on SSE stream - const notification: JSONRPCMessage = { - jsonrpc: "2.0", - method: "notifications/message", - params: { level: "info", data: "Test notification" }, - }; - - // Send the notification via transport - await transport.send(notification); - - // Read from the stream and verify we got the notification - const text = await readSSEEvent(sseResponse); - - const eventLines = text.split("\n"); - const dataLine = eventLines.find(line => line.startsWith("data:")); - expect(dataLine).toBeDefined(); + if (sessionId) { + headers['mcp-session-id'] = sessionId; + // After initialization, include the protocol version header + headers['mcp-protocol-version'] = '2025-03-26'; + } - const eventData = JSON.parse(dataLine!.substring(5)); - expect(eventData).toMatchObject({ - jsonrpc: "2.0", - method: "notifications/message", - params: { level: "info", data: "Test notification" }, + return fetch(baseUrl, { + method: 'POST', + headers, + body: JSON.stringify(message) }); - }); - - it("should not close GET SSE stream after sending multiple server notifications", async () => { - sessionId = await initializeServer(); +} - // Open a standalone SSE stream - const sseResponse = await fetch(baseUrl, { - method: "GET", - headers: { - Accept: "text/event-stream", - "mcp-session-id": sessionId, - "mcp-protocol-version": "2025-03-26", - }, +function expectErrorResponse(data: unknown, expectedCode: number, expectedMessagePattern: RegExp): void { + expect(data).toMatchObject({ + jsonrpc: '2.0', + error: expect.objectContaining({ + code: expectedCode, + message: expect.stringMatching(expectedMessagePattern) + }) }); +} - expect(sseResponse.status).toBe(200); - const reader = sseResponse.body?.getReader(); - - // Send multiple notifications - const notification1: JSONRPCMessage = { - jsonrpc: "2.0", - method: "notifications/message", - params: { level: "info", data: "First notification" } - }; - - // Just send one and verify it comes through - then the stream should stay open - await transport.send(notification1); - - const { value, done } = await reader!.read(); - const text = new TextDecoder().decode(value); - expect(text).toContain("First notification"); - expect(done).toBe(false); // Stream should still be open - }); - - it("should reject second SSE stream for the same session", async () => { - sessionId = await initializeServer(); - - // Open first SSE stream - const firstStream = await fetch(baseUrl, { - method: "GET", - headers: { - Accept: "text/event-stream", - "mcp-session-id": sessionId, - "mcp-protocol-version": "2025-03-26", - }, - }); - - expect(firstStream.status).toBe(200); - - // Try to open a second SSE stream with the same session ID - const secondStream = await fetch(baseUrl, { - method: "GET", - headers: { - Accept: "text/event-stream", - "mcp-session-id": sessionId, - "mcp-protocol-version": "2025-03-26", - }, - }); - - // Should be rejected - expect(secondStream.status).toBe(409); // Conflict - const errorData = await secondStream.json(); - expectErrorResponse(errorData, -32000, /Only one SSE stream is allowed per session/); - }); - - it("should reject GET requests without Accept: text/event-stream header", async () => { - sessionId = await initializeServer(); - - // Try GET without proper Accept header - const response = await fetch(baseUrl, { - method: "GET", - headers: { - Accept: "application/json", - "mcp-session-id": sessionId, - "mcp-protocol-version": "2025-03-26", - }, - }); - - expect(response.status).toBe(406); - const errorData = await response.json(); - expectErrorResponse(errorData, -32000, /Client must accept text\/event-stream/); - }); - - it("should reject POST requests without proper Accept header", async () => { - sessionId = await initializeServer(); - - // Try POST without Accept: text/event-stream - const response = await fetch(baseUrl, { - method: "POST", - headers: { - "Content-Type": "application/json", - Accept: "application/json", // Missing text/event-stream - "mcp-session-id": sessionId, - }, - body: JSON.stringify(TEST_MESSAGES.toolsList), - }); - - expect(response.status).toBe(406); - const errorData = await response.json(); - expectErrorResponse(errorData, -32000, /Client must accept both application\/json and text\/event-stream/); - }); - - it("should reject unsupported Content-Type", async () => { - sessionId = await initializeServer(); - - // Try POST with text/plain Content-Type - const response = await fetch(baseUrl, { - method: "POST", - headers: { - "Content-Type": "text/plain", - Accept: "application/json, text/event-stream", - "mcp-session-id": sessionId, - }, - body: "This is plain text", - }); - - expect(response.status).toBe(415); - const errorData = await response.json(); - expectErrorResponse(errorData, -32000, /Content-Type must be application\/json/); - }); - - it("should handle JSON-RPC batch notification messages with 202 response", async () => { - sessionId = await initializeServer(); - - // Send batch of notifications (no IDs) - const batchNotifications: JSONRPCMessage[] = [ - { jsonrpc: "2.0", method: "someNotification1", params: {} }, - { jsonrpc: "2.0", method: "someNotification2", params: {} }, - ]; - const response = await sendPostRequest(baseUrl, batchNotifications, sessionId); - - expect(response.status).toBe(202); - }); - - it("should handle batch request messages with SSE stream for responses", async () => { - sessionId = await initializeServer(); - - // Send batch of requests - const batchRequests: JSONRPCMessage[] = [ - { jsonrpc: "2.0", method: "tools/list", params: {}, id: "req-1" }, - { jsonrpc: "2.0", method: "tools/call", params: { name: "greet", arguments: { name: "BatchUser" } }, id: "req-2" }, - ]; - const response = await sendPostRequest(baseUrl, batchRequests, sessionId); - - expect(response.status).toBe(200); - expect(response.headers.get("content-type")).toBe("text/event-stream"); - - const reader = response.body?.getReader(); - - // The responses may come in any order or together in one chunk - const { value } = await reader!.read(); - const text = new TextDecoder().decode(value); - - // Check that both responses were sent on the same stream - expect(text).toContain('"id":"req-1"'); - expect(text).toContain('"tools"'); // tools/list result - expect(text).toContain('"id":"req-2"'); - expect(text).toContain('Hello, BatchUser'); // tools/call result - }); - - it("should properly handle invalid JSON data", async () => { - sessionId = await initializeServer(); - - // Send invalid JSON - const response = await fetch(baseUrl, { - method: "POST", - headers: { - "Content-Type": "application/json", - Accept: "application/json, text/event-stream", - "mcp-session-id": sessionId, - }, - body: "This is not valid JSON", - }); - - expect(response.status).toBe(400); - const errorData = await response.json(); - expectErrorResponse(errorData, -32700, /Parse error/); - }); - - it("should return 400 error for invalid JSON-RPC messages", async () => { - sessionId = await initializeServer(); - - // Invalid JSON-RPC (missing required jsonrpc version) - const invalidMessage = { method: "tools/list", params: {}, id: 1 }; // missing jsonrpc version - const response = await sendPostRequest(baseUrl, invalidMessage as JSONRPCMessage, sessionId); - - expect(response.status).toBe(400); - const errorData = await response.json(); - expect(errorData).toMatchObject({ - jsonrpc: "2.0", - error: expect.anything(), - }); - }); - - it("should reject requests to uninitialized server", async () => { - // Create a new HTTP server and transport without initializing - const { server: uninitializedServer, transport: uninitializedTransport, baseUrl: uninitializedUrl } = await createTestServer(); - // Transport not used in test but needed for cleanup - - // No initialization, just send a request directly - const uninitializedMessage: JSONRPCMessage = { - jsonrpc: "2.0", - method: "tools/list", - params: {}, - id: "uninitialized-test", - }; - - // Send a request to uninitialized server - const response = await sendPostRequest(uninitializedUrl, uninitializedMessage, "any-session-id"); - - expect(response.status).toBe(400); - const errorData = await response.json(); - expectErrorResponse(errorData, -32000, /Server not initialized/); - - // Cleanup - await stopTestServer({ server: uninitializedServer, transport: uninitializedTransport }); - }); - - it("should send response messages to the connection that sent the request", async () => { - sessionId = await initializeServer(); - - const message1: JSONRPCMessage = { - jsonrpc: "2.0", - method: "tools/list", - params: {}, - id: "req-1" - }; - - const message2: JSONRPCMessage = { - jsonrpc: "2.0", - method: "tools/call", - params: { - name: "greet", - arguments: { name: "Connection2" } - }, - id: "req-2" - }; - - // Make two concurrent fetch connections for different requests - const req1 = sendPostRequest(baseUrl, message1, sessionId); - const req2 = sendPostRequest(baseUrl, message2, sessionId); - - // Get both responses - const [response1, response2] = await Promise.all([req1, req2]); - const reader1 = response1.body?.getReader(); - const reader2 = response2.body?.getReader(); +describe('StreamableHTTPServerTransport', () => { + let server: Server; + let mcpServer: McpServer; + let transport: StreamableHTTPServerTransport; + let baseUrl: URL; + let sessionId: string; - // Read responses from each stream (requires each receives its specific response) - const { value: value1 } = await reader1!.read(); - const text1 = new TextDecoder().decode(value1); - expect(text1).toContain('"id":"req-1"'); - expect(text1).toContain('"tools"'); // tools/list result - - const { value: value2 } = await reader2!.read(); - const text2 = new TextDecoder().decode(value2); - expect(text2).toContain('"id":"req-2"'); - expect(text2).toContain('Hello, Connection2'); // tools/call result - }); - - it("should keep stream open after sending server notifications", async () => { - sessionId = await initializeServer(); - - // Open a standalone SSE stream - const sseResponse = await fetch(baseUrl, { - method: "GET", - headers: { - Accept: "text/event-stream", - "mcp-session-id": sessionId, - "mcp-protocol-version": "2025-03-26", - }, - }); - - // Send several server-initiated notifications - await transport.send({ - jsonrpc: "2.0", - method: "notifications/message", - params: { level: "info", data: "First notification" }, - }); - - await transport.send({ - jsonrpc: "2.0", - method: "notifications/message", - params: { level: "info", data: "Second notification" }, - }); - - // Stream should still be open - it should not close after sending notifications - expect(sseResponse.bodyUsed).toBe(false); - }); - - // The current implementation will close the entire transport for DELETE - // Creating a temporary transport/server where we don't care if it gets closed - it("should properly handle DELETE requests and close session", async () => { - // Setup a temporary server for this test - const tempResult = await createTestServer(); - const tempServer = tempResult.server; - const tempUrl = tempResult.baseUrl; - - // Initialize to get a session ID - const initResponse = await sendPostRequest(tempUrl, TEST_MESSAGES.initialize); - const tempSessionId = initResponse.headers.get("mcp-session-id"); - - // Now DELETE the session - const deleteResponse = await fetch(tempUrl, { - method: "DELETE", - headers: { - "mcp-session-id": tempSessionId || "", - "mcp-protocol-version": "2025-03-26", - }, - }); - - expect(deleteResponse.status).toBe(200); - - // Clean up - don't wait indefinitely for server close - tempServer.close(); - }); - - it("should reject DELETE requests with invalid session ID", async () => { - // Initialize the server first to activate it - sessionId = await initializeServer(); - - // Try to delete with invalid session ID - const response = await fetch(baseUrl, { - method: "DELETE", - headers: { - "mcp-session-id": "invalid-session-id", - "mcp-protocol-version": "2025-03-26", - }, - }); - - expect(response.status).toBe(404); - const errorData = await response.json(); - expectErrorResponse(errorData, -32001, /Session not found/); - }); - - describe("protocol version header validation", () => { - it("should accept requests with matching protocol version", async () => { - sessionId = await initializeServer(); - - // Send request with matching protocol version - const response = await sendPostRequest(baseUrl, TEST_MESSAGES.toolsList, sessionId); - - expect(response.status).toBe(200); - }); - - it("should accept requests without protocol version header", async () => { - sessionId = await initializeServer(); - - // Send request without protocol version header - const response = await fetch(baseUrl, { - method: "POST", - headers: { - "Content-Type": "application/json", - Accept: "application/json, text/event-stream", - "mcp-session-id": sessionId, - // No mcp-protocol-version header - }, - body: JSON.stringify(TEST_MESSAGES.toolsList), - }); + beforeEach(async () => { + const result = await createTestServer(); + server = result.server; + transport = result.transport; + mcpServer = result.mcpServer; + baseUrl = result.baseUrl; + }); - expect(response.status).toBe(200); + afterEach(async () => { + await stopTestServer({ server, transport }); }); - it("should reject requests with unsupported protocol version", async () => { - sessionId = await initializeServer(); + async function initializeServer(): Promise { + const response = await sendPostRequest(baseUrl, TEST_MESSAGES.initialize); - // Send request with unsupported protocol version - const response = await fetch(baseUrl, { - method: "POST", - headers: { - "Content-Type": "application/json", - Accept: "application/json, text/event-stream", - "mcp-session-id": sessionId, - "mcp-protocol-version": "1999-01-01", // Unsupported version - }, - body: JSON.stringify(TEST_MESSAGES.toolsList), - }); + expect(response.status).toBe(200); + const newSessionId = response.headers.get('mcp-session-id'); + expect(newSessionId).toBeDefined(); + return newSessionId as string; + } - expect(response.status).toBe(400); - const errorData = await response.json(); - expectErrorResponse(errorData, -32000, /Bad Request: Unsupported protocol version \(supported versions: .+\)/); + it('should initialize server and generate session ID', async () => { + const response = await sendPostRequest(baseUrl, TEST_MESSAGES.initialize); + + expect(response.status).toBe(200); + expect(response.headers.get('content-type')).toBe('text/event-stream'); + expect(response.headers.get('mcp-session-id')).toBeDefined(); + }); + + it('should reject second initialization request', async () => { + // First initialize + const sessionId = await initializeServer(); + expect(sessionId).toBeDefined(); + + // Try second initialize + const secondInitMessage = { + ...TEST_MESSAGES.initialize, + id: 'second-init' + }; + + const response = await sendPostRequest(baseUrl, secondInitMessage); + + expect(response.status).toBe(400); + const errorData = await response.json(); + expectErrorResponse(errorData, -32600, /Server already initialized/); + }); + + it('should reject batch initialize request', async () => { + const batchInitMessages: JSONRPCMessage[] = [ + TEST_MESSAGES.initialize, + { + jsonrpc: '2.0', + method: 'initialize', + params: { + clientInfo: { name: 'test-client-2', version: '1.0' }, + protocolVersion: '2025-03-26' + }, + id: 'init-2' + } + ]; + + const response = await sendPostRequest(baseUrl, batchInitMessages); + + expect(response.status).toBe(400); + const errorData = await response.json(); + expectErrorResponse(errorData, -32600, /Only one initialization request is allowed/); + }); + + it('should handle post requests via sse response correctly', async () => { + sessionId = await initializeServer(); + + const response = await sendPostRequest(baseUrl, TEST_MESSAGES.toolsList, sessionId); + + expect(response.status).toBe(200); + + // Read the SSE stream for the response + const text = await readSSEEvent(response); + + // Parse the SSE event + const eventLines = text.split('\n'); + const dataLine = eventLines.find(line => line.startsWith('data:')); + expect(dataLine).toBeDefined(); + + const eventData = JSON.parse(dataLine!.substring(5)); + expect(eventData).toMatchObject({ + jsonrpc: '2.0', + result: expect.objectContaining({ + tools: expect.arrayContaining([ + expect.objectContaining({ + name: 'greet', + description: 'A simple greeting tool' + }) + ]) + }), + id: 'tools-1' + }); + }); + + it('should call a tool and return the result', async () => { + sessionId = await initializeServer(); + + const toolCallMessage: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'tools/call', + params: { + name: 'greet', + arguments: { + name: 'Test User' + } + }, + id: 'call-1' + }; + + const response = await sendPostRequest(baseUrl, toolCallMessage, sessionId); + expect(response.status).toBe(200); + + const text = await readSSEEvent(response); + const eventLines = text.split('\n'); + const dataLine = eventLines.find(line => line.startsWith('data:')); + expect(dataLine).toBeDefined(); + + const eventData = JSON.parse(dataLine!.substring(5)); + expect(eventData).toMatchObject({ + jsonrpc: '2.0', + result: { + content: [ + { + type: 'text', + text: 'Hello, Test User!' + } + ] + }, + id: 'call-1' + }); + }); + + /*** + * Test: Tool With Request Info + */ + it('should pass request info to tool callback', async () => { + sessionId = await initializeServer(); + + mcpServer.tool( + 'test-request-info', + 'A simple test tool with request info', + { name: z.string().describe('Name to greet') }, + async ({ name }, { requestInfo }): Promise => { + return { + content: [ + { type: 'text', text: `Hello, ${name}!` }, + { type: 'text', text: `${JSON.stringify(requestInfo)}` } + ] + }; + } + ); + + const toolCallMessage: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'tools/call', + params: { + name: 'test-request-info', + arguments: { + name: 'Test User' + } + }, + id: 'call-1' + }; + + const response = await sendPostRequest(baseUrl, toolCallMessage, sessionId); + expect(response.status).toBe(200); + + const text = await readSSEEvent(response); + const eventLines = text.split('\n'); + const dataLine = eventLines.find(line => line.startsWith('data:')); + expect(dataLine).toBeDefined(); + + const eventData = JSON.parse(dataLine!.substring(5)); + + expect(eventData).toMatchObject({ + jsonrpc: '2.0', + result: { + content: [ + { type: 'text', text: 'Hello, Test User!' }, + { type: 'text', text: expect.any(String) } + ] + }, + id: 'call-1' + }); + + const requestInfo = JSON.parse(eventData.result.content[1].text); + expect(requestInfo).toMatchObject({ + headers: { + 'content-type': 'application/json', + accept: 'application/json, text/event-stream', + connection: 'keep-alive', + 'mcp-session-id': sessionId, + 'accept-language': '*', + 'user-agent': expect.any(String), + 'accept-encoding': expect.any(String), + 'content-length': expect.any(String) + } + }); + }); + + it('should reject requests without a valid session ID', async () => { + const response = await sendPostRequest(baseUrl, TEST_MESSAGES.toolsList); + + expect(response.status).toBe(400); + const errorData = await response.json(); + expectErrorResponse(errorData, -32000, /Bad Request/); + expect(errorData.id).toBeNull(); + }); + + it('should reject invalid session ID', async () => { + // First initialize to be in valid state + await initializeServer(); + + // Now try with invalid session ID + const response = await sendPostRequest(baseUrl, TEST_MESSAGES.toolsList, 'invalid-session-id'); + + expect(response.status).toBe(404); + const errorData = await response.json(); + expectErrorResponse(errorData, -32001, /Session not found/); + }); + + it('should establish standalone SSE stream and receive server-initiated messages', async () => { + // First initialize to get a session ID + sessionId = await initializeServer(); + + // Open a standalone SSE stream + const sseResponse = await fetch(baseUrl, { + method: 'GET', + headers: { + Accept: 'text/event-stream', + 'mcp-session-id': sessionId, + 'mcp-protocol-version': '2025-03-26' + } + }); + + expect(sseResponse.status).toBe(200); + expect(sseResponse.headers.get('content-type')).toBe('text/event-stream'); + + // Send a notification (server-initiated message) that should appear on SSE stream + const notification: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'notifications/message', + params: { level: 'info', data: 'Test notification' } + }; + + // Send the notification via transport + await transport.send(notification); + + // Read from the stream and verify we got the notification + const text = await readSSEEvent(sseResponse); + + const eventLines = text.split('\n'); + const dataLine = eventLines.find(line => line.startsWith('data:')); + expect(dataLine).toBeDefined(); + + const eventData = JSON.parse(dataLine!.substring(5)); + expect(eventData).toMatchObject({ + jsonrpc: '2.0', + method: 'notifications/message', + params: { level: 'info', data: 'Test notification' } + }); + }); + + it('should not close GET SSE stream after sending multiple server notifications', async () => { + sessionId = await initializeServer(); + + // Open a standalone SSE stream + const sseResponse = await fetch(baseUrl, { + method: 'GET', + headers: { + Accept: 'text/event-stream', + 'mcp-session-id': sessionId, + 'mcp-protocol-version': '2025-03-26' + } + }); + + expect(sseResponse.status).toBe(200); + const reader = sseResponse.body?.getReader(); + + // Send multiple notifications + const notification1: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'notifications/message', + params: { level: 'info', data: 'First notification' } + }; + + // Just send one and verify it comes through - then the stream should stay open + await transport.send(notification1); + + const { value, done } = await reader!.read(); + const text = new TextDecoder().decode(value); + expect(text).toContain('First notification'); + expect(done).toBe(false); // Stream should still be open + }); + + it('should reject second SSE stream for the same session', async () => { + sessionId = await initializeServer(); + + // Open first SSE stream + const firstStream = await fetch(baseUrl, { + method: 'GET', + headers: { + Accept: 'text/event-stream', + 'mcp-session-id': sessionId, + 'mcp-protocol-version': '2025-03-26' + } + }); + + expect(firstStream.status).toBe(200); + + // Try to open a second SSE stream with the same session ID + const secondStream = await fetch(baseUrl, { + method: 'GET', + headers: { + Accept: 'text/event-stream', + 'mcp-session-id': sessionId, + 'mcp-protocol-version': '2025-03-26' + } + }); + + // Should be rejected + expect(secondStream.status).toBe(409); // Conflict + const errorData = await secondStream.json(); + expectErrorResponse(errorData, -32000, /Only one SSE stream is allowed per session/); + }); + + it('should reject GET requests without Accept: text/event-stream header', async () => { + sessionId = await initializeServer(); + + // Try GET without proper Accept header + const response = await fetch(baseUrl, { + method: 'GET', + headers: { + Accept: 'application/json', + 'mcp-session-id': sessionId, + 'mcp-protocol-version': '2025-03-26' + } + }); + + expect(response.status).toBe(406); + const errorData = await response.json(); + expectErrorResponse(errorData, -32000, /Client must accept text\/event-stream/); + }); + + it('should reject POST requests without proper Accept header', async () => { + sessionId = await initializeServer(); + + // Try POST without Accept: text/event-stream + const response = await fetch(baseUrl, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json', // Missing text/event-stream + 'mcp-session-id': sessionId + }, + body: JSON.stringify(TEST_MESSAGES.toolsList) + }); + + expect(response.status).toBe(406); + const errorData = await response.json(); + expectErrorResponse(errorData, -32000, /Client must accept both application\/json and text\/event-stream/); + }); + + it('should reject unsupported Content-Type', async () => { + sessionId = await initializeServer(); + + // Try POST with text/plain Content-Type + const response = await fetch(baseUrl, { + method: 'POST', + headers: { + 'Content-Type': 'text/plain', + Accept: 'application/json, text/event-stream', + 'mcp-session-id': sessionId + }, + body: 'This is plain text' + }); + + expect(response.status).toBe(415); + const errorData = await response.json(); + expectErrorResponse(errorData, -32000, /Content-Type must be application\/json/); + }); + + it('should handle JSON-RPC batch notification messages with 202 response', async () => { + sessionId = await initializeServer(); + + // Send batch of notifications (no IDs) + const batchNotifications: JSONRPCMessage[] = [ + { jsonrpc: '2.0', method: 'someNotification1', params: {} }, + { jsonrpc: '2.0', method: 'someNotification2', params: {} } + ]; + const response = await sendPostRequest(baseUrl, batchNotifications, sessionId); + + expect(response.status).toBe(202); + }); + + it('should handle batch request messages with SSE stream for responses', async () => { + sessionId = await initializeServer(); + + // Send batch of requests + const batchRequests: JSONRPCMessage[] = [ + { jsonrpc: '2.0', method: 'tools/list', params: {}, id: 'req-1' }, + { jsonrpc: '2.0', method: 'tools/call', params: { name: 'greet', arguments: { name: 'BatchUser' } }, id: 'req-2' } + ]; + const response = await sendPostRequest(baseUrl, batchRequests, sessionId); + + expect(response.status).toBe(200); + expect(response.headers.get('content-type')).toBe('text/event-stream'); + + const reader = response.body?.getReader(); + + // The responses may come in any order or together in one chunk + const { value } = await reader!.read(); + const text = new TextDecoder().decode(value); + + // Check that both responses were sent on the same stream + expect(text).toContain('"id":"req-1"'); + expect(text).toContain('"tools"'); // tools/list result + expect(text).toContain('"id":"req-2"'); + expect(text).toContain('Hello, BatchUser'); // tools/call result + }); + + it('should properly handle invalid JSON data', async () => { + sessionId = await initializeServer(); + + // Send invalid JSON + const response = await fetch(baseUrl, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream', + 'mcp-session-id': sessionId + }, + body: 'This is not valid JSON' + }); + + expect(response.status).toBe(400); + const errorData = await response.json(); + expectErrorResponse(errorData, -32700, /Parse error/); + }); + + it('should return 400 error for invalid JSON-RPC messages', async () => { + sessionId = await initializeServer(); + + // Invalid JSON-RPC (missing required jsonrpc version) + const invalidMessage = { method: 'tools/list', params: {}, id: 1 }; // missing jsonrpc version + const response = await sendPostRequest(baseUrl, invalidMessage as JSONRPCMessage, sessionId); + + expect(response.status).toBe(400); + const errorData = await response.json(); + expect(errorData).toMatchObject({ + jsonrpc: '2.0', + error: expect.anything() + }); + }); + + it('should reject requests to uninitialized server', async () => { + // Create a new HTTP server and transport without initializing + const { server: uninitializedServer, transport: uninitializedTransport, baseUrl: uninitializedUrl } = await createTestServer(); + // Transport not used in test but needed for cleanup + + // No initialization, just send a request directly + const uninitializedMessage: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'tools/list', + params: {}, + id: 'uninitialized-test' + }; + + // Send a request to uninitialized server + const response = await sendPostRequest(uninitializedUrl, uninitializedMessage, 'any-session-id'); + + expect(response.status).toBe(400); + const errorData = await response.json(); + expectErrorResponse(errorData, -32000, /Server not initialized/); + + // Cleanup + await stopTestServer({ server: uninitializedServer, transport: uninitializedTransport }); + }); + + it('should send response messages to the connection that sent the request', async () => { + sessionId = await initializeServer(); + + const message1: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'tools/list', + params: {}, + id: 'req-1' + }; + + const message2: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'tools/call', + params: { + name: 'greet', + arguments: { name: 'Connection2' } + }, + id: 'req-2' + }; + + // Make two concurrent fetch connections for different requests + const req1 = sendPostRequest(baseUrl, message1, sessionId); + const req2 = sendPostRequest(baseUrl, message2, sessionId); + + // Get both responses + const [response1, response2] = await Promise.all([req1, req2]); + const reader1 = response1.body?.getReader(); + const reader2 = response2.body?.getReader(); + + // Read responses from each stream (requires each receives its specific response) + const { value: value1 } = await reader1!.read(); + const text1 = new TextDecoder().decode(value1); + expect(text1).toContain('"id":"req-1"'); + expect(text1).toContain('"tools"'); // tools/list result + + const { value: value2 } = await reader2!.read(); + const text2 = new TextDecoder().decode(value2); + expect(text2).toContain('"id":"req-2"'); + expect(text2).toContain('Hello, Connection2'); // tools/call result + }); + + it('should keep stream open after sending server notifications', async () => { + sessionId = await initializeServer(); + + // Open a standalone SSE stream + const sseResponse = await fetch(baseUrl, { + method: 'GET', + headers: { + Accept: 'text/event-stream', + 'mcp-session-id': sessionId, + 'mcp-protocol-version': '2025-03-26' + } + }); + + // Send several server-initiated notifications + await transport.send({ + jsonrpc: '2.0', + method: 'notifications/message', + params: { level: 'info', data: 'First notification' } + }); + + await transport.send({ + jsonrpc: '2.0', + method: 'notifications/message', + params: { level: 'info', data: 'Second notification' } + }); + + // Stream should still be open - it should not close after sending notifications + expect(sseResponse.bodyUsed).toBe(false); + }); + + // The current implementation will close the entire transport for DELETE + // Creating a temporary transport/server where we don't care if it gets closed + it('should properly handle DELETE requests and close session', async () => { + // Setup a temporary server for this test + const tempResult = await createTestServer(); + const tempServer = tempResult.server; + const tempUrl = tempResult.baseUrl; + + // Initialize to get a session ID + const initResponse = await sendPostRequest(tempUrl, TEST_MESSAGES.initialize); + const tempSessionId = initResponse.headers.get('mcp-session-id'); + + // Now DELETE the session + const deleteResponse = await fetch(tempUrl, { + method: 'DELETE', + headers: { + 'mcp-session-id': tempSessionId || '', + 'mcp-protocol-version': '2025-03-26' + } + }); + + expect(deleteResponse.status).toBe(200); + + // Clean up - don't wait indefinitely for server close + tempServer.close(); + }); + + it('should reject DELETE requests with invalid session ID', async () => { + // Initialize the server first to activate it + sessionId = await initializeServer(); + + // Try to delete with invalid session ID + const response = await fetch(baseUrl, { + method: 'DELETE', + headers: { + 'mcp-session-id': 'invalid-session-id', + 'mcp-protocol-version': '2025-03-26' + } + }); + + expect(response.status).toBe(404); + const errorData = await response.json(); + expectErrorResponse(errorData, -32001, /Session not found/); + }); + + describe('protocol version header validation', () => { + it('should accept requests with matching protocol version', async () => { + sessionId = await initializeServer(); + + // Send request with matching protocol version + const response = await sendPostRequest(baseUrl, TEST_MESSAGES.toolsList, sessionId); + + expect(response.status).toBe(200); + }); + + it('should accept requests without protocol version header', async () => { + sessionId = await initializeServer(); + + // Send request without protocol version header + const response = await fetch(baseUrl, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream', + 'mcp-session-id': sessionId + // No mcp-protocol-version header + }, + body: JSON.stringify(TEST_MESSAGES.toolsList) + }); + + expect(response.status).toBe(200); + }); + + it('should reject requests with unsupported protocol version', async () => { + sessionId = await initializeServer(); + + // Send request with unsupported protocol version + const response = await fetch(baseUrl, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream', + 'mcp-session-id': sessionId, + 'mcp-protocol-version': '1999-01-01' // Unsupported version + }, + body: JSON.stringify(TEST_MESSAGES.toolsList) + }); + + expect(response.status).toBe(400); + const errorData = await response.json(); + expectErrorResponse(errorData, -32000, /Bad Request: Unsupported protocol version \(supported versions: .+\)/); + }); + + it('should accept when protocol version differs from negotiated version', async () => { + sessionId = await initializeServer(); + + // Spy on console.warn to verify warning is logged + const warnSpy = jest.spyOn(console, 'warn').mockImplementation(); + + // Send request with different but supported protocol version + const response = await fetch(baseUrl, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream', + 'mcp-session-id': sessionId, + 'mcp-protocol-version': '2024-11-05' // Different but supported version + }, + body: JSON.stringify(TEST_MESSAGES.toolsList) + }); + + // Request should still succeed + expect(response.status).toBe(200); + + warnSpy.mockRestore(); + }); + + it('should handle protocol version validation for GET requests', async () => { + sessionId = await initializeServer(); + + // GET request with unsupported protocol version + const response = await fetch(baseUrl, { + method: 'GET', + headers: { + Accept: 'text/event-stream', + 'mcp-session-id': sessionId, + 'mcp-protocol-version': 'invalid-version' + } + }); + + expect(response.status).toBe(400); + const errorData = await response.json(); + expectErrorResponse(errorData, -32000, /Bad Request: Unsupported protocol version \(supported versions: .+\)/); + }); + + it('should handle protocol version validation for DELETE requests', async () => { + sessionId = await initializeServer(); + + // DELETE request with unsupported protocol version + const response = await fetch(baseUrl, { + method: 'DELETE', + headers: { + 'mcp-session-id': sessionId, + 'mcp-protocol-version': 'invalid-version' + } + }); + + expect(response.status).toBe(400); + const errorData = await response.json(); + expectErrorResponse(errorData, -32000, /Bad Request: Unsupported protocol version \(supported versions: .+\)/); + }); }); +}); - it("should accept when protocol version differs from negotiated version", async () => { - sessionId = await initializeServer(); - - // Spy on console.warn to verify warning is logged - const warnSpy = jest.spyOn(console, 'warn').mockImplementation(); +describe('StreamableHTTPServerTransport with AuthInfo', () => { + let server: Server; + let transport: StreamableHTTPServerTransport; + let baseUrl: URL; + let sessionId: string; - // Send request with different but supported protocol version - const response = await fetch(baseUrl, { - method: "POST", - headers: { - "Content-Type": "application/json", - Accept: "application/json, text/event-stream", - "mcp-session-id": sessionId, - "mcp-protocol-version": "2024-11-05", // Different but supported version - }, - body: JSON.stringify(TEST_MESSAGES.toolsList), - }); - - // Request should still succeed - expect(response.status).toBe(200); + beforeEach(async () => { + const result = await createTestAuthServer(); + server = result.server; + transport = result.transport; + baseUrl = result.baseUrl; + }); - warnSpy.mockRestore(); + afterEach(async () => { + await stopTestServer({ server, transport }); }); - it("should handle protocol version validation for GET requests", async () => { - sessionId = await initializeServer(); + async function initializeServer(): Promise { + const response = await sendPostRequest(baseUrl, TEST_MESSAGES.initialize); - // GET request with unsupported protocol version - const response = await fetch(baseUrl, { - method: "GET", - headers: { - Accept: "text/event-stream", - "mcp-session-id": sessionId, - "mcp-protocol-version": "invalid-version", - }, - }); + expect(response.status).toBe(200); + const newSessionId = response.headers.get('mcp-session-id'); + expect(newSessionId).toBeDefined(); + return newSessionId as string; + } - expect(response.status).toBe(400); - const errorData = await response.json(); - expectErrorResponse(errorData, -32000, /Bad Request: Unsupported protocol version \(supported versions: .+\)/); + it('should call a tool with authInfo', async () => { + sessionId = await initializeServer(); + + const toolCallMessage: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'tools/call', + params: { + name: 'profile', + arguments: { active: true } + }, + id: 'call-1' + }; + + const response = await sendPostRequest(baseUrl, toolCallMessage, sessionId, { authorization: 'Bearer test-token' }); + expect(response.status).toBe(200); + + const text = await readSSEEvent(response); + const eventLines = text.split('\n'); + const dataLine = eventLines.find(line => line.startsWith('data:')); + expect(dataLine).toBeDefined(); + + const eventData = JSON.parse(dataLine!.substring(5)); + expect(eventData).toMatchObject({ + jsonrpc: '2.0', + result: { + content: [ + { + type: 'text', + text: 'Active profile from token: test-token!' + } + ] + }, + id: 'call-1' + }); + }); + + it('should calls tool without authInfo when it is optional', async () => { + sessionId = await initializeServer(); + + const toolCallMessage: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'tools/call', + params: { + name: 'profile', + arguments: { active: false } + }, + id: 'call-1' + }; + + const response = await sendPostRequest(baseUrl, toolCallMessage, sessionId); + expect(response.status).toBe(200); + + const text = await readSSEEvent(response); + const eventLines = text.split('\n'); + const dataLine = eventLines.find(line => line.startsWith('data:')); + expect(dataLine).toBeDefined(); + + const eventData = JSON.parse(dataLine!.substring(5)); + expect(eventData).toMatchObject({ + jsonrpc: '2.0', + result: { + content: [ + { + type: 'text', + text: 'Inactive profile from token: undefined!' + } + ] + }, + id: 'call-1' + }); }); +}); - it("should handle protocol version validation for DELETE requests", async () => { - sessionId = await initializeServer(); - - // DELETE request with unsupported protocol version - const response = await fetch(baseUrl, { - method: "DELETE", - headers: { - "mcp-session-id": sessionId, - "mcp-protocol-version": "invalid-version", - }, - }); +// Test JSON Response Mode +describe('StreamableHTTPServerTransport with JSON Response Mode', () => { + let server: Server; + let transport: StreamableHTTPServerTransport; + let baseUrl: URL; + let sessionId: string; + + beforeEach(async () => { + const result = await createTestServer({ sessionIdGenerator: () => randomUUID(), enableJsonResponse: true }); + server = result.server; + transport = result.transport; + baseUrl = result.baseUrl; + + // Initialize and get session ID + const initResponse = await sendPostRequest(baseUrl, TEST_MESSAGES.initialize); + + sessionId = initResponse.headers.get('mcp-session-id') as string; + }); + + afterEach(async () => { + await stopTestServer({ server, transport }); + }); + + it('should return JSON response for a single request', async () => { + const toolsListMessage: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'tools/list', + params: {}, + id: 'json-req-1' + }; + + const response = await sendPostRequest(baseUrl, toolsListMessage, sessionId); + + expect(response.status).toBe(200); + expect(response.headers.get('content-type')).toBe('application/json'); + + const result = await response.json(); + expect(result).toMatchObject({ + jsonrpc: '2.0', + result: expect.objectContaining({ + tools: expect.arrayContaining([expect.objectContaining({ name: 'greet' })]) + }), + id: 'json-req-1' + }); + }); + + it('should return JSON response for batch requests', async () => { + const batchMessages: JSONRPCMessage[] = [ + { jsonrpc: '2.0', method: 'tools/list', params: {}, id: 'batch-1' }, + { jsonrpc: '2.0', method: 'tools/call', params: { name: 'greet', arguments: { name: 'JSON' } }, id: 'batch-2' } + ]; + + const response = await sendPostRequest(baseUrl, batchMessages, sessionId); + + expect(response.status).toBe(200); + expect(response.headers.get('content-type')).toBe('application/json'); + + const results = await response.json(); + expect(Array.isArray(results)).toBe(true); + expect(results).toHaveLength(2); + + // Batch responses can come in any order + const listResponse = results.find((r: { id?: string }) => r.id === 'batch-1'); + const callResponse = results.find((r: { id?: string }) => r.id === 'batch-2'); + + expect(listResponse).toEqual( + expect.objectContaining({ + jsonrpc: '2.0', + id: 'batch-1', + result: expect.objectContaining({ + tools: expect.arrayContaining([expect.objectContaining({ name: 'greet' })]) + }) + }) + ); + + expect(callResponse).toEqual( + expect.objectContaining({ + jsonrpc: '2.0', + id: 'batch-2', + result: expect.objectContaining({ + content: expect.arrayContaining([expect.objectContaining({ type: 'text', text: 'Hello, JSON!' })]) + }) + }) + ); + }); +}); - expect(response.status).toBe(400); - const errorData = await response.json(); - expectErrorResponse(errorData, -32000, /Bad Request: Unsupported protocol version \(supported versions: .+\)/); +// Test pre-parsed body handling +describe('StreamableHTTPServerTransport with pre-parsed body', () => { + let server: Server; + let transport: StreamableHTTPServerTransport; + let baseUrl: URL; + let sessionId: string; + let parsedBody: unknown = null; + + beforeEach(async () => { + const result = await createTestServer({ + customRequestHandler: async (req, res) => { + try { + if (parsedBody !== null) { + await transport.handleRequest(req, res, parsedBody); + parsedBody = null; // Reset after use + } else { + await transport.handleRequest(req, res); + } + } catch (error) { + console.error('Error handling request:', error); + if (!res.headersSent) res.writeHead(500).end(); + } + }, + sessionIdGenerator: () => randomUUID() + }); + + server = result.server; + transport = result.transport; + baseUrl = result.baseUrl; + + // Initialize and get session ID + const initResponse = await sendPostRequest(baseUrl, TEST_MESSAGES.initialize); + sessionId = initResponse.headers.get('mcp-session-id') as string; + }); + + afterEach(async () => { + await stopTestServer({ server, transport }); + }); + + it('should accept pre-parsed request body', async () => { + // Set up the pre-parsed body + parsedBody = { + jsonrpc: '2.0', + method: 'tools/list', + params: {}, + id: 'preparsed-1' + }; + + // Send an empty body since we'll use pre-parsed body + const response = await fetch(baseUrl, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream', + 'mcp-session-id': sessionId + }, + // Empty body - we're testing pre-parsed body + body: '' + }); + + expect(response.status).toBe(200); + expect(response.headers.get('content-type')).toBe('text/event-stream'); + + const reader = response.body?.getReader(); + const { value } = await reader!.read(); + const text = new TextDecoder().decode(value); + + // Verify the response used the pre-parsed body + expect(text).toContain('"id":"preparsed-1"'); + expect(text).toContain('"tools"'); + }); + + it('should handle pre-parsed batch messages', async () => { + parsedBody = [ + { jsonrpc: '2.0', method: 'tools/list', params: {}, id: 'batch-1' }, + { jsonrpc: '2.0', method: 'tools/call', params: { name: 'greet', arguments: { name: 'PreParsed' } }, id: 'batch-2' } + ]; + + const response = await fetch(baseUrl, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream', + 'mcp-session-id': sessionId + }, + body: '' // Empty as we're using pre-parsed + }); + + expect(response.status).toBe(200); + + const reader = response.body?.getReader(); + const { value } = await reader!.read(); + const text = new TextDecoder().decode(value); + + expect(text).toContain('"id":"batch-1"'); + expect(text).toContain('"tools"'); + }); + + it('should prefer pre-parsed body over request body', async () => { + // Set pre-parsed to tools/list + parsedBody = { + jsonrpc: '2.0', + method: 'tools/list', + params: {}, + id: 'preparsed-wins' + }; + + // Send actual body with tools/call - should be ignored + const response = await fetch(baseUrl, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream', + 'mcp-session-id': sessionId + }, + body: JSON.stringify({ + jsonrpc: '2.0', + method: 'tools/call', + params: { name: 'greet', arguments: { name: 'Ignored' } }, + id: 'ignored-id' + }) + }); + + expect(response.status).toBe(200); + + const reader = response.body?.getReader(); + const { value } = await reader!.read(); + const text = new TextDecoder().decode(value); + + // Should have processed the pre-parsed body + expect(text).toContain('"id":"preparsed-wins"'); + expect(text).toContain('"tools"'); + expect(text).not.toContain('"ignored-id"'); }); - }); }); -describe("StreamableHTTPServerTransport with AuthInfo", () => { - let server: Server; - let transport: StreamableHTTPServerTransport; - let baseUrl: URL; - let sessionId: string; - - beforeEach(async () => { - const result = await createTestAuthServer(); - server = result.server; - transport = result.transport; - baseUrl = result.baseUrl; - }); - - afterEach(async () => { - await stopTestServer({ server, transport }); - }); - - async function initializeServer(): Promise { - const response = await sendPostRequest(baseUrl, TEST_MESSAGES.initialize); - - expect(response.status).toBe(200); - const newSessionId = response.headers.get("mcp-session-id"); - expect(newSessionId).toBeDefined(); - return newSessionId as string; - } - - it("should call a tool with authInfo", async () => { - sessionId = await initializeServer(); - - const toolCallMessage: JSONRPCMessage = { - jsonrpc: "2.0", - method: "tools/call", - params: { - name: "profile", - arguments: { active: true }, - }, - id: "call-1", - }; +// Test resumability support +describe('StreamableHTTPServerTransport with resumability', () => { + let server: Server; + let transport: StreamableHTTPServerTransport; + let baseUrl: URL; + let sessionId: string; + let mcpServer: McpServer; + const storedEvents: Map = new Map(); + + // Simple implementation of EventStore + const eventStore: EventStore = { + async storeEvent(streamId: string, message: JSONRPCMessage): Promise { + const eventId = `${streamId}_${randomUUID()}`; + storedEvents.set(eventId, { eventId, message }); + return eventId; + }, - const response = await sendPostRequest(baseUrl, toolCallMessage, sessionId, { 'authorization': 'Bearer test-token' }); - expect(response.status).toBe(200); - - const text = await readSSEEvent(response); - const eventLines = text.split("\n"); - const dataLine = eventLines.find(line => line.startsWith("data:")); - expect(dataLine).toBeDefined(); - - const eventData = JSON.parse(dataLine!.substring(5)); - expect(eventData).toMatchObject({ - jsonrpc: "2.0", - result: { - content: [ - { - type: "text", - text: "Active profile from token: test-token!", - }, - ], - }, - id: "call-1", - }); - }); - - it("should calls tool without authInfo when it is optional", async () => { - sessionId = await initializeServer(); - - const toolCallMessage: JSONRPCMessage = { - jsonrpc: "2.0", - method: "tools/call", - params: { - name: "profile", - arguments: { active: false }, - }, - id: "call-1", + async replayEventsAfter( + lastEventId: EventId, + { + send + }: { + send: (eventId: EventId, message: JSONRPCMessage) => Promise; + } + ): Promise { + const streamId = lastEventId.split('_')[0]; + // Extract stream ID from the event ID + // For test simplicity, just return all events with matching streamId that aren't the lastEventId + for (const [eventId, { message }] of storedEvents.entries()) { + if (eventId.startsWith(streamId) && eventId !== lastEventId) { + await send(eventId, message); + } + } + return streamId; + } }; - const response = await sendPostRequest(baseUrl, toolCallMessage, sessionId); - expect(response.status).toBe(200); - - const text = await readSSEEvent(response); - const eventLines = text.split("\n"); - const dataLine = eventLines.find(line => line.startsWith("data:")); - expect(dataLine).toBeDefined(); - - const eventData = JSON.parse(dataLine!.substring(5)); - expect(eventData).toMatchObject({ - jsonrpc: "2.0", - result: { - content: [ - { - type: "text", - text: "Inactive profile from token: undefined!", - }, - ], - }, - id: "call-1", - }); - }); + beforeEach(async () => { + storedEvents.clear(); + const result = await createTestServer({ + sessionIdGenerator: () => randomUUID(), + eventStore + }); + + server = result.server; + transport = result.transport; + baseUrl = result.baseUrl; + mcpServer = result.mcpServer; + + // Verify resumability is enabled on the transport + expect(transport['_eventStore']).toBeDefined(); + + // Initialize the server + const initResponse = await sendPostRequest(baseUrl, TEST_MESSAGES.initialize); + sessionId = initResponse.headers.get('mcp-session-id') as string; + expect(sessionId).toBeDefined(); + }); + + afterEach(async () => { + await stopTestServer({ server, transport }); + storedEvents.clear(); + }); + + it('should store and include event IDs in server SSE messages', async () => { + // Open a standalone SSE stream + const sseResponse = await fetch(baseUrl, { + method: 'GET', + headers: { + Accept: 'text/event-stream', + 'mcp-session-id': sessionId, + 'mcp-protocol-version': '2025-03-26' + } + }); + + expect(sseResponse.status).toBe(200); + expect(sseResponse.headers.get('content-type')).toBe('text/event-stream'); + + // Send a notification that should be stored with an event ID + const notification: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'notifications/message', + params: { level: 'info', data: 'Test notification with event ID' } + }; + + // Send the notification via transport + await transport.send(notification); + + // Read from the stream and verify we got the notification with an event ID + const reader = sseResponse.body?.getReader(); + const { value } = await reader!.read(); + const text = new TextDecoder().decode(value); + + // The response should contain an event ID + expect(text).toContain('id: '); + expect(text).toContain('"method":"notifications/message"'); + + // Extract the event ID + const idMatch = text.match(/id: ([^\n]+)/); + expect(idMatch).toBeTruthy(); + + // Verify the event was stored + const eventId = idMatch![1]; + expect(storedEvents.has(eventId)).toBe(true); + const storedEvent = storedEvents.get(eventId); + expect(eventId.startsWith('_GET_stream')).toBe(true); + expect(storedEvent?.message).toMatchObject(notification); + }); + + it('should store and replay MCP server tool notifications', async () => { + // Establish a standalone SSE stream + const sseResponse = await fetch(baseUrl, { + method: 'GET', + headers: { + Accept: 'text/event-stream', + 'mcp-session-id': sessionId + } + }); + expect(sseResponse.status).toBe(200); // Send a server notification through the MCP server + await mcpServer.server.sendLoggingMessage({ level: 'info', data: 'First notification from MCP server' }); + + // Read the notification from the SSE stream + const reader = sseResponse.body?.getReader(); + const { value } = await reader!.read(); + const text = new TextDecoder().decode(value); + + // Verify the notification was sent with an event ID + expect(text).toContain('id: '); + expect(text).toContain('First notification from MCP server'); + + // Extract the event ID + const idMatch = text.match(/id: ([^\n]+)/); + expect(idMatch).toBeTruthy(); + const firstEventId = idMatch![1]; + + // Send a second notification + await mcpServer.server.sendLoggingMessage({ level: 'info', data: 'Second notification from MCP server' }); + + // Close the first SSE stream to simulate a disconnect + await reader!.cancel(); + + // Reconnect with the Last-Event-ID to get missed messages + const reconnectResponse = await fetch(baseUrl, { + method: 'GET', + headers: { + Accept: 'text/event-stream', + 'mcp-session-id': sessionId, + 'mcp-protocol-version': '2025-03-26', + 'last-event-id': firstEventId + } + }); + + expect(reconnectResponse.status).toBe(200); + + // Read the replayed notification + const reconnectReader = reconnectResponse.body?.getReader(); + const reconnectData = await reconnectReader!.read(); + const reconnectText = new TextDecoder().decode(reconnectData.value); + + // Verify we received the second notification that was sent after our stored eventId + expect(reconnectText).toContain('Second notification from MCP server'); + expect(reconnectText).toContain('id: '); + }); }); -// Test JSON Response Mode -describe("StreamableHTTPServerTransport with JSON Response Mode", () => { - let server: Server; - let transport: StreamableHTTPServerTransport; - let baseUrl: URL; - let sessionId: string; - - beforeEach(async () => { - const result = await createTestServer({ sessionIdGenerator: (() => randomUUID()), enableJsonResponse: true }); - server = result.server; - transport = result.transport; - baseUrl = result.baseUrl; - - // Initialize and get session ID - const initResponse = await sendPostRequest(baseUrl, TEST_MESSAGES.initialize); - - sessionId = initResponse.headers.get("mcp-session-id") as string; - }); - - afterEach(async () => { - await stopTestServer({ server, transport }); - }); - - it("should return JSON response for a single request", async () => { - const toolsListMessage: JSONRPCMessage = { - jsonrpc: "2.0", - method: "tools/list", - params: {}, - id: "json-req-1", - }; - - const response = await sendPostRequest(baseUrl, toolsListMessage, sessionId); - - expect(response.status).toBe(200); - expect(response.headers.get("content-type")).toBe("application/json"); - - const result = await response.json(); - expect(result).toMatchObject({ - jsonrpc: "2.0", - result: expect.objectContaining({ - tools: expect.arrayContaining([ - expect.objectContaining({ name: "greet" }) - ]) - }), - id: "json-req-1" - }); - }); - - it("should return JSON response for batch requests", async () => { - const batchMessages: JSONRPCMessage[] = [ - { jsonrpc: "2.0", method: "tools/list", params: {}, id: "batch-1" }, - { jsonrpc: "2.0", method: "tools/call", params: { name: "greet", arguments: { name: "JSON" } }, id: "batch-2" } - ]; - - const response = await sendPostRequest(baseUrl, batchMessages, sessionId); - - expect(response.status).toBe(200); - expect(response.headers.get("content-type")).toBe("application/json"); - - const results = await response.json(); - expect(Array.isArray(results)).toBe(true); - expect(results).toHaveLength(2); - - // Batch responses can come in any order - const listResponse = results.find((r: { id?: string }) => r.id === "batch-1"); - const callResponse = results.find((r: { id?: string }) => r.id === "batch-2"); - - expect(listResponse).toEqual(expect.objectContaining({ - jsonrpc: "2.0", - id: "batch-1", - result: expect.objectContaining({ - tools: expect.arrayContaining([ - expect.objectContaining({ name: "greet" }) - ]) - }) - })); - - expect(callResponse).toEqual(expect.objectContaining({ - jsonrpc: "2.0", - id: "batch-2", - result: expect.objectContaining({ - content: expect.arrayContaining([ - expect.objectContaining({ type: "text", text: "Hello, JSON!" }) - ]) - }) - })); - }); +// Test stateless mode +describe('StreamableHTTPServerTransport in stateless mode', () => { + let server: Server; + let transport: StreamableHTTPServerTransport; + let baseUrl: URL; + + beforeEach(async () => { + const result = await createTestServer({ sessionIdGenerator: undefined }); + server = result.server; + transport = result.transport; + baseUrl = result.baseUrl; + }); + + afterEach(async () => { + await stopTestServer({ server, transport }); + }); + + it('should operate without session ID validation', async () => { + // Initialize the server first + const initResponse = await sendPostRequest(baseUrl, TEST_MESSAGES.initialize); + + expect(initResponse.status).toBe(200); + // Should NOT have session ID header in stateless mode + expect(initResponse.headers.get('mcp-session-id')).toBeNull(); + + // Try request without session ID - should work in stateless mode + const toolsResponse = await sendPostRequest(baseUrl, TEST_MESSAGES.toolsList); + + expect(toolsResponse.status).toBe(200); + }); + + it('should handle POST requests with various session IDs in stateless mode', async () => { + await sendPostRequest(baseUrl, TEST_MESSAGES.initialize); + + // Try with a random session ID - should be accepted + const response1 = await fetch(baseUrl, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream', + 'mcp-session-id': 'random-id-1' + }, + body: JSON.stringify({ jsonrpc: '2.0', method: 'tools/list', params: {}, id: 't1' }) + }); + expect(response1.status).toBe(200); + + // Try with another random session ID - should also be accepted + const response2 = await fetch(baseUrl, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream', + 'mcp-session-id': 'different-id-2' + }, + body: JSON.stringify({ jsonrpc: '2.0', method: 'tools/list', params: {}, id: 't2' }) + }); + expect(response2.status).toBe(200); + }); + + it('should reject second SSE stream even in stateless mode', async () => { + // Despite no session ID requirement, the transport still only allows + // one standalone SSE stream at a time + + // Initialize the server first + await sendPostRequest(baseUrl, TEST_MESSAGES.initialize); + + // Open first SSE stream + const stream1 = await fetch(baseUrl, { + method: 'GET', + headers: { + Accept: 'text/event-stream', + 'mcp-protocol-version': '2025-03-26' + } + }); + expect(stream1.status).toBe(200); + + // Open second SSE stream - should still be rejected, stateless mode still only allows one + const stream2 = await fetch(baseUrl, { + method: 'GET', + headers: { + Accept: 'text/event-stream', + 'mcp-protocol-version': '2025-03-26' + } + }); + expect(stream2.status).toBe(409); // Conflict - only one stream allowed + }); }); -// Test pre-parsed body handling -describe("StreamableHTTPServerTransport with pre-parsed body", () => { - let server: Server; - let transport: StreamableHTTPServerTransport; - let baseUrl: URL; - let sessionId: string; - let parsedBody: unknown = null; - - beforeEach(async () => { - const result = await createTestServer({ - customRequestHandler: async (req, res) => { - try { - if (parsedBody !== null) { - await transport.handleRequest(req, res, parsedBody); - parsedBody = null; // Reset after use - } else { - await transport.handleRequest(req, res); - } - } catch (error) { - console.error("Error handling request:", error); - if (!res.headersSent) res.writeHead(500).end(); - } - }, - sessionIdGenerator: (() => randomUUID()) - }); - - server = result.server; - transport = result.transport; - baseUrl = result.baseUrl; - - // Initialize and get session ID - const initResponse = await sendPostRequest(baseUrl, TEST_MESSAGES.initialize); - sessionId = initResponse.headers.get("mcp-session-id") as string; - }); - - afterEach(async () => { - await stopTestServer({ server, transport }); - }); - - it("should accept pre-parsed request body", async () => { - // Set up the pre-parsed body - parsedBody = { - jsonrpc: "2.0", - method: "tools/list", - params: {}, - id: "preparsed-1", - }; +// Test onsessionclosed callback +describe('StreamableHTTPServerTransport onsessionclosed callback', () => { + it('should call onsessionclosed callback when session is closed via DELETE', async () => { + const mockCallback = jest.fn(); + + // Create server with onsessionclosed callback + const result = await createTestServer({ + sessionIdGenerator: () => randomUUID(), + onsessionclosed: mockCallback + }); + + const tempServer = result.server; + const tempUrl = result.baseUrl; + + // Initialize to get a session ID + const initResponse = await sendPostRequest(tempUrl, TEST_MESSAGES.initialize); + const tempSessionId = initResponse.headers.get('mcp-session-id'); + expect(tempSessionId).toBeDefined(); + + // DELETE the session + const deleteResponse = await fetch(tempUrl, { + method: 'DELETE', + headers: { + 'mcp-session-id': tempSessionId || '', + 'mcp-protocol-version': '2025-03-26' + } + }); + + expect(deleteResponse.status).toBe(200); + expect(mockCallback).toHaveBeenCalledWith(tempSessionId); + expect(mockCallback).toHaveBeenCalledTimes(1); + + // Clean up + tempServer.close(); + }); + + it('should not call onsessionclosed callback when not provided', async () => { + // Create server without onsessionclosed callback + const result = await createTestServer({ + sessionIdGenerator: () => randomUUID() + }); + + const tempServer = result.server; + const tempUrl = result.baseUrl; + + // Initialize to get a session ID + const initResponse = await sendPostRequest(tempUrl, TEST_MESSAGES.initialize); + const tempSessionId = initResponse.headers.get('mcp-session-id'); + + // DELETE the session - should not throw error + const deleteResponse = await fetch(tempUrl, { + method: 'DELETE', + headers: { + 'mcp-session-id': tempSessionId || '', + 'mcp-protocol-version': '2025-03-26' + } + }); + + expect(deleteResponse.status).toBe(200); + + // Clean up + tempServer.close(); + }); + + it('should not call onsessionclosed callback for invalid session DELETE', async () => { + const mockCallback = jest.fn(); + + // Create server with onsessionclosed callback + const result = await createTestServer({ + sessionIdGenerator: () => randomUUID(), + onsessionclosed: mockCallback + }); + + const tempServer = result.server; + const tempUrl = result.baseUrl; + + // Initialize to get a valid session + await sendPostRequest(tempUrl, TEST_MESSAGES.initialize); + + // Try to DELETE with invalid session ID + const deleteResponse = await fetch(tempUrl, { + method: 'DELETE', + headers: { + 'mcp-session-id': 'invalid-session-id', + 'mcp-protocol-version': '2025-03-26' + } + }); + + expect(deleteResponse.status).toBe(404); + expect(mockCallback).not.toHaveBeenCalled(); + + // Clean up + tempServer.close(); + }); + + it('should call onsessionclosed callback with correct session ID when multiple sessions exist', async () => { + const mockCallback = jest.fn(); + + // Create first server + const result1 = await createTestServer({ + sessionIdGenerator: () => randomUUID(), + onsessionclosed: mockCallback + }); + + const server1 = result1.server; + const url1 = result1.baseUrl; + + // Create second server + const result2 = await createTestServer({ + sessionIdGenerator: () => randomUUID(), + onsessionclosed: mockCallback + }); + + const server2 = result2.server; + const url2 = result2.baseUrl; + + // Initialize both servers + const initResponse1 = await sendPostRequest(url1, TEST_MESSAGES.initialize); + const sessionId1 = initResponse1.headers.get('mcp-session-id'); + + const initResponse2 = await sendPostRequest(url2, TEST_MESSAGES.initialize); + const sessionId2 = initResponse2.headers.get('mcp-session-id'); + + expect(sessionId1).toBeDefined(); + expect(sessionId2).toBeDefined(); + expect(sessionId1).not.toBe(sessionId2); - // Send an empty body since we'll use pre-parsed body - const response = await fetch(baseUrl, { - method: "POST", - headers: { - "Content-Type": "application/json", - Accept: "application/json, text/event-stream", - "mcp-session-id": sessionId, - }, - // Empty body - we're testing pre-parsed body - body: "" + // DELETE first session + const deleteResponse1 = await fetch(url1, { + method: 'DELETE', + headers: { + 'mcp-session-id': sessionId1 || '', + 'mcp-protocol-version': '2025-03-26' + } + }); + + expect(deleteResponse1.status).toBe(200); + expect(mockCallback).toHaveBeenCalledWith(sessionId1); + expect(mockCallback).toHaveBeenCalledTimes(1); + + // DELETE second session + const deleteResponse2 = await fetch(url2, { + method: 'DELETE', + headers: { + 'mcp-session-id': sessionId2 || '', + 'mcp-protocol-version': '2025-03-26' + } + }); + + expect(deleteResponse2.status).toBe(200); + expect(mockCallback).toHaveBeenCalledWith(sessionId2); + expect(mockCallback).toHaveBeenCalledTimes(2); + + // Clean up + server1.close(); + server2.close(); }); +}); + +// Test async callbacks for onsessioninitialized and onsessionclosed +describe('StreamableHTTPServerTransport async callbacks', () => { + it('should support async onsessioninitialized callback', async () => { + const initializationOrder: string[] = []; - expect(response.status).toBe(200); - expect(response.headers.get("content-type")).toBe("text/event-stream"); + // Create server with async onsessioninitialized callback + const result = await createTestServer({ + sessionIdGenerator: () => randomUUID(), + onsessioninitialized: async (sessionId: string) => { + initializationOrder.push('async-start'); + // Simulate async operation + await new Promise(resolve => setTimeout(resolve, 10)); + initializationOrder.push('async-end'); + initializationOrder.push(sessionId); + } + }); + + const tempServer = result.server; + const tempUrl = result.baseUrl; - const reader = response.body?.getReader(); - const { value } = await reader!.read(); - const text = new TextDecoder().decode(value); + // Initialize to trigger the callback + const initResponse = await sendPostRequest(tempUrl, TEST_MESSAGES.initialize); + const tempSessionId = initResponse.headers.get('mcp-session-id'); - // Verify the response used the pre-parsed body - expect(text).toContain('"id":"preparsed-1"'); - expect(text).toContain('"tools"'); - }); + // Give time for async callback to complete + await new Promise(resolve => setTimeout(resolve, 50)); - it("should handle pre-parsed batch messages", async () => { - parsedBody = [ - { jsonrpc: "2.0", method: "tools/list", params: {}, id: "batch-1" }, - { jsonrpc: "2.0", method: "tools/call", params: { name: "greet", arguments: { name: "PreParsed" } }, id: "batch-2" } - ]; + expect(initializationOrder).toEqual(['async-start', 'async-end', tempSessionId]); - const response = await fetch(baseUrl, { - method: "POST", - headers: { - "Content-Type": "application/json", - Accept: "application/json, text/event-stream", - "mcp-session-id": sessionId, - }, - body: "" // Empty as we're using pre-parsed + // Clean up + tempServer.close(); }); - expect(response.status).toBe(200); + it('should support sync onsessioninitialized callback (backwards compatibility)', async () => { + const capturedSessionId: string[] = []; + + // Create server with sync onsessioninitialized callback + const result = await createTestServer({ + sessionIdGenerator: () => randomUUID(), + onsessioninitialized: (sessionId: string) => { + capturedSessionId.push(sessionId); + } + }); - const reader = response.body?.getReader(); - const { value } = await reader!.read(); - const text = new TextDecoder().decode(value); - - expect(text).toContain('"id":"batch-1"'); - expect(text).toContain('"tools"'); - }); - - it("should prefer pre-parsed body over request body", async () => { - // Set pre-parsed to tools/list - parsedBody = { - jsonrpc: "2.0", - method: "tools/list", - params: {}, - id: "preparsed-wins", - }; - - // Send actual body with tools/call - should be ignored - const response = await fetch(baseUrl, { - method: "POST", - headers: { - "Content-Type": "application/json", - Accept: "application/json, text/event-stream", - "mcp-session-id": sessionId, - }, - body: JSON.stringify({ - jsonrpc: "2.0", - method: "tools/call", - params: { name: "greet", arguments: { name: "Ignored" } }, - id: "ignored-id" - }) + const tempServer = result.server; + const tempUrl = result.baseUrl; + + // Initialize to trigger the callback + const initResponse = await sendPostRequest(tempUrl, TEST_MESSAGES.initialize); + const tempSessionId = initResponse.headers.get('mcp-session-id'); + + expect(capturedSessionId).toEqual([tempSessionId]); + + // Clean up + tempServer.close(); }); - expect(response.status).toBe(200); - - const reader = response.body?.getReader(); - const { value } = await reader!.read(); - const text = new TextDecoder().decode(value); + it('should support async onsessionclosed callback', async () => { + const closureOrder: string[] = []; - // Should have processed the pre-parsed body - expect(text).toContain('"id":"preparsed-wins"'); - expect(text).toContain('"tools"'); - expect(text).not.toContain('"ignored-id"'); - }); -}); + // Create server with async onsessionclosed callback + const result = await createTestServer({ + sessionIdGenerator: () => randomUUID(), + onsessionclosed: async (sessionId: string) => { + closureOrder.push('async-close-start'); + // Simulate async operation + await new Promise(resolve => setTimeout(resolve, 10)); + closureOrder.push('async-close-end'); + closureOrder.push(sessionId); + } + }); + + const tempServer = result.server; + const tempUrl = result.baseUrl; + + // Initialize to get a session ID + const initResponse = await sendPostRequest(tempUrl, TEST_MESSAGES.initialize); + const tempSessionId = initResponse.headers.get('mcp-session-id'); + expect(tempSessionId).toBeDefined(); + + // DELETE the session + const deleteResponse = await fetch(tempUrl, { + method: 'DELETE', + headers: { + 'mcp-session-id': tempSessionId || '', + 'mcp-protocol-version': '2025-03-26' + } + }); + + expect(deleteResponse.status).toBe(200); + + // Give time for async callback to complete + await new Promise(resolve => setTimeout(resolve, 50)); + + expect(closureOrder).toEqual(['async-close-start', 'async-close-end', tempSessionId]); + + // Clean up + tempServer.close(); + }); + + it('should propagate errors from async onsessioninitialized callback', async () => { + const consoleErrorSpy = jest.spyOn(console, 'error').mockImplementation(); + + // Create server with async onsessioninitialized callback that throws + const result = await createTestServer({ + sessionIdGenerator: () => randomUUID(), + onsessioninitialized: async (_sessionId: string) => { + throw new Error('Async initialization error'); + } + }); -// Test resumability support -describe("StreamableHTTPServerTransport with resumability", () => { - let server: Server; - let transport: StreamableHTTPServerTransport; - let baseUrl: URL; - let sessionId: string; - let mcpServer: McpServer; - const storedEvents: Map = new Map(); - - // Simple implementation of EventStore - const eventStore: EventStore = { - - async storeEvent(streamId: string, message: JSONRPCMessage): Promise { - const eventId = `${streamId}_${randomUUID()}`; - storedEvents.set(eventId, { eventId, message }); - return eventId; - }, - - async replayEventsAfter(lastEventId: EventId, { send }: { - send: (eventId: EventId, message: JSONRPCMessage) => Promise - }): Promise { - const streamId = lastEventId.split('_')[0]; - // Extract stream ID from the event ID - // For test simplicity, just return all events with matching streamId that aren't the lastEventId - for (const [eventId, { message }] of storedEvents.entries()) { - if (eventId.startsWith(streamId) && eventId !== lastEventId) { - await send(eventId, message); - } - } - return streamId; - }, - }; - - beforeEach(async () => { - storedEvents.clear(); - const result = await createTestServer({ - sessionIdGenerator: () => randomUUID(), - eventStore - }); - - server = result.server; - transport = result.transport; - baseUrl = result.baseUrl; - mcpServer = result.mcpServer; - - // Verify resumability is enabled on the transport - expect((transport)['_eventStore']).toBeDefined(); - - // Initialize the server - const initResponse = await sendPostRequest(baseUrl, TEST_MESSAGES.initialize); - sessionId = initResponse.headers.get("mcp-session-id") as string; - expect(sessionId).toBeDefined(); - }); - - afterEach(async () => { - await stopTestServer({ server, transport }); - storedEvents.clear(); - }); - - it("should store and include event IDs in server SSE messages", async () => { - // Open a standalone SSE stream - const sseResponse = await fetch(baseUrl, { - method: "GET", - headers: { - Accept: "text/event-stream", - "mcp-session-id": sessionId, - "mcp-protocol-version": "2025-03-26", - }, - }); - - expect(sseResponse.status).toBe(200); - expect(sseResponse.headers.get("content-type")).toBe("text/event-stream"); - - // Send a notification that should be stored with an event ID - const notification: JSONRPCMessage = { - jsonrpc: "2.0", - method: "notifications/message", - params: { level: "info", data: "Test notification with event ID" }, - }; + const tempServer = result.server; + const tempUrl = result.baseUrl; - // Send the notification via transport - await transport.send(notification); + // Initialize should fail when callback throws + const initResponse = await sendPostRequest(tempUrl, TEST_MESSAGES.initialize); + expect(initResponse.status).toBe(400); - // Read from the stream and verify we got the notification with an event ID - const reader = sseResponse.body?.getReader(); - const { value } = await reader!.read(); - const text = new TextDecoder().decode(value); - - // The response should contain an event ID - expect(text).toContain('id: '); - expect(text).toContain('"method":"notifications/message"'); - - // Extract the event ID - const idMatch = text.match(/id: ([^\n]+)/); - expect(idMatch).toBeTruthy(); - - // Verify the event was stored - const eventId = idMatch![1]; - expect(storedEvents.has(eventId)).toBe(true); - const storedEvent = storedEvents.get(eventId); - expect(eventId.startsWith('_GET_stream')).toBe(true); - expect(storedEvent?.message).toMatchObject(notification); - }); - - - it("should store and replay MCP server tool notifications", async () => { - // Establish a standalone SSE stream - const sseResponse = await fetch(baseUrl, { - method: "GET", - headers: { - Accept: "text/event-stream", - "mcp-session-id": sessionId, - }, - }); - expect(sseResponse.status).toBe(200); // Send a server notification through the MCP server - await mcpServer.server.sendLoggingMessage({ level: "info", data: "First notification from MCP server" }); - - // Read the notification from the SSE stream - const reader = sseResponse.body?.getReader(); - const { value } = await reader!.read(); - const text = new TextDecoder().decode(value); + // Clean up + consoleErrorSpy.mockRestore(); + tempServer.close(); + }); - // Verify the notification was sent with an event ID - expect(text).toContain('id: '); - expect(text).toContain('First notification from MCP server'); + it('should propagate errors from async onsessionclosed callback', async () => { + const consoleErrorSpy = jest.spyOn(console, 'error').mockImplementation(); - // Extract the event ID - const idMatch = text.match(/id: ([^\n]+)/); - expect(idMatch).toBeTruthy(); - const firstEventId = idMatch![1]; + // Create server with async onsessionclosed callback that throws + const result = await createTestServer({ + sessionIdGenerator: () => randomUUID(), + onsessionclosed: async (_sessionId: string) => { + throw new Error('Async closure error'); + } + }); + + const tempServer = result.server; + const tempUrl = result.baseUrl; + + // Initialize to get a session ID + const initResponse = await sendPostRequest(tempUrl, TEST_MESSAGES.initialize); + const tempSessionId = initResponse.headers.get('mcp-session-id'); - // Send a second notification - await mcpServer.server.sendLoggingMessage({ level: "info", data: "Second notification from MCP server" }); + // DELETE should fail when callback throws + const deleteResponse = await fetch(tempUrl, { + method: 'DELETE', + headers: { + 'mcp-session-id': tempSessionId || '', + 'mcp-protocol-version': '2025-03-26' + } + }); - // Close the first SSE stream to simulate a disconnect - await reader!.cancel(); + expect(deleteResponse.status).toBe(500); - // Reconnect with the Last-Event-ID to get missed messages - const reconnectResponse = await fetch(baseUrl, { - method: "GET", - headers: { - Accept: "text/event-stream", - "mcp-session-id": sessionId, - "mcp-protocol-version": "2025-03-26", - "last-event-id": firstEventId - }, + // Clean up + consoleErrorSpy.mockRestore(); + tempServer.close(); }); - expect(reconnectResponse.status).toBe(200); + it('should handle both async callbacks together', async () => { + const events: string[] = []; + + // Create server with both async callbacks + const result = await createTestServer({ + sessionIdGenerator: () => randomUUID(), + onsessioninitialized: async (sessionId: string) => { + await new Promise(resolve => setTimeout(resolve, 5)); + events.push(`initialized:${sessionId}`); + }, + onsessionclosed: async (sessionId: string) => { + await new Promise(resolve => setTimeout(resolve, 5)); + events.push(`closed:${sessionId}`); + } + }); + + const tempServer = result.server; + const tempUrl = result.baseUrl; - // Read the replayed notification - const reconnectReader = reconnectResponse.body?.getReader(); - const reconnectData = await reconnectReader!.read(); - const reconnectText = new TextDecoder().decode(reconnectData.value); + // Initialize to trigger first callback + const initResponse = await sendPostRequest(tempUrl, TEST_MESSAGES.initialize); + const tempSessionId = initResponse.headers.get('mcp-session-id'); - // Verify we received the second notification that was sent after our stored eventId - expect(reconnectText).toContain('Second notification from MCP server'); - expect(reconnectText).toContain('id: '); - }); -}); + // Wait for async callback + await new Promise(resolve => setTimeout(resolve, 20)); -// Test stateless mode -describe("StreamableHTTPServerTransport in stateless mode", () => { - let server: Server; - let transport: StreamableHTTPServerTransport; - let baseUrl: URL; - - beforeEach(async () => { - const result = await createTestServer({ sessionIdGenerator: undefined }); - server = result.server; - transport = result.transport; - baseUrl = result.baseUrl; - }); - - afterEach(async () => { - await stopTestServer({ server, transport }); - }); - - it("should operate without session ID validation", async () => { - // Initialize the server first - const initResponse = await sendPostRequest(baseUrl, TEST_MESSAGES.initialize); - - expect(initResponse.status).toBe(200); - // Should NOT have session ID header in stateless mode - expect(initResponse.headers.get("mcp-session-id")).toBeNull(); - - // Try request without session ID - should work in stateless mode - const toolsResponse = await sendPostRequest(baseUrl, TEST_MESSAGES.toolsList); - - expect(toolsResponse.status).toBe(200); - }); - - it("should handle POST requests with various session IDs in stateless mode", async () => { - await sendPostRequest(baseUrl, TEST_MESSAGES.initialize); - - // Try with a random session ID - should be accepted - const response1 = await fetch(baseUrl, { - method: "POST", - headers: { - "Content-Type": "application/json", - Accept: "application/json, text/event-stream", - "mcp-session-id": "random-id-1", - }, - body: JSON.stringify({ jsonrpc: "2.0", method: "tools/list", params: {}, id: "t1" }), - }); - expect(response1.status).toBe(200); - - // Try with another random session ID - should also be accepted - const response2 = await fetch(baseUrl, { - method: "POST", - headers: { - "Content-Type": "application/json", - Accept: "application/json, text/event-stream", - "mcp-session-id": "different-id-2", - }, - body: JSON.stringify({ jsonrpc: "2.0", method: "tools/list", params: {}, id: "t2" }), - }); - expect(response2.status).toBe(200); - }); - - it("should reject second SSE stream even in stateless mode", async () => { - // Despite no session ID requirement, the transport still only allows - // one standalone SSE stream at a time - - // Initialize the server first - await sendPostRequest(baseUrl, TEST_MESSAGES.initialize); - - // Open first SSE stream - const stream1 = await fetch(baseUrl, { - method: "GET", - headers: { - Accept: "text/event-stream", - "mcp-protocol-version": "2025-03-26" - }, - }); - expect(stream1.status).toBe(200); - - // Open second SSE stream - should still be rejected, stateless mode still only allows one - const stream2 = await fetch(baseUrl, { - method: "GET", - headers: { - Accept: "text/event-stream", - "mcp-protocol-version": "2025-03-26" - }, - }); - expect(stream2.status).toBe(409); // Conflict - only one stream allowed - }); -}); + expect(events).toContain(`initialized:${tempSessionId}`); -// Test onsessionclosed callback -describe("StreamableHTTPServerTransport onsessionclosed callback", () => { - it("should call onsessionclosed callback when session is closed via DELETE", async () => { - const mockCallback = jest.fn(); - - // Create server with onsessionclosed callback - const result = await createTestServer({ - sessionIdGenerator: () => randomUUID(), - onsessionclosed: mockCallback, - }); - - const tempServer = result.server; - const tempUrl = result.baseUrl; - - // Initialize to get a session ID - const initResponse = await sendPostRequest(tempUrl, TEST_MESSAGES.initialize); - const tempSessionId = initResponse.headers.get("mcp-session-id"); - expect(tempSessionId).toBeDefined(); - - // DELETE the session - const deleteResponse = await fetch(tempUrl, { - method: "DELETE", - headers: { - "mcp-session-id": tempSessionId || "", - "mcp-protocol-version": "2025-03-26", - }, - }); - - expect(deleteResponse.status).toBe(200); - expect(mockCallback).toHaveBeenCalledWith(tempSessionId); - expect(mockCallback).toHaveBeenCalledTimes(1); - - // Clean up - tempServer.close(); - }); - - it("should not call onsessionclosed callback when not provided", async () => { - // Create server without onsessionclosed callback - const result = await createTestServer({ - sessionIdGenerator: () => randomUUID(), - }); - - const tempServer = result.server; - const tempUrl = result.baseUrl; - - // Initialize to get a session ID - const initResponse = await sendPostRequest(tempUrl, TEST_MESSAGES.initialize); - const tempSessionId = initResponse.headers.get("mcp-session-id"); - - // DELETE the session - should not throw error - const deleteResponse = await fetch(tempUrl, { - method: "DELETE", - headers: { - "mcp-session-id": tempSessionId || "", - "mcp-protocol-version": "2025-03-26", - }, - }); - - expect(deleteResponse.status).toBe(200); - - // Clean up - tempServer.close(); - }); - - it("should not call onsessionclosed callback for invalid session DELETE", async () => { - const mockCallback = jest.fn(); - - // Create server with onsessionclosed callback - const result = await createTestServer({ - sessionIdGenerator: () => randomUUID(), - onsessionclosed: mockCallback, - }); - - const tempServer = result.server; - const tempUrl = result.baseUrl; - - // Initialize to get a valid session - await sendPostRequest(tempUrl, TEST_MESSAGES.initialize); - - // Try to DELETE with invalid session ID - const deleteResponse = await fetch(tempUrl, { - method: "DELETE", - headers: { - "mcp-session-id": "invalid-session-id", - "mcp-protocol-version": "2025-03-26", - }, - }); - - expect(deleteResponse.status).toBe(404); - expect(mockCallback).not.toHaveBeenCalled(); - - // Clean up - tempServer.close(); - }); - - it("should call onsessionclosed callback with correct session ID when multiple sessions exist", async () => { - const mockCallback = jest.fn(); - - // Create first server - const result1 = await createTestServer({ - sessionIdGenerator: () => randomUUID(), - onsessionclosed: mockCallback, - }); - - const server1 = result1.server; - const url1 = result1.baseUrl; - - // Create second server - const result2 = await createTestServer({ - sessionIdGenerator: () => randomUUID(), - onsessionclosed: mockCallback, - }); - - const server2 = result2.server; - const url2 = result2.baseUrl; - - // Initialize both servers - const initResponse1 = await sendPostRequest(url1, TEST_MESSAGES.initialize); - const sessionId1 = initResponse1.headers.get("mcp-session-id"); - - const initResponse2 = await sendPostRequest(url2, TEST_MESSAGES.initialize); - const sessionId2 = initResponse2.headers.get("mcp-session-id"); - - expect(sessionId1).toBeDefined(); - expect(sessionId2).toBeDefined(); - expect(sessionId1).not.toBe(sessionId2); - - // DELETE first session - const deleteResponse1 = await fetch(url1, { - method: "DELETE", - headers: { - "mcp-session-id": sessionId1 || "", - "mcp-protocol-version": "2025-03-26", - }, - }); - - expect(deleteResponse1.status).toBe(200); - expect(mockCallback).toHaveBeenCalledWith(sessionId1); - expect(mockCallback).toHaveBeenCalledTimes(1); - - // DELETE second session - const deleteResponse2 = await fetch(url2, { - method: "DELETE", - headers: { - "mcp-session-id": sessionId2 || "", - "mcp-protocol-version": "2025-03-26", - }, - }); - - expect(deleteResponse2.status).toBe(200); - expect(mockCallback).toHaveBeenCalledWith(sessionId2); - expect(mockCallback).toHaveBeenCalledTimes(2); - - // Clean up - server1.close(); - server2.close(); - }); -}); + // DELETE to trigger second callback + const deleteResponse = await fetch(tempUrl, { + method: 'DELETE', + headers: { + 'mcp-session-id': tempSessionId || '', + 'mcp-protocol-version': '2025-03-26' + } + }); -// Test async callbacks for onsessioninitialized and onsessionclosed -describe("StreamableHTTPServerTransport async callbacks", () => { - it("should support async onsessioninitialized callback", async () => { - const initializationOrder: string[] = []; - - // Create server with async onsessioninitialized callback - const result = await createTestServer({ - sessionIdGenerator: () => randomUUID(), - onsessioninitialized: async (sessionId: string) => { - initializationOrder.push('async-start'); - // Simulate async operation - await new Promise(resolve => setTimeout(resolve, 10)); - initializationOrder.push('async-end'); - initializationOrder.push(sessionId); - }, - }); - - const tempServer = result.server; - const tempUrl = result.baseUrl; - - // Initialize to trigger the callback - const initResponse = await sendPostRequest(tempUrl, TEST_MESSAGES.initialize); - const tempSessionId = initResponse.headers.get("mcp-session-id"); - - // Give time for async callback to complete - await new Promise(resolve => setTimeout(resolve, 50)); - - expect(initializationOrder).toEqual(['async-start', 'async-end', tempSessionId]); - - // Clean up - tempServer.close(); - }); - - it("should support sync onsessioninitialized callback (backwards compatibility)", async () => { - const capturedSessionId: string[] = []; - - // Create server with sync onsessioninitialized callback - const result = await createTestServer({ - sessionIdGenerator: () => randomUUID(), - onsessioninitialized: (sessionId: string) => { - capturedSessionId.push(sessionId); - }, - }); - - const tempServer = result.server; - const tempUrl = result.baseUrl; - - // Initialize to trigger the callback - const initResponse = await sendPostRequest(tempUrl, TEST_MESSAGES.initialize); - const tempSessionId = initResponse.headers.get("mcp-session-id"); - - expect(capturedSessionId).toEqual([tempSessionId]); - - // Clean up - tempServer.close(); - }); - - it("should support async onsessionclosed callback", async () => { - const closureOrder: string[] = []; - - // Create server with async onsessionclosed callback - const result = await createTestServer({ - sessionIdGenerator: () => randomUUID(), - onsessionclosed: async (sessionId: string) => { - closureOrder.push('async-close-start'); - // Simulate async operation - await new Promise(resolve => setTimeout(resolve, 10)); - closureOrder.push('async-close-end'); - closureOrder.push(sessionId); - }, - }); - - const tempServer = result.server; - const tempUrl = result.baseUrl; - - // Initialize to get a session ID - const initResponse = await sendPostRequest(tempUrl, TEST_MESSAGES.initialize); - const tempSessionId = initResponse.headers.get("mcp-session-id"); - expect(tempSessionId).toBeDefined(); - - // DELETE the session - const deleteResponse = await fetch(tempUrl, { - method: "DELETE", - headers: { - "mcp-session-id": tempSessionId || "", - "mcp-protocol-version": "2025-03-26", - }, - }); - - expect(deleteResponse.status).toBe(200); - - // Give time for async callback to complete - await new Promise(resolve => setTimeout(resolve, 50)); - - expect(closureOrder).toEqual(['async-close-start', 'async-close-end', tempSessionId]); - - // Clean up - tempServer.close(); - }); - - it("should propagate errors from async onsessioninitialized callback", async () => { - const consoleErrorSpy = jest.spyOn(console, 'error').mockImplementation(); - - // Create server with async onsessioninitialized callback that throws - const result = await createTestServer({ - sessionIdGenerator: () => randomUUID(), - onsessioninitialized: async (_sessionId: string) => { - throw new Error('Async initialization error'); - }, - }); - - const tempServer = result.server; - const tempUrl = result.baseUrl; - - // Initialize should fail when callback throws - const initResponse = await sendPostRequest(tempUrl, TEST_MESSAGES.initialize); - expect(initResponse.status).toBe(400); - - // Clean up - consoleErrorSpy.mockRestore(); - tempServer.close(); - }); - - it("should propagate errors from async onsessionclosed callback", async () => { - const consoleErrorSpy = jest.spyOn(console, 'error').mockImplementation(); - - // Create server with async onsessionclosed callback that throws - const result = await createTestServer({ - sessionIdGenerator: () => randomUUID(), - onsessionclosed: async (_sessionId: string) => { - throw new Error('Async closure error'); - }, - }); - - const tempServer = result.server; - const tempUrl = result.baseUrl; - - // Initialize to get a session ID - const initResponse = await sendPostRequest(tempUrl, TEST_MESSAGES.initialize); - const tempSessionId = initResponse.headers.get("mcp-session-id"); - - // DELETE should fail when callback throws - const deleteResponse = await fetch(tempUrl, { - method: "DELETE", - headers: { - "mcp-session-id": tempSessionId || "", - "mcp-protocol-version": "2025-03-26", - }, - }); - - expect(deleteResponse.status).toBe(500); - - // Clean up - consoleErrorSpy.mockRestore(); - tempServer.close(); - }); - - it("should handle both async callbacks together", async () => { - const events: string[] = []; - - // Create server with both async callbacks - const result = await createTestServer({ - sessionIdGenerator: () => randomUUID(), - onsessioninitialized: async (sessionId: string) => { - await new Promise(resolve => setTimeout(resolve, 5)); - events.push(`initialized:${sessionId}`); - }, - onsessionclosed: async (sessionId: string) => { - await new Promise(resolve => setTimeout(resolve, 5)); - events.push(`closed:${sessionId}`); - }, - }); - - const tempServer = result.server; - const tempUrl = result.baseUrl; - - // Initialize to trigger first callback - const initResponse = await sendPostRequest(tempUrl, TEST_MESSAGES.initialize); - const tempSessionId = initResponse.headers.get("mcp-session-id"); - - // Wait for async callback - await new Promise(resolve => setTimeout(resolve, 20)); - - expect(events).toContain(`initialized:${tempSessionId}`); - - // DELETE to trigger second callback - const deleteResponse = await fetch(tempUrl, { - method: "DELETE", - headers: { - "mcp-session-id": tempSessionId || "", - "mcp-protocol-version": "2025-03-26", - }, - }); - - expect(deleteResponse.status).toBe(200); - - // Wait for async callback - await new Promise(resolve => setTimeout(resolve, 20)); - - expect(events).toContain(`closed:${tempSessionId}`); - expect(events).toHaveLength(2); - - // Clean up - tempServer.close(); - }); + expect(deleteResponse.status).toBe(200); + + // Wait for async callback + await new Promise(resolve => setTimeout(resolve, 20)); + + expect(events).toContain(`closed:${tempSessionId}`); + expect(events).toHaveLength(2); + + // Clean up + tempServer.close(); + }); }); // Test DNS rebinding protection -describe("StreamableHTTPServerTransport DNS rebinding protection", () => { - let server: Server; - let transport: StreamableHTTPServerTransport; - let baseUrl: URL; - - afterEach(async () => { - if (server && transport) { - await stopTestServer({ server, transport }); - } - }); - - describe("Host header validation", () => { - it("should accept requests with allowed host headers", async () => { - const result = await createTestServerWithDnsProtection({ - sessionIdGenerator: undefined, - allowedHosts: ['localhost'], - enableDnsRebindingProtection: true, - }); - server = result.server; - transport = result.transport; - baseUrl = result.baseUrl; - - // Note: fetch() automatically sets Host header to match the URL - // Since we're connecting to localhost:3001 and that's in allowedHosts, this should work - const response = await fetch(baseUrl, { - method: "POST", - headers: { - "Content-Type": "application/json", - Accept: "application/json, text/event-stream", - }, - body: JSON.stringify(TEST_MESSAGES.initialize), - }); - - expect(response.status).toBe(200); - }); - - it("should reject requests with disallowed host headers", async () => { - // Test DNS rebinding protection by creating a server that only allows example.com - // but we're connecting via localhost, so it should be rejected - const result = await createTestServerWithDnsProtection({ - sessionIdGenerator: undefined, - allowedHosts: ['example.com:3001'], - enableDnsRebindingProtection: true, - }); - server = result.server; - transport = result.transport; - baseUrl = result.baseUrl; - - const response = await fetch(baseUrl, { - method: "POST", - headers: { - "Content-Type": "application/json", - Accept: "application/json, text/event-stream", - }, - body: JSON.stringify(TEST_MESSAGES.initialize), - }); - - expect(response.status).toBe(403); - const body = await response.json(); - expect(body.error.message).toContain("Invalid Host header:"); - }); - - it("should reject GET requests with disallowed host headers", async () => { - const result = await createTestServerWithDnsProtection({ - sessionIdGenerator: undefined, - allowedHosts: ['example.com:3001'], - enableDnsRebindingProtection: true, - }); - server = result.server; - transport = result.transport; - baseUrl = result.baseUrl; - - const response = await fetch(baseUrl, { - method: "GET", - headers: { - Accept: "text/event-stream", - }, - }); - - expect(response.status).toBe(403); - }); - }); - - describe("Origin header validation", () => { - it("should accept requests with allowed origin headers", async () => { - const result = await createTestServerWithDnsProtection({ - sessionIdGenerator: undefined, - allowedOrigins: ['http://localhost:3000', 'https://example.com'], - enableDnsRebindingProtection: true, - }); - server = result.server; - transport = result.transport; - baseUrl = result.baseUrl; - - const response = await fetch(baseUrl, { - method: "POST", - headers: { - "Content-Type": "application/json", - Accept: "application/json, text/event-stream", - Origin: "http://localhost:3000", - }, - body: JSON.stringify(TEST_MESSAGES.initialize), - }); - - expect(response.status).toBe(200); - }); - - it("should reject requests with disallowed origin headers", async () => { - const result = await createTestServerWithDnsProtection({ - sessionIdGenerator: undefined, - allowedOrigins: ['http://localhost:3000'], - enableDnsRebindingProtection: true, - }); - server = result.server; - transport = result.transport; - baseUrl = result.baseUrl; - - const response = await fetch(baseUrl, { - method: "POST", - headers: { - "Content-Type": "application/json", - Accept: "application/json, text/event-stream", - Origin: "http://evil.com", - }, - body: JSON.stringify(TEST_MESSAGES.initialize), - }); - - expect(response.status).toBe(403); - const body = await response.json(); - expect(body.error.message).toBe("Invalid Origin header: http://evil.com"); - }); - }); - - describe("enableDnsRebindingProtection option", () => { - it("should skip all validations when enableDnsRebindingProtection is false", async () => { - const result = await createTestServerWithDnsProtection({ - sessionIdGenerator: undefined, - allowedHosts: ['localhost'], - allowedOrigins: ['http://localhost:3000'], - enableDnsRebindingProtection: false, - }); - server = result.server; - transport = result.transport; - baseUrl = result.baseUrl; - - const response = await fetch(baseUrl, { - method: "POST", - headers: { - "Content-Type": "application/json", - Accept: "application/json, text/event-stream", - Host: "evil.com", - Origin: "http://evil.com", - }, - body: JSON.stringify(TEST_MESSAGES.initialize), - }); - - // Should pass even with invalid headers because protection is disabled - expect(response.status).toBe(200); - }); - }); - - describe("Combined validations", () => { - it("should validate both host and origin when both are configured", async () => { - const result = await createTestServerWithDnsProtection({ - sessionIdGenerator: undefined, - allowedHosts: ['localhost'], - allowedOrigins: ['http://localhost:3001'], - enableDnsRebindingProtection: true, - }); - server = result.server; - transport = result.transport; - baseUrl = result.baseUrl; - - // Test with invalid origin (host will be automatically correct via fetch) - const response1 = await fetch(baseUrl, { - method: "POST", - headers: { - "Content-Type": "application/json", - Accept: "application/json, text/event-stream", - Origin: "http://evil.com", - }, - body: JSON.stringify(TEST_MESSAGES.initialize), - }); - - expect(response1.status).toBe(403); - const body1 = await response1.json(); - expect(body1.error.message).toBe("Invalid Origin header: http://evil.com"); - - // Test with valid origin - const response2 = await fetch(baseUrl, { - method: "POST", - headers: { - "Content-Type": "application/json", - Accept: "application/json, text/event-stream", - Origin: "http://localhost:3001", - }, - body: JSON.stringify(TEST_MESSAGES.initialize), - }); +describe('StreamableHTTPServerTransport DNS rebinding protection', () => { + let server: Server; + let transport: StreamableHTTPServerTransport; + let baseUrl: URL; + + afterEach(async () => { + if (server && transport) { + await stopTestServer({ server, transport }); + } + }); - expect(response2.status).toBe(200); + describe('Host header validation', () => { + it('should accept requests with allowed host headers', async () => { + const result = await createTestServerWithDnsProtection({ + sessionIdGenerator: undefined, + allowedHosts: ['localhost'], + enableDnsRebindingProtection: true + }); + server = result.server; + transport = result.transport; + baseUrl = result.baseUrl; + + // Note: fetch() automatically sets Host header to match the URL + // Since we're connecting to localhost:3001 and that's in allowedHosts, this should work + const response = await fetch(baseUrl, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream' + }, + body: JSON.stringify(TEST_MESSAGES.initialize) + }); + + expect(response.status).toBe(200); + }); + + it('should reject requests with disallowed host headers', async () => { + // Test DNS rebinding protection by creating a server that only allows example.com + // but we're connecting via localhost, so it should be rejected + const result = await createTestServerWithDnsProtection({ + sessionIdGenerator: undefined, + allowedHosts: ['example.com:3001'], + enableDnsRebindingProtection: true + }); + server = result.server; + transport = result.transport; + baseUrl = result.baseUrl; + + const response = await fetch(baseUrl, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream' + }, + body: JSON.stringify(TEST_MESSAGES.initialize) + }); + + expect(response.status).toBe(403); + const body = await response.json(); + expect(body.error.message).toContain('Invalid Host header:'); + }); + + it('should reject GET requests with disallowed host headers', async () => { + const result = await createTestServerWithDnsProtection({ + sessionIdGenerator: undefined, + allowedHosts: ['example.com:3001'], + enableDnsRebindingProtection: true + }); + server = result.server; + transport = result.transport; + baseUrl = result.baseUrl; + + const response = await fetch(baseUrl, { + method: 'GET', + headers: { + Accept: 'text/event-stream' + } + }); + + expect(response.status).toBe(403); + }); + }); + + describe('Origin header validation', () => { + it('should accept requests with allowed origin headers', async () => { + const result = await createTestServerWithDnsProtection({ + sessionIdGenerator: undefined, + allowedOrigins: ['http://localhost:3000', 'https://example.com'], + enableDnsRebindingProtection: true + }); + server = result.server; + transport = result.transport; + baseUrl = result.baseUrl; + + const response = await fetch(baseUrl, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream', + Origin: 'http://localhost:3000' + }, + body: JSON.stringify(TEST_MESSAGES.initialize) + }); + + expect(response.status).toBe(200); + }); + + it('should reject requests with disallowed origin headers', async () => { + const result = await createTestServerWithDnsProtection({ + sessionIdGenerator: undefined, + allowedOrigins: ['http://localhost:3000'], + enableDnsRebindingProtection: true + }); + server = result.server; + transport = result.transport; + baseUrl = result.baseUrl; + + const response = await fetch(baseUrl, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream', + Origin: 'http://evil.com' + }, + body: JSON.stringify(TEST_MESSAGES.initialize) + }); + + expect(response.status).toBe(403); + const body = await response.json(); + expect(body.error.message).toBe('Invalid Origin header: http://evil.com'); + }); + }); + + describe('enableDnsRebindingProtection option', () => { + it('should skip all validations when enableDnsRebindingProtection is false', async () => { + const result = await createTestServerWithDnsProtection({ + sessionIdGenerator: undefined, + allowedHosts: ['localhost'], + allowedOrigins: ['http://localhost:3000'], + enableDnsRebindingProtection: false + }); + server = result.server; + transport = result.transport; + baseUrl = result.baseUrl; + + const response = await fetch(baseUrl, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream', + Host: 'evil.com', + Origin: 'http://evil.com' + }, + body: JSON.stringify(TEST_MESSAGES.initialize) + }); + + // Should pass even with invalid headers because protection is disabled + expect(response.status).toBe(200); + }); + }); + + describe('Combined validations', () => { + it('should validate both host and origin when both are configured', async () => { + const result = await createTestServerWithDnsProtection({ + sessionIdGenerator: undefined, + allowedHosts: ['localhost'], + allowedOrigins: ['http://localhost:3001'], + enableDnsRebindingProtection: true + }); + server = result.server; + transport = result.transport; + baseUrl = result.baseUrl; + + // Test with invalid origin (host will be automatically correct via fetch) + const response1 = await fetch(baseUrl, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream', + Origin: 'http://evil.com' + }, + body: JSON.stringify(TEST_MESSAGES.initialize) + }); + + expect(response1.status).toBe(403); + const body1 = await response1.json(); + expect(body1.error.message).toBe('Invalid Origin header: http://evil.com'); + + // Test with valid origin + const response2 = await fetch(baseUrl, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream', + Origin: 'http://localhost:3001' + }, + body: JSON.stringify(TEST_MESSAGES.initialize) + }); + + expect(response2.status).toBe(200); + }); }); - }); }); /** * Helper to create test server with DNS rebinding protection options */ async function createTestServerWithDnsProtection(config: { - sessionIdGenerator: (() => string) | undefined; - allowedHosts?: string[]; - allowedOrigins?: string[]; - enableDnsRebindingProtection?: boolean; + sessionIdGenerator: (() => string) | undefined; + allowedHosts?: string[]; + allowedOrigins?: string[]; + enableDnsRebindingProtection?: boolean; }): Promise<{ - server: Server; - transport: StreamableHTTPServerTransport; - mcpServer: McpServer; - baseUrl: URL; + server: Server; + transport: StreamableHTTPServerTransport; + mcpServer: McpServer; + baseUrl: URL; }> { - const mcpServer = new McpServer( - { name: "test-server", version: "1.0.0" }, - { capabilities: { logging: {} } } - ); - - const port = await getFreePort(); - - if (config.allowedHosts) { - config.allowedHosts = config.allowedHosts.map(host => { - if (host.includes(':')) { - return host; - } - return `localhost:${port}`; - }); - } - - const transport = new StreamableHTTPServerTransport({ - sessionIdGenerator: config.sessionIdGenerator, - allowedHosts: config.allowedHosts, - allowedOrigins: config.allowedOrigins, - enableDnsRebindingProtection: config.enableDnsRebindingProtection, - }); - - await mcpServer.connect(transport); - - const httpServer = createServer(async (req, res) => { - if (req.method === "POST") { - let body = ""; - req.on("data", (chunk) => (body += chunk)); - req.on("end", async () => { - const parsedBody = JSON.parse(body); - await transport.handleRequest(req as IncomingMessage & { auth?: AuthInfo }, res, parsedBody); - }); - } else { - await transport.handleRequest(req as IncomingMessage & { auth?: AuthInfo }, res); + const mcpServer = new McpServer({ name: 'test-server', version: '1.0.0' }, { capabilities: { logging: {} } }); + + const port = await getFreePort(); + + if (config.allowedHosts) { + config.allowedHosts = config.allowedHosts.map(host => { + if (host.includes(':')) { + return host; + } + return `localhost:${port}`; + }); } - }); - await new Promise((resolve) => { - httpServer.listen(port, () => resolve()); - }); + const transport = new StreamableHTTPServerTransport({ + sessionIdGenerator: config.sessionIdGenerator, + allowedHosts: config.allowedHosts, + allowedOrigins: config.allowedOrigins, + enableDnsRebindingProtection: config.enableDnsRebindingProtection + }); + + await mcpServer.connect(transport); + + const httpServer = createServer(async (req, res) => { + if (req.method === 'POST') { + let body = ''; + req.on('data', chunk => (body += chunk)); + req.on('end', async () => { + const parsedBody = JSON.parse(body); + await transport.handleRequest(req as IncomingMessage & { auth?: AuthInfo }, res, parsedBody); + }); + } else { + await transport.handleRequest(req as IncomingMessage & { auth?: AuthInfo }, res); + } + }); + + await new Promise(resolve => { + httpServer.listen(port, () => resolve()); + }); - const serverUrl = new URL(`http://localhost:${port}/`); + const serverUrl = new URL(`http://localhost:${port}/`); - return { - server: httpServer, - transport, - mcpServer, - baseUrl: serverUrl, - }; -} \ No newline at end of file + return { + server: httpServer, + transport, + mcpServer, + baseUrl: serverUrl + }; +} diff --git a/src/server/streamableHttp.ts b/src/server/streamableHttp.ts index c0da91704..d57e75cd7 100644 --- a/src/server/streamableHttp.ts +++ b/src/server/streamableHttp.ts @@ -1,12 +1,24 @@ -import { IncomingMessage, ServerResponse } from "node:http"; -import { Transport } from "../shared/transport.js"; -import { MessageExtraInfo, RequestInfo, isInitializeRequest, isJSONRPCError, isJSONRPCRequest, isJSONRPCResponse, JSONRPCMessage, JSONRPCMessageSchema, RequestId, SUPPORTED_PROTOCOL_VERSIONS, DEFAULT_NEGOTIATED_PROTOCOL_VERSION } from "../types.js"; -import getRawBody from "raw-body"; -import contentType from "content-type"; -import { randomUUID } from "node:crypto"; -import { AuthInfo } from "./auth/types.js"; - -const MAXIMUM_MESSAGE_SIZE = "4mb"; +import { IncomingMessage, ServerResponse } from 'node:http'; +import { Transport } from '../shared/transport.js'; +import { + MessageExtraInfo, + RequestInfo, + isInitializeRequest, + isJSONRPCError, + isJSONRPCRequest, + isJSONRPCResponse, + JSONRPCMessage, + JSONRPCMessageSchema, + RequestId, + SUPPORTED_PROTOCOL_VERSIONS, + DEFAULT_NEGOTIATED_PROTOCOL_VERSION +} from '../types.js'; +import getRawBody from 'raw-body'; +import contentType from 'content-type'; +import { randomUUID } from 'node:crypto'; +import { AuthInfo } from './auth/types.js'; + +const MAXIMUM_MESSAGE_SIZE = '4mb'; export type StreamId = string; export type EventId = string; @@ -15,753 +27,781 @@ export type EventId = string; * Interface for resumability support via event storage */ export interface EventStore { - /** - * Stores an event for later retrieval - * @param streamId ID of the stream the event belongs to - * @param message The JSON-RPC message to store - * @returns The generated event ID for the stored event - */ - storeEvent(streamId: StreamId, message: JSONRPCMessage): Promise; - - replayEventsAfter(lastEventId: EventId, { send }: { - send: (eventId: EventId, message: JSONRPCMessage) => Promise - }): Promise; + /** + * Stores an event for later retrieval + * @param streamId ID of the stream the event belongs to + * @param message The JSON-RPC message to store + * @returns The generated event ID for the stored event + */ + storeEvent(streamId: StreamId, message: JSONRPCMessage): Promise; + + replayEventsAfter( + lastEventId: EventId, + { + send + }: { + send: (eventId: EventId, message: JSONRPCMessage) => Promise; + } + ): Promise; } /** * Configuration options for StreamableHTTPServerTransport */ export interface StreamableHTTPServerTransportOptions { - /** - * Function that generates a session ID for the transport. - * The session ID SHOULD be globally unique and cryptographically secure (e.g., a securely generated UUID, a JWT, or a cryptographic hash) - * - * Return undefined to disable session management. - */ - sessionIdGenerator: (() => string) | undefined; - - /** - * A callback for session initialization events - * This is called when the server initializes a new session. - * Useful in cases when you need to register multiple mcp sessions - * and need to keep track of them. - * @param sessionId The generated session ID - */ - onsessioninitialized?: (sessionId: string) => void | Promise; - - /** - * A callback for session close events - * This is called when the server closes a session due to a DELETE request. - * Useful in cases when you need to clean up resources associated with the session. - * Note that this is different from the transport closing, if you are handling - * HTTP requests from multiple nodes you might want to close each - * StreamableHTTPServerTransport after a request is completed while still keeping the - * session open/running. - * @param sessionId The session ID that was closed - */ - onsessionclosed?: (sessionId: string) => void | Promise; - - /** - * If true, the server will return JSON responses instead of starting an SSE stream. - * This can be useful for simple request/response scenarios without streaming. - * Default is false (SSE streams are preferred). - */ - enableJsonResponse?: boolean; - - /** - * Event store for resumability support - * If provided, resumability will be enabled, allowing clients to reconnect and resume messages - */ - eventStore?: EventStore; - - /** - * List of allowed host header values for DNS rebinding protection. - * If not specified, host validation is disabled. - */ - allowedHosts?: string[]; - - /** - * List of allowed origin header values for DNS rebinding protection. - * If not specified, origin validation is disabled. - */ - allowedOrigins?: string[]; - - /** - * Enable DNS rebinding protection (requires allowedHosts and/or allowedOrigins to be configured). - * Default is false for backwards compatibility. - */ - enableDnsRebindingProtection?: boolean; + /** + * Function that generates a session ID for the transport. + * The session ID SHOULD be globally unique and cryptographically secure (e.g., a securely generated UUID, a JWT, or a cryptographic hash) + * + * Return undefined to disable session management. + */ + sessionIdGenerator: (() => string) | undefined; + + /** + * A callback for session initialization events + * This is called when the server initializes a new session. + * Useful in cases when you need to register multiple mcp sessions + * and need to keep track of them. + * @param sessionId The generated session ID + */ + onsessioninitialized?: (sessionId: string) => void | Promise; + + /** + * A callback for session close events + * This is called when the server closes a session due to a DELETE request. + * Useful in cases when you need to clean up resources associated with the session. + * Note that this is different from the transport closing, if you are handling + * HTTP requests from multiple nodes you might want to close each + * StreamableHTTPServerTransport after a request is completed while still keeping the + * session open/running. + * @param sessionId The session ID that was closed + */ + onsessionclosed?: (sessionId: string) => void | Promise; + + /** + * If true, the server will return JSON responses instead of starting an SSE stream. + * This can be useful for simple request/response scenarios without streaming. + * Default is false (SSE streams are preferred). + */ + enableJsonResponse?: boolean; + + /** + * Event store for resumability support + * If provided, resumability will be enabled, allowing clients to reconnect and resume messages + */ + eventStore?: EventStore; + + /** + * List of allowed host header values for DNS rebinding protection. + * If not specified, host validation is disabled. + */ + allowedHosts?: string[]; + + /** + * List of allowed origin header values for DNS rebinding protection. + * If not specified, origin validation is disabled. + */ + allowedOrigins?: string[]; + + /** + * Enable DNS rebinding protection (requires allowedHosts and/or allowedOrigins to be configured). + * Default is false for backwards compatibility. + */ + enableDnsRebindingProtection?: boolean; } /** * Server transport for Streamable HTTP: this implements the MCP Streamable HTTP transport specification. * It supports both SSE streaming and direct HTTP responses. - * + * * Usage example: - * + * * ```typescript * // Stateful mode - server sets the session ID * const statefulTransport = new StreamableHTTPServerTransport({ * sessionIdGenerator: () => randomUUID(), * }); - * + * * // Stateless mode - explicitly set session ID to undefined * const statelessTransport = new StreamableHTTPServerTransport({ * sessionIdGenerator: undefined, * }); - * + * * // Using with pre-parsed request body * app.post('/mcp', (req, res) => { * transport.handleRequest(req, res, req.body); * }); * ``` - * + * * In stateful mode: * - Session ID is generated and included in response headers * - Session ID is always included in initialization responses * - Requests with invalid session IDs are rejected with 404 Not Found * - Non-initialization requests without a session ID are rejected with 400 Bad Request * - State is maintained in-memory (connections, message history) - * + * * In stateless mode: * - No Session ID is included in any responses * - No session validation is performed */ export class StreamableHTTPServerTransport implements Transport { - // when sessionId is not set (undefined), it means the transport is in stateless mode - private sessionIdGenerator: (() => string) | undefined; - private _started: boolean = false; - private _streamMapping: Map = new Map(); - private _requestToStreamMapping: Map = new Map(); - private _requestResponseMap: Map = new Map(); - private _initialized: boolean = false; - private _enableJsonResponse: boolean = false; - private _standaloneSseStreamId: string = '_GET_stream'; - private _eventStore?: EventStore; - private _onsessioninitialized?: (sessionId: string) => void | Promise; - private _onsessionclosed?: (sessionId: string) => void | Promise; - private _allowedHosts?: string[]; - private _allowedOrigins?: string[]; - private _enableDnsRebindingProtection: boolean; - - sessionId?: string; - onclose?: () => void; - onerror?: (error: Error) => void; - onmessage?: (message: JSONRPCMessage, extra?: MessageExtraInfo) => void; - - constructor(options: StreamableHTTPServerTransportOptions) { - this.sessionIdGenerator = options.sessionIdGenerator; - this._enableJsonResponse = options.enableJsonResponse ?? false; - this._eventStore = options.eventStore; - this._onsessioninitialized = options.onsessioninitialized; - this._onsessionclosed = options.onsessionclosed; - this._allowedHosts = options.allowedHosts; - this._allowedOrigins = options.allowedOrigins; - this._enableDnsRebindingProtection = options.enableDnsRebindingProtection ?? false; - } - - /** - * Starts the transport. This is required by the Transport interface but is a no-op - * for the Streamable HTTP transport as connections are managed per-request. - */ - async start(): Promise { - if (this._started) { - throw new Error("Transport already started"); - } - this._started = true; - } - - /** - * Validates request headers for DNS rebinding protection. - * @returns Error message if validation fails, undefined if validation passes. - */ - private validateRequestHeaders(req: IncomingMessage): string | undefined { - // Skip validation if protection is not enabled - if (!this._enableDnsRebindingProtection) { - return undefined; - } - - // Validate Host header if allowedHosts is configured - if (this._allowedHosts && this._allowedHosts.length > 0) { - const hostHeader = req.headers.host; - if (!hostHeader || !this._allowedHosts.includes(hostHeader)) { - return `Invalid Host header: ${hostHeader}`; - } + // when sessionId is not set (undefined), it means the transport is in stateless mode + private sessionIdGenerator: (() => string) | undefined; + private _started: boolean = false; + private _streamMapping: Map = new Map(); + private _requestToStreamMapping: Map = new Map(); + private _requestResponseMap: Map = new Map(); + private _initialized: boolean = false; + private _enableJsonResponse: boolean = false; + private _standaloneSseStreamId: string = '_GET_stream'; + private _eventStore?: EventStore; + private _onsessioninitialized?: (sessionId: string) => void | Promise; + private _onsessionclosed?: (sessionId: string) => void | Promise; + private _allowedHosts?: string[]; + private _allowedOrigins?: string[]; + private _enableDnsRebindingProtection: boolean; + + sessionId?: string; + onclose?: () => void; + onerror?: (error: Error) => void; + onmessage?: (message: JSONRPCMessage, extra?: MessageExtraInfo) => void; + + constructor(options: StreamableHTTPServerTransportOptions) { + this.sessionIdGenerator = options.sessionIdGenerator; + this._enableJsonResponse = options.enableJsonResponse ?? false; + this._eventStore = options.eventStore; + this._onsessioninitialized = options.onsessioninitialized; + this._onsessionclosed = options.onsessionclosed; + this._allowedHosts = options.allowedHosts; + this._allowedOrigins = options.allowedOrigins; + this._enableDnsRebindingProtection = options.enableDnsRebindingProtection ?? false; } - // Validate Origin header if allowedOrigins is configured - if (this._allowedOrigins && this._allowedOrigins.length > 0) { - const originHeader = req.headers.origin; - if (!originHeader || !this._allowedOrigins.includes(originHeader)) { - return `Invalid Origin header: ${originHeader}`; - } + /** + * Starts the transport. This is required by the Transport interface but is a no-op + * for the Streamable HTTP transport as connections are managed per-request. + */ + async start(): Promise { + if (this._started) { + throw new Error('Transport already started'); + } + this._started = true; } - return undefined; - } - - /** - * Handles an incoming HTTP request, whether GET or POST - */ - async handleRequest(req: IncomingMessage & { auth?: AuthInfo }, res: ServerResponse, parsedBody?: unknown): Promise { - // Validate request headers for DNS rebinding protection - const validationError = this.validateRequestHeaders(req); - if (validationError) { - res.writeHead(403).end(JSON.stringify({ - jsonrpc: "2.0", - error: { - code: -32000, - message: validationError - }, - id: null - })); - this.onerror?.(new Error(validationError)); - return; - } + /** + * Validates request headers for DNS rebinding protection. + * @returns Error message if validation fails, undefined if validation passes. + */ + private validateRequestHeaders(req: IncomingMessage): string | undefined { + // Skip validation if protection is not enabled + if (!this._enableDnsRebindingProtection) { + return undefined; + } - if (req.method === "POST") { - await this.handlePostRequest(req, res, parsedBody); - } else if (req.method === "GET") { - await this.handleGetRequest(req, res); - } else if (req.method === "DELETE") { - await this.handleDeleteRequest(req, res); - } else { - await this.handleUnsupportedRequest(res); - } - } - - /** - * Handles GET requests for SSE stream - */ - private async handleGetRequest(req: IncomingMessage, res: ServerResponse): Promise { - // The client MUST include an Accept header, listing text/event-stream as a supported content type. - const acceptHeader = req.headers.accept; - if (!acceptHeader?.includes("text/event-stream")) { - res.writeHead(406).end(JSON.stringify({ - jsonrpc: "2.0", - error: { - code: -32000, - message: "Not Acceptable: Client must accept text/event-stream" - }, - id: null - })); - return; - } + // Validate Host header if allowedHosts is configured + if (this._allowedHosts && this._allowedHosts.length > 0) { + const hostHeader = req.headers.host; + if (!hostHeader || !this._allowedHosts.includes(hostHeader)) { + return `Invalid Host header: ${hostHeader}`; + } + } - // If an Mcp-Session-Id is returned by the server during initialization, - // clients using the Streamable HTTP transport MUST include it - // in the Mcp-Session-Id header on all of their subsequent HTTP requests. - if (!this.validateSession(req, res)) { - return; - } - if (!this.validateProtocolVersion(req, res)) { - return; - } - // Handle resumability: check for Last-Event-ID header - if (this._eventStore) { - const lastEventId = req.headers['last-event-id'] as string | undefined; - if (lastEventId) { - await this.replayEvents(lastEventId, res); - return; - } - } + // Validate Origin header if allowedOrigins is configured + if (this._allowedOrigins && this._allowedOrigins.length > 0) { + const originHeader = req.headers.origin; + if (!originHeader || !this._allowedOrigins.includes(originHeader)) { + return `Invalid Origin header: ${originHeader}`; + } + } - // The server MUST either return Content-Type: text/event-stream in response to this HTTP GET, - // or else return HTTP 405 Method Not Allowed - const headers: Record = { - "Content-Type": "text/event-stream", - "Cache-Control": "no-cache, no-transform", - Connection: "keep-alive", - }; - - // After initialization, always include the session ID if we have one - if (this.sessionId !== undefined) { - headers["mcp-session-id"] = this.sessionId; + return undefined; } - // Check if there's already an active standalone SSE stream for this session - if (this._streamMapping.get(this._standaloneSseStreamId) !== undefined) { - // Only one GET SSE stream is allowed per session - res.writeHead(409).end(JSON.stringify({ - jsonrpc: "2.0", - error: { - code: -32000, - message: "Conflict: Only one SSE stream is allowed per session" - }, - id: null - })); - return; - } + /** + * Handles an incoming HTTP request, whether GET or POST + */ + async handleRequest(req: IncomingMessage & { auth?: AuthInfo }, res: ServerResponse, parsedBody?: unknown): Promise { + // Validate request headers for DNS rebinding protection + const validationError = this.validateRequestHeaders(req); + if (validationError) { + res.writeHead(403).end( + JSON.stringify({ + jsonrpc: '2.0', + error: { + code: -32000, + message: validationError + }, + id: null + }) + ); + this.onerror?.(new Error(validationError)); + return; + } - // We need to send headers immediately as messages will arrive much later, - // otherwise the client will just wait for the first message - res.writeHead(200, headers).flushHeaders(); - - // Assign the response to the standalone SSE stream - this._streamMapping.set(this._standaloneSseStreamId, res); - // Set up close handler for client disconnects - res.on("close", () => { - this._streamMapping.delete(this._standaloneSseStreamId); - }); - - // Add error handler for standalone SSE stream - res.on("error", (error) => { - this.onerror?.(error as Error); - }); - } - - /** - * Replays events that would have been sent after the specified event ID - * Only used when resumability is enabled - */ - private async replayEvents(lastEventId: string, res: ServerResponse): Promise { - if (!this._eventStore) { - return; - } - try { - const headers: Record = { - "Content-Type": "text/event-stream", - "Cache-Control": "no-cache, no-transform", - Connection: "keep-alive", - }; - - if (this.sessionId !== undefined) { - headers["mcp-session-id"] = this.sessionId; - } - res.writeHead(200, headers).flushHeaders(); - - const streamId = await this._eventStore?.replayEventsAfter(lastEventId, { - send: async (eventId: string, message: JSONRPCMessage) => { - if (!this.writeSSEEvent(res, message, eventId)) { - this.onerror?.(new Error("Failed replay events")); - res.end(); - } + if (req.method === 'POST') { + await this.handlePostRequest(req, res, parsedBody); + } else if (req.method === 'GET') { + await this.handleGetRequest(req, res); + } else if (req.method === 'DELETE') { + await this.handleDeleteRequest(req, res); + } else { + await this.handleUnsupportedRequest(res); } - }); - this._streamMapping.set(streamId, res); - - // Add error handler for replay stream - res.on("error", (error) => { - this.onerror?.(error as Error); - }); - } catch (error) { - this.onerror?.(error as Error); - } - } - - /** - * Writes an event to the SSE stream with proper formatting - */ - private writeSSEEvent(res: ServerResponse, message: JSONRPCMessage, eventId?: string): boolean { - let eventData = `event: message\n`; - // Include event ID if provided - this is important for resumability - if (eventId) { - eventData += `id: ${eventId}\n`; } - eventData += `data: ${JSON.stringify(message)}\n\n`; - - return res.write(eventData); - } - - /** - * Handles unsupported requests (PUT, PATCH, etc.) - */ - private async handleUnsupportedRequest(res: ServerResponse): Promise { - res.writeHead(405, { - "Allow": "GET, POST, DELETE" - }).end(JSON.stringify({ - jsonrpc: "2.0", - error: { - code: -32000, - message: "Method not allowed." - }, - id: null - })); - } - - /** - * Handles POST requests containing JSON-RPC messages - */ - private async handlePostRequest(req: IncomingMessage & { auth?: AuthInfo }, res: ServerResponse, parsedBody?: unknown): Promise { - try { - // Validate the Accept header - const acceptHeader = req.headers.accept; - // The client MUST include an Accept header, listing both application/json and text/event-stream as supported content types. - if (!acceptHeader?.includes("application/json") || !acceptHeader.includes("text/event-stream")) { - res.writeHead(406).end(JSON.stringify({ - jsonrpc: "2.0", - error: { - code: -32000, - message: "Not Acceptable: Client must accept both application/json and text/event-stream" - }, - id: null - })); - return; - } - - const ct = req.headers["content-type"]; - if (!ct || !ct.includes("application/json")) { - res.writeHead(415).end(JSON.stringify({ - jsonrpc: "2.0", - error: { - code: -32000, - message: "Unsupported Media Type: Content-Type must be application/json" - }, - id: null - })); - return; - } - - const authInfo: AuthInfo | undefined = req.auth; - const requestInfo: RequestInfo = { headers: req.headers }; - - let rawMessage; - if (parsedBody !== undefined) { - rawMessage = parsedBody; - } else { - const parsedCt = contentType.parse(ct); - const body = await getRawBody(req, { - limit: MAXIMUM_MESSAGE_SIZE, - encoding: parsedCt.parameters.charset ?? "utf-8", - }); - rawMessage = JSON.parse(body.toString()); - } - - let messages: JSONRPCMessage[]; - - // handle batch and single messages - if (Array.isArray(rawMessage)) { - messages = rawMessage.map(msg => JSONRPCMessageSchema.parse(msg)); - } else { - messages = [JSONRPCMessageSchema.parse(rawMessage)]; - } - - // Check if this is an initialization request - // https://spec.modelcontextprotocol.io/specification/2025-03-26/basic/lifecycle/ - const isInitializationRequest = messages.some(isInitializeRequest); - if (isInitializationRequest) { - // If it's a server with session management and the session ID is already set we should reject the request - // to avoid re-initialization. - if (this._initialized && this.sessionId !== undefined) { - res.writeHead(400).end(JSON.stringify({ - jsonrpc: "2.0", - error: { - code: -32600, - message: "Invalid Request: Server already initialized" - }, - id: null - })); - return; - } - if (messages.length > 1) { - res.writeHead(400).end(JSON.stringify({ - jsonrpc: "2.0", - error: { - code: -32600, - message: "Invalid Request: Only one initialization request is allowed" - }, - id: null - })); - return; - } - this.sessionId = this.sessionIdGenerator?.(); - this._initialized = true; - // If we have a session ID and an onsessioninitialized handler, call it immediately - // This is needed in cases where the server needs to keep track of multiple sessions - if (this.sessionId && this._onsessioninitialized) { - await Promise.resolve(this._onsessioninitialized(this.sessionId)); + /** + * Handles GET requests for SSE stream + */ + private async handleGetRequest(req: IncomingMessage, res: ServerResponse): Promise { + // The client MUST include an Accept header, listing text/event-stream as a supported content type. + const acceptHeader = req.headers.accept; + if (!acceptHeader?.includes('text/event-stream')) { + res.writeHead(406).end( + JSON.stringify({ + jsonrpc: '2.0', + error: { + code: -32000, + message: 'Not Acceptable: Client must accept text/event-stream' + }, + id: null + }) + ); + return; } - } - if (!isInitializationRequest) { // If an Mcp-Session-Id is returned by the server during initialization, - // clients using the Streamable HTTP transport MUST include it + // clients using the Streamable HTTP transport MUST include it // in the Mcp-Session-Id header on all of their subsequent HTTP requests. if (!this.validateSession(req, res)) { - return; + return; } - // Mcp-Protocol-Version header is required for all requests after initialization. if (!this.validateProtocolVersion(req, res)) { - return; + return; + } + // Handle resumability: check for Last-Event-ID header + if (this._eventStore) { + const lastEventId = req.headers['last-event-id'] as string | undefined; + if (lastEventId) { + await this.replayEvents(lastEventId, res); + return; + } } - } - - - // check if it contains requests - const hasRequests = messages.some(isJSONRPCRequest); - if (!hasRequests) { - // if it only contains notifications or responses, return 202 - res.writeHead(202).end(); + // The server MUST either return Content-Type: text/event-stream in response to this HTTP GET, + // or else return HTTP 405 Method Not Allowed + const headers: Record = { + 'Content-Type': 'text/event-stream', + 'Cache-Control': 'no-cache, no-transform', + Connection: 'keep-alive' + }; - // handle each message - for (const message of messages) { - this.onmessage?.(message, { authInfo, requestInfo }); - } - } else if (hasRequests) { - // The default behavior is to use SSE streaming - // but in some cases server will return JSON responses - const streamId = randomUUID(); - if (!this._enableJsonResponse) { - const headers: Record = { - "Content-Type": "text/event-stream", - "Cache-Control": "no-cache", - Connection: "keep-alive", - }; - - // After initialization, always include the session ID if we have one - if (this.sessionId !== undefined) { - headers["mcp-session-id"] = this.sessionId; - } - - res.writeHead(200, headers); + // After initialization, always include the session ID if we have one + if (this.sessionId !== undefined) { + headers['mcp-session-id'] = this.sessionId; } - // Store the response for this request to send messages back through this connection - // We need to track by request ID to maintain the connection - for (const message of messages) { - if (isJSONRPCRequest(message)) { - this._streamMapping.set(streamId, res); - this._requestToStreamMapping.set(message.id, streamId); - } + + // Check if there's already an active standalone SSE stream for this session + if (this._streamMapping.get(this._standaloneSseStreamId) !== undefined) { + // Only one GET SSE stream is allowed per session + res.writeHead(409).end( + JSON.stringify({ + jsonrpc: '2.0', + error: { + code: -32000, + message: 'Conflict: Only one SSE stream is allowed per session' + }, + id: null + }) + ); + return; } + + // We need to send headers immediately as messages will arrive much later, + // otherwise the client will just wait for the first message + res.writeHead(200, headers).flushHeaders(); + + // Assign the response to the standalone SSE stream + this._streamMapping.set(this._standaloneSseStreamId, res); // Set up close handler for client disconnects - res.on("close", () => { - this._streamMapping.delete(streamId); + res.on('close', () => { + this._streamMapping.delete(this._standaloneSseStreamId); }); - // Add error handler for stream write errors - res.on("error", (error) => { - this.onerror?.(error as Error); + // Add error handler for standalone SSE stream + res.on('error', error => { + this.onerror?.(error as Error); }); + } - // handle each message - for (const message of messages) { - this.onmessage?.(message, { authInfo, requestInfo }); + /** + * Replays events that would have been sent after the specified event ID + * Only used when resumability is enabled + */ + private async replayEvents(lastEventId: string, res: ServerResponse): Promise { + if (!this._eventStore) { + return; } - // The server SHOULD NOT close the SSE stream before sending all JSON-RPC responses - // This will be handled by the send() method when responses are ready - } - } catch (error) { - // return JSON-RPC formatted error - res.writeHead(400).end(JSON.stringify({ - jsonrpc: "2.0", - error: { - code: -32700, - message: "Parse error", - data: String(error) - }, - id: null - })); - this.onerror?.(error as Error); - } - } - - /** - * Handles DELETE requests to terminate sessions - */ - private async handleDeleteRequest(req: IncomingMessage, res: ServerResponse): Promise { - if (!this.validateSession(req, res)) { - return; - } - if (!this.validateProtocolVersion(req, res)) { - return; - } - await Promise.resolve(this._onsessionclosed?.(this.sessionId!)); - await this.close(); - res.writeHead(200).end(); - } - - /** - * Validates session ID for non-initialization requests - * Returns true if the session is valid, false otherwise - */ - private validateSession(req: IncomingMessage, res: ServerResponse): boolean { - if (this.sessionIdGenerator === undefined) { - // If the sessionIdGenerator ID is not set, the session management is disabled - // and we don't need to validate the session ID - return true; - } - if (!this._initialized) { - // If the server has not been initialized yet, reject all requests - res.writeHead(400).end(JSON.stringify({ - jsonrpc: "2.0", - error: { - code: -32000, - message: "Bad Request: Server not initialized" - }, - id: null - })); - return false; - } + try { + const headers: Record = { + 'Content-Type': 'text/event-stream', + 'Cache-Control': 'no-cache, no-transform', + Connection: 'keep-alive' + }; + + if (this.sessionId !== undefined) { + headers['mcp-session-id'] = this.sessionId; + } + res.writeHead(200, headers).flushHeaders(); + + const streamId = await this._eventStore?.replayEventsAfter(lastEventId, { + send: async (eventId: string, message: JSONRPCMessage) => { + if (!this.writeSSEEvent(res, message, eventId)) { + this.onerror?.(new Error('Failed replay events')); + res.end(); + } + } + }); + this._streamMapping.set(streamId, res); - const sessionId = req.headers["mcp-session-id"]; - - if (!sessionId) { - // Non-initialization requests without a session ID should return 400 Bad Request - res.writeHead(400).end(JSON.stringify({ - jsonrpc: "2.0", - error: { - code: -32000, - message: "Bad Request: Mcp-Session-Id header is required" - }, - id: null - })); - return false; - } else if (Array.isArray(sessionId)) { - res.writeHead(400).end(JSON.stringify({ - jsonrpc: "2.0", - error: { - code: -32000, - message: "Bad Request: Mcp-Session-Id header must be a single value" - }, - id: null - })); - return false; - } - else if (sessionId !== this.sessionId) { - // Reject requests with invalid session ID with 404 Not Found - res.writeHead(404).end(JSON.stringify({ - jsonrpc: "2.0", - error: { - code: -32001, - message: "Session not found" - }, - id: null - })); - return false; + // Add error handler for replay stream + res.on('error', error => { + this.onerror?.(error as Error); + }); + } catch (error) { + this.onerror?.(error as Error); + } } - return true; - } + /** + * Writes an event to the SSE stream with proper formatting + */ + private writeSSEEvent(res: ServerResponse, message: JSONRPCMessage, eventId?: string): boolean { + let eventData = `event: message\n`; + // Include event ID if provided - this is important for resumability + if (eventId) { + eventData += `id: ${eventId}\n`; + } + eventData += `data: ${JSON.stringify(message)}\n\n`; - private validateProtocolVersion(req: IncomingMessage, res: ServerResponse): boolean { - let protocolVersion = req.headers["mcp-protocol-version"] ?? DEFAULT_NEGOTIATED_PROTOCOL_VERSION; - if (Array.isArray(protocolVersion)) { - protocolVersion = protocolVersion[protocolVersion.length - 1]; + return res.write(eventData); } - if (!SUPPORTED_PROTOCOL_VERSIONS.includes(protocolVersion)) { - res.writeHead(400).end(JSON.stringify({ - jsonrpc: "2.0", - error: { - code: -32000, - message: `Bad Request: Unsupported protocol version (supported versions: ${SUPPORTED_PROTOCOL_VERSIONS.join(", ")})` - }, - id: null - })); - return false; + /** + * Handles unsupported requests (PUT, PATCH, etc.) + */ + private async handleUnsupportedRequest(res: ServerResponse): Promise { + res.writeHead(405, { + Allow: 'GET, POST, DELETE' + }).end( + JSON.stringify({ + jsonrpc: '2.0', + error: { + code: -32000, + message: 'Method not allowed.' + }, + id: null + }) + ); } - return true; - } - - async close(): Promise { - // Close all SSE connections - this._streamMapping.forEach((response) => { - response.end(); - }); - this._streamMapping.clear(); - - // Clear any pending responses - this._requestResponseMap.clear(); - this.onclose?.(); - } - - async send(message: JSONRPCMessage, options?: { relatedRequestId?: RequestId }): Promise { - let requestId = options?.relatedRequestId; - if (isJSONRPCResponse(message) || isJSONRPCError(message)) { - // If the message is a response, use the request ID from the message - requestId = message.id; + + /** + * Handles POST requests containing JSON-RPC messages + */ + private async handlePostRequest(req: IncomingMessage & { auth?: AuthInfo }, res: ServerResponse, parsedBody?: unknown): Promise { + try { + // Validate the Accept header + const acceptHeader = req.headers.accept; + // The client MUST include an Accept header, listing both application/json and text/event-stream as supported content types. + if (!acceptHeader?.includes('application/json') || !acceptHeader.includes('text/event-stream')) { + res.writeHead(406).end( + JSON.stringify({ + jsonrpc: '2.0', + error: { + code: -32000, + message: 'Not Acceptable: Client must accept both application/json and text/event-stream' + }, + id: null + }) + ); + return; + } + + const ct = req.headers['content-type']; + if (!ct || !ct.includes('application/json')) { + res.writeHead(415).end( + JSON.stringify({ + jsonrpc: '2.0', + error: { + code: -32000, + message: 'Unsupported Media Type: Content-Type must be application/json' + }, + id: null + }) + ); + return; + } + + const authInfo: AuthInfo | undefined = req.auth; + const requestInfo: RequestInfo = { headers: req.headers }; + + let rawMessage; + if (parsedBody !== undefined) { + rawMessage = parsedBody; + } else { + const parsedCt = contentType.parse(ct); + const body = await getRawBody(req, { + limit: MAXIMUM_MESSAGE_SIZE, + encoding: parsedCt.parameters.charset ?? 'utf-8' + }); + rawMessage = JSON.parse(body.toString()); + } + + let messages: JSONRPCMessage[]; + + // handle batch and single messages + if (Array.isArray(rawMessage)) { + messages = rawMessage.map(msg => JSONRPCMessageSchema.parse(msg)); + } else { + messages = [JSONRPCMessageSchema.parse(rawMessage)]; + } + + // Check if this is an initialization request + // https://spec.modelcontextprotocol.io/specification/2025-03-26/basic/lifecycle/ + const isInitializationRequest = messages.some(isInitializeRequest); + if (isInitializationRequest) { + // If it's a server with session management and the session ID is already set we should reject the request + // to avoid re-initialization. + if (this._initialized && this.sessionId !== undefined) { + res.writeHead(400).end( + JSON.stringify({ + jsonrpc: '2.0', + error: { + code: -32600, + message: 'Invalid Request: Server already initialized' + }, + id: null + }) + ); + return; + } + if (messages.length > 1) { + res.writeHead(400).end( + JSON.stringify({ + jsonrpc: '2.0', + error: { + code: -32600, + message: 'Invalid Request: Only one initialization request is allowed' + }, + id: null + }) + ); + return; + } + this.sessionId = this.sessionIdGenerator?.(); + this._initialized = true; + + // If we have a session ID and an onsessioninitialized handler, call it immediately + // This is needed in cases where the server needs to keep track of multiple sessions + if (this.sessionId && this._onsessioninitialized) { + await Promise.resolve(this._onsessioninitialized(this.sessionId)); + } + } + if (!isInitializationRequest) { + // If an Mcp-Session-Id is returned by the server during initialization, + // clients using the Streamable HTTP transport MUST include it + // in the Mcp-Session-Id header on all of their subsequent HTTP requests. + if (!this.validateSession(req, res)) { + return; + } + // Mcp-Protocol-Version header is required for all requests after initialization. + if (!this.validateProtocolVersion(req, res)) { + return; + } + } + + // check if it contains requests + const hasRequests = messages.some(isJSONRPCRequest); + + if (!hasRequests) { + // if it only contains notifications or responses, return 202 + res.writeHead(202).end(); + + // handle each message + for (const message of messages) { + this.onmessage?.(message, { authInfo, requestInfo }); + } + } else if (hasRequests) { + // The default behavior is to use SSE streaming + // but in some cases server will return JSON responses + const streamId = randomUUID(); + if (!this._enableJsonResponse) { + const headers: Record = { + 'Content-Type': 'text/event-stream', + 'Cache-Control': 'no-cache', + Connection: 'keep-alive' + }; + + // After initialization, always include the session ID if we have one + if (this.sessionId !== undefined) { + headers['mcp-session-id'] = this.sessionId; + } + + res.writeHead(200, headers); + } + // Store the response for this request to send messages back through this connection + // We need to track by request ID to maintain the connection + for (const message of messages) { + if (isJSONRPCRequest(message)) { + this._streamMapping.set(streamId, res); + this._requestToStreamMapping.set(message.id, streamId); + } + } + // Set up close handler for client disconnects + res.on('close', () => { + this._streamMapping.delete(streamId); + }); + + // Add error handler for stream write errors + res.on('error', error => { + this.onerror?.(error as Error); + }); + + // handle each message + for (const message of messages) { + this.onmessage?.(message, { authInfo, requestInfo }); + } + // The server SHOULD NOT close the SSE stream before sending all JSON-RPC responses + // This will be handled by the send() method when responses are ready + } + } catch (error) { + // return JSON-RPC formatted error + res.writeHead(400).end( + JSON.stringify({ + jsonrpc: '2.0', + error: { + code: -32700, + message: 'Parse error', + data: String(error) + }, + id: null + }) + ); + this.onerror?.(error as Error); + } } - // Check if this message should be sent on the standalone SSE stream (no request ID) - // Ignore notifications from tools (which have relatedRequestId set) - // Those will be sent via dedicated response SSE streams - if (requestId === undefined) { - // For standalone SSE streams, we can only send requests and notifications - if (isJSONRPCResponse(message) || isJSONRPCError(message)) { - throw new Error("Cannot send a response on a standalone SSE stream unless resuming a previous client request"); - } - const standaloneSse = this._streamMapping.get(this._standaloneSseStreamId) - if (standaloneSse === undefined) { - // The spec says the server MAY send messages on the stream, so it's ok to discard if no stream - return; - } - - // Generate and store event ID if event store is provided - let eventId: string | undefined; - if (this._eventStore) { - // Stores the event and gets the generated event ID - eventId = await this._eventStore.storeEvent(this._standaloneSseStreamId, message); - } - - // Send the message to the standalone SSE stream - this.writeSSEEvent(standaloneSse, message, eventId); - return; + /** + * Handles DELETE requests to terminate sessions + */ + private async handleDeleteRequest(req: IncomingMessage, res: ServerResponse): Promise { + if (!this.validateSession(req, res)) { + return; + } + if (!this.validateProtocolVersion(req, res)) { + return; + } + await Promise.resolve(this._onsessionclosed?.(this.sessionId!)); + await this.close(); + res.writeHead(200).end(); } - // Get the response for this request - const streamId = this._requestToStreamMapping.get(requestId); - const response = this._streamMapping.get(streamId!); - if (!streamId) { - throw new Error(`No connection established for request ID: ${String(requestId)}`); + /** + * Validates session ID for non-initialization requests + * Returns true if the session is valid, false otherwise + */ + private validateSession(req: IncomingMessage, res: ServerResponse): boolean { + if (this.sessionIdGenerator === undefined) { + // If the sessionIdGenerator ID is not set, the session management is disabled + // and we don't need to validate the session ID + return true; + } + if (!this._initialized) { + // If the server has not been initialized yet, reject all requests + res.writeHead(400).end( + JSON.stringify({ + jsonrpc: '2.0', + error: { + code: -32000, + message: 'Bad Request: Server not initialized' + }, + id: null + }) + ); + return false; + } + + const sessionId = req.headers['mcp-session-id']; + + if (!sessionId) { + // Non-initialization requests without a session ID should return 400 Bad Request + res.writeHead(400).end( + JSON.stringify({ + jsonrpc: '2.0', + error: { + code: -32000, + message: 'Bad Request: Mcp-Session-Id header is required' + }, + id: null + }) + ); + return false; + } else if (Array.isArray(sessionId)) { + res.writeHead(400).end( + JSON.stringify({ + jsonrpc: '2.0', + error: { + code: -32000, + message: 'Bad Request: Mcp-Session-Id header must be a single value' + }, + id: null + }) + ); + return false; + } else if (sessionId !== this.sessionId) { + // Reject requests with invalid session ID with 404 Not Found + res.writeHead(404).end( + JSON.stringify({ + jsonrpc: '2.0', + error: { + code: -32001, + message: 'Session not found' + }, + id: null + }) + ); + return false; + } + + return true; } - if (!this._enableJsonResponse) { - // For SSE responses, generate event ID if event store is provided - let eventId: string | undefined; - - if (this._eventStore) { - eventId = await this._eventStore.storeEvent(streamId, message); - } - if (response) { - // Write the event to the response stream - this.writeSSEEvent(response, message, eventId); - } + private validateProtocolVersion(req: IncomingMessage, res: ServerResponse): boolean { + let protocolVersion = req.headers['mcp-protocol-version'] ?? DEFAULT_NEGOTIATED_PROTOCOL_VERSION; + if (Array.isArray(protocolVersion)) { + protocolVersion = protocolVersion[protocolVersion.length - 1]; + } + + if (!SUPPORTED_PROTOCOL_VERSIONS.includes(protocolVersion)) { + res.writeHead(400).end( + JSON.stringify({ + jsonrpc: '2.0', + error: { + code: -32000, + message: `Bad Request: Unsupported protocol version (supported versions: ${SUPPORTED_PROTOCOL_VERSIONS.join(', ')})` + }, + id: null + }) + ); + return false; + } + return true; } - if (isJSONRPCResponse(message) || isJSONRPCError(message)) { - this._requestResponseMap.set(requestId, message); - const relatedIds = Array.from(this._requestToStreamMapping.entries()) - .filter(([_, streamId]) => this._streamMapping.get(streamId) === response) - .map(([id]) => id); + async close(): Promise { + // Close all SSE connections + this._streamMapping.forEach(response => { + response.end(); + }); + this._streamMapping.clear(); + + // Clear any pending responses + this._requestResponseMap.clear(); + this.onclose?.(); + } - // Check if we have responses for all requests using this connection - const allResponsesReady = relatedIds.every(id => this._requestResponseMap.has(id)); + async send(message: JSONRPCMessage, options?: { relatedRequestId?: RequestId }): Promise { + let requestId = options?.relatedRequestId; + if (isJSONRPCResponse(message) || isJSONRPCError(message)) { + // If the message is a response, use the request ID from the message + requestId = message.id; + } - if (allResponsesReady) { - if (!response) { - throw new Error(`No connection established for request ID: ${String(requestId)}`); + // Check if this message should be sent on the standalone SSE stream (no request ID) + // Ignore notifications from tools (which have relatedRequestId set) + // Those will be sent via dedicated response SSE streams + if (requestId === undefined) { + // For standalone SSE streams, we can only send requests and notifications + if (isJSONRPCResponse(message) || isJSONRPCError(message)) { + throw new Error('Cannot send a response on a standalone SSE stream unless resuming a previous client request'); + } + const standaloneSse = this._streamMapping.get(this._standaloneSseStreamId); + if (standaloneSse === undefined) { + // The spec says the server MAY send messages on the stream, so it's ok to discard if no stream + return; + } + + // Generate and store event ID if event store is provided + let eventId: string | undefined; + if (this._eventStore) { + // Stores the event and gets the generated event ID + eventId = await this._eventStore.storeEvent(this._standaloneSseStreamId, message); + } + + // Send the message to the standalone SSE stream + this.writeSSEEvent(standaloneSse, message, eventId); + return; } - if (this._enableJsonResponse) { - // All responses ready, send as JSON - const headers: Record = { - 'Content-Type': 'application/json', - }; - if (this.sessionId !== undefined) { - headers['mcp-session-id'] = this.sessionId; - } - const responses = relatedIds - .map(id => this._requestResponseMap.get(id)!); + // Get the response for this request + const streamId = this._requestToStreamMapping.get(requestId); + const response = this._streamMapping.get(streamId!); + if (!streamId) { + throw new Error(`No connection established for request ID: ${String(requestId)}`); + } - response.writeHead(200, headers); - if (responses.length === 1) { - response.end(JSON.stringify(responses[0])); - } else { - response.end(JSON.stringify(responses)); - } - } else { - // End the SSE stream - response.end(); + if (!this._enableJsonResponse) { + // For SSE responses, generate event ID if event store is provided + let eventId: string | undefined; + + if (this._eventStore) { + eventId = await this._eventStore.storeEvent(streamId, message); + } + if (response) { + // Write the event to the response stream + this.writeSSEEvent(response, message, eventId); + } } - // Clean up - for (const id of relatedIds) { - this._requestResponseMap.delete(id); - this._requestToStreamMapping.delete(id); + + if (isJSONRPCResponse(message) || isJSONRPCError(message)) { + this._requestResponseMap.set(requestId, message); + const relatedIds = Array.from(this._requestToStreamMapping.entries()) + .filter(([_, streamId]) => this._streamMapping.get(streamId) === response) + .map(([id]) => id); + + // Check if we have responses for all requests using this connection + const allResponsesReady = relatedIds.every(id => this._requestResponseMap.has(id)); + + if (allResponsesReady) { + if (!response) { + throw new Error(`No connection established for request ID: ${String(requestId)}`); + } + if (this._enableJsonResponse) { + // All responses ready, send as JSON + const headers: Record = { + 'Content-Type': 'application/json' + }; + if (this.sessionId !== undefined) { + headers['mcp-session-id'] = this.sessionId; + } + + const responses = relatedIds.map(id => this._requestResponseMap.get(id)!); + + response.writeHead(200, headers); + if (responses.length === 1) { + response.end(JSON.stringify(responses[0])); + } else { + response.end(JSON.stringify(responses)); + } + } else { + // End the SSE stream + response.end(); + } + // Clean up + for (const id of relatedIds) { + this._requestResponseMap.delete(id); + this._requestToStreamMapping.delete(id); + } + } } - } } - } } - diff --git a/src/server/title.test.ts b/src/server/title.test.ts index 3f64570b8..9606fce44 100644 --- a/src/server/title.test.ts +++ b/src/server/title.test.ts @@ -1,236 +1,217 @@ -import { Server } from "./index.js"; -import { Client } from "../client/index.js"; -import { InMemoryTransport } from "../inMemory.js"; -import { z } from "zod"; -import { McpServer, ResourceTemplate } from "./mcp.js"; - -describe("Title field backwards compatibility", () => { - it("should work with tools that have title", async () => { - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - const server = new McpServer( - { name: "test-server", version: "1.0.0" }, - { capabilities: {} } - ); - - // Register tool with title - server.registerTool( - "test-tool", - { - title: "Test Tool Display Name", - description: "A test tool", - inputSchema: { - value: z.string() - } - }, - async () => ({ content: [{ type: "text", text: "result" }] }) - ); - - const client = new Client({ name: "test-client", version: "1.0.0" }); - - await server.server.connect(serverTransport); - await client.connect(clientTransport); - - const tools = await client.listTools(); - expect(tools.tools).toHaveLength(1); - expect(tools.tools[0].name).toBe("test-tool"); - expect(tools.tools[0].title).toBe("Test Tool Display Name"); - expect(tools.tools[0].description).toBe("A test tool"); - }); - - it("should work with tools without title", async () => { - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - const server = new McpServer( - { name: "test-server", version: "1.0.0" }, - { capabilities: {} } - ); - - // Register tool without title - server.tool( - "test-tool", - "A test tool", - { value: z.string() }, - async () => ({ content: [{ type: "text", text: "result" }] }) - ); - - const client = new Client({ name: "test-client", version: "1.0.0" }); - - await server.server.connect(serverTransport); - await client.connect(clientTransport); - - const tools = await client.listTools(); - expect(tools.tools).toHaveLength(1); - expect(tools.tools[0].name).toBe("test-tool"); - expect(tools.tools[0].title).toBeUndefined(); - expect(tools.tools[0].description).toBe("A test tool"); - }); - - it("should work with prompts that have title using update", async () => { - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - const server = new McpServer( - { name: "test-server", version: "1.0.0" }, - { capabilities: {} } - ); - - // Register prompt with title by updating after creation - const prompt = server.prompt( - "test-prompt", - "A test prompt", - async () => ({ messages: [{ role: "user", content: { type: "text", text: "test" } }] }) - ); - prompt.update({ title: "Test Prompt Display Name" }); - - const client = new Client({ name: "test-client", version: "1.0.0" }); - - await server.server.connect(serverTransport); - await client.connect(clientTransport); - - const prompts = await client.listPrompts(); - expect(prompts.prompts).toHaveLength(1); - expect(prompts.prompts[0].name).toBe("test-prompt"); - expect(prompts.prompts[0].title).toBe("Test Prompt Display Name"); - expect(prompts.prompts[0].description).toBe("A test prompt"); - }); - - it("should work with prompts using registerPrompt", async () => { - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - const server = new McpServer( - { name: "test-server", version: "1.0.0" }, - { capabilities: {} } - ); - - // Register prompt with title using registerPrompt - server.registerPrompt( - "test-prompt", - { - title: "Test Prompt Display Name", - description: "A test prompt", - argsSchema: { input: z.string() } - }, - async ({ input }) => ({ - messages: [{ - role: "user", - content: { type: "text", text: `test: ${input}` } - }] - }) - ); - - const client = new Client({ name: "test-client", version: "1.0.0" }); - - await server.server.connect(serverTransport); - await client.connect(clientTransport); - - const prompts = await client.listPrompts(); - expect(prompts.prompts).toHaveLength(1); - expect(prompts.prompts[0].name).toBe("test-prompt"); - expect(prompts.prompts[0].title).toBe("Test Prompt Display Name"); - expect(prompts.prompts[0].description).toBe("A test prompt"); - expect(prompts.prompts[0].arguments).toHaveLength(1); - }); - - it("should work with resources using registerResource", async () => { - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - const server = new McpServer( - { name: "test-server", version: "1.0.0" }, - { capabilities: {} } - ); - - // Register resource with title using registerResource - server.registerResource( - "test-resource", - "https://example.com/test", - { - title: "Test Resource Display Name", - description: "A test resource", - mimeType: "text/plain" - }, - async () => ({ - contents: [{ - uri: "https://example.com/test", - text: "test content" - }] - }) - ); - - const client = new Client({ name: "test-client", version: "1.0.0" }); - - await server.server.connect(serverTransport); - await client.connect(clientTransport); - - const resources = await client.listResources(); - expect(resources.resources).toHaveLength(1); - expect(resources.resources[0].name).toBe("test-resource"); - expect(resources.resources[0].title).toBe("Test Resource Display Name"); - expect(resources.resources[0].description).toBe("A test resource"); - expect(resources.resources[0].mimeType).toBe("text/plain"); - }); - - it("should work with dynamic resources using registerResource", async () => { - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - const server = new McpServer( - { name: "test-server", version: "1.0.0" }, - { capabilities: {} } - ); - - // Register dynamic resource with title using registerResource - server.registerResource( - "user-profile", - new ResourceTemplate("users://{userId}/profile", { list: undefined }), - { - title: "User Profile", - description: "User profile information" - }, - async (uri, { userId }, _extra) => ({ - contents: [{ - uri: uri.href, - text: `Profile data for user ${userId}` - }] - }) - ); - - const client = new Client({ name: "test-client", version: "1.0.0" }); - - await server.server.connect(serverTransport); - await client.connect(clientTransport); - - const resourceTemplates = await client.listResourceTemplates(); - expect(resourceTemplates.resourceTemplates).toHaveLength(1); - expect(resourceTemplates.resourceTemplates[0].name).toBe("user-profile"); - expect(resourceTemplates.resourceTemplates[0].title).toBe("User Profile"); - expect(resourceTemplates.resourceTemplates[0].description).toBe("User profile information"); - expect(resourceTemplates.resourceTemplates[0].uriTemplate).toBe("users://{userId}/profile"); - - // Test reading the resource - const readResult = await client.readResource({ uri: "users://123/profile" }); - expect(readResult.contents).toHaveLength(1); - expect(readResult.contents[0].text).toBe("Profile data for user 123"); - }); - - it("should support serverInfo with title", async () => { - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - const server = new Server( - { - name: "test-server", - version: "1.0.0", - title: "Test Server Display Name" - }, - { capabilities: {} } - ); - - const client = new Client({ name: "test-client", version: "1.0.0" }); - - await server.connect(serverTransport); - await client.connect(clientTransport); - - const serverInfo = client.getServerVersion(); - expect(serverInfo?.name).toBe("test-server"); - expect(serverInfo?.version).toBe("1.0.0"); - expect(serverInfo?.title).toBe("Test Server Display Name"); - }); -}); \ No newline at end of file +import { Server } from './index.js'; +import { Client } from '../client/index.js'; +import { InMemoryTransport } from '../inMemory.js'; +import { z } from 'zod'; +import { McpServer, ResourceTemplate } from './mcp.js'; + +describe('Title field backwards compatibility', () => { + it('should work with tools that have title', async () => { + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + const server = new McpServer({ name: 'test-server', version: '1.0.0' }, { capabilities: {} }); + + // Register tool with title + server.registerTool( + 'test-tool', + { + title: 'Test Tool Display Name', + description: 'A test tool', + inputSchema: { + value: z.string() + } + }, + async () => ({ content: [{ type: 'text', text: 'result' }] }) + ); + + const client = new Client({ name: 'test-client', version: '1.0.0' }); + + await server.server.connect(serverTransport); + await client.connect(clientTransport); + + const tools = await client.listTools(); + expect(tools.tools).toHaveLength(1); + expect(tools.tools[0].name).toBe('test-tool'); + expect(tools.tools[0].title).toBe('Test Tool Display Name'); + expect(tools.tools[0].description).toBe('A test tool'); + }); + + it('should work with tools without title', async () => { + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + const server = new McpServer({ name: 'test-server', version: '1.0.0' }, { capabilities: {} }); + + // Register tool without title + server.tool('test-tool', 'A test tool', { value: z.string() }, async () => ({ content: [{ type: 'text', text: 'result' }] })); + + const client = new Client({ name: 'test-client', version: '1.0.0' }); + + await server.server.connect(serverTransport); + await client.connect(clientTransport); + + const tools = await client.listTools(); + expect(tools.tools).toHaveLength(1); + expect(tools.tools[0].name).toBe('test-tool'); + expect(tools.tools[0].title).toBeUndefined(); + expect(tools.tools[0].description).toBe('A test tool'); + }); + + it('should work with prompts that have title using update', async () => { + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + const server = new McpServer({ name: 'test-server', version: '1.0.0' }, { capabilities: {} }); + + // Register prompt with title by updating after creation + const prompt = server.prompt('test-prompt', 'A test prompt', async () => ({ + messages: [{ role: 'user', content: { type: 'text', text: 'test' } }] + })); + prompt.update({ title: 'Test Prompt Display Name' }); + + const client = new Client({ name: 'test-client', version: '1.0.0' }); + + await server.server.connect(serverTransport); + await client.connect(clientTransport); + + const prompts = await client.listPrompts(); + expect(prompts.prompts).toHaveLength(1); + expect(prompts.prompts[0].name).toBe('test-prompt'); + expect(prompts.prompts[0].title).toBe('Test Prompt Display Name'); + expect(prompts.prompts[0].description).toBe('A test prompt'); + }); + + it('should work with prompts using registerPrompt', async () => { + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + const server = new McpServer({ name: 'test-server', version: '1.0.0' }, { capabilities: {} }); + + // Register prompt with title using registerPrompt + server.registerPrompt( + 'test-prompt', + { + title: 'Test Prompt Display Name', + description: 'A test prompt', + argsSchema: { input: z.string() } + }, + async ({ input }) => ({ + messages: [ + { + role: 'user', + content: { type: 'text', text: `test: ${input}` } + } + ] + }) + ); + + const client = new Client({ name: 'test-client', version: '1.0.0' }); + + await server.server.connect(serverTransport); + await client.connect(clientTransport); + + const prompts = await client.listPrompts(); + expect(prompts.prompts).toHaveLength(1); + expect(prompts.prompts[0].name).toBe('test-prompt'); + expect(prompts.prompts[0].title).toBe('Test Prompt Display Name'); + expect(prompts.prompts[0].description).toBe('A test prompt'); + expect(prompts.prompts[0].arguments).toHaveLength(1); + }); + + it('should work with resources using registerResource', async () => { + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + const server = new McpServer({ name: 'test-server', version: '1.0.0' }, { capabilities: {} }); + + // Register resource with title using registerResource + server.registerResource( + 'test-resource', + 'https://example.com/test', + { + title: 'Test Resource Display Name', + description: 'A test resource', + mimeType: 'text/plain' + }, + async () => ({ + contents: [ + { + uri: 'https://example.com/test', + text: 'test content' + } + ] + }) + ); + + const client = new Client({ name: 'test-client', version: '1.0.0' }); + + await server.server.connect(serverTransport); + await client.connect(clientTransport); + + const resources = await client.listResources(); + expect(resources.resources).toHaveLength(1); + expect(resources.resources[0].name).toBe('test-resource'); + expect(resources.resources[0].title).toBe('Test Resource Display Name'); + expect(resources.resources[0].description).toBe('A test resource'); + expect(resources.resources[0].mimeType).toBe('text/plain'); + }); + + it('should work with dynamic resources using registerResource', async () => { + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + const server = new McpServer({ name: 'test-server', version: '1.0.0' }, { capabilities: {} }); + + // Register dynamic resource with title using registerResource + server.registerResource( + 'user-profile', + new ResourceTemplate('users://{userId}/profile', { list: undefined }), + { + title: 'User Profile', + description: 'User profile information' + }, + async (uri, { userId }, _extra) => ({ + contents: [ + { + uri: uri.href, + text: `Profile data for user ${userId}` + } + ] + }) + ); + + const client = new Client({ name: 'test-client', version: '1.0.0' }); + + await server.server.connect(serverTransport); + await client.connect(clientTransport); + + const resourceTemplates = await client.listResourceTemplates(); + expect(resourceTemplates.resourceTemplates).toHaveLength(1); + expect(resourceTemplates.resourceTemplates[0].name).toBe('user-profile'); + expect(resourceTemplates.resourceTemplates[0].title).toBe('User Profile'); + expect(resourceTemplates.resourceTemplates[0].description).toBe('User profile information'); + expect(resourceTemplates.resourceTemplates[0].uriTemplate).toBe('users://{userId}/profile'); + + // Test reading the resource + const readResult = await client.readResource({ uri: 'users://123/profile' }); + expect(readResult.contents).toHaveLength(1); + expect(readResult.contents[0].text).toBe('Profile data for user 123'); + }); + + it('should support serverInfo with title', async () => { + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + const server = new Server( + { + name: 'test-server', + version: '1.0.0', + title: 'Test Server Display Name' + }, + { capabilities: {} } + ); + + const client = new Client({ name: 'test-client', version: '1.0.0' }); + + await server.connect(serverTransport); + await client.connect(clientTransport); + + const serverInfo = client.getServerVersion(); + expect(serverInfo?.name).toBe('test-server'); + expect(serverInfo?.version).toBe('1.0.0'); + expect(serverInfo?.title).toBe('Test Server Display Name'); + }); +}); diff --git a/src/shared/auth-utils.test.ts b/src/shared/auth-utils.test.ts index c1fa7bdf1..04ba98d74 100644 --- a/src/shared/auth-utils.test.ts +++ b/src/shared/auth-utils.test.ts @@ -1,61 +1,90 @@ import { resourceUrlFromServerUrl, checkResourceAllowed } from './auth-utils.js'; describe('auth-utils', () => { - describe('resourceUrlFromServerUrl', () => { - it('should remove fragments', () => { - expect(resourceUrlFromServerUrl(new URL('https://example.com/path#fragment')).href).toBe('https://example.com/path'); - expect(resourceUrlFromServerUrl(new URL('https://example.com#fragment')).href).toBe('https://example.com/'); - expect(resourceUrlFromServerUrl(new URL('https://example.com/path?query=1#fragment')).href).toBe('https://example.com/path?query=1'); - }); + describe('resourceUrlFromServerUrl', () => { + it('should remove fragments', () => { + expect(resourceUrlFromServerUrl(new URL('https://example.com/path#fragment')).href).toBe('https://example.com/path'); + expect(resourceUrlFromServerUrl(new URL('https://example.com#fragment')).href).toBe('https://example.com/'); + expect(resourceUrlFromServerUrl(new URL('https://example.com/path?query=1#fragment')).href).toBe( + 'https://example.com/path?query=1' + ); + }); - it('should return URL unchanged if no fragment', () => { - expect(resourceUrlFromServerUrl(new URL('https://example.com')).href).toBe('https://example.com/'); - expect(resourceUrlFromServerUrl(new URL('https://example.com/path')).href).toBe('https://example.com/path'); - expect(resourceUrlFromServerUrl(new URL('https://example.com/path?query=1')).href).toBe('https://example.com/path?query=1'); - }); + it('should return URL unchanged if no fragment', () => { + expect(resourceUrlFromServerUrl(new URL('https://example.com')).href).toBe('https://example.com/'); + expect(resourceUrlFromServerUrl(new URL('https://example.com/path')).href).toBe('https://example.com/path'); + expect(resourceUrlFromServerUrl(new URL('https://example.com/path?query=1')).href).toBe('https://example.com/path?query=1'); + }); - it('should keep everything else unchanged', () => { - // Case sensitivity preserved - expect(resourceUrlFromServerUrl(new URL('https://EXAMPLE.COM/PATH')).href).toBe('https://example.com/PATH'); - // Ports preserved - expect(resourceUrlFromServerUrl(new URL('https://example.com:443/path')).href).toBe('https://example.com/path'); - expect(resourceUrlFromServerUrl(new URL('https://example.com:8080/path')).href).toBe('https://example.com:8080/path'); - // Query parameters preserved - expect(resourceUrlFromServerUrl(new URL('https://example.com?foo=bar&baz=qux')).href).toBe('https://example.com/?foo=bar&baz=qux'); - // Trailing slashes preserved - expect(resourceUrlFromServerUrl(new URL('https://example.com/')).href).toBe('https://example.com/'); - expect(resourceUrlFromServerUrl(new URL('https://example.com/path/')).href).toBe('https://example.com/path/'); + it('should keep everything else unchanged', () => { + // Case sensitivity preserved + expect(resourceUrlFromServerUrl(new URL('https://EXAMPLE.COM/PATH')).href).toBe('https://example.com/PATH'); + // Ports preserved + expect(resourceUrlFromServerUrl(new URL('https://example.com:443/path')).href).toBe('https://example.com/path'); + expect(resourceUrlFromServerUrl(new URL('https://example.com:8080/path')).href).toBe('https://example.com:8080/path'); + // Query parameters preserved + expect(resourceUrlFromServerUrl(new URL('https://example.com?foo=bar&baz=qux')).href).toBe( + 'https://example.com/?foo=bar&baz=qux' + ); + // Trailing slashes preserved + expect(resourceUrlFromServerUrl(new URL('https://example.com/')).href).toBe('https://example.com/'); + expect(resourceUrlFromServerUrl(new URL('https://example.com/path/')).href).toBe('https://example.com/path/'); + }); }); - }); - describe('resourceMatches', () => { - it('should match identical URLs', () => { - expect(checkResourceAllowed({ requestedResource: 'https://example.com/path', configuredResource: 'https://example.com/path' })).toBe(true); - expect(checkResourceAllowed({ requestedResource: 'https://example.com/', configuredResource: 'https://example.com/' })).toBe(true); - }); + describe('resourceMatches', () => { + it('should match identical URLs', () => { + expect( + checkResourceAllowed({ requestedResource: 'https://example.com/path', configuredResource: 'https://example.com/path' }) + ).toBe(true); + expect(checkResourceAllowed({ requestedResource: 'https://example.com/', configuredResource: 'https://example.com/' })).toBe( + true + ); + }); - it('should not match URLs with different paths', () => { - expect(checkResourceAllowed({ requestedResource: 'https://example.com/path1', configuredResource: 'https://example.com/path2' })).toBe(false); - expect(checkResourceAllowed({ requestedResource: 'https://example.com/', configuredResource: 'https://example.com/path' })).toBe(false); - }); + it('should not match URLs with different paths', () => { + expect( + checkResourceAllowed({ requestedResource: 'https://example.com/path1', configuredResource: 'https://example.com/path2' }) + ).toBe(false); + expect( + checkResourceAllowed({ requestedResource: 'https://example.com/', configuredResource: 'https://example.com/path' }) + ).toBe(false); + }); - it('should not match URLs with different domains', () => { - expect(checkResourceAllowed({ requestedResource: 'https://example.com/path', configuredResource: 'https://example.org/path' })).toBe(false); - }); + it('should not match URLs with different domains', () => { + expect( + checkResourceAllowed({ requestedResource: 'https://example.com/path', configuredResource: 'https://example.org/path' }) + ).toBe(false); + }); - it('should not match URLs with different ports', () => { - expect(checkResourceAllowed({ requestedResource: 'https://example.com:8080/path', configuredResource: 'https://example.com/path' })).toBe(false); - }); + it('should not match URLs with different ports', () => { + expect( + checkResourceAllowed({ requestedResource: 'https://example.com:8080/path', configuredResource: 'https://example.com/path' }) + ).toBe(false); + }); - it('should not match URLs where one path is a sub-path of another', () => { - expect(checkResourceAllowed({ requestedResource: 'https://example.com/mcpxxxx', configuredResource: 'https://example.com/mcp' })).toBe(false); - expect(checkResourceAllowed({ requestedResource: 'https://example.com/folder', configuredResource: 'https://example.com/folder/subfolder' })).toBe(false); - expect(checkResourceAllowed({ requestedResource: 'https://example.com/api/v1', configuredResource: 'https://example.com/api' })).toBe(true); - }); + it('should not match URLs where one path is a sub-path of another', () => { + expect( + checkResourceAllowed({ requestedResource: 'https://example.com/mcpxxxx', configuredResource: 'https://example.com/mcp' }) + ).toBe(false); + expect( + checkResourceAllowed({ + requestedResource: 'https://example.com/folder', + configuredResource: 'https://example.com/folder/subfolder' + }) + ).toBe(false); + expect( + checkResourceAllowed({ requestedResource: 'https://example.com/api/v1', configuredResource: 'https://example.com/api' }) + ).toBe(true); + }); - it('should handle trailing slashes vs no trailing slashes', () => { - expect(checkResourceAllowed({ requestedResource: 'https://example.com/mcp/', configuredResource: 'https://example.com/mcp' })).toBe(true); - expect(checkResourceAllowed({ requestedResource: 'https://example.com/folder', configuredResource: 'https://example.com/folder/' })).toBe(false); + it('should handle trailing slashes vs no trailing slashes', () => { + expect( + checkResourceAllowed({ requestedResource: 'https://example.com/mcp/', configuredResource: 'https://example.com/mcp' }) + ).toBe(true); + expect( + checkResourceAllowed({ requestedResource: 'https://example.com/folder', configuredResource: 'https://example.com/folder/' }) + ).toBe(false); + }); }); - }); }); diff --git a/src/shared/auth-utils.ts b/src/shared/auth-utils.ts index 97a77c01d..c9863da43 100644 --- a/src/shared/auth-utils.ts +++ b/src/shared/auth-utils.ts @@ -7,10 +7,10 @@ * RFC 8707 section 2 states that resource URIs "MUST NOT include a fragment component". * Keeps everything else unchanged (scheme, domain, port, path, query). */ -export function resourceUrlFromServerUrl(url: URL | string ): URL { - const resourceURL = typeof url === "string" ? new URL(url) : new URL(url.href); - resourceURL.hash = ''; // Remove fragment - return resourceURL; +export function resourceUrlFromServerUrl(url: URL | string): URL { + const resourceURL = typeof url === 'string' ? new URL(url) : new URL(url.href); + resourceURL.hash = ''; // Remove fragment + return resourceURL; } /** @@ -22,33 +22,34 @@ export function resourceUrlFromServerUrl(url: URL | string ): URL { * @param configuredResource The resource URL that has been configured * @returns true if the requested resource matches the configured resource, false otherwise */ - export function checkResourceAllowed( - { requestedResource, configuredResource }: { - requestedResource: URL | string; - configuredResource: URL | string - } - ): boolean { - const requested = typeof requestedResource === "string" ? new URL(requestedResource) : new URL(requestedResource.href); - const configured = typeof configuredResource === "string" ? new URL(configuredResource) : new URL(configuredResource.href); +export function checkResourceAllowed({ + requestedResource, + configuredResource +}: { + requestedResource: URL | string; + configuredResource: URL | string; +}): boolean { + const requested = typeof requestedResource === 'string' ? new URL(requestedResource) : new URL(requestedResource.href); + const configured = typeof configuredResource === 'string' ? new URL(configuredResource) : new URL(configuredResource.href); - // Compare the origin (scheme, domain, and port) - if (requested.origin !== configured.origin) { - return false; - } + // Compare the origin (scheme, domain, and port) + if (requested.origin !== configured.origin) { + return false; + } - // Handle cases like requested=/foo and configured=/foo/ - if (requested.pathname.length < configured.pathname.length) { - return false - } + // Handle cases like requested=/foo and configured=/foo/ + if (requested.pathname.length < configured.pathname.length) { + return false; + } - // Check if the requested path starts with the configured path - // Ensure both paths end with / for proper comparison - // This ensures that if we have paths like "/api" and "/api/users", - // we properly detect that "/api/users" is a subpath of "/api" - // By adding a trailing slash if missing, we avoid false positives - // where paths like "/api123" would incorrectly match "/api" - const requestedPath = requested.pathname.endsWith('/') ? requested.pathname : requested.pathname + '/'; - const configuredPath = configured.pathname.endsWith('/') ? configured.pathname : configured.pathname + '/'; + // Check if the requested path starts with the configured path + // Ensure both paths end with / for proper comparison + // This ensures that if we have paths like "/api" and "/api/users", + // we properly detect that "/api/users" is a subpath of "/api" + // By adding a trailing slash if missing, we avoid false positives + // where paths like "/api123" would incorrectly match "/api" + const requestedPath = requested.pathname.endsWith('/') ? requested.pathname : requested.pathname + '/'; + const configuredPath = configured.pathname.endsWith('/') ? configured.pathname : configured.pathname + '/'; - return requestedPath.startsWith(configuredPath); - } + return requestedPath.startsWith(configuredPath); +} diff --git a/src/shared/auth.test.ts b/src/shared/auth.test.ts index c1ed82ba2..41f9dc1a9 100644 --- a/src/shared/auth.test.ts +++ b/src/shared/auth.test.ts @@ -1,116 +1,111 @@ import { describe, it, expect } from '@jest/globals'; -import { - SafeUrlSchema, - OAuthMetadataSchema, - OpenIdProviderMetadataSchema, - OAuthClientMetadataSchema, -} from './auth.js'; +import { SafeUrlSchema, OAuthMetadataSchema, OpenIdProviderMetadataSchema, OAuthClientMetadataSchema } from './auth.js'; describe('SafeUrlSchema', () => { - it('accepts valid HTTPS URLs', () => { - expect(SafeUrlSchema.parse('https://example.com')).toBe('https://example.com'); - expect(SafeUrlSchema.parse('https://auth.example.com/oauth/authorize')).toBe('https://auth.example.com/oauth/authorize'); - }); - - it('accepts valid HTTP URLs', () => { - expect(SafeUrlSchema.parse('http://localhost:3000')).toBe('http://localhost:3000'); - }); - - it('rejects javascript: scheme URLs', () => { - expect(() => SafeUrlSchema.parse('javascript:alert(1)')).toThrow('URL cannot use javascript:, data:, or vbscript: scheme'); - expect(() => SafeUrlSchema.parse('JAVASCRIPT:alert(1)')).toThrow('URL cannot use javascript:, data:, or vbscript: scheme'); - }); - - it('rejects invalid URLs', () => { - expect(() => SafeUrlSchema.parse('not-a-url')).toThrow(); - expect(() => SafeUrlSchema.parse('')).toThrow(); - }); - - it('works with safeParse', () => { - expect(() => SafeUrlSchema.safeParse('not-a-url')).not.toThrow(); - }); + it('accepts valid HTTPS URLs', () => { + expect(SafeUrlSchema.parse('https://example.com')).toBe('https://example.com'); + expect(SafeUrlSchema.parse('https://auth.example.com/oauth/authorize')).toBe('https://auth.example.com/oauth/authorize'); + }); + + it('accepts valid HTTP URLs', () => { + expect(SafeUrlSchema.parse('http://localhost:3000')).toBe('http://localhost:3000'); + }); + + it('rejects javascript: scheme URLs', () => { + expect(() => SafeUrlSchema.parse('javascript:alert(1)')).toThrow('URL cannot use javascript:, data:, or vbscript: scheme'); + expect(() => SafeUrlSchema.parse('JAVASCRIPT:alert(1)')).toThrow('URL cannot use javascript:, data:, or vbscript: scheme'); + }); + + it('rejects invalid URLs', () => { + expect(() => SafeUrlSchema.parse('not-a-url')).toThrow(); + expect(() => SafeUrlSchema.parse('')).toThrow(); + }); + + it('works with safeParse', () => { + expect(() => SafeUrlSchema.safeParse('not-a-url')).not.toThrow(); + }); }); describe('OAuthMetadataSchema', () => { - it('validates complete OAuth metadata', () => { - const metadata = { - issuer: 'https://auth.example.com', - authorization_endpoint: 'https://auth.example.com/oauth/authorize', - token_endpoint: 'https://auth.example.com/oauth/token', - response_types_supported: ['code'], - scopes_supported: ['read', 'write'], - }; - - expect(() => OAuthMetadataSchema.parse(metadata)).not.toThrow(); - }); - - it('rejects metadata with javascript: URLs', () => { - const metadata = { - issuer: 'https://auth.example.com', - authorization_endpoint: 'javascript:alert(1)', - token_endpoint: 'https://auth.example.com/oauth/token', - response_types_supported: ['code'], - }; - - expect(() => OAuthMetadataSchema.parse(metadata)).toThrow('URL cannot use javascript:, data:, or vbscript: scheme'); - }); - - it('requires mandatory fields', () => { - const incompleteMetadata = { - issuer: 'https://auth.example.com', - }; - - expect(() => OAuthMetadataSchema.parse(incompleteMetadata)).toThrow(); - }); + it('validates complete OAuth metadata', () => { + const metadata = { + issuer: 'https://auth.example.com', + authorization_endpoint: 'https://auth.example.com/oauth/authorize', + token_endpoint: 'https://auth.example.com/oauth/token', + response_types_supported: ['code'], + scopes_supported: ['read', 'write'] + }; + + expect(() => OAuthMetadataSchema.parse(metadata)).not.toThrow(); + }); + + it('rejects metadata with javascript: URLs', () => { + const metadata = { + issuer: 'https://auth.example.com', + authorization_endpoint: 'javascript:alert(1)', + token_endpoint: 'https://auth.example.com/oauth/token', + response_types_supported: ['code'] + }; + + expect(() => OAuthMetadataSchema.parse(metadata)).toThrow('URL cannot use javascript:, data:, or vbscript: scheme'); + }); + + it('requires mandatory fields', () => { + const incompleteMetadata = { + issuer: 'https://auth.example.com' + }; + + expect(() => OAuthMetadataSchema.parse(incompleteMetadata)).toThrow(); + }); }); describe('OpenIdProviderMetadataSchema', () => { - it('validates complete OpenID Provider metadata', () => { - const metadata = { - issuer: 'https://auth.example.com', - authorization_endpoint: 'https://auth.example.com/oauth/authorize', - token_endpoint: 'https://auth.example.com/oauth/token', - jwks_uri: 'https://auth.example.com/.well-known/jwks.json', - response_types_supported: ['code'], - subject_types_supported: ['public'], - id_token_signing_alg_values_supported: ['RS256'], - }; - - expect(() => OpenIdProviderMetadataSchema.parse(metadata)).not.toThrow(); - }); - - it('rejects metadata with javascript: in jwks_uri', () => { - const metadata = { - issuer: 'https://auth.example.com', - authorization_endpoint: 'https://auth.example.com/oauth/authorize', - token_endpoint: 'https://auth.example.com/oauth/token', - jwks_uri: 'javascript:alert(1)', - response_types_supported: ['code'], - subject_types_supported: ['public'], - id_token_signing_alg_values_supported: ['RS256'], - }; - - expect(() => OpenIdProviderMetadataSchema.parse(metadata)).toThrow('URL cannot use javascript:, data:, or vbscript: scheme'); - }); + it('validates complete OpenID Provider metadata', () => { + const metadata = { + issuer: 'https://auth.example.com', + authorization_endpoint: 'https://auth.example.com/oauth/authorize', + token_endpoint: 'https://auth.example.com/oauth/token', + jwks_uri: 'https://auth.example.com/.well-known/jwks.json', + response_types_supported: ['code'], + subject_types_supported: ['public'], + id_token_signing_alg_values_supported: ['RS256'] + }; + + expect(() => OpenIdProviderMetadataSchema.parse(metadata)).not.toThrow(); + }); + + it('rejects metadata with javascript: in jwks_uri', () => { + const metadata = { + issuer: 'https://auth.example.com', + authorization_endpoint: 'https://auth.example.com/oauth/authorize', + token_endpoint: 'https://auth.example.com/oauth/token', + jwks_uri: 'javascript:alert(1)', + response_types_supported: ['code'], + subject_types_supported: ['public'], + id_token_signing_alg_values_supported: ['RS256'] + }; + + expect(() => OpenIdProviderMetadataSchema.parse(metadata)).toThrow('URL cannot use javascript:, data:, or vbscript: scheme'); + }); }); describe('OAuthClientMetadataSchema', () => { - it('validates client metadata with safe URLs', () => { - const metadata = { - redirect_uris: ['https://app.example.com/callback'], - client_name: 'Test App', - client_uri: 'https://app.example.com', - }; - - expect(() => OAuthClientMetadataSchema.parse(metadata)).not.toThrow(); - }); - - it('rejects client metadata with javascript: redirect URIs', () => { - const metadata = { - redirect_uris: ['javascript:alert(1)'], - client_name: 'Test App', - }; - - expect(() => OAuthClientMetadataSchema.parse(metadata)).toThrow('URL cannot use javascript:, data:, or vbscript: scheme'); - }); + it('validates client metadata with safe URLs', () => { + const metadata = { + redirect_uris: ['https://app.example.com/callback'], + client_name: 'Test App', + client_uri: 'https://app.example.com' + }; + + expect(() => OAuthClientMetadataSchema.parse(metadata)).not.toThrow(); + }); + + it('rejects client metadata with javascript: redirect URIs', () => { + const metadata = { + redirect_uris: ['javascript:alert(1)'], + client_name: 'Test App' + }; + + expect(() => OAuthClientMetadataSchema.parse(metadata)).toThrow('URL cannot use javascript:, data:, or vbscript: scheme'); + }); }); diff --git a/src/shared/auth.ts b/src/shared/auth.ts index 886eb1084..0e079646b 100644 --- a/src/shared/auth.ts +++ b/src/shared/auth.ts @@ -1,201 +1,191 @@ -import { z } from "zod"; +import { z } from 'zod'; /** * Reusable URL validation that disallows javascript: scheme */ -export const SafeUrlSchema = z.string().url() - .superRefine((val, ctx) => { - if (!URL.canParse(val)) { - ctx.addIssue({ - code: z.ZodIssueCode.custom, - message: "URL must be parseable", - fatal: true, - }); - - return z.NEVER; - } - }).refine( - (url) => { - const u = new URL(url); - return u.protocol !== 'javascript:' && u.protocol !== 'data:' && u.protocol !== 'vbscript:'; - }, - { message: "URL cannot use javascript:, data:, or vbscript: scheme" } -); - +export const SafeUrlSchema = z + .string() + .url() + .superRefine((val, ctx) => { + if (!URL.canParse(val)) { + ctx.addIssue({ + code: z.ZodIssueCode.custom, + message: 'URL must be parseable', + fatal: true + }); + + return z.NEVER; + } + }) + .refine( + url => { + const u = new URL(url); + return u.protocol !== 'javascript:' && u.protocol !== 'data:' && u.protocol !== 'vbscript:'; + }, + { message: 'URL cannot use javascript:, data:, or vbscript: scheme' } + ); /** * RFC 9728 OAuth Protected Resource Metadata */ export const OAuthProtectedResourceMetadataSchema = z - .object({ - resource: z.string().url(), - authorization_servers: z.array(SafeUrlSchema).optional(), - jwks_uri: z.string().url().optional(), - scopes_supported: z.array(z.string()).optional(), - bearer_methods_supported: z.array(z.string()).optional(), - resource_signing_alg_values_supported: z.array(z.string()).optional(), - resource_name: z.string().optional(), - resource_documentation: z.string().optional(), - resource_policy_uri: z.string().url().optional(), - resource_tos_uri: z.string().url().optional(), - tls_client_certificate_bound_access_tokens: z.boolean().optional(), - authorization_details_types_supported: z.array(z.string()).optional(), - dpop_signing_alg_values_supported: z.array(z.string()).optional(), - dpop_bound_access_tokens_required: z.boolean().optional(), - }) - .passthrough(); + .object({ + resource: z.string().url(), + authorization_servers: z.array(SafeUrlSchema).optional(), + jwks_uri: z.string().url().optional(), + scopes_supported: z.array(z.string()).optional(), + bearer_methods_supported: z.array(z.string()).optional(), + resource_signing_alg_values_supported: z.array(z.string()).optional(), + resource_name: z.string().optional(), + resource_documentation: z.string().optional(), + resource_policy_uri: z.string().url().optional(), + resource_tos_uri: z.string().url().optional(), + tls_client_certificate_bound_access_tokens: z.boolean().optional(), + authorization_details_types_supported: z.array(z.string()).optional(), + dpop_signing_alg_values_supported: z.array(z.string()).optional(), + dpop_bound_access_tokens_required: z.boolean().optional() + }) + .passthrough(); /** * RFC 8414 OAuth 2.0 Authorization Server Metadata */ export const OAuthMetadataSchema = z - .object({ - issuer: z.string(), - authorization_endpoint: SafeUrlSchema, - token_endpoint: SafeUrlSchema, - registration_endpoint: SafeUrlSchema.optional(), - scopes_supported: z.array(z.string()).optional(), - response_types_supported: z.array(z.string()), - response_modes_supported: z.array(z.string()).optional(), - grant_types_supported: z.array(z.string()).optional(), - token_endpoint_auth_methods_supported: z.array(z.string()).optional(), - token_endpoint_auth_signing_alg_values_supported: z - .array(z.string()) - .optional(), - service_documentation: SafeUrlSchema.optional(), - revocation_endpoint: SafeUrlSchema.optional(), - revocation_endpoint_auth_methods_supported: z.array(z.string()).optional(), - revocation_endpoint_auth_signing_alg_values_supported: z - .array(z.string()) - .optional(), - introspection_endpoint: z.string().optional(), - introspection_endpoint_auth_methods_supported: z - .array(z.string()) - .optional(), - introspection_endpoint_auth_signing_alg_values_supported: z - .array(z.string()) - .optional(), - code_challenge_methods_supported: z.array(z.string()).optional(), - }) - .passthrough(); + .object({ + issuer: z.string(), + authorization_endpoint: SafeUrlSchema, + token_endpoint: SafeUrlSchema, + registration_endpoint: SafeUrlSchema.optional(), + scopes_supported: z.array(z.string()).optional(), + response_types_supported: z.array(z.string()), + response_modes_supported: z.array(z.string()).optional(), + grant_types_supported: z.array(z.string()).optional(), + token_endpoint_auth_methods_supported: z.array(z.string()).optional(), + token_endpoint_auth_signing_alg_values_supported: z.array(z.string()).optional(), + service_documentation: SafeUrlSchema.optional(), + revocation_endpoint: SafeUrlSchema.optional(), + revocation_endpoint_auth_methods_supported: z.array(z.string()).optional(), + revocation_endpoint_auth_signing_alg_values_supported: z.array(z.string()).optional(), + introspection_endpoint: z.string().optional(), + introspection_endpoint_auth_methods_supported: z.array(z.string()).optional(), + introspection_endpoint_auth_signing_alg_values_supported: z.array(z.string()).optional(), + code_challenge_methods_supported: z.array(z.string()).optional() + }) + .passthrough(); /** * OpenID Connect Discovery 1.0 Provider Metadata * see: https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata */ export const OpenIdProviderMetadataSchema = z - .object({ - issuer: z.string(), - authorization_endpoint: SafeUrlSchema, - token_endpoint: SafeUrlSchema, - userinfo_endpoint: SafeUrlSchema.optional(), - jwks_uri: SafeUrlSchema, - registration_endpoint: SafeUrlSchema.optional(), - scopes_supported: z.array(z.string()).optional(), - response_types_supported: z.array(z.string()), - response_modes_supported: z.array(z.string()).optional(), - grant_types_supported: z.array(z.string()).optional(), - acr_values_supported: z.array(z.string()).optional(), - subject_types_supported: z.array(z.string()), - id_token_signing_alg_values_supported: z.array(z.string()), - id_token_encryption_alg_values_supported: z.array(z.string()).optional(), - id_token_encryption_enc_values_supported: z.array(z.string()).optional(), - userinfo_signing_alg_values_supported: z.array(z.string()).optional(), - userinfo_encryption_alg_values_supported: z.array(z.string()).optional(), - userinfo_encryption_enc_values_supported: z.array(z.string()).optional(), - request_object_signing_alg_values_supported: z.array(z.string()).optional(), - request_object_encryption_alg_values_supported: z - .array(z.string()) - .optional(), - request_object_encryption_enc_values_supported: z - .array(z.string()) - .optional(), - token_endpoint_auth_methods_supported: z.array(z.string()).optional(), - token_endpoint_auth_signing_alg_values_supported: z - .array(z.string()) - .optional(), - display_values_supported: z.array(z.string()).optional(), - claim_types_supported: z.array(z.string()).optional(), - claims_supported: z.array(z.string()).optional(), - service_documentation: z.string().optional(), - claims_locales_supported: z.array(z.string()).optional(), - ui_locales_supported: z.array(z.string()).optional(), - claims_parameter_supported: z.boolean().optional(), - request_parameter_supported: z.boolean().optional(), - request_uri_parameter_supported: z.boolean().optional(), - require_request_uri_registration: z.boolean().optional(), - op_policy_uri: SafeUrlSchema.optional(), - op_tos_uri: SafeUrlSchema.optional(), - }) - .passthrough(); + .object({ + issuer: z.string(), + authorization_endpoint: SafeUrlSchema, + token_endpoint: SafeUrlSchema, + userinfo_endpoint: SafeUrlSchema.optional(), + jwks_uri: SafeUrlSchema, + registration_endpoint: SafeUrlSchema.optional(), + scopes_supported: z.array(z.string()).optional(), + response_types_supported: z.array(z.string()), + response_modes_supported: z.array(z.string()).optional(), + grant_types_supported: z.array(z.string()).optional(), + acr_values_supported: z.array(z.string()).optional(), + subject_types_supported: z.array(z.string()), + id_token_signing_alg_values_supported: z.array(z.string()), + id_token_encryption_alg_values_supported: z.array(z.string()).optional(), + id_token_encryption_enc_values_supported: z.array(z.string()).optional(), + userinfo_signing_alg_values_supported: z.array(z.string()).optional(), + userinfo_encryption_alg_values_supported: z.array(z.string()).optional(), + userinfo_encryption_enc_values_supported: z.array(z.string()).optional(), + request_object_signing_alg_values_supported: z.array(z.string()).optional(), + request_object_encryption_alg_values_supported: z.array(z.string()).optional(), + request_object_encryption_enc_values_supported: z.array(z.string()).optional(), + token_endpoint_auth_methods_supported: z.array(z.string()).optional(), + token_endpoint_auth_signing_alg_values_supported: z.array(z.string()).optional(), + display_values_supported: z.array(z.string()).optional(), + claim_types_supported: z.array(z.string()).optional(), + claims_supported: z.array(z.string()).optional(), + service_documentation: z.string().optional(), + claims_locales_supported: z.array(z.string()).optional(), + ui_locales_supported: z.array(z.string()).optional(), + claims_parameter_supported: z.boolean().optional(), + request_parameter_supported: z.boolean().optional(), + request_uri_parameter_supported: z.boolean().optional(), + require_request_uri_registration: z.boolean().optional(), + op_policy_uri: SafeUrlSchema.optional(), + op_tos_uri: SafeUrlSchema.optional() + }) + .passthrough(); /** * OpenID Connect Discovery metadata that may include OAuth 2.0 fields * This schema represents the real-world scenario where OIDC providers * return a mix of OpenID Connect and OAuth 2.0 metadata fields */ -export const OpenIdProviderDiscoveryMetadataSchema = - OpenIdProviderMetadataSchema.merge( +export const OpenIdProviderDiscoveryMetadataSchema = OpenIdProviderMetadataSchema.merge( OAuthMetadataSchema.pick({ - code_challenge_methods_supported: true, + code_challenge_methods_supported: true }) - ); +); /** * OAuth 2.1 token response */ export const OAuthTokensSchema = z - .object({ - access_token: z.string(), - id_token: z.string().optional(), // Optional for OAuth 2.1, but necessary in OpenID Connect - token_type: z.string(), - expires_in: z.number().optional(), - scope: z.string().optional(), - refresh_token: z.string().optional(), - }) - .strip(); + .object({ + access_token: z.string(), + id_token: z.string().optional(), // Optional for OAuth 2.1, but necessary in OpenID Connect + token_type: z.string(), + expires_in: z.number().optional(), + scope: z.string().optional(), + refresh_token: z.string().optional() + }) + .strip(); /** * OAuth 2.1 error response */ -export const OAuthErrorResponseSchema = z - .object({ +export const OAuthErrorResponseSchema = z.object({ error: z.string(), error_description: z.string().optional(), - error_uri: z.string().optional(), - }); + error_uri: z.string().optional() +}); /** * RFC 7591 OAuth 2.0 Dynamic Client Registration metadata */ -export const OAuthClientMetadataSchema = z.object({ - redirect_uris: z.array(SafeUrlSchema), - token_endpoint_auth_method: z.string().optional(), - grant_types: z.array(z.string()).optional(), - response_types: z.array(z.string()).optional(), - client_name: z.string().optional(), - client_uri: SafeUrlSchema.optional(), - logo_uri: SafeUrlSchema.optional(), - scope: z.string().optional(), - contacts: z.array(z.string()).optional(), - tos_uri: SafeUrlSchema.optional(), - policy_uri: z.string().optional(), - jwks_uri: SafeUrlSchema.optional(), - jwks: z.any().optional(), - software_id: z.string().optional(), - software_version: z.string().optional(), - software_statement: z.string().optional(), -}).strip(); +export const OAuthClientMetadataSchema = z + .object({ + redirect_uris: z.array(SafeUrlSchema), + token_endpoint_auth_method: z.string().optional(), + grant_types: z.array(z.string()).optional(), + response_types: z.array(z.string()).optional(), + client_name: z.string().optional(), + client_uri: SafeUrlSchema.optional(), + logo_uri: SafeUrlSchema.optional(), + scope: z.string().optional(), + contacts: z.array(z.string()).optional(), + tos_uri: SafeUrlSchema.optional(), + policy_uri: z.string().optional(), + jwks_uri: SafeUrlSchema.optional(), + jwks: z.any().optional(), + software_id: z.string().optional(), + software_version: z.string().optional(), + software_statement: z.string().optional() + }) + .strip(); /** * RFC 7591 OAuth 2.0 Dynamic Client Registration client information */ -export const OAuthClientInformationSchema = z.object({ - client_id: z.string(), - client_secret: z.string().optional(), - client_id_issued_at: z.number().optional(), - client_secret_expires_at: z.number().optional(), -}).strip(); +export const OAuthClientInformationSchema = z + .object({ + client_id: z.string(), + client_secret: z.string().optional(), + client_id_issued_at: z.number().optional(), + client_secret_expires_at: z.number().optional() + }) + .strip(); /** * RFC 7591 OAuth 2.0 Dynamic Client Registration full response (client information plus metadata) @@ -205,18 +195,22 @@ export const OAuthClientInformationFullSchema = OAuthClientMetadataSchema.merge( /** * RFC 7591 OAuth 2.0 Dynamic Client Registration error response */ -export const OAuthClientRegistrationErrorSchema = z.object({ - error: z.string(), - error_description: z.string().optional(), -}).strip(); +export const OAuthClientRegistrationErrorSchema = z + .object({ + error: z.string(), + error_description: z.string().optional() + }) + .strip(); /** * RFC 7009 OAuth 2.0 Token Revocation request */ -export const OAuthTokenRevocationRequestSchema = z.object({ - token: z.string(), - token_type_hint: z.string().optional(), -}).strip(); +export const OAuthTokenRevocationRequestSchema = z + .object({ + token: z.string(), + token_type_hint: z.string().optional() + }) + .strip(); export type OAuthMetadata = z.infer; export type OpenIdProviderMetadata = z.infer; diff --git a/src/shared/metadataUtils.ts b/src/shared/metadataUtils.ts index 0119a6691..d58729298 100644 --- a/src/shared/metadataUtils.ts +++ b/src/shared/metadataUtils.ts @@ -1,4 +1,4 @@ -import { BaseMetadata } from "../types.js"; +import { BaseMetadata } from '../types.js'; /** * Utilities for working with BaseMetadata objects. @@ -11,19 +11,19 @@ import { BaseMetadata } from "../types.js"; * This implements the spec requirement: "if no title is provided, name should be used for display purposes" */ export function getDisplayName(metadata: BaseMetadata): string { - // First check for title (not undefined and not empty string) - if (metadata.title !== undefined && metadata.title !== '') { - return metadata.title; - } + // First check for title (not undefined and not empty string) + if (metadata.title !== undefined && metadata.title !== '') { + return metadata.title; + } - // Then check for annotations.title (only present in Tool objects) - if ('annotations' in metadata) { - const metadataWithAnnotations = metadata as BaseMetadata & { annotations?: { title?: string } }; - if (metadataWithAnnotations.annotations?.title) { - return metadataWithAnnotations.annotations.title; + // Then check for annotations.title (only present in Tool objects) + if ('annotations' in metadata) { + const metadataWithAnnotations = metadata as BaseMetadata & { annotations?: { title?: string } }; + if (metadataWithAnnotations.annotations?.title) { + return metadataWithAnnotations.annotations.title; + } } - } - // Finally fall back to name - return metadata.name; + // Finally fall back to name + return metadata.name; } diff --git a/src/shared/protocol-transport-handling.test.ts b/src/shared/protocol-transport-handling.test.ts index 3baa9b638..375a0ee78 100644 --- a/src/shared/protocol-transport-handling.test.ts +++ b/src/shared/protocol-transport-handling.test.ts @@ -1,189 +1,187 @@ -import { describe, expect, test, beforeEach } from "@jest/globals"; -import { Protocol } from "./protocol.js"; -import { Transport } from "./transport.js"; -import { Request, Notification, Result, JSONRPCMessage } from "../types.js"; -import { z } from "zod"; +import { describe, expect, test, beforeEach } from '@jest/globals'; +import { Protocol } from './protocol.js'; +import { Transport } from './transport.js'; +import { Request, Notification, Result, JSONRPCMessage } from '../types.js'; +import { z } from 'zod'; // Mock Transport class class MockTransport implements Transport { - id: string; - onclose?: () => void; - onerror?: (error: Error) => void; - onmessage?: (message: unknown) => void; - sentMessages: JSONRPCMessage[] = []; - - constructor(id: string) { - this.id = id; - } - - async start(): Promise {} - - async close(): Promise { - this.onclose?.(); - } - - async send(message: JSONRPCMessage): Promise { - this.sentMessages.push(message); - } + id: string; + onclose?: () => void; + onerror?: (error: Error) => void; + onmessage?: (message: unknown) => void; + sentMessages: JSONRPCMessage[] = []; + + constructor(id: string) { + this.id = id; + } + + async start(): Promise {} + + async close(): Promise { + this.onclose?.(); + } + + async send(message: JSONRPCMessage): Promise { + this.sentMessages.push(message); + } } -describe("Protocol transport handling bug", () => { - let protocol: Protocol; - let transportA: MockTransport; - let transportB: MockTransport; - - beforeEach(() => { - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - })(); - - transportA = new MockTransport("A"); - transportB = new MockTransport("B"); - }); - - test("should send response to the correct transport when multiple clients are connected", async () => { - // Set up a request handler that simulates processing time - let resolveHandler: (value: Result) => void; - const handlerPromise = new Promise((resolve) => { - resolveHandler = resolve; - }); +describe('Protocol transport handling bug', () => { + let protocol: Protocol; + let transportA: MockTransport; + let transportB: MockTransport; - const TestRequestSchema = z.object({ - method: z.literal("test/method"), - params: z.object({ - from: z.string() - }).optional() - }); + beforeEach(() => { + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + })(); - protocol.setRequestHandler( - TestRequestSchema, - async (request) => { - console.log(`Processing request from ${request.params?.from}`); - return handlerPromise; - } - ); - - // Client A connects and sends a request - await protocol.connect(transportA); - - const requestFromA = { - jsonrpc: "2.0" as const, - method: "test/method", - params: { from: "clientA" }, - id: 1 - }; - - // Simulate client A sending a request - transportA.onmessage?.(requestFromA); - - // While A's request is being processed, client B connects - // This overwrites the transport reference in the protocol - await protocol.connect(transportB); - - const requestFromB = { - jsonrpc: "2.0" as const, - method: "test/method", - params: { from: "clientB" }, - id: 2 - }; - - // Client B sends its own request - transportB.onmessage?.(requestFromB); - - // Now complete A's request - resolveHandler!({ data: "responseForA" } as Result); - - // Wait for async operations to complete - await new Promise(resolve => setTimeout(resolve, 10)); - - // Check where the responses went - console.log("Transport A received:", transportA.sentMessages); - console.log("Transport B received:", transportB.sentMessages); - - // FIXED: Each transport now receives its own response - - // Transport A should receive response for request ID 1 - expect(transportA.sentMessages.length).toBe(1); - expect(transportA.sentMessages[0]).toMatchObject({ - jsonrpc: "2.0", - id: 1, - result: { data: "responseForA" } - }); - - // Transport B should only receive its own response (when implemented) - expect(transportB.sentMessages.length).toBe(1); - expect(transportB.sentMessages[0]).toMatchObject({ - jsonrpc: "2.0", - id: 2, - result: { data: "responseForA" } // Same handler result in this test - }); - }); - - test("demonstrates the timing issue with multiple rapid connections", async () => { - const delays: number[] = []; - const results: { transport: string; response: JSONRPCMessage[] }[] = []; - - const DelayedRequestSchema = z.object({ - method: z.literal("test/delayed"), - params: z.object({ - delay: z.number(), - client: z.string() - }).optional() + transportA = new MockTransport('A'); + transportB = new MockTransport('B'); }); - // Set up handler with variable delay - protocol.setRequestHandler( - DelayedRequestSchema, - async (request, extra) => { - const delay = request.params?.delay || 0; - delays.push(delay); - - await new Promise(resolve => setTimeout(resolve, delay)); - - return { - processedBy: `handler-${extra.requestId}`, - delay: delay - } as Result; - } - ); - - // Rapid succession of connections and requests - await protocol.connect(transportA); - transportA.onmessage?.({ - jsonrpc: "2.0" as const, - method: "test/delayed", - params: { delay: 50, client: "A" }, - id: 1 + test('should send response to the correct transport when multiple clients are connected', async () => { + // Set up a request handler that simulates processing time + let resolveHandler: (value: Result) => void; + const handlerPromise = new Promise(resolve => { + resolveHandler = resolve; + }); + + const TestRequestSchema = z.object({ + method: z.literal('test/method'), + params: z + .object({ + from: z.string() + }) + .optional() + }); + + protocol.setRequestHandler(TestRequestSchema, async request => { + console.log(`Processing request from ${request.params?.from}`); + return handlerPromise; + }); + + // Client A connects and sends a request + await protocol.connect(transportA); + + const requestFromA = { + jsonrpc: '2.0' as const, + method: 'test/method', + params: { from: 'clientA' }, + id: 1 + }; + + // Simulate client A sending a request + transportA.onmessage?.(requestFromA); + + // While A's request is being processed, client B connects + // This overwrites the transport reference in the protocol + await protocol.connect(transportB); + + const requestFromB = { + jsonrpc: '2.0' as const, + method: 'test/method', + params: { from: 'clientB' }, + id: 2 + }; + + // Client B sends its own request + transportB.onmessage?.(requestFromB); + + // Now complete A's request + resolveHandler!({ data: 'responseForA' } as Result); + + // Wait for async operations to complete + await new Promise(resolve => setTimeout(resolve, 10)); + + // Check where the responses went + console.log('Transport A received:', transportA.sentMessages); + console.log('Transport B received:', transportB.sentMessages); + + // FIXED: Each transport now receives its own response + + // Transport A should receive response for request ID 1 + expect(transportA.sentMessages.length).toBe(1); + expect(transportA.sentMessages[0]).toMatchObject({ + jsonrpc: '2.0', + id: 1, + result: { data: 'responseForA' } + }); + + // Transport B should only receive its own response (when implemented) + expect(transportB.sentMessages.length).toBe(1); + expect(transportB.sentMessages[0]).toMatchObject({ + jsonrpc: '2.0', + id: 2, + result: { data: 'responseForA' } // Same handler result in this test + }); }); - // Connect B while A is processing - setTimeout(async () => { - await protocol.connect(transportB); - transportB.onmessage?.({ - jsonrpc: "2.0" as const, - method: "test/delayed", - params: { delay: 10, client: "B" }, - id: 2 - }); - }, 10); - - // Wait for all processing - await new Promise(resolve => setTimeout(resolve, 100)); - - // Collect results - if (transportA.sentMessages.length > 0) { - results.push({ transport: "A", response: transportA.sentMessages }); - } - if (transportB.sentMessages.length > 0) { - results.push({ transport: "B", response: transportB.sentMessages }); - } - - console.log("Timing test results:", results); - - // FIXED: Each transport receives its own responses - expect(transportA.sentMessages.length).toBe(1); - expect(transportB.sentMessages.length).toBe(1); - }); -}); \ No newline at end of file + test('demonstrates the timing issue with multiple rapid connections', async () => { + const delays: number[] = []; + const results: { transport: string; response: JSONRPCMessage[] }[] = []; + + const DelayedRequestSchema = z.object({ + method: z.literal('test/delayed'), + params: z + .object({ + delay: z.number(), + client: z.string() + }) + .optional() + }); + + // Set up handler with variable delay + protocol.setRequestHandler(DelayedRequestSchema, async (request, extra) => { + const delay = request.params?.delay || 0; + delays.push(delay); + + await new Promise(resolve => setTimeout(resolve, delay)); + + return { + processedBy: `handler-${extra.requestId}`, + delay: delay + } as Result; + }); + + // Rapid succession of connections and requests + await protocol.connect(transportA); + transportA.onmessage?.({ + jsonrpc: '2.0' as const, + method: 'test/delayed', + params: { delay: 50, client: 'A' }, + id: 1 + }); + + // Connect B while A is processing + setTimeout(async () => { + await protocol.connect(transportB); + transportB.onmessage?.({ + jsonrpc: '2.0' as const, + method: 'test/delayed', + params: { delay: 10, client: 'B' }, + id: 2 + }); + }, 10); + + // Wait for all processing + await new Promise(resolve => setTimeout(resolve, 100)); + + // Collect results + if (transportA.sentMessages.length > 0) { + results.push({ transport: 'A', response: transportA.sentMessages }); + } + if (transportB.sentMessages.length > 0) { + results.push({ transport: 'B', response: transportB.sentMessages }); + } + + console.log('Timing test results:', results); + + // FIXED: Each transport receives its own responses + expect(transportA.sentMessages.length).toBe(1); + expect(transportB.sentMessages.length).toBe(1); + }); +}); diff --git a/src/shared/protocol.test.ts b/src/shared/protocol.test.ts index f4e74c8bb..1c098eafa 100644 --- a/src/shared/protocol.test.ts +++ b/src/shared/protocol.test.ts @@ -1,740 +1,744 @@ -import { ZodType, z } from "zod"; -import { - ClientCapabilities, - ErrorCode, - McpError, - Notification, - Request, - Result, - ServerCapabilities, -} from "../types.js"; -import { Protocol, mergeCapabilities } from "./protocol.js"; -import { Transport } from "./transport.js"; +import { ZodType, z } from 'zod'; +import { ClientCapabilities, ErrorCode, McpError, Notification, Request, Result, ServerCapabilities } from '../types.js'; +import { Protocol, mergeCapabilities } from './protocol.js'; +import { Transport } from './transport.js'; // Mock Transport class class MockTransport implements Transport { - onclose?: () => void; - onerror?: (error: Error) => void; - onmessage?: (message: unknown) => void; - - async start(): Promise {} - async close(): Promise { - this.onclose?.(); - } - async send(_message: unknown): Promise {} -} + onclose?: () => void; + onerror?: (error: Error) => void; + onmessage?: (message: unknown) => void; -describe("protocol tests", () => { - let protocol: Protocol; - let transport: MockTransport; - let sendSpy: jest.SpyInstance; - - beforeEach(() => { - transport = new MockTransport(); - sendSpy = jest.spyOn(transport, 'send'); - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - })(); - }); - - test("should throw a timeout error if the request exceeds the timeout", async () => { - await protocol.connect(transport); - const request = { method: "example", params: {} }; - try { - const mockSchema: ZodType<{ result: string }> = z.object({ - result: z.string(), - }); - await protocol.request(request, mockSchema, { - timeout: 0, - }); - } catch (error) { - expect(error).toBeInstanceOf(McpError); - if (error instanceof McpError) { - expect(error.code).toBe(ErrorCode.RequestTimeout); - } + async start(): Promise {} + async close(): Promise { + this.onclose?.(); } - }); - - test("should invoke onclose when the connection is closed", async () => { - const oncloseMock = jest.fn(); - protocol.onclose = oncloseMock; - await protocol.connect(transport); - await transport.close(); - expect(oncloseMock).toHaveBeenCalled(); - }); - - test("should not overwrite existing hooks when connecting transports", async () => { - const oncloseMock = jest.fn(); - const onerrorMock = jest.fn(); - const onmessageMock = jest.fn(); - transport.onclose = oncloseMock; - transport.onerror = onerrorMock; - transport.onmessage = onmessageMock; - await protocol.connect(transport); - transport.onclose(); - transport.onerror(new Error()); - transport.onmessage(""); - expect(oncloseMock).toHaveBeenCalled(); - expect(onerrorMock).toHaveBeenCalled(); - expect(onmessageMock).toHaveBeenCalled(); - }); - - describe("_meta preservation with onprogress", () => { - test("should preserve existing _meta when adding progressToken", async () => { - await protocol.connect(transport); - const request = { - method: "example", - params: { - data: "test", - _meta: { - customField: "customValue", - anotherField: 123 - } - } - }; - const mockSchema: ZodType<{ result: string }> = z.object({ - result: z.string(), - }); - const onProgressMock = jest.fn(); - - protocol.request(request, mockSchema, { - onprogress: onProgressMock, - }); - - expect(sendSpy).toHaveBeenCalledWith(expect.objectContaining({ - method: "example", - params: { - data: "test", - _meta: { - customField: "customValue", - anotherField: 123, - progressToken: expect.any(Number) - } - }, - jsonrpc: "2.0", - id: expect.any(Number) - }), expect.any(Object)); - }); + async send(_message: unknown): Promise {} +} - test("should create _meta with progressToken when no _meta exists", async () => { - await protocol.connect(transport); - const request = { - method: "example", - params: { - data: "test" - } - }; - const mockSchema: ZodType<{ result: string }> = z.object({ - result: z.string(), - }); - const onProgressMock = jest.fn(); - - protocol.request(request, mockSchema, { - onprogress: onProgressMock, - }); - - expect(sendSpy).toHaveBeenCalledWith(expect.objectContaining({ - method: "example", - params: { - data: "test", - _meta: { - progressToken: expect.any(Number) - } - }, - jsonrpc: "2.0", - id: expect.any(Number) - }), expect.any(Object)); +describe('protocol tests', () => { + let protocol: Protocol; + let transport: MockTransport; + let sendSpy: jest.SpyInstance; + + beforeEach(() => { + transport = new MockTransport(); + sendSpy = jest.spyOn(transport, 'send'); + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + })(); }); - test("should not modify _meta when onprogress is not provided", async () => { - await protocol.connect(transport); - const request = { - method: "example", - params: { - data: "test", - _meta: { - customField: "customValue" - } + test('should throw a timeout error if the request exceeds the timeout', async () => { + await protocol.connect(transport); + const request = { method: 'example', params: {} }; + try { + const mockSchema: ZodType<{ result: string }> = z.object({ + result: z.string() + }); + await protocol.request(request, mockSchema, { + timeout: 0 + }); + } catch (error) { + expect(error).toBeInstanceOf(McpError); + if (error instanceof McpError) { + expect(error.code).toBe(ErrorCode.RequestTimeout); + } } - }; - const mockSchema: ZodType<{ result: string }> = z.object({ - result: z.string(), - }); - - protocol.request(request, mockSchema); - - expect(sendSpy).toHaveBeenCalledWith(expect.objectContaining({ - method: "example", - params: { - data: "test", - _meta: { - customField: "customValue" - } - }, - jsonrpc: "2.0", - id: expect.any(Number) - }), expect.any(Object)); }); - test("should handle params being undefined with onprogress", async () => { - await protocol.connect(transport); - const request = { - method: "example" - }; - const mockSchema: ZodType<{ result: string }> = z.object({ - result: z.string(), - }); - const onProgressMock = jest.fn(); - - protocol.request(request, mockSchema, { - onprogress: onProgressMock, - }); - - expect(sendSpy).toHaveBeenCalledWith(expect.objectContaining({ - method: "example", - params: { - _meta: { - progressToken: expect.any(Number) - } - }, - jsonrpc: "2.0", - id: expect.any(Number) - }), expect.any(Object)); + test('should invoke onclose when the connection is closed', async () => { + const oncloseMock = jest.fn(); + protocol.onclose = oncloseMock; + await protocol.connect(transport); + await transport.close(); + expect(oncloseMock).toHaveBeenCalled(); }); - }); - describe("progress notification timeout behavior", () => { - beforeEach(() => { - jest.useFakeTimers(); - }); - afterEach(() => { - jest.useRealTimers(); + test('should not overwrite existing hooks when connecting transports', async () => { + const oncloseMock = jest.fn(); + const onerrorMock = jest.fn(); + const onmessageMock = jest.fn(); + transport.onclose = oncloseMock; + transport.onerror = onerrorMock; + transport.onmessage = onmessageMock; + await protocol.connect(transport); + transport.onclose(); + transport.onerror(new Error()); + transport.onmessage(''); + expect(oncloseMock).toHaveBeenCalled(); + expect(onerrorMock).toHaveBeenCalled(); + expect(onmessageMock).toHaveBeenCalled(); }); - test("should not reset timeout when resetTimeoutOnProgress is false", async () => { - await protocol.connect(transport); - const request = { method: "example", params: {} }; - const mockSchema: ZodType<{ result: string }> = z.object({ - result: z.string(), - }); - const onProgressMock = jest.fn(); - const requestPromise = protocol.request(request, mockSchema, { - timeout: 1000, - resetTimeoutOnProgress: false, - onprogress: onProgressMock, - }); - - jest.advanceTimersByTime(800); - - if (transport.onmessage) { - transport.onmessage({ - jsonrpc: "2.0", - method: "notifications/progress", - params: { - progressToken: 0, - progress: 50, - total: 100, - }, + describe('_meta preservation with onprogress', () => { + test('should preserve existing _meta when adding progressToken', async () => { + await protocol.connect(transport); + const request = { + method: 'example', + params: { + data: 'test', + _meta: { + customField: 'customValue', + anotherField: 123 + } + } + }; + const mockSchema: ZodType<{ result: string }> = z.object({ + result: z.string() + }); + const onProgressMock = jest.fn(); + + protocol.request(request, mockSchema, { + onprogress: onProgressMock + }); + + expect(sendSpy).toHaveBeenCalledWith( + expect.objectContaining({ + method: 'example', + params: { + data: 'test', + _meta: { + customField: 'customValue', + anotherField: 123, + progressToken: expect.any(Number) + } + }, + jsonrpc: '2.0', + id: expect.any(Number) + }), + expect.any(Object) + ); }); - } - await Promise.resolve(); - - expect(onProgressMock).toHaveBeenCalledWith({ - progress: 50, - total: 100, - }); - - jest.advanceTimersByTime(201); - - await expect(requestPromise).rejects.toThrow("Request timed out"); - }); - test("should reset timeout when progress notification is received", async () => { - await protocol.connect(transport); - const request = { method: "example", params: {} }; - const mockSchema: ZodType<{ result: string }> = z.object({ - result: z.string(), - }); - const onProgressMock = jest.fn(); - const requestPromise = protocol.request(request, mockSchema, { - timeout: 1000, - resetTimeoutOnProgress: true, - onprogress: onProgressMock, - }); - jest.advanceTimersByTime(800); - if (transport.onmessage) { - transport.onmessage({ - jsonrpc: "2.0", - method: "notifications/progress", - params: { - progressToken: 0, - progress: 50, - total: 100, - }, + test('should create _meta with progressToken when no _meta exists', async () => { + await protocol.connect(transport); + const request = { + method: 'example', + params: { + data: 'test' + } + }; + const mockSchema: ZodType<{ result: string }> = z.object({ + result: z.string() + }); + const onProgressMock = jest.fn(); + + protocol.request(request, mockSchema, { + onprogress: onProgressMock + }); + + expect(sendSpy).toHaveBeenCalledWith( + expect.objectContaining({ + method: 'example', + params: { + data: 'test', + _meta: { + progressToken: expect.any(Number) + } + }, + jsonrpc: '2.0', + id: expect.any(Number) + }), + expect.any(Object) + ); }); - } - await Promise.resolve(); - expect(onProgressMock).toHaveBeenCalledWith({ - progress: 50, - total: 100, - }); - jest.advanceTimersByTime(800); - if (transport.onmessage) { - transport.onmessage({ - jsonrpc: "2.0", - id: 0, - result: { result: "success" }, + + test('should not modify _meta when onprogress is not provided', async () => { + await protocol.connect(transport); + const request = { + method: 'example', + params: { + data: 'test', + _meta: { + customField: 'customValue' + } + } + }; + const mockSchema: ZodType<{ result: string }> = z.object({ + result: z.string() + }); + + protocol.request(request, mockSchema); + + expect(sendSpy).toHaveBeenCalledWith( + expect.objectContaining({ + method: 'example', + params: { + data: 'test', + _meta: { + customField: 'customValue' + } + }, + jsonrpc: '2.0', + id: expect.any(Number) + }), + expect.any(Object) + ); + }); + + test('should handle params being undefined with onprogress', async () => { + await protocol.connect(transport); + const request = { + method: 'example' + }; + const mockSchema: ZodType<{ result: string }> = z.object({ + result: z.string() + }); + const onProgressMock = jest.fn(); + + protocol.request(request, mockSchema, { + onprogress: onProgressMock + }); + + expect(sendSpy).toHaveBeenCalledWith( + expect.objectContaining({ + method: 'example', + params: { + _meta: { + progressToken: expect.any(Number) + } + }, + jsonrpc: '2.0', + id: expect.any(Number) + }), + expect.any(Object) + ); }); - } - await Promise.resolve(); - await expect(requestPromise).resolves.toEqual({ result: "success" }); }); - test("should respect maxTotalTimeout", async () => { - await protocol.connect(transport); - const request = { method: "example", params: {} }; - const mockSchema: ZodType<{ result: string }> = z.object({ - result: z.string(), - }); - const onProgressMock = jest.fn(); - const requestPromise = protocol.request(request, mockSchema, { - timeout: 1000, - maxTotalTimeout: 150, - resetTimeoutOnProgress: true, - onprogress: onProgressMock, - }); - - // First progress notification should work - jest.advanceTimersByTime(80); - if (transport.onmessage) { - transport.onmessage({ - jsonrpc: "2.0", - method: "notifications/progress", - params: { - progressToken: 0, - progress: 50, - total: 100, - }, + describe('progress notification timeout behavior', () => { + beforeEach(() => { + jest.useFakeTimers(); }); - } - await Promise.resolve(); - expect(onProgressMock).toHaveBeenCalledWith({ - progress: 50, - total: 100, - }); - jest.advanceTimersByTime(80); - if (transport.onmessage) { - transport.onmessage({ - jsonrpc: "2.0", - method: "notifications/progress", - params: { - progressToken: 0, - progress: 75, - total: 100, - }, + afterEach(() => { + jest.useRealTimers(); }); - } - await expect(requestPromise).rejects.toThrow("Maximum total timeout exceeded"); - expect(onProgressMock).toHaveBeenCalledTimes(1); - }); - test("should timeout if no progress received within timeout period", async () => { - await protocol.connect(transport); - const request = { method: "example", params: {} }; - const mockSchema: ZodType<{ result: string }> = z.object({ - result: z.string(), - }); - const requestPromise = protocol.request(request, mockSchema, { - timeout: 100, - resetTimeoutOnProgress: true, - }); - jest.advanceTimersByTime(101); - await expect(requestPromise).rejects.toThrow("Request timed out"); - }); + test('should not reset timeout when resetTimeoutOnProgress is false', async () => { + await protocol.connect(transport); + const request = { method: 'example', params: {} }; + const mockSchema: ZodType<{ result: string }> = z.object({ + result: z.string() + }); + const onProgressMock = jest.fn(); + const requestPromise = protocol.request(request, mockSchema, { + timeout: 1000, + resetTimeoutOnProgress: false, + onprogress: onProgressMock + }); + + jest.advanceTimersByTime(800); + + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + method: 'notifications/progress', + params: { + progressToken: 0, + progress: 50, + total: 100 + } + }); + } + await Promise.resolve(); + + expect(onProgressMock).toHaveBeenCalledWith({ + progress: 50, + total: 100 + }); + + jest.advanceTimersByTime(201); + + await expect(requestPromise).rejects.toThrow('Request timed out'); + }); - test("should handle multiple progress notifications correctly", async () => { - await protocol.connect(transport); - const request = { method: "example", params: {} }; - const mockSchema: ZodType<{ result: string }> = z.object({ - result: z.string(), - }); - const onProgressMock = jest.fn(); - const requestPromise = protocol.request(request, mockSchema, { - timeout: 1000, - resetTimeoutOnProgress: true, - onprogress: onProgressMock, - }); - - // Simulate multiple progress updates - for (let i = 1; i <= 3; i++) { - jest.advanceTimersByTime(800); - if (transport.onmessage) { - transport.onmessage({ - jsonrpc: "2.0", - method: "notifications/progress", - params: { - progressToken: 0, - progress: i * 25, - total: 100, - }, - }); - } - await Promise.resolve(); - expect(onProgressMock).toHaveBeenNthCalledWith(i, { - progress: i * 25, - total: 100, + test('should reset timeout when progress notification is received', async () => { + await protocol.connect(transport); + const request = { method: 'example', params: {} }; + const mockSchema: ZodType<{ result: string }> = z.object({ + result: z.string() + }); + const onProgressMock = jest.fn(); + const requestPromise = protocol.request(request, mockSchema, { + timeout: 1000, + resetTimeoutOnProgress: true, + onprogress: onProgressMock + }); + jest.advanceTimersByTime(800); + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + method: 'notifications/progress', + params: { + progressToken: 0, + progress: 50, + total: 100 + } + }); + } + await Promise.resolve(); + expect(onProgressMock).toHaveBeenCalledWith({ + progress: 50, + total: 100 + }); + jest.advanceTimersByTime(800); + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + id: 0, + result: { result: 'success' } + }); + } + await Promise.resolve(); + await expect(requestPromise).resolves.toEqual({ result: 'success' }); }); - } - if (transport.onmessage) { - transport.onmessage({ - jsonrpc: "2.0", - id: 0, - result: { result: "success" }, + + test('should respect maxTotalTimeout', async () => { + await protocol.connect(transport); + const request = { method: 'example', params: {} }; + const mockSchema: ZodType<{ result: string }> = z.object({ + result: z.string() + }); + const onProgressMock = jest.fn(); + const requestPromise = protocol.request(request, mockSchema, { + timeout: 1000, + maxTotalTimeout: 150, + resetTimeoutOnProgress: true, + onprogress: onProgressMock + }); + + // First progress notification should work + jest.advanceTimersByTime(80); + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + method: 'notifications/progress', + params: { + progressToken: 0, + progress: 50, + total: 100 + } + }); + } + await Promise.resolve(); + expect(onProgressMock).toHaveBeenCalledWith({ + progress: 50, + total: 100 + }); + jest.advanceTimersByTime(80); + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + method: 'notifications/progress', + params: { + progressToken: 0, + progress: 75, + total: 100 + } + }); + } + await expect(requestPromise).rejects.toThrow('Maximum total timeout exceeded'); + expect(onProgressMock).toHaveBeenCalledTimes(1); }); - } - await Promise.resolve(); - await expect(requestPromise).resolves.toEqual({ result: "success" }); - }); - test("should handle progress notifications with message field", async () => { - await protocol.connect(transport); - const request = { method: "example", params: {} }; - const mockSchema: ZodType<{ result: string }> = z.object({ - result: z.string(), - }); - const onProgressMock = jest.fn(); - - const requestPromise = protocol.request(request, mockSchema, { - timeout: 1000, - onprogress: onProgressMock, - }); - - jest.advanceTimersByTime(200); - - if (transport.onmessage) { - transport.onmessage({ - jsonrpc: "2.0", - method: "notifications/progress", - params: { - progressToken: 0, - progress: 25, - total: 100, - message: "Initializing process...", - }, + test('should timeout if no progress received within timeout period', async () => { + await protocol.connect(transport); + const request = { method: 'example', params: {} }; + const mockSchema: ZodType<{ result: string }> = z.object({ + result: z.string() + }); + const requestPromise = protocol.request(request, mockSchema, { + timeout: 100, + resetTimeoutOnProgress: true + }); + jest.advanceTimersByTime(101); + await expect(requestPromise).rejects.toThrow('Request timed out'); }); - } - await Promise.resolve(); - - expect(onProgressMock).toHaveBeenCalledWith({ - progress: 25, - total: 100, - message: "Initializing process...", - }); - - jest.advanceTimersByTime(200); - - if (transport.onmessage) { - transport.onmessage({ - jsonrpc: "2.0", - method: "notifications/progress", - params: { - progressToken: 0, - progress: 75, - total: 100, - message: "Processing data...", - }, + + test('should handle multiple progress notifications correctly', async () => { + await protocol.connect(transport); + const request = { method: 'example', params: {} }; + const mockSchema: ZodType<{ result: string }> = z.object({ + result: z.string() + }); + const onProgressMock = jest.fn(); + const requestPromise = protocol.request(request, mockSchema, { + timeout: 1000, + resetTimeoutOnProgress: true, + onprogress: onProgressMock + }); + + // Simulate multiple progress updates + for (let i = 1; i <= 3; i++) { + jest.advanceTimersByTime(800); + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + method: 'notifications/progress', + params: { + progressToken: 0, + progress: i * 25, + total: 100 + } + }); + } + await Promise.resolve(); + expect(onProgressMock).toHaveBeenNthCalledWith(i, { + progress: i * 25, + total: 100 + }); + } + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + id: 0, + result: { result: 'success' } + }); + } + await Promise.resolve(); + await expect(requestPromise).resolves.toEqual({ result: 'success' }); }); - } - await Promise.resolve(); - - expect(onProgressMock).toHaveBeenCalledWith({ - progress: 75, - total: 100, - message: "Processing data...", - }); - - if (transport.onmessage) { - transport.onmessage({ - jsonrpc: "2.0", - id: 0, - result: { result: "success" }, + + test('should handle progress notifications with message field', async () => { + await protocol.connect(transport); + const request = { method: 'example', params: {} }; + const mockSchema: ZodType<{ result: string }> = z.object({ + result: z.string() + }); + const onProgressMock = jest.fn(); + + const requestPromise = protocol.request(request, mockSchema, { + timeout: 1000, + onprogress: onProgressMock + }); + + jest.advanceTimersByTime(200); + + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + method: 'notifications/progress', + params: { + progressToken: 0, + progress: 25, + total: 100, + message: 'Initializing process...' + } + }); + } + await Promise.resolve(); + + expect(onProgressMock).toHaveBeenCalledWith({ + progress: 25, + total: 100, + message: 'Initializing process...' + }); + + jest.advanceTimersByTime(200); + + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + method: 'notifications/progress', + params: { + progressToken: 0, + progress: 75, + total: 100, + message: 'Processing data...' + } + }); + } + await Promise.resolve(); + + expect(onProgressMock).toHaveBeenCalledWith({ + progress: 75, + total: 100, + message: 'Processing data...' + }); + + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + id: 0, + result: { result: 'success' } + }); + } + await Promise.resolve(); + await expect(requestPromise).resolves.toEqual({ result: 'success' }); }); - } - await Promise.resolve(); - await expect(requestPromise).resolves.toEqual({ result: "success" }); - }); - }); - - describe("Debounced Notifications", () => { - // We need to flush the microtask queue to test the debouncing logic. - // This helper function does that. - const flushMicrotasks = () => new Promise(resolve => setImmediate(resolve)); - - it("should NOT debounce a notification that has parameters", async () => { - // ARRANGE - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - })({ debouncedNotificationMethods: ['test/debounced_with_params'] }); - await protocol.connect(transport); - - // ACT - // These notifications are configured for debouncing but contain params, so they should be sent immediately. - await protocol.notification({ method: 'test/debounced_with_params', params: { data: 1 } }); - await protocol.notification({ method: 'test/debounced_with_params', params: { data: 2 } }); - - // ASSERT - // Both should have been sent immediately to avoid data loss. - expect(sendSpy).toHaveBeenCalledTimes(2); - expect(sendSpy).toHaveBeenCalledWith(expect.objectContaining({ params: { data: 1 } }), undefined); - expect(sendSpy).toHaveBeenCalledWith(expect.objectContaining({ params: { data: 2 } }), undefined); }); - it("should NOT debounce a notification that has a relatedRequestId", async () => { - // ARRANGE - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - })({ debouncedNotificationMethods: ['test/debounced_with_options'] }); - await protocol.connect(transport); - - // ACT - await protocol.notification({ method: 'test/debounced_with_options' }, { relatedRequestId: 'req-1' }); - await protocol.notification({ method: 'test/debounced_with_options' }, { relatedRequestId: 'req-2' }); - - // ASSERT - expect(sendSpy).toHaveBeenCalledTimes(2); - expect(sendSpy).toHaveBeenCalledWith(expect.any(Object), { relatedRequestId: 'req-1' }); - expect(sendSpy).toHaveBeenCalledWith(expect.any(Object), { relatedRequestId: 'req-2' }); - }); + describe('Debounced Notifications', () => { + // We need to flush the microtask queue to test the debouncing logic. + // This helper function does that. + const flushMicrotasks = () => new Promise(resolve => setImmediate(resolve)); + + it('should NOT debounce a notification that has parameters', async () => { + // ARRANGE + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + })({ debouncedNotificationMethods: ['test/debounced_with_params'] }); + await protocol.connect(transport); + + // ACT + // These notifications are configured for debouncing but contain params, so they should be sent immediately. + await protocol.notification({ method: 'test/debounced_with_params', params: { data: 1 } }); + await protocol.notification({ method: 'test/debounced_with_params', params: { data: 2 } }); + + // ASSERT + // Both should have been sent immediately to avoid data loss. + expect(sendSpy).toHaveBeenCalledTimes(2); + expect(sendSpy).toHaveBeenCalledWith(expect.objectContaining({ params: { data: 1 } }), undefined); + expect(sendSpy).toHaveBeenCalledWith(expect.objectContaining({ params: { data: 2 } }), undefined); + }); - it("should clear pending debounced notifications on connection close", async () => { - // ARRANGE - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - })({ debouncedNotificationMethods: ['test/debounced'] }); - await protocol.connect(transport); + it('should NOT debounce a notification that has a relatedRequestId', async () => { + // ARRANGE + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + })({ debouncedNotificationMethods: ['test/debounced_with_options'] }); + await protocol.connect(transport); + + // ACT + await protocol.notification({ method: 'test/debounced_with_options' }, { relatedRequestId: 'req-1' }); + await protocol.notification({ method: 'test/debounced_with_options' }, { relatedRequestId: 'req-2' }); + + // ASSERT + expect(sendSpy).toHaveBeenCalledTimes(2); + expect(sendSpy).toHaveBeenCalledWith(expect.any(Object), { relatedRequestId: 'req-1' }); + expect(sendSpy).toHaveBeenCalledWith(expect.any(Object), { relatedRequestId: 'req-2' }); + }); - // ACT - // Schedule a notification but don't flush the microtask queue. - protocol.notification({ method: 'test/debounced' }); + it('should clear pending debounced notifications on connection close', async () => { + // ARRANGE + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + })({ debouncedNotificationMethods: ['test/debounced'] }); + await protocol.connect(transport); - // Close the connection. This should clear the pending set. - await protocol.close(); + // ACT + // Schedule a notification but don't flush the microtask queue. + protocol.notification({ method: 'test/debounced' }); - // Now, flush the microtask queue. - await flushMicrotasks(); + // Close the connection. This should clear the pending set. + await protocol.close(); - // ASSERT - // The send should never have happened because the transport was cleared. - expect(sendSpy).not.toHaveBeenCalled(); - }); + // Now, flush the microtask queue. + await flushMicrotasks(); - it("should debounce multiple synchronous calls when params property is omitted", async () => { - // ARRANGE - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - })({ debouncedNotificationMethods: ['test/debounced'] }); - await protocol.connect(transport); - - // ACT - // This is the more idiomatic way to write a notification with no params. - protocol.notification({ method: 'test/debounced' }); - protocol.notification({ method: 'test/debounced' }); - protocol.notification({ method: 'test/debounced' }); - - expect(sendSpy).not.toHaveBeenCalled(); - await flushMicrotasks(); - - // ASSERT - expect(sendSpy).toHaveBeenCalledTimes(1); - // The final sent object might not even have the `params` key, which is fine. - // We can check that it was called and that the params are "falsy". - const sentNotification = sendSpy.mock.calls[0][0]; - expect(sentNotification.method).toBe('test/debounced'); - expect(sentNotification.params).toBeUndefined(); - }); + // ASSERT + // The send should never have happened because the transport was cleared. + expect(sendSpy).not.toHaveBeenCalled(); + }); - it("should debounce calls when params is explicitly undefined", async () => { - // ARRANGE - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - })({ debouncedNotificationMethods: ['test/debounced'] }); - await protocol.connect(transport); - - // ACT - protocol.notification({ method: 'test/debounced', params: undefined }); - protocol.notification({ method: 'test/debounced', params: undefined }); - await flushMicrotasks(); - - // ASSERT - expect(sendSpy).toHaveBeenCalledTimes(1); - expect(sendSpy).toHaveBeenCalledWith( - expect.objectContaining({ - method: 'test/debounced', - params: undefined - }), - undefined - ); - }); + it('should debounce multiple synchronous calls when params property is omitted', async () => { + // ARRANGE + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + })({ debouncedNotificationMethods: ['test/debounced'] }); + await protocol.connect(transport); + + // ACT + // This is the more idiomatic way to write a notification with no params. + protocol.notification({ method: 'test/debounced' }); + protocol.notification({ method: 'test/debounced' }); + protocol.notification({ method: 'test/debounced' }); + + expect(sendSpy).not.toHaveBeenCalled(); + await flushMicrotasks(); + + // ASSERT + expect(sendSpy).toHaveBeenCalledTimes(1); + // The final sent object might not even have the `params` key, which is fine. + // We can check that it was called and that the params are "falsy". + const sentNotification = sendSpy.mock.calls[0][0]; + expect(sentNotification.method).toBe('test/debounced'); + expect(sentNotification.params).toBeUndefined(); + }); - it("should send non-debounced notifications immediately and multiple times", async () => { - // ARRANGE - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - })({ debouncedNotificationMethods: ['test/debounced'] }); // Configure for a different method - await protocol.connect(transport); - - // ACT - // Call a non-debounced notification method multiple times. - await protocol.notification({ method: 'test/immediate' }); - await protocol.notification({ method: 'test/immediate' }); - - // ASSERT - // Since this method is not in the debounce list, it should be sent every time. - expect(sendSpy).toHaveBeenCalledTimes(2); - }); + it('should debounce calls when params is explicitly undefined', async () => { + // ARRANGE + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + })({ debouncedNotificationMethods: ['test/debounced'] }); + await protocol.connect(transport); + + // ACT + protocol.notification({ method: 'test/debounced', params: undefined }); + protocol.notification({ method: 'test/debounced', params: undefined }); + await flushMicrotasks(); + + // ASSERT + expect(sendSpy).toHaveBeenCalledTimes(1); + expect(sendSpy).toHaveBeenCalledWith( + expect.objectContaining({ + method: 'test/debounced', + params: undefined + }), + undefined + ); + }); + + it('should send non-debounced notifications immediately and multiple times', async () => { + // ARRANGE + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + })({ debouncedNotificationMethods: ['test/debounced'] }); // Configure for a different method + await protocol.connect(transport); + + // ACT + // Call a non-debounced notification method multiple times. + await protocol.notification({ method: 'test/immediate' }); + await protocol.notification({ method: 'test/immediate' }); + + // ASSERT + // Since this method is not in the debounce list, it should be sent every time. + expect(sendSpy).toHaveBeenCalledTimes(2); + }); - it("should not debounce any notifications if the option is not provided", async () => { - // ARRANGE - // Use the default protocol from beforeEach, which has no debounce options. - await protocol.connect(transport); + it('should not debounce any notifications if the option is not provided', async () => { + // ARRANGE + // Use the default protocol from beforeEach, which has no debounce options. + await protocol.connect(transport); - // ACT - await protocol.notification({ method: 'any/method' }); - await protocol.notification({ method: 'any/method' }); + // ACT + await protocol.notification({ method: 'any/method' }); + await protocol.notification({ method: 'any/method' }); - // ASSERT - // Without the config, behavior should be immediate sending. - expect(sendSpy).toHaveBeenCalledTimes(2); - }); + // ASSERT + // Without the config, behavior should be immediate sending. + expect(sendSpy).toHaveBeenCalledTimes(2); + }); - it("should handle sequential batches of debounced notifications correctly", async () => { - // ARRANGE - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - })({ debouncedNotificationMethods: ['test/debounced'] }); - await protocol.connect(transport); - - // ACT (Batch 1) - protocol.notification({ method: 'test/debounced' }); - protocol.notification({ method: 'test/debounced' }); - await flushMicrotasks(); - - // ASSERT (Batch 1) - expect(sendSpy).toHaveBeenCalledTimes(1); - - // ACT (Batch 2) - // After the first batch has been sent, a new batch should be possible. - protocol.notification({ method: 'test/debounced' }); - protocol.notification({ method: 'test/debounced' }); - await flushMicrotasks(); - - // ASSERT (Batch 2) - // The total number of sends should now be 2. - expect(sendSpy).toHaveBeenCalledTimes(2); + it('should handle sequential batches of debounced notifications correctly', async () => { + // ARRANGE + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + })({ debouncedNotificationMethods: ['test/debounced'] }); + await protocol.connect(transport); + + // ACT (Batch 1) + protocol.notification({ method: 'test/debounced' }); + protocol.notification({ method: 'test/debounced' }); + await flushMicrotasks(); + + // ASSERT (Batch 1) + expect(sendSpy).toHaveBeenCalledTimes(1); + + // ACT (Batch 2) + // After the first batch has been sent, a new batch should be possible. + protocol.notification({ method: 'test/debounced' }); + protocol.notification({ method: 'test/debounced' }); + await flushMicrotasks(); + + // ASSERT (Batch 2) + // The total number of sends should now be 2. + expect(sendSpy).toHaveBeenCalledTimes(2); + }); }); - }); }); -describe("mergeCapabilities", () => { - it("should merge client capabilities", () => { - const base: ClientCapabilities = { - sampling: {}, - roots: { - listChanged: true, - }, - }; - - const additional: ClientCapabilities = { - experimental: { - feature: true, - }, - elicitation: {}, - roots: { - newProp: true, - }, - }; - - const merged = mergeCapabilities(base, additional); - expect(merged).toEqual({ - sampling: {}, - elicitation: {}, - roots: { - listChanged: true, - newProp: true, - }, - experimental: { - feature: true, - }, +describe('mergeCapabilities', () => { + it('should merge client capabilities', () => { + const base: ClientCapabilities = { + sampling: {}, + roots: { + listChanged: true + } + }; + + const additional: ClientCapabilities = { + experimental: { + feature: true + }, + elicitation: {}, + roots: { + newProp: true + } + }; + + const merged = mergeCapabilities(base, additional); + expect(merged).toEqual({ + sampling: {}, + elicitation: {}, + roots: { + listChanged: true, + newProp: true + }, + experimental: { + feature: true + } + }); + }); + + it('should merge server capabilities', () => { + const base: ServerCapabilities = { + logging: {}, + prompts: { + listChanged: true + } + }; + + const additional: ServerCapabilities = { + resources: { + subscribe: true + }, + prompts: { + newProp: true + } + }; + + const merged = mergeCapabilities(base, additional); + expect(merged).toEqual({ + logging: {}, + prompts: { + listChanged: true, + newProp: true + }, + resources: { + subscribe: true + } + }); }); - }); - - it("should merge server capabilities", () => { - const base: ServerCapabilities = { - logging: {}, - prompts: { - listChanged: true, - }, - }; - - const additional: ServerCapabilities = { - resources: { - subscribe: true, - }, - prompts: { - newProp: true, - }, - }; - - const merged = mergeCapabilities(base, additional); - expect(merged).toEqual({ - logging: {}, - prompts: { - listChanged: true, - newProp: true, - }, - resources: { - subscribe: true, - }, + + it('should override existing values with additional values', () => { + const base: ServerCapabilities = { + prompts: { + listChanged: false + } + }; + + const additional: ServerCapabilities = { + prompts: { + listChanged: true + } + }; + + const merged = mergeCapabilities(base, additional); + expect(merged.prompts!.listChanged).toBe(true); + }); + + it('should handle empty objects', () => { + const base = {}; + const additional = {}; + const merged = mergeCapabilities(base, additional); + expect(merged).toEqual({}); }); - }); - - it("should override existing values with additional values", () => { - const base: ServerCapabilities = { - prompts: { - listChanged: false, - }, - }; - - const additional: ServerCapabilities = { - prompts: { - listChanged: true, - }, - }; - - const merged = mergeCapabilities(base, additional); - expect(merged.prompts!.listChanged).toBe(true); - }); - - it("should handle empty objects", () => { - const base = {}; - const additional = {}; - const merged = mergeCapabilities(base, additional); - expect(merged).toEqual({}); - }); }); diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 7df190ba1..5cb969418 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -1,32 +1,32 @@ -import { ZodLiteral, ZodObject, ZodType, z } from "zod"; +import { ZodLiteral, ZodObject, ZodType, z } from 'zod'; import { - CancelledNotificationSchema, - ClientCapabilities, - ErrorCode, - isJSONRPCError, - isJSONRPCRequest, - isJSONRPCResponse, - isJSONRPCNotification, - JSONRPCError, - JSONRPCNotification, - JSONRPCRequest, - JSONRPCResponse, - McpError, - Notification, - PingRequestSchema, - Progress, - ProgressNotification, - ProgressNotificationSchema, - Request, - RequestId, - Result, - ServerCapabilities, - RequestMeta, - MessageExtraInfo, - RequestInfo, -} from "../types.js"; -import { Transport, TransportSendOptions } from "./transport.js"; -import { AuthInfo } from "../server/auth/types.js"; + CancelledNotificationSchema, + ClientCapabilities, + ErrorCode, + isJSONRPCError, + isJSONRPCRequest, + isJSONRPCResponse, + isJSONRPCNotification, + JSONRPCError, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + McpError, + Notification, + PingRequestSchema, + Progress, + ProgressNotification, + ProgressNotificationSchema, + Request, + RequestId, + Result, + ServerCapabilities, + RequestMeta, + MessageExtraInfo, + RequestInfo +} from '../types.js'; +import { Transport, TransportSendOptions } from './transport.js'; +import { AuthInfo } from '../server/auth/types.js'; /** * Callback for progress notifications. @@ -37,21 +37,21 @@ export type ProgressCallback = (progress: Progress) => void; * Additional initialization options. */ export type ProtocolOptions = { - /** - * Whether to restrict emitted requests to only those that the remote side has indicated that they can handle, through their advertised capabilities. - * - * Note that this DOES NOT affect checking of _local_ side capabilities, as it is considered a logic error to mis-specify those. - * - * Currently this defaults to false, for backwards compatibility with SDK versions that did not advertise capabilities correctly. In future, this will default to true. - */ - enforceStrictCapabilities?: boolean; - /** - * An array of notification method names that should be automatically debounced. - * Any notifications with a method in this list will be coalesced if they - * occur in the same tick of the event loop. - * e.g., ['notifications/tools/list_changed'] - */ - debouncedNotificationMethods?: string[]; + /** + * Whether to restrict emitted requests to only those that the remote side has indicated that they can handle, through their advertised capabilities. + * + * Note that this DOES NOT affect checking of _local_ side capabilities, as it is considered a logic error to mis-specify those. + * + * Currently this defaults to false, for backwards compatibility with SDK versions that did not advertise capabilities correctly. In future, this will default to true. + */ + enforceStrictCapabilities?: boolean; + /** + * An array of notification method names that should be automatically debounced. + * Any notifications with a method in this list will be coalesced if they + * occur in the same tick of the event loop. + * e.g., ['notifications/tools/list_changed'] + */ + debouncedNotificationMethods?: string[]; }; /** @@ -63,53 +63,52 @@ export const DEFAULT_REQUEST_TIMEOUT_MSEC = 60000; * Options that can be given per request. */ export type RequestOptions = { - /** - * If set, requests progress notifications from the remote end (if supported). When progress notifications are received, this callback will be invoked. - */ - onprogress?: ProgressCallback; - - /** - * Can be used to cancel an in-flight request. This will cause an AbortError to be raised from request(). - */ - signal?: AbortSignal; - - /** - * A timeout (in milliseconds) for this request. If exceeded, an McpError with code `RequestTimeout` will be raised from request(). - * - * If not specified, `DEFAULT_REQUEST_TIMEOUT_MSEC` will be used as the timeout. - */ - timeout?: number; - - /** - * If true, receiving a progress notification will reset the request timeout. - * This is useful for long-running operations that send periodic progress updates. - * Default: false - */ - resetTimeoutOnProgress?: boolean; - - /** - * Maximum total time (in milliseconds) to wait for a response. - * If exceeded, an McpError with code `RequestTimeout` will be raised, regardless of progress notifications. - * If not specified, there is no maximum total timeout. - */ - maxTotalTimeout?: number; + /** + * If set, requests progress notifications from the remote end (if supported). When progress notifications are received, this callback will be invoked. + */ + onprogress?: ProgressCallback; + + /** + * Can be used to cancel an in-flight request. This will cause an AbortError to be raised from request(). + */ + signal?: AbortSignal; + + /** + * A timeout (in milliseconds) for this request. If exceeded, an McpError with code `RequestTimeout` will be raised from request(). + * + * If not specified, `DEFAULT_REQUEST_TIMEOUT_MSEC` will be used as the timeout. + */ + timeout?: number; + + /** + * If true, receiving a progress notification will reset the request timeout. + * This is useful for long-running operations that send periodic progress updates. + * Default: false + */ + resetTimeoutOnProgress?: boolean; + + /** + * Maximum total time (in milliseconds) to wait for a response. + * If exceeded, an McpError with code `RequestTimeout` will be raised, regardless of progress notifications. + * If not specified, there is no maximum total timeout. + */ + maxTotalTimeout?: number; } & TransportSendOptions; /** * Options that can be given per notification. */ export type NotificationOptions = { - /** - * May be used to indicate to the transport which incoming request to associate this outgoing notification with. - */ - relatedRequestId?: RequestId; -} + /** + * May be used to indicate to the transport which incoming request to associate this outgoing notification with. + */ + relatedRequestId?: RequestId; +}; /** * Extra data given to request handlers. */ -export type RequestHandlerExtra = { +export type RequestHandlerExtra = { /** * An abort signal used to communicate if the request was cancelled from the sender's side. */ @@ -143,640 +142,567 @@ export type RequestHandlerExtra Promise; /** * Sends a request that relates to the current request being handled. - * + * * This is used by certain transports to correctly associate related messages. */ sendRequest: >(request: SendRequestT, resultSchema: U, options?: RequestOptions) => Promise>; - }; +}; /** * Information about a request's timeout state */ type TimeoutInfo = { - timeoutId: ReturnType; - startTime: number; - timeout: number; - maxTotalTimeout?: number; - resetTimeoutOnProgress: boolean; - onTimeout: () => void; + timeoutId: ReturnType; + startTime: number; + timeout: number; + maxTotalTimeout?: number; + resetTimeoutOnProgress: boolean; + onTimeout: () => void; }; /** * Implements MCP protocol framing on top of a pluggable transport, including * features like request/response linking, notifications, and progress. */ -export abstract class Protocol< - SendRequestT extends Request, - SendNotificationT extends Notification, - SendResultT extends Result, -> { - private _transport?: Transport; - private _requestMessageId = 0; - private _requestHandlers: Map< - string, - ( - request: JSONRPCRequest, - extra: RequestHandlerExtra, - ) => Promise - > = new Map(); - private _requestHandlerAbortControllers: Map = - new Map(); - private _notificationHandlers: Map< - string, - (notification: JSONRPCNotification) => Promise - > = new Map(); - private _responseHandlers: Map< - number, - (response: JSONRPCResponse | Error) => void - > = new Map(); - private _progressHandlers: Map = new Map(); - private _timeoutInfo: Map = new Map(); - private _pendingDebouncedNotifications = new Set(); - - /** - * Callback for when the connection is closed for any reason. - * - * This is invoked when close() is called as well. - */ - onclose?: () => void; - - /** - * Callback for when an error occurs. - * - * Note that errors are not necessarily fatal; they are used for reporting any kind of exceptional condition out of band. - */ - onerror?: (error: Error) => void; - - /** - * A handler to invoke for any request types that do not have their own handler installed. - */ - fallbackRequestHandler?: ( - request: JSONRPCRequest, - extra: RequestHandlerExtra - ) => Promise; - - /** - * A handler to invoke for any notification types that do not have their own handler installed. - */ - fallbackNotificationHandler?: (notification: Notification) => Promise; - - constructor(private _options?: ProtocolOptions) { - this.setNotificationHandler(CancelledNotificationSchema, (notification) => { - const controller = this._requestHandlerAbortControllers.get( - notification.params.requestId, - ); - controller?.abort(notification.params.reason); - }); - - this.setNotificationHandler(ProgressNotificationSchema, (notification) => { - this._onprogress(notification as unknown as ProgressNotification); - }); - - this.setRequestHandler( - PingRequestSchema, - // Automatic pong by default. - (_request) => ({}) as SendResultT, - ); - } - - private _setupTimeout( - messageId: number, - timeout: number, - maxTotalTimeout: number | undefined, - onTimeout: () => void, - resetTimeoutOnProgress: boolean = false - ) { - this._timeoutInfo.set(messageId, { - timeoutId: setTimeout(onTimeout, timeout), - startTime: Date.now(), - timeout, - maxTotalTimeout, - resetTimeoutOnProgress, - onTimeout - }); - } - - private _resetTimeout(messageId: number): boolean { - const info = this._timeoutInfo.get(messageId); - if (!info) return false; - - const totalElapsed = Date.now() - info.startTime; - if (info.maxTotalTimeout && totalElapsed >= info.maxTotalTimeout) { - this._timeoutInfo.delete(messageId); - throw new McpError( - ErrorCode.RequestTimeout, - "Maximum total timeout exceeded", - { maxTotalTimeout: info.maxTotalTimeout, totalElapsed } - ); - } +export abstract class Protocol { + private _transport?: Transport; + private _requestMessageId = 0; + private _requestHandlers: Map< + string, + (request: JSONRPCRequest, extra: RequestHandlerExtra) => Promise + > = new Map(); + private _requestHandlerAbortControllers: Map = new Map(); + private _notificationHandlers: Map Promise> = new Map(); + private _responseHandlers: Map void> = new Map(); + private _progressHandlers: Map = new Map(); + private _timeoutInfo: Map = new Map(); + private _pendingDebouncedNotifications = new Set(); - clearTimeout(info.timeoutId); - info.timeoutId = setTimeout(info.onTimeout, info.timeout); - return true; - } + /** + * Callback for when the connection is closed for any reason. + * + * This is invoked when close() is called as well. + */ + onclose?: () => void; - private _cleanupTimeout(messageId: number) { - const info = this._timeoutInfo.get(messageId); - if (info) { - clearTimeout(info.timeoutId); - this._timeoutInfo.delete(messageId); - } - } - - /** - * Attaches to the given transport, starts it, and starts listening for messages. - * - * The Protocol object assumes ownership of the Transport, replacing any callbacks that have already been set, and expects that it is the only user of the Transport instance going forward. - */ - async connect(transport: Transport): Promise { - this._transport = transport; - const _onclose = this.transport?.onclose; - this._transport.onclose = () => { - _onclose?.(); - this._onclose(); - }; - - const _onerror = this.transport?.onerror; - this._transport.onerror = (error: Error) => { - _onerror?.(error); - this._onerror(error); - }; - - const _onmessage = this._transport?.onmessage; - this._transport.onmessage = (message, extra) => { - _onmessage?.(message, extra); - if (isJSONRPCResponse(message) || isJSONRPCError(message)) { - this._onresponse(message); - } else if (isJSONRPCRequest(message)) { - this._onrequest(message, extra); - } else if (isJSONRPCNotification(message)) { - this._onnotification(message); - } else { - this._onerror( - new Error(`Unknown message type: ${JSON.stringify(message)}`), + /** + * Callback for when an error occurs. + * + * Note that errors are not necessarily fatal; they are used for reporting any kind of exceptional condition out of band. + */ + onerror?: (error: Error) => void; + + /** + * A handler to invoke for any request types that do not have their own handler installed. + */ + fallbackRequestHandler?: (request: JSONRPCRequest, extra: RequestHandlerExtra) => Promise; + + /** + * A handler to invoke for any notification types that do not have their own handler installed. + */ + fallbackNotificationHandler?: (notification: Notification) => Promise; + + constructor(private _options?: ProtocolOptions) { + this.setNotificationHandler(CancelledNotificationSchema, notification => { + const controller = this._requestHandlerAbortControllers.get(notification.params.requestId); + controller?.abort(notification.params.reason); + }); + + this.setNotificationHandler(ProgressNotificationSchema, notification => { + this._onprogress(notification as unknown as ProgressNotification); + }); + + this.setRequestHandler( + PingRequestSchema, + // Automatic pong by default. + _request => ({}) as SendResultT ); - } - }; - - await this._transport.start(); - } - - private _onclose(): void { - const responseHandlers = this._responseHandlers; - this._responseHandlers = new Map(); - this._progressHandlers.clear(); - this._pendingDebouncedNotifications.clear(); - this._transport = undefined; - this.onclose?.(); - - const error = new McpError(ErrorCode.ConnectionClosed, "Connection closed"); - for (const handler of responseHandlers.values()) { - handler(error); } - } - private _onerror(error: Error): void { - this.onerror?.(error); - } + private _setupTimeout( + messageId: number, + timeout: number, + maxTotalTimeout: number | undefined, + onTimeout: () => void, + resetTimeoutOnProgress: boolean = false + ) { + this._timeoutInfo.set(messageId, { + timeoutId: setTimeout(onTimeout, timeout), + startTime: Date.now(), + timeout, + maxTotalTimeout, + resetTimeoutOnProgress, + onTimeout + }); + } - private _onnotification(notification: JSONRPCNotification): void { - const handler = - this._notificationHandlers.get(notification.method) ?? - this.fallbackNotificationHandler; + private _resetTimeout(messageId: number): boolean { + const info = this._timeoutInfo.get(messageId); + if (!info) return false; + + const totalElapsed = Date.now() - info.startTime; + if (info.maxTotalTimeout && totalElapsed >= info.maxTotalTimeout) { + this._timeoutInfo.delete(messageId); + throw new McpError(ErrorCode.RequestTimeout, 'Maximum total timeout exceeded', { + maxTotalTimeout: info.maxTotalTimeout, + totalElapsed + }); + } - // Ignore notifications not being subscribed to. - if (handler === undefined) { - return; + clearTimeout(info.timeoutId); + info.timeoutId = setTimeout(info.onTimeout, info.timeout); + return true; } - // Starting with Promise.resolve() puts any synchronous errors into the monad as well. - Promise.resolve() - .then(() => handler(notification)) - .catch((error) => - this._onerror( - new Error(`Uncaught error in notification handler: ${error}`), - ), - ); - } - - private _onrequest(request: JSONRPCRequest, extra?: MessageExtraInfo): void { - const handler = - this._requestHandlers.get(request.method) ?? this.fallbackRequestHandler; - - // Capture the current transport at request time to ensure responses go to the correct client - const capturedTransport = this._transport; - - if (handler === undefined) { - capturedTransport - ?.send({ - jsonrpc: "2.0", - id: request.id, - error: { - code: ErrorCode.MethodNotFound, - message: "Method not found", - }, - }) - .catch((error) => - this._onerror( - new Error(`Failed to send an error response: ${error}`), - ), - ); - return; + private _cleanupTimeout(messageId: number) { + const info = this._timeoutInfo.get(messageId); + if (info) { + clearTimeout(info.timeoutId); + this._timeoutInfo.delete(messageId); + } } - const abortController = new AbortController(); - this._requestHandlerAbortControllers.set(request.id, abortController); - - const fullExtra: RequestHandlerExtra = { - signal: abortController.signal, - sessionId: capturedTransport?.sessionId, - _meta: request.params?._meta, - sendNotification: - (notification) => - this.notification(notification, { relatedRequestId: request.id }), - sendRequest: (r, resultSchema, options?) => - this.request(r, resultSchema, { ...options, relatedRequestId: request.id }), - authInfo: extra?.authInfo, - requestId: request.id, - requestInfo: extra?.requestInfo - }; - - // Starting with Promise.resolve() puts any synchronous errors into the monad as well. - Promise.resolve() - .then(() => handler(request, fullExtra)) - .then( - (result) => { - if (abortController.signal.aborted) { - return; - } + /** + * Attaches to the given transport, starts it, and starts listening for messages. + * + * The Protocol object assumes ownership of the Transport, replacing any callbacks that have already been set, and expects that it is the only user of the Transport instance going forward. + */ + async connect(transport: Transport): Promise { + this._transport = transport; + const _onclose = this.transport?.onclose; + this._transport.onclose = () => { + _onclose?.(); + this._onclose(); + }; - return capturedTransport?.send({ - result, - jsonrpc: "2.0", - id: request.id, - }); - }, - (error) => { - if (abortController.signal.aborted) { - return; - } - - return capturedTransport?.send({ - jsonrpc: "2.0", - id: request.id, - error: { - code: Number.isSafeInteger(error["code"]) - ? error["code"] - : ErrorCode.InternalError, - message: error.message ?? "Internal error", - }, - }); - }, - ) - .catch((error) => - this._onerror(new Error(`Failed to send response: ${error}`)), - ) - .finally(() => { - this._requestHandlerAbortControllers.delete(request.id); - }); - } - - private _onprogress(notification: ProgressNotification): void { - const { progressToken, ...params } = notification.params; - const messageId = Number(progressToken); - - const handler = this._progressHandlers.get(messageId); - if (!handler) { - this._onerror(new Error(`Received a progress notification for an unknown token: ${JSON.stringify(notification)}`)); - return; - } + const _onerror = this.transport?.onerror; + this._transport.onerror = (error: Error) => { + _onerror?.(error); + this._onerror(error); + }; - const responseHandler = this._responseHandlers.get(messageId); - const timeoutInfo = this._timeoutInfo.get(messageId); + const _onmessage = this._transport?.onmessage; + this._transport.onmessage = (message, extra) => { + _onmessage?.(message, extra); + if (isJSONRPCResponse(message) || isJSONRPCError(message)) { + this._onresponse(message); + } else if (isJSONRPCRequest(message)) { + this._onrequest(message, extra); + } else if (isJSONRPCNotification(message)) { + this._onnotification(message); + } else { + this._onerror(new Error(`Unknown message type: ${JSON.stringify(message)}`)); + } + }; - if (timeoutInfo && responseHandler && timeoutInfo.resetTimeoutOnProgress) { - try { - this._resetTimeout(messageId); - } catch (error) { - responseHandler(error as Error); - return; - } + await this._transport.start(); } - handler(params); - } - - private _onresponse(response: JSONRPCResponse | JSONRPCError): void { - const messageId = Number(response.id); - const handler = this._responseHandlers.get(messageId); - if (handler === undefined) { - this._onerror( - new Error( - `Received a response for an unknown message ID: ${JSON.stringify(response)}`, - ), - ); - return; + private _onclose(): void { + const responseHandlers = this._responseHandlers; + this._responseHandlers = new Map(); + this._progressHandlers.clear(); + this._pendingDebouncedNotifications.clear(); + this._transport = undefined; + this.onclose?.(); + + const error = new McpError(ErrorCode.ConnectionClosed, 'Connection closed'); + for (const handler of responseHandlers.values()) { + handler(error); + } } - this._responseHandlers.delete(messageId); - this._progressHandlers.delete(messageId); - this._cleanupTimeout(messageId); - - if (isJSONRPCResponse(response)) { - handler(response); - } else { - const error = new McpError( - response.error.code, - response.error.message, - response.error.data, - ); - handler(error); + private _onerror(error: Error): void { + this.onerror?.(error); } - } - - get transport(): Transport | undefined { - return this._transport; - } - - /** - * Closes the connection. - */ - async close(): Promise { - await this._transport?.close(); - } - - /** - * A method to check if a capability is supported by the remote side, for the given method to be called. - * - * This should be implemented by subclasses. - */ - protected abstract assertCapabilityForMethod( - method: SendRequestT["method"], - ): void; - - /** - * A method to check if a notification is supported by the local side, for the given method to be sent. - * - * This should be implemented by subclasses. - */ - protected abstract assertNotificationCapability( - method: SendNotificationT["method"], - ): void; - - /** - * A method to check if a request handler is supported by the local side, for the given method to be handled. - * - * This should be implemented by subclasses. - */ - protected abstract assertRequestHandlerCapability(method: string): void; - - /** - * Sends a request and wait for a response. - * - * Do not use this method to emit notifications! Use notification() instead. - */ - request>( - request: SendRequestT, - resultSchema: T, - options?: RequestOptions, - ): Promise> { - const { relatedRequestId, resumptionToken, onresumptiontoken } = options ?? {}; - - return new Promise((resolve, reject) => { - if (!this._transport) { - reject(new Error("Not connected")); - return; - } - - if (this._options?.enforceStrictCapabilities === true) { - this.assertCapabilityForMethod(request.method); - } - - options?.signal?.throwIfAborted(); - - const messageId = this._requestMessageId++; - const jsonrpcRequest: JSONRPCRequest = { - ...request, - jsonrpc: "2.0", - id: messageId, - }; - - if (options?.onprogress) { - this._progressHandlers.set(messageId, options.onprogress); - jsonrpcRequest.params = { - ...request.params, - _meta: { - ...(request.params?._meta || {}), - progressToken: messageId - }, - }; - } - const cancel = (reason: unknown) => { - this._responseHandlers.delete(messageId); - this._progressHandlers.delete(messageId); - this._cleanupTimeout(messageId); + private _onnotification(notification: JSONRPCNotification): void { + const handler = this._notificationHandlers.get(notification.method) ?? this.fallbackNotificationHandler; - this._transport - ?.send({ - jsonrpc: "2.0", - method: "notifications/cancelled", - params: { - requestId: messageId, - reason: String(reason), - }, - }, { relatedRequestId, resumptionToken, onresumptiontoken }) - .catch((error) => - this._onerror(new Error(`Failed to send cancellation: ${error}`)), - ); - - reject(reason); - }; - - this._responseHandlers.set(messageId, (response) => { - if (options?.signal?.aborted) { - return; + // Ignore notifications not being subscribed to. + if (handler === undefined) { + return; } - if (response instanceof Error) { - return reject(response); + // Starting with Promise.resolve() puts any synchronous errors into the monad as well. + Promise.resolve() + .then(() => handler(notification)) + .catch(error => this._onerror(new Error(`Uncaught error in notification handler: ${error}`))); + } + + private _onrequest(request: JSONRPCRequest, extra?: MessageExtraInfo): void { + const handler = this._requestHandlers.get(request.method) ?? this.fallbackRequestHandler; + + // Capture the current transport at request time to ensure responses go to the correct client + const capturedTransport = this._transport; + + if (handler === undefined) { + capturedTransport + ?.send({ + jsonrpc: '2.0', + id: request.id, + error: { + code: ErrorCode.MethodNotFound, + message: 'Method not found' + } + }) + .catch(error => this._onerror(new Error(`Failed to send an error response: ${error}`))); + return; } - try { - const result = resultSchema.parse(response.result); - resolve(result); - } catch (error) { - reject(error); + const abortController = new AbortController(); + this._requestHandlerAbortControllers.set(request.id, abortController); + + const fullExtra: RequestHandlerExtra = { + signal: abortController.signal, + sessionId: capturedTransport?.sessionId, + _meta: request.params?._meta, + sendNotification: notification => this.notification(notification, { relatedRequestId: request.id }), + sendRequest: (r, resultSchema, options?) => this.request(r, resultSchema, { ...options, relatedRequestId: request.id }), + authInfo: extra?.authInfo, + requestId: request.id, + requestInfo: extra?.requestInfo + }; + + // Starting with Promise.resolve() puts any synchronous errors into the monad as well. + Promise.resolve() + .then(() => handler(request, fullExtra)) + .then( + result => { + if (abortController.signal.aborted) { + return; + } + + return capturedTransport?.send({ + result, + jsonrpc: '2.0', + id: request.id + }); + }, + error => { + if (abortController.signal.aborted) { + return; + } + + return capturedTransport?.send({ + jsonrpc: '2.0', + id: request.id, + error: { + code: Number.isSafeInteger(error['code']) ? error['code'] : ErrorCode.InternalError, + message: error.message ?? 'Internal error' + } + }); + } + ) + .catch(error => this._onerror(new Error(`Failed to send response: ${error}`))) + .finally(() => { + this._requestHandlerAbortControllers.delete(request.id); + }); + } + + private _onprogress(notification: ProgressNotification): void { + const { progressToken, ...params } = notification.params; + const messageId = Number(progressToken); + + const handler = this._progressHandlers.get(messageId); + if (!handler) { + this._onerror(new Error(`Received a progress notification for an unknown token: ${JSON.stringify(notification)}`)); + return; } - }); - options?.signal?.addEventListener("abort", () => { - cancel(options?.signal?.reason); - }); + const responseHandler = this._responseHandlers.get(messageId); + const timeoutInfo = this._timeoutInfo.get(messageId); - const timeout = options?.timeout ?? DEFAULT_REQUEST_TIMEOUT_MSEC; - const timeoutHandler = () => cancel(new McpError( - ErrorCode.RequestTimeout, - "Request timed out", - { timeout } - )); + if (timeoutInfo && responseHandler && timeoutInfo.resetTimeoutOnProgress) { + try { + this._resetTimeout(messageId); + } catch (error) { + responseHandler(error as Error); + return; + } + } - this._setupTimeout(messageId, timeout, options?.maxTotalTimeout, timeoutHandler, options?.resetTimeoutOnProgress ?? false); + handler(params); + } + + private _onresponse(response: JSONRPCResponse | JSONRPCError): void { + const messageId = Number(response.id); + const handler = this._responseHandlers.get(messageId); + if (handler === undefined) { + this._onerror(new Error(`Received a response for an unknown message ID: ${JSON.stringify(response)}`)); + return; + } - this._transport.send(jsonrpcRequest, { relatedRequestId, resumptionToken, onresumptiontoken }).catch((error) => { + this._responseHandlers.delete(messageId); + this._progressHandlers.delete(messageId); this._cleanupTimeout(messageId); - reject(error); - }); - }); - } - - /** - * Emits a notification, which is a one-way message that does not expect a response. - */ - async notification(notification: SendNotificationT, options?: NotificationOptions): Promise { - if (!this._transport) { - throw new Error("Not connected"); + + if (isJSONRPCResponse(response)) { + handler(response); + } else { + const error = new McpError(response.error.code, response.error.message, response.error.data); + handler(error); + } } - this.assertNotificationCapability(notification.method); + get transport(): Transport | undefined { + return this._transport; + } - const debouncedMethods = this._options?.debouncedNotificationMethods ?? []; - // A notification can only be debounced if it's in the list AND it's "simple" - // (i.e., has no parameters and no related request ID that could be lost). - const canDebounce = debouncedMethods.includes(notification.method) - && !notification.params - && !(options?.relatedRequestId); + /** + * Closes the connection. + */ + async close(): Promise { + await this._transport?.close(); + } + + /** + * A method to check if a capability is supported by the remote side, for the given method to be called. + * + * This should be implemented by subclasses. + */ + protected abstract assertCapabilityForMethod(method: SendRequestT['method']): void; - if (canDebounce) { - // If a notification of this type is already scheduled, do nothing. - if (this._pendingDebouncedNotifications.has(notification.method)) { - return; - } + /** + * A method to check if a notification is supported by the local side, for the given method to be sent. + * + * This should be implemented by subclasses. + */ + protected abstract assertNotificationCapability(method: SendNotificationT['method']): void; - // Mark this notification type as pending. - this._pendingDebouncedNotifications.add(notification.method); + /** + * A method to check if a request handler is supported by the local side, for the given method to be handled. + * + * This should be implemented by subclasses. + */ + protected abstract assertRequestHandlerCapability(method: string): void; - // Schedule the actual send to happen in the next microtask. - // This allows all synchronous calls in the current event loop tick to be coalesced. - Promise.resolve().then(() => { - // Un-mark the notification so the next one can be scheduled. - this._pendingDebouncedNotifications.delete(notification.method); + /** + * Sends a request and wait for a response. + * + * Do not use this method to emit notifications! Use notification() instead. + */ + request>(request: SendRequestT, resultSchema: T, options?: RequestOptions): Promise> { + const { relatedRequestId, resumptionToken, onresumptiontoken } = options ?? {}; + + return new Promise((resolve, reject) => { + if (!this._transport) { + reject(new Error('Not connected')); + return; + } + + if (this._options?.enforceStrictCapabilities === true) { + this.assertCapabilityForMethod(request.method); + } + + options?.signal?.throwIfAborted(); + + const messageId = this._requestMessageId++; + const jsonrpcRequest: JSONRPCRequest = { + ...request, + jsonrpc: '2.0', + id: messageId + }; + + if (options?.onprogress) { + this._progressHandlers.set(messageId, options.onprogress); + jsonrpcRequest.params = { + ...request.params, + _meta: { + ...(request.params?._meta || {}), + progressToken: messageId + } + }; + } + + const cancel = (reason: unknown) => { + this._responseHandlers.delete(messageId); + this._progressHandlers.delete(messageId); + this._cleanupTimeout(messageId); + + this._transport + ?.send( + { + jsonrpc: '2.0', + method: 'notifications/cancelled', + params: { + requestId: messageId, + reason: String(reason) + } + }, + { relatedRequestId, resumptionToken, onresumptiontoken } + ) + .catch(error => this._onerror(new Error(`Failed to send cancellation: ${error}`))); + + reject(reason); + }; + + this._responseHandlers.set(messageId, response => { + if (options?.signal?.aborted) { + return; + } + + if (response instanceof Error) { + return reject(response); + } + + try { + const result = resultSchema.parse(response.result); + resolve(result); + } catch (error) { + reject(error); + } + }); + + options?.signal?.addEventListener('abort', () => { + cancel(options?.signal?.reason); + }); + + const timeout = options?.timeout ?? DEFAULT_REQUEST_TIMEOUT_MSEC; + const timeoutHandler = () => cancel(new McpError(ErrorCode.RequestTimeout, 'Request timed out', { timeout })); + + this._setupTimeout(messageId, timeout, options?.maxTotalTimeout, timeoutHandler, options?.resetTimeoutOnProgress ?? false); + + this._transport.send(jsonrpcRequest, { relatedRequestId, resumptionToken, onresumptiontoken }).catch(error => { + this._cleanupTimeout(messageId); + reject(error); + }); + }); + } - // SAFETY CHECK: If the connection was closed while this was pending, abort. + /** + * Emits a notification, which is a one-way message that does not expect a response. + */ + async notification(notification: SendNotificationT, options?: NotificationOptions): Promise { if (!this._transport) { - return; + throw new Error('Not connected'); + } + + this.assertNotificationCapability(notification.method); + + const debouncedMethods = this._options?.debouncedNotificationMethods ?? []; + // A notification can only be debounced if it's in the list AND it's "simple" + // (i.e., has no parameters and no related request ID that could be lost). + const canDebounce = debouncedMethods.includes(notification.method) && !notification.params && !options?.relatedRequestId; + + if (canDebounce) { + // If a notification of this type is already scheduled, do nothing. + if (this._pendingDebouncedNotifications.has(notification.method)) { + return; + } + + // Mark this notification type as pending. + this._pendingDebouncedNotifications.add(notification.method); + + // Schedule the actual send to happen in the next microtask. + // This allows all synchronous calls in the current event loop tick to be coalesced. + Promise.resolve().then(() => { + // Un-mark the notification so the next one can be scheduled. + this._pendingDebouncedNotifications.delete(notification.method); + + // SAFETY CHECK: If the connection was closed while this was pending, abort. + if (!this._transport) { + return; + } + + const jsonrpcNotification: JSONRPCNotification = { + ...notification, + jsonrpc: '2.0' + }; + // Send the notification, but don't await it here to avoid blocking. + // Handle potential errors with a .catch(). + this._transport?.send(jsonrpcNotification, options).catch(error => this._onerror(error)); + }); + + // Return immediately. + return; } const jsonrpcNotification: JSONRPCNotification = { - ...notification, - jsonrpc: "2.0", + ...notification, + jsonrpc: '2.0' }; - // Send the notification, but don't await it here to avoid blocking. - // Handle potential errors with a .catch(). - this._transport?.send(jsonrpcNotification, options).catch(error => this._onerror(error)); - }); - // Return immediately. - return; + await this._transport.send(jsonrpcNotification, options); } - const jsonrpcNotification: JSONRPCNotification = { - ...notification, - jsonrpc: "2.0", - }; - - await this._transport.send(jsonrpcNotification, options); - } - - /** - * Registers a handler to invoke when this protocol object receives a request with the given method. - * - * Note that this will replace any previous request handler for the same method. - */ - setRequestHandler< - T extends ZodObject<{ - method: ZodLiteral; - }>, - >( - requestSchema: T, - handler: ( - request: z.infer, - extra: RequestHandlerExtra, - ) => SendResultT | Promise, - ): void { - const method = requestSchema.shape.method.value; - this.assertRequestHandlerCapability(method); - - this._requestHandlers.set(method, (request, extra) => { - return Promise.resolve(handler(requestSchema.parse(request), extra)); - }); - } - - /** - * Removes the request handler for the given method. - */ - removeRequestHandler(method: string): void { - this._requestHandlers.delete(method); - } - - /** - * Asserts that a request handler has not already been set for the given method, in preparation for a new one being automatically installed. - */ - assertCanSetRequestHandler(method: string): void { - if (this._requestHandlers.has(method)) { - throw new Error( - `A request handler for ${method} already exists, which would be overridden`, - ); + /** + * Registers a handler to invoke when this protocol object receives a request with the given method. + * + * Note that this will replace any previous request handler for the same method. + */ + setRequestHandler< + T extends ZodObject<{ + method: ZodLiteral; + }> + >( + requestSchema: T, + handler: (request: z.infer, extra: RequestHandlerExtra) => SendResultT | Promise + ): void { + const method = requestSchema.shape.method.value; + this.assertRequestHandlerCapability(method); + + this._requestHandlers.set(method, (request, extra) => { + return Promise.resolve(handler(requestSchema.parse(request), extra)); + }); + } + + /** + * Removes the request handler for the given method. + */ + removeRequestHandler(method: string): void { + this._requestHandlers.delete(method); + } + + /** + * Asserts that a request handler has not already been set for the given method, in preparation for a new one being automatically installed. + */ + assertCanSetRequestHandler(method: string): void { + if (this._requestHandlers.has(method)) { + throw new Error(`A request handler for ${method} already exists, which would be overridden`); + } + } + + /** + * Registers a handler to invoke when this protocol object receives a notification with the given method. + * + * Note that this will replace any previous notification handler for the same method. + */ + setNotificationHandler< + T extends ZodObject<{ + method: ZodLiteral; + }> + >(notificationSchema: T, handler: (notification: z.infer) => void | Promise): void { + this._notificationHandlers.set(notificationSchema.shape.method.value, notification => + Promise.resolve(handler(notificationSchema.parse(notification))) + ); + } + + /** + * Removes the notification handler for the given method. + */ + removeNotificationHandler(method: string): void { + this._notificationHandlers.delete(method); } - } - - /** - * Registers a handler to invoke when this protocol object receives a notification with the given method. - * - * Note that this will replace any previous notification handler for the same method. - */ - setNotificationHandler< - T extends ZodObject<{ - method: ZodLiteral; - }>, - >( - notificationSchema: T, - handler: (notification: z.infer) => void | Promise, - ): void { - this._notificationHandlers.set( - notificationSchema.shape.method.value, - (notification) => - Promise.resolve(handler(notificationSchema.parse(notification))), - ); - } - - /** - * Removes the notification handler for the given method. - */ - removeNotificationHandler(method: string): void { - this._notificationHandlers.delete(method); - } } -export function mergeCapabilities< - T extends ServerCapabilities | ClientCapabilities, ->(base: T, additional: T): T { - return Object.entries(additional).reduce( - (acc, [key, value]) => { - if (value && typeof value === "object") { - acc[key] = acc[key] ? { ...acc[key], ...value } : value; - } else { - acc[key] = value; - } - return acc; - }, - { ...base }, - ); +export function mergeCapabilities(base: T, additional: T): T { + return Object.entries(additional).reduce( + (acc, [key, value]) => { + if (value && typeof value === 'object') { + acc[key] = acc[key] ? { ...acc[key], ...value } : value; + } else { + acc[key] = value; + } + return acc; + }, + { ...base } + ); } diff --git a/src/shared/stdio.test.ts b/src/shared/stdio.test.ts index b12279664..e41c938b6 100644 --- a/src/shared/stdio.test.ts +++ b/src/shared/stdio.test.ts @@ -1,35 +1,35 @@ -import { JSONRPCMessage } from "../types.js"; -import { ReadBuffer } from "./stdio.js"; +import { JSONRPCMessage } from '../types.js'; +import { ReadBuffer } from './stdio.js'; const testMessage: JSONRPCMessage = { - jsonrpc: "2.0", - method: "foobar", + jsonrpc: '2.0', + method: 'foobar' }; -test("should have no messages after initialization", () => { - const readBuffer = new ReadBuffer(); - expect(readBuffer.readMessage()).toBeNull(); +test('should have no messages after initialization', () => { + const readBuffer = new ReadBuffer(); + expect(readBuffer.readMessage()).toBeNull(); }); -test("should only yield a message after a newline", () => { - const readBuffer = new ReadBuffer(); +test('should only yield a message after a newline', () => { + const readBuffer = new ReadBuffer(); - readBuffer.append(Buffer.from(JSON.stringify(testMessage))); - expect(readBuffer.readMessage()).toBeNull(); + readBuffer.append(Buffer.from(JSON.stringify(testMessage))); + expect(readBuffer.readMessage()).toBeNull(); - readBuffer.append(Buffer.from("\n")); - expect(readBuffer.readMessage()).toEqual(testMessage); - expect(readBuffer.readMessage()).toBeNull(); + readBuffer.append(Buffer.from('\n')); + expect(readBuffer.readMessage()).toEqual(testMessage); + expect(readBuffer.readMessage()).toBeNull(); }); -test("should be reusable after clearing", () => { - const readBuffer = new ReadBuffer(); +test('should be reusable after clearing', () => { + const readBuffer = new ReadBuffer(); - readBuffer.append(Buffer.from("foobar")); - readBuffer.clear(); - expect(readBuffer.readMessage()).toBeNull(); + readBuffer.append(Buffer.from('foobar')); + readBuffer.clear(); + expect(readBuffer.readMessage()).toBeNull(); - readBuffer.append(Buffer.from(JSON.stringify(testMessage))); - readBuffer.append(Buffer.from("\n")); - expect(readBuffer.readMessage()).toEqual(testMessage); + readBuffer.append(Buffer.from(JSON.stringify(testMessage))); + readBuffer.append(Buffer.from('\n')); + expect(readBuffer.readMessage()).toEqual(testMessage); }); diff --git a/src/shared/stdio.ts b/src/shared/stdio.ts index 52bde646f..fe14612bd 100644 --- a/src/shared/stdio.ts +++ b/src/shared/stdio.ts @@ -1,39 +1,39 @@ -import { JSONRPCMessage, JSONRPCMessageSchema } from "../types.js"; +import { JSONRPCMessage, JSONRPCMessageSchema } from '../types.js'; /** * Buffers a continuous stdio stream into discrete JSON-RPC messages. */ export class ReadBuffer { - private _buffer?: Buffer; + private _buffer?: Buffer; - append(chunk: Buffer): void { - this._buffer = this._buffer ? Buffer.concat([this._buffer, chunk]) : chunk; - } - - readMessage(): JSONRPCMessage | null { - if (!this._buffer) { - return null; + append(chunk: Buffer): void { + this._buffer = this._buffer ? Buffer.concat([this._buffer, chunk]) : chunk; } - const index = this._buffer.indexOf("\n"); - if (index === -1) { - return null; - } + readMessage(): JSONRPCMessage | null { + if (!this._buffer) { + return null; + } - const line = this._buffer.toString("utf8", 0, index).replace(/\r$/, ''); - this._buffer = this._buffer.subarray(index + 1); - return deserializeMessage(line); - } + const index = this._buffer.indexOf('\n'); + if (index === -1) { + return null; + } - clear(): void { - this._buffer = undefined; - } + const line = this._buffer.toString('utf8', 0, index).replace(/\r$/, ''); + this._buffer = this._buffer.subarray(index + 1); + return deserializeMessage(line); + } + + clear(): void { + this._buffer = undefined; + } } export function deserializeMessage(line: string): JSONRPCMessage { - return JSONRPCMessageSchema.parse(JSON.parse(line)); + return JSONRPCMessageSchema.parse(JSON.parse(line)); } export function serializeMessage(message: JSONRPCMessage): string { - return JSON.stringify(message) + "\n"; + return JSON.stringify(message) + '\n'; } diff --git a/src/shared/transport.ts b/src/shared/transport.ts index 386b6bae5..c64f622b7 100644 --- a/src/shared/transport.ts +++ b/src/shared/transport.ts @@ -1,4 +1,4 @@ -import { JSONRPCMessage, MessageExtraInfo, RequestId } from "../types.js"; +import { JSONRPCMessage, MessageExtraInfo, RequestId } from '../types.js'; export type FetchLike = (url: string | URL, init?: RequestInit) => Promise; @@ -6,80 +6,80 @@ export type FetchLike = (url: string | URL, init?: RequestInit) => Promise void; -} + /** + * A callback that is invoked when the resumption token changes, if supported by the transport. + * + * This allows clients to persist the latest token for potential reconnection. + */ + onresumptiontoken?: (token: string) => void; +}; /** * Describes the minimal contract for a MCP transport that a client or server can communicate over. */ export interface Transport { - /** - * Starts processing messages on the transport, including any connection steps that might need to be taken. - * - * This method should only be called after callbacks are installed, or else messages may be lost. - * - * NOTE: This method should not be called explicitly when using Client, Server, or Protocol classes, as they will implicitly call start(). - */ - start(): Promise; + /** + * Starts processing messages on the transport, including any connection steps that might need to be taken. + * + * This method should only be called after callbacks are installed, or else messages may be lost. + * + * NOTE: This method should not be called explicitly when using Client, Server, or Protocol classes, as they will implicitly call start(). + */ + start(): Promise; - /** - * Sends a JSON-RPC message (request or response). - * - * If present, `relatedRequestId` is used to indicate to the transport which incoming request to associate this outgoing message with. - */ - send(message: JSONRPCMessage, options?: TransportSendOptions): Promise; + /** + * Sends a JSON-RPC message (request or response). + * + * If present, `relatedRequestId` is used to indicate to the transport which incoming request to associate this outgoing message with. + */ + send(message: JSONRPCMessage, options?: TransportSendOptions): Promise; - /** - * Closes the connection. - */ - close(): Promise; + /** + * Closes the connection. + */ + close(): Promise; - /** - * Callback for when the connection is closed for any reason. - * - * This should be invoked when close() is called as well. - */ - onclose?: () => void; + /** + * Callback for when the connection is closed for any reason. + * + * This should be invoked when close() is called as well. + */ + onclose?: () => void; - /** - * Callback for when an error occurs. - * - * Note that errors are not necessarily fatal; they are used for reporting any kind of exceptional condition out of band. - */ - onerror?: (error: Error) => void; + /** + * Callback for when an error occurs. + * + * Note that errors are not necessarily fatal; they are used for reporting any kind of exceptional condition out of band. + */ + onerror?: (error: Error) => void; - /** - * Callback for when a message (request or response) is received over the connection. - * - * Includes the requestInfo and authInfo if the transport is authenticated. - * - * The requestInfo can be used to get the original request information (headers, etc.) - */ - onmessage?: (message: JSONRPCMessage, extra?: MessageExtraInfo) => void; + /** + * Callback for when a message (request or response) is received over the connection. + * + * Includes the requestInfo and authInfo if the transport is authenticated. + * + * The requestInfo can be used to get the original request information (headers, etc.) + */ + onmessage?: (message: JSONRPCMessage, extra?: MessageExtraInfo) => void; - /** - * The session ID generated for this connection. - */ - sessionId?: string; + /** + * The session ID generated for this connection. + */ + sessionId?: string; - /** - * Sets the protocol version used for the connection (called when the initialize response is received). - */ - setProtocolVersion?: (version: string) => void; + /** + * Sets the protocol version used for the connection (called when the initialize response is received). + */ + setProtocolVersion?: (version: string) => void; } diff --git a/src/shared/uriTemplate.test.ts b/src/shared/uriTemplate.test.ts index 8ec4fb736..043f9325d 100644 --- a/src/shared/uriTemplate.test.ts +++ b/src/shared/uriTemplate.test.ts @@ -1,276 +1,288 @@ -import { UriTemplate } from "./uriTemplate.js"; - -describe("UriTemplate", () => { - describe("isTemplate", () => { - it("should return true for strings containing template expressions", () => { - expect(UriTemplate.isTemplate("{foo}")).toBe(true); - expect(UriTemplate.isTemplate("/users/{id}")).toBe(true); - expect(UriTemplate.isTemplate("http://example.com/{path}/{file}")).toBe(true); - expect(UriTemplate.isTemplate("/search{?q,limit}")).toBe(true); +import { UriTemplate } from './uriTemplate.js'; + +describe('UriTemplate', () => { + describe('isTemplate', () => { + it('should return true for strings containing template expressions', () => { + expect(UriTemplate.isTemplate('{foo}')).toBe(true); + expect(UriTemplate.isTemplate('/users/{id}')).toBe(true); + expect(UriTemplate.isTemplate('http://example.com/{path}/{file}')).toBe(true); + expect(UriTemplate.isTemplate('/search{?q,limit}')).toBe(true); + }); + + it('should return false for strings without template expressions', () => { + expect(UriTemplate.isTemplate('')).toBe(false); + expect(UriTemplate.isTemplate('plain string')).toBe(false); + expect(UriTemplate.isTemplate('http://example.com/foo/bar')).toBe(false); + expect(UriTemplate.isTemplate('{}')).toBe(false); // Empty braces don't count + expect(UriTemplate.isTemplate('{ }')).toBe(false); // Just whitespace doesn't count + }); + }); + + describe('simple string expansion', () => { + it('should expand simple string variables', () => { + const template = new UriTemplate('http://example.com/users/{username}'); + expect(template.expand({ username: 'fred' })).toBe('http://example.com/users/fred'); + expect(template.variableNames).toEqual(['username']); + }); + + it('should handle multiple variables', () => { + const template = new UriTemplate('{x,y}'); + expect(template.expand({ x: '1024', y: '768' })).toBe('1024,768'); + expect(template.variableNames).toEqual(['x', 'y']); + }); + + it('should encode reserved characters', () => { + const template = new UriTemplate('{var}'); + expect(template.expand({ var: 'value with spaces' })).toBe('value%20with%20spaces'); + }); + }); + + describe('reserved expansion', () => { + it('should not encode reserved characters with + operator', () => { + const template = new UriTemplate('{+path}/here'); + expect(template.expand({ path: '/foo/bar' })).toBe('/foo/bar/here'); + expect(template.variableNames).toEqual(['path']); + }); + }); + + describe('fragment expansion', () => { + it('should add # prefix and not encode reserved chars', () => { + const template = new UriTemplate('X{#var}'); + expect(template.expand({ var: '/test' })).toBe('X#/test'); + expect(template.variableNames).toEqual(['var']); + }); + }); + + describe('label expansion', () => { + it('should add . prefix', () => { + const template = new UriTemplate('X{.var}'); + expect(template.expand({ var: 'test' })).toBe('X.test'); + expect(template.variableNames).toEqual(['var']); + }); + }); + + describe('path expansion', () => { + it('should add / prefix', () => { + const template = new UriTemplate('X{/var}'); + expect(template.expand({ var: 'test' })).toBe('X/test'); + expect(template.variableNames).toEqual(['var']); + }); + }); + + describe('query expansion', () => { + it('should add ? prefix and name=value format', () => { + const template = new UriTemplate('X{?var}'); + expect(template.expand({ var: 'test' })).toBe('X?var=test'); + expect(template.variableNames).toEqual(['var']); + }); + }); + + describe('form continuation expansion', () => { + it('should add & prefix and name=value format', () => { + const template = new UriTemplate('X{&var}'); + expect(template.expand({ var: 'test' })).toBe('X&var=test'); + expect(template.variableNames).toEqual(['var']); + }); + }); + + describe('matching', () => { + it('should match simple strings and extract variables', () => { + const template = new UriTemplate('http://example.com/users/{username}'); + const match = template.match('http://example.com/users/fred'); + expect(match).toEqual({ username: 'fred' }); + }); + + it('should match multiple variables', () => { + const template = new UriTemplate('/users/{username}/posts/{postId}'); + const match = template.match('/users/fred/posts/123'); + expect(match).toEqual({ username: 'fred', postId: '123' }); + }); + + it('should return null for non-matching URIs', () => { + const template = new UriTemplate('/users/{username}'); + const match = template.match('/posts/123'); + expect(match).toBeNull(); + }); + + it('should handle exploded arrays', () => { + const template = new UriTemplate('{/list*}'); + const match = template.match('/red,green,blue'); + expect(match).toEqual({ list: ['red', 'green', 'blue'] }); + }); + }); + + describe('edge cases', () => { + it('should handle empty variables', () => { + const template = new UriTemplate('{empty}'); + expect(template.expand({})).toBe(''); + expect(template.expand({ empty: '' })).toBe(''); + }); + + it('should handle undefined variables', () => { + const template = new UriTemplate('{a}{b}{c}'); + expect(template.expand({ b: '2' })).toBe('2'); + }); + + it('should handle special characters in variable names', () => { + const template = new UriTemplate('{$var_name}'); + expect(template.expand({ $var_name: 'value' })).toBe('value'); + }); + }); + + describe('complex patterns', () => { + it('should handle nested path segments', () => { + const template = new UriTemplate('/api/{version}/{resource}/{id}'); + expect( + template.expand({ + version: 'v1', + resource: 'users', + id: '123' + }) + ).toBe('/api/v1/users/123'); + expect(template.variableNames).toEqual(['version', 'resource', 'id']); + }); + + it('should handle query parameters with arrays', () => { + const template = new UriTemplate('/search{?tags*}'); + expect( + template.expand({ + tags: ['nodejs', 'typescript', 'testing'] + }) + ).toBe('/search?tags=nodejs,typescript,testing'); + expect(template.variableNames).toEqual(['tags']); + }); + + it('should handle multiple query parameters', () => { + const template = new UriTemplate('/search{?q,page,limit}'); + expect( + template.expand({ + q: 'test', + page: '1', + limit: '10' + }) + ).toBe('/search?q=test&page=1&limit=10'); + expect(template.variableNames).toEqual(['q', 'page', 'limit']); + }); + }); + + describe('matching complex patterns', () => { + it('should match nested path segments', () => { + const template = new UriTemplate('/api/{version}/{resource}/{id}'); + const match = template.match('/api/v1/users/123'); + expect(match).toEqual({ + version: 'v1', + resource: 'users', + id: '123' + }); + expect(template.variableNames).toEqual(['version', 'resource', 'id']); + }); + + it('should match query parameters', () => { + const template = new UriTemplate('/search{?q}'); + const match = template.match('/search?q=test'); + expect(match).toEqual({ q: 'test' }); + expect(template.variableNames).toEqual(['q']); + }); + + it('should match multiple query parameters', () => { + const template = new UriTemplate('/search{?q,page}'); + const match = template.match('/search?q=test&page=1'); + expect(match).toEqual({ q: 'test', page: '1' }); + expect(template.variableNames).toEqual(['q', 'page']); + }); + + it('should handle partial matches correctly', () => { + const template = new UriTemplate('/users/{id}'); + expect(template.match('/users/123/extra')).toBeNull(); + expect(template.match('/users')).toBeNull(); + }); + }); + + describe('security and edge cases', () => { + it('should handle extremely long input strings', () => { + const longString = 'x'.repeat(100000); + const template = new UriTemplate(`/api/{param}`); + expect(template.expand({ param: longString })).toBe(`/api/${longString}`); + expect(template.match(`/api/${longString}`)).toEqual({ param: longString }); + }); + + it('should handle deeply nested template expressions', () => { + const template = new UriTemplate('{a}{b}{c}{d}{e}{f}{g}{h}{i}{j}'.repeat(1000)); + expect(() => + template.expand({ + a: '1', + b: '2', + c: '3', + d: '4', + e: '5', + f: '6', + g: '7', + h: '8', + i: '9', + j: '0' + }) + ).not.toThrow(); + }); + + it('should handle malformed template expressions', () => { + expect(() => new UriTemplate('{unclosed')).toThrow(); + expect(() => new UriTemplate('{}')).not.toThrow(); + expect(() => new UriTemplate('{,}')).not.toThrow(); + expect(() => new UriTemplate('{a}{')).toThrow(); + }); + + it('should handle pathological regex patterns', () => { + const template = new UriTemplate('/api/{param}'); + // Create a string that could cause catastrophic backtracking + const input = '/api/' + 'a'.repeat(100000); + expect(() => template.match(input)).not.toThrow(); + }); + + it('should handle invalid UTF-8 sequences', () => { + const template = new UriTemplate('/api/{param}'); + const invalidUtf8 = '���'; + expect(() => template.expand({ param: invalidUtf8 })).not.toThrow(); + expect(() => template.match(`/api/${invalidUtf8}`)).not.toThrow(); + }); + + it('should handle template/URI length mismatches', () => { + const template = new UriTemplate('/api/{param}'); + expect(template.match('/api/')).toBeNull(); + expect(template.match('/api')).toBeNull(); + expect(template.match('/api/value/extra')).toBeNull(); + }); + + it('should handle repeated operators', () => { + const template = new UriTemplate('{?a}{?b}{?c}'); + expect(template.expand({ a: '1', b: '2', c: '3' })).toBe('?a=1&b=2&c=3'); + expect(template.variableNames).toEqual(['a', 'b', 'c']); + }); + + it('should handle overlapping variable names', () => { + const template = new UriTemplate('{var}{vara}'); + expect(template.expand({ var: '1', vara: '2' })).toBe('12'); + expect(template.variableNames).toEqual(['var', 'vara']); + }); + + it('should handle empty segments', () => { + const template = new UriTemplate('///{a}////{b}////'); + expect(template.expand({ a: '1', b: '2' })).toBe('///1////2////'); + expect(template.match('///1////2////')).toEqual({ a: '1', b: '2' }); + expect(template.variableNames).toEqual(['a', 'b']); + }); + + it('should handle maximum template expression limit', () => { + // Create a template with many expressions + const expressions = Array(10000).fill('{param}').join(''); + expect(() => new UriTemplate(expressions)).not.toThrow(); + }); + + it('should handle maximum variable name length', () => { + const longName = 'a'.repeat(10000); + const template = new UriTemplate(`{${longName}}`); + const vars: Record = {}; + vars[longName] = 'value'; + expect(() => template.expand(vars)).not.toThrow(); + }); }); - - it("should return false for strings without template expressions", () => { - expect(UriTemplate.isTemplate("")).toBe(false); - expect(UriTemplate.isTemplate("plain string")).toBe(false); - expect(UriTemplate.isTemplate("http://example.com/foo/bar")).toBe(false); - expect(UriTemplate.isTemplate("{}")).toBe(false); // Empty braces don't count - expect(UriTemplate.isTemplate("{ }")).toBe(false); // Just whitespace doesn't count - }); - }); - - describe("simple string expansion", () => { - it("should expand simple string variables", () => { - const template = new UriTemplate("http://example.com/users/{username}"); - expect(template.expand({ username: "fred" })).toBe( - "http://example.com/users/fred", - ); - expect(template.variableNames).toEqual(['username']) - }); - - it("should handle multiple variables", () => { - const template = new UriTemplate("{x,y}"); - expect(template.expand({ x: "1024", y: "768" })).toBe("1024,768"); - expect(template.variableNames).toEqual(['x', 'y']) - }); - - it("should encode reserved characters", () => { - const template = new UriTemplate("{var}"); - expect(template.expand({ var: "value with spaces" })).toBe( - "value%20with%20spaces", - ); - }); - }); - - describe("reserved expansion", () => { - it("should not encode reserved characters with + operator", () => { - const template = new UriTemplate("{+path}/here"); - expect(template.expand({ path: "/foo/bar" })).toBe("/foo/bar/here"); - expect(template.variableNames).toEqual(['path']) - }); - }); - - describe("fragment expansion", () => { - it("should add # prefix and not encode reserved chars", () => { - const template = new UriTemplate("X{#var}"); - expect(template.expand({ var: "/test" })).toBe("X#/test"); - expect(template.variableNames).toEqual(['var']) - }); - }); - - describe("label expansion", () => { - it("should add . prefix", () => { - const template = new UriTemplate("X{.var}"); - expect(template.expand({ var: "test" })).toBe("X.test"); - expect(template.variableNames).toEqual(['var']) - }); - }); - - describe("path expansion", () => { - it("should add / prefix", () => { - const template = new UriTemplate("X{/var}"); - expect(template.expand({ var: "test" })).toBe("X/test"); - expect(template.variableNames).toEqual(['var']) - }); - }); - - describe("query expansion", () => { - it("should add ? prefix and name=value format", () => { - const template = new UriTemplate("X{?var}"); - expect(template.expand({ var: "test" })).toBe("X?var=test"); - expect(template.variableNames).toEqual(['var']) - }); - }); - - describe("form continuation expansion", () => { - it("should add & prefix and name=value format", () => { - const template = new UriTemplate("X{&var}"); - expect(template.expand({ var: "test" })).toBe("X&var=test"); - expect(template.variableNames).toEqual(['var']) - }); - }); - - describe("matching", () => { - it("should match simple strings and extract variables", () => { - const template = new UriTemplate("http://example.com/users/{username}"); - const match = template.match("http://example.com/users/fred"); - expect(match).toEqual({ username: "fred" }); - }); - - it("should match multiple variables", () => { - const template = new UriTemplate("/users/{username}/posts/{postId}"); - const match = template.match("/users/fred/posts/123"); - expect(match).toEqual({ username: "fred", postId: "123" }); - }); - - it("should return null for non-matching URIs", () => { - const template = new UriTemplate("/users/{username}"); - const match = template.match("/posts/123"); - expect(match).toBeNull(); - }); - - it("should handle exploded arrays", () => { - const template = new UriTemplate("{/list*}"); - const match = template.match("/red,green,blue"); - expect(match).toEqual({ list: ["red", "green", "blue"] }); - }); - }); - - describe("edge cases", () => { - it("should handle empty variables", () => { - const template = new UriTemplate("{empty}"); - expect(template.expand({})).toBe(""); - expect(template.expand({ empty: "" })).toBe(""); - }); - - it("should handle undefined variables", () => { - const template = new UriTemplate("{a}{b}{c}"); - expect(template.expand({ b: "2" })).toBe("2"); - }); - - it("should handle special characters in variable names", () => { - const template = new UriTemplate("{$var_name}"); - expect(template.expand({ "$var_name": "value" })).toBe("value"); - }); - }); - - describe("complex patterns", () => { - it("should handle nested path segments", () => { - const template = new UriTemplate("/api/{version}/{resource}/{id}"); - expect(template.expand({ - version: "v1", - resource: "users", - id: "123" - })).toBe("/api/v1/users/123"); - expect(template.variableNames).toEqual(['version', 'resource', 'id']) - }); - - it("should handle query parameters with arrays", () => { - const template = new UriTemplate("/search{?tags*}"); - expect(template.expand({ - tags: ["nodejs", "typescript", "testing"] - })).toBe("/search?tags=nodejs,typescript,testing"); - expect(template.variableNames).toEqual(['tags']) - }); - - it("should handle multiple query parameters", () => { - const template = new UriTemplate("/search{?q,page,limit}"); - expect(template.expand({ - q: "test", - page: "1", - limit: "10" - })).toBe("/search?q=test&page=1&limit=10"); - expect(template.variableNames).toEqual(['q', 'page', 'limit']) - }); - }); - - describe("matching complex patterns", () => { - it("should match nested path segments", () => { - const template = new UriTemplate("/api/{version}/{resource}/{id}"); - const match = template.match("/api/v1/users/123"); - expect(match).toEqual({ - version: "v1", - resource: "users", - id: "123" - }); - expect(template.variableNames).toEqual(['version', 'resource', 'id']) - }); - - it("should match query parameters", () => { - const template = new UriTemplate("/search{?q}"); - const match = template.match("/search?q=test"); - expect(match).toEqual({ q: "test" }); - expect(template.variableNames).toEqual(['q']) - }); - - it("should match multiple query parameters", () => { - const template = new UriTemplate("/search{?q,page}"); - const match = template.match("/search?q=test&page=1"); - expect(match).toEqual({ q: "test", page: "1" }); - expect(template.variableNames).toEqual(['q', 'page']) - }); - - it("should handle partial matches correctly", () => { - const template = new UriTemplate("/users/{id}"); - expect(template.match("/users/123/extra")).toBeNull(); - expect(template.match("/users")).toBeNull(); - }); - }); - - describe("security and edge cases", () => { - it("should handle extremely long input strings", () => { - const longString = "x".repeat(100000); - const template = new UriTemplate(`/api/{param}`); - expect(template.expand({ param: longString })).toBe(`/api/${longString}`); - expect(template.match(`/api/${longString}`)).toEqual({ param: longString }); - }); - - it("should handle deeply nested template expressions", () => { - const template = new UriTemplate("{a}{b}{c}{d}{e}{f}{g}{h}{i}{j}".repeat(1000)); - expect(() => template.expand({ - a: "1", b: "2", c: "3", d: "4", e: "5", - f: "6", g: "7", h: "8", i: "9", j: "0" - })).not.toThrow(); - }); - - it("should handle malformed template expressions", () => { - expect(() => new UriTemplate("{unclosed")).toThrow(); - expect(() => new UriTemplate("{}")).not.toThrow(); - expect(() => new UriTemplate("{,}")).not.toThrow(); - expect(() => new UriTemplate("{a}{")).toThrow(); - }); - - it("should handle pathological regex patterns", () => { - const template = new UriTemplate("/api/{param}"); - // Create a string that could cause catastrophic backtracking - const input = "/api/" + "a".repeat(100000); - expect(() => template.match(input)).not.toThrow(); - }); - - it("should handle invalid UTF-8 sequences", () => { - const template = new UriTemplate("/api/{param}"); - const invalidUtf8 = "���"; - expect(() => template.expand({ param: invalidUtf8 })).not.toThrow(); - expect(() => template.match(`/api/${invalidUtf8}`)).not.toThrow(); - }); - - it("should handle template/URI length mismatches", () => { - const template = new UriTemplate("/api/{param}"); - expect(template.match("/api/")).toBeNull(); - expect(template.match("/api")).toBeNull(); - expect(template.match("/api/value/extra")).toBeNull(); - }); - - it("should handle repeated operators", () => { - const template = new UriTemplate("{?a}{?b}{?c}"); - expect(template.expand({ a: "1", b: "2", c: "3" })).toBe("?a=1&b=2&c=3"); - expect(template.variableNames).toEqual(['a', 'b', 'c']) - }); - - it("should handle overlapping variable names", () => { - const template = new UriTemplate("{var}{vara}"); - expect(template.expand({ var: "1", vara: "2" })).toBe("12"); - expect(template.variableNames).toEqual(['var', 'vara']) - }); - - it("should handle empty segments", () => { - const template = new UriTemplate("///{a}////{b}////"); - expect(template.expand({ a: "1", b: "2" })).toBe("///1////2////"); - expect(template.match("///1////2////")).toEqual({ a: "1", b: "2" }); - expect(template.variableNames).toEqual(['a', 'b']) - }); - - it("should handle maximum template expression limit", () => { - // Create a template with many expressions - const expressions = Array(10000).fill("{param}").join(""); - expect(() => new UriTemplate(expressions)).not.toThrow(); - }); - - it("should handle maximum variable name length", () => { - const longName = "a".repeat(10000); - const template = new UriTemplate(`{${longName}}`); - const vars: Record = {}; - vars[longName] = "value"; - expect(() => template.expand(vars)).not.toThrow(); - }); - }); }); diff --git a/src/shared/uriTemplate.ts b/src/shared/uriTemplate.ts index 982589ac8..1dd57f56f 100644 --- a/src/shared/uriTemplate.ts +++ b/src/shared/uriTemplate.ts @@ -8,309 +8,280 @@ const MAX_TEMPLATE_EXPRESSIONS = 10000; const MAX_REGEX_LENGTH = 1000000; // 1MB export class UriTemplate { - /** - * Returns true if the given string contains any URI template expressions. - * A template expression is a sequence of characters enclosed in curly braces, - * like {foo} or {?bar}. - */ - static isTemplate(str: string): boolean { - // Look for any sequence of characters between curly braces - // that isn't just whitespace - return /\{[^}\s]+\}/.test(str); - } - - private static validateLength( - str: string, - max: number, - context: string, - ): void { - if (str.length > max) { - throw new Error( - `${context} exceeds maximum length of ${max} characters (got ${str.length})`, - ); + /** + * Returns true if the given string contains any URI template expressions. + * A template expression is a sequence of characters enclosed in curly braces, + * like {foo} or {?bar}. + */ + static isTemplate(str: string): boolean { + // Look for any sequence of characters between curly braces + // that isn't just whitespace + return /\{[^}\s]+\}/.test(str); } - } - private readonly template: string; - private readonly parts: Array< - | string - | { name: string; operator: string; names: string[]; exploded: boolean } - >; - - get variableNames(): string[] { - return this.parts.flatMap((part) => typeof part === 'string' ? [] : part.names); - } - - constructor(template: string) { - UriTemplate.validateLength(template, MAX_TEMPLATE_LENGTH, "Template"); - this.template = template; - this.parts = this.parse(template); - } - - toString(): string { - return this.template; - } - - private parse( - template: string, - ): Array< - | string - | { name: string; operator: string; names: string[]; exploded: boolean } - > { - const parts: Array< - | string - | { name: string; operator: string; names: string[]; exploded: boolean } - > = []; - let currentText = ""; - let i = 0; - let expressionCount = 0; - - while (i < template.length) { - if (template[i] === "{") { - if (currentText) { - parts.push(currentText); - currentText = ""; - } - const end = template.indexOf("}", i); - if (end === -1) throw new Error("Unclosed template expression"); - - expressionCount++; - if (expressionCount > MAX_TEMPLATE_EXPRESSIONS) { - throw new Error( - `Template contains too many expressions (max ${MAX_TEMPLATE_EXPRESSIONS})`, - ); - } - const expr = template.slice(i + 1, end); - const operator = this.getOperator(expr); - const exploded = expr.includes("*"); - const names = this.getNames(expr); - const name = names[0]; - - // Validate variable name length - for (const name of names) { - UriTemplate.validateLength( - name, - MAX_VARIABLE_LENGTH, - "Variable name", - ); + private static validateLength(str: string, max: number, context: string): void { + if (str.length > max) { + throw new Error(`${context} exceeds maximum length of ${max} characters (got ${str.length})`); } - - parts.push({ name, operator, names, exploded }); - i = end + 1; - } else { - currentText += template[i]; - i++; - } } + private readonly template: string; + private readonly parts: Array; - if (currentText) { - parts.push(currentText); + get variableNames(): string[] { + return this.parts.flatMap(part => (typeof part === 'string' ? [] : part.names)); } - return parts; - } - - private getOperator(expr: string): string { - const operators = ["+", "#", ".", "/", "?", "&"]; - return operators.find((op) => expr.startsWith(op)) || ""; - } - - private getNames(expr: string): string[] { - const operator = this.getOperator(expr); - return expr - .slice(operator.length) - .split(",") - .map((name) => name.replace("*", "").trim()) - .filter((name) => name.length > 0); - } - - private encodeValue(value: string, operator: string): string { - UriTemplate.validateLength(value, MAX_VARIABLE_LENGTH, "Variable value"); - if (operator === "+" || operator === "#") { - return encodeURI(value); + constructor(template: string) { + UriTemplate.validateLength(template, MAX_TEMPLATE_LENGTH, 'Template'); + this.template = template; + this.parts = this.parse(template); } - return encodeURIComponent(value); - } - - private expandPart( - part: { - name: string; - operator: string; - names: string[]; - exploded: boolean; - }, - variables: Variables, - ): string { - if (part.operator === "?" || part.operator === "&") { - const pairs = part.names - .map((name) => { - const value = variables[name]; - if (value === undefined) return ""; - const encoded = Array.isArray(value) - ? value.map((v) => this.encodeValue(v, part.operator)).join(",") - : this.encodeValue(value.toString(), part.operator); - return `${name}=${encoded}`; - }) - .filter((pair) => pair.length > 0); - - if (pairs.length === 0) return ""; - const separator = part.operator === "?" ? "?" : "&"; - return separator + pairs.join("&"); + + toString(): string { + return this.template; } - if (part.names.length > 1) { - const values = part.names - .map((name) => variables[name]) - .filter((v) => v !== undefined); - if (values.length === 0) return ""; - return values.map((v) => (Array.isArray(v) ? v[0] : v)).join(","); + private parse(template: string): Array { + const parts: Array = []; + let currentText = ''; + let i = 0; + let expressionCount = 0; + + while (i < template.length) { + if (template[i] === '{') { + if (currentText) { + parts.push(currentText); + currentText = ''; + } + const end = template.indexOf('}', i); + if (end === -1) throw new Error('Unclosed template expression'); + + expressionCount++; + if (expressionCount > MAX_TEMPLATE_EXPRESSIONS) { + throw new Error(`Template contains too many expressions (max ${MAX_TEMPLATE_EXPRESSIONS})`); + } + + const expr = template.slice(i + 1, end); + const operator = this.getOperator(expr); + const exploded = expr.includes('*'); + const names = this.getNames(expr); + const name = names[0]; + + // Validate variable name length + for (const name of names) { + UriTemplate.validateLength(name, MAX_VARIABLE_LENGTH, 'Variable name'); + } + + parts.push({ name, operator, names, exploded }); + i = end + 1; + } else { + currentText += template[i]; + i++; + } + } + + if (currentText) { + parts.push(currentText); + } + + return parts; } - const value = variables[part.name]; - if (value === undefined) return ""; - - const values = Array.isArray(value) ? value : [value]; - const encoded = values.map((v) => this.encodeValue(v, part.operator)); - - switch (part.operator) { - case "": - return encoded.join(","); - case "+": - return encoded.join(","); - case "#": - return "#" + encoded.join(","); - case ".": - return "." + encoded.join("."); - case "/": - return "/" + encoded.join("/"); - default: - return encoded.join(","); + private getOperator(expr: string): string { + const operators = ['+', '#', '.', '/', '?', '&']; + return operators.find(op => expr.startsWith(op)) || ''; } - } - - expand(variables: Variables): string { - let result = ""; - let hasQueryParam = false; - - for (const part of this.parts) { - if (typeof part === "string") { - result += part; - continue; - } - - const expanded = this.expandPart(part, variables); - if (!expanded) continue; - - // Convert ? to & if we already have a query parameter - if ((part.operator === "?" || part.operator === "&") && hasQueryParam) { - result += expanded.replace("?", "&"); - } else { - result += expanded; - } - - if (part.operator === "?" || part.operator === "&") { - hasQueryParam = true; - } + + private getNames(expr: string): string[] { + const operator = this.getOperator(expr); + return expr + .slice(operator.length) + .split(',') + .map(name => name.replace('*', '').trim()) + .filter(name => name.length > 0); } - return result; - } + private encodeValue(value: string, operator: string): string { + UriTemplate.validateLength(value, MAX_VARIABLE_LENGTH, 'Variable value'); + if (operator === '+' || operator === '#') { + return encodeURI(value); + } + return encodeURIComponent(value); + } - private escapeRegExp(str: string): string { - return str.replace(/[.*+?^${}()|[\]\\]/g, "\\$&"); - } + private expandPart( + part: { + name: string; + operator: string; + names: string[]; + exploded: boolean; + }, + variables: Variables + ): string { + if (part.operator === '?' || part.operator === '&') { + const pairs = part.names + .map(name => { + const value = variables[name]; + if (value === undefined) return ''; + const encoded = Array.isArray(value) + ? value.map(v => this.encodeValue(v, part.operator)).join(',') + : this.encodeValue(value.toString(), part.operator); + return `${name}=${encoded}`; + }) + .filter(pair => pair.length > 0); + + if (pairs.length === 0) return ''; + const separator = part.operator === '?' ? '?' : '&'; + return separator + pairs.join('&'); + } - private partToRegExp(part: { - name: string; - operator: string; - names: string[]; - exploded: boolean; - }): Array<{ pattern: string; name: string }> { - const patterns: Array<{ pattern: string; name: string }> = []; + if (part.names.length > 1) { + const values = part.names.map(name => variables[name]).filter(v => v !== undefined); + if (values.length === 0) return ''; + return values.map(v => (Array.isArray(v) ? v[0] : v)).join(','); + } - // Validate variable name length for matching - for (const name of part.names) { - UriTemplate.validateLength(name, MAX_VARIABLE_LENGTH, "Variable name"); + const value = variables[part.name]; + if (value === undefined) return ''; + + const values = Array.isArray(value) ? value : [value]; + const encoded = values.map(v => this.encodeValue(v, part.operator)); + + switch (part.operator) { + case '': + return encoded.join(','); + case '+': + return encoded.join(','); + case '#': + return '#' + encoded.join(','); + case '.': + return '.' + encoded.join('.'); + case '/': + return '/' + encoded.join('/'); + default: + return encoded.join(','); + } } - if (part.operator === "?" || part.operator === "&") { - for (let i = 0; i < part.names.length; i++) { - const name = part.names[i]; - const prefix = i === 0 ? "\\" + part.operator : "&"; - patterns.push({ - pattern: prefix + this.escapeRegExp(name) + "=([^&]+)", - name, - }); - } - return patterns; + expand(variables: Variables): string { + let result = ''; + let hasQueryParam = false; + + for (const part of this.parts) { + if (typeof part === 'string') { + result += part; + continue; + } + + const expanded = this.expandPart(part, variables); + if (!expanded) continue; + + // Convert ? to & if we already have a query parameter + if ((part.operator === '?' || part.operator === '&') && hasQueryParam) { + result += expanded.replace('?', '&'); + } else { + result += expanded; + } + + if (part.operator === '?' || part.operator === '&') { + hasQueryParam = true; + } + } + + return result; } - let pattern: string; - const name = part.name; - - switch (part.operator) { - case "": - pattern = part.exploded ? "([^/]+(?:,[^/]+)*)" : "([^/,]+)"; - break; - case "+": - case "#": - pattern = "(.+)"; - break; - case ".": - pattern = "\\.([^/,]+)"; - break; - case "/": - pattern = "/" + (part.exploded ? "([^/]+(?:,[^/]+)*)" : "([^/,]+)"); - break; - default: - pattern = "([^/]+)"; + private escapeRegExp(str: string): string { + return str.replace(/[.*+?^${}()|[\]\\]/g, '\\$&'); } - patterns.push({ pattern, name }); - return patterns; - } - - match(uri: string): Variables | null { - UriTemplate.validateLength(uri, MAX_TEMPLATE_LENGTH, "URI"); - let pattern = "^"; - const names: Array<{ name: string; exploded: boolean }> = []; - - for (const part of this.parts) { - if (typeof part === "string") { - pattern += this.escapeRegExp(part); - } else { - const patterns = this.partToRegExp(part); - for (const { pattern: partPattern, name } of patterns) { - pattern += partPattern; - names.push({ name, exploded: part.exploded }); + private partToRegExp(part: { + name: string; + operator: string; + names: string[]; + exploded: boolean; + }): Array<{ pattern: string; name: string }> { + const patterns: Array<{ pattern: string; name: string }> = []; + + // Validate variable name length for matching + for (const name of part.names) { + UriTemplate.validateLength(name, MAX_VARIABLE_LENGTH, 'Variable name'); + } + + if (part.operator === '?' || part.operator === '&') { + for (let i = 0; i < part.names.length; i++) { + const name = part.names[i]; + const prefix = i === 0 ? '\\' + part.operator : '&'; + patterns.push({ + pattern: prefix + this.escapeRegExp(name) + '=([^&]+)', + name + }); + } + return patterns; + } + + let pattern: string; + const name = part.name; + + switch (part.operator) { + case '': + pattern = part.exploded ? '([^/]+(?:,[^/]+)*)' : '([^/,]+)'; + break; + case '+': + case '#': + pattern = '(.+)'; + break; + case '.': + pattern = '\\.([^/,]+)'; + break; + case '/': + pattern = '/' + (part.exploded ? '([^/]+(?:,[^/]+)*)' : '([^/,]+)'); + break; + default: + pattern = '([^/]+)'; } - } - } - pattern += "$"; - UriTemplate.validateLength( - pattern, - MAX_REGEX_LENGTH, - "Generated regex pattern", - ); - const regex = new RegExp(pattern); - const match = uri.match(regex); - - if (!match) return null; - - const result: Variables = {}; - for (let i = 0; i < names.length; i++) { - const { name, exploded } = names[i]; - const value = match[i + 1]; - const cleanName = name.replace("*", ""); - - if (exploded && value.includes(",")) { - result[cleanName] = value.split(","); - } else { - result[cleanName] = value; - } + patterns.push({ pattern, name }); + return patterns; } - return result; - } + match(uri: string): Variables | null { + UriTemplate.validateLength(uri, MAX_TEMPLATE_LENGTH, 'URI'); + let pattern = '^'; + const names: Array<{ name: string; exploded: boolean }> = []; + + for (const part of this.parts) { + if (typeof part === 'string') { + pattern += this.escapeRegExp(part); + } else { + const patterns = this.partToRegExp(part); + for (const { pattern: partPattern, name } of patterns) { + pattern += partPattern; + names.push({ name, exploded: part.exploded }); + } + } + } + + pattern += '$'; + UriTemplate.validateLength(pattern, MAX_REGEX_LENGTH, 'Generated regex pattern'); + const regex = new RegExp(pattern); + const match = uri.match(regex); + + if (!match) return null; + + const result: Variables = {}; + for (let i = 0; i < names.length; i++) { + const { name, exploded } = names[i]; + const value = match[i + 1]; + const cleanName = name.replace('*', ''); + + if (exploded && value.includes(',')) { + result[cleanName] = value.split(','); + } else { + result[cleanName] = value; + } + } + + return result; + } } diff --git a/src/spec.types.test.ts b/src/spec.types.test.ts index 5aa497f4a..3c835431a 100644 --- a/src/spec.types.test.ts +++ b/src/spec.types.test.ts @@ -5,716 +5,493 @@ * - Runtime checks to verify each Spec type has a static check * (note: a few don't have SDK types, see MISSING_SDK_TYPES below) */ -import * as SDKTypes from "./types.js"; -import * as SpecTypes from "../spec.types.js"; -import fs from "node:fs"; +import * as SDKTypes from './types.js'; +import * as SpecTypes from '../spec.types.js'; +import fs from 'node:fs'; /* eslint-disable @typescript-eslint/no-unused-vars */ /* eslint-disable @typescript-eslint/no-unsafe-function-type */ // Removes index signatures added by ZodObject.passthrough(). type RemovePassthrough = T extends object - ? T extends Array - ? Array> - : T extends Function - ? T - : {[K in keyof T as string extends K ? never : K]: RemovePassthrough} + ? T extends Array + ? Array> + : T extends Function + ? T + : { [K in keyof T as string extends K ? never : K]: RemovePassthrough } : T; // Adds the `jsonrpc` property to a type, to match the on-wire format of notifications. -type WithJSONRPC = T & { jsonrpc: "2.0" }; +type WithJSONRPC = T & { jsonrpc: '2.0' }; // Adds the `jsonrpc` and `id` properties to a type, to match the on-wire format of requests. -type WithJSONRPCRequest = T & { jsonrpc: "2.0"; id: SDKTypes.RequestId }; +type WithJSONRPCRequest = T & { jsonrpc: '2.0'; id: SDKTypes.RequestId }; -type IsUnknown = [unknown] extends [T] ? [T] extends [unknown] ? true : false : false; +type IsUnknown = [unknown] extends [T] ? ([T] extends [unknown] ? true : false) : false; // Turns {x?: unknown} into {x: unknown} but keeps {_meta?: unknown} unchanged (and leaves other optional properties unchanged, e.g. {x?: string}). // This works around an apparent quirk of ZodObject.unknown() (makes fields optional) type MakeUnknownsNotOptional = - IsUnknown extends true - ? unknown - : (T extends object - ? (T extends Array - ? Array> - : (T extends Function - ? T - : Pick & { - // Start with empty object to avoid duplicates - // Make unknown properties required (except _meta) - [K in keyof T as '_meta' extends K ? never : IsUnknown extends true ? K : never]-?: unknown; - } & - Pick extends true ? never : K - }[keyof T]> & { - // Recurse on the picked properties - [K in keyof Pick extends true ? never : K}[keyof T]>]: MakeUnknownsNotOptional - })) - : T); + IsUnknown extends true + ? unknown + : T extends object + ? T extends Array + ? Array> + : T extends Function + ? T + : Pick & { + // Start with empty object to avoid duplicates + // Make unknown properties required (except _meta) + [K in keyof T as '_meta' extends K ? never : IsUnknown extends true ? K : never]-?: unknown; + } & Pick< + T, + { + // Pick all _meta and non-unknown properties with original modifiers + [K in keyof T]: '_meta' extends K ? K : IsUnknown extends true ? never : K; + }[keyof T] + > & { + // Recurse on the picked properties + [K in keyof Pick< + T, + { [K in keyof T]: '_meta' extends K ? K : IsUnknown extends true ? never : K }[keyof T] + >]: MakeUnknownsNotOptional; + } + : T; -function checkCancelledNotification( - sdk: WithJSONRPC, - spec: SpecTypes.CancelledNotification -) { - sdk = spec; - spec = sdk; +function checkCancelledNotification(sdk: WithJSONRPC, spec: SpecTypes.CancelledNotification) { + sdk = spec; + spec = sdk; } -function checkBaseMetadata( - sdk: RemovePassthrough, - spec: SpecTypes.BaseMetadata -) { - sdk = spec; - spec = sdk; +function checkBaseMetadata(sdk: RemovePassthrough, spec: SpecTypes.BaseMetadata) { + sdk = spec; + spec = sdk; } -function checkImplementation( - sdk: RemovePassthrough, - spec: SpecTypes.Implementation -) { - sdk = spec; - spec = sdk; +function checkImplementation(sdk: RemovePassthrough, spec: SpecTypes.Implementation) { + sdk = spec; + spec = sdk; } -function checkProgressNotification( - sdk: WithJSONRPC, - spec: SpecTypes.ProgressNotification -) { - sdk = spec; - spec = sdk; +function checkProgressNotification(sdk: WithJSONRPC, spec: SpecTypes.ProgressNotification) { + sdk = spec; + spec = sdk; } -function checkSubscribeRequest( - sdk: WithJSONRPCRequest, - spec: SpecTypes.SubscribeRequest -) { - sdk = spec; - spec = sdk; +function checkSubscribeRequest(sdk: WithJSONRPCRequest, spec: SpecTypes.SubscribeRequest) { + sdk = spec; + spec = sdk; } -function checkUnsubscribeRequest( - sdk: WithJSONRPCRequest, - spec: SpecTypes.UnsubscribeRequest -) { - sdk = spec; - spec = sdk; +function checkUnsubscribeRequest(sdk: WithJSONRPCRequest, spec: SpecTypes.UnsubscribeRequest) { + sdk = spec; + spec = sdk; } -function checkPaginatedRequest( - sdk: WithJSONRPCRequest, - spec: SpecTypes.PaginatedRequest -) { - sdk = spec; - spec = sdk; +function checkPaginatedRequest(sdk: WithJSONRPCRequest, spec: SpecTypes.PaginatedRequest) { + sdk = spec; + spec = sdk; } -function checkPaginatedResult( - sdk: SDKTypes.PaginatedResult, - spec: SpecTypes.PaginatedResult -) { - sdk = spec; - spec = sdk; +function checkPaginatedResult(sdk: SDKTypes.PaginatedResult, spec: SpecTypes.PaginatedResult) { + sdk = spec; + spec = sdk; } -function checkListRootsRequest( - sdk: WithJSONRPCRequest, - spec: SpecTypes.ListRootsRequest -) { - sdk = spec; - spec = sdk; +function checkListRootsRequest(sdk: WithJSONRPCRequest, spec: SpecTypes.ListRootsRequest) { + sdk = spec; + spec = sdk; } -function checkListRootsResult( - sdk: RemovePassthrough, - spec: SpecTypes.ListRootsResult -) { - sdk = spec; - spec = sdk; +function checkListRootsResult(sdk: RemovePassthrough, spec: SpecTypes.ListRootsResult) { + sdk = spec; + spec = sdk; } -function checkRoot( - sdk: RemovePassthrough, - spec: SpecTypes.Root -) { - sdk = spec; - spec = sdk; +function checkRoot(sdk: RemovePassthrough, spec: SpecTypes.Root) { + sdk = spec; + spec = sdk; } -function checkElicitRequest( - sdk: WithJSONRPCRequest>, - spec: SpecTypes.ElicitRequest -) { - sdk = spec; - spec = sdk; +function checkElicitRequest(sdk: WithJSONRPCRequest>, spec: SpecTypes.ElicitRequest) { + sdk = spec; + spec = sdk; } -function checkElicitResult( - sdk: RemovePassthrough, - spec: SpecTypes.ElicitResult -) { - sdk = spec; - spec = sdk; +function checkElicitResult(sdk: RemovePassthrough, spec: SpecTypes.ElicitResult) { + sdk = spec; + spec = sdk; } -function checkCompleteRequest( - sdk: WithJSONRPCRequest>, - spec: SpecTypes.CompleteRequest -) { - sdk = spec; - spec = sdk; +function checkCompleteRequest(sdk: WithJSONRPCRequest>, spec: SpecTypes.CompleteRequest) { + sdk = spec; + spec = sdk; } -function checkCompleteResult( - sdk: SDKTypes.CompleteResult, - spec: SpecTypes.CompleteResult -) { - sdk = spec; - spec = sdk; +function checkCompleteResult(sdk: SDKTypes.CompleteResult, spec: SpecTypes.CompleteResult) { + sdk = spec; + spec = sdk; } -function checkProgressToken( - sdk: SDKTypes.ProgressToken, - spec: SpecTypes.ProgressToken -) { - sdk = spec; - spec = sdk; +function checkProgressToken(sdk: SDKTypes.ProgressToken, spec: SpecTypes.ProgressToken) { + sdk = spec; + spec = sdk; } -function checkCursor( - sdk: SDKTypes.Cursor, - spec: SpecTypes.Cursor -) { - sdk = spec; - spec = sdk; +function checkCursor(sdk: SDKTypes.Cursor, spec: SpecTypes.Cursor) { + sdk = spec; + spec = sdk; } -function checkRequest( - sdk: SDKTypes.Request, - spec: SpecTypes.Request -) { - sdk = spec; - spec = sdk; +function checkRequest(sdk: SDKTypes.Request, spec: SpecTypes.Request) { + sdk = spec; + spec = sdk; } -function checkResult( - sdk: SDKTypes.Result, - spec: SpecTypes.Result -) { - sdk = spec; - spec = sdk; +function checkResult(sdk: SDKTypes.Result, spec: SpecTypes.Result) { + sdk = spec; + spec = sdk; } -function checkRequestId( - sdk: SDKTypes.RequestId, - spec: SpecTypes.RequestId -) { - sdk = spec; - spec = sdk; +function checkRequestId(sdk: SDKTypes.RequestId, spec: SpecTypes.RequestId) { + sdk = spec; + spec = sdk; } -function checkJSONRPCRequest( - sdk: SDKTypes.JSONRPCRequest, - spec: SpecTypes.JSONRPCRequest -) { - sdk = spec; - spec = sdk; +function checkJSONRPCRequest(sdk: SDKTypes.JSONRPCRequest, spec: SpecTypes.JSONRPCRequest) { + sdk = spec; + spec = sdk; } -function checkJSONRPCNotification( - sdk: SDKTypes.JSONRPCNotification, - spec: SpecTypes.JSONRPCNotification -) { - sdk = spec; - spec = sdk; +function checkJSONRPCNotification(sdk: SDKTypes.JSONRPCNotification, spec: SpecTypes.JSONRPCNotification) { + sdk = spec; + spec = sdk; } -function checkJSONRPCResponse( - sdk: SDKTypes.JSONRPCResponse, - spec: SpecTypes.JSONRPCResponse -) { - sdk = spec; - spec = sdk; +function checkJSONRPCResponse(sdk: SDKTypes.JSONRPCResponse, spec: SpecTypes.JSONRPCResponse) { + sdk = spec; + spec = sdk; } -function checkEmptyResult( - sdk: SDKTypes.EmptyResult, - spec: SpecTypes.EmptyResult -) { - sdk = spec; - spec = sdk; +function checkEmptyResult(sdk: SDKTypes.EmptyResult, spec: SpecTypes.EmptyResult) { + sdk = spec; + spec = sdk; } -function checkNotification( - sdk: SDKTypes.Notification, - spec: SpecTypes.Notification -) { - sdk = spec; - spec = sdk; +function checkNotification(sdk: SDKTypes.Notification, spec: SpecTypes.Notification) { + sdk = spec; + spec = sdk; } -function checkClientResult( - sdk: SDKTypes.ClientResult, - spec: SpecTypes.ClientResult -) { - sdk = spec; - spec = sdk; +function checkClientResult(sdk: SDKTypes.ClientResult, spec: SpecTypes.ClientResult) { + sdk = spec; + spec = sdk; } -function checkClientNotification( - sdk: WithJSONRPC, - spec: SpecTypes.ClientNotification -) { - sdk = spec; - spec = sdk; +function checkClientNotification(sdk: WithJSONRPC, spec: SpecTypes.ClientNotification) { + sdk = spec; + spec = sdk; } -function checkServerResult( - sdk: SDKTypes.ServerResult, - spec: SpecTypes.ServerResult -) { - sdk = spec; - spec = sdk; +function checkServerResult(sdk: SDKTypes.ServerResult, spec: SpecTypes.ServerResult) { + sdk = spec; + spec = sdk; } function checkResourceTemplateReference( - sdk: RemovePassthrough, - spec: SpecTypes.ResourceTemplateReference + sdk: RemovePassthrough, + spec: SpecTypes.ResourceTemplateReference ) { - sdk = spec; - spec = sdk; + sdk = spec; + spec = sdk; } -function checkPromptReference( - sdk: RemovePassthrough, - spec: SpecTypes.PromptReference -) { - sdk = spec; - spec = sdk; +function checkPromptReference(sdk: RemovePassthrough, spec: SpecTypes.PromptReference) { + sdk = spec; + spec = sdk; } -function checkToolAnnotations( - sdk: RemovePassthrough, - spec: SpecTypes.ToolAnnotations -) { - sdk = spec; - spec = sdk; +function checkToolAnnotations(sdk: RemovePassthrough, spec: SpecTypes.ToolAnnotations) { + sdk = spec; + spec = sdk; } -function checkTool( - sdk: RemovePassthrough, - spec: SpecTypes.Tool -) { - sdk = spec; - spec = sdk; +function checkTool(sdk: RemovePassthrough, spec: SpecTypes.Tool) { + sdk = spec; + spec = sdk; } -function checkListToolsRequest( - sdk: WithJSONRPCRequest, - spec: SpecTypes.ListToolsRequest -) { - sdk = spec; - spec = sdk; +function checkListToolsRequest(sdk: WithJSONRPCRequest, spec: SpecTypes.ListToolsRequest) { + sdk = spec; + spec = sdk; } -function checkListToolsResult( - sdk: RemovePassthrough, - spec: SpecTypes.ListToolsResult -) { - sdk = spec; - spec = sdk; +function checkListToolsResult(sdk: RemovePassthrough, spec: SpecTypes.ListToolsResult) { + sdk = spec; + spec = sdk; } -function checkCallToolResult( - sdk: RemovePassthrough, - spec: SpecTypes.CallToolResult -) { - sdk = spec; - spec = sdk; +function checkCallToolResult(sdk: RemovePassthrough, spec: SpecTypes.CallToolResult) { + sdk = spec; + spec = sdk; } -function checkCallToolRequest( - sdk: WithJSONRPCRequest, - spec: SpecTypes.CallToolRequest -) { - sdk = spec; - spec = sdk; +function checkCallToolRequest(sdk: WithJSONRPCRequest, spec: SpecTypes.CallToolRequest) { + sdk = spec; + spec = sdk; } function checkToolListChangedNotification( - sdk: WithJSONRPC, - spec: SpecTypes.ToolListChangedNotification + sdk: WithJSONRPC, + spec: SpecTypes.ToolListChangedNotification ) { - sdk = spec; - spec = sdk; + sdk = spec; + spec = sdk; } function checkResourceListChangedNotification( - sdk: WithJSONRPC, - spec: SpecTypes.ResourceListChangedNotification + sdk: WithJSONRPC, + spec: SpecTypes.ResourceListChangedNotification ) { - sdk = spec; - spec = sdk; + sdk = spec; + spec = sdk; } function checkPromptListChangedNotification( - sdk: WithJSONRPC, - spec: SpecTypes.PromptListChangedNotification + sdk: WithJSONRPC, + spec: SpecTypes.PromptListChangedNotification ) { - sdk = spec; - spec = sdk; + sdk = spec; + spec = sdk; } function checkRootsListChangedNotification( - sdk: WithJSONRPC, - spec: SpecTypes.RootsListChangedNotification + sdk: WithJSONRPC, + spec: SpecTypes.RootsListChangedNotification ) { - sdk = spec; - spec = sdk; + sdk = spec; + spec = sdk; } function checkResourceUpdatedNotification( - sdk: WithJSONRPC, - spec: SpecTypes.ResourceUpdatedNotification + sdk: WithJSONRPC, + spec: SpecTypes.ResourceUpdatedNotification ) { - sdk = spec; - spec = sdk; + sdk = spec; + spec = sdk; } -function checkSamplingMessage( - sdk: RemovePassthrough, - spec: SpecTypes.SamplingMessage -) { - sdk = spec; - spec = sdk; +function checkSamplingMessage(sdk: RemovePassthrough, spec: SpecTypes.SamplingMessage) { + sdk = spec; + spec = sdk; } -function checkCreateMessageResult( - sdk: RemovePassthrough, - spec: SpecTypes.CreateMessageResult -) { - sdk = spec; - spec = sdk; +function checkCreateMessageResult(sdk: RemovePassthrough, spec: SpecTypes.CreateMessageResult) { + sdk = spec; + spec = sdk; } -function checkSetLevelRequest( - sdk: WithJSONRPCRequest, - spec: SpecTypes.SetLevelRequest -) { - sdk = spec; - spec = sdk; +function checkSetLevelRequest(sdk: WithJSONRPCRequest, spec: SpecTypes.SetLevelRequest) { + sdk = spec; + spec = sdk; } -function checkPingRequest( - sdk: WithJSONRPCRequest, - spec: SpecTypes.PingRequest -) { - sdk = spec; - spec = sdk; +function checkPingRequest(sdk: WithJSONRPCRequest, spec: SpecTypes.PingRequest) { + sdk = spec; + spec = sdk; } -function checkInitializedNotification( - sdk: WithJSONRPC, - spec: SpecTypes.InitializedNotification -) { - sdk = spec; - spec = sdk; +function checkInitializedNotification(sdk: WithJSONRPC, spec: SpecTypes.InitializedNotification) { + sdk = spec; + spec = sdk; } -function checkListResourcesRequest( - sdk: WithJSONRPCRequest, - spec: SpecTypes.ListResourcesRequest -) { - sdk = spec; - spec = sdk; +function checkListResourcesRequest(sdk: WithJSONRPCRequest, spec: SpecTypes.ListResourcesRequest) { + sdk = spec; + spec = sdk; } -function checkListResourcesResult( - sdk: RemovePassthrough, - spec: SpecTypes.ListResourcesResult -) { - sdk = spec; - spec = sdk; +function checkListResourcesResult(sdk: RemovePassthrough, spec: SpecTypes.ListResourcesResult) { + sdk = spec; + spec = sdk; } function checkListResourceTemplatesRequest( - sdk: WithJSONRPCRequest, - spec: SpecTypes.ListResourceTemplatesRequest + sdk: WithJSONRPCRequest, + spec: SpecTypes.ListResourceTemplatesRequest ) { - sdk = spec; - spec = sdk; + sdk = spec; + spec = sdk; } function checkListResourceTemplatesResult( - sdk: RemovePassthrough, - spec: SpecTypes.ListResourceTemplatesResult + sdk: RemovePassthrough, + spec: SpecTypes.ListResourceTemplatesResult ) { - sdk = spec; - spec = sdk; + sdk = spec; + spec = sdk; } -function checkReadResourceRequest( - sdk: WithJSONRPCRequest, - spec: SpecTypes.ReadResourceRequest -) { - sdk = spec; - spec = sdk; +function checkReadResourceRequest(sdk: WithJSONRPCRequest, spec: SpecTypes.ReadResourceRequest) { + sdk = spec; + spec = sdk; } -function checkReadResourceResult( - sdk: RemovePassthrough, - spec: SpecTypes.ReadResourceResult -) { - sdk = spec; - spec = sdk; +function checkReadResourceResult(sdk: RemovePassthrough, spec: SpecTypes.ReadResourceResult) { + sdk = spec; + spec = sdk; } -function checkResourceContents( - sdk: RemovePassthrough, - spec: SpecTypes.ResourceContents -) { - sdk = spec; - spec = sdk; +function checkResourceContents(sdk: RemovePassthrough, spec: SpecTypes.ResourceContents) { + sdk = spec; + spec = sdk; } -function checkTextResourceContents( - sdk: RemovePassthrough, - spec: SpecTypes.TextResourceContents -) { - sdk = spec; - spec = sdk; +function checkTextResourceContents(sdk: RemovePassthrough, spec: SpecTypes.TextResourceContents) { + sdk = spec; + spec = sdk; } -function checkBlobResourceContents( - sdk: RemovePassthrough, - spec: SpecTypes.BlobResourceContents -) { - sdk = spec; - spec = sdk; +function checkBlobResourceContents(sdk: RemovePassthrough, spec: SpecTypes.BlobResourceContents) { + sdk = spec; + spec = sdk; } -function checkResource( - sdk: RemovePassthrough, - spec: SpecTypes.Resource -) { - sdk = spec; - spec = sdk; +function checkResource(sdk: RemovePassthrough, spec: SpecTypes.Resource) { + sdk = spec; + spec = sdk; } -function checkResourceTemplate( - sdk: RemovePassthrough, - spec: SpecTypes.ResourceTemplate -) { - sdk = spec; - spec = sdk; +function checkResourceTemplate(sdk: RemovePassthrough, spec: SpecTypes.ResourceTemplate) { + sdk = spec; + spec = sdk; } -function checkPromptArgument( - sdk: RemovePassthrough, - spec: SpecTypes.PromptArgument -) { - sdk = spec; - spec = sdk; +function checkPromptArgument(sdk: RemovePassthrough, spec: SpecTypes.PromptArgument) { + sdk = spec; + spec = sdk; } -function checkPrompt( - sdk: RemovePassthrough, - spec: SpecTypes.Prompt -) { - sdk = spec; - spec = sdk; +function checkPrompt(sdk: RemovePassthrough, spec: SpecTypes.Prompt) { + sdk = spec; + spec = sdk; } -function checkListPromptsRequest( - sdk: WithJSONRPCRequest, - spec: SpecTypes.ListPromptsRequest -) { - sdk = spec; - spec = sdk; +function checkListPromptsRequest(sdk: WithJSONRPCRequest, spec: SpecTypes.ListPromptsRequest) { + sdk = spec; + spec = sdk; } -function checkListPromptsResult( - sdk: RemovePassthrough, - spec: SpecTypes.ListPromptsResult -) { - sdk = spec; - spec = sdk; +function checkListPromptsResult(sdk: RemovePassthrough, spec: SpecTypes.ListPromptsResult) { + sdk = spec; + spec = sdk; } -function checkGetPromptRequest( - sdk: WithJSONRPCRequest, - spec: SpecTypes.GetPromptRequest -) { - sdk = spec; - spec = sdk; +function checkGetPromptRequest(sdk: WithJSONRPCRequest, spec: SpecTypes.GetPromptRequest) { + sdk = spec; + spec = sdk; } -function checkTextContent( - sdk: RemovePassthrough, - spec: SpecTypes.TextContent -) { - sdk = spec; - spec = sdk; +function checkTextContent(sdk: RemovePassthrough, spec: SpecTypes.TextContent) { + sdk = spec; + spec = sdk; } -function checkImageContent( - sdk: RemovePassthrough, - spec: SpecTypes.ImageContent -) { - sdk = spec; - spec = sdk; +function checkImageContent(sdk: RemovePassthrough, spec: SpecTypes.ImageContent) { + sdk = spec; + spec = sdk; } -function checkAudioContent( - sdk: RemovePassthrough, - spec: SpecTypes.AudioContent -) { - sdk = spec; - spec = sdk; +function checkAudioContent(sdk: RemovePassthrough, spec: SpecTypes.AudioContent) { + sdk = spec; + spec = sdk; } -function checkEmbeddedResource( - sdk: RemovePassthrough, - spec: SpecTypes.EmbeddedResource -) { - sdk = spec; - spec = sdk; +function checkEmbeddedResource(sdk: RemovePassthrough, spec: SpecTypes.EmbeddedResource) { + sdk = spec; + spec = sdk; } -function checkResourceLink( - sdk: RemovePassthrough, - spec: SpecTypes.ResourceLink -) { - sdk = spec; - spec = sdk; +function checkResourceLink(sdk: RemovePassthrough, spec: SpecTypes.ResourceLink) { + sdk = spec; + spec = sdk; } -function checkContentBlock( - sdk: RemovePassthrough, - spec: SpecTypes.ContentBlock -) { - sdk = spec; - spec = sdk; +function checkContentBlock(sdk: RemovePassthrough, spec: SpecTypes.ContentBlock) { + sdk = spec; + spec = sdk; } -function checkPromptMessage( - sdk: RemovePassthrough, - spec: SpecTypes.PromptMessage -) { - sdk = spec; - spec = sdk; +function checkPromptMessage(sdk: RemovePassthrough, spec: SpecTypes.PromptMessage) { + sdk = spec; + spec = sdk; } -function checkGetPromptResult( - sdk: RemovePassthrough, - spec: SpecTypes.GetPromptResult -) { - sdk = spec; - spec = sdk; +function checkGetPromptResult(sdk: RemovePassthrough, spec: SpecTypes.GetPromptResult) { + sdk = spec; + spec = sdk; } -function checkBooleanSchema( - sdk: RemovePassthrough, - spec: SpecTypes.BooleanSchema -) { - sdk = spec; - spec = sdk; +function checkBooleanSchema(sdk: RemovePassthrough, spec: SpecTypes.BooleanSchema) { + sdk = spec; + spec = sdk; } -function checkStringSchema( - sdk: RemovePassthrough, - spec: SpecTypes.StringSchema -) { - sdk = spec; - spec = sdk; +function checkStringSchema(sdk: RemovePassthrough, spec: SpecTypes.StringSchema) { + sdk = spec; + spec = sdk; } -function checkNumberSchema( - sdk: RemovePassthrough, - spec: SpecTypes.NumberSchema -) { - sdk = spec; - spec = sdk; +function checkNumberSchema(sdk: RemovePassthrough, spec: SpecTypes.NumberSchema) { + sdk = spec; + spec = sdk; } -function checkEnumSchema( - sdk: RemovePassthrough, - spec: SpecTypes.EnumSchema -) { - sdk = spec; - spec = sdk; +function checkEnumSchema(sdk: RemovePassthrough, spec: SpecTypes.EnumSchema) { + sdk = spec; + spec = sdk; } function checkPrimitiveSchemaDefinition( - sdk: RemovePassthrough, - spec: SpecTypes.PrimitiveSchemaDefinition + sdk: RemovePassthrough, + spec: SpecTypes.PrimitiveSchemaDefinition ) { - sdk = spec; - spec = sdk; + sdk = spec; + spec = sdk; } -function checkJSONRPCError( - sdk: SDKTypes.JSONRPCError, - spec: SpecTypes.JSONRPCError -) { - sdk = spec; - spec = sdk; +function checkJSONRPCError(sdk: SDKTypes.JSONRPCError, spec: SpecTypes.JSONRPCError) { + sdk = spec; + spec = sdk; } -function checkJSONRPCMessage( - sdk: SDKTypes.JSONRPCMessage, - spec: SpecTypes.JSONRPCMessage -) { - sdk = spec; - spec = sdk; +function checkJSONRPCMessage(sdk: SDKTypes.JSONRPCMessage, spec: SpecTypes.JSONRPCMessage) { + sdk = spec; + spec = sdk; } function checkCreateMessageRequest( - sdk: WithJSONRPCRequest>, - spec: SpecTypes.CreateMessageRequest + sdk: WithJSONRPCRequest>, + spec: SpecTypes.CreateMessageRequest ) { - sdk = spec; - spec = sdk; + sdk = spec; + spec = sdk; } -function checkInitializeRequest( - sdk: WithJSONRPCRequest>, - spec: SpecTypes.InitializeRequest -) { - sdk = spec; - spec = sdk; +function checkInitializeRequest(sdk: WithJSONRPCRequest>, spec: SpecTypes.InitializeRequest) { + sdk = spec; + spec = sdk; } -function checkInitializeResult( - sdk: RemovePassthrough, - spec: SpecTypes.InitializeResult -) { - sdk = spec; - spec = sdk; +function checkInitializeResult(sdk: RemovePassthrough, spec: SpecTypes.InitializeResult) { + sdk = spec; + spec = sdk; } -function checkClientCapabilities( - sdk: RemovePassthrough, - spec: SpecTypes.ClientCapabilities -) { - sdk = spec; - spec = sdk; +function checkClientCapabilities(sdk: RemovePassthrough, spec: SpecTypes.ClientCapabilities) { + sdk = spec; + spec = sdk; } -function checkServerCapabilities( - sdk: RemovePassthrough, - spec: SpecTypes.ServerCapabilities -) { - sdk = spec; - spec = sdk; +function checkServerCapabilities(sdk: RemovePassthrough, spec: SpecTypes.ServerCapabilities) { + sdk = spec; + spec = sdk; } -function checkClientRequest( - sdk: WithJSONRPCRequest>, - spec: SpecTypes.ClientRequest -) { - sdk = spec; - spec = sdk; +function checkClientRequest(sdk: WithJSONRPCRequest>, spec: SpecTypes.ClientRequest) { + sdk = spec; + spec = sdk; } -function checkServerRequest( - sdk: WithJSONRPCRequest>, - spec: SpecTypes.ServerRequest -) { - sdk = spec; - spec = sdk; +function checkServerRequest(sdk: WithJSONRPCRequest>, spec: SpecTypes.ServerRequest) { + sdk = spec; + spec = sdk; } function checkLoggingMessageNotification( - sdk: MakeUnknownsNotOptional>, - spec: SpecTypes.LoggingMessageNotification + sdk: MakeUnknownsNotOptional>, + spec: SpecTypes.LoggingMessageNotification ) { - sdk = spec; - spec = sdk; + sdk = spec; + spec = sdk; } function checkServerNotification( - sdk: MakeUnknownsNotOptional>, - spec: SpecTypes.ServerNotification + sdk: MakeUnknownsNotOptional>, + spec: SpecTypes.ServerNotification ) { - sdk = spec; - spec = sdk; + sdk = spec; + spec = sdk; } -function checkLoggingLevel( - sdk: SDKTypes.LoggingLevel, - spec: SpecTypes.LoggingLevel -) { - sdk = spec; - spec = sdk; +function checkLoggingLevel(sdk: SDKTypes.LoggingLevel, spec: SpecTypes.LoggingLevel) { + sdk = spec; + spec = sdk; } -function checkIcon( - sdk: RemovePassthrough, - spec: SpecTypes.Icon -) { - sdk = spec; - spec = sdk; +function checkIcon(sdk: RemovePassthrough, spec: SpecTypes.Icon) { + sdk = spec; + spec = sdk; } // This file is .gitignore'd, and fetched by `npm run fetch:spec-types` (called by `npm run test`) -const SPEC_TYPES_FILE = 'spec.types.ts'; -const SDK_TYPES_FILE = 'src/types.ts'; +const SPEC_TYPES_FILE = 'spec.types.ts'; +const SDK_TYPES_FILE = 'src/types.ts'; const MISSING_SDK_TYPES = [ - // These are inlined in the SDK: - 'Role', - 'Error', // The inner error object of a JSONRPCError + // These are inlined in the SDK: + 'Role', + 'Error', // The inner error object of a JSONRPCError - // These aren't supported by the SDK yet: - // TODO: Add definitions to the SDK - 'Annotations', - 'ModelHint', - 'ModelPreferences', - 'Icons', -] + // These aren't supported by the SDK yet: + // TODO: Add definitions to the SDK + 'Annotations', + 'ModelHint', + 'ModelPreferences', + 'Icons' +]; function extractExportedTypes(source: string): string[] { - return [...source.matchAll(/export\s+(?:interface|class|type)\s+(\w+)\b/g)].map(m => m[1]); + return [...source.matchAll(/export\s+(?:interface|class|type)\s+(\w+)\b/g)].map(m => m[1]); } describe('Spec Types', () => { - const specTypes = extractExportedTypes(fs.readFileSync(SPEC_TYPES_FILE, 'utf-8')); - const sdkTypes = extractExportedTypes(fs.readFileSync(SDK_TYPES_FILE, 'utf-8')); - const testSource = fs.readFileSync(__filename, 'utf-8'); + const specTypes = extractExportedTypes(fs.readFileSync(SPEC_TYPES_FILE, 'utf-8')); + const sdkTypes = extractExportedTypes(fs.readFileSync(SDK_TYPES_FILE, 'utf-8')); + const testSource = fs.readFileSync(__filename, 'utf-8'); - it('should define some expected types', () => { - expect(specTypes).toContain('JSONRPCNotification'); - expect(specTypes).toContain('ElicitResult'); - expect(specTypes).toHaveLength(94); - }); + it('should define some expected types', () => { + expect(specTypes).toContain('JSONRPCNotification'); + expect(specTypes).toContain('ElicitResult'); + expect(specTypes).toHaveLength(94); + }); - it('should have up to date list of missing sdk types', () => { - for (const typeName of MISSING_SDK_TYPES) { - expect(sdkTypes).not.toContain(typeName); - } - }); + it('should have up to date list of missing sdk types', () => { + for (const typeName of MISSING_SDK_TYPES) { + expect(sdkTypes).not.toContain(typeName); + } + }); - for (const type of specTypes) { - if (MISSING_SDK_TYPES.includes(type)) { - continue; // Skip missing SDK types + for (const type of specTypes) { + if (MISSING_SDK_TYPES.includes(type)) { + continue; // Skip missing SDK types + } + it(`${type} should have a compatibility test`, () => { + expect(testSource).toContain(`function check${type}(`); + }); } - it(`${type} should have a compatibility test`, () => { - expect(testSource).toContain(`function check${type}(`); - }); - } }); diff --git a/src/types.test.ts b/src/types.test.ts index 0aee62a93..cd8cc0711 100644 --- a/src/types.test.ts +++ b/src/types.test.ts @@ -6,76 +6,75 @@ import { PromptMessageSchema, CallToolResultSchema, CompleteRequestSchema -} from "./types.js"; +} from './types.js'; -describe("Types", () => { - - test("should have correct latest protocol version", () => { +describe('Types', () => { + test('should have correct latest protocol version', () => { expect(LATEST_PROTOCOL_VERSION).toBeDefined(); - expect(LATEST_PROTOCOL_VERSION).toBe("2025-06-18"); + expect(LATEST_PROTOCOL_VERSION).toBe('2025-06-18'); }); - test("should have correct supported protocol versions", () => { + test('should have correct supported protocol versions', () => { expect(SUPPORTED_PROTOCOL_VERSIONS).toBeDefined(); expect(SUPPORTED_PROTOCOL_VERSIONS).toBeInstanceOf(Array); expect(SUPPORTED_PROTOCOL_VERSIONS).toContain(LATEST_PROTOCOL_VERSION); - expect(SUPPORTED_PROTOCOL_VERSIONS).toContain("2024-11-05"); - expect(SUPPORTED_PROTOCOL_VERSIONS).toContain("2024-10-07"); - expect(SUPPORTED_PROTOCOL_VERSIONS).toContain("2025-03-26"); + expect(SUPPORTED_PROTOCOL_VERSIONS).toContain('2024-11-05'); + expect(SUPPORTED_PROTOCOL_VERSIONS).toContain('2024-10-07'); + expect(SUPPORTED_PROTOCOL_VERSIONS).toContain('2025-03-26'); }); - describe("ResourceLink", () => { - test("should validate a minimal ResourceLink", () => { + describe('ResourceLink', () => { + test('should validate a minimal ResourceLink', () => { const resourceLink = { - type: "resource_link", - uri: "file:///path/to/file.txt", - name: "file.txt" + type: 'resource_link', + uri: 'file:///path/to/file.txt', + name: 'file.txt' }; const result = ResourceLinkSchema.safeParse(resourceLink); expect(result.success).toBe(true); if (result.success) { - expect(result.data.type).toBe("resource_link"); - expect(result.data.uri).toBe("file:///path/to/file.txt"); - expect(result.data.name).toBe("file.txt"); + expect(result.data.type).toBe('resource_link'); + expect(result.data.uri).toBe('file:///path/to/file.txt'); + expect(result.data.name).toBe('file.txt'); } }); - test("should validate a ResourceLink with all optional fields", () => { + test('should validate a ResourceLink with all optional fields', () => { const resourceLink = { - type: "resource_link", - uri: "https://example.com/resource", - name: "Example Resource", - title: "A comprehensive example resource", - description: "This resource demonstrates all fields", - mimeType: "text/plain", - _meta: { custom: "metadata" } + type: 'resource_link', + uri: 'https://example.com/resource', + name: 'Example Resource', + title: 'A comprehensive example resource', + description: 'This resource demonstrates all fields', + mimeType: 'text/plain', + _meta: { custom: 'metadata' } }; const result = ResourceLinkSchema.safeParse(resourceLink); expect(result.success).toBe(true); if (result.success) { - expect(result.data.title).toBe("A comprehensive example resource"); - expect(result.data.description).toBe("This resource demonstrates all fields"); - expect(result.data.mimeType).toBe("text/plain"); - expect(result.data._meta).toEqual({ custom: "metadata" }); + expect(result.data.title).toBe('A comprehensive example resource'); + expect(result.data.description).toBe('This resource demonstrates all fields'); + expect(result.data.mimeType).toBe('text/plain'); + expect(result.data._meta).toEqual({ custom: 'metadata' }); } }); - test("should fail validation for invalid type", () => { + test('should fail validation for invalid type', () => { const invalidResourceLink = { - type: "invalid_type", - uri: "file:///path/to/file.txt", - name: "file.txt" + type: 'invalid_type', + uri: 'file:///path/to/file.txt', + name: 'file.txt' }; const result = ResourceLinkSchema.safeParse(invalidResourceLink); expect(result.success).toBe(false); }); - test("should fail validation for missing required fields", () => { + test('should fail validation for missing required fields', () => { const invalidResourceLink = { - type: "resource_link", - uri: "file:///path/to/file.txt" + type: 'resource_link', + uri: 'file:///path/to/file.txt' // missing name }; @@ -84,123 +83,123 @@ describe("Types", () => { }); }); - describe("ContentBlock", () => { - test("should validate text content", () => { + describe('ContentBlock', () => { + test('should validate text content', () => { const textContent = { - type: "text", - text: "Hello, world!" + type: 'text', + text: 'Hello, world!' }; const result = ContentBlockSchema.safeParse(textContent); expect(result.success).toBe(true); if (result.success) { - expect(result.data.type).toBe("text"); + expect(result.data.type).toBe('text'); } }); - test("should validate image content", () => { + test('should validate image content', () => { const imageContent = { - type: "image", - data: "aGVsbG8=", // base64 encoded "hello" - mimeType: "image/png" + type: 'image', + data: 'aGVsbG8=', // base64 encoded "hello" + mimeType: 'image/png' }; const result = ContentBlockSchema.safeParse(imageContent); expect(result.success).toBe(true); if (result.success) { - expect(result.data.type).toBe("image"); + expect(result.data.type).toBe('image'); } }); - test("should validate audio content", () => { + test('should validate audio content', () => { const audioContent = { - type: "audio", - data: "aGVsbG8=", // base64 encoded "hello" - mimeType: "audio/mp3" + type: 'audio', + data: 'aGVsbG8=', // base64 encoded "hello" + mimeType: 'audio/mp3' }; const result = ContentBlockSchema.safeParse(audioContent); expect(result.success).toBe(true); if (result.success) { - expect(result.data.type).toBe("audio"); + expect(result.data.type).toBe('audio'); } }); - test("should validate resource link content", () => { + test('should validate resource link content', () => { const resourceLink = { - type: "resource_link", - uri: "file:///path/to/file.txt", - name: "file.txt", - mimeType: "text/plain" + type: 'resource_link', + uri: 'file:///path/to/file.txt', + name: 'file.txt', + mimeType: 'text/plain' }; const result = ContentBlockSchema.safeParse(resourceLink); expect(result.success).toBe(true); if (result.success) { - expect(result.data.type).toBe("resource_link"); + expect(result.data.type).toBe('resource_link'); } }); - test("should validate embedded resource content", () => { + test('should validate embedded resource content', () => { const embeddedResource = { - type: "resource", + type: 'resource', resource: { - uri: "file:///path/to/file.txt", - mimeType: "text/plain", - text: "File contents" + uri: 'file:///path/to/file.txt', + mimeType: 'text/plain', + text: 'File contents' } }; const result = ContentBlockSchema.safeParse(embeddedResource); expect(result.success).toBe(true); if (result.success) { - expect(result.data.type).toBe("resource"); + expect(result.data.type).toBe('resource'); } }); }); - describe("PromptMessage with ContentBlock", () => { - test("should validate prompt message with resource link", () => { + describe('PromptMessage with ContentBlock', () => { + test('should validate prompt message with resource link', () => { const promptMessage = { - role: "assistant", + role: 'assistant', content: { - type: "resource_link", - uri: "file:///project/src/main.rs", - name: "main.rs", - description: "Primary application entry point", - mimeType: "text/x-rust" + type: 'resource_link', + uri: 'file:///project/src/main.rs', + name: 'main.rs', + description: 'Primary application entry point', + mimeType: 'text/x-rust' } }; const result = PromptMessageSchema.safeParse(promptMessage); expect(result.success).toBe(true); if (result.success) { - expect(result.data.content.type).toBe("resource_link"); + expect(result.data.content.type).toBe('resource_link'); } }); }); - describe("CallToolResult with ContentBlock", () => { - test("should validate tool result with resource links", () => { + describe('CallToolResult with ContentBlock', () => { + test('should validate tool result with resource links', () => { const toolResult = { content: [ { - type: "text", - text: "Found the following files:" + type: 'text', + text: 'Found the following files:' }, { - type: "resource_link", - uri: "file:///project/src/main.rs", - name: "main.rs", - description: "Primary application entry point", - mimeType: "text/x-rust" + type: 'resource_link', + uri: 'file:///project/src/main.rs', + name: 'main.rs', + description: 'Primary application entry point', + mimeType: 'text/x-rust' }, { - type: "resource_link", - uri: "file:///project/src/lib.rs", - name: "lib.rs", - description: "Library exports", - mimeType: "text/x-rust" + type: 'resource_link', + uri: 'file:///project/src/lib.rs', + name: 'lib.rs', + description: 'Library exports', + mimeType: 'text/x-rust' } ] }; @@ -209,13 +208,13 @@ describe("Types", () => { expect(result.success).toBe(true); if (result.success) { expect(result.data.content).toHaveLength(3); - expect(result.data.content[0].type).toBe("text"); - expect(result.data.content[1].type).toBe("resource_link"); - expect(result.data.content[2].type).toBe("resource_link"); + expect(result.data.content[0].type).toBe('text'); + expect(result.data.content[1].type).toBe('resource_link'); + expect(result.data.content[2].type).toBe('resource_link'); } }); - test("should validate empty content array with default", () => { + test('should validate empty content array with default', () => { const toolResult = {}; const result = CallToolResultSchema.safeParse(toolResult); @@ -226,34 +225,34 @@ describe("Types", () => { }); }); - describe("CompleteRequest", () => { - test("should validate a CompleteRequest without resolved field", () => { + describe('CompleteRequest', () => { + test('should validate a CompleteRequest without resolved field', () => { const request = { - method: "completion/complete", + method: 'completion/complete', params: { - ref: { type: "ref/prompt", name: "greeting" }, - argument: { name: "name", value: "A" } + ref: { type: 'ref/prompt', name: 'greeting' }, + argument: { name: 'name', value: 'A' } } }; const result = CompleteRequestSchema.safeParse(request); expect(result.success).toBe(true); if (result.success) { - expect(result.data.method).toBe("completion/complete"); - expect(result.data.params.ref.type).toBe("ref/prompt"); + expect(result.data.method).toBe('completion/complete'); + expect(result.data.params.ref.type).toBe('ref/prompt'); expect(result.data.params.context).toBeUndefined(); } }); - test("should validate a CompleteRequest with resolved field", () => { + test('should validate a CompleteRequest with resolved field', () => { const request = { - method: "completion/complete", + method: 'completion/complete', params: { - ref: { type: "ref/resource", uri: "github://repos/{owner}/{repo}" }, - argument: { name: "repo", value: "t" }, + ref: { type: 'ref/resource', uri: 'github://repos/{owner}/{repo}' }, + argument: { name: 'repo', value: 't' }, context: { arguments: { - "{owner}": "microsoft" + '{owner}': 'microsoft' } } } @@ -263,17 +262,17 @@ describe("Types", () => { expect(result.success).toBe(true); if (result.success) { expect(result.data.params.context?.arguments).toEqual({ - "{owner}": "microsoft" + '{owner}': 'microsoft' }); } }); - test("should validate a CompleteRequest with empty resolved field", () => { + test('should validate a CompleteRequest with empty resolved field', () => { const request = { - method: "completion/complete", + method: 'completion/complete', params: { - ref: { type: "ref/prompt", name: "test" }, - argument: { name: "arg", value: "" }, + ref: { type: 'ref/prompt', name: 'test' }, + argument: { name: 'arg', value: '' }, context: { arguments: {} } @@ -287,16 +286,16 @@ describe("Types", () => { } }); - test("should validate a CompleteRequest with multiple resolved variables", () => { + test('should validate a CompleteRequest with multiple resolved variables', () => { const request = { - method: "completion/complete", + method: 'completion/complete', params: { - ref: { type: "ref/resource", uri: "api://v1/{tenant}/{resource}/{id}" }, - argument: { name: "id", value: "123" }, + ref: { type: 'ref/resource', uri: 'api://v1/{tenant}/{resource}/{id}' }, + argument: { name: 'id', value: '123' }, context: { arguments: { - "{tenant}": "acme-corp", - "{resource}": "users" + '{tenant}': 'acme-corp', + '{resource}': 'users' } } } @@ -306,8 +305,8 @@ describe("Types", () => { expect(result.success).toBe(true); if (result.success) { expect(result.data.params.context?.arguments).toEqual({ - "{tenant}": "acme-corp", - "{resource}": "users" + '{tenant}': 'acme-corp', + '{resource}': 'users' }); } }); diff --git a/src/types.ts b/src/types.ts index 262e3b623..67466ed7f 100644 --- a/src/types.ts +++ b/src/types.ts @@ -1,17 +1,12 @@ -import { z, ZodTypeAny } from "zod"; -import { AuthInfo } from "./server/auth/types.js"; - -export const LATEST_PROTOCOL_VERSION = "2025-06-18"; -export const DEFAULT_NEGOTIATED_PROTOCOL_VERSION = "2025-03-26"; -export const SUPPORTED_PROTOCOL_VERSIONS = [ - LATEST_PROTOCOL_VERSION, - "2025-03-26", - "2024-11-05", - "2024-10-07", -]; +import { z, ZodTypeAny } from 'zod'; +import { AuthInfo } from './server/auth/types.js'; + +export const LATEST_PROTOCOL_VERSION = '2025-06-18'; +export const DEFAULT_NEGOTIATED_PROTOCOL_VERSION = '2025-03-26'; +export const SUPPORTED_PROTOCOL_VERSIONS = [LATEST_PROTOCOL_VERSION, '2025-03-26', '2024-11-05', '2024-10-07']; /* JSON-RPC types */ -export const JSONRPC_VERSION = "2.0"; +export const JSONRPC_VERSION = '2.0'; /** * A progress token, used to associate progress notifications with the original request. @@ -24,49 +19,49 @@ export const ProgressTokenSchema = z.union([z.string(), z.number().int()]); export const CursorSchema = z.string(); const RequestMetaSchema = z - .object({ - /** - * If specified, the caller is requesting out-of-band progress notifications for this request (as represented by notifications/progress). The value of this parameter is an opaque token that will be attached to any subsequent notifications. The receiver is not obligated to provide these notifications. - */ - progressToken: z.optional(ProgressTokenSchema), - }) - .passthrough(); + .object({ + /** + * If specified, the caller is requesting out-of-band progress notifications for this request (as represented by notifications/progress). The value of this parameter is an opaque token that will be attached to any subsequent notifications. The receiver is not obligated to provide these notifications. + */ + progressToken: z.optional(ProgressTokenSchema) + }) + .passthrough(); const BaseRequestParamsSchema = z - .object({ - _meta: z.optional(RequestMetaSchema), - }) - .passthrough(); + .object({ + _meta: z.optional(RequestMetaSchema) + }) + .passthrough(); export const RequestSchema = z.object({ - method: z.string(), - params: z.optional(BaseRequestParamsSchema), + method: z.string(), + params: z.optional(BaseRequestParamsSchema) }); const BaseNotificationParamsSchema = z - .object({ - /** - * See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) - * for notes on _meta usage. - */ - _meta: z.optional(z.object({}).passthrough()), - }) - .passthrough(); + .object({ + /** + * See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + * for notes on _meta usage. + */ + _meta: z.optional(z.object({}).passthrough()) + }) + .passthrough(); export const NotificationSchema = z.object({ - method: z.string(), - params: z.optional(BaseNotificationParamsSchema), + method: z.string(), + params: z.optional(BaseNotificationParamsSchema) }); export const ResultSchema = z - .object({ - /** - * See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) - * for notes on _meta usage. - */ - _meta: z.optional(z.object({}).passthrough()), - }) - .passthrough(); + .object({ + /** + * See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + * for notes on _meta usage. + */ + _meta: z.optional(z.object({}).passthrough()) + }) + .passthrough(); /** * A uniquely identifying ID for a request in JSON-RPC. @@ -77,94 +72,83 @@ export const RequestIdSchema = z.union([z.string(), z.number().int()]); * A request that expects a response. */ export const JSONRPCRequestSchema = z - .object({ - jsonrpc: z.literal(JSONRPC_VERSION), - id: RequestIdSchema, - }) - .merge(RequestSchema) - .strict(); + .object({ + jsonrpc: z.literal(JSONRPC_VERSION), + id: RequestIdSchema + }) + .merge(RequestSchema) + .strict(); -export const isJSONRPCRequest = (value: unknown): value is JSONRPCRequest => - JSONRPCRequestSchema.safeParse(value).success; +export const isJSONRPCRequest = (value: unknown): value is JSONRPCRequest => JSONRPCRequestSchema.safeParse(value).success; /** * A notification which does not expect a response. */ export const JSONRPCNotificationSchema = z - .object({ - jsonrpc: z.literal(JSONRPC_VERSION), - }) - .merge(NotificationSchema) - .strict(); + .object({ + jsonrpc: z.literal(JSONRPC_VERSION) + }) + .merge(NotificationSchema) + .strict(); -export const isJSONRPCNotification = ( - value: unknown -): value is JSONRPCNotification => - JSONRPCNotificationSchema.safeParse(value).success; +export const isJSONRPCNotification = (value: unknown): value is JSONRPCNotification => JSONRPCNotificationSchema.safeParse(value).success; /** * A successful (non-error) response to a request. */ export const JSONRPCResponseSchema = z - .object({ - jsonrpc: z.literal(JSONRPC_VERSION), - id: RequestIdSchema, - result: ResultSchema, - }) - .strict(); + .object({ + jsonrpc: z.literal(JSONRPC_VERSION), + id: RequestIdSchema, + result: ResultSchema + }) + .strict(); -export const isJSONRPCResponse = (value: unknown): value is JSONRPCResponse => - JSONRPCResponseSchema.safeParse(value).success; +export const isJSONRPCResponse = (value: unknown): value is JSONRPCResponse => JSONRPCResponseSchema.safeParse(value).success; /** * Error codes defined by the JSON-RPC specification. */ export enum ErrorCode { - // SDK error codes - ConnectionClosed = -32000, - RequestTimeout = -32001, - - // Standard JSON-RPC error codes - ParseError = -32700, - InvalidRequest = -32600, - MethodNotFound = -32601, - InvalidParams = -32602, - InternalError = -32603, + // SDK error codes + ConnectionClosed = -32000, + RequestTimeout = -32001, + + // Standard JSON-RPC error codes + ParseError = -32700, + InvalidRequest = -32600, + MethodNotFound = -32601, + InvalidParams = -32602, + InternalError = -32603 } /** * A response to a request that indicates an error occurred. */ export const JSONRPCErrorSchema = z - .object({ - jsonrpc: z.literal(JSONRPC_VERSION), - id: RequestIdSchema, - error: z.object({ - /** - * The error type that occurred. - */ - code: z.number().int(), - /** - * A short description of the error. The message SHOULD be limited to a concise single sentence. - */ - message: z.string(), - /** - * Additional information about the error. The value of this member is defined by the sender (e.g. detailed error information, nested errors etc.). - */ - data: z.optional(z.unknown()), - }), - }) - .strict(); - -export const isJSONRPCError = (value: unknown): value is JSONRPCError => - JSONRPCErrorSchema.safeParse(value).success; - -export const JSONRPCMessageSchema = z.union([ - JSONRPCRequestSchema, - JSONRPCNotificationSchema, - JSONRPCResponseSchema, - JSONRPCErrorSchema, -]); + .object({ + jsonrpc: z.literal(JSONRPC_VERSION), + id: RequestIdSchema, + error: z.object({ + /** + * The error type that occurred. + */ + code: z.number().int(), + /** + * A short description of the error. The message SHOULD be limited to a concise single sentence. + */ + message: z.string(), + /** + * Additional information about the error. The value of this member is defined by the sender (e.g. detailed error information, nested errors etc.). + */ + data: z.optional(z.unknown()) + }) + }) + .strict(); + +export const isJSONRPCError = (value: unknown): value is JSONRPCError => JSONRPCErrorSchema.safeParse(value).success; + +export const JSONRPCMessageSchema = z.union([JSONRPCRequestSchema, JSONRPCNotificationSchema, JSONRPCResponseSchema, JSONRPCErrorSchema]); /* Empty result */ /** @@ -183,20 +167,20 @@ export const EmptyResultSchema = ResultSchema.strict(); * A client MUST NOT attempt to cancel its `initialize` request. */ export const CancelledNotificationSchema = NotificationSchema.extend({ - method: z.literal("notifications/cancelled"), - params: BaseNotificationParamsSchema.extend({ - /** - * The ID of the request to cancel. - * - * This MUST correspond to the ID of a request previously issued in the same direction. - */ - requestId: RequestIdSchema, + method: z.literal('notifications/cancelled'), + params: BaseNotificationParamsSchema.extend({ + /** + * The ID of the request to cancel. + * + * This MUST correspond to the ID of a request previously issued in the same direction. + */ + requestId: RequestIdSchema, - /** - * An optional string describing the reason for the cancellation. This MAY be logged or presented to the user. - */ - reason: z.string().optional(), - }), + /** + * An optional string describing the reason for the cancellation. This MAY be logged or presented to the user. + */ + reason: z.string().optional() + }) }); /* Base Metadata */ @@ -204,262 +188,260 @@ export const CancelledNotificationSchema = NotificationSchema.extend({ * Icon schema for use in tools, prompts, resources, and implementations. */ export const IconSchema = z - .object({ - /** - * URL or data URI for the icon. - */ - src: z.string(), - /** - * Optional MIME type for the icon. - */ - mimeType: z.optional(z.string()), - /** - * Optional string specifying icon dimensions (e.g., "48x48 96x96"). - */ - sizes: z.optional(z.string()), - }) - .passthrough(); + .object({ + /** + * URL or data URI for the icon. + */ + src: z.string(), + /** + * Optional MIME type for the icon. + */ + mimeType: z.optional(z.string()), + /** + * Optional string specifying icon dimensions (e.g., "48x48 96x96"). + */ + sizes: z.optional(z.string()) + }) + .passthrough(); /** * Base metadata interface for common properties across resources, tools, prompts, and implementations. */ export const BaseMetadataSchema = z - .object({ - /** Intended for programmatic or logical use, but used as a display name in past specs or fallback */ - name: z.string(), - /** - * Intended for UI and end-user contexts — optimized to be human-readable and easily understood, - * even by those unfamiliar with domain-specific terminology. - * - * If not provided, the name should be used for display (except for Tool, - * where `annotations.title` should be given precedence over using `name`, - * if present). - */ - title: z.optional(z.string()), - }) - .passthrough(); + .object({ + /** Intended for programmatic or logical use, but used as a display name in past specs or fallback */ + name: z.string(), + /** + * Intended for UI and end-user contexts — optimized to be human-readable and easily understood, + * even by those unfamiliar with domain-specific terminology. + * + * If not provided, the name should be used for display (except for Tool, + * where `annotations.title` should be given precedence over using `name`, + * if present). + */ + title: z.optional(z.string()) + }) + .passthrough(); /* Initialization */ /** * Describes the name and version of an MCP implementation. */ export const ImplementationSchema = BaseMetadataSchema.extend({ - version: z.string(), - /** - * An optional URL of the website for this implementation. - */ - websiteUrl: z.optional(z.string()), - /** - * An optional list of icons for this implementation. - * This can be used by clients to display the implementation in a user interface. - * Each icon should have a `kind` property that specifies whether it is a data representation or a URL source, a `src` property that points to the icon file or data representation, and may also include a `mimeType` and `sizes` property. - * The `mimeType` property should be a valid MIME type for the icon file, such as "image/png" or "image/svg+xml". - * The `sizes` property should be a string that specifies one or more sizes at which the icon file can be used, such as "48x48" or "any" for scalable formats like SVG. - * The `sizes` property is optional, and if not provided, the client should assume that the icon can be used at any size. - */ - icons: z.optional(z.array(IconSchema)), + version: z.string(), + /** + * An optional URL of the website for this implementation. + */ + websiteUrl: z.optional(z.string()), + /** + * An optional list of icons for this implementation. + * This can be used by clients to display the implementation in a user interface. + * Each icon should have a `kind` property that specifies whether it is a data representation or a URL source, a `src` property that points to the icon file or data representation, and may also include a `mimeType` and `sizes` property. + * The `mimeType` property should be a valid MIME type for the icon file, such as "image/png" or "image/svg+xml". + * The `sizes` property should be a string that specifies one or more sizes at which the icon file can be used, such as "48x48" or "any" for scalable formats like SVG. + * The `sizes` property is optional, and if not provided, the client should assume that the icon can be used at any size. + */ + icons: z.optional(z.array(IconSchema)) }); /** * Capabilities a client may support. Known capabilities are defined here, in this schema, but this is not a closed set: any client can define its own, additional capabilities. */ export const ClientCapabilitiesSchema = z - .object({ - /** - * Experimental, non-standard capabilities that the client supports. - */ - experimental: z.optional(z.object({}).passthrough()), - /** - * Present if the client supports sampling from an LLM. - */ - sampling: z.optional(z.object({}).passthrough()), - /** - * Present if the client supports eliciting user input. - */ - elicitation: z.optional(z.object({}).passthrough()), - /** - * Present if the client supports listing roots. - */ - roots: z.optional( - z - .object({ - /** - * Whether the client supports issuing notifications for changes to the roots list. - */ - listChanged: z.optional(z.boolean()), - }) - .passthrough(), - ), - }) - .passthrough(); + .object({ + /** + * Experimental, non-standard capabilities that the client supports. + */ + experimental: z.optional(z.object({}).passthrough()), + /** + * Present if the client supports sampling from an LLM. + */ + sampling: z.optional(z.object({}).passthrough()), + /** + * Present if the client supports eliciting user input. + */ + elicitation: z.optional(z.object({}).passthrough()), + /** + * Present if the client supports listing roots. + */ + roots: z.optional( + z + .object({ + /** + * Whether the client supports issuing notifications for changes to the roots list. + */ + listChanged: z.optional(z.boolean()) + }) + .passthrough() + ) + }) + .passthrough(); /** * This request is sent from the client to the server when it first connects, asking it to begin initialization. */ export const InitializeRequestSchema = RequestSchema.extend({ - method: z.literal("initialize"), - params: BaseRequestParamsSchema.extend({ - /** - * The latest version of the Model Context Protocol that the client supports. The client MAY decide to support older versions as well. - */ - protocolVersion: z.string(), - capabilities: ClientCapabilitiesSchema, - clientInfo: ImplementationSchema, - }), + method: z.literal('initialize'), + params: BaseRequestParamsSchema.extend({ + /** + * The latest version of the Model Context Protocol that the client supports. The client MAY decide to support older versions as well. + */ + protocolVersion: z.string(), + capabilities: ClientCapabilitiesSchema, + clientInfo: ImplementationSchema + }) }); -export const isInitializeRequest = (value: unknown): value is InitializeRequest => - InitializeRequestSchema.safeParse(value).success; - +export const isInitializeRequest = (value: unknown): value is InitializeRequest => InitializeRequestSchema.safeParse(value).success; /** * Capabilities that a server may support. Known capabilities are defined here, in this schema, but this is not a closed set: any server can define its own, additional capabilities. */ export const ServerCapabilitiesSchema = z - .object({ - /** - * Experimental, non-standard capabilities that the server supports. - */ - experimental: z.optional(z.object({}).passthrough()), - /** - * Present if the server supports sending log messages to the client. - */ - logging: z.optional(z.object({}).passthrough()), - /** - * Present if the server supports sending completions to the client. - */ - completions: z.optional(z.object({}).passthrough()), - /** - * Present if the server offers any prompt templates. - */ - prompts: z.optional( - z - .object({ - /** - * Whether this server supports issuing notifications for changes to the prompt list. - */ - listChanged: z.optional(z.boolean()), - }) - .passthrough(), - ), - /** - * Present if the server offers any resources to read. - */ - resources: z.optional( - z - .object({ - /** - * Whether this server supports clients subscribing to resource updates. - */ - subscribe: z.optional(z.boolean()), - - /** - * Whether this server supports issuing notifications for changes to the resource list. - */ - listChanged: z.optional(z.boolean()), - }) - .passthrough(), - ), - /** - * Present if the server offers any tools to call. - */ - tools: z.optional( - z - .object({ - /** - * Whether this server supports issuing notifications for changes to the tool list. - */ - listChanged: z.optional(z.boolean()), - }) - .passthrough(), - ), - }) - .passthrough(); + .object({ + /** + * Experimental, non-standard capabilities that the server supports. + */ + experimental: z.optional(z.object({}).passthrough()), + /** + * Present if the server supports sending log messages to the client. + */ + logging: z.optional(z.object({}).passthrough()), + /** + * Present if the server supports sending completions to the client. + */ + completions: z.optional(z.object({}).passthrough()), + /** + * Present if the server offers any prompt templates. + */ + prompts: z.optional( + z + .object({ + /** + * Whether this server supports issuing notifications for changes to the prompt list. + */ + listChanged: z.optional(z.boolean()) + }) + .passthrough() + ), + /** + * Present if the server offers any resources to read. + */ + resources: z.optional( + z + .object({ + /** + * Whether this server supports clients subscribing to resource updates. + */ + subscribe: z.optional(z.boolean()), + + /** + * Whether this server supports issuing notifications for changes to the resource list. + */ + listChanged: z.optional(z.boolean()) + }) + .passthrough() + ), + /** + * Present if the server offers any tools to call. + */ + tools: z.optional( + z + .object({ + /** + * Whether this server supports issuing notifications for changes to the tool list. + */ + listChanged: z.optional(z.boolean()) + }) + .passthrough() + ) + }) + .passthrough(); /** * After receiving an initialize request from the client, the server sends this response. */ export const InitializeResultSchema = ResultSchema.extend({ - /** - * The version of the Model Context Protocol that the server wants to use. This may not match the version that the client requested. If the client cannot support this version, it MUST disconnect. - */ - protocolVersion: z.string(), - capabilities: ServerCapabilitiesSchema, - serverInfo: ImplementationSchema, - /** - * Instructions describing how to use the server and its features. - * - * This can be used by clients to improve the LLM's understanding of available tools, resources, etc. It can be thought of like a "hint" to the model. For example, this information MAY be added to the system prompt. - */ - instructions: z.optional(z.string()), + /** + * The version of the Model Context Protocol that the server wants to use. This may not match the version that the client requested. If the client cannot support this version, it MUST disconnect. + */ + protocolVersion: z.string(), + capabilities: ServerCapabilitiesSchema, + serverInfo: ImplementationSchema, + /** + * Instructions describing how to use the server and its features. + * + * This can be used by clients to improve the LLM's understanding of available tools, resources, etc. It can be thought of like a "hint" to the model. For example, this information MAY be added to the system prompt. + */ + instructions: z.optional(z.string()) }); /** * This notification is sent from the client to the server after initialization has finished. */ export const InitializedNotificationSchema = NotificationSchema.extend({ - method: z.literal("notifications/initialized"), + method: z.literal('notifications/initialized') }); export const isInitializedNotification = (value: unknown): value is InitializedNotification => - InitializedNotificationSchema.safeParse(value).success; + InitializedNotificationSchema.safeParse(value).success; /* Ping */ /** * A ping, issued by either the server or the client, to check that the other party is still alive. The receiver must promptly respond, or else may be disconnected. */ export const PingRequestSchema = RequestSchema.extend({ - method: z.literal("ping"), + method: z.literal('ping') }); /* Progress notifications */ export const ProgressSchema = z - .object({ - /** - * The progress thus far. This should increase every time progress is made, even if the total is unknown. - */ - progress: z.number(), - /** - * Total number of items to process (or total progress required), if known. - */ - total: z.optional(z.number()), - /** - * An optional message describing the current progress. - */ - message: z.optional(z.string()), - }) - .passthrough(); + .object({ + /** + * The progress thus far. This should increase every time progress is made, even if the total is unknown. + */ + progress: z.number(), + /** + * Total number of items to process (or total progress required), if known. + */ + total: z.optional(z.number()), + /** + * An optional message describing the current progress. + */ + message: z.optional(z.string()) + }) + .passthrough(); /** * An out-of-band notification used to inform the receiver of a progress update for a long-running request. */ export const ProgressNotificationSchema = NotificationSchema.extend({ - method: z.literal("notifications/progress"), - params: BaseNotificationParamsSchema.merge(ProgressSchema).extend({ - /** - * The progress token which was given in the initial request, used to associate this notification with the request that is proceeding. - */ - progressToken: ProgressTokenSchema, - }), + method: z.literal('notifications/progress'), + params: BaseNotificationParamsSchema.merge(ProgressSchema).extend({ + /** + * The progress token which was given in the initial request, used to associate this notification with the request that is proceeding. + */ + progressToken: ProgressTokenSchema + }) }); /* Pagination */ export const PaginatedRequestSchema = RequestSchema.extend({ - params: BaseRequestParamsSchema.extend({ - /** - * An opaque token representing the current pagination position. - * If provided, the server should return results starting after this cursor. - */ - cursor: z.optional(CursorSchema), - }).optional(), + params: BaseRequestParamsSchema.extend({ + /** + * An opaque token representing the current pagination position. + * If provided, the server should return results starting after this cursor. + */ + cursor: z.optional(CursorSchema) + }).optional() }); export const PaginatedResultSchema = ResultSchema.extend({ - /** - * An opaque token representing the pagination position after the last returned result. - * If present, there may be more results available. - */ - nextCursor: z.optional(CursorSchema), + /** + * An opaque token representing the pagination position after the last returned result. + * If present, there may be more results available. + */ + nextCursor: z.optional(CursorSchema) }); /* Resources */ @@ -467,38 +449,37 @@ export const PaginatedResultSchema = ResultSchema.extend({ * The contents of a specific resource or sub-resource. */ export const ResourceContentsSchema = z - .object({ - /** - * The URI of this resource. - */ - uri: z.string(), - /** - * The MIME type of this resource, if known. - */ - mimeType: z.optional(z.string()), - /** - * See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) - * for notes on _meta usage. - */ - _meta: z.optional(z.object({}).passthrough()), - }) - .passthrough(); + .object({ + /** + * The URI of this resource. + */ + uri: z.string(), + /** + * The MIME type of this resource, if known. + */ + mimeType: z.optional(z.string()), + /** + * See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + * for notes on _meta usage. + */ + _meta: z.optional(z.object({}).passthrough()) + }) + .passthrough(); export const TextResourceContentsSchema = ResourceContentsSchema.extend({ - /** - * The text of the item. This must only be set if the item can actually be represented as text (not binary data). - */ - text: z.string(), + /** + * The text of the item. This must only be set if the item can actually be represented as text (not binary data). + */ + text: z.string() }); - /** * A Zod schema for validating Base64 strings that is more performant and * robust for very large inputs than the default regex-based check. It avoids * stack overflows by using the native `atob` function for validation. */ const Base64Schema = z.string().refine( - (val) => { + val => { try { // atob throws a DOMException if the string contains characters // that are not part of the Base64 character set. @@ -508,173 +489,169 @@ const Base64Schema = z.string().refine( return false; } }, - { message: "Invalid Base64 string" }, + { message: 'Invalid Base64 string' } ); export const BlobResourceContentsSchema = ResourceContentsSchema.extend({ - /** - * A base64-encoded string representing the binary data of the item. - */ - blob: Base64Schema, + /** + * A base64-encoded string representing the binary data of the item. + */ + blob: Base64Schema }); /** * A known resource that the server is capable of reading. */ export const ResourceSchema = BaseMetadataSchema.extend({ - /** - * The URI of this resource. - */ - uri: z.string(), - - /** - * A description of what this resource represents. - * - * This can be used by clients to improve the LLM's understanding of available resources. It can be thought of like a "hint" to the model. - */ - description: z.optional(z.string()), - - /** - * The MIME type of this resource, if known. - */ - mimeType: z.optional(z.string()), - - /** - * An optional list of icons for this resource. - */ - icons: z.optional(z.array(IconSchema)), - - /** - * See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) - * for notes on _meta usage. - */ - _meta: z.optional(z.object({}).passthrough()), + /** + * The URI of this resource. + */ + uri: z.string(), + + /** + * A description of what this resource represents. + * + * This can be used by clients to improve the LLM's understanding of available resources. It can be thought of like a "hint" to the model. + */ + description: z.optional(z.string()), + + /** + * The MIME type of this resource, if known. + */ + mimeType: z.optional(z.string()), + + /** + * An optional list of icons for this resource. + */ + icons: z.optional(z.array(IconSchema)), + + /** + * See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + * for notes on _meta usage. + */ + _meta: z.optional(z.object({}).passthrough()) }); /** * A template description for resources available on the server. */ export const ResourceTemplateSchema = BaseMetadataSchema.extend({ - /** - * A URI template (according to RFC 6570) that can be used to construct resource URIs. - */ - uriTemplate: z.string(), - - /** - * A description of what this template is for. - * - * This can be used by clients to improve the LLM's understanding of available resources. It can be thought of like a "hint" to the model. - */ - description: z.optional(z.string()), - - /** - * The MIME type for all resources that match this template. This should only be included if all resources matching this template have the same type. - */ - mimeType: z.optional(z.string()), - - /** - * See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) - * for notes on _meta usage. - */ - _meta: z.optional(z.object({}).passthrough()), + /** + * A URI template (according to RFC 6570) that can be used to construct resource URIs. + */ + uriTemplate: z.string(), + + /** + * A description of what this template is for. + * + * This can be used by clients to improve the LLM's understanding of available resources. It can be thought of like a "hint" to the model. + */ + description: z.optional(z.string()), + + /** + * The MIME type for all resources that match this template. This should only be included if all resources matching this template have the same type. + */ + mimeType: z.optional(z.string()), + + /** + * See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + * for notes on _meta usage. + */ + _meta: z.optional(z.object({}).passthrough()) }); /** * Sent from the client to request a list of resources the server has. */ export const ListResourcesRequestSchema = PaginatedRequestSchema.extend({ - method: z.literal("resources/list"), + method: z.literal('resources/list') }); /** * The server's response to a resources/list request from the client. */ export const ListResourcesResultSchema = PaginatedResultSchema.extend({ - resources: z.array(ResourceSchema), + resources: z.array(ResourceSchema) }); /** * Sent from the client to request a list of resource templates the server has. */ -export const ListResourceTemplatesRequestSchema = PaginatedRequestSchema.extend( - { - method: z.literal("resources/templates/list"), - }, -); +export const ListResourceTemplatesRequestSchema = PaginatedRequestSchema.extend({ + method: z.literal('resources/templates/list') +}); /** * The server's response to a resources/templates/list request from the client. */ export const ListResourceTemplatesResultSchema = PaginatedResultSchema.extend({ - resourceTemplates: z.array(ResourceTemplateSchema), + resourceTemplates: z.array(ResourceTemplateSchema) }); /** * Sent from the client to the server, to read a specific resource URI. */ export const ReadResourceRequestSchema = RequestSchema.extend({ - method: z.literal("resources/read"), - params: BaseRequestParamsSchema.extend({ - /** - * The URI of the resource to read. The URI can use any protocol; it is up to the server how to interpret it. - */ - uri: z.string(), - }), + method: z.literal('resources/read'), + params: BaseRequestParamsSchema.extend({ + /** + * The URI of the resource to read. The URI can use any protocol; it is up to the server how to interpret it. + */ + uri: z.string() + }) }); /** * The server's response to a resources/read request from the client. */ export const ReadResourceResultSchema = ResultSchema.extend({ - contents: z.array( - z.union([TextResourceContentsSchema, BlobResourceContentsSchema]), - ), + contents: z.array(z.union([TextResourceContentsSchema, BlobResourceContentsSchema])) }); /** * An optional notification from the server to the client, informing it that the list of resources it can read from has changed. This may be issued by servers without any previous subscription from the client. */ export const ResourceListChangedNotificationSchema = NotificationSchema.extend({ - method: z.literal("notifications/resources/list_changed"), + method: z.literal('notifications/resources/list_changed') }); /** * Sent from the client to request resources/updated notifications from the server whenever a particular resource changes. */ export const SubscribeRequestSchema = RequestSchema.extend({ - method: z.literal("resources/subscribe"), - params: BaseRequestParamsSchema.extend({ - /** - * The URI of the resource to subscribe to. The URI can use any protocol; it is up to the server how to interpret it. - */ - uri: z.string(), - }), + method: z.literal('resources/subscribe'), + params: BaseRequestParamsSchema.extend({ + /** + * The URI of the resource to subscribe to. The URI can use any protocol; it is up to the server how to interpret it. + */ + uri: z.string() + }) }); /** * Sent from the client to request cancellation of resources/updated notifications from the server. This should follow a previous resources/subscribe request. */ export const UnsubscribeRequestSchema = RequestSchema.extend({ - method: z.literal("resources/unsubscribe"), - params: BaseRequestParamsSchema.extend({ - /** - * The URI of the resource to unsubscribe from. - */ - uri: z.string(), - }), + method: z.literal('resources/unsubscribe'), + params: BaseRequestParamsSchema.extend({ + /** + * The URI of the resource to unsubscribe from. + */ + uri: z.string() + }) }); /** * A notification from the server to the client, informing it that a resource has changed and may need to be read again. This should only be sent if the client previously sent a resources/subscribe request. */ export const ResourceUpdatedNotificationSchema = NotificationSchema.extend({ - method: z.literal("notifications/resources/updated"), - params: BaseNotificationParamsSchema.extend({ - /** - * The URI of the resource that has been updated. This might be a sub-resource of the one that the client actually subscribed to. - */ - uri: z.string(), - }), + method: z.literal('notifications/resources/updated'), + params: BaseNotificationParamsSchema.extend({ + /** + * The URI of the resource that has been updated. This might be a sub-resource of the one that the client actually subscribed to. + */ + uri: z.string() + }) }); /* Prompts */ @@ -682,155 +659,155 @@ export const ResourceUpdatedNotificationSchema = NotificationSchema.extend({ * Describes an argument that a prompt can accept. */ export const PromptArgumentSchema = z - .object({ + .object({ + /** + * The name of the argument. + */ + name: z.string(), + /** + * A human-readable description of the argument. + */ + description: z.optional(z.string()), + /** + * Whether this argument must be provided. + */ + required: z.optional(z.boolean()) + }) + .passthrough(); + +/** + * A prompt or prompt template that the server offers. + */ +export const PromptSchema = BaseMetadataSchema.extend({ + /** + * An optional description of what this prompt provides + */ + description: z.optional(z.string()), /** - * The name of the argument. + * A list of arguments to use for templating the prompt. */ - name: z.string(), + arguments: z.optional(z.array(PromptArgumentSchema)), /** - * A human-readable description of the argument. + * An optional list of icons for this prompt. */ - description: z.optional(z.string()), + icons: z.optional(z.array(IconSchema)), /** - * Whether this argument must be provided. + * See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + * for notes on _meta usage. */ - required: z.optional(z.boolean()), - }) - .passthrough(); - -/** - * A prompt or prompt template that the server offers. - */ -export const PromptSchema = BaseMetadataSchema.extend({ - /** - * An optional description of what this prompt provides - */ - description: z.optional(z.string()), - /** - * A list of arguments to use for templating the prompt. - */ - arguments: z.optional(z.array(PromptArgumentSchema)), - /** - * An optional list of icons for this prompt. - */ - icons: z.optional(z.array(IconSchema)), - /** - * See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) - * for notes on _meta usage. - */ - _meta: z.optional(z.object({}).passthrough()), + _meta: z.optional(z.object({}).passthrough()) }); /** * Sent from the client to request a list of prompts and prompt templates the server has. */ export const ListPromptsRequestSchema = PaginatedRequestSchema.extend({ - method: z.literal("prompts/list"), + method: z.literal('prompts/list') }); /** * The server's response to a prompts/list request from the client. */ export const ListPromptsResultSchema = PaginatedResultSchema.extend({ - prompts: z.array(PromptSchema), + prompts: z.array(PromptSchema) }); /** * Used by the client to get a prompt provided by the server. */ export const GetPromptRequestSchema = RequestSchema.extend({ - method: z.literal("prompts/get"), - params: BaseRequestParamsSchema.extend({ - /** - * The name of the prompt or prompt template. - */ - name: z.string(), - /** - * Arguments to use for templating the prompt. - */ - arguments: z.optional(z.record(z.string())), - }), + method: z.literal('prompts/get'), + params: BaseRequestParamsSchema.extend({ + /** + * The name of the prompt or prompt template. + */ + name: z.string(), + /** + * Arguments to use for templating the prompt. + */ + arguments: z.optional(z.record(z.string())) + }) }); /** * Text provided to or from an LLM. */ export const TextContentSchema = z - .object({ - type: z.literal("text"), - /** - * The text content of the message. - */ - text: z.string(), + .object({ + type: z.literal('text'), + /** + * The text content of the message. + */ + text: z.string(), - /** - * See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) - * for notes on _meta usage. - */ - _meta: z.optional(z.object({}).passthrough()), - }) - .passthrough(); + /** + * See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + * for notes on _meta usage. + */ + _meta: z.optional(z.object({}).passthrough()) + }) + .passthrough(); /** - * An image provided to or from an LLM. - */ -export const ImageContentSchema = z - .object({ - type: z.literal("image"), - /** - * The base64-encoded image data. - */ - data: Base64Schema, - /** - * The MIME type of the image. Different providers may support different image types. - */ - mimeType: z.string(), + * An image provided to or from an LLM. + */ +export const ImageContentSchema = z + .object({ + type: z.literal('image'), + /** + * The base64-encoded image data. + */ + data: Base64Schema, + /** + * The MIME type of the image. Different providers may support different image types. + */ + mimeType: z.string(), - /** - * See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) - * for notes on _meta usage. - */ - _meta: z.optional(z.object({}).passthrough()), - }) - .passthrough(); + /** + * See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + * for notes on _meta usage. + */ + _meta: z.optional(z.object({}).passthrough()) + }) + .passthrough(); /** * An Audio provided to or from an LLM. */ export const AudioContentSchema = z - .object({ - type: z.literal("audio"), - /** - * The base64-encoded audio data. - */ - data: Base64Schema, - /** - * The MIME type of the audio. Different providers may support different audio types. - */ - mimeType: z.string(), + .object({ + type: z.literal('audio'), + /** + * The base64-encoded audio data. + */ + data: Base64Schema, + /** + * The MIME type of the audio. Different providers may support different audio types. + */ + mimeType: z.string(), - /** - * See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) - * for notes on _meta usage. - */ - _meta: z.optional(z.object({}).passthrough()), - }) - .passthrough(); + /** + * See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + * for notes on _meta usage. + */ + _meta: z.optional(z.object({}).passthrough()) + }) + .passthrough(); /** * The contents of a resource, embedded into a prompt or tool call result. */ export const EmbeddedResourceSchema = z - .object({ - type: z.literal("resource"), - resource: z.union([TextResourceContentsSchema, BlobResourceContentsSchema]), - /** - * See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) - * for notes on _meta usage. - */ - _meta: z.optional(z.object({}).passthrough()), - }) - .passthrough(); + .object({ + type: z.literal('resource'), + resource: z.union([TextResourceContentsSchema, BlobResourceContentsSchema]), + /** + * See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + * for notes on _meta usage. + */ + _meta: z.optional(z.object({}).passthrough()) + }) + .passthrough(); /** * A resource that the server is capable of reading, included in a prompt or tool call result. @@ -838,46 +815,46 @@ export const EmbeddedResourceSchema = z * Note: resource links returned by tools are not guaranteed to appear in the results of `resources/list` requests. */ export const ResourceLinkSchema = ResourceSchema.extend({ - type: z.literal("resource_link"), + type: z.literal('resource_link') }); /** * A content block that can be used in prompts and tool results. */ export const ContentBlockSchema = z.union([ - TextContentSchema, - ImageContentSchema, - AudioContentSchema, - ResourceLinkSchema, - EmbeddedResourceSchema, + TextContentSchema, + ImageContentSchema, + AudioContentSchema, + ResourceLinkSchema, + EmbeddedResourceSchema ]); /** * Describes a message returned as part of a prompt. */ export const PromptMessageSchema = z - .object({ - role: z.enum(["user", "assistant"]), - content: ContentBlockSchema, - }) - .passthrough(); + .object({ + role: z.enum(['user', 'assistant']), + content: ContentBlockSchema + }) + .passthrough(); /** * The server's response to a prompts/get request from the client. */ export const GetPromptResultSchema = ResultSchema.extend({ - /** - * An optional description for the prompt. - */ - description: z.optional(z.string()), - messages: z.array(PromptMessageSchema), + /** + * An optional description for the prompt. + */ + description: z.optional(z.string()), + messages: z.array(PromptMessageSchema) }); /** * An optional notification from the server to the client, informing it that the list of prompts it offers has changed. This may be issued by servers without any previous subscription from the client. */ export const PromptListChangedNotificationSchema = NotificationSchema.extend({ - method: z.literal("notifications/prompts/list_changed"), + method: z.literal('notifications/prompts/list_changed') }); /* Tools */ @@ -892,222 +869,214 @@ export const PromptListChangedNotificationSchema = NotificationSchema.extend({ * received from untrusted servers. */ export const ToolAnnotationsSchema = z - .object({ + .object({ + /** + * A human-readable title for the tool. + */ + title: z.optional(z.string()), + + /** + * If true, the tool does not modify its environment. + * + * Default: false + */ + readOnlyHint: z.optional(z.boolean()), + + /** + * If true, the tool may perform destructive updates to its environment. + * If false, the tool performs only additive updates. + * + * (This property is meaningful only when `readOnlyHint == false`) + * + * Default: true + */ + destructiveHint: z.optional(z.boolean()), + + /** + * If true, calling the tool repeatedly with the same arguments + * will have no additional effect on the its environment. + * + * (This property is meaningful only when `readOnlyHint == false`) + * + * Default: false + */ + idempotentHint: z.optional(z.boolean()), + + /** + * If true, this tool may interact with an "open world" of external + * entities. If false, the tool's domain of interaction is closed. + * For example, the world of a web search tool is open, whereas that + * of a memory tool is not. + * + * Default: true + */ + openWorldHint: z.optional(z.boolean()) + }) + .passthrough(); + +/** + * Definition for a tool the client can call. + */ +export const ToolSchema = BaseMetadataSchema.extend({ /** - * A human-readable title for the tool. + * A human-readable description of the tool. */ - title: z.optional(z.string()), - + description: z.optional(z.string()), /** - * If true, the tool does not modify its environment. - * - * Default: false + * A JSON Schema object defining the expected parameters for the tool. */ - readOnlyHint: z.optional(z.boolean()), - + inputSchema: z + .object({ + type: z.literal('object'), + properties: z.optional(z.object({}).passthrough()), + required: z.optional(z.array(z.string())) + }) + .passthrough(), /** - * If true, the tool may perform destructive updates to its environment. - * If false, the tool performs only additive updates. - * - * (This property is meaningful only when `readOnlyHint == false`) - * - * Default: true + * An optional JSON Schema object defining the structure of the tool's output returned in + * the structuredContent field of a CallToolResult. + */ + outputSchema: z.optional( + z + .object({ + type: z.literal('object'), + properties: z.optional(z.object({}).passthrough()), + required: z.optional(z.array(z.string())) + }) + .passthrough() + ), + /** + * Optional additional tool information. */ - destructiveHint: z.optional(z.boolean()), + annotations: z.optional(ToolAnnotationsSchema), /** - * If true, calling the tool repeatedly with the same arguments - * will have no additional effect on the its environment. - * - * (This property is meaningful only when `readOnlyHint == false`) - * - * Default: false + * An optional list of icons for this tool. */ - idempotentHint: z.optional(z.boolean()), + icons: z.optional(z.array(IconSchema)), /** - * If true, this tool may interact with an "open world" of external - * entities. If false, the tool's domain of interaction is closed. - * For example, the world of a web search tool is open, whereas that - * of a memory tool is not. - * - * Default: true + * See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + * for notes on _meta usage. */ - openWorldHint: z.optional(z.boolean()), - }) - .passthrough(); - -/** - * Definition for a tool the client can call. - */ -export const ToolSchema = BaseMetadataSchema.extend({ - /** - * A human-readable description of the tool. - */ - description: z.optional(z.string()), - /** - * A JSON Schema object defining the expected parameters for the tool. - */ - inputSchema: z - .object({ - type: z.literal("object"), - properties: z.optional(z.object({}).passthrough()), - required: z.optional(z.array(z.string())), - }) - .passthrough(), - /** - * An optional JSON Schema object defining the structure of the tool's output returned in - * the structuredContent field of a CallToolResult. - */ - outputSchema: z.optional( - z.object({ - type: z.literal("object"), - properties: z.optional(z.object({}).passthrough()), - required: z.optional(z.array(z.string())), - }) - .passthrough() - ), - /** - * Optional additional tool information. - */ - annotations: z.optional(ToolAnnotationsSchema), - - /** - * An optional list of icons for this tool. - */ - icons: z.optional(z.array(IconSchema)), - - /** - * See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) - * for notes on _meta usage. - */ - _meta: z.optional(z.object({}).passthrough()), + _meta: z.optional(z.object({}).passthrough()) }); /** * Sent from the client to request a list of tools the server has. */ export const ListToolsRequestSchema = PaginatedRequestSchema.extend({ - method: z.literal("tools/list"), + method: z.literal('tools/list') }); /** * The server's response to a tools/list request from the client. */ export const ListToolsResultSchema = PaginatedResultSchema.extend({ - tools: z.array(ToolSchema), + tools: z.array(ToolSchema) }); /** * The server's response to a tool call. */ export const CallToolResultSchema = ResultSchema.extend({ - /** - * A list of content objects that represent the result of the tool call. - * - * If the Tool does not define an outputSchema, this field MUST be present in the result. - * For backwards compatibility, this field is always present, but it may be empty. - */ - content: z.array(ContentBlockSchema).default([]), - - /** - * An object containing structured tool output. - * - * If the Tool defines an outputSchema, this field MUST be present in the result, and contain a JSON object that matches the schema. - */ - structuredContent: z.object({}).passthrough().optional(), - - /** - * Whether the tool call ended in an error. - * - * If not set, this is assumed to be false (the call was successful). - * - * Any errors that originate from the tool SHOULD be reported inside the result - * object, with `isError` set to true, _not_ as an MCP protocol-level error - * response. Otherwise, the LLM would not be able to see that an error occurred - * and self-correct. - * - * However, any errors in _finding_ the tool, an error indicating that the - * server does not support tool calls, or any other exceptional conditions, - * should be reported as an MCP error response. - */ - isError: z.optional(z.boolean()), + /** + * A list of content objects that represent the result of the tool call. + * + * If the Tool does not define an outputSchema, this field MUST be present in the result. + * For backwards compatibility, this field is always present, but it may be empty. + */ + content: z.array(ContentBlockSchema).default([]), + + /** + * An object containing structured tool output. + * + * If the Tool defines an outputSchema, this field MUST be present in the result, and contain a JSON object that matches the schema. + */ + structuredContent: z.object({}).passthrough().optional(), + + /** + * Whether the tool call ended in an error. + * + * If not set, this is assumed to be false (the call was successful). + * + * Any errors that originate from the tool SHOULD be reported inside the result + * object, with `isError` set to true, _not_ as an MCP protocol-level error + * response. Otherwise, the LLM would not be able to see that an error occurred + * and self-correct. + * + * However, any errors in _finding_ the tool, an error indicating that the + * server does not support tool calls, or any other exceptional conditions, + * should be reported as an MCP error response. + */ + isError: z.optional(z.boolean()) }); /** * CallToolResultSchema extended with backwards compatibility to protocol version 2024-10-07. */ export const CompatibilityCallToolResultSchema = CallToolResultSchema.or( - ResultSchema.extend({ - toolResult: z.unknown(), - }), + ResultSchema.extend({ + toolResult: z.unknown() + }) ); /** * Used by the client to invoke a tool provided by the server. */ export const CallToolRequestSchema = RequestSchema.extend({ - method: z.literal("tools/call"), - params: BaseRequestParamsSchema.extend({ - name: z.string(), - arguments: z.optional(z.record(z.unknown())), - }), + method: z.literal('tools/call'), + params: BaseRequestParamsSchema.extend({ + name: z.string(), + arguments: z.optional(z.record(z.unknown())) + }) }); /** * An optional notification from the server to the client, informing it that the list of tools it offers has changed. This may be issued by servers without any previous subscription from the client. */ export const ToolListChangedNotificationSchema = NotificationSchema.extend({ - method: z.literal("notifications/tools/list_changed"), + method: z.literal('notifications/tools/list_changed') }); /* Logging */ /** * The severity of a log message. */ -export const LoggingLevelSchema = z.enum([ - "debug", - "info", - "notice", - "warning", - "error", - "critical", - "alert", - "emergency", -]); +export const LoggingLevelSchema = z.enum(['debug', 'info', 'notice', 'warning', 'error', 'critical', 'alert', 'emergency']); /** * A request from the client to the server, to enable or adjust logging. */ export const SetLevelRequestSchema = RequestSchema.extend({ - method: z.literal("logging/setLevel"), - params: BaseRequestParamsSchema.extend({ - /** - * The level of logging that the client wants to receive from the server. The server should send all logs at this level and higher (i.e., more severe) to the client as notifications/logging/message. - */ - level: LoggingLevelSchema, - }), + method: z.literal('logging/setLevel'), + params: BaseRequestParamsSchema.extend({ + /** + * The level of logging that the client wants to receive from the server. The server should send all logs at this level and higher (i.e., more severe) to the client as notifications/logging/message. + */ + level: LoggingLevelSchema + }) }); /** * Notification of a log message passed from server to client. If no logging/setLevel request has been sent from the client, the server MAY decide which messages to send automatically. */ export const LoggingMessageNotificationSchema = NotificationSchema.extend({ - method: z.literal("notifications/message"), - params: BaseNotificationParamsSchema.extend({ - /** - * The severity of this log message. - */ - level: LoggingLevelSchema, - /** - * An optional name of the logger issuing this message. - */ - logger: z.optional(z.string()), - /** - * The data to be logged, such as a string message or an object. Any JSON serializable type is allowed here. - */ - data: z.unknown(), - }), + method: z.literal('notifications/message'), + params: BaseNotificationParamsSchema.extend({ + /** + * The severity of this log message. + */ + level: LoggingLevelSchema, + /** + * An optional name of the logger issuing this message. + */ + logger: z.optional(z.string()), + /** + * The data to be logged, such as a string message or an object. Any JSON serializable type is allowed here. + */ + data: z.unknown() + }) }); /* Sampling */ @@ -1115,100 +1084,94 @@ export const LoggingMessageNotificationSchema = NotificationSchema.extend({ * Hints to use for model selection. */ export const ModelHintSchema = z - .object({ - /** - * A hint for a model name. - */ - name: z.string().optional(), - }) - .passthrough(); + .object({ + /** + * A hint for a model name. + */ + name: z.string().optional() + }) + .passthrough(); /** * The server's preferences for model selection, requested of the client during sampling. */ export const ModelPreferencesSchema = z - .object({ - /** - * Optional hints to use for model selection. - */ - hints: z.optional(z.array(ModelHintSchema)), - /** - * How much to prioritize cost when selecting a model. - */ - costPriority: z.optional(z.number().min(0).max(1)), - /** - * How much to prioritize sampling speed (latency) when selecting a model. - */ - speedPriority: z.optional(z.number().min(0).max(1)), - /** - * How much to prioritize intelligence and capabilities when selecting a model. - */ - intelligencePriority: z.optional(z.number().min(0).max(1)), - }) - .passthrough(); + .object({ + /** + * Optional hints to use for model selection. + */ + hints: z.optional(z.array(ModelHintSchema)), + /** + * How much to prioritize cost when selecting a model. + */ + costPriority: z.optional(z.number().min(0).max(1)), + /** + * How much to prioritize sampling speed (latency) when selecting a model. + */ + speedPriority: z.optional(z.number().min(0).max(1)), + /** + * How much to prioritize intelligence and capabilities when selecting a model. + */ + intelligencePriority: z.optional(z.number().min(0).max(1)) + }) + .passthrough(); /** * Describes a message issued to or received from an LLM API. */ export const SamplingMessageSchema = z - .object({ - role: z.enum(["user", "assistant"]), - content: z.union([TextContentSchema, ImageContentSchema, AudioContentSchema]), - }) - .passthrough(); + .object({ + role: z.enum(['user', 'assistant']), + content: z.union([TextContentSchema, ImageContentSchema, AudioContentSchema]) + }) + .passthrough(); /** * A request from the server to sample an LLM via the client. The client has full discretion over which model to select. The client should also inform the user before beginning sampling, to allow them to inspect the request (human in the loop) and decide whether to approve it. */ export const CreateMessageRequestSchema = RequestSchema.extend({ - method: z.literal("sampling/createMessage"), - params: BaseRequestParamsSchema.extend({ - messages: z.array(SamplingMessageSchema), - /** - * An optional system prompt the server wants to use for sampling. The client MAY modify or omit this prompt. - */ - systemPrompt: z.optional(z.string()), - /** - * A request to include context from one or more MCP servers (including the caller), to be attached to the prompt. The client MAY ignore this request. - */ - includeContext: z.optional(z.enum(["none", "thisServer", "allServers"])), - temperature: z.optional(z.number()), - /** - * The maximum number of tokens to sample, as requested by the server. The client MAY choose to sample fewer tokens than requested. - */ - maxTokens: z.number().int(), - stopSequences: z.optional(z.array(z.string())), - /** - * Optional metadata to pass through to the LLM provider. The format of this metadata is provider-specific. - */ - metadata: z.optional(z.object({}).passthrough()), - /** - * The server's preferences for which model to select. - */ - modelPreferences: z.optional(ModelPreferencesSchema), - }), + method: z.literal('sampling/createMessage'), + params: BaseRequestParamsSchema.extend({ + messages: z.array(SamplingMessageSchema), + /** + * An optional system prompt the server wants to use for sampling. The client MAY modify or omit this prompt. + */ + systemPrompt: z.optional(z.string()), + /** + * A request to include context from one or more MCP servers (including the caller), to be attached to the prompt. The client MAY ignore this request. + */ + includeContext: z.optional(z.enum(['none', 'thisServer', 'allServers'])), + temperature: z.optional(z.number()), + /** + * The maximum number of tokens to sample, as requested by the server. The client MAY choose to sample fewer tokens than requested. + */ + maxTokens: z.number().int(), + stopSequences: z.optional(z.array(z.string())), + /** + * Optional metadata to pass through to the LLM provider. The format of this metadata is provider-specific. + */ + metadata: z.optional(z.object({}).passthrough()), + /** + * The server's preferences for which model to select. + */ + modelPreferences: z.optional(ModelPreferencesSchema) + }) }); /** * The client's response to a sampling/create_message request from the server. The client should inform the user before returning the sampled message, to allow them to inspect the response (human in the loop) and decide whether to allow the server to see it. */ export const CreateMessageResultSchema = ResultSchema.extend({ - /** - * The name of the model that generated the message. - */ - model: z.string(), - /** - * The reason why sampling stopped. - */ - stopReason: z.optional( - z.enum(["endTurn", "stopSequence", "maxTokens"]).or(z.string()), - ), - role: z.enum(["user", "assistant"]), - content: z.discriminatedUnion("type", [ - TextContentSchema, - ImageContentSchema, - AudioContentSchema - ]), + /** + * The name of the model that generated the message. + */ + model: z.string(), + /** + * The reason why sampling stopped. + */ + stopReason: z.optional(z.enum(['endTurn', 'stopSequence', 'maxTokens']).or(z.string())), + role: z.enum(['user', 'assistant']), + content: z.discriminatedUnion('type', [TextContentSchema, ImageContentSchema, AudioContentSchema]) }); /* Elicitation */ @@ -1216,100 +1179,95 @@ export const CreateMessageResultSchema = ResultSchema.extend({ * Primitive schema definition for boolean fields. */ export const BooleanSchemaSchema = z - .object({ - type: z.literal("boolean"), - title: z.optional(z.string()), - description: z.optional(z.string()), - default: z.optional(z.boolean()), - }) - .passthrough(); + .object({ + type: z.literal('boolean'), + title: z.optional(z.string()), + description: z.optional(z.string()), + default: z.optional(z.boolean()) + }) + .passthrough(); /** * Primitive schema definition for string fields. */ export const StringSchemaSchema = z - .object({ - type: z.literal("string"), - title: z.optional(z.string()), - description: z.optional(z.string()), - minLength: z.optional(z.number()), - maxLength: z.optional(z.number()), - format: z.optional(z.enum(["email", "uri", "date", "date-time"])), - }) - .passthrough(); + .object({ + type: z.literal('string'), + title: z.optional(z.string()), + description: z.optional(z.string()), + minLength: z.optional(z.number()), + maxLength: z.optional(z.number()), + format: z.optional(z.enum(['email', 'uri', 'date', 'date-time'])) + }) + .passthrough(); /** * Primitive schema definition for number fields. */ export const NumberSchemaSchema = z - .object({ - type: z.enum(["number", "integer"]), - title: z.optional(z.string()), - description: z.optional(z.string()), - minimum: z.optional(z.number()), - maximum: z.optional(z.number()), - }) - .passthrough(); + .object({ + type: z.enum(['number', 'integer']), + title: z.optional(z.string()), + description: z.optional(z.string()), + minimum: z.optional(z.number()), + maximum: z.optional(z.number()) + }) + .passthrough(); /** * Primitive schema definition for enum fields. */ export const EnumSchemaSchema = z - .object({ - type: z.literal("string"), - title: z.optional(z.string()), - description: z.optional(z.string()), - enum: z.array(z.string()), - enumNames: z.optional(z.array(z.string())), - }) - .passthrough(); + .object({ + type: z.literal('string'), + title: z.optional(z.string()), + description: z.optional(z.string()), + enum: z.array(z.string()), + enumNames: z.optional(z.array(z.string())) + }) + .passthrough(); /** * Union of all primitive schema definitions. */ -export const PrimitiveSchemaDefinitionSchema = z.union([ - BooleanSchemaSchema, - StringSchemaSchema, - NumberSchemaSchema, - EnumSchemaSchema, -]); +export const PrimitiveSchemaDefinitionSchema = z.union([BooleanSchemaSchema, StringSchemaSchema, NumberSchemaSchema, EnumSchemaSchema]); /** * A request from the server to elicit user input via the client. * The client should present the message and form fields to the user. */ export const ElicitRequestSchema = RequestSchema.extend({ - method: z.literal("elicitation/create"), - params: BaseRequestParamsSchema.extend({ - /** - * The message to present to the user. - */ - message: z.string(), - /** - * The schema for the requested user input. - */ - requestedSchema: z - .object({ - type: z.literal("object"), - properties: z.record(z.string(), PrimitiveSchemaDefinitionSchema), - required: z.optional(z.array(z.string())), - }) - .passthrough(), - }), + method: z.literal('elicitation/create'), + params: BaseRequestParamsSchema.extend({ + /** + * The message to present to the user. + */ + message: z.string(), + /** + * The schema for the requested user input. + */ + requestedSchema: z + .object({ + type: z.literal('object'), + properties: z.record(z.string(), PrimitiveSchemaDefinitionSchema), + required: z.optional(z.array(z.string())) + }) + .passthrough() + }) }); /** * The client's response to an elicitation/create request from the server. */ export const ElicitResultSchema = ResultSchema.extend({ - /** - * The user's response action. - */ - action: z.enum(["accept", "decline", "cancel"]), - /** - * The collected user input content (only present if action is "accept"). - */ - content: z.optional(z.record(z.string(), z.unknown())), + /** + * The user's response action. + */ + action: z.enum(['accept', 'decline', 'cancel']), + /** + * The collected user input content (only present if action is "accept"). + */ + content: z.optional(z.record(z.string(), z.unknown())) }); /* Autocomplete */ @@ -1317,14 +1275,14 @@ export const ElicitResultSchema = ResultSchema.extend({ * A reference to a resource or resource template definition. */ export const ResourceTemplateReferenceSchema = z - .object({ - type: z.literal("ref/resource"), - /** - * The URI or URI template of the resource. - */ - uri: z.string(), - }) - .passthrough(); + .object({ + type: z.literal('ref/resource'), + /** + * The URI or URI template of the resource. + */ + uri: z.string() + }) + .passthrough(); /** * @deprecated Use ResourceTemplateReferenceSchema instead @@ -1335,68 +1293,68 @@ export const ResourceReferenceSchema = ResourceTemplateReferenceSchema; * Identifies a prompt. */ export const PromptReferenceSchema = z - .object({ - type: z.literal("ref/prompt"), - /** - * The name of the prompt or prompt template - */ - name: z.string(), - }) - .passthrough(); + .object({ + type: z.literal('ref/prompt'), + /** + * The name of the prompt or prompt template + */ + name: z.string() + }) + .passthrough(); /** * A request from the client to the server, to ask for completion options. */ export const CompleteRequestSchema = RequestSchema.extend({ - method: z.literal("completion/complete"), - params: BaseRequestParamsSchema.extend({ - ref: z.union([PromptReferenceSchema, ResourceTemplateReferenceSchema]), - /** - * The argument's information - */ - argument: z - .object({ - /** - * The name of the argument - */ - name: z.string(), - /** - * The value of the argument to use for completion matching. - */ - value: z.string(), - }) - .passthrough(), - context: z.optional( - z.object({ + method: z.literal('completion/complete'), + params: BaseRequestParamsSchema.extend({ + ref: z.union([PromptReferenceSchema, ResourceTemplateReferenceSchema]), /** - * Previously-resolved variables in a URI template or prompt. + * The argument's information */ - arguments: z.optional(z.record(z.string(), z.string())), - }) - ), - }), + argument: z + .object({ + /** + * The name of the argument + */ + name: z.string(), + /** + * The value of the argument to use for completion matching. + */ + value: z.string() + }) + .passthrough(), + context: z.optional( + z.object({ + /** + * Previously-resolved variables in a URI template or prompt. + */ + arguments: z.optional(z.record(z.string(), z.string())) + }) + ) + }) }); /** * The server's response to a completion/complete request */ export const CompleteResultSchema = ResultSchema.extend({ - completion: z - .object({ - /** - * An array of completion values. Must not exceed 100 items. - */ - values: z.array(z.string()).max(100), - /** - * The total number of completion options available. This can exceed the number of values actually sent in the response. - */ - total: z.optional(z.number().int()), - /** - * Indicates whether there are additional completion options beyond those provided in the current response, even if the exact total is unknown. - */ - hasMore: z.optional(z.boolean()), - }) - .passthrough(), + completion: z + .object({ + /** + * An array of completion values. Must not exceed 100 items. + */ + values: z.array(z.string()).max(100), + /** + * The total number of completion options available. This can exceed the number of values actually sent in the response. + */ + total: z.optional(z.number().int()), + /** + * Indicates whether there are additional completion options beyond those provided in the current response, even if the exact total is unknown. + */ + hasMore: z.optional(z.boolean()) + }) + .passthrough() }); /* Roots */ @@ -1404,130 +1362,120 @@ export const CompleteResultSchema = ResultSchema.extend({ * Represents a root directory or file that the server can operate on. */ export const RootSchema = z - .object({ - /** - * The URI identifying the root. This *must* start with file:// for now. - */ - uri: z.string().startsWith("file://"), - /** - * An optional name for the root. - */ - name: z.optional(z.string()), + .object({ + /** + * The URI identifying the root. This *must* start with file:// for now. + */ + uri: z.string().startsWith('file://'), + /** + * An optional name for the root. + */ + name: z.optional(z.string()), - /** - * See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) - * for notes on _meta usage. - */ - _meta: z.optional(z.object({}).passthrough()), - }) - .passthrough(); + /** + * See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + * for notes on _meta usage. + */ + _meta: z.optional(z.object({}).passthrough()) + }) + .passthrough(); /** * Sent from the server to request a list of root URIs from the client. */ export const ListRootsRequestSchema = RequestSchema.extend({ - method: z.literal("roots/list"), + method: z.literal('roots/list') }); /** * The client's response to a roots/list request from the server. */ export const ListRootsResultSchema = ResultSchema.extend({ - roots: z.array(RootSchema), + roots: z.array(RootSchema) }); /** * A notification from the client to the server, informing it that the list of roots has changed. */ export const RootsListChangedNotificationSchema = NotificationSchema.extend({ - method: z.literal("notifications/roots/list_changed"), + method: z.literal('notifications/roots/list_changed') }); /* Client messages */ export const ClientRequestSchema = z.union([ - PingRequestSchema, - InitializeRequestSchema, - CompleteRequestSchema, - SetLevelRequestSchema, - GetPromptRequestSchema, - ListPromptsRequestSchema, - ListResourcesRequestSchema, - ListResourceTemplatesRequestSchema, - ReadResourceRequestSchema, - SubscribeRequestSchema, - UnsubscribeRequestSchema, - CallToolRequestSchema, - ListToolsRequestSchema, + PingRequestSchema, + InitializeRequestSchema, + CompleteRequestSchema, + SetLevelRequestSchema, + GetPromptRequestSchema, + ListPromptsRequestSchema, + ListResourcesRequestSchema, + ListResourceTemplatesRequestSchema, + ReadResourceRequestSchema, + SubscribeRequestSchema, + UnsubscribeRequestSchema, + CallToolRequestSchema, + ListToolsRequestSchema ]); export const ClientNotificationSchema = z.union([ - CancelledNotificationSchema, - ProgressNotificationSchema, - InitializedNotificationSchema, - RootsListChangedNotificationSchema, + CancelledNotificationSchema, + ProgressNotificationSchema, + InitializedNotificationSchema, + RootsListChangedNotificationSchema ]); -export const ClientResultSchema = z.union([ - EmptyResultSchema, - CreateMessageResultSchema, - ElicitResultSchema, - ListRootsResultSchema, -]); +export const ClientResultSchema = z.union([EmptyResultSchema, CreateMessageResultSchema, ElicitResultSchema, ListRootsResultSchema]); /* Server messages */ -export const ServerRequestSchema = z.union([ - PingRequestSchema, - CreateMessageRequestSchema, - ElicitRequestSchema, - ListRootsRequestSchema, -]); +export const ServerRequestSchema = z.union([PingRequestSchema, CreateMessageRequestSchema, ElicitRequestSchema, ListRootsRequestSchema]); export const ServerNotificationSchema = z.union([ - CancelledNotificationSchema, - ProgressNotificationSchema, - LoggingMessageNotificationSchema, - ResourceUpdatedNotificationSchema, - ResourceListChangedNotificationSchema, - ToolListChangedNotificationSchema, - PromptListChangedNotificationSchema, + CancelledNotificationSchema, + ProgressNotificationSchema, + LoggingMessageNotificationSchema, + ResourceUpdatedNotificationSchema, + ResourceListChangedNotificationSchema, + ToolListChangedNotificationSchema, + PromptListChangedNotificationSchema ]); export const ServerResultSchema = z.union([ - EmptyResultSchema, - InitializeResultSchema, - CompleteResultSchema, - GetPromptResultSchema, - ListPromptsResultSchema, - ListResourcesResultSchema, - ListResourceTemplatesResultSchema, - ReadResourceResultSchema, - CallToolResultSchema, - ListToolsResultSchema, + EmptyResultSchema, + InitializeResultSchema, + CompleteResultSchema, + GetPromptResultSchema, + ListPromptsResultSchema, + ListResourcesResultSchema, + ListResourceTemplatesResultSchema, + ReadResourceResultSchema, + CallToolResultSchema, + ListToolsResultSchema ]); export class McpError extends Error { - constructor( - public readonly code: number, - message: string, - public readonly data?: unknown, - ) { - super(`MCP error ${code}: ${message}`); - this.name = "McpError"; - } + constructor( + public readonly code: number, + message: string, + public readonly data?: unknown + ) { + super(`MCP error ${code}: ${message}`); + this.name = 'McpError'; + } } type Primitive = string | number | boolean | bigint | null | undefined; type Flatten = T extends Primitive - ? T - : T extends Array - ? Array> - : T extends Set - ? Set> - : T extends Map - ? Map, Flatten> - : T extends object - ? { [K in keyof T]: Flatten } - : T; + ? T + : T extends Array + ? Array> + : T extends Set + ? Set> + : T extends Map + ? Map, Flatten> + : T extends object + ? { [K in keyof T]: Flatten } + : T; type Infer = Flatten>; @@ -1540,25 +1488,25 @@ export type IsomorphicHeaders = Record; * Information about the incoming request. */ export interface RequestInfo { - /** - * The headers of the request. - */ - headers: IsomorphicHeaders; + /** + * The headers of the request. + */ + headers: IsomorphicHeaders; } /** * Extra information about a message. */ export interface MessageExtraInfo { - /** - * The request information. - */ - requestInfo?: RequestInfo; - - /** - * The authentication information. - */ - authInfo?: AuthInfo; + /** + * The request information. + */ + requestInfo?: RequestInfo; + + /** + * The authentication information. + */ + authInfo?: AuthInfo; } /* JSON-RPC types */ diff --git a/tsconfig.cjs.json b/tsconfig.cjs.json index b2f344a81..3b46f11c4 100644 --- a/tsconfig.cjs.json +++ b/tsconfig.cjs.json @@ -1,9 +1,9 @@ { - "extends": "./tsconfig.json", - "compilerOptions": { - "module": "commonjs", - "moduleResolution": "node", - "outDir": "./dist/cjs" - }, - "exclude": ["**/*.test.ts", "src/__mocks__/**/*"] + "extends": "./tsconfig.json", + "compilerOptions": { + "module": "commonjs", + "moduleResolution": "node", + "outDir": "./dist/cjs" + }, + "exclude": ["**/*.test.ts", "src/__mocks__/**/*"] } diff --git a/tsconfig.json b/tsconfig.json index cedbaaaeb..4cc22bf1b 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -1,23 +1,23 @@ { - "compilerOptions": { - "target": "es2018", - "module": "Node16", - "moduleResolution": "Node16", - "declaration": true, - "declarationMap": true, - "sourceMap": true, - "outDir": "./dist", - "strict": true, - "esModuleInterop": true, - "forceConsistentCasingInFileNames": true, - "resolveJsonModule": true, - "isolatedModules": true, - "skipLibCheck": true, - "baseUrl": ".", - "paths": { - "pkce-challenge": ["node_modules/pkce-challenge/dist/index.node"] - } - }, - "include": ["src/**/*"], - "exclude": ["node_modules", "dist"] + "compilerOptions": { + "target": "es2018", + "module": "Node16", + "moduleResolution": "Node16", + "declaration": true, + "declarationMap": true, + "sourceMap": true, + "outDir": "./dist", + "strict": true, + "esModuleInterop": true, + "forceConsistentCasingInFileNames": true, + "resolveJsonModule": true, + "isolatedModules": true, + "skipLibCheck": true, + "baseUrl": ".", + "paths": { + "pkce-challenge": ["node_modules/pkce-challenge/dist/index.node"] + } + }, + "include": ["src/**/*"], + "exclude": ["node_modules", "dist"] } diff --git a/tsconfig.prod.json b/tsconfig.prod.json index 2302dd844..8fc00fcd4 100644 --- a/tsconfig.prod.json +++ b/tsconfig.prod.json @@ -1,7 +1,7 @@ { - "extends": "./tsconfig.json", - "compilerOptions": { - "outDir": "./dist/esm" - }, - "exclude": ["**/*.test.ts", "src/__mocks__/**/*"] + "extends": "./tsconfig.json", + "compilerOptions": { + "outDir": "./dist/esm" + }, + "exclude": ["**/*.test.ts", "src/__mocks__/**/*"] }