Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/cli/connector-loader.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import { MvmtConfig } from '../config/schema.js';
import { MvmtConfig, OBSIDIAN_SOURCE_ID, resolveProxySourceId } from '../config/schema.js';
import { Connector } from '../connectors/types.js';
import { ObsidianConnector } from '../connectors/obsidian.js';
import { createProxyConnector } from '../connectors/factory.js';
import { Logger } from '../utils/logger.js';

export type LoadedConnector = {
connector: Connector;
sourceId: string;
toolCount: number;
};

Expand All @@ -28,7 +29,7 @@ export async function initializeConnectors(
try {
await connector.initialize();
const toolCount = (await connector.listTools()).length;
loaded.push({ connector, toolCount });
loaded.push({ connector, sourceId: resolveProxySourceId(proxyConfig), toolCount });
emit(`Loaded proxy:${proxyConfig.name} (${toolCount} tools)`, stdioMode, logger);
} catch (err) {
emit(
Expand All @@ -46,7 +47,7 @@ export async function initializeConnectors(
try {
await connector.initialize();
const toolCount = (await connector.listTools()).length;
loaded.push({ connector, toolCount });
loaded.push({ connector, sourceId: OBSIDIAN_SOURCE_ID, toolCount });
emit(`Loaded obsidian (${toolCount} tools)`, stdioMode, logger);
} catch (err) {
emit(
Expand Down
2 changes: 1 addition & 1 deletion src/cli/start.ts
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ export async function start(options: StartOptions = {}): Promise<void> {
const audit = interactiveMode
? new InteractiveAuditLogger(createAuditLogger())
: createAuditLogger();
const router = new ToolRouter(loaded.map((entry) => entry.connector), audit, plugins);
const router = new ToolRouter(loaded.map((entry) => ({ connector: entry.connector, sourceId: entry.sourceId })), audit, plugins);
await router.initialize();

// Cleanup tasks run on SIGINT/SIGTERM and on startup failure.
Expand Down
26 changes: 18 additions & 8 deletions src/server/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import {
} from './oauth.js';
import { rateLimit } from './rate-limit.js';
import { ClientConfig } from '../config/schema.js';
import { attachClientIdentity, isQuarantined, resolveClientIdentity } from './client-identity.js';
import { attachClientIdentity, ClientIdentity, isQuarantined, readClientIdentity, resolveClientIdentity } from './client-identity.js';

// Rate limits are defense-in-depth against brute-force and DoS,
// primarily meaningful when mvmt is exposed via a tunnel. Auth-gated
Expand All @@ -34,6 +34,7 @@ const DEFAULT_HEALTH_RATE_LIMIT = { windowMs: 60_000, max: 120 };
type McpSession = {
transport: StreamableHTTPServerTransport;
server: Server;
clientIdentity?: ClientIdentity;
lastActivity: number;
};

Expand Down Expand Up @@ -79,14 +80,14 @@ export interface HttpRequestLogEntry {
clientId?: string;
}

export function createMcpServer(router: ToolRouter): Server {
export function createMcpServer(router: ToolRouter, clientIdentity?: ClientIdentity): Server {
const server = new Server(
{ name: 'mvmt', version: '0.1.0' },
{ capabilities: { tools: {} } },
);

server.setRequestHandler(ListToolsRequestSchema, async () => ({
tools: router.getAllTools().map((tool) => ({
tools: router.getAllTools(clientIdentity).map((tool) => ({
name: tool.namespacedName,
description: tool.description,
inputSchema: tool.inputSchema as { type: 'object'; properties?: Record<string, object>; required?: string[] },
Expand All @@ -95,7 +96,7 @@ export function createMcpServer(router: ToolRouter): Server {

server.setRequestHandler(CallToolRequestSchema, async (request) => {
try {
const result = await router.callTool(request.params.name, request.params.arguments ?? {});
const result = await router.callTool(request.params.name, request.params.arguments ?? {}, clientIdentity);
return result as any;
} catch (err) {
return {
Expand Down Expand Up @@ -571,7 +572,7 @@ export async function startHttpServer(router: ToolRouter, options: HttpServerOpt
res.json({
status: 'ok',
uptime: Math.floor(process.uptime()),
tools: router.getAllTools().length,
tools: router.getAllTools(readClientIdentity(_req)).length,
sessions: sessions.size,
});
});
Expand Down Expand Up @@ -621,6 +622,10 @@ async function handleMcpRequest(

if (sessionId && sessions.has(sessionId)) {
const session = sessions.get(sessionId)!;
if (!sameClientIdentity(session.clientIdentity, readClientIdentity(req))) {
res.status(403).json({ error: 'mcp_session_client_mismatch' });
return;
}
session.lastActivity = Date.now();
if (isStandaloneSseRequest(req)) {
session.transport.closeStandaloneSSEStream();
Expand All @@ -637,7 +642,8 @@ async function handleMcpRequest(
const transport = new StreamableHTTPServerTransport({
sessionIdGenerator: () => randomUUID(),
});
const server = createMcpServer(router);
const clientIdentity = readClientIdentity(req);
const server = createMcpServer(router, clientIdentity);

transport.onerror = (error) => {
if (isBenignDuplicateSseConflict(error)) {
Expand All @@ -656,15 +662,15 @@ async function handleMcpRequest(
await transport.handleRequest(req, res, req.body);

if (transport.sessionId) {
sessions.set(transport.sessionId, { transport, server, lastActivity: Date.now() });
sessions.set(transport.sessionId, { transport, server, clientIdentity, lastActivity: Date.now() });
}
}

async function handleStatelessMcpRequest(req: Request, res: Response, router: ToolRouter): Promise<void> {
const transport = new StreamableHTTPServerTransport({
sessionIdGenerator: undefined,
});
const server = createMcpServer(router);
const server = createMcpServer(router, readClientIdentity(req));

transport.onerror = (error) => {
if (isBenignDuplicateSseConflict(error)) {
Expand All @@ -678,6 +684,10 @@ async function handleStatelessMcpRequest(req: Request, res: Response, router: To
await transport.handleRequest(req, res, req.body);
}

function sameClientIdentity(left: ClientIdentity | undefined, right: ClientIdentity | undefined): boolean {
return left?.id === right?.id;
}

function authLogKind(req: Request): string {
if (req.path === '/mcp') return 'mcp.auth';
if (req.path === '/health') return 'health.auth';
Expand Down
125 changes: 105 additions & 20 deletions src/server/router.ts
Original file line number Diff line number Diff line change
@@ -1,56 +1,91 @@
import { CallToolResult, Connector, ToolDefinition } from '../connectors/types.js';
import { isLikelyWriteTool, isMemPalaceWriteTool } from '../connectors/write-policy.js';
import { PermissionConfig } from '../config/schema.js';
import { PatternRedactorAuditEvent, ToolResultPlugin } from '../plugins/types.js';
import { AuditLogger, summarizeArgs } from '../utils/audit.js';
import { ClientIdentity } from './client-identity.js';

type PermissionAction = PermissionConfig['actions'][number];

export interface RouterConnector {
connector: Connector;
sourceId?: string;
}

export interface NamespacedTool {
namespacedName: string;
originalName: string;
connectorId: string;
sourceId: string;
requiredAction: PermissionAction;
description: string;
inputSchema: Record<string, unknown>;
}

export class ToolRouter {
private readonly toolMap = new Map<string, { connector: Connector; originalName: string }>();
private readonly toolMap = new Map<string, {
connector: Connector;
originalName: string;
sourceId: string;
requiredAction: PermissionAction;
}>();
private readonly allTools: NamespacedTool[] = [];
private readonly connectors: RouterConnector[];

constructor(
private readonly connectors: Connector[],
connectors: Array<Connector | RouterConnector>,
private readonly audit?: AuditLogger,
private readonly plugins: ToolResultPlugin[] = [],
) {}
) {
this.connectors = connectors.map((entry) => (
'connector' in entry
? entry
: { connector: entry, sourceId: entry.id }
));
}

async initialize(): Promise<void> {
for (const connector of this.connectors) {
for (const entry of this.connectors) {
const { connector } = entry;
const sourceId = entry.sourceId ?? connector.id;
const tools = await connector.listTools();

for (const tool of tools) {
const namespacedName = `${connector.id}__${tool.name}`;
const requiredAction = inferRequiredAction(tool.name);
if (this.toolMap.has(namespacedName)) {
throw new Error(`Duplicate tool name after namespacing: ${namespacedName}`);
}

this.toolMap.set(namespacedName, {
connector,
originalName: tool.name,
sourceId,
requiredAction,
});

this.allTools.push({
namespacedName,
originalName: tool.name,
connectorId: connector.id,
sourceId,
requiredAction,
description: `[${connector.displayName}] ${tool.description}`,
inputSchema: normalizeInputSchema(tool),
});
}
}
}

getAllTools(): NamespacedTool[] {
return [...this.allTools];
getAllTools(identity?: ClientIdentity): NamespacedTool[] {
return this.allTools.filter((tool) => isToolAllowed(tool, identity)).map((tool) => ({ ...tool }));
}

async callTool(namespacedName: string, args: Record<string, unknown>): Promise<CallToolResult> {
async callTool(
namespacedName: string,
args: Record<string, unknown>,
identity?: ClientIdentity,
): Promise<CallToolResult> {
const entry = this.toolMap.get(namespacedName);
if (!entry) {
throw new Error(`Unknown tool: ${namespacedName}`);
Expand All @@ -60,6 +95,16 @@ export class ToolRouter {
let result: CallToolResult;
const redactions: NonNullable<import('../utils/audit.js').AuditEntry['redactions']> = [];
let threw = false;
const deniedReason = toolDeniedReason(entry, identity);
if (deniedReason) {
result = {
content: [{ type: 'text', text: `Error: access denied for tool "${namespacedName}" (${deniedReason}).` }],
isError: true,
};
this.recordAudit(entry, namespacedName, args, start, result, false, redactions, identity, deniedReason);
return result;
}

try {
result = await entry.connector.callTool(entry.originalName, args);
for (const plugin of this.plugins) {
Expand All @@ -78,21 +123,61 @@ export class ToolRouter {
threw = true;
throw err;
} finally {
if (this.audit) {
const { argKeys, argPreview } = summarizeArgs(args);
this.audit.record({
ts: new Date().toISOString(),
connectorId: entry.connector.id,
tool: namespacedName,
argKeys,
argPreview,
...(redactions.length > 0 ? { redactions } : {}),
isError: threw || Boolean(result! && (result as CallToolResult).isError),
durationMs: Date.now() - start,
});
}
this.recordAudit(entry, namespacedName, args, start, result!, threw, redactions, identity);
}
}

private recordAudit(
entry: { connector: Connector },
namespacedName: string,
args: Record<string, unknown>,
start: number,
result: CallToolResult | undefined,
threw: boolean,
redactions: NonNullable<import('../utils/audit.js').AuditEntry['redactions']>,
identity?: ClientIdentity,
deniedReason?: string,
): void {
if (!this.audit) return;
const { argKeys, argPreview } = summarizeArgs(args);
this.audit.record({
ts: new Date().toISOString(),
connectorId: entry.connector.id,
tool: namespacedName,
...(identity ? { clientId: identity.id } : {}),
argKeys,
argPreview,
...(redactions.length > 0 ? { redactions } : {}),
isError: threw || Boolean(result?.isError),
...(deniedReason ? { deniedReason } : {}),
durationMs: Date.now() - start,
});
}
}

function isToolAllowed(tool: NamespacedTool, identity?: ClientIdentity): boolean {
return !toolDeniedReason(tool, identity);
}

function toolDeniedReason(
tool: { sourceId: string; requiredAction: PermissionAction },
identity?: ClientIdentity,
): string | undefined {
if (!identity || identity.isLegacyDefault) return undefined;
if (!identity.rawToolsEnabled) return 'raw_tools_disabled';
const allowedActions = identity.permissions
.filter((permission) => permission.sourceId === tool.sourceId)
.flatMap((permission) => permission.actions);
if (allowedActions.includes(tool.requiredAction)) return undefined;
return `missing_permission source=${tool.sourceId} action=${tool.requiredAction}`;
}

function inferRequiredAction(toolName: string): PermissionAction {
const lower = toolName.toLowerCase();
if (isMemPalaceWriteTool(lower)) return 'memory_write';
if (isLikelyWriteTool(lower)) return 'write';
if (lower.includes('search') || lower.startsWith('find_') || lower.startsWith('query_')) return 'search';
return 'read';
}

function flattenRedactionEvents(
Expand Down
2 changes: 2 additions & 0 deletions src/utils/audit.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ export interface AuditEntry {
ts: string;
connectorId: string;
tool: string;
clientId?: string;
argKeys: string[];
argPreview: string;
redactions?: Array<{
Expand All @@ -18,6 +19,7 @@ export interface AuditEntry {
truncated?: boolean;
}>;
isError: boolean;
deniedReason?: string;
durationMs: number;
}

Expand Down
Loading
Loading