diff --git a/src/cli/connector-loader.ts b/src/cli/connector-loader.ts index b1bc25a..5ba08c6 100644 --- a/src/cli/connector-loader.ts +++ b/src/cli/connector-loader.ts @@ -1,4 +1,4 @@ -import { MvmtConfig } from '../config/schema.js'; +import { MvmtConfig, OBSIDIAN_SOURCE_ID, resolveProxySourceId } from '../config/schema.js'; import { Connector } from '../connectors/types.js'; import { ObsidianConnector } from '../connectors/obsidian.js'; import { createProxyConnector } from '../connectors/factory.js'; @@ -6,6 +6,7 @@ import { Logger } from '../utils/logger.js'; export type LoadedConnector = { connector: Connector; + sourceId: string; toolCount: number; }; @@ -28,7 +29,7 @@ export async function initializeConnectors( try { await connector.initialize(); const toolCount = (await connector.listTools()).length; - loaded.push({ connector, toolCount }); + loaded.push({ connector, sourceId: resolveProxySourceId(proxyConfig), toolCount }); emit(`Loaded proxy:${proxyConfig.name} (${toolCount} tools)`, stdioMode, logger); } catch (err) { emit( @@ -46,7 +47,7 @@ export async function initializeConnectors( try { await connector.initialize(); const toolCount = (await connector.listTools()).length; - loaded.push({ connector, toolCount }); + loaded.push({ connector, sourceId: OBSIDIAN_SOURCE_ID, toolCount }); emit(`Loaded obsidian (${toolCount} tools)`, stdioMode, logger); } catch (err) { emit( diff --git a/src/cli/start.ts b/src/cli/start.ts index 9b529fe..04e5368 100644 --- a/src/cli/start.ts +++ b/src/cli/start.ts @@ -91,7 +91,7 @@ export async function start(options: StartOptions = {}): Promise { const audit = interactiveMode ? new InteractiveAuditLogger(createAuditLogger()) : createAuditLogger(); - const router = new ToolRouter(loaded.map((entry) => entry.connector), audit, plugins); + const router = new ToolRouter(loaded.map((entry) => ({ connector: entry.connector, sourceId: entry.sourceId })), audit, plugins); await router.initialize(); // Cleanup tasks run on SIGINT/SIGTERM and on startup failure. diff --git a/src/server/index.ts b/src/server/index.ts index c47b68a..df55567 100644 --- a/src/server/index.ts +++ b/src/server/index.ts @@ -19,7 +19,7 @@ import { } from './oauth.js'; import { rateLimit } from './rate-limit.js'; import { ClientConfig } from '../config/schema.js'; -import { attachClientIdentity, isQuarantined, resolveClientIdentity } from './client-identity.js'; +import { attachClientIdentity, ClientIdentity, isQuarantined, readClientIdentity, resolveClientIdentity } from './client-identity.js'; // Rate limits are defense-in-depth against brute-force and DoS, // primarily meaningful when mvmt is exposed via a tunnel. Auth-gated @@ -34,6 +34,7 @@ const DEFAULT_HEALTH_RATE_LIMIT = { windowMs: 60_000, max: 120 }; type McpSession = { transport: StreamableHTTPServerTransport; server: Server; + clientIdentity?: ClientIdentity; lastActivity: number; }; @@ -79,14 +80,14 @@ export interface HttpRequestLogEntry { clientId?: string; } -export function createMcpServer(router: ToolRouter): Server { +export function createMcpServer(router: ToolRouter, clientIdentity?: ClientIdentity): Server { const server = new Server( { name: 'mvmt', version: '0.1.0' }, { capabilities: { tools: {} } }, ); server.setRequestHandler(ListToolsRequestSchema, async () => ({ - tools: router.getAllTools().map((tool) => ({ + tools: router.getAllTools(clientIdentity).map((tool) => ({ name: tool.namespacedName, description: tool.description, inputSchema: tool.inputSchema as { type: 'object'; properties?: Record; required?: string[] }, @@ -95,7 +96,7 @@ export function createMcpServer(router: ToolRouter): Server { server.setRequestHandler(CallToolRequestSchema, async (request) => { try { - const result = await router.callTool(request.params.name, request.params.arguments ?? {}); + const result = await router.callTool(request.params.name, request.params.arguments ?? {}, clientIdentity); return result as any; } catch (err) { return { @@ -571,7 +572,7 @@ export async function startHttpServer(router: ToolRouter, options: HttpServerOpt res.json({ status: 'ok', uptime: Math.floor(process.uptime()), - tools: router.getAllTools().length, + tools: router.getAllTools(readClientIdentity(_req)).length, sessions: sessions.size, }); }); @@ -621,6 +622,10 @@ async function handleMcpRequest( if (sessionId && sessions.has(sessionId)) { const session = sessions.get(sessionId)!; + if (!sameClientIdentity(session.clientIdentity, readClientIdentity(req))) { + res.status(403).json({ error: 'mcp_session_client_mismatch' }); + return; + } session.lastActivity = Date.now(); if (isStandaloneSseRequest(req)) { session.transport.closeStandaloneSSEStream(); @@ -637,7 +642,8 @@ async function handleMcpRequest( const transport = new StreamableHTTPServerTransport({ sessionIdGenerator: () => randomUUID(), }); - const server = createMcpServer(router); + const clientIdentity = readClientIdentity(req); + const server = createMcpServer(router, clientIdentity); transport.onerror = (error) => { if (isBenignDuplicateSseConflict(error)) { @@ -656,7 +662,7 @@ async function handleMcpRequest( await transport.handleRequest(req, res, req.body); if (transport.sessionId) { - sessions.set(transport.sessionId, { transport, server, lastActivity: Date.now() }); + sessions.set(transport.sessionId, { transport, server, clientIdentity, lastActivity: Date.now() }); } } @@ -664,7 +670,7 @@ async function handleStatelessMcpRequest(req: Request, res: Response, router: To const transport = new StreamableHTTPServerTransport({ sessionIdGenerator: undefined, }); - const server = createMcpServer(router); + const server = createMcpServer(router, readClientIdentity(req)); transport.onerror = (error) => { if (isBenignDuplicateSseConflict(error)) { @@ -678,6 +684,10 @@ async function handleStatelessMcpRequest(req: Request, res: Response, router: To await transport.handleRequest(req, res, req.body); } +function sameClientIdentity(left: ClientIdentity | undefined, right: ClientIdentity | undefined): boolean { + return left?.id === right?.id; +} + function authLogKind(req: Request): string { if (req.path === '/mcp') return 'mcp.auth'; if (req.path === '/health') return 'health.auth'; diff --git a/src/server/router.ts b/src/server/router.ts index 21ff790..7fb4e7b 100644 --- a/src/server/router.ts +++ b/src/server/router.ts @@ -1,31 +1,58 @@ import { CallToolResult, Connector, ToolDefinition } from '../connectors/types.js'; +import { isLikelyWriteTool, isMemPalaceWriteTool } from '../connectors/write-policy.js'; +import { PermissionConfig } from '../config/schema.js'; import { PatternRedactorAuditEvent, ToolResultPlugin } from '../plugins/types.js'; import { AuditLogger, summarizeArgs } from '../utils/audit.js'; +import { ClientIdentity } from './client-identity.js'; + +type PermissionAction = PermissionConfig['actions'][number]; + +export interface RouterConnector { + connector: Connector; + sourceId?: string; +} export interface NamespacedTool { namespacedName: string; originalName: string; connectorId: string; + sourceId: string; + requiredAction: PermissionAction; description: string; inputSchema: Record; } export class ToolRouter { - private readonly toolMap = new Map(); + private readonly toolMap = new Map(); private readonly allTools: NamespacedTool[] = []; + private readonly connectors: RouterConnector[]; constructor( - private readonly connectors: Connector[], + connectors: Array, private readonly audit?: AuditLogger, private readonly plugins: ToolResultPlugin[] = [], - ) {} + ) { + this.connectors = connectors.map((entry) => ( + 'connector' in entry + ? entry + : { connector: entry, sourceId: entry.id } + )); + } async initialize(): Promise { - for (const connector of this.connectors) { + for (const entry of this.connectors) { + const { connector } = entry; + const sourceId = entry.sourceId ?? connector.id; const tools = await connector.listTools(); for (const tool of tools) { const namespacedName = `${connector.id}__${tool.name}`; + const requiredAction = inferRequiredAction(tool.name); if (this.toolMap.has(namespacedName)) { throw new Error(`Duplicate tool name after namespacing: ${namespacedName}`); } @@ -33,12 +60,16 @@ export class ToolRouter { this.toolMap.set(namespacedName, { connector, originalName: tool.name, + sourceId, + requiredAction, }); this.allTools.push({ namespacedName, originalName: tool.name, connectorId: connector.id, + sourceId, + requiredAction, description: `[${connector.displayName}] ${tool.description}`, inputSchema: normalizeInputSchema(tool), }); @@ -46,11 +77,15 @@ export class ToolRouter { } } - getAllTools(): NamespacedTool[] { - return [...this.allTools]; + getAllTools(identity?: ClientIdentity): NamespacedTool[] { + return this.allTools.filter((tool) => isToolAllowed(tool, identity)).map((tool) => ({ ...tool })); } - async callTool(namespacedName: string, args: Record): Promise { + async callTool( + namespacedName: string, + args: Record, + identity?: ClientIdentity, + ): Promise { const entry = this.toolMap.get(namespacedName); if (!entry) { throw new Error(`Unknown tool: ${namespacedName}`); @@ -60,6 +95,16 @@ export class ToolRouter { let result: CallToolResult; const redactions: NonNullable = []; let threw = false; + const deniedReason = toolDeniedReason(entry, identity); + if (deniedReason) { + result = { + content: [{ type: 'text', text: `Error: access denied for tool "${namespacedName}" (${deniedReason}).` }], + isError: true, + }; + this.recordAudit(entry, namespacedName, args, start, result, false, redactions, identity, deniedReason); + return result; + } + try { result = await entry.connector.callTool(entry.originalName, args); for (const plugin of this.plugins) { @@ -78,21 +123,61 @@ export class ToolRouter { threw = true; throw err; } finally { - if (this.audit) { - const { argKeys, argPreview } = summarizeArgs(args); - this.audit.record({ - ts: new Date().toISOString(), - connectorId: entry.connector.id, - tool: namespacedName, - argKeys, - argPreview, - ...(redactions.length > 0 ? { redactions } : {}), - isError: threw || Boolean(result! && (result as CallToolResult).isError), - durationMs: Date.now() - start, - }); - } + this.recordAudit(entry, namespacedName, args, start, result!, threw, redactions, identity); } } + + private recordAudit( + entry: { connector: Connector }, + namespacedName: string, + args: Record, + start: number, + result: CallToolResult | undefined, + threw: boolean, + redactions: NonNullable, + identity?: ClientIdentity, + deniedReason?: string, + ): void { + if (!this.audit) return; + const { argKeys, argPreview } = summarizeArgs(args); + this.audit.record({ + ts: new Date().toISOString(), + connectorId: entry.connector.id, + tool: namespacedName, + ...(identity ? { clientId: identity.id } : {}), + argKeys, + argPreview, + ...(redactions.length > 0 ? { redactions } : {}), + isError: threw || Boolean(result?.isError), + ...(deniedReason ? { deniedReason } : {}), + durationMs: Date.now() - start, + }); + } +} + +function isToolAllowed(tool: NamespacedTool, identity?: ClientIdentity): boolean { + return !toolDeniedReason(tool, identity); +} + +function toolDeniedReason( + tool: { sourceId: string; requiredAction: PermissionAction }, + identity?: ClientIdentity, +): string | undefined { + if (!identity || identity.isLegacyDefault) return undefined; + if (!identity.rawToolsEnabled) return 'raw_tools_disabled'; + const allowedActions = identity.permissions + .filter((permission) => permission.sourceId === tool.sourceId) + .flatMap((permission) => permission.actions); + if (allowedActions.includes(tool.requiredAction)) return undefined; + return `missing_permission source=${tool.sourceId} action=${tool.requiredAction}`; +} + +function inferRequiredAction(toolName: string): PermissionAction { + const lower = toolName.toLowerCase(); + if (isMemPalaceWriteTool(lower)) return 'memory_write'; + if (isLikelyWriteTool(lower)) return 'write'; + if (lower.includes('search') || lower.startsWith('find_') || lower.startsWith('query_')) return 'search'; + return 'read'; } function flattenRedactionEvents( diff --git a/src/utils/audit.ts b/src/utils/audit.ts index 7066535..c003a8b 100644 --- a/src/utils/audit.ts +++ b/src/utils/audit.ts @@ -8,6 +8,7 @@ export interface AuditEntry { ts: string; connectorId: string; tool: string; + clientId?: string; argKeys: string[]; argPreview: string; redactions?: Array<{ @@ -18,6 +19,7 @@ export interface AuditEntry { truncated?: boolean; }>; isError: boolean; + deniedReason?: string; durationMs: number; } diff --git a/tests/router.test.ts b/tests/router.test.ts index b8ca67a..2868e5b 100644 --- a/tests/router.test.ts +++ b/tests/router.test.ts @@ -1,6 +1,7 @@ import { describe, expect, it, vi } from 'vitest'; import { Connector } from '../src/connectors/types.js'; import { ToolRouter } from '../src/server/router.js'; +import { ClientIdentity } from '../src/server/client-identity.js'; describe('ToolRouter', () => { it('namespaces tools and routes calls to the owning connector', async () => { @@ -30,6 +31,8 @@ describe('ToolRouter', () => { namespacedName: 'proxy_github__create_issue', originalName: 'create_issue', connectorId: 'proxy_github', + sourceId: 'proxy_github', + requiredAction: 'write', description: '[github] Create an issue', inputSchema: { type: 'object', properties: { title: { type: 'string' } } }, }, @@ -154,4 +157,128 @@ describe('ToolRouter', () => { }), ); }); + + it('filters raw tools by client source/action permissions', async () => { + const connector: Connector = { + id: 'proxy_filesystem', + displayName: 'filesystem', + initialize: vi.fn(), + shutdown: vi.fn(), + listTools: vi.fn(async () => [ + { name: 'search_files', description: 'Search files', inputSchema: { type: 'object', properties: {} } }, + { name: 'read_file', description: 'Read a file', inputSchema: { type: 'object', properties: {} } }, + { name: 'write_file', description: 'Write a file', inputSchema: { type: 'object', properties: {} } }, + ]), + callTool: vi.fn(), + }; + + const router = new ToolRouter([{ connector, sourceId: 'workspace' }]); + await router.initialize(); + + expect(router.getAllTools(client('codex', true, [{ sourceId: 'workspace', actions: ['search'] }]))).toEqual([ + expect.objectContaining({ namespacedName: 'proxy_filesystem__search_files', requiredAction: 'search' }), + ]); + }); + + it('denies raw tools when rawToolsEnabled is false', async () => { + const audit = { record: vi.fn() }; + const callTool = vi.fn(); + const connector: Connector = { + id: 'obsidian', + displayName: 'obsidian', + initialize: vi.fn(), + shutdown: vi.fn(), + listTools: vi.fn(async () => [ + { name: 'read_note', description: 'Read a note', inputSchema: { type: 'object', properties: {} } }, + ]), + callTool, + }; + + const router = new ToolRouter([connector], audit); + await router.initialize(); + + expect(router.getAllTools(client('chatgpt', false, [{ sourceId: 'obsidian', actions: ['read'] }]))).toEqual([]); + await expect( + router.callTool('obsidian__read_note', { notePath: 'Project' }, client('chatgpt', false, [ + { sourceId: 'obsidian', actions: ['read'] }, + ])), + ).resolves.toMatchObject({ isError: true }); + expect(callTool).not.toHaveBeenCalled(); + expect(audit.record).toHaveBeenCalledWith( + expect.objectContaining({ + clientId: 'chatgpt', + deniedReason: 'raw_tools_disabled', + isError: true, + }), + ); + }); + + it('denies raw tool calls without the required source/action permission', async () => { + const audit = { record: vi.fn() }; + const callTool = vi.fn(); + const connector: Connector = { + id: 'proxy_filesystem', + displayName: 'filesystem', + initialize: vi.fn(), + shutdown: vi.fn(), + listTools: vi.fn(async () => [ + { name: 'read_file', description: 'Read a file', inputSchema: { type: 'object', properties: {} } }, + ]), + callTool, + }; + + const router = new ToolRouter([{ connector, sourceId: 'workspace' }], audit); + await router.initialize(); + + await expect( + router.callTool('proxy_filesystem__read_file', { path: '/tmp/a' }, client('codex', true, [ + { sourceId: 'workspace', actions: ['search'] }, + ])), + ).resolves.toMatchObject({ isError: true }); + expect(callTool).not.toHaveBeenCalled(); + expect(audit.record).toHaveBeenCalledWith( + expect.objectContaining({ + clientId: 'codex', + deniedReason: 'missing_permission source=workspace action=read', + isError: true, + }), + ); + }); + + it('records clientId for authorized raw tool calls', async () => { + const audit = { record: vi.fn() }; + const connector: Connector = { + id: 'obsidian', + displayName: 'obsidian', + initialize: vi.fn(), + shutdown: vi.fn(), + listTools: vi.fn(async () => [ + { name: 'read_note', description: 'Read a note', inputSchema: { type: 'object', properties: {} } }, + ]), + callTool: vi.fn(async () => ({ content: [{ type: 'text' as const, text: 'ok' }] })), + }; + + const router = new ToolRouter([connector], audit); + await router.initialize(); + + await router.callTool('obsidian__read_note', {}, client('codex', true, [ + { sourceId: 'obsidian', actions: ['read'] }, + ])); + + expect(audit.record).toHaveBeenCalledWith(expect.objectContaining({ clientId: 'codex', isError: false })); + }); }); + +function client( + id: string, + rawToolsEnabled: boolean, + permissions: ClientIdentity['permissions'], +): ClientIdentity { + return { + id, + name: id, + source: 'token', + rawToolsEnabled, + permissions, + }; +} diff --git a/tests/server.test.ts b/tests/server.test.ts index 500eb5c..f5f870f 100644 --- a/tests/server.test.ts +++ b/tests/server.test.ts @@ -1281,6 +1281,46 @@ describe('startHttpServer lifecycle', () => { } }); + it('filters MCP tools/list and denies tool calls by resolved client permissions', async () => { + const connector = new PolicyConnector(); + const router = new ToolRouter([{ connector, sourceId: 'workspace' }]); + await router.initialize(); + const tmp = fs.mkdtempSync(path.join(os.tmpdir(), 'mvmt-server-test-')); + const tokenPath = path.join(tmp, '.mvmt', '.session-token'); + const server = await startHttpServer(router, { + port: 0, + tokenPath, + clients: [ + { + id: 'searcher', + name: 'Search-only client', + auth: { type: 'token', tokenHash: sha256Hex('search-token') }, + rawToolsEnabled: true, + permissions: [{ sourceId: 'workspace', actions: ['search'] }], + }, + ], + }); + + try { + const sessionId = await initializeMcpSession(server.port, 'search-token'); + const listTools = await mcpJsonRequest(server.port, 'search-token', sessionId, 2, 'tools/list', {}); + expect(listTools.result.tools.map((tool: { name: string }) => tool.name)).toEqual([ + 'proxy_filesystem__search_files', + ]); + + const denied = await mcpJsonRequest(server.port, 'search-token', sessionId, 3, 'tools/call', { + name: 'proxy_filesystem__read_file', + arguments: { path: '/tmp/a' }, + }); + expect(denied.result.isError).toBe(true); + expect(denied.result.content[0].text).toContain('missing_permission source=workspace action=read'); + expect(connector.calls).toEqual([]); + } finally { + await server.close(); + fs.rmSync(tmp, { recursive: true, force: true }); + } + }); + it('revokes outstanding OAuth access tokens the moment the signing key file is rewritten', async () => { const router = new ToolRouter([new EmptyConnector()]); await router.initialize(); @@ -1413,6 +1453,28 @@ class EmptyConnector implements Connector { async shutdown(): Promise {} } +class PolicyConnector implements Connector { + readonly id = 'proxy_filesystem'; + readonly displayName = 'filesystem'; + readonly calls: Array<{ name: string; args: Record }> = []; + + async initialize(): Promise {} + + async listTools() { + return [ + { name: 'search_files', description: 'Search files', inputSchema: { type: 'object', properties: {} } }, + { name: 'read_file', description: 'Read a file', inputSchema: { type: 'object', properties: {} } }, + ]; + } + + async callTool(name: string, args: Record) { + this.calls.push({ name, args }); + return { content: [{ type: 'text' as const, text: 'ok' }] }; + } + + async shutdown(): Promise {} +} + function canListenOn(port: number): Promise { return new Promise((resolve) => { const server = createServer(); @@ -1518,3 +1580,59 @@ function expectAccessTokenAudience(accessToken: string, signingKeyPath: string, function sha256Hex(value: string): string { return createHash('sha256').update(value, 'utf8').digest('hex'); } + +async function initializeMcpSession(port: number, token: string): Promise { + const response = await fetch(`http://127.0.0.1:${port}/mcp`, { + method: 'POST', + headers: { + Authorization: `Bearer ${token}`, + Accept: 'application/json, text/event-stream', + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + jsonrpc: '2.0', + id: 1, + method: 'initialize', + params: { + protocolVersion: '2025-03-26', + capabilities: {}, + clientInfo: { name: 'mvmt-policy-test', version: '0.0.0' }, + }, + }), + }); + expect(response.status).toBe(200); + const sessionId = response.headers.get('mcp-session-id'); + expect(sessionId).toBeTruthy(); + await response.text(); + return sessionId!; +} + +async function mcpJsonRequest( + port: number, + token: string, + sessionId: string, + id: number, + method: string, + params: Record, +): Promise { + const response = await fetch(`http://127.0.0.1:${port}/mcp`, { + method: 'POST', + headers: { + Authorization: `Bearer ${token}`, + Accept: 'application/json, text/event-stream', + 'Content-Type': 'application/json', + 'Mcp-Protocol-Version': '2025-03-26', + 'Mcp-Session-Id': sessionId, + }, + body: JSON.stringify({ jsonrpc: '2.0', id, method, params }), + }); + expect(response.status).toBe(200); + return parseMcpResponse(await response.text()); +} + +function parseMcpResponse(text: string): any { + if (text.trimStart().startsWith('{')) return JSON.parse(text); + const dataLine = text.split('\n').find((line) => line.startsWith('data: ')); + if (!dataLine) throw new Error(`Could not parse MCP response: ${text}`); + return JSON.parse(dataLine.slice('data: '.length)); +}