diff --git a/crates/ruvector-cli/src/mcp/transport.rs b/crates/ruvector-cli/src/mcp/transport.rs index ad68d83fa..b54000fee 100644 --- a/crates/ruvector-cli/src/mcp/transport.rs +++ b/crates/ruvector-cli/src/mcp/transport.rs @@ -13,7 +13,7 @@ use futures::stream::Stream; use serde_json; use std::sync::Arc; use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; -use tower_http::cors::CorsLayer; +use tower_http::cors::{AllowOrigin, CorsLayer}; /// STDIO transport for local MCP communication pub struct StdioTransport { @@ -97,11 +97,26 @@ impl SseTransport { /// Run SSE transport server pub async fn run(&self) -> Result<()> { + // Use restrictive CORS: only allow localhost origins by default + let cors = CorsLayer::new() + .allow_origin(AllowOrigin::predicate(|origin, _| { + if let Ok(origin_str) = origin.to_str() { + origin_str.starts_with("http://127.0.0.1") + || origin_str.starts_with("http://localhost") + || origin_str.starts_with("https://127.0.0.1") + || origin_str.starts_with("https://localhost") + } else { + false + } + })) + .allow_methods([axum::http::Method::GET, axum::http::Method::POST]) + .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION]); + let app = Router::new() .route("/", get(root)) .route("/mcp", post(mcp_handler)) .route("/mcp/sse", get(mcp_sse_handler)) - .layer(CorsLayer::permissive()) + .layer(cors) .with_state(self.handler.clone()); let addr = format!("{}:{}", self.host, self.port); diff --git a/npm/packages/ruvector/bin/mcp-server.js b/npm/packages/ruvector/bin/mcp-server.js index 29fc6840b..9a9f5bb92 100644 --- a/npm/packages/ruvector/bin/mcp-server.js +++ b/npm/packages/ruvector/bin/mcp-server.js @@ -36,18 +36,38 @@ function validateRvfPath(filePath) { if (typeof filePath !== 'string' || filePath.length === 0) { throw new Error('Path must be a non-empty string'); } - const resolved = path.resolve(filePath); - // Block obvious path traversal - if (filePath.includes('..') || filePath.includes('\0')) { - throw new Error('Path traversal detected'); + // Block null bytes + if (filePath.includes('\0')) { + throw new Error('Path contains null bytes'); } - // Block sensitive system paths - const blocked = ['/etc', '/proc', '/sys', '/dev', '/boot', '/root', '/var/run']; - for (const prefix of blocked) { - if (resolved.startsWith(prefix)) { - throw new Error(`Access to ${prefix} is not allowed`); + // Resolve to absolute, then canonicalize via realpath if it exists + let resolved = path.resolve(filePath); + try { + // Resolve symlinks for existing paths to prevent symlink-based escapes + resolved = fs.realpathSync(resolved); + } catch { + // Path doesn't exist yet — resolve the parent directory + const parentDir = path.dirname(resolved); + try { + const realParent = fs.realpathSync(parentDir); + resolved = path.join(realParent, path.basename(resolved)); + } catch { + // Parent doesn't exist either — keep the resolved path for the block check } } + // Confine to the current working directory + const cwd = process.cwd(); + if (!resolved.startsWith(cwd + path.sep) && resolved !== cwd) { + // Also block sensitive system paths regardless + const blocked = ['/etc', '/proc', '/sys', '/dev', '/boot', '/root', '/var/run', '/var/log', '/tmp']; + for (const prefix of blocked) { + if (resolved.startsWith(prefix)) { + throw new Error(`Access denied: path resolves to '${resolved}' which is outside the working directory and in restricted area '${prefix}'`); + } + } + // Allow paths outside cwd only if they're not in blocked directories + // (for tools that reference project files by absolute path) + } return resolved; } @@ -57,14 +77,24 @@ function validateRvfPath(filePath) { */ function sanitizeShellArg(arg) { if (typeof arg !== 'string') return ''; - // Remove null bytes, backticks, $(), and other shell metacharacters + // Remove null bytes, backticks, $(), quotes, newlines, and other shell metacharacters return arg .replace(/\0/g, '') - .replace(/[`$(){}|;&<>!]/g, '') + .replace(/[\r\n]/g, '') + .replace(/[`$(){}|;&<>!'"\\]/g, '') .replace(/\.\./g, '') .slice(0, 4096); } +/** + * Validate a numeric argument (returns integer or default). + * Prevents injection via numeric-looking fields. + */ +function sanitizeNumericArg(arg, defaultVal) { + const n = parseInt(arg, 10); + return Number.isFinite(n) && n > 0 ? n : (defaultVal || 0); +} + // Try to load the full IntelligenceEngine let IntelligenceEngine = null; let engineAvailable = false; @@ -1319,7 +1349,7 @@ server.setRequestHandler(CallToolRequestSchema, async (request) => { let cmd = 'npx ruvector hooks init'; if (args.force) cmd += ' --force'; if (args.pretrain) cmd += ' --pretrain'; - if (args.build_agents) cmd += ` --build-agents ${args.build_agents}`; + if (args.build_agents) cmd += ` --build-agents ${sanitizeShellArg(args.build_agents)}`; try { const output = execSync(cmd, { encoding: 'utf-8', timeout: 60000 }); @@ -1341,7 +1371,7 @@ server.setRequestHandler(CallToolRequestSchema, async (request) => { case 'hooks_pretrain': { let cmd = 'npx ruvector hooks pretrain'; - if (args.depth) cmd += ` --depth ${args.depth}`; + if (args.depth) cmd += ` --depth ${sanitizeNumericArg(args.depth, 3)}`; if (args.skip_git) cmd += ' --skip-git'; if (args.verbose) cmd += ' --verbose'; @@ -1371,7 +1401,7 @@ server.setRequestHandler(CallToolRequestSchema, async (request) => { case 'hooks_build_agents': { let cmd = 'npx ruvector hooks build-agents'; - if (args.focus) cmd += ` --focus ${args.focus}`; + if (args.focus) cmd += ` --focus ${sanitizeShellArg(args.focus)}`; if (args.include_prompts) cmd += ' --include-prompts'; try { @@ -1484,21 +1514,44 @@ server.setRequestHandler(CallToolRequestSchema, async (request) => { const data = args.data; const merge = args.merge !== false; - if (data.patterns) { + // Validate imported data structure to prevent prototype pollution and injection + if (typeof data !== 'object' || data === null || Array.isArray(data)) { + throw new Error('Import data must be a non-null object'); + } + const allowedKeys = ['patterns', 'memories', 'errors', 'agents', 'edges', 'trajectories']; + for (const key of Object.keys(data)) { + if (!allowedKeys.includes(key)) { + throw new Error(`Unknown import key: '${key}'. Allowed: ${allowedKeys.join(', ')}`); + } + } + // Prevent prototype pollution via __proto__, constructor, prototype keys + const dangerousKeys = ['__proto__', 'constructor', 'prototype']; + function checkForProtoPollution(obj, path) { + if (typeof obj !== 'object' || obj === null) return; + for (const key of Object.keys(obj)) { + if (dangerousKeys.includes(key)) { + throw new Error(`Dangerous key '${key}' detected at ${path}.${key}`); + } + } + } + if (data.patterns) checkForProtoPollution(data.patterns, 'patterns'); + if (data.errors) checkForProtoPollution(data.errors, 'errors'); + + if (data.patterns && typeof data.patterns === 'object') { if (merge) { Object.assign(intel.data.patterns, data.patterns); } else { intel.data.patterns = data.patterns; } } - if (data.memories) { + if (data.memories && Array.isArray(data.memories)) { if (merge) { intel.data.memories = [...(intel.data.memories || []), ...data.memories]; } else { intel.data.memories = data.memories; } } - if (data.errors) { + if (data.errors && typeof data.errors === 'object') { if (merge) { Object.assign(intel.data.errors, data.errors); } else { @@ -2426,7 +2479,7 @@ server.setRequestHandler(CallToolRequestSchema, async (request) => { case 'workers_status': { try { - const cmdArgs = args.workerId ? `workers status ${args.workerId}` : 'workers status'; + const cmdArgs = args.workerId ? `workers status ${sanitizeShellArg(args.workerId)}` : 'workers status'; const result = execSync(`npx agentic-flow@alpha ${cmdArgs}`, { encoding: 'utf-8', timeout: 15000,