From c5bd0a369d2a82aad63820028c6283a623a700f1 Mon Sep 17 00:00:00 2001 From: Skiipy Date: Fri, 27 Mar 2026 12:41:59 -0400 Subject: [PATCH] Align rate limiting with auth key sources --- .env.example | 4 ++ README.md | 1 + api/src/middleware/auth.js | 28 +++++++- api/src/middleware/ratelimit.js | 2 +- api/src/routes/export.js | 64 ++++++++--------- api/src/routes/memory.js | 36 ++++++---- api/src/services/memory-write-utils.js | 57 +++++++++++++++ api/src/services/stores/postgres.js | 1 + api/src/services/stores/sqlite.js | 19 +---- api/tests/auth.test.js | 98 ++++++++++++++++++++++++++ api/tests/memory-write-utils.test.js | 79 +++++++++++++++++++++ api/tests/ratelimit.test.js | 52 ++++++++++++++ 12 files changed, 374 insertions(+), 67 deletions(-) create mode 100644 api/src/services/memory-write-utils.js create mode 100644 api/tests/auth.test.js create mode 100644 api/tests/memory-write-utils.test.js create mode 100644 api/tests/ratelimit.test.js diff --git a/.env.example b/.env.example index b6aab37..da817dd 100644 --- a/.env.example +++ b/.env.example @@ -15,6 +15,10 @@ PORT=8084 # AGENT_KEY_n8n=your-key-here # AGENT_KEY_morpheus=your-key-here +# Optional compatibility toggle: allow ?key=... query auth for browser-only tools (e.g. graph HTML pages) +# Strongly recommended to keep disabled in production. +# ALLOW_QUERY_API_KEY=false + # --- Rate Limiting --- # RATE_LIMIT_WRITES=60 # Max write requests per minute per key (default: 60) # RATE_LIMIT_READS=120 # Max read requests per minute per key (default: 120) diff --git a/README.md b/README.md index 5f9354b..94694ac 100644 --- a/README.md +++ b/README.md @@ -629,6 +629,7 @@ All configuration is via environment variables. Copy `.env.example` to `.env` an | `QDRANT_API_KEY` | — | Qdrant API key | | `PORT` | `8084` | API server port | | `HOST` | `127.0.0.1` | Bind address. Set to `0.0.0.0` for LAN/Docker access. | +| `ALLOW_QUERY_API_KEY` | `false` | Compatibility flag to accept `?key=...` query auth (not recommended; prefer `x-api-key`/Bearer headers). | ### Embedding Provider diff --git a/api/src/middleware/auth.js b/api/src/middleware/auth.js index 844f43b..98c50c7 100644 --- a/api/src/middleware/auth.js +++ b/api/src/middleware/auth.js @@ -1,6 +1,7 @@ import crypto from 'crypto'; const ADMIN_KEY = process.env.BRAIN_API_KEY; +const ALLOW_QUERY_API_KEY = process.env.ALLOW_QUERY_API_KEY === 'true'; // Build agent registry from env vars: AGENT_KEY_= // e.g. AGENT_KEY_claude_code=abc123 → { key: 'abc123', agent: 'claude-code' } @@ -52,6 +53,23 @@ function safeEqual(a, b) { return crypto.timingSafeEqual(Buffer.from(a), Buffer.from(b)); } +function extractKey(req) { + const headerKey = req.headers['x-api-key']; + if (headerKey) return { key: headerKey, source: 'header' }; + + const authHeader = req.headers.authorization || req.headers.Authorization; + if (typeof authHeader === 'string' && authHeader.startsWith('Bearer ')) { + const bearerToken = authHeader.slice('Bearer '.length).trim(); + if (bearerToken) return { key: bearerToken, source: 'bearer' }; + } + + if (ALLOW_QUERY_API_KEY && req.query?.key) { + return { key: req.query.key, source: 'query' }; + } + + return { key: null, source: null }; +} + export function authMiddleware(req, res, next) { const ip = req.ip || req.socket.remoteAddress; @@ -59,9 +77,13 @@ export function authMiddleware(req, res, next) { return res.status(429).json({ error: 'Too many failed attempts. Try again later.' }); } - const key = req.headers['x-api-key'] || req.query.key; + const { key, source } = extractKey(req); if (!key) { recordFailure(ip); + const queryKeyProvided = req.query?.key && !ALLOW_QUERY_API_KEY; + if (queryKeyProvided) { + return res.status(401).json({ error: 'Query-string API keys are disabled. Use x-api-key header.' }); + } return res.status(401).json({ error: 'Missing API key' }); } @@ -69,12 +91,16 @@ export function authMiddleware(req, res, next) { const agentName = agentRegistry.get(key); if (agentName) { req.authenticatedAgent = agentName; + req.authSource = source; + req.rateLimitKey = key; return next(); } // Fall back to admin key (no agent binding — full access) if (safeEqual(key, ADMIN_KEY)) { req.authenticatedAgent = null; // admin — no agent identity enforced + req.authSource = source; + req.rateLimitKey = key; return next(); } diff --git a/api/src/middleware/ratelimit.js b/api/src/middleware/ratelimit.js index ac7f63a..9492cb6 100644 --- a/api/src/middleware/ratelimit.js +++ b/api/src/middleware/ratelimit.js @@ -38,7 +38,7 @@ function classifyRequest(method, path) { } export function rateLimitMiddleware(req, res, next) { - const apiKey = req.headers['x-api-key'] || 'unknown'; + const apiKey = req.rateLimitKey || req.headers['x-api-key'] || req.headers.authorization || req.query?.key || 'unknown'; const type = classifyRequest(req.method, req.path); const { limited, retryAfter } = checkLimit(apiKey, type); diff --git a/api/src/routes/export.js b/api/src/routes/export.js index 00b49fe..ad80e61 100644 --- a/api/src/routes/export.js +++ b/api/src/routes/export.js @@ -3,6 +3,7 @@ import crypto from 'crypto'; import { scrollPoints, upsertPoint, findByPayload } from '../services/qdrant.js'; import { embed } from '../services/embedders/interface.js'; import { isStoreAvailable, createEvent, upsertFact, upsertStatus } from '../services/stores/interface.js'; +import { buildDedupExtraFilter, normalizeImportRecord } from '../services/memory-write-utils.js'; export const exportRouter = Router(); @@ -96,49 +97,46 @@ exportRouter.post('/import', async (req, res) => { // Process each record in the batch sequentially for (const record of batch) { try { - const content = record.content || record.text || ''; - if (!content) { + const { normalized, contentHash, error } = normalizeImportRecord(record); + if (error) { errors++; continue; } - // Compute content hash (SHA-256, first 16 hex chars — matches memory.js pattern) - const contentHash = crypto.createHash('sha256').update(content).digest('hex').slice(0, 16); - - // Check for existing memory with same content hash - const existing = await findByPayload('content_hash', contentHash); + // Check for existing memory with same hash in the same tenant/type scope + const existing = await findByPayload('content_hash', contentHash, buildDedupExtraFilter(normalized.client_id, normalized.type)); if (existing.length > 0) { skipped++; continue; } // Embed and generate ID - const vector = await embed(content); - const pointId = record.id || crypto.randomUUID(); + const vector = await embed(normalized.content, 'store'); + const pointId = normalized.id || crypto.randomUUID(); const now = new Date().toISOString(); // Build full payload const payload = { - text: content, - type: record.type || 'event', - key: record.key || null, - subject: record.subject || null, - client_id: record.client_id || 'global', - knowledge_category: record.knowledge_category || null, - category: record.category || 'episodic', - source_agent: record.source_agent || 'import', - importance: record.importance || 'medium', - confidence: record.confidence !== undefined ? record.confidence : 1.0, - access_count: record.access_count || 0, - active: record.active !== undefined ? record.active : true, - superseded_by: record.superseded_by || null, - entities: record.entities || [], + text: normalized.content, + type: normalized.type, + key: normalized.key || null, + subject: normalized.subject || null, + client_id: normalized.client_id, + knowledge_category: normalized.knowledge_category, + category: normalized.category, + source_agent: normalized.source_agent, + importance: normalized.importance, + confidence: normalized.confidence !== undefined ? normalized.confidence : 1.0, + access_count: normalized.access_count || 0, + active: normalized.active !== undefined ? normalized.active : true, + superseded_by: normalized.superseded_by || null, + entities: normalized.entities || [], content_hash: contentHash, - created_at: record.created_at || now, - last_accessed_at: record.last_accessed_at || now, - observed_by: record.observed_by || [record.source_agent || 'import'], - observation_count: record.observation_count || 1, - consolidated: record.consolidated || false, + created_at: normalized.created_at || now, + last_accessed_at: normalized.last_accessed_at || now, + observed_by: normalized.observed_by || [normalized.source_agent], + observation_count: normalized.observation_count || 1, + consolidated: normalized.consolidated || false, }; // Upsert to Qdrant @@ -148,7 +146,7 @@ exportRouter.post('/import', async (req, res) => { if (isStoreAvailable()) { try { const storeData = { - content, + content: normalized.content, source_agent: payload.source_agent, client_id: payload.client_id, category: payload.category, @@ -162,12 +160,12 @@ exportRouter.post('/import', async (req, res) => { storeData.type = type; await createEvent(storeData); } else if (type === 'fact') { - storeData.key = record.key || contentHash; - storeData.value = content; + storeData.key = normalized.key || contentHash; + storeData.value = normalized.content; await upsertFact(storeData); } else if (type === 'status') { - storeData.subject = record.subject || 'unknown'; - storeData.status = record.status_value || content; + storeData.subject = normalized.subject || 'unknown'; + storeData.status = normalized.status_value || normalized.content; await upsertStatus(storeData); } } catch (storeErr) { diff --git a/api/src/routes/memory.js b/api/src/routes/memory.js index 9c6cb1e..aa1bd61 100644 --- a/api/src/routes/memory.js +++ b/api/src/routes/memory.js @@ -9,16 +9,17 @@ import { createEvent, upsertFact, upsertStatus, listEvents, listFacts, listStatuses, isStoreAvailable, isEntityStoreAvailable, createEntity, findEntity, linkEntityToMemory, createRelationship, } from '../services/stores/interface.js'; -import { scrubCredentials, scrubObject } from '../services/scrub.js'; +import { scrubObject } from '../services/scrub.js'; import { extractEntities, linkExtractedEntities } from '../services/entities.js'; -import { validateMemoryInput, MAX_OBSERVED_BY } from '../middleware/validate.js'; +import { MAX_OBSERVED_BY } from '../middleware/validate.js'; import { dispatchNotification } from '../services/notifications.js'; import { isKeywordSearchAvailable, indexMemory, deactivateMemory, keywordSearch } from '../services/keyword-search.js'; import { isGraphSearchAvailable, graphSearch } from '../services/graph-search.js'; import { reciprocalRankFusion } from '../services/rrf.js'; +import { buildDedupExtraFilter, normalizeMemoryRecord } from '../services/memory-write-utils.js'; +import { getClientResolver } from '../services/client-resolver.js'; const MULTI_PATH_SEARCH = process.env.MULTI_PATH_SEARCH !== 'false'; // default: true -import { getClientResolver } from '../services/client-resolver.js'; export const memoryRouter = Router(); @@ -27,12 +28,6 @@ memoryRouter.post('/', async (req, res) => { try { let { type, content, source_agent, client_id, category, importance, knowledge_category, metadata } = req.body; - // Validate all input fields - const validationError = validateMemoryInput(req.body); - if (validationError) { - return res.status(400).json({ error: validationError }); - } - // Enforce agent identity: if authenticated with an agent key, source_agent must match if (req.authenticatedAgent && source_agent !== req.authenticatedAgent) { return res.status(403).json({ @@ -49,14 +44,27 @@ memoryRouter.post('/', async (req, res) => { } } - // Scrub credentials - const cleanContent = scrubCredentials(content); + const { normalized, contentHash, error: normalizationError } = normalizeMemoryRecord({ + ...req.body, + type, + content, + source_agent, + client_id, + category, + importance, + knowledge_category, + metadata, + }); + + if (normalizationError) { + return res.status(400).json({ error: normalizationError }); + } - // Generate content hash for dedup - const contentHash = crypto.createHash('sha256').update(cleanContent).digest('hex').slice(0, 16); + ({ type, content, source_agent, client_id, category, importance, knowledge_category, metadata } = normalized); + const cleanContent = normalized.content; // --- Deduplication check --- - const duplicates = await findByPayload('content_hash', contentHash, { active: true }); + const duplicates = await findByPayload('content_hash', contentHash, buildDedupExtraFilter(client_id, type)); if (duplicates.length > 0) { const existing = duplicates[0]; const existingObservedBy = existing.payload.observed_by || [existing.payload.source_agent]; diff --git a/api/src/services/memory-write-utils.js b/api/src/services/memory-write-utils.js new file mode 100644 index 0000000..45cbe4d --- /dev/null +++ b/api/src/services/memory-write-utils.js @@ -0,0 +1,57 @@ +import crypto from 'crypto'; +import { scrubCredentials } from './scrub.js'; +import { validateMemoryInput } from '../middleware/validate.js'; + +export function buildDedupExtraFilter(clientId, type) { + return { + active: true, + client_id: clientId || 'global', + type, + }; +} + +export function normalizeMemoryRecord(record = {}, options = {}) { + const { + defaultType, + defaultSourceAgent, + } = options; + + const rawContent = record.content || record.text || ''; + const normalized = { + ...record, + type: record.type || defaultType, + content: scrubCredentials(rawContent), + source_agent: record.source_agent || defaultSourceAgent, + client_id: record.client_id || 'global', + category: record.category || 'episodic', + importance: record.importance || 'medium', + knowledge_category: record.knowledge_category || 'general', + }; + + const validationError = validateMemoryInput({ + type: normalized.type, + content: normalized.content, + source_agent: normalized.source_agent, + importance: normalized.importance, + client_id: normalized.client_id, + key: normalized.key, + subject: normalized.subject, + status_value: normalized.status_value, + }); + + if (validationError) { + return { error: validationError }; + } + + return { + normalized, + contentHash: crypto.createHash('sha256').update(normalized.content).digest('hex').slice(0, 16), + }; +} + +export function normalizeImportRecord(record = {}) { + return normalizeMemoryRecord(record, { + defaultType: 'event', + defaultSourceAgent: 'import', + }); +} diff --git a/api/src/services/stores/postgres.js b/api/src/services/stores/postgres.js index ba8011f..9a87c44 100644 --- a/api/src/services/stores/postgres.js +++ b/api/src/services/stores/postgres.js @@ -241,6 +241,7 @@ export class PostgresStore { if (filters.source_agent) { sql += ` AND source_agent = $${i++}`; params.push(filters.source_agent); } if (filters.category) { sql += ` AND category = $${i++}`; params.push(filters.category); } + if (filters.client_id) { sql += ` AND client_id = $${i++}`; params.push(filters.client_id); } if (filters.subject) { sql += ` AND subject ILIKE $${i++}`; params.push(`%${filters.subject}%`); } sql += ' ORDER BY updated_at DESC LIMIT 50'; diff --git a/api/src/services/stores/sqlite.js b/api/src/services/stores/sqlite.js index b20b80b..1b5adcb 100644 --- a/api/src/services/stores/sqlite.js +++ b/api/src/services/stores/sqlite.js @@ -125,24 +125,6 @@ export class SQLiteStore { console.warn('[sqlite] idx_eml_unique creation failed:', e.message); } } - - // Entity relationships table - this.db.exec(` - CREATE TABLE IF NOT EXISTS entity_relationships ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - source_entity_id INTEGER REFERENCES entities(id), - target_entity_id INTEGER REFERENCES entities(id), - relationship_type TEXT NOT NULL DEFAULT 'co_occurrence', - strength INTEGER DEFAULT 1, - created_at TEXT DEFAULT (datetime('now')), - updated_at TEXT DEFAULT (datetime('now')), - UNIQUE(source_entity_id, target_entity_id, relationship_type) - ); - - CREATE INDEX IF NOT EXISTS idx_er_source ON entity_relationships(source_entity_id); - CREATE INDEX IF NOT EXISTS idx_er_target ON entity_relationships(target_entity_id); - `); - // FTS5 virtual table for keyword search (BM25) try { this.db.exec(` @@ -294,6 +276,7 @@ export class SQLiteStore { if (filters.source_agent) { sql += ' AND source_agent = @source_agent'; params.source_agent = filters.source_agent; } if (filters.category) { sql += ' AND category = @category'; params.category = filters.category; } + if (filters.client_id) { sql += ' AND client_id = @client_id'; params.client_id = filters.client_id; } if (filters.subject) { sql += ' AND subject LIKE @subject'; params.subject = `%${filters.subject}%`; } sql += ' ORDER BY updated_at DESC LIMIT 50'; diff --git a/api/tests/auth.test.js b/api/tests/auth.test.js new file mode 100644 index 0000000..fb01271 --- /dev/null +++ b/api/tests/auth.test.js @@ -0,0 +1,98 @@ +import { describe, test } from 'node:test'; +import assert from 'node:assert/strict'; + +function createRes() { + return { + statusCode: 200, + body: null, + status(code) { + this.statusCode = code; + return this; + }, + json(payload) { + this.body = payload; + return this; + }, + }; +} + +async function loadAuth({ allowQuery = 'false' } = {}) { + process.env.BRAIN_API_KEY = 'admin-secret-key'; + process.env.AGENT_KEY_test_agent = 'agent-secret-key'; + process.env.ALLOW_QUERY_API_KEY = allowQuery; + const mod = await import(`../src/middleware/auth.js?case=${Date.now()}-${Math.random()}`); + return mod.authMiddleware; +} + +describe('authMiddleware', () => { + test('accepts x-api-key header for admin key', async () => { + const authMiddleware = await loadAuth({ allowQuery: 'false' }); + const req = { + headers: { 'x-api-key': 'admin-secret-key' }, + query: {}, + ip: '10.0.0.1', + socket: { remoteAddress: '10.0.0.1' }, + }; + const res = createRes(); + let called = false; + + authMiddleware(req, res, () => { called = true; }); + + assert.equal(called, true); + assert.equal(req.authenticatedAgent, null); + assert.equal(req.authSource, 'header'); + assert.equal(req.rateLimitKey, 'admin-secret-key'); + }); + + test('accepts bearer token auth', async () => { + const authMiddleware = await loadAuth({ allowQuery: 'false' }); + const req = { + headers: { authorization: 'Bearer admin-secret-key' }, + query: {}, + ip: '10.0.0.11', + socket: { remoteAddress: '10.0.0.11' }, + }; + const res = createRes(); + let called = false; + + authMiddleware(req, res, () => { called = true; }); + + assert.equal(called, true); + assert.equal(req.authSource, 'bearer'); + assert.equal(req.rateLimitKey, 'admin-secret-key'); + }); + + test('rejects query key when ALLOW_QUERY_API_KEY=false', async () => { + const authMiddleware = await loadAuth({ allowQuery: 'false' }); + const req = { + headers: {}, + query: { key: 'admin-secret-key' }, + ip: '10.0.0.2', + socket: { remoteAddress: '10.0.0.2' }, + }; + const res = createRes(); + + authMiddleware(req, res, () => {}); + + assert.equal(res.statusCode, 401); + assert.deepEqual(res.body, { error: 'Query-string API keys are disabled. Use x-api-key header.' }); + }); + + test('accepts query key when ALLOW_QUERY_API_KEY=true', async () => { + const authMiddleware = await loadAuth({ allowQuery: 'true' }); + const req = { + headers: {}, + query: { key: 'admin-secret-key' }, + ip: '10.0.0.3', + socket: { remoteAddress: '10.0.0.3' }, + }; + const res = createRes(); + let called = false; + + authMiddleware(req, res, () => { called = true; }); + + assert.equal(called, true); + assert.equal(req.authSource, 'query'); + assert.equal(req.rateLimitKey, 'admin-secret-key'); + }); +}); diff --git a/api/tests/memory-write-utils.test.js b/api/tests/memory-write-utils.test.js new file mode 100644 index 0000000..4a24e0a --- /dev/null +++ b/api/tests/memory-write-utils.test.js @@ -0,0 +1,79 @@ +import { describe, test } from 'node:test'; +import assert from 'node:assert/strict'; +import { buildDedupExtraFilter, normalizeImportRecord, normalizeMemoryRecord } from '../src/services/memory-write-utils.js'; + +describe('buildDedupExtraFilter', () => { + test('includes active + tenant scope + type', () => { + assert.deepEqual(buildDedupExtraFilter('acme', 'event'), { + active: true, + client_id: 'acme', + type: 'event', + }); + }); + + test('defaults client_id to global', () => { + assert.deepEqual(buildDedupExtraFilter(undefined, 'fact'), { + active: true, + client_id: 'global', + type: 'fact', + }); + }); +}); + + +describe('normalizeMemoryRecord', () => { + test('requires type/source_agent when no defaults are supplied', () => { + const { error } = normalizeMemoryRecord({ content: 'hello' }); + assert.match(error, /type is required/); + }); + + test('normalizes and hashes valid store payloads', () => { + const { normalized, contentHash, error } = normalizeMemoryRecord({ + type: 'event', + content: 'api_key=sk_live_abc123def456ghi789', + source_agent: 'agent_1', + }); + + assert.equal(error, undefined); + assert.equal(normalized.client_id, 'global'); + assert.equal(typeof contentHash, 'string'); + assert.equal(contentHash.length, 16); + assert.match(normalized.content, /\[CREDENTIAL_REDACTED\]/); + }); +}); + +describe('normalizeImportRecord', () => { + test('normalizes defaults and validates record', () => { + const { normalized, error } = normalizeImportRecord({ + content: 'Deployment completed', + source_agent: 'import_agent', + }); + + assert.equal(error, undefined); + assert.equal(normalized.type, 'event'); + assert.equal(normalized.client_id, 'global'); + assert.equal(normalized.importance, 'medium'); + assert.equal(normalized.category, 'episodic'); + }); + + test('scrubs credentials from imported content', () => { + const { normalized, error } = normalizeImportRecord({ + type: 'event', + source_agent: 'import_agent', + content: 'token=sk_live_abc123def456ghi789', + }); + + assert.equal(error, undefined); + assert.match(normalized.content, /\[CREDENTIAL_REDACTED\]/); + }); + + test('rejects invalid source_agent', () => { + const { error } = normalizeImportRecord({ + type: 'event', + source_agent: 'bad agent with spaces', + content: 'hello', + }); + + assert.match(error, /source_agent/); + }); +}); diff --git a/api/tests/ratelimit.test.js b/api/tests/ratelimit.test.js new file mode 100644 index 0000000..940cadd --- /dev/null +++ b/api/tests/ratelimit.test.js @@ -0,0 +1,52 @@ +import { describe, test } from 'node:test'; +import assert from 'node:assert/strict'; + +function createRes() { + return { + statusCode: 200, + body: null, + headers: {}, + status(code) { + this.statusCode = code; + return this; + }, + json(payload) { + this.body = payload; + return this; + }, + set(name, value) { + this.headers[name] = value; + return this; + }, + }; +} + +async function loadRateLimit() { + process.env.RATE_LIMIT_READS = '1'; + process.env.RATE_LIMIT_WRITES = '1'; + const mod = await import(`../src/middleware/ratelimit.js?case=${Date.now()}-${Math.random()}`); + return mod.rateLimitMiddleware; +} + +describe('rateLimitMiddleware', () => { + test('uses req.rateLimitKey for bucketing', async () => { + const rateLimitMiddleware = await loadRateLimit(); + + const req1 = { method: 'GET', path: '/memory/search', headers: {}, query: {}, rateLimitKey: 'agent-a' }; + const res1 = createRes(); + let next1 = false; + rateLimitMiddleware(req1, res1, () => { next1 = true; }); + assert.equal(next1, true); + + const req2 = { method: 'GET', path: '/memory/search', headers: {}, query: {}, rateLimitKey: 'agent-a' }; + const res2 = createRes(); + rateLimitMiddleware(req2, res2, () => {}); + assert.equal(res2.statusCode, 429); + + const req3 = { method: 'GET', path: '/memory/search', headers: {}, query: {}, rateLimitKey: 'agent-b' }; + const res3 = createRes(); + let next3 = false; + rateLimitMiddleware(req3, res3, () => { next3 = true; }); + assert.equal(next3, true); + }); +});