diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 3176b4e..e804f18 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -3,12 +3,10 @@ name: Benchmark on: push: branches: - - main - - master + - "**" pull_request: branches: - - main - - master + - "**" workflow_dispatch: inputs: engines: diff --git a/bench/run.js b/bench/run.js index 9ef81da..3eeac5b 100644 --- a/bench/run.js +++ b/bench/run.js @@ -9,7 +9,7 @@ const port = Number(portArg ?? 3001); function printUsage() { console.log("Usage: bun bench/run.js "); - console.log("Engines: http-native | bun | old | xitca | monoio | zig"); + console.log("Engines: http-native | bun | xitca | monoio | zig"); console.log("Scenarios: static | dynamic | opt"); console.log(""); console.log("Example:"); @@ -30,7 +30,7 @@ function benchmarkPathForScenario(activeScenario) { } async function main() { - if (!["http-native", "bun", "old", "xitca", "monoio", "zig"].includes(engine)) { + if (!["http-native", "bun", "xitca", "monoio", "zig"].includes(engine)) { printUsage(); process.exit(1); } @@ -43,33 +43,33 @@ async function main() { const child = engine === "xitca" || engine === "monoio" ? spawn( - "cargo", - [ - "run", - "--release", - "--manifest-path", - engine === "xitca" - ? "bench/xitca-server/Cargo.toml" - : "bench/monoio-server/Cargo.toml", - "--", - scenario, - String(port), - ], + "cargo", + [ + "run", + "--release", + "--manifest-path", + engine === "xitca" + ? "bench/xitca-server/Cargo.toml" + : "bench/monoio-server/Cargo.toml", + "--", + scenario, + String(port), + ], + { + cwd: process.cwd(), + stdio: ["ignore", "pipe", "inherit"], + }, + ) + : engine === "zig" + ? spawn( + "zig", + ["build", "run", "-Doptimize=ReleaseFast", "--", scenario, String(port)], { - cwd: process.cwd(), + cwd: `${process.cwd()}/bench/zig-httpz`, stdio: ["ignore", "pipe", "inherit"], }, ) - : engine === "zig" - ? spawn( - "zig", - ["build", "run", "-Doptimize=ReleaseFast", "--", scenario, String(port)], - { - cwd: `${process.cwd()}/bench/zig-httpz`, - stdio: ["ignore", "pipe", "inherit"], - }, - ) - : spawn("bun", ["bench/target.js", engine, scenario, String(port)], { + : spawn("bun", ["bench/target.js", engine, scenario, String(port)], { cwd: process.cwd(), stdio: ["ignore", "pipe", "inherit"], }); diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 0000000..56010df --- /dev/null +++ b/examples/README.md @@ -0,0 +1,27 @@ +# http-native Examples + +Each folder contains a self-contained example you can run with `bun`: + +```bash +# Make sure to build the native module first +bun run build + +# Then run any example +bun examples/basic/server.js +bun examples/cors/server.js +bun examples/middleware/server.js +bun examples/rest-api/server.js +bun examples/validation/server.js +bun examples/error-handling/server.js +``` + +## Examples + +| Example | Description | +|---------|-------------| +| **[basic](./basic/)** | Routes, params, query strings, status codes, custom headers | +| **[cors](./cors/)** | CORS with wildcard, specific origins, dynamic origins, credentials | +| **[middleware](./middleware/)** | Logging, path-scoped auth, request IDs, `res.locals` data passing | +| **[rest-api](./rest-api/)** | Full CRUD Todo API — GET, POST, PUT, PATCH, DELETE with body parsing | +| **[validation](./validation/)** | Request body & query validation (Zod-compatible schema interface) | +| **[error-handling](./error-handling/)** | Custom error classes, global `onError()` handler, async error catching | diff --git a/examples/basic/server.js b/examples/basic/server.js new file mode 100644 index 0000000..0ea6f6e --- /dev/null +++ b/examples/basic/server.js @@ -0,0 +1,54 @@ +import { createApp } from "http-native"; + +const app = createApp(); + +// Simple GET route +app.get("/", (req, res) => { + res.json({ message: "Hello, World!", timestamp: Date.now() }); +}); + +// Route with params +app.get("/hello/:name", (req, res) => { + res.json({ greeting: `Hello, ${req.params.name}!` }); +}); + +// Multiple params +app.get("/users/:userId/posts/:postId", (req, res) => { + res.json({ + userId: req.params.userId, + postId: req.params.postId, + }); +}); + +// Query string parsing +app.get("/search", (req, res) => { + const { q, page, limit } = req.query; + res.json({ + query: q ?? "", + page: Number(page) || 1, + limit: Number(limit) || 10, + }); +}); + +// Different status codes +app.get("/not-found", (req, res) => { + res.status(404).json({ error: "Resource not found" }); +}); + +// Custom headers +app.get("/custom-headers", (req, res) => { + res.set("X-Custom-Header", "http-native") + .set("X-Request-Id", crypto.randomUUID()) + .json({ headers: "set!" }); +}); + +const server = await app.listen({ port: 3000 }); +console.log(`🚀 Basic server running at ${server.url}`); +console.log(` +Try these routes: + curl ${server.url}/ + curl ${server.url}/hello/world + curl ${server.url}/users/42/posts/7 + curl ${server.url}/search?q=http-native&page=2&limit=20 + curl -v ${server.url}/custom-headers +`); diff --git a/examples/cors/server.js b/examples/cors/server.js new file mode 100644 index 0000000..f1a200c --- /dev/null +++ b/examples/cors/server.js @@ -0,0 +1,72 @@ +import { createApp } from "http-native"; +import { cors } from "http-native/cors"; + +const app = createApp(); + +// ─── Example 1: Allow all origins ───────────────────────────────────────────── +// app.use(cors()); + +// ─── Example 2: Allow specific origin ───────────────────────────────────────── +// app.use(cors({ origin: "https://myapp.com" })); + +// ─── Example 3: Allow multiple origins ──────────────────────────────────────── +// app.use(cors({ origin: ["https://myapp.com", "https://admin.myapp.com"] })); + +// ─── Example 4: Dynamic origin with credentials ────────────────────────────── +const ALLOWED_ORIGINS = new Set([ + "http://localhost:3000", + "http://localhost:5173", + "https://myapp.com", +]); + +app.use( + cors({ + origin: (requestOrigin) => ALLOWED_ORIGINS.has(requestOrigin), + credentials: true, + allowedHeaders: ["Content-Type", "Authorization", "X-Request-Id"], + exposedHeaders: ["X-Request-Id", "X-RateLimit-Remaining"], + maxAge: 86400, // Cache preflight for 24 hours + }), +); + +// ─── Routes ─────────────────────────────────────────────────────────────────── + +app.get("/api/data", (req, res) => { + res.set("X-Request-Id", crypto.randomUUID()).json({ + data: [ + { id: 1, name: "Item 1" }, + { id: 2, name: "Item 2" }, + ], + }); +}); + +app.post("/api/data", (req, res) => { + const body = req.json(); + res.status(201).json({ created: true, item: body }); +}); + +app.options("/api/data", (req, res) => { + // CORS middleware handles this automatically via preflight + res.status(204).send(); +}); + +const server = await app.listen({ port: 3000 }); +console.log(`🌐 CORS example running at ${server.url}`); +console.log(` +Test with: + # Simple GET (no CORS headers without Origin) + curl -v ${server.url}/api/data + + # CORS preflight + curl -v -X OPTIONS \\ + -H "Origin: http://localhost:5173" \\ + -H "Access-Control-Request-Method: POST" \\ + -H "Access-Control-Request-Headers: Content-Type" \\ + ${server.url}/api/data + + # CORS GET + curl -v -H "Origin: http://localhost:5173" ${server.url}/api/data + + # Disallowed origin (no CORS headers) + curl -v -H "Origin: https://evil.com" ${server.url}/api/data +`); diff --git a/examples/error-handling/server.js b/examples/error-handling/server.js new file mode 100644 index 0000000..f219d36 --- /dev/null +++ b/examples/error-handling/server.js @@ -0,0 +1,130 @@ +import { createApp } from "http-native"; + +const app = createApp(); + +// ─── Custom Error Classes ───────────────────────────────────────────────────── + +class AppError extends Error { + constructor(message, statusCode = 500, code = "INTERNAL_ERROR") { + super(message); + this.statusCode = statusCode; + this.code = code; + } +} + +class NotFoundError extends AppError { + constructor(resource = "Resource") { + super(`${resource} not found`, 404, "NOT_FOUND"); + } +} + +class ValidationError extends AppError { + constructor(message) { + super(message, 400, "VALIDATION_ERROR"); + } +} + +class UnauthorizedError extends AppError { + constructor(message = "Authentication required") { + super(message, 401, "UNAUTHORIZED"); + } +} + +// ─── Global Error Handler ───────────────────────────────────────────────────── + +app.onError((err, req, res) => { + // Known application errors + if (err instanceof AppError) { + console.error(`[${err.code}] ${req.method} ${req.path}: ${err.message}`); + res.status(err.statusCode).json({ + error: { + code: err.code, + message: err.message, + }, + }); + return; + } + + // Unexpected errors — log full stack, return generic message + console.error(`[UNHANDLED] ${req.method} ${req.path}:`, err); + res.status(500).json({ + error: { + code: "INTERNAL_ERROR", + message: + process.env.NODE_ENV === "production" + ? "An unexpected error occurred" + : err.message, + }, + }); +}); + +// ─── Routes ─────────────────────────────────────────────────────────────────── + +app.get("/", (req, res) => { + res.json({ status: "ok" }); +}); + +// Throws a custom NotFoundError +app.get("/users/:id", (req, res) => { + const id = Number(req.params.id); + if (id !== 1) { + throw new NotFoundError("User"); + } + res.json({ id: 1, name: "Alice" }); +}); + +// Throws a ValidationError +app.post("/users", (req, res) => { + const body = req.json(); + if (!body?.name) { + throw new ValidationError("Name is required"); + } + if (body.name.length < 2) { + throw new ValidationError("Name must be at least 2 characters"); + } + res.status(201).json({ id: 2, name: body.name }); +}); + +// Throws an UnauthorizedError +app.get("/admin", (req, res) => { + const token = req.header("authorization"); + if (!token) { + throw new UnauthorizedError(); + } + res.json({ admin: true }); +}); + +// Throws an unhandled error (caught by generic handler) +app.get("/crash", (req, res) => { + throw new TypeError("Cannot read property 'foo' of undefined"); +}); + +// Async error (also caught!) +app.get("/async-crash", async (req, res) => { + await new Promise((resolve) => setTimeout(resolve, 10)); + throw new Error("Async failure"); +}); + +const server = await app.listen({ port: 3000 }); +console.log(`🛡️ Error handling example running at ${server.url}`); +console.log(` +Try these: + # OK + curl ${server.url}/users/1 + + # 404 NotFoundError + curl ${server.url}/users/999 + + # 400 ValidationError + curl -X POST -H "Content-Type: application/json" \\ + -d '{}' ${server.url}/users + + # 401 UnauthorizedError + curl ${server.url}/admin + + # 500 Unhandled error + curl ${server.url}/crash + + # 500 Async error + curl ${server.url}/async-crash +`); diff --git a/examples/middleware/server.js b/examples/middleware/server.js new file mode 100644 index 0000000..a7eade9 --- /dev/null +++ b/examples/middleware/server.js @@ -0,0 +1,106 @@ +import { createApp } from "http-native"; + +const app = createApp(); + +// ─── Logging Middleware (global) ────────────────────────────────────────────── + +app.use(async (req, res, next) => { + const start = performance.now(); + console.log(`→ ${req.method} ${req.path}`); + + await next(); + + const duration = (performance.now() - start).toFixed(2); + console.log(`← ${req.method} ${req.path} [${duration}ms]`); +}); + +// ─── Auth Middleware (scoped to /api) ───────────────────────────────────────── + +app.use("/api", async (req, res, next) => { + const token = req.header("authorization"); + + if (!token || !token.startsWith("Bearer ")) { + res.status(401).json({ + error: "Unauthorized", + message: "Missing or invalid Bearer token", + }); + return; + } + + // Simulate token validation + const payload = token.slice(7); + if (payload === "invalid") { + res.status(403).json({ error: "Forbidden", message: "Token is invalid" }); + return; + } + + // Attach user info to response locals for downstream handlers + res.locals.user = { id: 1, name: "John Doe", token: payload }; + await next(); +}); + +// ─── Request ID Middleware (global) ──────────────────────────────────────────── + +app.use(async (req, res, next) => { + const requestId = crypto.randomUUID(); + res.set("X-Request-Id", requestId); + res.locals.requestId = requestId; + await next(); +}); + +// ─── Routes ─────────────────────────────────────────────────────────────────── + +// Public route (only logging + request ID middlewares run) +app.get("/", (req, res) => { + res.json({ message: "Public endpoint", requestId: res.locals.requestId }); +}); + +// Protected route (logging + request ID + auth middlewares run) +app.get("/api/profile", (req, res) => { + res.json({ + user: res.locals.user, + requestId: res.locals.requestId, + }); +}); + +app.get("/api/secret", (req, res) => { + res.json({ + secret: "The cake is a lie", + accessedBy: res.locals.user.name, + }); +}); + +// ─── Error Handler ──────────────────────────────────────────────────────────── + +app.onError((err, req, res) => { + console.error(`❌ Error on ${req.method} ${req.path}:`, err.message); + res.status(500).json({ + error: "Something went wrong", + requestId: res.locals.requestId, + }); +}); + +// Route that deliberately throws +app.get("/api/crash", (req, res) => { + throw new Error("Simulated server error"); +}); + +const server = await app.listen({ port: 3000 }); +console.log(`🔐 Middleware example running at ${server.url}`); +console.log(` +Try these routes: + # Public (no auth needed) + curl ${server.url}/ + + # Protected (will fail - no token) + curl ${server.url}/api/profile + + # Protected (with valid token) + curl -H "Authorization: Bearer mytoken123" ${server.url}/api/profile + + # Protected (with invalid token) + curl -H "Authorization: Bearer invalid" ${server.url}/api/profile + + # Error handler test + curl -H "Authorization: Bearer mytoken123" ${server.url}/api/crash +`); diff --git a/examples/rest-api/server.js b/examples/rest-api/server.js new file mode 100644 index 0000000..2e0a144 --- /dev/null +++ b/examples/rest-api/server.js @@ -0,0 +1,148 @@ +import { createApp } from "http-native"; + +const app = createApp(); + +// In-memory store +const todos = new Map(); +let nextId = 1; + +// Seed data +todos.set(1, { id: 1, title: "Learn http-native", completed: true }); +todos.set(2, { id: 2, title: "Build something awesome", completed: false }); +todos.set(3, { id: 3, title: "Deploy to production", completed: false }); +nextId = 4; + +// ─── Error Handler ──────────────────────────────────────────────────────────── + +app.onError((err, req, res) => { + console.error(`Error: ${err.message}`); + res.status(500).json({ error: "Internal server error" }); +}); + +// ─── List all todos ─────────────────────────────────────────────────────────── + +app.get("/todos", (req, res) => { + const { completed } = req.query; + let items = [...todos.values()]; + + if (completed === "true") { + items = items.filter((t) => t.completed); + } else if (completed === "false") { + items = items.filter((t) => !t.completed); + } + + res.json({ todos: items, count: items.length }); +}); + +// ─── Get single todo ────────────────────────────────────────────────────────── + +app.get("/todos/:id", (req, res) => { + const id = Number(req.params.id); + const todo = todos.get(id); + + if (!todo) { + res.status(404).json({ error: `Todo #${id} not found` }); + return; + } + + res.json(todo); +}); + +// ─── Create todo ────────────────────────────────────────────────────────────── + +app.post("/todos", (req, res) => { + const body = req.json(); + + if (!body || !body.title) { + res.status(400).json({ error: "Title is required" }); + return; + } + + const todo = { + id: nextId++, + title: String(body.title), + completed: Boolean(body.completed ?? false), + }; + + todos.set(todo.id, todo); + res.status(201).json(todo); +}); + +// ─── Update todo ────────────────────────────────────────────────────────────── + +app.put("/todos/:id", (req, res) => { + const id = Number(req.params.id); + const existing = todos.get(id); + + if (!existing) { + res.status(404).json({ error: `Todo #${id} not found` }); + return; + } + + const body = req.json(); + const updated = { + ...existing, + title: body?.title ?? existing.title, + completed: body?.completed ?? existing.completed, + }; + + todos.set(id, updated); + res.json(updated); +}); + +// ─── Delete todo ────────────────────────────────────────────────────────────── + +app.delete("/todos/:id", (req, res) => { + const id = Number(req.params.id); + + if (!todos.has(id)) { + res.status(404).json({ error: `Todo #${id} not found` }); + return; + } + + todos.delete(id); + res.sendStatus(204); +}); + +// ─── Toggle completion ──────────────────────────────────────────────────────── + +app.patch("/todos/:id/toggle", (req, res) => { + const id = Number(req.params.id); + const todo = todos.get(id); + + if (!todo) { + res.status(404).json({ error: `Todo #${id} not found` }); + return; + } + + todo.completed = !todo.completed; + res.json(todo); +}); + +const server = await app.listen({ port: 3000 }); +console.log(`📝 REST API running at ${server.url}`); +console.log(` +CRUD operations: + # List all + curl ${server.url}/todos + + # Filter completed + curl "${server.url}/todos?completed=false" + + # Get one + curl ${server.url}/todos/1 + + # Create + curl -X POST -H "Content-Type: application/json" \\ + -d '{"title":"New todo"}' ${server.url}/todos + + # Update + curl -X PUT -H "Content-Type: application/json" \\ + -d '{"title":"Updated","completed":true}' ${server.url}/todos/1 + + # Toggle + curl -X PATCH ${server.url}/todos/2/toggle + + # Delete + curl -X DELETE ${server.url}/todos/3 +`); diff --git a/examples/validation/server.js b/examples/validation/server.js new file mode 100644 index 0000000..a3e25dc --- /dev/null +++ b/examples/validation/server.js @@ -0,0 +1,126 @@ +import { createApp } from "http-native"; +import { validate } from "http-native/validate"; + +const app = createApp(); + +// ─── Manual Schema (no external deps) ───────────────────────────────────────── +// +// Works with any object that has .parse() or .safeParse(). +// Below is a minimal hand-rolled schema — in production, use Zod: +// +// import { z } from "zod"; +// const CreateUserSchema = z.object({ +// name: z.string().min(1).max(100), +// email: z.string().email(), +// age: z.number().int().min(0).max(150).optional(), +// }); + +// Minimal schema with .parse() compatible API +function createSchema(validator) { + return { + parse(data) { + const errors = validator(data); + if (errors.length > 0) { + const err = new Error("Validation failed"); + err.issues = errors.map((msg) => ({ path: [], message: msg })); + throw err; + } + return data; + }, + }; +} + +const CreateUserSchema = createSchema((data) => { + const errors = []; + if (!data || typeof data !== "object") { + errors.push("Body must be an object"); + return errors; + } + if (typeof data.name !== "string" || data.name.length === 0) { + errors.push("name is required and must be a non-empty string"); + } + if (typeof data.email !== "string" || !data.email.includes("@")) { + errors.push("email is required and must be a valid email"); + } + if (data.age !== undefined && (typeof data.age !== "number" || data.age < 0)) { + errors.push("age must be a non-negative number"); + } + return errors; +}); + +const QuerySchema = createSchema((data) => { + const errors = []; + if (data.page && isNaN(Number(data.page))) { + errors.push("page must be a number"); + } + if (data.limit && isNaN(Number(data.limit))) { + errors.push("limit must be a number"); + } + return errors; +}); + +// ─── Error Handler ──────────────────────────────────────────────────────────── + +app.onError((err, req, res) => { + console.error(`Error: ${err.message}`); + res.status(500).json({ error: "Internal server error" }); +}); + +// ─── Routes with Validation ────────────────────────────────────────────────── + +// Validates body against CreateUserSchema +app.post( + "/users", + validate({ body: CreateUserSchema }), + (req, res) => { + // req.validatedBody is the parsed & validated data + const user = { + id: crypto.randomUUID(), + ...req.validatedBody, + createdAt: new Date().toISOString(), + }; + + res.status(201).json(user); + }, +); + +// Validates query params +app.get( + "/users", + validate({ query: QuerySchema }), + (req, res) => { + const page = Number(req.query.page) || 1; + const limit = Number(req.query.limit) || 10; + + res.json({ + users: [], + pagination: { page, limit, total: 0 }, + }); + }, +); + +const server = await app.listen({ port: 3000 }); +console.log(`✅ Validation example running at ${server.url}`); +console.log(` +Try these: + # Valid user + curl -X POST -H "Content-Type: application/json" \\ + -d '{"name":"Alice","email":"alice@example.com","age":30}' \\ + ${server.url}/users + + # Missing name (400 error) + curl -X POST -H "Content-Type: application/json" \\ + -d '{"email":"bob@example.com"}' \\ + ${server.url}/users + + # Invalid email (400 error) + curl -X POST -H "Content-Type: application/json" \\ + -d '{"name":"Charlie","email":"not-an-email"}' \\ + ${server.url}/users + + # Valid query + curl "${server.url}/users?page=2&limit=20" + + # Invalid query (400 error) + curl "${server.url}/users?page=abc" +`); diff --git a/package.json b/package.json index b384e73..c86de66 100644 --- a/package.json +++ b/package.json @@ -4,33 +4,33 @@ "type": "module", "private": true, "exports": { - ".": "./src/index.js", + ".": { + "types": "./src/index.d.ts", + "default": "./src/index.js" + }, + "./cors": "./src/cors.js", + "./validate": "./src/validate.js", "./http-server.config": "./src/http-server.config.js" }, "scripts": { "build": "bun scripts/build-native.mjs", "build:release": "bun scripts/build-native.mjs --release", - "build:old:release": "cargo build --release --manifest-path old/native/Cargo.toml", "test": "bun run build && bun test/test.js", - "bench:ci": "bun bench/ci.js", - "bench": "bun run build:release && cargo build --release --manifest-path old/native/Cargo.toml && bun bench/run.js", + "bench": "bun run build:release && bun bench/run.js", "bench:http-native:static": "bun run build:release && bun bench/run.js http-native static 3001", "bench:bun:static": "bun bench/run.js bun static 3000", "bench:xitca:static": "bun bench/run.js xitca static 3003", "bench:monoio:static": "bun bench/run.js monoio static 3004", "bench:zig:static": "bun bench/run.js zig static 3005", - "bench:old:static": "cargo build --release --manifest-path old/native/Cargo.toml && bun bench/run.js old static 3002", "bench:http-native:dynamic": "bun run build:release && bun bench/run.js http-native dynamic 3011", "bench:bun:dynamic": "bun bench/run.js bun dynamic 3010", "bench:xitca:dynamic": "bun bench/run.js xitca dynamic 3013", "bench:monoio:dynamic": "bun bench/run.js monoio dynamic 3014", "bench:zig:dynamic": "bun bench/run.js zig dynamic 3015", - "bench:old:dynamic": "cargo build --release --manifest-path old/native/Cargo.toml && bun bench/run.js old dynamic 3012", "bench:http-native:opt": "bun run build:release && bun bench/run.js http-native opt 3021", "bench:bun:opt": "bun bench/run.js bun opt 3020", "bench:xitca:opt": "bun bench/run.js xitca opt 3023", "bench:monoio:opt": "bun bench/run.js monoio opt 3024", - "bench:zig:opt": "bun bench/run.js zig opt 3025", - "bench:old:opt": "cargo build --release --manifest-path old/native/Cargo.toml && bun bench/run.js old opt 3022" + "bench:zig:opt": "bun bench/run.js zig opt 3025" } -} +} \ No newline at end of file diff --git a/rust-native/Cargo.lock b/rust-native/Cargo.lock index 56e7a75..4148b55 100644 --- a/rust-native/Cargo.lock +++ b/rust-native/Cargo.lock @@ -310,6 +310,8 @@ dependencies = [ "anyhow", "base64", "bytes", + "httparse", + "itoa", "json5", "memchr", "monoio", @@ -322,6 +324,12 @@ dependencies = [ "url", ] +[[package]] +name = "httparse" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" + [[package]] name = "icu_collections" version = "2.1.1" diff --git a/rust-native/Cargo.toml b/rust-native/Cargo.toml index 49a8cc2..9ff68a4 100644 --- a/rust-native/Cargo.toml +++ b/rust-native/Cargo.toml @@ -10,6 +10,8 @@ crate-type = ["cdylib"] anyhow = "1.0" base64 = "0.22" bytes = "1.10" +httparse = "1.9" +itoa = "1.0" json5 = "0.4" memchr = "2.7" monoio = { version = "0.2", features = ["sync"] } @@ -28,3 +30,4 @@ lto = "fat" codegen-units = 1 panic = "abort" strip = "symbols" +opt-level = 3 diff --git a/rust-native/src/lib.rs b/rust-native/src/lib.rs index 0798860..9d6eebf 100644 --- a/rust-native/src/lib.rs +++ b/rust-native/src/lib.rs @@ -4,7 +4,7 @@ mod router; use anyhow::{anyhow, Context, Result}; use bytes::Bytes; -use memchr::{memchr, memmem}; +use memchr::memmem; use monoio::io::{AsyncReadRent, AsyncWriteRent, AsyncWriteRentExt}; use monoio::net::{ListenerOpts, TcpListener, TcpStream}; use napi::bindgen_prelude::{Buffer, Function, Promise}; @@ -12,6 +12,7 @@ use napi::threadsafe_function::ThreadsafeFunction; use napi::{Error, Status}; use napi_derive::napi; use std::borrow::Cow; +use std::cell::RefCell; use std::net::{SocketAddr, ToSocketAddrs}; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{mpsc, Arc, Mutex}; @@ -19,6 +20,8 @@ use std::sync::{mpsc, Arc, Mutex}; use crate::manifest::{HttpServerConfigInput, ManifestInput}; use crate::router::{ExactStaticRoute, MatchedRoute, Router}; +// ─── Constants ──────────────────────────────────────────────────────────────── + const FALLBACK_DEFAULT_HOST: &str = "127.0.0.1"; const FALLBACK_DEFAULT_BACKLOG: i32 = 2048; const FALLBACK_MAX_HEADER_BYTES: usize = 16 * 1024; @@ -29,10 +32,58 @@ const FALLBACK_HEADER_CONTENT_LENGTH_PREFIX: &str = "content-length:"; const FALLBACK_HEADER_TRANSFER_ENCODING_PREFIX: &str = "transfer-encoding:"; const BRIDGE_VERSION: u8 = 1; const REQUEST_FLAG_QUERY_PRESENT: u16 = 1 << 0; +const REQUEST_FLAG_BODY_PRESENT: u16 = 1 << 1; const NOT_FOUND_BODY: &[u8] = br#"{"error":"Route not found"}"#; +/// Security: Maximum number of headers we allow per request +const MAX_HEADER_COUNT: usize = 64; +/// Security: Maximum URL length to prevent abuse +const MAX_URL_LENGTH: usize = 8192; +/// Security: Maximum single header value length +const MAX_HEADER_VALUE_LENGTH: usize = 8192; +/// Security: Maximum request body size (1 MB) +const MAX_BODY_BYTES: usize = 1024 * 1024; + +/// Buffer pool: initial capacity for connection read buffers +const BUFFER_INITIAL_CAPACITY: usize = 8192; +/// Buffer pool: max buffers held per thread +const BUFFER_POOL_MAX_SIZE: usize = 256; +/// Buffer pool: max buffer size to recycle (don't recycle oversized buffers) +const BUFFER_POOL_MAX_RECYCLE_SIZE: usize = 65536; + type DispatchTsfn = ThreadsafeFunction, Buffer, Status, false, false, 0>; +// ─── Thread-Local Buffer Pool ───────────────────────────────────────────────── +// +// Eliminates per-connection Vec allocations by recycling buffers. + +thread_local! { + static BUFFER_POOL: RefCell>> = RefCell::new(Vec::with_capacity(BUFFER_POOL_MAX_SIZE)); +} + +fn acquire_buffer() -> Vec { + BUFFER_POOL.with(|pool| { + pool.borrow_mut() + .pop() + .unwrap_or_else(|| Vec::with_capacity(BUFFER_INITIAL_CAPACITY)) + }) +} + +fn release_buffer(mut buf: Vec) { + if buf.capacity() > BUFFER_POOL_MAX_RECYCLE_SIZE { + return; // Don't recycle oversized buffers + } + buf.clear(); + BUFFER_POOL.with(|pool| { + let mut pool = pool.borrow_mut(); + if pool.len() < BUFFER_POOL_MAX_SIZE { + pool.push(buf); + } + }); +} + +// ─── Server Configuration ───────────────────────────────────────────────────── + #[derive(Clone)] struct HttpServerConfig { default_host: String, @@ -109,6 +160,8 @@ impl HttpServerConfig { } } +// ─── NAPI Interface ─────────────────────────────────────────────────────────── + #[napi(object)] pub struct NativeListenOptions { pub host: Option, @@ -305,6 +358,8 @@ fn worker_count_for(options: &NativeListenOptions) -> usize { .unwrap_or(1) } +// ─── JS Dispatcher ──────────────────────────────────────────────────────────── + struct JsDispatcher { callback: DispatchTsfn, } @@ -323,6 +378,8 @@ impl JsDispatcher { } } +// ─── Server Loop ────────────────────────────────────────────────────────────── + async fn run_server( listener: TcpListener, router: Arc, @@ -341,10 +398,8 @@ async fn run_server( break; } - if should_enable_nodelay() { - if let Err(error) = stream.set_nodelay(true) { - eprintln!("[http-native] failed to enable TCP_NODELAY: {error}"); - } + if let Err(error) = stream.set_nodelay(true) { + eprintln!("[http-native] failed to enable TCP_NODELAY: {error}"); } let router = Arc::clone(&router); @@ -372,90 +427,214 @@ async fn run_server( Ok(()) } +// ─── Parsed Request (from httparse) ─────────────────────────────────────────── + +struct ParsedRequest<'a> { + method: &'a [u8], + target: &'a [u8], + path: &'a [u8], + keep_alive: bool, + header_bytes: usize, + has_body: bool, + content_length: Option, + /// Pre-parsed header pairs — stored once, used by both routing and bridge + headers: Vec<(&'a str, &'a str)>, +} + +// ─── Connection Handler with Buffer Pool ────────────────────────────────────── + async fn handle_connection( mut stream: TcpStream, router: Arc, dispatcher: Arc, server_config: Arc, ) -> Result<()> { - let mut buffer: Vec = Vec::with_capacity(8192); - let mut buffer_start = 0usize; + let mut buffer = acquire_buffer(); + + let result = handle_connection_inner( + &mut stream, + &mut buffer, + &router, + &dispatcher, + &server_config, + ) + .await; + release_buffer(buffer); + result +} + +async fn handle_connection_inner( + stream: &mut TcpStream, + buffer: &mut Vec, + router: &Router, + dispatcher: &JsDispatcher, + server_config: &HttpServerConfig, +) -> Result<()> { loop { - let request_head = loop { - let readable = &buffer[buffer_start..]; - let request_head = if router.exact_get_root().is_some() { - parse_hot_root_request_head(readable, server_config.as_ref()) - .or_else(|| parse_request_head(readable)) + // Try hot-path parsing first (GET / with known prefix) + let parsed = loop { + let result = if router.exact_get_root().is_some() { + parse_hot_root_request(buffer, server_config) + .or_else(|| parse_request_httparse(buffer)) } else { - parse_request_head(readable) + parse_request_httparse(buffer) }; - if let Some(request_head) = request_head { - break request_head; + if let Some(parsed) = result { + break parsed; } - if find_header_end(readable).is_some() { + if find_header_end(buffer).is_some() { + // Headers complete but couldn't parse — malformed request stream.shutdown().await?; return Ok(()); } - compact_read_buffer(&mut buffer, &mut buffer_start); - let (read_result, next_buffer) = stream.read(buffer).await; - buffer = next_buffer; + // SAFETY: We take ownership of the buffer, read into it, then put it back + let owned_buf = std::mem::take(buffer); + let (read_result, next_buffer) = stream.read(owned_buf).await; + *buffer = next_buffer; let bytes_read = read_result?; if bytes_read == 0 { return Ok(()); } - if buffer.len().saturating_sub(buffer_start) > server_config.max_header_bytes { + if buffer.len() > server_config.max_header_bytes { + // Security: Request header too large + let response = build_error_response_bytes( + 431, + b"{\"error\":\"Request Header Fields Too Large\"}", + false, + ); + let (write_result, _) = stream.write_all(response).await; + write_result?; stream.shutdown().await?; return Ok(()); } }; - if request_head.has_body { - stream.shutdown().await?; - return Ok(()); + let header_bytes = parsed.header_bytes; + let keep_alive = parsed.keep_alive; + let has_body = parsed.has_body; + let content_length = parsed.content_length; + + // Extract owned copies from parsed (which borrows buffer) before we mutate buffer + let method_owned: Vec = parsed.method.to_vec(); + let target_owned: Vec = parsed.target.to_vec(); + let path_owned: Vec = parsed.path.to_vec(); + let headers_owned: Vec<(String, String)> = parsed + .headers + .iter() + .map(|(n, v)| (n.to_string(), v.to_string())) + .collect(); + + drop(parsed); + + // ── Fast path: static routes (GET /) ── + if !has_body && method_owned == b"GET" { + if path_owned == b"/" { + if let Some(static_route) = router.exact_get_root() { + drain_consumed_bytes(buffer, header_bytes); + write_exact_static_response(stream, static_route, keep_alive).await?; + if !keep_alive { + stream.shutdown().await?; + return Ok(()); + } + continue; + } + } + if let Some(static_route) = router.exact_static_route(&method_owned, &path_owned) { + drain_consumed_bytes(buffer, header_bytes); + write_exact_static_response(stream, static_route, keep_alive).await?; + if !keep_alive { + stream.shutdown().await?; + return Ok(()); + } + continue; + } } - let header_bytes = request_head.header_bytes; - let keep_alive = request_head.keep_alive; - - if let Some(static_route) = - resolve_static_fast_path(&router, &request_head, server_config.as_ref()) - { - consume_read_buffer(&mut buffer, &mut buffer_start, header_bytes); - write_exact_static_response(&mut stream, static_route, keep_alive).await?; + // ── Read request body if present ────────────────────────────── + let body_bytes: Vec = if has_body { + let content_length = match content_length { + Some(len) => len, + None => { + // Chunked or unknown body length — reject for now + let response = + build_error_response_bytes(411, b"{\"error\":\"Length Required\"}", false); + let (write_result, _) = stream.write_all(response).await; + write_result?; + stream.shutdown().await?; + return Ok(()); + } + }; - if !keep_alive { + // Security: enforce max body size + if content_length > MAX_BODY_BYTES { + let response = + build_error_response_bytes(413, b"{\"error\":\"Payload Too Large\"}", false); + let (write_result, _) = stream.write_all(response).await; + write_result?; stream.shutdown().await?; return Ok(()); } - continue; - } + // Some body bytes may already be in the buffer after the headers + let already_in_buffer = if buffer.len() > header_bytes { + buffer.len() - header_bytes + } else { + 0 + }; + + if already_in_buffer >= content_length { + // Entire body is already in the buffer + let body = buffer[header_bytes..header_bytes + content_length].to_vec(); + drain_consumed_bytes(buffer, header_bytes + content_length); + body + } else { + // Need to read more bytes from the stream + let mut body = Vec::with_capacity(content_length); + if already_in_buffer > 0 { + body.extend_from_slice(&buffer[header_bytes..]); + } + drain_consumed_bytes(buffer, buffer.len()); + + while body.len() < content_length { + let remaining = content_length - body.len(); + let chunk_buf = vec![0u8; remaining.min(65536)]; + let (read_result, returned_buf) = stream.read(chunk_buf).await; + let bytes_read = read_result?; + if bytes_read == 0 { + return Ok(()); // Connection closed mid-body + } + body.extend_from_slice(&returned_buf[..bytes_read]); + } + body.truncate(content_length); + body + } + } else { + drain_consumed_bytes(buffer, header_bytes); + Vec::new() + }; - let dispatch_request = build_manual_dispatch_request( - &router, - &buffer[buffer_start..buffer_start + header_bytes], - &request_head, + // Dynamic path: build bridge envelope and dispatch to JS + let dispatch_request = build_dispatch_request_owned( + router, + &method_owned, + &target_owned, + &path_owned, + &headers_owned, + &body_bytes, )?; - consume_read_buffer(&mut buffer, &mut buffer_start, header_bytes); match dispatch_request { Some(request) => { - write_dynamic_dispatch_response( - &mut stream, - dispatcher.as_ref(), - request, - keep_alive, - ) - .await?; + write_dynamic_dispatch_response(stream, dispatcher, request, keep_alive).await?; } None => { - write_not_found_response(&mut stream, keep_alive).await?; + write_not_found_response(stream, keep_alive).await?; } } @@ -466,185 +645,245 @@ async fn handle_connection( } } -fn resolve_static_fast_path<'a>( - router: &'a Router, - request_head: &RequestHead<'_>, - server_config: &HttpServerConfig, -) -> Option<&'a ExactStaticRoute> { - if request_head.path == b"/" - && request_head.method == b"GET" - && parse_hot_root_request_head_prefix(request_head, server_config) - { - return router.exact_get_root(); - } +// ─── httparse-based Request Parsing ─────────────────────────────────────────── +// +// Uses the battle-tested `httparse` crate for RFC-compliant zero-copy parsing. +// Single-pass: parses headers once and stores them for reuse by both the +// router and the bridge envelope builder. - router.exact_static_route(request_head.method, request_head.path) -} +fn parse_request_httparse(bytes: &[u8]) -> Option> { + let mut raw_headers = [httparse::EMPTY_HEADER; MAX_HEADER_COUNT]; + let mut req = httparse::Request::new(&mut raw_headers); -fn parse_hot_root_request_head_prefix( - request_head: &RequestHead<'_>, - _server_config: &HttpServerConfig, -) -> bool { - request_head.method == b"GET" && request_head.path == b"/" -} + let header_len = match req.parse(bytes) { + Ok(httparse::Status::Complete(len)) => len, + Ok(httparse::Status::Partial) => return None, + Err(_) => return None, // Malformed — caller will handle + }; -fn compact_read_buffer(buffer: &mut Vec, buffer_start: &mut usize) { - if *buffer_start == 0 { - return; - } + let method = req.method?.as_bytes(); + let target = req.path?.as_bytes(); + let version = req.version?; - if *buffer_start >= buffer.len() { - buffer.clear(); - *buffer_start = 0; - return; + // Security: enforce URL length limit + if target.len() > MAX_URL_LENGTH { + return None; } - if *buffer_start < 4096 && buffer.len() < buffer.capacity() { - return; - } + // Extract path (before '?') + let path = target.split(|b| *b == b'?').next()?; - let remaining = buffer.len() - *buffer_start; - buffer.copy_within(*buffer_start.., 0); - buffer.truncate(remaining); - *buffer_start = 0; -} + let mut keep_alive = version >= 1; // HTTP/1.1+ defaults to keep-alive + let mut has_body = false; + let mut content_length: Option = None; + let mut headers = Vec::with_capacity(req.headers.len()); -fn consume_read_buffer(buffer: &mut Vec, buffer_start: &mut usize, consumed: usize) { - *buffer_start = (*buffer_start).saturating_add(consumed); + for header in req.headers.iter() { + if header.name.is_empty() { + break; + } - if *buffer_start >= buffer.len() { - buffer.clear(); - *buffer_start = 0; - } -} + // Security: enforce header value length + if header.value.len() > MAX_HEADER_VALUE_LENGTH { + return None; + } -fn bind_listener( - options: &NativeListenOptions, - server_config: &HttpServerConfig, -) -> Result { - let host = options - .host - .as_deref() - .unwrap_or(server_config.default_host.as_str()); - let bind_addr = resolve_socket_addr(host, options.port) - .with_context(|| format!("failed to resolve bind address {host}:{}", options.port))?; - let mut listener_opts = ListenerOpts::new() - .reuse_addr(true) - .backlog(options.backlog.unwrap_or(server_config.default_backlog)); - if worker_count_for(options) > 1 { - listener_opts = listener_opts.reuse_port(true); - } + let name = header.name; // httparse gives us &str + let value = match std::str::from_utf8(header.value) { + Ok(v) => v, + Err(_) => continue, // Skip non-UTF-8 headers + }; - TcpListener::bind_with_config(bind_addr, &listener_opts) - .with_context(|| format!("failed to bind TCP listener on {bind_addr}")) -} + // Connection handling + if name.eq_ignore_ascii_case("connection") { + let lower = value.to_ascii_lowercase(); + if lower.contains("close") { + keep_alive = false; + } + if lower.contains("keep-alive") { + keep_alive = true; + } + } -fn should_enable_nodelay() -> bool { - std::env::var("HTTP_NATIVE_TCP_NODELAY") - .ok() - .map(|value| { - let normalized = value.trim().to_ascii_lowercase(); - !matches!(normalized.as_str(), "0" | "false" | "off" | "no") - }) - .unwrap_or(true) -} + // Body detection + if name.eq_ignore_ascii_case("content-length") { + let trimmed = value.trim(); + if let Ok(len) = trimmed.parse::() { + content_length = Some(len); + if len > 0 { + has_body = true; + } + } + } -fn resolve_socket_addr(host: &str, port: u16) -> Result { - (host, port) - .to_socket_addrs()? - .next() - .ok_or_else(|| anyhow!("unable to resolve {host}:{port}")) -} + if name.eq_ignore_ascii_case("transfer-encoding") { + let trimmed = value.trim(); + if !trimmed.is_empty() && !trimmed.eq_ignore_ascii_case("identity") { + has_body = true; + } + } -fn validate_manifest(manifest: &ManifestInput) -> Result<()> { - if manifest.version != 1 { - return Err(anyhow!("Unsupported manifest version {}", manifest.version)); + headers.push((name, value)); } - Ok(()) + Some(ParsedRequest { + method, + target, + path, + keep_alive, + header_bytes: header_len, + has_body, + content_length, + headers, + }) } -async fn write_exact_static_response( - stream: &mut TcpStream, - static_route: &ExactStaticRoute, - keep_alive: bool, -) -> Result<()> { - let response = if keep_alive { - static_route.keep_alive_response.clone() +// ─── Hot Root Path (GET /) ──────────────────────────────────────────────────── +// +// Ultra-fast path for the most common benchmark case. Falls back to httparse +// if the request doesn't exactly match the expected prefix. + +fn parse_hot_root_request( + bytes: &[u8], + server_config: &HttpServerConfig, +) -> Option> { + let (_, keep_alive) = if bytes.starts_with(server_config.hot_get_root_http11.as_slice()) { + (server_config.hot_get_root_http11.len(), true) + } else if bytes.starts_with(server_config.hot_get_root_http10.as_slice()) { + (server_config.hot_get_root_http10.len(), false) } else { - static_route.close_response.clone() + return None; }; - let (write_result, _) = stream.write_all(response).await; - write_result?; - Ok(()) -} + let header_end = find_header_end(bytes)?; + let mut keep_alive = keep_alive; + let mut has_body = false; + let mut line_start = bytes.iter().position(|b| *b == b'\n')? + 1; -#[derive(Clone)] -struct DispatchResponseEnvelope { - status: u16, - headers: Vec<(String, String)>, - body: Bytes, -} + while line_start + 2 <= header_end { + let next_end = memmem::find(&bytes[line_start..header_end + 2], b"\r\n")? + line_start; -fn method_code_from_bytes(method: &[u8]) -> Option { - match method { - b"GET" => Some(1), - b"POST" => Some(2), - b"PUT" => Some(3), - b"DELETE" => Some(4), - b"PATCH" => Some(5), - b"OPTIONS" => Some(6), - b"HEAD" => Some(7), - _ => None, + if next_end == line_start { + break; + } + + let line = &bytes[line_start..next_end]; + if line.len() >= server_config.header_connection_prefix.len() + && line[..server_config.header_connection_prefix.len()] + .eq_ignore_ascii_case(server_config.header_connection_prefix.as_slice()) + { + let value = &line[server_config.header_connection_prefix.len()..]; + if contains_ascii_case_insensitive(value, b"close") { + keep_alive = false; + } + if contains_ascii_case_insensitive(value, b"keep-alive") { + keep_alive = true; + } + } else if line.len() >= server_config.header_content_length_prefix.len() + && line[..server_config.header_content_length_prefix.len()] + .eq_ignore_ascii_case(server_config.header_content_length_prefix.as_slice()) + { + let value = + trim_ascii_spaces(&line[server_config.header_content_length_prefix.len()..]); + if value != b"0" { + has_body = true; + } + } else if line.len() >= server_config.header_transfer_encoding_prefix.len() + && line[..server_config.header_transfer_encoding_prefix.len()] + .eq_ignore_ascii_case(server_config.header_transfer_encoding_prefix.as_slice()) + { + let value = + trim_ascii_spaces(&line[server_config.header_transfer_encoding_prefix.len()..]); + if !value.is_empty() && !value.eq_ignore_ascii_case(b"identity") { + has_body = true; + } + } + + line_start = next_end + 2; } + + Some(ParsedRequest { + method: b"GET", + target: b"/", + path: b"/", + keep_alive, + header_bytes: header_end + 4, + has_body, + content_length: None, + headers: Vec::new(), // Hot path: no headers needed for static response + }) } -fn build_manual_dispatch_request( +// ─── Routing ────────────────────────────────────────────────────────────────── + +// ─── Bridge Envelope Building (Single-Pass Headers) ─────────────────────────── +// +// Uses the pre-parsed headers from httparse — no second scan of the raw bytes. + +fn build_dispatch_request_owned( router: &Router, - request_bytes: &[u8], - request_head: &RequestHead<'_>, + method: &[u8], + target: &[u8], + path: &[u8], + headers: &[(String, String)], + body: &[u8], ) -> Result> { - let Some(method_code) = method_code_from_bytes(request_head.method) else { + let Some(method_code) = method_code_from_bytes(method) else { return Ok(None); }; - let path = match std::str::from_utf8(request_head.path) { - Ok(path) => path, + let path_str = match std::str::from_utf8(path) { + Ok(path_str) => path_str, Err(_) => return Ok(None), }; - let url = match std::str::from_utf8(request_head.target) { - Ok(url) => url, + let url_str = match std::str::from_utf8(target) { + Ok(url_str) => url_str, Err(_) => return Ok(None), }; - let normalized_path = normalize_runtime_path(path); + + // Security: strict path validation + let normalized_path = normalize_runtime_path(path_str); + if contains_path_traversal(&normalized_path) { + return Ok(None); + } + let Some(matched_route) = router.match_route(method_code, normalized_path.as_ref()) else { return Ok(None); }; - let header_entries = if matched_route.full_headers || !matched_route.header_keys.is_empty() { - parse_request_header_pairs(request_bytes)? - } else { - Vec::new() - }; - build_dispatch_request_from_pairs(&matched_route, method_code, path, url, &header_entries) - .map(Some) + let header_refs: Vec<(&str, &str)> = headers + .iter() + .map(|(n, v)| (n.as_str(), v.as_str())) + .collect(); + + build_dispatch_envelope( + &matched_route, + method_code, + path_str, + url_str, + &header_refs, + body, + ) + .map(Some) } -fn build_dispatch_request_from_pairs( +fn build_dispatch_envelope( matched_route: &MatchedRoute<'_, '_>, method_code: u8, path: &str, url: &str, header_entries: &[(&str, &str)], + body: &[u8], ) -> Result { let url_bytes = url.as_bytes(); let path_bytes = path.as_bytes(); - let flags = if url.contains('?') { - REQUEST_FLAG_QUERY_PRESENT - } else { - 0 - }; + let mut flags: u16 = 0; + if url.contains('?') { + flags |= REQUEST_FLAG_QUERY_PRESENT; + } + if !body.is_empty() { + flags |= REQUEST_FLAG_BODY_PRESENT; + } if url_bytes.len() > u32::MAX as usize { return Err(anyhow!("request url too large")); @@ -660,8 +899,9 @@ fn build_dispatch_request_from_pairs( return Err(anyhow!("too many headers")); } - let mut frame = - Vec::with_capacity(16 + url_bytes.len() + path_bytes.len() + selected_headers.len() * 16); + let mut frame = Vec::with_capacity( + 20 + url_bytes.len() + path_bytes.len() + selected_headers.len() * 16 + body.len(), + ); frame.push(BRIDGE_VERSION); frame.push(method_code); push_u16(&mut frame, flags); @@ -670,6 +910,7 @@ fn build_dispatch_request_from_pairs( push_u16(&mut frame, path_bytes.len() as u16); push_u16(&mut frame, matched_route.param_values.len() as u16); push_u16(&mut frame, selected_headers.len() as u16); + push_u32(&mut frame, body.len() as u32); // NEW: body length frame.extend_from_slice(url_bytes); frame.extend_from_slice(path_bytes); @@ -681,6 +922,8 @@ fn build_dispatch_request_from_pairs( push_string_pair(&mut frame, name, value)?; } + frame.extend_from_slice(body); // NEW: body bytes at end + Ok(Buffer::from(frame)) } @@ -710,59 +953,29 @@ fn select_header_entries<'a>( selected } -fn parse_dispatch_response(bytes: &[u8]) -> Result { - let mut offset = 0; - let status = read_u16(bytes, &mut offset)?; - let header_count = read_u16(bytes, &mut offset)? as usize; - let body_length = read_u32(bytes, &mut offset)? as usize; - - let mut headers = Vec::with_capacity(header_count); - for _ in 0..header_count { - let name_length = read_u8(bytes, &mut offset)? as usize; - let value_length = read_u16(bytes, &mut offset)? as usize; - let name = read_utf8(bytes, &mut offset, name_length)?; - let value = read_utf8(bytes, &mut offset, value_length)?; - headers.push((name, value)); - } +// ─── Response Writing ───────────────────────────────────────────────────────── - if offset + body_length > bytes.len() { - return Err(anyhow!("response body truncated")); - } +async fn write_exact_static_response( + stream: &mut TcpStream, + static_route: &ExactStaticRoute, + keep_alive: bool, +) -> Result<()> { + let response = if keep_alive { + static_route.keep_alive_response.clone() + } else { + static_route.close_response.clone() + }; - let body = Bytes::copy_from_slice(&bytes[offset..offset + body_length]); - Ok(DispatchResponseEnvelope { - status, - headers, - body, - }) + let (write_result, _) = stream.write_all(response).await; + write_result?; + Ok(()) } -fn parse_request_header_pairs(bytes: &[u8]) -> Result> { - let header_end = find_header_end(bytes).ok_or_else(|| anyhow!("request header incomplete"))?; - let line_end = - memmem::find(bytes, b"\r\n").ok_or_else(|| anyhow!("request line incomplete"))?; - let mut line_start = line_end + 2; - let mut headers = Vec::new(); - - while line_start + 2 <= header_end { - let next_end = memmem::find(&bytes[line_start..header_end + 2], b"\r\n") - .ok_or_else(|| anyhow!("invalid header line"))? - + line_start; - - if next_end == line_start { - break; - } - - let line = &bytes[line_start..next_end]; - let separator = memchr(b':', line).ok_or_else(|| anyhow!("invalid header separator"))?; - let name = std::str::from_utf8(&line[..separator]).context("header name was not utf-8")?; - let value = std::str::from_utf8(trim_ascii_spaces(&line[separator + 1..])) - .context("header value was not utf-8")?; - headers.push((name, value)); - line_start = next_end + 2; - } - - Ok(headers) +#[derive(Clone)] +struct DispatchResponseEnvelope { + status: u16, + headers: Vec<(String, String)>, + body: Bytes, } async fn write_dynamic_dispatch_response( @@ -774,28 +987,24 @@ async fn write_dynamic_dispatch_response( let parsed = match dispatcher.dispatch(request).await { Ok(response) => match parse_dispatch_response(response.as_ref()) { Ok(parsed) => parsed, - Err(error) => DispatchResponseEnvelope { + Err(_) => DispatchResponseEnvelope { status: 500, headers: vec![( "content-type".to_string(), "application/json; charset=utf-8".to_string(), )], - body: Bytes::from(format!( - r#"{{"error":"Invalid response envelope","detail":"{}"}}"#, - escape_json(error.to_string().as_str()) - )), + // Security: sanitized error — no internal details + body: Bytes::from_static(b"{\"error\":\"Internal Server Error\"}"), }, }, - Err(error) => DispatchResponseEnvelope { + Err(_) => DispatchResponseEnvelope { status: 502, headers: vec![( "content-type".to_string(), "application/json; charset=utf-8".to_string(), )], - body: Bytes::from(format!( - r#"{{"error":"Dispatch failed","detail":"{}"}}"#, - escape_json(error.to_string().as_str()) - )), + // Security: sanitized error — no internal details + body: Bytes::from_static(b"{\"error\":\"Bad Gateway\"}"), }, }; @@ -820,6 +1029,19 @@ async fn write_not_found_response(stream: &mut TcpStream, keep_alive: bool) -> R Ok(()) } +/// Build a simple error response without going through the JS bridge +fn build_error_response_bytes(status: u16, body: &[u8], keep_alive: bool) -> Vec { + build_response_bytes( + status, + &[( + "content-type".to_string(), + "application/json; charset=utf-8".to_string(), + )], + Bytes::copy_from_slice(body), + keep_alive, + ) +} + fn build_dispatch_response_bytes(response: DispatchResponseEnvelope, keep_alive: bool) -> Vec { build_response_bytes( response.status, @@ -829,26 +1051,68 @@ fn build_dispatch_response_bytes(response: DispatchResponseEnvelope, keep_alive: ) } +/// Optimized response builder: pre-calculates size and writes in a single pass fn build_response_bytes( status: u16, headers: &[(String, String)], body: Bytes, keep_alive: bool, ) -> Vec { - let mut output = format!( - "HTTP/1.1 {} {}\r\ncontent-length: {}\r\nconnection: {}\r\n", - status, - status_reason(status), - body.len(), - if keep_alive { "keep-alive" } else { "close" } - ) - .into_bytes(); + let reason = status_reason(status); + let connection = if keep_alive { "keep-alive" } else { "close" }; + let body_len = body.len(); + + // Pre-calculate total size to avoid reallocations + // "HTTP/1.1 " + status(3) + " " + reason + "\r\n" + "content-length: " + digits + "\r\n" + "connection: " + conn + "\r\n" + let mut total_size = + 9 + 3 + 1 + reason.len() + 2 + 16 + count_digits(body_len) + 2 + 12 + connection.len() + 2; for (name, value) in headers { if name.eq_ignore_ascii_case("content-length") || name.eq_ignore_ascii_case("connection") { continue; } + // Security: skip headers with CRLF injection + if name.contains('\r') + || name.contains('\n') + || value.contains('\r') + || value.contains('\n') + { + continue; + } + total_size += name.len() + 2 + value.len() + 2; + } + + total_size += 2 + body_len; // final \r\n + body + let mut output = Vec::with_capacity(total_size); + + // Status line + output.extend_from_slice(b"HTTP/1.1 "); + write_u16(&mut output, status); + output.push(b' '); + output.extend_from_slice(reason.as_bytes()); + output.extend_from_slice(b"\r\n"); + + // Mandatory headers + output.extend_from_slice(b"content-length: "); + write_usize(&mut output, body_len); + output.extend_from_slice(b"\r\n"); + output.extend_from_slice(b"connection: "); + output.extend_from_slice(connection.as_bytes()); + output.extend_from_slice(b"\r\n"); + + // User headers + for (name, value) in headers { + if name.eq_ignore_ascii_case("content-length") || name.eq_ignore_ascii_case("connection") { + continue; + } + if name.contains('\r') + || name.contains('\n') + || value.contains('\r') + || value.contains('\n') + { + continue; + } output.extend_from_slice(name.as_bytes()); output.extend_from_slice(b": "); output.extend_from_slice(value.as_bytes()); @@ -860,6 +1124,177 @@ fn build_response_bytes( output } +// ─── Response Parsing (from JS bridge) ──────────────────────────────────────── + +fn parse_dispatch_response(bytes: &[u8]) -> Result { + let mut offset = 0; + let status = read_u16(bytes, &mut offset)?; + let header_count = read_u16(bytes, &mut offset)? as usize; + let body_length = read_u32(bytes, &mut offset)? as usize; + + let mut headers = Vec::with_capacity(header_count); + for _ in 0..header_count { + let name_length = read_u8(bytes, &mut offset)? as usize; + let value_length = read_u16(bytes, &mut offset)? as usize; + let name = read_utf8(bytes, &mut offset, name_length)?; + let value = read_utf8(bytes, &mut offset, value_length)?; + headers.push((name, value)); + } + + if offset + body_length > bytes.len() { + return Err(anyhow!("response body truncated")); + } + + let body = Bytes::copy_from_slice(&bytes[offset..offset + body_length]); + Ok(DispatchResponseEnvelope { + status, + headers, + body, + }) +} + +// ─── Security Utilities ─────────────────────────────────────────────────────── + +/// Check for path traversal attempts (../, ..\, etc.) +fn contains_path_traversal(path: &str) -> bool { + // Decode percent-encoded dots + let decoded = path.replace("%2e", ".").replace("%2E", "."); + + // Check for traversal patterns + decoded.contains("/../") + || decoded.contains("\\..\\") + || decoded.ends_with("/..") + || decoded.ends_with("\\..") + || decoded.starts_with("../") + || decoded.starts_with("..\\") + || decoded == ".." +} + +/// RFC 8259 compliant JSON string escaping — handles ALL control characters +#[allow(dead_code)] +pub(crate) fn escape_json(value: &str) -> String { + let mut output = String::with_capacity(value.len() + 8); + for ch in value.chars() { + match ch { + '"' => output.push_str("\\\""), + '\\' => output.push_str("\\\\"), + '\n' => output.push_str("\\n"), + '\r' => output.push_str("\\r"), + '\t' => output.push_str("\\t"), + '\x08' => output.push_str("\\b"), + '\x0C' => output.push_str("\\f"), + c if c.is_control() => { + output.push_str(&format!("\\u{:04x}", c as u32)); + } + c => output.push(c), + } + } + output +} + +// ─── Helpers ────────────────────────────────────────────────────────────────── + +fn method_code_from_bytes(method: &[u8]) -> Option { + match method { + b"GET" => Some(1), + b"POST" => Some(2), + b"PUT" => Some(3), + b"DELETE" => Some(4), + b"PATCH" => Some(5), + b"OPTIONS" => Some(6), + b"HEAD" => Some(7), + _ => None, + } +} + +fn drain_consumed_bytes(buffer: &mut Vec, consumed: usize) { + if consumed >= buffer.len() { + buffer.clear(); + return; + } + + let remaining = buffer.len() - consumed; + buffer.copy_within(consumed.., 0); + buffer.truncate(remaining); +} + +fn bind_listener( + options: &NativeListenOptions, + server_config: &HttpServerConfig, +) -> Result { + let host = options + .host + .as_deref() + .unwrap_or(server_config.default_host.as_str()); + let bind_addr = resolve_socket_addr(host, options.port) + .with_context(|| format!("failed to resolve bind address {host}:{}", options.port))?; + let listener_opts = ListenerOpts::new() + .reuse_addr(true) + .reuse_port(true) + .backlog(options.backlog.unwrap_or(server_config.default_backlog)); + + TcpListener::bind_with_config(bind_addr, &listener_opts) + .with_context(|| format!("failed to bind TCP listener on {bind_addr}")) +} + +fn resolve_socket_addr(host: &str, port: u16) -> Result { + (host, port) + .to_socket_addrs()? + .next() + .ok_or_else(|| anyhow!("unable to resolve {host}:{port}")) +} + +fn validate_manifest(manifest: &ManifestInput) -> Result<()> { + if manifest.version != 1 { + return Err(anyhow!("Unsupported manifest version {}", manifest.version)); + } + + Ok(()) +} + +fn find_header_end(bytes: &[u8]) -> Option { + memmem::find(bytes, b"\r\n\r\n") +} + +fn contains_ascii_case_insensitive(haystack: &[u8], needle: &[u8]) -> bool { + if needle.is_empty() || haystack.len() < needle.len() { + return false; + } + + haystack + .windows(needle.len()) + .any(|window| window.eq_ignore_ascii_case(needle)) +} + +fn trim_ascii_spaces(bytes: &[u8]) -> &[u8] { + let start = bytes + .iter() + .position(|byte| !byte.is_ascii_whitespace()) + .unwrap_or(bytes.len()); + let end = bytes + .iter() + .rposition(|byte| !byte.is_ascii_whitespace()) + .map(|index| index + 1) + .unwrap_or(start); + &bytes[start..end] +} + +fn normalize_runtime_path(path: &str) -> Cow<'_, str> { + if path == "/" || !path.ends_with('/') { + return Cow::Borrowed(path); + } + + Cow::Owned(crate::analyzer::normalize_path(path)) +} + +fn config_string( + input: Option<&HttpServerConfigInput>, + pick: impl Fn(&HttpServerConfigInput) -> Option<&str>, + fallback: &str, +) -> String { + input.and_then(pick).unwrap_or(fallback).to_string() +} + fn status_reason(status: u16) -> &'static str { match status { 200 => "OK", @@ -868,6 +1303,10 @@ fn status_reason(status: u16) -> &'static str { 204 => "No Content", 400 => "Bad Request", 404 => "Not Found", + 411 => "Length Required", + 413 => "Payload Too Large", + 415 => "Unsupported Media Type", + 431 => "Request Header Fields Too Large", 500 => "Internal Server Error", 502 => "Bad Gateway", 503 => "Service Unavailable", @@ -875,6 +1314,31 @@ fn status_reason(status: u16) -> &'static str { } } +/// Fast integer-to-string for small values — uses stack-allocated itoa buffer +#[inline(always)] +fn write_usize(output: &mut Vec, value: usize) { + let mut buf = itoa::Buffer::new(); + output.extend_from_slice(buf.format(value).as_bytes()); +} + +#[inline(always)] +fn write_u16(output: &mut Vec, value: u16) { + let mut buf = itoa::Buffer::new(); + output.extend_from_slice(buf.format(value).as_bytes()); +} + +fn count_digits(mut n: usize) -> usize { + if n == 0 { + return 1; + } + let mut count = 0; + while n > 0 { + count += 1; + n /= 10; + } + count +} + fn push_string_pair(frame: &mut Vec, name: &str, value: &str) -> Result<()> { if name.len() > u8::MAX as usize { return Err(anyhow!("field name too long")); @@ -955,207 +1419,6 @@ fn read_utf8(bytes: &[u8], offset: &mut usize, length: usize) -> Result Ok(value) } -struct RequestHead<'a> { - method: &'a [u8], - target: &'a [u8], - path: &'a [u8], - keep_alive: bool, - header_bytes: usize, - has_body: bool, -} - -fn parse_request_head(bytes: &[u8]) -> Option> { - let header_end = find_header_end(bytes)?; - let line_end = memmem::find(bytes, b"\r\n")?; - let request_line = &bytes[..line_end]; - - let first_space = memchr(b' ', request_line)?; - let second_space = memchr(b' ', &request_line[first_space + 1..])? + first_space + 1; - - let method = &request_line[..first_space]; - let target = &request_line[first_space + 1..second_space]; - let version = &request_line[second_space + 1..]; - let path = target.split(|byte| *byte == b'?').next()?; - - let (keep_alive, has_body) = scan_header_controls( - bytes, - line_end + 2, - header_end, - version.eq_ignore_ascii_case(b"HTTP/1.1"), - b"connection:", - b"content-length:", - b"transfer-encoding:", - )?; - - Some(RequestHead { - method, - target, - path, - keep_alive, - header_bytes: header_end + 4, - has_body, - }) -} - -fn parse_hot_root_request_head( - bytes: &[u8], - server_config: &HttpServerConfig, -) -> Option> { - let (request_line_len, keep_alive) = - if bytes.starts_with(server_config.hot_get_root_http11.as_slice()) { - (server_config.hot_get_root_http11.len(), true) - } else if bytes.starts_with(server_config.hot_get_root_http10.as_slice()) { - (server_config.hot_get_root_http10.len(), false) - } else { - return None; - }; - - let header_end = find_header_end(bytes)?; - let (keep_alive, has_body) = scan_header_controls( - bytes, - request_line_len, - header_end, - keep_alive, - server_config.header_connection_prefix.as_slice(), - server_config.header_content_length_prefix.as_slice(), - server_config.header_transfer_encoding_prefix.as_slice(), - )?; - - Some(RequestHead { - method: b"GET", - target: b"/", - path: b"/", - keep_alive, - header_bytes: header_end + 4, - has_body, - }) -} - -fn find_header_end(bytes: &[u8]) -> Option { - memmem::find(bytes, b"\r\n\r\n") -} - -fn scan_header_controls( - bytes: &[u8], - mut line_start: usize, - header_end: usize, - mut keep_alive: bool, - connection_prefix: &[u8], - content_length_prefix: &[u8], - transfer_encoding_prefix: &[u8], -) -> Option<(bool, bool)> { - let mut has_body = false; - - while line_start <= header_end { - let next_end = find_line_end(bytes, line_start, header_end)?; - if next_end == line_start { - break; - } - - let line = &bytes[line_start..next_end]; - match ascii_lowercase(line[0]) { - b'c' => { - if line.len() >= connection_prefix.len() - && line[..connection_prefix.len()].eq_ignore_ascii_case(connection_prefix) - { - let value = &line[connection_prefix.len()..]; - if contains_ascii_case_insensitive(value, b"close") { - keep_alive = false; - } - if contains_ascii_case_insensitive(value, b"keep-alive") { - keep_alive = true; - } - } else if line.len() >= content_length_prefix.len() - && line[..content_length_prefix.len()] - .eq_ignore_ascii_case(content_length_prefix) - { - let value = trim_ascii_spaces(&line[content_length_prefix.len()..]); - if value != b"0" { - has_body = true; - } - } - } - b't' => { - if line.len() >= transfer_encoding_prefix.len() - && line[..transfer_encoding_prefix.len()] - .eq_ignore_ascii_case(transfer_encoding_prefix) - { - let value = trim_ascii_spaces(&line[transfer_encoding_prefix.len()..]); - if !value.is_empty() && !value.eq_ignore_ascii_case(b"identity") { - has_body = true; - } - } - } - _ => {} - } - - line_start = next_end + 2; - } - - Some((keep_alive, has_body)) -} - -fn find_line_end(bytes: &[u8], line_start: usize, header_end: usize) -> Option { - let relative_end = memchr(b'\r', &bytes[line_start..header_end + 2])?; - let line_end = line_start + relative_end; - if bytes.get(line_end + 1) != Some(&b'\n') { - return None; - } - - Some(line_end) -} - -fn ascii_lowercase(byte: u8) -> u8 { - if byte.is_ascii_uppercase() { - byte + 32 - } else { - byte - } -} - -fn contains_ascii_case_insensitive(haystack: &[u8], needle: &[u8]) -> bool { - if needle.is_empty() || haystack.len() < needle.len() { - return false; - } - - haystack - .windows(needle.len()) - .any(|window| window.eq_ignore_ascii_case(needle)) -} - -fn trim_ascii_spaces(bytes: &[u8]) -> &[u8] { - let start = bytes - .iter() - .position(|byte| !byte.is_ascii_whitespace()) - .unwrap_or(bytes.len()); - let end = bytes - .iter() - .rposition(|byte| !byte.is_ascii_whitespace()) - .map(|index| index + 1) - .unwrap_or(start); - &bytes[start..end] -} - -fn config_string( - input: Option<&HttpServerConfigInput>, - pick: impl Fn(&HttpServerConfigInput) -> Option<&str>, - fallback: &str, -) -> String { - input.and_then(pick).unwrap_or(fallback).to_string() -} - -fn normalize_runtime_path(path: &str) -> Cow<'_, str> { - if path == "/" || !path.ends_with('/') { - return Cow::Borrowed(path); - } - - Cow::Owned(crate::analyzer::normalize_path(path)) -} - -fn escape_json(value: &str) -> String { - value.replace('\\', "\\\\").replace('"', "\\\"") -} - fn to_napi_error(error: E) -> Error where E: std::fmt::Display, diff --git a/rust-native/src/manifest.rs b/rust-native/src/manifest.rs index 6176c42..c0d45f0 100644 --- a/rust-native/src/manifest.rs +++ b/rust-native/src/manifest.rs @@ -38,6 +38,7 @@ pub struct RouteInput { pub handler_id: u32, pub handler_source: String, pub param_names: Vec, + #[allow(dead_code)] pub segment_count: u16, pub header_keys: Vec, pub full_headers: bool, diff --git a/rust-native/src/router.rs b/rust-native/src/router.rs index 305fba4..070d839 100644 --- a/rust-native/src/router.rs +++ b/rust-native/src/router.rs @@ -10,32 +10,17 @@ use crate::manifest::{ManifestInput, RouteInput}; const ROUTE_KIND_EXACT: u8 = 1; const ROUTE_KIND_PARAM: u8 = 2; +// ─── Public Types ───────────────────────────────────────────────────────────── + #[derive(Clone)] pub struct Router { exact_get_root: Option, + /// O(1) exact-match routes (HashMap>) dynamic_exact_routes: HashMap, DynamicRouteSpec>>, - dynamic_param_routes: HashMap>>, + /// O(1) static-response routes exact_static_routes: HashMap, ExactStaticRoute>>, -} - -#[derive(Clone)] -struct ParamRoute { - spec: DynamicRouteSpec, - segments: Vec, -} - -#[derive(Clone)] -struct DynamicRouteSpec { - handler_id: u32, - param_names: Box<[Box]>, - header_keys: Box<[Box]>, - full_headers: bool, -} - -#[derive(Clone)] -enum CompiledSegment { - Static(Box), - Param, + /// O(M) radix-tree routes per method (M = path length) + radix_trees: HashMap, } #[derive(Clone)] @@ -51,6 +36,17 @@ pub struct MatchedRoute<'a, 'b> { pub full_headers: bool, } +// ─── Internal Types ─────────────────────────────────────────────────────────── + +#[derive(Clone)] +struct DynamicRouteSpec { + handler_id: u32, + #[allow(dead_code)] + param_names: Box<[Box]>, + header_keys: Box<[Box]>, + full_headers: bool, +} + #[derive(Clone, Copy, Eq, Hash, PartialEq)] enum MethodKey { Delete, @@ -62,12 +58,126 @@ enum MethodKey { Put, } +// ─── Radix Tree ─────────────────────────────────────────────────────────────── +// +// Each node represents either a static prefix or a parameter capture. +// Matching is O(M) where M is the number of path segments, not O(N) routes. + +#[derive(Clone)] +struct RadixNode { + children: Vec, + /// If this node is a terminal route, the handler spec + handler: Option, + /// Parameter capture node for this level + param_child: Option>, +} + +#[derive(Clone)] +struct RadixChild { + /// The static segment this child matches + segment: Box, + node: RadixNode, +} + +#[derive(Clone)] +struct RadixParamChild { + node: RadixNode, +} + +impl RadixNode { + fn new() -> Self { + Self { + children: Vec::new(), + handler: None, + param_child: None, + } + } + + /// Insert a route into the radix tree + fn insert(&mut self, segments: &[RouteSegment], spec: DynamicRouteSpec) { + if segments.is_empty() { + self.handler = Some(spec); + return; + } + + match &segments[0] { + RouteSegment::Static(value) => { + // Find existing child with this segment + for child in &mut self.children { + if child.segment.as_ref() == value.as_str() { + child.node.insert(&segments[1..], spec); + return; + } + } + + // Create new child + let mut child_node = RadixNode::new(); + child_node.insert(&segments[1..], spec); + self.children.push(RadixChild { + segment: value.clone().into_boxed_str(), + node: child_node, + }); + } + RouteSegment::Param(_) => { + if self.param_child.is_none() { + self.param_child = Some(Box::new(RadixParamChild { + node: RadixNode::new(), + })); + } + self.param_child + .as_mut() + .unwrap() + .node + .insert(&segments[1..], spec); + } + } + } + + /// Match a request path against this radix tree — O(M) where M = segment count + fn match_path<'a, 'b>( + &'a self, + segments: &[&'b str], + param_values: &mut Vec<&'b str>, + ) -> Option<&'a DynamicRouteSpec> { + if segments.is_empty() { + return self.handler.as_ref(); + } + + let segment = segments[0]; + let rest = &segments[1..]; + + // Try static children first (higher priority) + for child in &self.children { + if child.segment.as_ref() == segment { + if let Some(spec) = child.node.match_path(rest, param_values) { + return Some(spec); + } + } + } + + // Try parameter capture + if let Some(param_child) = &self.param_child { + let prev_len = param_values.len(); + param_values.push(segment); + if let Some(spec) = param_child.node.match_path(rest, param_values) { + return Some(spec); + } + // Backtrack + param_values.truncate(prev_len); + } + + None + } +} + +// ─── Router Implementation ──────────────────────────────────────────────────── + impl Router { pub fn from_manifest(manifest: &ManifestInput) -> Result { let mut exact_get_root = None; let mut dynamic_exact_routes = HashMap::new(); - let mut dynamic_param_routes = HashMap::new(); let mut exact_static_routes = HashMap::new(); + let mut radix_trees: HashMap = HashMap::new(); for route in &manifest.routes { let method = route.method.to_uppercase(); @@ -119,13 +229,13 @@ impl Router { ); } ROUTE_KIND_PARAM => { - let compiled = compile_param_route(route, path.as_str()); - dynamic_param_routes + // Insert into radix tree instead of linear Vec + let segments = parse_segments(path.as_str()); + let spec = compile_dynamic_route_spec(route); + radix_trees .entry(method_key) - .or_insert_with(HashMap::new) - .entry(route.segment_count) - .or_insert_with(Vec::new) - .push(compiled); + .or_insert_with(RadixNode::new) + .insert(&segments, spec); } _ => {} } @@ -134,8 +244,8 @@ impl Router { Ok(Self { exact_get_root, dynamic_exact_routes, - dynamic_param_routes, exact_static_routes, + radix_trees, }) } @@ -146,6 +256,7 @@ impl Router { ) -> Option> { let method_key = MethodKey::from_code(method_code)?; + // Fast path: exact match (O(1)) if let Some(route_spec) = self .dynamic_exact_routes .get(&method_key) @@ -159,40 +270,18 @@ impl Router { }); } - let path_segments = split_request_segments(path); - let param_routes = self - .dynamic_param_routes - .get(&method_key) - .and_then(|routes| routes.get(&(path_segments.len() as u16)))?; - - for route in param_routes { - let mut param_values = Vec::with_capacity(route.spec.param_names.len()); - let mut matched = true; - - for (segment, path_segment) in route.segments.iter().zip(path_segments.iter()) { - match segment { - CompiledSegment::Static(value) if value.as_ref() == *path_segment => {} - CompiledSegment::Static(_) => { - matched = false; - break; - } - CompiledSegment::Param => { - param_values.push(*path_segment); - } - } - } - - if matched { - return Some(MatchedRoute { - handler_id: route.spec.handler_id, - param_values, - header_keys: route.spec.header_keys.as_ref(), - full_headers: route.spec.full_headers, - }); - } - } - - None + // Radix tree match (O(M) where M = segment count) + let segments = split_request_segments(path); + let tree = self.radix_trees.get(&method_key)?; + let mut param_values = Vec::new(); + let spec = tree.match_path(&segments, &mut param_values)?; + + Some(MatchedRoute { + handler_id: spec.handler_id, + param_values, + header_keys: spec.header_keys.as_ref(), + full_headers: spec.full_headers, + }) } pub fn exact_static_route(&self, method: &[u8], path: &[u8]) -> Option<&ExactStaticRoute> { @@ -211,6 +300,8 @@ impl Router { } } +// ─── MethodKey ──────────────────────────────────────────────────────────────── + impl MethodKey { fn from_method_str(method: &str) -> Option { match method { @@ -252,20 +343,7 @@ impl MethodKey { } } -fn compile_param_route(route: &RouteInput, path: &str) -> ParamRoute { - let segments = parse_segments(path) - .into_iter() - .map(|segment| match segment { - RouteSegment::Static(value) => CompiledSegment::Static(value.into_boxed_str()), - RouteSegment::Param(_) => CompiledSegment::Param, - }) - .collect(); - - ParamRoute { - spec: compile_dynamic_route_spec(route), - segments, - } -} +// ─── Helpers ────────────────────────────────────────────────────────────────── fn compile_dynamic_route_spec(route: &RouteInput) -> DynamicRouteSpec { let param_names = route @@ -328,6 +406,10 @@ fn build_response_bytes( .into_bytes(); for (name, value) in headers { + // Security: skip headers with CRLF injection + if name.contains('\r') || name.contains('\n') || value.contains('\r') || value.contains('\n') { + continue; + } response.extend_from_slice(name.as_bytes()); response.extend_from_slice(b": "); response.extend_from_slice(value.as_bytes()); diff --git a/src/bridge.js b/src/bridge.js index 094d230..86a6d55 100644 --- a/src/bridge.js +++ b/src/bridge.js @@ -7,6 +7,7 @@ const EMPTY_ARRAY = Object.freeze([]); export const BRIDGE_VERSION = 1; export const REQUEST_FLAG_QUERY_PRESENT = 1 << 0; +export const REQUEST_FLAG_BODY_PRESENT = 1 << 1; export const METHOD_CODES = Object.freeze({ GET: 1, @@ -23,15 +24,42 @@ export const ROUTE_KIND = Object.freeze({ PARAM: 2, }); -const EMPTY_OBJECT = Object.freeze({}); +// Security: use null-prototype objects for user-facing data to prevent prototype pollution +const EMPTY_OBJECT = Object.freeze(Object.create(null)); + +// ─── Regex patterns for source analysis ─────────────────────────────────────── const PARAM_DOT_RE = /\breq\.params\.([A-Za-z_$][\w$]*)\b/g; -const PARAM_BRACKET_RE = /\breq\.params\[(["'])([^"'\\]+)\1\]/g; +const PARAM_BRACKET_RE = /\breq\.params\[(['"])([^"'\\]+)\1\]/g; const QUERY_DOT_RE = /\breq\.query\.([A-Za-z_$][\w$]*)\b/g; -const QUERY_BRACKET_RE = /\breq\.query\[(["'])([^"'\\]+)\1\]/g; +const QUERY_BRACKET_RE = /\breq\.query\[(['"])([^"'\\]+)\1\]/g; const HEADER_DOT_RE = /\breq\.headers\.([A-Za-z_$][\w$]*)\b/g; -const HEADER_BRACKET_RE = /\breq\.headers\[(["'])([^"'\\]+)\1\]/g; -const HEADER_CALL_RE = /\breq\.header\((["'])([^"'\\]+)\1\)/g; +const HEADER_BRACKET_RE = /\breq\.headers\[(['"])([^"'\\]+)\1\]/g; +const HEADER_CALL_RE = /\breq\.header\((['"])([^"'\\]+)\1\)/g; + +// Pre-compiled patterns for full-access detection (using RegExp constructor to avoid escaping issues) +const PARAMS_FULL_ACCESS_RE = new RegExp('\\breq\\.params\\b(?!\\s*(?:\\.|\\[))'); +const PARAMS_DYN_BRACKET_RE = new RegExp("\\breq\\.params\\[(?!['\"])"); +const QUERY_FULL_ACCESS_RE = new RegExp('\\breq\\.query\\b(?!\\s*(?:\\.|\\[))'); +const QUERY_DYN_BRACKET_RE = new RegExp("\\breq\\.query\\[(?!['\"])"); +const HEADERS_FULL_ACCESS_RE = new RegExp('\\breq\\.headers\\b(?!\\s*(?:\\.|\\[))'); +const HEADERS_DYN_BRACKET_RE = new RegExp("\\breq\\.headers\\[(?!['\"])"); +const REQ_BRACKET_STR_RE = new RegExp("\\breq\\s*\\[(['\"])[^\"'\\\\]+\\1\\]"); +const REQ_BRACKET_DYN_RE = new RegExp("\\breq\\s*\\[(?!['\"])"); +const HEADER_CALL_DYN_RE = new RegExp("\\breq\\.header\\((?!['\"])"); + +// Security: dangerous prototype keys that must never be allowed in user objects +const DANGEROUS_KEYS = new Set([ + "__proto__", + "constructor", + "prototype", + "__defineGetter__", + "__defineSetter__", + "__lookupGetter__", + "__lookupSetter__", +]); + +// ─── Route Compilation ──────────────────────────────────────────────────────── export function compileRouteShape(method, path) { const methodCode = METHOD_CODES[method]; @@ -62,6 +90,8 @@ export function compileRouteShape(method, path) { }; } +// ─── Request Access Analysis ────────────────────────────────────────────────── + export function analyzeRequestAccess(source) { const plan = createEmptyAccessPlan(); const normalizedSource = String(source ?? ""); @@ -88,22 +118,22 @@ export function analyzeRequestAccess(source) { collectMatches(normalizedSource, HEADER_BRACKET_RE, plan.headerKeys, normalizeHeaderLookup, 2); collectMatches(normalizedSource, HEADER_CALL_RE, plan.headerKeys, normalizeHeaderLookup, 2); - if (/\breq\.params\b(?!\s*(?:\.|\[))/.test(normalizedSource) || /\breq\.params\[(?!["'])/.test(normalizedSource)) { + if (PARAMS_FULL_ACCESS_RE.test(normalizedSource) || PARAMS_DYN_BRACKET_RE.test(normalizedSource)) { plan.fullParams = true; plan.dispatchKind = "generic_fallback"; } - if (/\breq\.query\b(?!\s*(?:\.|\[))/.test(normalizedSource) || /\breq\.query\[(?!["'])/.test(normalizedSource)) { + if (QUERY_FULL_ACCESS_RE.test(normalizedSource) || QUERY_DYN_BRACKET_RE.test(normalizedSource)) { plan.fullQuery = true; plan.dispatchKind = "generic_fallback"; } - if (/\breq\.headers\b(?!\s*(?:\.|\[))/.test(normalizedSource) || /\breq\.headers\[(?!["'])/.test(normalizedSource)) { + if (HEADERS_FULL_ACCESS_RE.test(normalizedSource) || HEADERS_DYN_BRACKET_RE.test(normalizedSource)) { plan.fullHeaders = true; plan.dispatchKind = "generic_fallback"; } - if (/\breq\s*\[(["'])[^"'\\]+\1\]/.test(normalizedSource) || /\breq\s*\[(?!["'])/.test(normalizedSource)) { + if (REQ_BRACKET_STR_RE.test(normalizedSource) || REQ_BRACKET_DYN_RE.test(normalizedSource)) { plan.method = true; plan.path = true; plan.url = true; @@ -113,7 +143,7 @@ export function analyzeRequestAccess(source) { plan.dispatchKind = "generic_fallback"; } - if (/\breq\.header\((?!["'])/.test(normalizedSource)) { + if (HEADER_CALL_DYN_RE.test(normalizedSource)) { plan.fullHeaders = true; plan.dispatchKind = "generic_fallback"; } @@ -152,106 +182,211 @@ export function mergeRequestAccessPlans(plans) { return freezeAccessPlan(merged); } -export function createRequestFactory( - plan, - routeParamNames = EMPTY_ARRAY, - routeMethod = "GET", -) { - return function buildRequest(decoded) { - const needsParams = plan.fullParams || plan.paramKeys.size > 0; - const needsQuery = plan.fullQuery || plan.queryKeys.size > 0; - const needsHeaders = plan.fullHeaders || plan.headerKeys.size > 0; - - let path; - let url; - let params; - let query; - let headers; - - function decodePath() { - if (path === undefined) { - path = textDecoder.decode(decoded.pathBytes); +// ─── Request Factory (with Object Pooling) ──────────────────────────────────── + +// Pool for request objects — avoids per-request allocations +const REQUEST_POOL_MAX = 512; +const requestPool = []; + +function acquireRequestObject() { + return requestPool.pop() || null; +} + +function releaseRequestObject(req) { + if (requestPool.length >= REQUEST_POOL_MAX) { + return; + } + // Reset all fields before pooling + req.method = ""; + req._path = undefined; + req._url = undefined; + req._params = undefined; + req._query = undefined; + req._headers = undefined; + req._decoded = null; + req._routeParamNames = null; + req._plan = null; + req._routeMethod = null; + requestPool.push(req); +} + +function createPooledRequest() { + const req = Object.create(null); + + // Internal state + req._path = undefined; + req._url = undefined; + req._params = undefined; + req._query = undefined; + req._headers = undefined; + req._bodyParsed = undefined; + req._decoded = null; + req._routeParamNames = null; + req._plan = null; + req._routeMethod = null; + req.method = ""; + + Object.defineProperty(req, "path", { + configurable: true, + enumerable: true, + get() { + if (req._path === undefined) { + req._path = textDecoder.decode(req._decoded.pathBytes); + } + return req._path; + }, + }); + + Object.defineProperty(req, "url", { + configurable: true, + enumerable: true, + get() { + if (req._url === undefined) { + req._url = textDecoder.decode(req._decoded.urlBytes); + } + return req._url; + }, + }); + + Object.defineProperty(req, "params", { + configurable: true, + enumerable: true, + get() { + if (req._params === undefined) { + const needsParams = req._plan.fullParams || req._plan.paramKeys.size > 0; + req._params = needsParams + ? materializeParamObject(req._decoded.paramValues, req._routeParamNames, req._plan) + : EMPTY_OBJECT; + } + return req._params; + }, + }); + + Object.defineProperty(req, "query", { + configurable: true, + enumerable: true, + get() { + if (req._query === undefined) { + const needsQuery = req._plan.fullQuery || req._plan.queryKeys.size > 0; + if (!needsQuery) { + req._query = EMPTY_OBJECT; + } else { + // Compute URL for query parsing + if (req._url === undefined) { + req._url = textDecoder.decode(req._decoded.urlBytes); + } + req._query = materializeQueryObject(req._url, req._decoded.flags, req._plan); + } + } + return req._query; + }, + }); + + Object.defineProperty(req, "headers", { + configurable: true, + enumerable: true, + get() { + if (req._headers === undefined) { + const needsHeaders = req._plan.fullHeaders || req._plan.headerKeys.size > 0; + req._headers = needsHeaders + ? materializeHeaderObject(req._decoded.rawHeaders, req._plan) + : EMPTY_OBJECT; } - return path; + return req._headers; + }, + }); + + req.header = function header(name) { + const lookup = normalizeHeaderLookup(name); + if (req._headers && lookup in req._headers) { + return req._headers[lookup]; } + if (req._decoded.rawHeaders.length === 0) { + return undefined; + } + return lookupHeaderValue(req._decoded.rawHeaders, lookup); + }; - function decodeUrl() { - if (url === undefined) { - url = textDecoder.decode(decoded.urlBytes); + // ─── Body APIs ──────────────────────────────────────────────────────── + + Object.defineProperty(req, "body", { + configurable: true, + enumerable: true, + get() { + if (req._decoded.bodyBytes === null) { + return null; } - return url; + return Buffer.from(req._decoded.bodyBytes.buffer, req._decoded.bodyBytes.byteOffset, req._decoded.bodyBytes.byteLength); + }, + }); + + req.json = function json() { + if (req._bodyParsed !== undefined) { + return req._bodyParsed; + } + if (req._decoded.bodyBytes === null || req._decoded.bodyBytes.length === 0) { + req._bodyParsed = null; + return null; } + const text = textDecoder.decode(req._decoded.bodyBytes); + req._bodyParsed = JSON.parse(text); + return req._bodyParsed; + }; - const request = { - method: routeMethod, + req.text = function text() { + if (req._decoded.bodyBytes === null || req._decoded.bodyBytes.length === 0) { + return ""; + } + return textDecoder.decode(req._decoded.bodyBytes); + }; - get path() { - return decodePath(); - }, + req.arrayBuffer = function arrayBuffer() { + if (req._decoded.bodyBytes === null) { + return new ArrayBuffer(0); + } + return req._decoded.bodyBytes.buffer.slice( + req._decoded.bodyBytes.byteOffset, + req._decoded.bodyBytes.byteOffset + req._decoded.bodyBytes.byteLength, + ); + }; - get url() { - return decodeUrl(); - }, + return req; +} - get params() { - if (params === undefined) { - params = needsParams - ? materializeParamObject(decoded.paramValues, routeParamNames, plan) - : EMPTY_OBJECT; - } - return params; - }, - - get query() { - if (query === undefined) { - query = needsQuery - ? materializeQueryObject(decodeUrl(), decoded.flags, plan) - : EMPTY_OBJECT; - } - return query; - }, - - get headers() { - if (headers === undefined) { - headers = needsHeaders - ? materializeHeaderObject(decoded.rawHeaders, plan) - : EMPTY_OBJECT; - } - return headers; - }, +export function createRequestFactory( + plan, + routeParamNames = EMPTY_ARRAY, + routeMethod = "GET", +) { + return function buildRequest(decoded) { + let request = acquireRequestObject(); + if (!request) { + request = createPooledRequest(); + } - header(name) { - const lookup = normalizeHeaderLookup(name); - if (headers && lookup in headers) { - return headers[lookup]; - } - if (decoded.rawHeaders.length === 0) { - return undefined; - } - return lookupHeaderValue(decoded.rawHeaders, lookup); - }, - }; + // Initialize for this request + request._decoded = decoded; + request._routeParamNames = routeParamNames; + request._plan = plan; + request._routeMethod = routeMethod; + request.method = routeMethod; + request._path = undefined; + request._url = undefined; + request._params = undefined; + request._query = undefined; + request._headers = undefined; + request._bodyParsed = undefined; return request; }; } -export function createJsonSerializer(mode = "fallback") { - if (mode === "fallback") { - const serializer = (value) => { - const serialized = JSON.stringify(value); - return Buffer.from(serialized, "utf8"); - }; - serializer.kind = "fallback"; - return serializer; - } +// ─── JSON Serialization ─────────────────────────────────────────────────────── +export function createJsonSerializer(mode = "fallback") { + // Performance: V8's native JSON.stringify is heavily optimized and almost always + // faster than any JS-level reimplementation. Use it directly. const serializer = (value) => { - const fastValue = trySerializeJsonFast(value); - if (fastValue !== null) { - return Buffer.from(fastValue, "utf8"); - } - const serialized = JSON.stringify(value); return Buffer.from(serialized, "utf8"); }; @@ -259,6 +394,8 @@ export function createJsonSerializer(mode = "fallback") { return serializer; } +// ─── Binary Protocol Codec ──────────────────────────────────────────────────── + export function decodeRequestEnvelope(buffer) { const bytes = buffer instanceof Uint8Array ? buffer : new Uint8Array(buffer); const view = new DataView(bytes.buffer, bytes.byteOffset, bytes.byteLength); @@ -285,6 +422,9 @@ export function decodeRequestEnvelope(buffer) { const headerCount = readU16(view, offset); offset += 2; + const bodyLength = readU32(view, offset); + offset += 4; + const urlBytes = readBytes(bytes, offset, urlLength); offset += urlLength; const pathBytes = readBytes(bytes, offset, pathLength); @@ -312,6 +452,10 @@ export function decodeRequestEnvelope(buffer) { rawHeaders[index] = [nameBytes, valueBytes]; } + // Read body bytes + const bodyBytes = bodyLength > 0 ? readBytes(bytes, offset, bodyLength) : null; + offset += bodyLength; + switch (methodCode) { case METHOD_CODES.GET: case METHOD_CODES.POST: @@ -333,6 +477,7 @@ export function decodeRequestEnvelope(buffer) { pathBytes, paramValues, rawHeaders, + bodyBytes, }; } @@ -384,6 +529,10 @@ export function encodeResponseEnvelope(snapshot) { return output; } +// ─── Object Materialization (Security-Hardened) ─────────────────────────────── +// +// All user-facing objects use Object.create(null) to prevent prototype pollution. + function materializeParamObject(entries, paramNames, plan) { if (plan.fullParams) { return materializeParamPairs(entries, paramNames); @@ -402,7 +551,7 @@ function materializeHeaderObject(entries, plan) { function materializeQueryObject(url, flags, plan) { if (!(flags & REQUEST_FLAG_QUERY_PRESENT)) { - return {}; + return Object.create(null); } if (plan.fullQuery) { @@ -413,11 +562,16 @@ function materializeQueryObject(url, flags, plan) { } function materializePairs(entries, lowerCaseKeys = false) { - const result = {}; + // Security: null-prototype object prevents prototype pollution + const result = Object.create(null); for (const [rawName, rawValue] of entries) { const name = textDecoder.decode(rawName); const key = lowerCaseKeys ? name.toLowerCase() : name; + // Security: skip dangerous prototype keys + if (DANGEROUS_KEYS.has(key)) { + continue; + } result[key] = textDecoder.decode(rawValue); } @@ -425,10 +579,14 @@ function materializePairs(entries, lowerCaseKeys = false) { } function materializeParamPairs(entries, paramNames) { - const result = {}; + const result = Object.create(null); for (let index = 0; index < entries.length; index += 1) { - result[paramNames[index]] = textDecoder.decode(entries[index]); + const key = paramNames[index]; + if (DANGEROUS_KEYS.has(key)) { + continue; + } + result[key] = textDecoder.decode(entries[index]); } return result; @@ -436,13 +594,13 @@ function materializeParamPairs(entries, paramNames) { function materializeSelectedParamPairs(entries, paramNames, selectedKeys) { if (selectedKeys.size === 0) { - return {}; + return Object.create(null); } - const result = {}; + const result = Object.create(null); for (let index = 0; index < entries.length; index += 1) { const key = paramNames[index]; - if (selectedKeys.has(key)) { + if (selectedKeys.has(key) && !DANGEROUS_KEYS.has(key)) { result[key] = textDecoder.decode(entries[index]); } } @@ -452,14 +610,14 @@ function materializeSelectedParamPairs(entries, paramNames, selectedKeys) { function materializeSelectedPairs(entries, selectedKeys, lowerCaseKeys = false) { if (selectedKeys.size === 0) { - return {}; + return Object.create(null); } - const result = {}; + const result = Object.create(null); for (const [rawName, rawValue] of entries) { const name = textDecoder.decode(rawName); const key = lowerCaseKeys ? name.toLowerCase() : name; - if (selectedKeys.has(key)) { + if (selectedKeys.has(key) && !DANGEROUS_KEYS.has(key)) { result[key] = textDecoder.decode(rawValue); } } @@ -470,13 +628,16 @@ function materializeSelectedPairs(entries, selectedKeys, lowerCaseKeys = false) function parseQuery(url) { const queryStart = url.indexOf("?"); if (queryStart < 0 || queryStart === url.length - 1) { - return {}; + return Object.create(null); } const params = new URLSearchParams(url.slice(queryStart + 1)); - const result = {}; + const result = Object.create(null); for (const [key, value] of params) { + if (DANGEROUS_KEYS.has(key)) { + continue; + } pushQueryEntry(result, key, value); } @@ -485,19 +646,19 @@ function parseQuery(url) { function parseSelectedQuery(url, selectedKeys) { if (selectedKeys.size === 0) { - return {}; + return Object.create(null); } const queryStart = url.indexOf("?"); if (queryStart < 0 || queryStart === url.length - 1) { - return {}; + return Object.create(null); } const params = new URLSearchParams(url.slice(queryStart + 1)); - const result = {}; + const result = Object.create(null); for (const [key, value] of params) { - if (selectedKeys.has(key)) { + if (selectedKeys.has(key) && !DANGEROUS_KEYS.has(key)) { pushQueryEntry(result, key, value); } } @@ -530,6 +691,8 @@ function lookupHeaderValue(entries, targetName) { return undefined; } +// ─── Access Plan ────────────────────────────────────────────────────────────── + function createEmptyAccessPlan() { return { method: false, @@ -590,87 +753,6 @@ function addSetEntries(target, source) { } } -function trySerializeJsonFast(value) { - const stack = new WeakSet(); - return serializeJsonValue(value, stack); -} - -function serializeJsonValue(value, stack) { - if (value === null) { - return "null"; - } - - switch (typeof value) { - case "string": - return JSON.stringify(value); - case "number": - return Number.isFinite(value) ? (Object.is(value, -0) ? "0" : String(value)) : "null"; - case "boolean": - return value ? "true" : "false"; - case "undefined": - case "function": - case "symbol": - return null; - case "bigint": - return null; - case "object": - break; - default: - return null; - } - - if (typeof value.toJSON === "function") { - return null; - } - - if (Array.isArray(value)) { - if (stack.has(value)) { - return null; - } - - stack.add(value); - const items = new Array(value.length); - for (let index = 0; index < value.length; index += 1) { - const descriptor = Object.getOwnPropertyDescriptor(value, String(index)); - if (descriptor && (descriptor.get || descriptor.set)) { - stack.delete(value); - return null; - } - - const serialized = serializeJsonValue(value[index], stack); - items[index] = serialized === null ? "null" : serialized; - } - stack.delete(value); - return `[${items.join(",")}]`; - } - - const prototype = Object.getPrototypeOf(value); - if (prototype !== PLAIN_OBJECT_PROTOTYPE && prototype !== null) { - return null; - } - - if (stack.has(value)) { - return null; - } - - stack.add(value); - const entries = []; - for (const key of Object.keys(value)) { - const descriptor = Object.getOwnPropertyDescriptor(value, key); - if (descriptor && (descriptor.get || descriptor.set)) { - stack.delete(value); - return null; - } - - const serializedValue = serializeJsonValue(value[key], stack); - if (serializedValue !== null) { - entries.push(`${JSON.stringify(key)}:${serializedValue}`); - } - } - stack.delete(value); - return `{${entries.join(",")}}`; -} - function identity(value) { return value; } @@ -679,6 +761,8 @@ function encodeUtf8(value) { return textEncoder.encode(String(value)); } +// ─── Binary Protocol Helpers ────────────────────────────────────────────────── + function readBytes(bytes, offset, length) { if (offset + length > bytes.byteLength) { throw new Error("Request envelope truncated"); diff --git a/src/cors.js b/src/cors.js new file mode 100644 index 0000000..3a36758 --- /dev/null +++ b/src/cors.js @@ -0,0 +1,116 @@ +/** + * http-native CORS middleware + * + * Usage: + * import { cors } from "http-native/cors"; + * app.use(cors({ origin: "*" })); + * app.use(cors({ origin: ["https://example.com"], credentials: true })); + */ + +const DEFAULT_METHODS = "GET,HEAD,PUT,PATCH,POST,DELETE"; +const DEFAULT_HEADERS = "Content-Type,Authorization,Accept,X-Requested-With"; + +/** + * @param {Object} [options] + * @param {string | string[] | ((origin: string) => boolean)} [options.origin="*"] + * @param {string | string[]} [options.methods] + * @param {string | string[]} [options.allowedHeaders] + * @param {string | string[]} [options.exposedHeaders] + * @param {boolean} [options.credentials=false] + * @param {number} [options.maxAge] + * @param {boolean} [options.preflight=true] + */ +export function cors(options = {}) { + const { + origin = "*", + methods = DEFAULT_METHODS, + allowedHeaders = DEFAULT_HEADERS, + exposedHeaders, + credentials = false, + maxAge, + preflight = true, + } = options; + + const methodsString = Array.isArray(methods) ? methods.join(",") : methods; + const allowedHeadersString = Array.isArray(allowedHeaders) + ? allowedHeaders.join(",") + : allowedHeaders; + const exposedHeadersString = exposedHeaders + ? Array.isArray(exposedHeaders) + ? exposedHeaders.join(",") + : exposedHeaders + : null; + + // Pre-compute origin matching function + const matchOrigin = buildOriginMatcher(origin); + + return async function corsMiddleware(req, res, next) { + const requestOrigin = req.header("origin") ?? req.headers?.origin; + + // Not a CORS request + if (!requestOrigin) { + await next(); + return; + } + + const allowed = matchOrigin(requestOrigin); + + if (!allowed) { + await next(); + return; + } + + // Set CORS headers + const effectiveOrigin = + origin === "*" && !credentials ? "*" : requestOrigin; + res.set("Access-Control-Allow-Origin", effectiveOrigin); + + if (credentials) { + res.set("Access-Control-Allow-Credentials", "true"); + } + + if (effectiveOrigin !== "*") { + res.set("Vary", "Origin"); + } + + if (exposedHeadersString) { + res.set("Access-Control-Expose-Headers", exposedHeadersString); + } + + // Handle preflight + if (preflight && req.method === "OPTIONS") { + res.set("Access-Control-Allow-Methods", methodsString); + res.set("Access-Control-Allow-Headers", allowedHeadersString); + + if (maxAge !== undefined) { + res.set("Access-Control-Max-Age", String(maxAge)); + } + + res.status(204).send(); + return; + } + + await next(); + }; +} + +function buildOriginMatcher(origin) { + if (origin === "*") { + return () => true; + } + + if (typeof origin === "function") { + return origin; + } + + if (typeof origin === "string") { + return (requestOrigin) => requestOrigin === origin; + } + + if (Array.isArray(origin)) { + const originSet = new Set(origin); + return (requestOrigin) => originSet.has(requestOrigin); + } + + return () => false; +} diff --git a/src/index.d.ts b/src/index.d.ts new file mode 100644 index 0000000..9bfa5d9 --- /dev/null +++ b/src/index.d.ts @@ -0,0 +1,238 @@ +import { Buffer } from "node:buffer"; + +// ─── Core Types ─────────────────────────────────────────────────────────────── + +export interface Request { + /** HTTP method (GET, POST, PUT, DELETE, PATCH, OPTIONS, HEAD) */ + readonly method: string; + + /** URL path without query string */ + readonly path: string; + + /** Full request URL including query string */ + readonly url: string; + + /** Route parameters extracted from path segments (e.g., /users/:id → { id: "42" }) */ + readonly params: Record; + + /** Parsed query string parameters. Multi-value params are arrays. */ + readonly query: Record; + + /** Lowercase request headers as key-value pairs */ + readonly headers: Record; + + /** Raw request body as a Buffer, or null if no body was sent */ + readonly body: Buffer | null; + + /** Get a specific header value by name (case-insensitive) */ + header(name: string): string | undefined; + + /** Parse the request body as JSON */ + json(): T | null; + + /** Get the request body as a UTF-8 string */ + text(): string; + + /** Get the request body as an ArrayBuffer */ + arrayBuffer(): ArrayBuffer; +} + +export interface Response { + /** Whether the response has already been finalized */ + readonly finished: boolean; + + /** Per-request storage for passing data between middlewares */ + locals: Record; + + /** Set the HTTP status code */ + status(code: number): Response; + + /** Set a response header */ + set(name: string, value: string): Response; + + /** Alias for set() */ + header(name: string, value: string): Response; + + /** Get a response header value */ + get(name: string): string | undefined; + + /** Set the Content-Type header */ + type(value: string): Response; + + /** Send a JSON response with proper Content-Type */ + json(data: unknown): Response; + + /** Send a response body (string, Buffer, or object) */ + send(data?: string | Buffer | Uint8Array | null): Response; + + /** Set status and send status code as text body */ + sendStatus(code: number): Response; +} + +export type NextFunction = () => Promise; + +export type Middleware = ( + req: Request, + res: Response, + next: NextFunction, +) => void | Promise; + +export type RouteHandler = ( + req: Request, + res: Response, +) => void | Promise; + +export type ErrorHandler = ( + error: Error, + req: Request, + res: Response, +) => void | Promise; + +// ─── Listen Options ─────────────────────────────────────────────────────────── + +export interface HttpServerConfig { + defaultHost?: string; + defaultBacklog?: number; + maxHeaderBytes?: number; +} + +export interface ListenOptions { + /** Host to bind to (default: "127.0.0.1") */ + host?: string; + + /** Port to bind to (default: 3000, use 0 for random) */ + port?: number; + + /** TCP listen backlog (default: 2048) */ + backlog?: number; + + /** Override server configuration */ + serverConfig?: HttpServerConfig; + + /** Runtime optimization options */ + opt?: Record; +} + +// ─── Server Handle ──────────────────────────────────────────────────────────── + +export interface ServerHandle { + /** Bound hostname */ + readonly host: string; + + /** Bound port number */ + readonly port: number; + + /** Full URL (http://host:port) */ + readonly url: string; + + /** Runtime optimization introspection */ + readonly optimizations: { + /** Get a snapshot of route optimization state */ + snapshot(): OptimizationSnapshot; + /** Get a human-readable optimization summary */ + summary(): string; + }; + + /** Gracefully close the server */ + close(): void; +} + +export interface OptimizationSnapshot { + routes: RouteOptimizationInfo[]; +} + +export interface RouteOptimizationInfo { + method: string; + path: string; + staticFastPath: boolean; + binaryBridge: boolean; + bridgeObserved: boolean; + cacheCandidate: boolean; + hits: number; + recommendation?: string; + dispatchKind?: string; + jsonFastPath?: string; +} + +// ─── Application ────────────────────────────────────────────────────────────── + +export interface Application { + /** Register path-scoped or global middleware */ + use(middleware: Middleware): Application; + use(path: string, middleware: Middleware): Application; + + /** Register a global error handler */ + onError(handler: ErrorHandler): Application; + + /** Register a GET route handler */ + get(path: string, handler: RouteHandler): Application; + + /** Register a POST route handler */ + post(path: string, handler: RouteHandler): Application; + + /** Register a PUT route handler */ + put(path: string, handler: RouteHandler): Application; + + /** Register a DELETE route handler */ + delete(path: string, handler: RouteHandler): Application; + + /** Register a PATCH route handler */ + patch(path: string, handler: RouteHandler): Application; + + /** Register an OPTIONS route handler */ + options(path: string, handler: RouteHandler): Application; + + /** Register a handler for all HTTP methods */ + all(path: string, handler: RouteHandler): Application; + + /** Start the server and listen for connections */ + listen(options?: ListenOptions): Promise; +} + +/** Create a new http-native application */ +export function createApp(): Application; + +// ─── CORS Types ─────────────────────────────────────────────────────────────── + +export interface CorsOptions { + /** Allowed origin(s). Default: "*" */ + origin?: string | string[] | ((origin: string) => boolean); + + /** Allowed HTTP methods */ + methods?: string | string[]; + + /** Allowed request headers */ + allowedHeaders?: string | string[]; + + /** Headers exposed to the browser */ + exposedHeaders?: string | string[]; + + /** Allow credentials (cookies, authorization) */ + credentials?: boolean; + + /** Cache duration for preflight response (seconds) */ + maxAge?: number; + + /** Handle preflight OPTIONS requests. Default: true */ + preflight?: boolean; +} + +/** Create a CORS middleware */ +export function cors(options?: CorsOptions): Middleware; + +// ─── Validation Types ───────────────────────────────────────────────────────── + +export interface ValidationSchema { + parse(data: unknown): T; +} + +export interface ValidateOptions { + body?: ValidationSchema; + query?: ValidationSchema; + params?: ValidationSchema; +} + +/** Create a validation middleware (works with Zod, TypeBox, or any schema with .parse()) */ +export function validate( + schema: ValidateOptions, +): Middleware; diff --git a/src/index.js b/src/index.js index a9e5788..ca8430a 100644 --- a/src/index.js +++ b/src/index.js @@ -19,6 +19,8 @@ const HTTP_METHODS = ["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"]; const ACTIVE_NATIVE_SERVERS = new Set(); const EMPTY_BUFFER = Buffer.alloc(0); +// ─── Path Normalization ─────────────────────────────────────────────────────── + function normalizePathPrefix(path) { if (path === "/") { return "/"; @@ -64,14 +66,45 @@ function normalizeContentType(type) { return type; } -function createResponseEnvelope(jsonSerializer = createJsonSerializer("fallback")) { - const state = { +// ─── Response Envelope (Pooled) ─────────────────────────────────────────────── + +const RESPONSE_POOL_MAX = 512; +const responsePool = []; + +function acquireResponseState() { + const pooled = responsePool.pop(); + if (pooled) { + pooled.status = 200; + // Reset headers — use null-prototype object for security + for (const key in pooled.headers) { + delete pooled.headers[key]; + } + pooled.body = EMPTY_BUFFER; + pooled.finished = false; + // Reset locals + for (const key in pooled.locals) { + delete pooled.locals[key]; + } + return pooled; + } + + return { status: 200, - headers: {}, + headers: Object.create(null), body: EMPTY_BUFFER, finished: false, - locals: {}, + locals: Object.create(null), }; +} + +function releaseResponseState(state) { + if (responsePool.length < RESPONSE_POOL_MAX) { + responsePool.push(state); + } +} + +function createResponseEnvelope(jsonSerializer = createJsonSerializer("fallback")) { + const state = acquireResponseState(); const response = { locals: state.locals, @@ -86,7 +119,14 @@ function createResponseEnvelope(jsonSerializer = createJsonSerializer("fallback" }, set(name, value) { - state.headers[String(name).toLowerCase()] = String(value); + // Security: validate header name/value for CRLF injection + const headerName = String(name).toLowerCase(); + const headerValue = String(value); + if (headerName.includes("\r") || headerName.includes("\n") || + headerValue.includes("\r") || headerValue.includes("\n")) { + return response; // Silently reject — security + } + state.headers[headerName] = headerValue; return response; }, @@ -161,10 +201,42 @@ function createResponseEnvelope(jsonSerializer = createJsonSerializer("fallback" body: state.body, }; }, + release() { + releaseResponseState(state); + }, }; } +// ─── Compiled Middleware Runner ──────────────────────────────────────────────── +// +// Generates an optimized runner that avoids function.length checks at runtime +// by pre-classifying middlewares during compilation. + function createMiddlewareRunner(middlewares) { + if (middlewares.length === 0) { + // Fast path: no middlewares — return a no-op + return async function noopMiddleware(_req, _res) {}; + } + + if (middlewares.length === 1) { + // Fast path: single middleware — avoid dispatch overhead + const mw = middlewares[0]; + if (mw.handler.length >= 3) { + return async function runSingleMiddleware(req, res) { + await mw.handler(req, res, () => Promise.resolve()); + }; + } + return async function runSingleMiddleware(req, res) { + await mw.handler(req, res); + }; + } + + // Pre-classify each middleware as "next-aware" or "auto-advance" + const classified = middlewares.map((mw) => ({ + handler: mw.handler, + needsNext: mw.handler.length >= 3, + })); + return async function runCompiledMiddlewares(req, res) { let index = -1; @@ -174,12 +246,12 @@ function createMiddlewareRunner(middlewares) { } index = position; - const middleware = middlewares[position]; + const middleware = classified[position]; if (!middleware || res.finished) { return; } - if (middleware.handler.length >= 3) { + if (middleware.needsNext) { await middleware.handler(req, res, () => dispatch(position + 1)); return; } @@ -194,23 +266,30 @@ function createMiddlewareRunner(middlewares) { }; } +// ─── Error Handling (Security-Hardened) ─────────────────────────────────────── + function serializeErrorResponse(error) { + // Security: NEVER leak internal error details to the client + const isProduction = process.env.NODE_ENV === "production"; + const body = isProduction + ? { error: "Internal Server Error" } + : { + error: "Internal Server Error", + detail: error instanceof Error ? error.message : String(error), + }; + return encodeResponseEnvelope({ status: 500, headers: { "content-type": "application/json; charset=utf-8", }, - body: Buffer.from( - JSON.stringify({ - error: "Internal Server Error", - detail: error instanceof Error ? error.message : String(error), - }), - "utf8", - ), + body: Buffer.from(JSON.stringify(body), "utf8"), }); } -function createDispatcher(compiledRoutes, runtimeOptimizer) { +// ─── Dispatcher ─────────────────────────────────────────────────────────────── + +function createDispatcher(compiledRoutes, runtimeOptimizer, errorHandlers = []) { const routesById = new Map(compiledRoutes.map((route) => [route.handlerId, route])); return async function dispatch(requestBuffer) { @@ -228,7 +307,7 @@ function createDispatcher(compiledRoutes, runtimeOptimizer) { } const req = route.requestFactory(decoded); - const { response: res, snapshot } = createResponseEnvelope(route.jsonSerializer); + const { response: res, snapshot, release } = createResponseEnvelope(route.jsonSerializer); try { await route.runMiddlewares(req, res); @@ -237,16 +316,38 @@ function createDispatcher(compiledRoutes, runtimeOptimizer) { } } catch (error) { if (!res.finished) { - return serializeErrorResponse(error); + // Try registered error handlers first + for (const errorHandler of errorHandlers) { + try { + await errorHandler(error, req, res); + if (res.finished) { + break; + } + } catch (handlerError) { + // Error handler itself threw — fall through to default + release(); + return serializeErrorResponse(handlerError); + } + } + + // If no error handler responded, use default + if (!res.finished) { + release(); + return serializeErrorResponse(error); + } } } const responseSnapshot = snapshot(); runtimeOptimizer?.recordDispatch(route, req, responseSnapshot); - return encodeResponseEnvelope(responseSnapshot); + const encoded = encodeResponseEnvelope(responseSnapshot); + release(); + return encoded; }; } +// ─── Route Registration & Compilation ───────────────────────────────────────── + function normalizeRouteRegistration(method, path, handler) { if (typeof handler !== "function") { throw new TypeError(`Handler for ${method} ${path} must be a function`); @@ -334,6 +435,8 @@ function normalizeListenOptions(options = {}) { }; } +// ─── Application Factory ───────────────────────────────────────────────────── + export function createApp() { const native = loadNativeModule(); let nextHandlerId = 1; @@ -341,6 +444,7 @@ export function createApp() { const app = { _routes: [], _middlewares: [], + _errorHandlers: [], use(pathOrMiddleware, maybeMiddleware) { let pathPrefix = "/"; @@ -359,6 +463,14 @@ export function createApp() { return this; }, + onError(handler) { + if (typeof handler !== "function") { + throw new TypeError("Error handler must be a function"); + } + this._errorHandlers.push(handler); + return this; + }, + get: undefined, post: undefined, put: undefined, @@ -411,7 +523,7 @@ export function createApp() { compiledMiddlewares, normalizedOptions.opt, ); - const dispatcher = createDispatcher(compiledRoutes, runtimeOptimizer); + const dispatcher = createDispatcher(compiledRoutes, runtimeOptimizer, this._errorHandlers); const handle = native.startServer(JSON.stringify(manifest), dispatcher, { host: normalizedOptions.host, port: normalizedOptions.port, diff --git a/src/validate.js b/src/validate.js new file mode 100644 index 0000000..e6b26b5 --- /dev/null +++ b/src/validate.js @@ -0,0 +1,152 @@ +/** + * http-native Validation Middleware + * + * Schema-agnostic: works with Zod, TypeBox, Yup, Joi, or any object with .parse() + * + * Usage: + * import { validate } from "http-native/validate"; + * import { z } from "zod"; + * + * app.post("/users", validate({ + * body: z.object({ name: z.string(), email: z.string().email() }), + * }), async (req, res) => { + * const { name, email } = req.validatedBody; + * res.json({ ok: true, name, email }); + * }); + */ + +/** + * @param {Object} schema + * @param {Object} [schema.body] - Schema to validate req.json() against + * @param {Object} [schema.query] - Schema to validate req.query against + * @param {Object} [schema.params] - Schema to validate req.params against + */ +export function validate(schema = {}) { + const { body: bodySchema, query: querySchema, params: paramsSchema } = schema; + + return async function validationMiddleware(req, res, next) { + try { + // Validate params + if (paramsSchema) { + const result = parseSchema(paramsSchema, req.params, "params"); + if (result.error) { + res.status(400).json({ + error: "Validation Error", + field: "params", + details: result.error, + }); + return; + } + req.validatedParams = result.value; + } + + // Validate query + if (querySchema) { + const result = parseSchema(querySchema, req.query, "query"); + if (result.error) { + res.status(400).json({ + error: "Validation Error", + field: "query", + details: result.error, + }); + return; + } + req.validatedQuery = result.value; + } + + // Validate body + if (bodySchema) { + const bodyData = req.json(); + if (bodyData === null && bodySchema) { + res.status(400).json({ + error: "Validation Error", + field: "body", + details: "Request body is required", + }); + return; + } + + const result = parseSchema(bodySchema, bodyData, "body"); + if (result.error) { + res.status(400).json({ + error: "Validation Error", + field: "body", + details: result.error, + }); + return; + } + req.validatedBody = result.value; + } + + await next(); + } catch (error) { + // JSON parse error or schema error + res.status(400).json({ + error: "Validation Error", + details: error instanceof Error ? error.message : String(error), + }); + } + }; +} + +/** + * Schema-agnostic parser. Supports: + * - Zod: schema.parse() throws ZodError + * - Zod safe: schema.safeParse() returns { success, data, error } + * - TypeBox/Ajv: schema.parse() or custom + * - Any object with .parse(data) that returns the parsed value or throws + */ +function parseSchema(schema, data, _fieldName) { + // Zod-style safeParse + if (typeof schema.safeParse === "function") { + const result = schema.safeParse(data); + if (result.success) { + return { value: result.data, error: null }; + } + // Zod error format + const details = result.error?.issues + ? result.error.issues.map((issue) => ({ + path: issue.path?.join(".") ?? "", + message: issue.message, + })) + : result.error?.message ?? "Validation failed"; + return { value: null, error: details }; + } + + // Standard .parse() — throws on error + if (typeof schema.parse === "function") { + try { + const value = schema.parse(data); + return { value, error: null }; + } catch (error) { + // Zod throws ZodError with .issues + if (error?.issues) { + const details = error.issues.map((issue) => ({ + path: issue.path?.join(".") ?? "", + message: issue.message, + })); + return { value: null, error: details }; + } + return { value: null, error: error?.message ?? "Validation failed" }; + } + } + + // Joi-style .validate() + if (typeof schema.validate === "function") { + const result = schema.validate(data); + if (result.error) { + const details = result.error.details + ? result.error.details.map((detail) => ({ + path: detail.path?.join(".") ?? "", + message: detail.message, + })) + : result.error.message ?? "Validation failed"; + return { value: null, error: details }; + } + return { value: result.value, error: null }; + } + + throw new TypeError( + "Schema must have a .parse(), .safeParse(), or .validate() method", + ); +} diff --git a/testing/better-express b/testing/better-express deleted file mode 160000 index 30254d8..0000000 --- a/testing/better-express +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 30254d89205a85577533a5b2e2d22704d640ec38