From 9bbc9782d919ff1f150db8410db637efd12e5ea7 Mon Sep 17 00:00:00 2001 From: Rishi Yadav Date: Fri, 27 Mar 2026 20:19:58 +0530 Subject: [PATCH 1/4] feat: enterprise hardening - request bodies, radix router, and error middleware --- bench/run.js | 50 +- examples/README.md | 27 + examples/basic/server.js | 54 ++ examples/cors/server.js | 72 ++ examples/error-handling/server.js | 130 +++ examples/middleware/server.js | 106 +++ examples/rest-api/server.js | 148 +++ examples/validation/server.js | 126 +++ package.json | 15 +- rust-native/Cargo.lock | 8 + rust-native/Cargo.toml | 3 + rust-native/src/lib.rs | 1458 ++++++++++------------------- rust-native/src/router.rs | 235 +++-- src/bridge.js | 464 +++++---- src/cors.js | 116 +++ src/index.d.ts | 238 +++++ src/index.js | 150 ++- src/validate.js | 152 +++ testing/better-express | 1 - 19 files changed, 2267 insertions(+), 1286 deletions(-) create mode 100644 examples/README.md create mode 100644 examples/basic/server.js create mode 100644 examples/cors/server.js create mode 100644 examples/error-handling/server.js create mode 100644 examples/middleware/server.js create mode 100644 examples/rest-api/server.js create mode 100644 examples/validation/server.js create mode 100644 src/cors.js create mode 100644 src/index.d.ts create mode 100644 src/validate.js delete mode 160000 testing/better-express diff --git a/bench/run.js b/bench/run.js index 9ef81da..3eeac5b 100644 --- a/bench/run.js +++ b/bench/run.js @@ -9,7 +9,7 @@ const port = Number(portArg ?? 3001); function printUsage() { console.log("Usage: bun bench/run.js "); - console.log("Engines: http-native | bun | old | xitca | monoio | zig"); + console.log("Engines: http-native | bun | xitca | monoio | zig"); console.log("Scenarios: static | dynamic | opt"); console.log(""); console.log("Example:"); @@ -30,7 +30,7 @@ function benchmarkPathForScenario(activeScenario) { } async function main() { - if (!["http-native", "bun", "old", "xitca", "monoio", "zig"].includes(engine)) { + if (!["http-native", "bun", "xitca", "monoio", "zig"].includes(engine)) { printUsage(); process.exit(1); } @@ -43,33 +43,33 @@ async function main() { const child = engine === "xitca" || engine === "monoio" ? spawn( - "cargo", - [ - "run", - "--release", - "--manifest-path", - engine === "xitca" - ? "bench/xitca-server/Cargo.toml" - : "bench/monoio-server/Cargo.toml", - "--", - scenario, - String(port), - ], + "cargo", + [ + "run", + "--release", + "--manifest-path", + engine === "xitca" + ? "bench/xitca-server/Cargo.toml" + : "bench/monoio-server/Cargo.toml", + "--", + scenario, + String(port), + ], + { + cwd: process.cwd(), + stdio: ["ignore", "pipe", "inherit"], + }, + ) + : engine === "zig" + ? spawn( + "zig", + ["build", "run", "-Doptimize=ReleaseFast", "--", scenario, String(port)], { - cwd: process.cwd(), + cwd: `${process.cwd()}/bench/zig-httpz`, stdio: ["ignore", "pipe", "inherit"], }, ) - : engine === "zig" - ? spawn( - "zig", - ["build", "run", "-Doptimize=ReleaseFast", "--", scenario, String(port)], - { - cwd: `${process.cwd()}/bench/zig-httpz`, - stdio: ["ignore", "pipe", "inherit"], - }, - ) - : spawn("bun", ["bench/target.js", engine, scenario, String(port)], { + : spawn("bun", ["bench/target.js", engine, scenario, String(port)], { cwd: process.cwd(), stdio: ["ignore", "pipe", "inherit"], }); diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 0000000..56010df --- /dev/null +++ b/examples/README.md @@ -0,0 +1,27 @@ +# http-native Examples + +Each folder contains a self-contained example you can run with `bun`: + +```bash +# Make sure to build the native module first +bun run build + +# Then run any example +bun examples/basic/server.js +bun examples/cors/server.js +bun examples/middleware/server.js +bun examples/rest-api/server.js +bun examples/validation/server.js +bun examples/error-handling/server.js +``` + +## Examples + +| Example | Description | +|---------|-------------| +| **[basic](./basic/)** | Routes, params, query strings, status codes, custom headers | +| **[cors](./cors/)** | CORS with wildcard, specific origins, dynamic origins, credentials | +| **[middleware](./middleware/)** | Logging, path-scoped auth, request IDs, `res.locals` data passing | +| **[rest-api](./rest-api/)** | Full CRUD Todo API — GET, POST, PUT, PATCH, DELETE with body parsing | +| **[validation](./validation/)** | Request body & query validation (Zod-compatible schema interface) | +| **[error-handling](./error-handling/)** | Custom error classes, global `onError()` handler, async error catching | diff --git a/examples/basic/server.js b/examples/basic/server.js new file mode 100644 index 0000000..0ea6f6e --- /dev/null +++ b/examples/basic/server.js @@ -0,0 +1,54 @@ +import { createApp } from "http-native"; + +const app = createApp(); + +// Simple GET route +app.get("/", (req, res) => { + res.json({ message: "Hello, World!", timestamp: Date.now() }); +}); + +// Route with params +app.get("/hello/:name", (req, res) => { + res.json({ greeting: `Hello, ${req.params.name}!` }); +}); + +// Multiple params +app.get("/users/:userId/posts/:postId", (req, res) => { + res.json({ + userId: req.params.userId, + postId: req.params.postId, + }); +}); + +// Query string parsing +app.get("/search", (req, res) => { + const { q, page, limit } = req.query; + res.json({ + query: q ?? "", + page: Number(page) || 1, + limit: Number(limit) || 10, + }); +}); + +// Different status codes +app.get("/not-found", (req, res) => { + res.status(404).json({ error: "Resource not found" }); +}); + +// Custom headers +app.get("/custom-headers", (req, res) => { + res.set("X-Custom-Header", "http-native") + .set("X-Request-Id", crypto.randomUUID()) + .json({ headers: "set!" }); +}); + +const server = await app.listen({ port: 3000 }); +console.log(`🚀 Basic server running at ${server.url}`); +console.log(` +Try these routes: + curl ${server.url}/ + curl ${server.url}/hello/world + curl ${server.url}/users/42/posts/7 + curl ${server.url}/search?q=http-native&page=2&limit=20 + curl -v ${server.url}/custom-headers +`); diff --git a/examples/cors/server.js b/examples/cors/server.js new file mode 100644 index 0000000..f1a200c --- /dev/null +++ b/examples/cors/server.js @@ -0,0 +1,72 @@ +import { createApp } from "http-native"; +import { cors } from "http-native/cors"; + +const app = createApp(); + +// ─── Example 1: Allow all origins ───────────────────────────────────────────── +// app.use(cors()); + +// ─── Example 2: Allow specific origin ───────────────────────────────────────── +// app.use(cors({ origin: "https://myapp.com" })); + +// ─── Example 3: Allow multiple origins ──────────────────────────────────────── +// app.use(cors({ origin: ["https://myapp.com", "https://admin.myapp.com"] })); + +// ─── Example 4: Dynamic origin with credentials ────────────────────────────── +const ALLOWED_ORIGINS = new Set([ + "http://localhost:3000", + "http://localhost:5173", + "https://myapp.com", +]); + +app.use( + cors({ + origin: (requestOrigin) => ALLOWED_ORIGINS.has(requestOrigin), + credentials: true, + allowedHeaders: ["Content-Type", "Authorization", "X-Request-Id"], + exposedHeaders: ["X-Request-Id", "X-RateLimit-Remaining"], + maxAge: 86400, // Cache preflight for 24 hours + }), +); + +// ─── Routes ─────────────────────────────────────────────────────────────────── + +app.get("/api/data", (req, res) => { + res.set("X-Request-Id", crypto.randomUUID()).json({ + data: [ + { id: 1, name: "Item 1" }, + { id: 2, name: "Item 2" }, + ], + }); +}); + +app.post("/api/data", (req, res) => { + const body = req.json(); + res.status(201).json({ created: true, item: body }); +}); + +app.options("/api/data", (req, res) => { + // CORS middleware handles this automatically via preflight + res.status(204).send(); +}); + +const server = await app.listen({ port: 3000 }); +console.log(`🌐 CORS example running at ${server.url}`); +console.log(` +Test with: + # Simple GET (no CORS headers without Origin) + curl -v ${server.url}/api/data + + # CORS preflight + curl -v -X OPTIONS \\ + -H "Origin: http://localhost:5173" \\ + -H "Access-Control-Request-Method: POST" \\ + -H "Access-Control-Request-Headers: Content-Type" \\ + ${server.url}/api/data + + # CORS GET + curl -v -H "Origin: http://localhost:5173" ${server.url}/api/data + + # Disallowed origin (no CORS headers) + curl -v -H "Origin: https://evil.com" ${server.url}/api/data +`); diff --git a/examples/error-handling/server.js b/examples/error-handling/server.js new file mode 100644 index 0000000..f219d36 --- /dev/null +++ b/examples/error-handling/server.js @@ -0,0 +1,130 @@ +import { createApp } from "http-native"; + +const app = createApp(); + +// ─── Custom Error Classes ───────────────────────────────────────────────────── + +class AppError extends Error { + constructor(message, statusCode = 500, code = "INTERNAL_ERROR") { + super(message); + this.statusCode = statusCode; + this.code = code; + } +} + +class NotFoundError extends AppError { + constructor(resource = "Resource") { + super(`${resource} not found`, 404, "NOT_FOUND"); + } +} + +class ValidationError extends AppError { + constructor(message) { + super(message, 400, "VALIDATION_ERROR"); + } +} + +class UnauthorizedError extends AppError { + constructor(message = "Authentication required") { + super(message, 401, "UNAUTHORIZED"); + } +} + +// ─── Global Error Handler ───────────────────────────────────────────────────── + +app.onError((err, req, res) => { + // Known application errors + if (err instanceof AppError) { + console.error(`[${err.code}] ${req.method} ${req.path}: ${err.message}`); + res.status(err.statusCode).json({ + error: { + code: err.code, + message: err.message, + }, + }); + return; + } + + // Unexpected errors — log full stack, return generic message + console.error(`[UNHANDLED] ${req.method} ${req.path}:`, err); + res.status(500).json({ + error: { + code: "INTERNAL_ERROR", + message: + process.env.NODE_ENV === "production" + ? "An unexpected error occurred" + : err.message, + }, + }); +}); + +// ─── Routes ─────────────────────────────────────────────────────────────────── + +app.get("/", (req, res) => { + res.json({ status: "ok" }); +}); + +// Throws a custom NotFoundError +app.get("/users/:id", (req, res) => { + const id = Number(req.params.id); + if (id !== 1) { + throw new NotFoundError("User"); + } + res.json({ id: 1, name: "Alice" }); +}); + +// Throws a ValidationError +app.post("/users", (req, res) => { + const body = req.json(); + if (!body?.name) { + throw new ValidationError("Name is required"); + } + if (body.name.length < 2) { + throw new ValidationError("Name must be at least 2 characters"); + } + res.status(201).json({ id: 2, name: body.name }); +}); + +// Throws an UnauthorizedError +app.get("/admin", (req, res) => { + const token = req.header("authorization"); + if (!token) { + throw new UnauthorizedError(); + } + res.json({ admin: true }); +}); + +// Throws an unhandled error (caught by generic handler) +app.get("/crash", (req, res) => { + throw new TypeError("Cannot read property 'foo' of undefined"); +}); + +// Async error (also caught!) +app.get("/async-crash", async (req, res) => { + await new Promise((resolve) => setTimeout(resolve, 10)); + throw new Error("Async failure"); +}); + +const server = await app.listen({ port: 3000 }); +console.log(`🛡️ Error handling example running at ${server.url}`); +console.log(` +Try these: + # OK + curl ${server.url}/users/1 + + # 404 NotFoundError + curl ${server.url}/users/999 + + # 400 ValidationError + curl -X POST -H "Content-Type: application/json" \\ + -d '{}' ${server.url}/users + + # 401 UnauthorizedError + curl ${server.url}/admin + + # 500 Unhandled error + curl ${server.url}/crash + + # 500 Async error + curl ${server.url}/async-crash +`); diff --git a/examples/middleware/server.js b/examples/middleware/server.js new file mode 100644 index 0000000..a7eade9 --- /dev/null +++ b/examples/middleware/server.js @@ -0,0 +1,106 @@ +import { createApp } from "http-native"; + +const app = createApp(); + +// ─── Logging Middleware (global) ────────────────────────────────────────────── + +app.use(async (req, res, next) => { + const start = performance.now(); + console.log(`→ ${req.method} ${req.path}`); + + await next(); + + const duration = (performance.now() - start).toFixed(2); + console.log(`← ${req.method} ${req.path} [${duration}ms]`); +}); + +// ─── Auth Middleware (scoped to /api) ───────────────────────────────────────── + +app.use("/api", async (req, res, next) => { + const token = req.header("authorization"); + + if (!token || !token.startsWith("Bearer ")) { + res.status(401).json({ + error: "Unauthorized", + message: "Missing or invalid Bearer token", + }); + return; + } + + // Simulate token validation + const payload = token.slice(7); + if (payload === "invalid") { + res.status(403).json({ error: "Forbidden", message: "Token is invalid" }); + return; + } + + // Attach user info to response locals for downstream handlers + res.locals.user = { id: 1, name: "John Doe", token: payload }; + await next(); +}); + +// ─── Request ID Middleware (global) ──────────────────────────────────────────── + +app.use(async (req, res, next) => { + const requestId = crypto.randomUUID(); + res.set("X-Request-Id", requestId); + res.locals.requestId = requestId; + await next(); +}); + +// ─── Routes ─────────────────────────────────────────────────────────────────── + +// Public route (only logging + request ID middlewares run) +app.get("/", (req, res) => { + res.json({ message: "Public endpoint", requestId: res.locals.requestId }); +}); + +// Protected route (logging + request ID + auth middlewares run) +app.get("/api/profile", (req, res) => { + res.json({ + user: res.locals.user, + requestId: res.locals.requestId, + }); +}); + +app.get("/api/secret", (req, res) => { + res.json({ + secret: "The cake is a lie", + accessedBy: res.locals.user.name, + }); +}); + +// ─── Error Handler ──────────────────────────────────────────────────────────── + +app.onError((err, req, res) => { + console.error(`❌ Error on ${req.method} ${req.path}:`, err.message); + res.status(500).json({ + error: "Something went wrong", + requestId: res.locals.requestId, + }); +}); + +// Route that deliberately throws +app.get("/api/crash", (req, res) => { + throw new Error("Simulated server error"); +}); + +const server = await app.listen({ port: 3000 }); +console.log(`🔐 Middleware example running at ${server.url}`); +console.log(` +Try these routes: + # Public (no auth needed) + curl ${server.url}/ + + # Protected (will fail - no token) + curl ${server.url}/api/profile + + # Protected (with valid token) + curl -H "Authorization: Bearer mytoken123" ${server.url}/api/profile + + # Protected (with invalid token) + curl -H "Authorization: Bearer invalid" ${server.url}/api/profile + + # Error handler test + curl -H "Authorization: Bearer mytoken123" ${server.url}/api/crash +`); diff --git a/examples/rest-api/server.js b/examples/rest-api/server.js new file mode 100644 index 0000000..2e0a144 --- /dev/null +++ b/examples/rest-api/server.js @@ -0,0 +1,148 @@ +import { createApp } from "http-native"; + +const app = createApp(); + +// In-memory store +const todos = new Map(); +let nextId = 1; + +// Seed data +todos.set(1, { id: 1, title: "Learn http-native", completed: true }); +todos.set(2, { id: 2, title: "Build something awesome", completed: false }); +todos.set(3, { id: 3, title: "Deploy to production", completed: false }); +nextId = 4; + +// ─── Error Handler ──────────────────────────────────────────────────────────── + +app.onError((err, req, res) => { + console.error(`Error: ${err.message}`); + res.status(500).json({ error: "Internal server error" }); +}); + +// ─── List all todos ─────────────────────────────────────────────────────────── + +app.get("/todos", (req, res) => { + const { completed } = req.query; + let items = [...todos.values()]; + + if (completed === "true") { + items = items.filter((t) => t.completed); + } else if (completed === "false") { + items = items.filter((t) => !t.completed); + } + + res.json({ todos: items, count: items.length }); +}); + +// ─── Get single todo ────────────────────────────────────────────────────────── + +app.get("/todos/:id", (req, res) => { + const id = Number(req.params.id); + const todo = todos.get(id); + + if (!todo) { + res.status(404).json({ error: `Todo #${id} not found` }); + return; + } + + res.json(todo); +}); + +// ─── Create todo ────────────────────────────────────────────────────────────── + +app.post("/todos", (req, res) => { + const body = req.json(); + + if (!body || !body.title) { + res.status(400).json({ error: "Title is required" }); + return; + } + + const todo = { + id: nextId++, + title: String(body.title), + completed: Boolean(body.completed ?? false), + }; + + todos.set(todo.id, todo); + res.status(201).json(todo); +}); + +// ─── Update todo ────────────────────────────────────────────────────────────── + +app.put("/todos/:id", (req, res) => { + const id = Number(req.params.id); + const existing = todos.get(id); + + if (!existing) { + res.status(404).json({ error: `Todo #${id} not found` }); + return; + } + + const body = req.json(); + const updated = { + ...existing, + title: body?.title ?? existing.title, + completed: body?.completed ?? existing.completed, + }; + + todos.set(id, updated); + res.json(updated); +}); + +// ─── Delete todo ────────────────────────────────────────────────────────────── + +app.delete("/todos/:id", (req, res) => { + const id = Number(req.params.id); + + if (!todos.has(id)) { + res.status(404).json({ error: `Todo #${id} not found` }); + return; + } + + todos.delete(id); + res.sendStatus(204); +}); + +// ─── Toggle completion ──────────────────────────────────────────────────────── + +app.patch("/todos/:id/toggle", (req, res) => { + const id = Number(req.params.id); + const todo = todos.get(id); + + if (!todo) { + res.status(404).json({ error: `Todo #${id} not found` }); + return; + } + + todo.completed = !todo.completed; + res.json(todo); +}); + +const server = await app.listen({ port: 3000 }); +console.log(`📝 REST API running at ${server.url}`); +console.log(` +CRUD operations: + # List all + curl ${server.url}/todos + + # Filter completed + curl "${server.url}/todos?completed=false" + + # Get one + curl ${server.url}/todos/1 + + # Create + curl -X POST -H "Content-Type: application/json" \\ + -d '{"title":"New todo"}' ${server.url}/todos + + # Update + curl -X PUT -H "Content-Type: application/json" \\ + -d '{"title":"Updated","completed":true}' ${server.url}/todos/1 + + # Toggle + curl -X PATCH ${server.url}/todos/2/toggle + + # Delete + curl -X DELETE ${server.url}/todos/3 +`); diff --git a/examples/validation/server.js b/examples/validation/server.js new file mode 100644 index 0000000..a3e25dc --- /dev/null +++ b/examples/validation/server.js @@ -0,0 +1,126 @@ +import { createApp } from "http-native"; +import { validate } from "http-native/validate"; + +const app = createApp(); + +// ─── Manual Schema (no external deps) ───────────────────────────────────────── +// +// Works with any object that has .parse() or .safeParse(). +// Below is a minimal hand-rolled schema — in production, use Zod: +// +// import { z } from "zod"; +// const CreateUserSchema = z.object({ +// name: z.string().min(1).max(100), +// email: z.string().email(), +// age: z.number().int().min(0).max(150).optional(), +// }); + +// Minimal schema with .parse() compatible API +function createSchema(validator) { + return { + parse(data) { + const errors = validator(data); + if (errors.length > 0) { + const err = new Error("Validation failed"); + err.issues = errors.map((msg) => ({ path: [], message: msg })); + throw err; + } + return data; + }, + }; +} + +const CreateUserSchema = createSchema((data) => { + const errors = []; + if (!data || typeof data !== "object") { + errors.push("Body must be an object"); + return errors; + } + if (typeof data.name !== "string" || data.name.length === 0) { + errors.push("name is required and must be a non-empty string"); + } + if (typeof data.email !== "string" || !data.email.includes("@")) { + errors.push("email is required and must be a valid email"); + } + if (data.age !== undefined && (typeof data.age !== "number" || data.age < 0)) { + errors.push("age must be a non-negative number"); + } + return errors; +}); + +const QuerySchema = createSchema((data) => { + const errors = []; + if (data.page && isNaN(Number(data.page))) { + errors.push("page must be a number"); + } + if (data.limit && isNaN(Number(data.limit))) { + errors.push("limit must be a number"); + } + return errors; +}); + +// ─── Error Handler ──────────────────────────────────────────────────────────── + +app.onError((err, req, res) => { + console.error(`Error: ${err.message}`); + res.status(500).json({ error: "Internal server error" }); +}); + +// ─── Routes with Validation ────────────────────────────────────────────────── + +// Validates body against CreateUserSchema +app.post( + "/users", + validate({ body: CreateUserSchema }), + (req, res) => { + // req.validatedBody is the parsed & validated data + const user = { + id: crypto.randomUUID(), + ...req.validatedBody, + createdAt: new Date().toISOString(), + }; + + res.status(201).json(user); + }, +); + +// Validates query params +app.get( + "/users", + validate({ query: QuerySchema }), + (req, res) => { + const page = Number(req.query.page) || 1; + const limit = Number(req.query.limit) || 10; + + res.json({ + users: [], + pagination: { page, limit, total: 0 }, + }); + }, +); + +const server = await app.listen({ port: 3000 }); +console.log(`✅ Validation example running at ${server.url}`); +console.log(` +Try these: + # Valid user + curl -X POST -H "Content-Type: application/json" \\ + -d '{"name":"Alice","email":"alice@example.com","age":30}' \\ + ${server.url}/users + + # Missing name (400 error) + curl -X POST -H "Content-Type: application/json" \\ + -d '{"email":"bob@example.com"}' \\ + ${server.url}/users + + # Invalid email (400 error) + curl -X POST -H "Content-Type: application/json" \\ + -d '{"name":"Charlie","email":"not-an-email"}' \\ + ${server.url}/users + + # Valid query + curl "${server.url}/users?page=2&limit=20" + + # Invalid query (400 error) + curl "${server.url}/users?page=abc" +`); diff --git a/package.json b/package.json index b384e73..21052f3 100644 --- a/package.json +++ b/package.json @@ -4,13 +4,17 @@ "type": "module", "private": true, "exports": { - ".": "./src/index.js", + ".": { + "types": "./src/index.d.ts", + "default": "./src/index.js" + }, + "./cors": "./src/cors.js", + "./validate": "./src/validate.js", "./http-server.config": "./src/http-server.config.js" }, "scripts": { "build": "bun scripts/build-native.mjs", "build:release": "bun scripts/build-native.mjs --release", - "build:old:release": "cargo build --release --manifest-path old/native/Cargo.toml", "test": "bun run build && bun test/test.js", "bench:ci": "bun bench/ci.js", "bench": "bun run build:release && cargo build --release --manifest-path old/native/Cargo.toml && bun bench/run.js", @@ -19,18 +23,15 @@ "bench:xitca:static": "bun bench/run.js xitca static 3003", "bench:monoio:static": "bun bench/run.js monoio static 3004", "bench:zig:static": "bun bench/run.js zig static 3005", - "bench:old:static": "cargo build --release --manifest-path old/native/Cargo.toml && bun bench/run.js old static 3002", "bench:http-native:dynamic": "bun run build:release && bun bench/run.js http-native dynamic 3011", "bench:bun:dynamic": "bun bench/run.js bun dynamic 3010", "bench:xitca:dynamic": "bun bench/run.js xitca dynamic 3013", "bench:monoio:dynamic": "bun bench/run.js monoio dynamic 3014", "bench:zig:dynamic": "bun bench/run.js zig dynamic 3015", - "bench:old:dynamic": "cargo build --release --manifest-path old/native/Cargo.toml && bun bench/run.js old dynamic 3012", "bench:http-native:opt": "bun run build:release && bun bench/run.js http-native opt 3021", "bench:bun:opt": "bun bench/run.js bun opt 3020", "bench:xitca:opt": "bun bench/run.js xitca opt 3023", "bench:monoio:opt": "bun bench/run.js monoio opt 3024", - "bench:zig:opt": "bun bench/run.js zig opt 3025", - "bench:old:opt": "cargo build --release --manifest-path old/native/Cargo.toml && bun bench/run.js old opt 3022" + "bench:zig:opt": "bun bench/run.js zig opt 3025" } -} +} \ No newline at end of file diff --git a/rust-native/Cargo.lock b/rust-native/Cargo.lock index 56e7a75..4148b55 100644 --- a/rust-native/Cargo.lock +++ b/rust-native/Cargo.lock @@ -310,6 +310,8 @@ dependencies = [ "anyhow", "base64", "bytes", + "httparse", + "itoa", "json5", "memchr", "monoio", @@ -322,6 +324,12 @@ dependencies = [ "url", ] +[[package]] +name = "httparse" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" + [[package]] name = "icu_collections" version = "2.1.1" diff --git a/rust-native/Cargo.toml b/rust-native/Cargo.toml index 49a8cc2..9ff68a4 100644 --- a/rust-native/Cargo.toml +++ b/rust-native/Cargo.toml @@ -10,6 +10,8 @@ crate-type = ["cdylib"] anyhow = "1.0" base64 = "0.22" bytes = "1.10" +httparse = "1.9" +itoa = "1.0" json5 = "0.4" memchr = "2.7" monoio = { version = "0.2", features = ["sync"] } @@ -28,3 +30,4 @@ lto = "fat" codegen-units = 1 panic = "abort" strip = "symbols" +opt-level = 3 diff --git a/rust-native/src/lib.rs b/rust-native/src/lib.rs index 0798860..7dd0955 100644 --- a/rust-native/src/lib.rs +++ b/rust-native/src/lib.rs @@ -1,375 +1,132 @@ -mod analyzer; -mod manifest; -mod router; - -use anyhow::{anyhow, Context, Result}; +use anyhow::{Context, Result}; use bytes::Bytes; -use memchr::{memchr, memmem}; -use monoio::io::{AsyncReadRent, AsyncWriteRent, AsyncWriteRentExt}; -use monoio::net::{ListenerOpts, TcpListener, TcpStream}; -use napi::bindgen_prelude::{Buffer, Function, Promise}; -use napi::threadsafe_function::ThreadsafeFunction; -use napi::{Error, Status}; -use napi_derive::napi; +use monoio::io::{AsyncReadExt, AsyncWriteExt}; +use monoio::net::{TcpListener, TcpStream}; +use monoio::utils::memmem; +use napi::bindgen_prelude::Buffer; use std::borrow::Cow; +use std::collections::HashMap; use std::net::{SocketAddr, ToSocketAddrs}; -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::{mpsc, Arc, Mutex}; - -use crate::manifest::{HttpServerConfigInput, ManifestInput}; +use std::sync::Arc; +use std::thread; use crate::router::{ExactStaticRoute, MatchedRoute, Router}; +use crate::manifest::HttpServerConfigInput; -const FALLBACK_DEFAULT_HOST: &str = "127.0.0.1"; -const FALLBACK_DEFAULT_BACKLOG: i32 = 2048; -const FALLBACK_MAX_HEADER_BYTES: usize = 16 * 1024; -const FALLBACK_HOT_GET_ROOT_HTTP11: &str = "GET / HTTP/1.1\r\n"; -const FALLBACK_HOT_GET_ROOT_HTTP10: &str = "GET / HTTP/1.0\r\n"; -const FALLBACK_HEADER_CONNECTION_PREFIX: &str = "connection:"; -const FALLBACK_HEADER_CONTENT_LENGTH_PREFIX: &str = "content-length:"; -const FALLBACK_HEADER_TRANSFER_ENCODING_PREFIX: &str = "transfer-encoding:"; -const BRIDGE_VERSION: u8 = 1; -const REQUEST_FLAG_QUERY_PRESENT: u16 = 1 << 0; -const NOT_FOUND_BODY: &[u8] = br#"{"error":"Route not found"}"#; - -type DispatchTsfn = ThreadsafeFunction, Buffer, Status, false, false, 0>; - -#[derive(Clone)] -struct HttpServerConfig { - default_host: String, - default_backlog: i32, - max_header_bytes: usize, - hot_get_root_http11: Vec, - hot_get_root_http10: Vec, - header_connection_prefix: Vec, - header_content_length_prefix: Vec, - header_transfer_encoding_prefix: Vec, -} +// ─── Constants & Limits ─────────────────────────────────────────────────────── -impl HttpServerConfig { - fn from_manifest(manifest: &ManifestInput) -> Result { - let input = manifest.server_config.as_ref(); - let default_backlog = input - .and_then(|config| config.default_backlog) - .unwrap_or(FALLBACK_DEFAULT_BACKLOG); - let max_header_bytes = input - .and_then(|config| config.max_header_bytes) - .unwrap_or(FALLBACK_MAX_HEADER_BYTES); - - if default_backlog <= 0 { - return Err(anyhow!( - "serverConfig.defaultBacklog must be greater than 0" - )); - } +const MAX_HEADERS: usize = 64; +const MAX_BODY_BYTES: usize = 1024 * 1024; // 1MB limit for safety +const REQUEST_FLAG_QUERY_PRESENT: u16 = 1; - if max_header_bytes == 0 { - return Err(anyhow!( - "serverConfig.maxHeaderBytes must be greater than 0" - )); - } +// ─── Types ──────────────────────────────────────────────────────────────────── - Ok(Self { - default_host: config_string( - input, - |config| config.default_host.as_deref(), - FALLBACK_DEFAULT_HOST, - ), - default_backlog, - max_header_bytes, - hot_get_root_http11: config_string( - input, - |config| config.hot_get_root_http11.as_deref(), - FALLBACK_HOT_GET_ROOT_HTTP11, - ) - .into_bytes(), - hot_get_root_http10: config_string( - input, - |config| config.hot_get_root_http10.as_deref(), - FALLBACK_HOT_GET_ROOT_HTTP10, - ) - .into_bytes(), - header_connection_prefix: config_string( - input, - |config| config.header_connection_prefix.as_deref(), - FALLBACK_HEADER_CONNECTION_PREFIX, - ) - .into_bytes(), - header_content_length_prefix: config_string( - input, - |config| config.header_content_length_prefix.as_deref(), - FALLBACK_HEADER_CONTENT_LENGTH_PREFIX, - ) - .into_bytes(), - header_transfer_encoding_prefix: config_string( - input, - |config| config.header_transfer_encoding_prefix.as_deref(), - FALLBACK_HEADER_TRANSFER_ENCODING_PREFIX, - ) - .into_bytes(), - }) - } -} - -#[napi(object)] -pub struct NativeListenOptions { - pub host: Option, - pub port: u16, - pub backlog: Option, -} +pub type JsDispatcher = napi::threadsafe_function::ThreadsafeFunction; -struct ShutdownHandle { - flag: Arc, - wake_addrs: Vec, +#[derive(Clone, Debug)] +pub struct HttpServerConfig { + pub default_host: String, + pub default_backlog: i32, + pub max_header_bytes: usize, + pub hot_get_root_http11: Vec, + pub hot_get_root_http10: Vec, + pub header_connection_prefix: Vec, + pub header_content_length_prefix: Vec, + pub header_transfer_encoding_prefix: Vec, } -#[napi] -pub struct NativeServerHandle { - host: String, - port: u32, - url: String, - shutdown: Mutex>, - closed: Mutex>>>, +pub struct ParsedRequest<'a> { + method: &'a [u8], + target: &'a [u8], + path: &'a [u8], + keep_alive: bool, + header_bytes: usize, + has_body: bool, + content_length: Option, + headers: Vec<(&'a str, &'a str)>, } -#[napi] -impl NativeServerHandle { - #[napi(getter)] - pub fn host(&self) -> String { - self.host.clone() - } - - #[napi(getter)] - pub fn port(&self) -> u32 { - self.port - } - - #[napi(getter)] - pub fn url(&self) -> String { - self.url.clone() - } - - #[napi] - pub fn close(&self) -> napi::Result<()> { - if let Some(shutdown) = self - .shutdown - .lock() - .expect("shutdown mutex poisoned") - .take() - { - shutdown.flag.store(true, Ordering::SeqCst); - for wake_addr in shutdown.wake_addrs { - let _ = std::net::TcpStream::connect(wake_addr); - } - } - - if let Some(receivers) = self.closed.lock().expect("closed mutex poisoned").take() { - for receiver in receivers { - let _ = receiver.recv(); - } - } +// ─── Buffer Pooling ─────────────────────────────────────────────────────────── +// +// Zero-allocation buffer management. Buffers are re-used across connections +// within the same thread to avoid expensive syscalls and allocator pressure. - Ok(()) - } +thread_local! { + static BUFFER_POOL: std::cell::RefCell> = std::cell::RefCell::new(Vec::with_capacity(65536)); } -#[napi] -pub fn start_server( - manifest_json: String, - dispatcher: Function<'_, Buffer, Promise>, - options: NativeListenOptions, -) -> napi::Result { - let manifest: ManifestInput = serde_json::from_str(&manifest_json).map_err(to_napi_error)?; - validate_manifest(&manifest).map_err(to_napi_error)?; - let server_config = - Arc::new(HttpServerConfig::from_manifest(&manifest).map_err(to_napi_error)?); - let router = Arc::new(Router::from_manifest(&manifest).map_err(to_napi_error)?); - - let callback: DispatchTsfn = dispatcher - .build_threadsafe_function::() - .build() - .map_err(to_napi_error)?; - let dispatcher = Arc::new(JsDispatcher { callback }); - - let worker_count = worker_count_for(&options); - let (startup_tx, startup_rx) = mpsc::sync_channel::>(worker_count); - let shutdown_flag = Arc::new(AtomicBool::new(false)); - let mut closed_receivers = Vec::with_capacity(worker_count); - - for _ in 0..worker_count { - let (closed_tx, closed_rx) = mpsc::channel::<()>(); - closed_receivers.push(closed_rx); - - let thread_router = Arc::clone(&router); - let thread_dispatcher = Arc::clone(&dispatcher); - let thread_config = Arc::clone(&server_config); - let thread_shutdown = Arc::clone(&shutdown_flag); - let thread_options = NativeListenOptions { - host: options.host.clone(), - port: options.port, - backlog: options.backlog, - }; - let thread_startup_tx = startup_tx.clone(); - - std::thread::spawn(move || { - let startup_tx_error = thread_startup_tx.clone(); - let result = (|| -> Result<()> { - let mut runtime = monoio::RuntimeBuilder::::new() - .build() - .context("failed to build monoio runtime")?; - - runtime.block_on(async move { - let listener = bind_listener(&thread_options, thread_config.as_ref()) - .context("failed to create monoio listener")?; - let local_addr = listener.local_addr()?; - let _ = thread_startup_tx.send(Ok(local_addr)); - run_server( - listener, - thread_router, - thread_dispatcher, - thread_config, - thread_shutdown, - ) - .await - }) - })(); - - if let Err(error) = &result { - let _ = startup_tx_error.send(Err(error.to_string())); - eprintln!("[http-native] native server error: {error:#}"); - } - - let _ = closed_tx.send(()); - }); - } - - let mut wake_addrs = Vec::with_capacity(worker_count); - let mut local_addr = None; - for _ in 0..worker_count { - match startup_rx.recv() { - Ok(Ok(addr)) => { - if local_addr.is_none() { - local_addr = Some(addr); - } - wake_addrs.push(addr); - } - Ok(Err(message)) => { - shutdown_flag.store(true, Ordering::SeqCst); - for wake_addr in &wake_addrs { - let _ = std::net::TcpStream::connect(*wake_addr); - } - for receiver in closed_receivers { - let _ = receiver.recv(); - } - return Err(Error::from_reason(message)); - } - Err(_) => { - shutdown_flag.store(true, Ordering::SeqCst); - for wake_addr in &wake_addrs { - let _ = std::net::TcpStream::connect(*wake_addr); - } - for receiver in closed_receivers { - let _ = receiver.recv(); - } - return Err(Error::from_reason( - "Native server exited before reporting readiness".to_string(), - )); - } +fn acquire_buffer() -> Vec { + BUFFER_POOL.with(|pool| { + let mut b = pool.borrow_mut(); + if b.capacity() < 65536 { + Vec::with_capacity(65536) + } else { + std::mem::take(&mut *b) } - } - - let local_addr = local_addr.expect("worker count must be at least 1"); - - let host = local_addr.ip().to_string(); - let port = local_addr.port() as u32; - - Ok(NativeServerHandle { - host: host.clone(), - port, - url: format!("http://{host}:{port}"), - shutdown: Mutex::new(Some(ShutdownHandle { - flag: shutdown_flag, - wake_addrs, - })), - closed: Mutex::new(Some(closed_receivers)), }) } -fn worker_count_for(options: &NativeListenOptions) -> usize { - if options.port == 0 { - return 1; - } - - std::env::var("HTTP_NATIVE_WORKERS") - .ok() - .and_then(|value| value.parse::().ok()) - .filter(|count| *count > 0) - .unwrap_or(1) +fn release_buffer(mut buf: Vec) { + buf.clear(); + BUFFER_POOL.with(|pool| { + *pool.borrow_mut() = buf; + }); } -struct JsDispatcher { - callback: DispatchTsfn, -} +// ─── Server Entry Point ─────────────────────────────────────────────────────── -impl JsDispatcher { - async fn dispatch(&self, request: Buffer) -> Result { - let response_json = self - .callback - .call_async(request) - .await - .map_err(|error| anyhow!(error.to_string()))? - .await - .map_err(|error| anyhow!(error.to_string()))?; - - Ok(response_json) - } -} - -async fn run_server( - listener: TcpListener, - router: Arc, - dispatcher: Arc, - server_config: Arc, - shutdown_flag: Arc, -) -> Result<()> { - loop { - if shutdown_flag.load(Ordering::Acquire) { - break; - } - - match listener.accept().await { - Ok((stream, _)) => { - if shutdown_flag.load(Ordering::Acquire) { - break; - } +pub fn start_server( + manifest_json: String, + handler: JsDispatcher, + options: NativeListenOptions, +) -> Result { + let manifest: crate::manifest::ManifestInput = serde_json::from_str(&manifest_json)?; + let router = Arc::new(Router::from_manifest(&manifest)?); + let dispatcher = Arc::new(handler); + let server_config = Arc::new(HttpServerConfig::from_input(manifest.server_config.as_ref())); - if should_enable_nodelay() { - if let Err(error) = stream.set_nodelay(true) { - eprintln!("[http-native] failed to enable TCP_NODELAY: {error}"); + let worker_count = worker_count_for(&options); + let mut workers = Vec::with_capacity(worker_count); + + for i in 0..worker_count { + let router = Arc::clone(&router); + let dispatcher = Arc::clone(&dispatcher); + let server_config = Arc::clone(&server_config); + let options = options.clone(); + + let handle = thread::spawn(move || { + let mut driver = monoio::RuntimeBuilder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + driver.block_on(async move { + let listener = bind_listener(&options, &server_config)?; + + loop { + let (stream, _) = listener.accept().await?; + + if should_enable_nodelay() { + if let Err(error) = stream.set_nodelay(true) { + eprintln!("[http-native] failed to enable TCP_NODELAY: {error}"); + } } - } - let router = Arc::clone(&router); - let dispatcher = Arc::clone(&dispatcher); - let server_config = Arc::clone(&server_config); + let router = Arc::clone(&router); + let dispatcher = Arc::clone(&dispatcher); + let server_config = Arc::clone(&server_config); - monoio::spawn(async move { - if let Err(error) = - handle_connection(stream, router, dispatcher, server_config).await - { - eprintln!("[http-native] connection error: {error}"); - } - }); - } - Err(error) => { - if shutdown_flag.load(Ordering::Acquire) { - break; + monoio::spawn(async move { + if let Err(e) = handle_connection(stream, router, dispatcher, server_config).await { + eprintln!("[http-native] worker {i} connection error: {e}"); + } + }); } - - eprintln!("[http-native] accept error: {error}"); - } - } + }) + }); + workers.push(handle); } - Ok(()) + Ok(ServerHandle { workers }) } async fn handle_connection( @@ -378,787 +135,554 @@ async fn handle_connection( dispatcher: Arc, server_config: Arc, ) -> Result<()> { - let mut buffer: Vec = Vec::with_capacity(8192); - let mut buffer_start = 0usize; + let mut buffer = acquire_buffer(); + + let result = handle_connection_inner( + &mut stream, + &mut buffer, + &router, + &dispatcher, + &server_config, + ) + .await; + release_buffer(buffer); + result +} + +async fn handle_connection_inner( + stream: &mut TcpStream, + buffer: &mut Vec, + router: &Router, + dispatcher: &JsDispatcher, + server_config: &HttpServerConfig, +) -> Result<()> { loop { - let request_head = loop { - let readable = &buffer[buffer_start..]; - let request_head = if router.exact_get_root().is_some() { - parse_hot_root_request_head(readable, server_config.as_ref()) - .or_else(|| parse_request_head(readable)) + // Try hot-path parsing first + let parsed = loop { + let result = if router.exact_get_root().is_some() { + parse_hot_root_request(buffer, server_config) + .or_else(|| parse_request_httparse(buffer)) } else { - parse_request_head(readable) + parse_request_httparse(buffer) }; - if let Some(request_head) = request_head { - break request_head; + if let Some(parsed) = result { + break parsed; } - if find_header_end(readable).is_some() { + if find_header_end(buffer).is_some() { stream.shutdown().await?; return Ok(()); } - compact_read_buffer(&mut buffer, &mut buffer_start); - let (read_result, next_buffer) = stream.read(buffer).await; - buffer = next_buffer; + let owned_buf = std::mem::take(buffer); + let (read_result, next_buffer) = stream.read(owned_buf).await; + *buffer = next_buffer; let bytes_read = read_result?; if bytes_read == 0 { return Ok(()); } - if buffer.len().saturating_sub(buffer_start) > server_config.max_header_bytes { + if buffer.len() > server_config.max_header_bytes { + let response = build_error_response_bytes( + 431, + b"{\"error\":\"Request Header Fields Too Large\"}", + false, + ); + let (write_result, _) = stream.write_all(response).await; + write_result?; stream.shutdown().await?; return Ok(()); } }; - if request_head.has_body { - stream.shutdown().await?; - return Ok(()); + let header_bytes = parsed.header_bytes; + let keep_alive = parsed.keep_alive; + let has_body = parsed.has_body; + let content_length = parsed.content_length; + + // Extract owned copies from parsed (which borrows buffer) before we mutate buffer + let method_owned: Vec = parsed.method.to_vec(); + let target_owned: Vec = parsed.target.to_vec(); + let path_owned: Vec = parsed.path.to_vec(); + let headers_owned: Vec<(String, String)> = parsed + .headers + .iter() + .map(|(n, v)| (n.to_string(), v.to_string())) + .collect(); + + drop(parsed); + + // ── Fast path: static routes (GET /) ── + if !has_body && method_owned == b"GET" { + if path_owned == b"/" { + if let Some(static_route) = router.exact_get_root() { + drain_consumed_bytes(buffer, header_bytes); + write_exact_static_response(stream, static_route, keep_alive).await?; + if !keep_alive { + stream.shutdown().await?; + return Ok(()); + } + continue; + } + } + if let Some(static_route) = router.exact_static_route(&method_owned, &path_owned) { + drain_consumed_bytes(buffer, header_bytes); + write_exact_static_response(stream, static_route, keep_alive).await?; + if !keep_alive { + stream.shutdown().await?; + return Ok(()); + } + continue; + } } - let header_bytes = request_head.header_bytes; - let keep_alive = request_head.keep_alive; - - if let Some(static_route) = - resolve_static_fast_path(&router, &request_head, server_config.as_ref()) - { - consume_read_buffer(&mut buffer, &mut buffer_start, header_bytes); - write_exact_static_response(&mut stream, static_route, keep_alive).await?; + // ── Read request body if present ────────────────────────────── + let body_bytes: Vec = if has_body { + let cl = match content_length { + Some(len) => len, + None => { + let response = build_error_response_bytes(411, b"{\"error\":\"Length Required\"}", false); + let (write_result, _) = stream.write_all(response).await; + write_result?; + stream.shutdown().await?; + return Ok(()); + } + }; - if !keep_alive { + if cl > MAX_BODY_BYTES { + let response = build_error_response_bytes(413, b"{\"error\":\"Payload Too Large\"}", false); + let (write_result, _) = stream.write_all(response).await; + write_result?; stream.shutdown().await?; return Ok(()); } - continue; - } + let already_in_buffer = if buffer.len() > header_bytes { + buffer.len() - header_bytes + } else { + 0 + }; - let dispatch_request = build_manual_dispatch_request( - &router, - &buffer[buffer_start..buffer_start + header_bytes], - &request_head, + if already_in_buffer >= cl { + let body = buffer[header_bytes..header_bytes + cl].to_vec(); + drain_consumed_bytes(buffer, header_bytes + cl); + body + } else { + let mut body = Vec::with_capacity(cl); + if already_in_buffer > 0 { + body.extend_from_slice(&buffer[header_bytes..]); + } + drain_consumed_bytes(buffer, buffer.len()); + + while body.len() < cl { + let remaining = cl - body.len(); + let chunk_buf = vec![0u8; remaining.min(65536)]; + let (read_result, returned_buf) = stream.read(chunk_buf).await; + let bytes_read = read_result?; + if bytes_read == 0 { + return Ok(()); + } + body.extend_from_slice(&returned_buf[..bytes_read]); + } + body.truncate(cl); + body + } + } else { + drain_consumed_bytes(buffer, header_bytes); + Vec::new() + }; + + // ── Dynamic path: Bridge to JS ──── + let dispatch_request = build_dispatch_request_owned( + router, + &method_owned, + &target_owned, + &path_owned, + &headers_owned, + &body_bytes, )?; - consume_read_buffer(&mut buffer, &mut buffer_start, header_bytes); match dispatch_request { Some(request) => { - write_dynamic_dispatch_response( - &mut stream, - dispatcher.as_ref(), - request, - keep_alive, - ) - .await?; + write_dynamic_dispatch_response(stream, dispatcher, request, keep_alive).await?; + if !keep_alive { + stream.shutdown().await?; + return Ok(()); + } } None => { - write_not_found_response(&mut stream, keep_alive).await?; + write_not_found_response(stream, keep_alive).await?; + if !keep_alive { + stream.shutdown().await?; + return Ok(()); + } } } - - if !keep_alive { - stream.shutdown().await?; - return Ok(()); - } } } -fn resolve_static_fast_path<'a>( - router: &'a Router, - request_head: &RequestHead<'_>, - server_config: &HttpServerConfig, -) -> Option<&'a ExactStaticRoute> { - if request_head.path == b"/" - && request_head.method == b"GET" - && parse_hot_root_request_head_prefix(request_head, server_config) - { - return router.exact_get_root(); - } - - router.exact_static_route(request_head.method, request_head.path) -} - -fn parse_hot_root_request_head_prefix( - request_head: &RequestHead<'_>, - _server_config: &HttpServerConfig, -) -> bool { - request_head.method == b"GET" && request_head.path == b"/" -} - -fn compact_read_buffer(buffer: &mut Vec, buffer_start: &mut usize) { - if *buffer_start == 0 { - return; - } - - if *buffer_start >= buffer.len() { - buffer.clear(); - *buffer_start = 0; - return; - } - - if *buffer_start < 4096 && buffer.len() < buffer.capacity() { - return; - } - - let remaining = buffer.len() - *buffer_start; - buffer.copy_within(*buffer_start.., 0); - buffer.truncate(remaining); - *buffer_start = 0; -} - -fn consume_read_buffer(buffer: &mut Vec, buffer_start: &mut usize, consumed: usize) { - *buffer_start = (*buffer_start).saturating_add(consumed); +// ─── Header Parsers ─────────────────────────────────────────────────────────── + +fn parse_request_httparse(bytes: &[u8]) -> Option> { + let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS]; + let mut req = httparse::Request::new(&mut headers); + + match req.parse(bytes) { + Ok(httparse::Status::Complete(header_len)) => { + let method = req.method?; + let target = req.path?; + let path = target.split(|&b| b == b'?').next()?; + + let mut keep_alive = req.version == Some(1); // Default true for HTTP/1.1 + let mut content_length = None; + let mut has_body = false; + let mut parsed_headers = Vec::with_capacity(req.headers.len()); + + for h in req.headers { + let name = h.name.to_lowercase(); + let value_bytes = h.value; + let value = std::str::from_utf8(value_bytes).ok()?; + + match name.as_str() { + "connection" => { + if contains_ascii_case_insensitive(value_bytes, b"close") { + keep_alive = false; + } else if contains_ascii_case_insensitive(value_bytes, b"keep-alive") { + keep_alive = true; + } + } + "content-length" => { + if let Ok(len) = value.trim().parse::() { + content_length = Some(len); + if len > 0 { has_body = true; } + } + } + "transfer-encoding" => { + if !value.eq_ignore_ascii_case("identity") { + has_body = true; + } + } + _ => {} + } + parsed_headers.push((h.name, value)); + } - if *buffer_start >= buffer.len() { - buffer.clear(); - *buffer_start = 0; + Some(ParsedRequest { + method: method.as_bytes(), + target: target.as_bytes(), + path: path.as_bytes(), + keep_alive, + header_bytes: header_len, + has_body, + content_length, + headers: parsed_headers, + }) + } + _ => None, } } -fn bind_listener( - options: &NativeListenOptions, +fn parse_hot_root_request<'a>( + bytes: &'a [u8], server_config: &HttpServerConfig, -) -> Result { - let host = options - .host - .as_deref() - .unwrap_or(server_config.default_host.as_str()); - let bind_addr = resolve_socket_addr(host, options.port) - .with_context(|| format!("failed to resolve bind address {host}:{}", options.port))?; - let mut listener_opts = ListenerOpts::new() - .reuse_addr(true) - .backlog(options.backlog.unwrap_or(server_config.default_backlog)); - if worker_count_for(options) > 1 { - listener_opts = listener_opts.reuse_port(true); - } - - TcpListener::bind_with_config(bind_addr, &listener_opts) - .with_context(|| format!("failed to bind TCP listener on {bind_addr}")) -} - -fn should_enable_nodelay() -> bool { - std::env::var("HTTP_NATIVE_TCP_NODELAY") - .ok() - .map(|value| { - let normalized = value.trim().to_ascii_lowercase(); - !matches!(normalized.as_str(), "0" | "false" | "off" | "no") - }) - .unwrap_or(true) -} - -fn resolve_socket_addr(host: &str, port: u16) -> Result { - (host, port) - .to_socket_addrs()? - .next() - .ok_or_else(|| anyhow!("unable to resolve {host}:{port}")) -} - -fn validate_manifest(manifest: &ManifestInput) -> Result<()> { - if manifest.version != 1 { - return Err(anyhow!("Unsupported manifest version {}", manifest.version)); - } - - Ok(()) -} - -async fn write_exact_static_response( - stream: &mut TcpStream, - static_route: &ExactStaticRoute, - keep_alive: bool, -) -> Result<()> { - let response = if keep_alive { - static_route.keep_alive_response.clone() +) -> Option> { + let (_, keep_alive) = if bytes.starts_with(server_config.hot_get_root_http11.as_slice()) { + (server_config.hot_get_root_http11.len(), true) + } else if bytes.starts_with(server_config.hot_get_root_http10.as_slice()) { + (server_config.hot_get_root_http10.len(), false) } else { - static_route.close_response.clone() + return None; }; - let (write_result, _) = stream.write_all(response).await; - write_result?; - Ok(()) + let header_end = find_header_end(bytes)?; + // For hot path, we just verify it looks like a header block ending + // but we use httparse for the actual details to be safe. + parse_request_httparse(bytes) } -#[derive(Clone)] -struct DispatchResponseEnvelope { - status: u16, - headers: Vec<(String, String)>, - body: Bytes, -} +// ─── Routing ────────────────────────────────────────────────────────────────── -fn method_code_from_bytes(method: &[u8]) -> Option { - match method { - b"GET" => Some(1), - b"POST" => Some(2), - b"PUT" => Some(3), - b"DELETE" => Some(4), - b"PATCH" => Some(5), - b"OPTIONS" => Some(6), - b"HEAD" => Some(7), - _ => None, +#[allow(dead_code)] +fn resolve_static_fast_path<'a>( + router: &'a Router, + parsed: &ParsedRequest<'_>, + _server_config: &HttpServerConfig, +) -> Option<&'a ExactStaticRoute> { + if parsed.path == b"/" && parsed.method == b"GET" { + return router.exact_get_root(); } + router.exact_static_route(parsed.method, parsed.path) } -fn build_manual_dispatch_request( +fn build_dispatch_request_owned( router: &Router, - request_bytes: &[u8], - request_head: &RequestHead<'_>, + method: &[u8], + target: &[u8], + path: &[u8], + headers: &[(String, String)], + body: &[u8], ) -> Result> { - let Some(method_code) = method_code_from_bytes(request_head.method) else { + let Some(method_code) = method_code_from_bytes(method) else { return Ok(None); }; - let path = match std::str::from_utf8(request_head.path) { - Ok(path) => path, - Err(_) => return Ok(None), - }; - let url = match std::str::from_utf8(request_head.target) { - Ok(url) => url, - Err(_) => return Ok(None), - }; - let normalized_path = normalize_runtime_path(path); + let path_str = std::str::from_utf8(path).ok().context("Invalid UTF-8 path")?; + let url_str = std::str::from_utf8(target).ok().context("Invalid UTF-8 URL")?; + + let normalized_path = normalize_runtime_path(path_str); + if contains_path_traversal(&normalized_path) { + return Ok(None); + } + let Some(matched_route) = router.match_route(method_code, normalized_path.as_ref()) else { return Ok(None); }; - let header_entries = if matched_route.full_headers || !matched_route.header_keys.is_empty() { - parse_request_header_pairs(request_bytes)? - } else { - Vec::new() - }; - build_dispatch_request_from_pairs(&matched_route, method_code, path, url, &header_entries) - .map(Some) + let header_refs: Vec<(&str, &str)> = headers.iter().map(|(n, v)| (n.as_str(), v.as_str())).collect(); + build_dispatch_envelope(&matched_route, method_code, path_str, url_str, &header_refs, body).map(Some) } -fn build_dispatch_request_from_pairs( +fn build_dispatch_envelope( matched_route: &MatchedRoute<'_, '_>, method_code: u8, path: &str, url: &str, header_entries: &[(&str, &str)], + body: &[u8], ) -> Result { let url_bytes = url.as_bytes(); let path_bytes = path.as_bytes(); - let flags = if url.contains('?') { - REQUEST_FLAG_QUERY_PRESENT - } else { - 0 - }; - - if url_bytes.len() > u32::MAX as usize { - return Err(anyhow!("request url too large")); - } - if path_bytes.len() > u16::MAX as usize { - return Err(anyhow!("request path too large")); - } - if matched_route.param_values.len() > u16::MAX as usize { - return Err(anyhow!("too many params")); - } - let selected_headers = select_header_entries(header_entries, matched_route); - if selected_headers.len() > u16::MAX as usize { - return Err(anyhow!("too many headers")); + let mut flags: u16 = 0; + if url.contains('?') { + flags |= REQUEST_FLAG_QUERY_PRESENT; } - let mut frame = - Vec::with_capacity(16 + url_bytes.len() + path_bytes.len() + selected_headers.len() * 16); - frame.push(BRIDGE_VERSION); - frame.push(method_code); - push_u16(&mut frame, flags); - push_u32(&mut frame, matched_route.handler_id); - push_u32(&mut frame, url_bytes.len() as u32); - push_u16(&mut frame, path_bytes.len() as u16); - push_u16(&mut frame, matched_route.param_values.len() as u16); - push_u16(&mut frame, selected_headers.len() as u16); - frame.extend_from_slice(url_bytes); - frame.extend_from_slice(path_bytes); - - for value in matched_route.param_values.iter() { - push_string_value(&mut frame, value)?; - } + let mut envelope = Vec::with_capacity(512 + body.len()); + envelope.push(1); // Version + envelope.push(method_code); + envelope.extend_from_slice(&(flags).to_le_bytes()); + envelope.extend_from_slice(&(matched_route.handler_id).to_le_bytes()); - for (name, value) in selected_headers { - push_string_pair(&mut frame, name, value)?; - } + write_usize(&mut envelope, url_bytes.len()); + envelope.extend_from_slice(url_bytes); - Ok(Buffer::from(frame)) -} + write_usize(&mut envelope, path_bytes.len()); + envelope.extend_from_slice(path_bytes); -fn select_header_entries<'a>( - header_entries: &[(&'a str, &'a str)], - matched_route: &MatchedRoute<'_, '_>, -) -> Vec<(&'a str, &'a str)> { - if matched_route.full_headers { - return header_entries.to_vec(); + write_usize(&mut envelope, matched_route.param_values.len()); + for val in &matched_route.param_values { + write_usize(&mut envelope, val.len()); + envelope.extend_from_slice(val.as_bytes()); } - if matched_route.header_keys.is_empty() { - return Vec::new(); - } + let header_count = matched_route.header_keys.len(); + write_usize(&mut envelope, header_count); - let mut selected = Vec::with_capacity(matched_route.header_keys.len()); - for (name, value) in header_entries { - if matched_route - .header_keys - .iter() - .any(|target| target.as_ref().eq_ignore_ascii_case(name)) - { - selected.push((*name, *value)); + for key_boxed in matched_route.header_keys { + let key = key_boxed.as_ref(); + let mut found = false; + for (h_name, h_value) in header_entries { + if h_name.eq_ignore_ascii_case(key) { + write_usize(&mut envelope, h_value.len()); + envelope.extend_from_slice(h_value.as_bytes()); + found = true; + break; + } + } + if !found { + write_usize(&mut envelope, 0); } } - selected -} - -fn parse_dispatch_response(bytes: &[u8]) -> Result { - let mut offset = 0; - let status = read_u16(bytes, &mut offset)?; - let header_count = read_u16(bytes, &mut offset)? as usize; - let body_length = read_u32(bytes, &mut offset)? as usize; - - let mut headers = Vec::with_capacity(header_count); - for _ in 0..header_count { - let name_length = read_u8(bytes, &mut offset)? as usize; - let value_length = read_u16(bytes, &mut offset)? as usize; - let name = read_utf8(bytes, &mut offset, name_length)?; - let value = read_utf8(bytes, &mut offset, value_length)?; - headers.push((name, value)); - } - - if offset + body_length > bytes.len() { - return Err(anyhow!("response body truncated")); - } + // Body support + write_usize(&mut envelope, body.len()); + envelope.extend_from_slice(body); - let body = Bytes::copy_from_slice(&bytes[offset..offset + body_length]); - Ok(DispatchResponseEnvelope { - status, - headers, - body, - }) + Ok(Buffer::from(envelope)) } -fn parse_request_header_pairs(bytes: &[u8]) -> Result> { - let header_end = find_header_end(bytes).ok_or_else(|| anyhow!("request header incomplete"))?; - let line_end = - memmem::find(bytes, b"\r\n").ok_or_else(|| anyhow!("request line incomplete"))?; - let mut line_start = line_end + 2; - let mut headers = Vec::new(); - - while line_start + 2 <= header_end { - let next_end = memmem::find(&bytes[line_start..header_end + 2], b"\r\n") - .ok_or_else(|| anyhow!("invalid header line"))? - + line_start; +// ─── Response Writing ───────────────────────────────────────────────────────── - if next_end == line_start { - break; - } - - let line = &bytes[line_start..next_end]; - let separator = memchr(b':', line).ok_or_else(|| anyhow!("invalid header separator"))?; - let name = std::str::from_utf8(&line[..separator]).context("header name was not utf-8")?; - let value = std::str::from_utf8(trim_ascii_spaces(&line[separator + 1..])) - .context("header value was not utf-8")?; - headers.push((name, value)); - line_start = next_end + 2; - } - - Ok(headers) +async fn write_exact_static_response( + stream: &mut TcpStream, + route: &ExactStaticRoute, + keep_alive: bool, +) -> Result<()> { + let response = if keep_alive { + &route.keep_alive_response + } else { + &route.close_response + }; + let (res, _) = stream.write_all(response.clone()).await; + res?; + Ok(()) } async fn write_dynamic_dispatch_response( stream: &mut TcpStream, dispatcher: &JsDispatcher, - request: Buffer, + request_buffer: Buffer, keep_alive: bool, ) -> Result<()> { - let parsed = match dispatcher.dispatch(request).await { - Ok(response) => match parse_dispatch_response(response.as_ref()) { - Ok(parsed) => parsed, - Err(error) => DispatchResponseEnvelope { - status: 500, - headers: vec![( - "content-type".to_string(), - "application/json; charset=utf-8".to_string(), - )], - body: Bytes::from(format!( - r#"{{"error":"Invalid response envelope","detail":"{}"}}"#, - escape_json(error.to_string().as_str()) - )), - }, - }, - Err(error) => DispatchResponseEnvelope { - status: 502, - headers: vec![( - "content-type".to_string(), - "application/json; charset=utf-8".to_string(), - )], - body: Bytes::from(format!( - r#"{{"error":"Dispatch failed","detail":"{}"}}"#, - escape_json(error.to_string().as_str()) - )), - }, - }; + let result: Buffer = dispatcher.call_async(request_buffer).await + .map_err(|e| anyhow::anyhow!("JS dispatch failed: {e}"))?; + + let (write_res, _) = stream.write_all(result).await; + write_res?; - let response_bytes = build_dispatch_response_bytes(parsed, keep_alive); - let (write_result, _) = stream.write_all(response_bytes).await; - write_result?; Ok(()) } async fn write_not_found_response(stream: &mut TcpStream, keep_alive: bool) -> Result<()> { - let response = build_response_bytes( - 404, - &[( - "content-type".to_string(), - "application/json; charset=utf-8".to_string(), - )], - Bytes::from_static(NOT_FOUND_BODY), - keep_alive, - ); - let (write_result, _) = stream.write_all(response).await; - write_result?; + let response = build_error_response_bytes(404, b"{\"error\":\"Not Found\"}", keep_alive); + let (res, _) = stream.write_all(response).await; + res?; Ok(()) } -fn build_dispatch_response_bytes(response: DispatchResponseEnvelope, keep_alive: bool) -> Vec { - build_response_bytes( - response.status, - &response.headers, - response.body, - keep_alive, - ) -} +// ─── Helpers ────────────────────────────────────────────────────────────────── -fn build_response_bytes( - status: u16, - headers: &[(String, String)], - body: Bytes, - keep_alive: bool, -) -> Vec { - let mut output = format!( - "HTTP/1.1 {} {}\r\ncontent-length: {}\r\nconnection: {}\r\n", +fn build_error_response_bytes(status: u16, body: &[u8], keep_alive: bool) -> Vec { + let mut response = format!( + "HTTP/1.1 {} {}\r\ncontent-length: {}\r\ncontent-type: application/json\r\nconnection: {}\r\n\r\n", status, status_reason(status), body.len(), if keep_alive { "keep-alive" } else { "close" } ) .into_bytes(); + response.extend_from_slice(body); + response +} - for (name, value) in headers { - if name.eq_ignore_ascii_case("content-length") || name.eq_ignore_ascii_case("connection") { - continue; - } - - output.extend_from_slice(name.as_bytes()); - output.extend_from_slice(b": "); - output.extend_from_slice(value.as_bytes()); - output.extend_from_slice(b"\r\n"); +fn drain_consumed_bytes(buffer: &mut Vec, consumed: usize) { + if consumed >= buffer.len() { + buffer.clear(); + } else { + buffer.drain(..consumed); } - - output.extend_from_slice(b"\r\n"); - output.extend_from_slice(body.as_ref()); - output } fn status_reason(status: u16) -> &'static str { match status { 200 => "OK", 201 => "Created", - 202 => "Accepted", 204 => "No Content", 400 => "Bad Request", + 401 => "Unauthorized", + 403 => "Forbidden", 404 => "Not Found", + 411 => "Length Required", + 413 => "Payload Too Large", + 431 => "Request Header Fields Too Large", 500 => "Internal Server Error", - 502 => "Bad Gateway", - 503 => "Service Unavailable", _ => "OK", } } -fn push_string_pair(frame: &mut Vec, name: &str, value: &str) -> Result<()> { - if name.len() > u8::MAX as usize { - return Err(anyhow!("field name too long")); - } - if value.len() > u16::MAX as usize { - return Err(anyhow!("field value too long")); - } - - frame.push(name.len() as u8); - push_u16(frame, value.len() as u16); - frame.extend_from_slice(name.as_bytes()); - frame.extend_from_slice(value.as_bytes()); - Ok(()) -} - -fn push_string_value(frame: &mut Vec, value: &str) -> Result<()> { - if value.len() > u16::MAX as usize { - return Err(anyhow!("field value too long")); - } - - push_u16(frame, value.len() as u16); - frame.extend_from_slice(value.as_bytes()); - Ok(()) -} - -fn push_u16(frame: &mut Vec, value: u16) { - frame.extend_from_slice(&value.to_le_bytes()); -} - -fn push_u32(frame: &mut Vec, value: u32) { - frame.extend_from_slice(&value.to_le_bytes()); -} - -fn read_u8(bytes: &[u8], offset: &mut usize) -> Result { - if *offset + 1 > bytes.len() { - return Err(anyhow!("response envelope truncated")); - } - - let value = bytes[*offset]; - *offset += 1; - Ok(value) -} - -fn read_u16(bytes: &[u8], offset: &mut usize) -> Result { - if *offset + 2 > bytes.len() { - return Err(anyhow!("response envelope truncated")); +fn method_code_from_bytes(method: &[u8]) -> Option { + match method { + b"GET" => Some(1), + b"POST" => Some(2), + b"PUT" => Some(3), + b"DELETE" => Some(4), + b"PATCH" => Some(5), + b"OPTIONS" => Some(6), + b"HEAD" => Some(7), + _ => None, } - - let value = u16::from_le_bytes([bytes[*offset], bytes[*offset + 1]]); - *offset += 2; - Ok(value) } -fn read_u32(bytes: &[u8], offset: &mut usize) -> Result { - if *offset + 4 > bytes.len() { - return Err(anyhow!("response envelope truncated")); - } - - let value = u32::from_le_bytes([ - bytes[*offset], - bytes[*offset + 1], - bytes[*offset + 2], - bytes[*offset + 3], - ]); - *offset += 4; - Ok(value) +fn write_usize(output: &mut Vec, value: usize) { + output.extend_from_slice(&(value as u32).to_le_bytes()); } -fn read_utf8(bytes: &[u8], offset: &mut usize, length: usize) -> Result { - if *offset + length > bytes.len() { - return Err(anyhow!("response envelope truncated")); +fn normalize_runtime_path(path: &str) -> Cow<'_, str> { + if path == "/" || !path.ends_with('/') { + Cow::Borrowed(path) + } else { + Cow::Owned(path.trim_end_matches('/').to_string()) } - - let value = std::str::from_utf8(&bytes[*offset..*offset + length]) - .context("response envelope contained invalid utf-8")? - .to_string(); - *offset += length; - Ok(value) -} - -struct RequestHead<'a> { - method: &'a [u8], - target: &'a [u8], - path: &'a [u8], - keep_alive: bool, - header_bytes: usize, - has_body: bool, -} - -fn parse_request_head(bytes: &[u8]) -> Option> { - let header_end = find_header_end(bytes)?; - let line_end = memmem::find(bytes, b"\r\n")?; - let request_line = &bytes[..line_end]; - - let first_space = memchr(b' ', request_line)?; - let second_space = memchr(b' ', &request_line[first_space + 1..])? + first_space + 1; - - let method = &request_line[..first_space]; - let target = &request_line[first_space + 1..second_space]; - let version = &request_line[second_space + 1..]; - let path = target.split(|byte| *byte == b'?').next()?; - - let (keep_alive, has_body) = scan_header_controls( - bytes, - line_end + 2, - header_end, - version.eq_ignore_ascii_case(b"HTTP/1.1"), - b"connection:", - b"content-length:", - b"transfer-encoding:", - )?; - - Some(RequestHead { - method, - target, - path, - keep_alive, - header_bytes: header_end + 4, - has_body, - }) } -fn parse_hot_root_request_head( - bytes: &[u8], - server_config: &HttpServerConfig, -) -> Option> { - let (request_line_len, keep_alive) = - if bytes.starts_with(server_config.hot_get_root_http11.as_slice()) { - (server_config.hot_get_root_http11.len(), true) - } else if bytes.starts_with(server_config.hot_get_root_http10.as_slice()) { - (server_config.hot_get_root_http10.len(), false) - } else { - return None; - }; - - let header_end = find_header_end(bytes)?; - let (keep_alive, has_body) = scan_header_controls( - bytes, - request_line_len, - header_end, - keep_alive, - server_config.header_connection_prefix.as_slice(), - server_config.header_content_length_prefix.as_slice(), - server_config.header_transfer_encoding_prefix.as_slice(), - )?; - - Some(RequestHead { - method: b"GET", - target: b"/", - path: b"/", - keep_alive, - header_bytes: header_end + 4, - has_body, - }) +fn contains_path_traversal(path: &str) -> bool { + path.contains("/../") || path.contains("\\..\\") || path.ends_with("/..") || path.ends_with("\\..") } fn find_header_end(bytes: &[u8]) -> Option { memmem::find(bytes, b"\r\n\r\n") } -fn scan_header_controls( - bytes: &[u8], - mut line_start: usize, - header_end: usize, - mut keep_alive: bool, - connection_prefix: &[u8], - content_length_prefix: &[u8], - transfer_encoding_prefix: &[u8], -) -> Option<(bool, bool)> { - let mut has_body = false; - - while line_start <= header_end { - let next_end = find_line_end(bytes, line_start, header_end)?; - if next_end == line_start { - break; - } - - let line = &bytes[line_start..next_end]; - match ascii_lowercase(line[0]) { - b'c' => { - if line.len() >= connection_prefix.len() - && line[..connection_prefix.len()].eq_ignore_ascii_case(connection_prefix) - { - let value = &line[connection_prefix.len()..]; - if contains_ascii_case_insensitive(value, b"close") { - keep_alive = false; - } - if contains_ascii_case_insensitive(value, b"keep-alive") { - keep_alive = true; - } - } else if line.len() >= content_length_prefix.len() - && line[..content_length_prefix.len()] - .eq_ignore_ascii_case(content_length_prefix) - { - let value = trim_ascii_spaces(&line[content_length_prefix.len()..]); - if value != b"0" { - has_body = true; - } - } - } - b't' => { - if line.len() >= transfer_encoding_prefix.len() - && line[..transfer_encoding_prefix.len()] - .eq_ignore_ascii_case(transfer_encoding_prefix) - { - let value = trim_ascii_spaces(&line[transfer_encoding_prefix.len()..]); - if !value.is_empty() && !value.eq_ignore_ascii_case(b"identity") { - has_body = true; - } - } - } - _ => {} - } - - line_start = next_end + 2; - } - - Some((keep_alive, has_body)) +fn contains_ascii_case_insensitive(haystack: &[u8], needle: &[u8]) -> bool { + haystack.windows(needle.len()).any(|w| w.eq_ignore_ascii_case(needle)) } -fn find_line_end(bytes: &[u8], line_start: usize, header_end: usize) -> Option { - let relative_end = memchr(b'\r', &bytes[line_start..header_end + 2])?; - let line_end = line_start + relative_end; - if bytes.get(line_end + 1) != Some(&b'\n') { - return None; - } - - Some(line_end) +fn should_enable_nodelay() -> bool { + std::env::var("HTTP_NATIVE_TCP_NODELAY") + .ok() + .map(|v| !matches!(v.trim().to_lowercase().as_str(), "0" | "false" | "off" | "no")) + .unwrap_or(true) } -fn ascii_lowercase(byte: u8) -> u8 { - if byte.is_ascii_uppercase() { - byte + 32 - } else { - byte - } -} +fn bind_listener(options: &NativeListenOptions, config: &HttpServerConfig) -> Result { + let host = options.host.as_deref().unwrap_or(&config.default_host); + let addr = (host, options.port).to_socket_addrs()?.next() + .ok_or_else(|| anyhow::anyhow!("Failed to resolve address {host}:{}", options.port))?; -fn contains_ascii_case_insensitive(haystack: &[u8], needle: &[u8]) -> bool { - if needle.is_empty() || haystack.len() < needle.len() { - return false; + let mut opts = monoio::net::ListenerOpts::new() + .reuse_addr(true) + .backlog(options.backlog.unwrap_or(config.default_backlog)); + + if worker_count_for(options) > 1 { + opts = opts.reuse_port(true); } - haystack - .windows(needle.len()) - .any(|window| window.eq_ignore_ascii_case(needle)) + TcpListener::bind_with_config(addr, &opts) } -fn trim_ascii_spaces(bytes: &[u8]) -> &[u8] { - let start = bytes - .iter() - .position(|byte| !byte.is_ascii_whitespace()) - .unwrap_or(bytes.len()); - let end = bytes - .iter() - .rposition(|byte| !byte.is_ascii_whitespace()) - .map(|index| index + 1) - .unwrap_or(start); - &bytes[start..end] +fn worker_count_for(options: &NativeListenOptions) -> usize { + options.workers.unwrap_or_else(num_cpus::get) } -fn config_string( - input: Option<&HttpServerConfigInput>, - pick: impl Fn(&HttpServerConfigInput) -> Option<&str>, - fallback: &str, -) -> String { - input.and_then(pick).unwrap_or(fallback).to_string() -} +// ─── NAPI Glue ──────────────────────────────────────────────────────────────── -fn normalize_runtime_path(path: &str) -> Cow<'_, str> { - if path == "/" || !path.ends_with('/') { - return Cow::Borrowed(path); - } +#[napi(object)] +#[derive(Clone, Default)] +pub struct NativeListenOptions { + pub host: Option, + pub port: u16, + pub backlog: Option, + pub workers: Option, +} - Cow::Owned(crate::analyzer::normalize_path(path)) +#[napi] +pub struct ServerHandle { + workers: Vec>, } -fn escape_json(value: &str) -> String { - value.replace('\\', "\\\\").replace('"', "\\\"") +#[napi] +impl ServerHandle { + #[napi] + pub fn close(&mut self) { + // In a real implementation, we'd send a shutdown signal. + // For now, we just let them drop or kill the process. + } } -fn to_napi_error(error: E) -> Error -where - E: std::fmt::Display, -{ - Error::from_reason(error.to_string()) +impl HttpServerConfig { + fn from_input(input: Option<&HttpServerConfigInput>) -> Self { + Self { + default_host: input.and_then(|i| i.default_host.clone()).unwrap_or_else(|| "127.0.0.1".to_string()), + default_backlog: input.and_then(|i| i.default_backlog).unwrap_or(2048), + max_header_bytes: input.and_then(|i| i.max_header_bytes).unwrap_or(8192), + hot_get_root_http11: b"GET / HTTP/1.1\r\n".to_vec(), + hot_get_root_http10: b"GET / HTTP/1.0\r\n".to_vec(), + header_connection_prefix: b"connection:".to_vec(), + header_content_length_prefix: b"content-length:".to_vec(), + header_transfer_encoding_prefix: b"transfer-encoding:".to_vec(), + } + } } diff --git a/rust-native/src/router.rs b/rust-native/src/router.rs index 305fba4..d893391 100644 --- a/rust-native/src/router.rs +++ b/rust-native/src/router.rs @@ -10,32 +10,17 @@ use crate::manifest::{ManifestInput, RouteInput}; const ROUTE_KIND_EXACT: u8 = 1; const ROUTE_KIND_PARAM: u8 = 2; +// ─── Public Types ───────────────────────────────────────────────────────────── + #[derive(Clone)] pub struct Router { exact_get_root: Option, + /// O(1) exact-match routes (HashMap>) dynamic_exact_routes: HashMap, DynamicRouteSpec>>, - dynamic_param_routes: HashMap>>, + /// O(1) static-response routes exact_static_routes: HashMap, ExactStaticRoute>>, -} - -#[derive(Clone)] -struct ParamRoute { - spec: DynamicRouteSpec, - segments: Vec, -} - -#[derive(Clone)] -struct DynamicRouteSpec { - handler_id: u32, - param_names: Box<[Box]>, - header_keys: Box<[Box]>, - full_headers: bool, -} - -#[derive(Clone)] -enum CompiledSegment { - Static(Box), - Param, + /// O(M) radix-tree routes per method (M = path length) + radix_trees: HashMap, } #[derive(Clone)] @@ -51,6 +36,16 @@ pub struct MatchedRoute<'a, 'b> { pub full_headers: bool, } +// ─── Internal Types ─────────────────────────────────────────────────────────── + +#[derive(Clone)] +struct DynamicRouteSpec { + handler_id: u32, + param_names: Box<[Box]>, + header_keys: Box<[Box]>, + full_headers: bool, +} + #[derive(Clone, Copy, Eq, Hash, PartialEq)] enum MethodKey { Delete, @@ -62,12 +57,126 @@ enum MethodKey { Put, } +// ─── Radix Tree ─────────────────────────────────────────────────────────────── +// +// Each node represents either a static prefix or a parameter capture. +// Matching is O(M) where M is the number of path segments, not O(N) routes. + +#[derive(Clone)] +struct RadixNode { + children: Vec, + /// If this node is a terminal route, the handler spec + handler: Option, + /// Parameter capture node for this level + param_child: Option>, +} + +#[derive(Clone)] +struct RadixChild { + /// The static segment this child matches + segment: Box, + node: RadixNode, +} + +#[derive(Clone)] +struct RadixParamChild { + node: RadixNode, +} + +impl RadixNode { + fn new() -> Self { + Self { + children: Vec::new(), + handler: None, + param_child: None, + } + } + + /// Insert a route into the radix tree + fn insert(&mut self, segments: &[RouteSegment], spec: DynamicRouteSpec) { + if segments.is_empty() { + self.handler = Some(spec); + return; + } + + match &segments[0] { + RouteSegment::Static(value) => { + // Find existing child with this segment + for child in &mut self.children { + if child.segment.as_ref() == value.as_str() { + child.node.insert(&segments[1..], spec); + return; + } + } + + // Create new child + let mut child_node = RadixNode::new(); + child_node.insert(&segments[1..], spec); + self.children.push(RadixChild { + segment: value.clone().into_boxed_str(), + node: child_node, + }); + } + RouteSegment::Param(_) => { + if self.param_child.is_none() { + self.param_child = Some(Box::new(RadixParamChild { + node: RadixNode::new(), + })); + } + self.param_child + .as_mut() + .unwrap() + .node + .insert(&segments[1..], spec); + } + } + } + + /// Match a request path against this radix tree — O(M) where M = segment count + fn match_path<'a, 'b>( + &'a self, + segments: &[&'b str], + param_values: &mut Vec<&'b str>, + ) -> Option<&'a DynamicRouteSpec> { + if segments.is_empty() { + return self.handler.as_ref(); + } + + let segment = segments[0]; + let rest = &segments[1..]; + + // Try static children first (higher priority) + for child in &self.children { + if child.segment.as_ref() == segment { + if let Some(spec) = child.node.match_path(rest, param_values) { + return Some(spec); + } + } + } + + // Try parameter capture + if let Some(param_child) = &self.param_child { + let prev_len = param_values.len(); + param_values.push(segment); + if let Some(spec) = param_child.node.match_path(rest, param_values) { + return Some(spec); + } + // Backtrack + param_values.truncate(prev_len); + } + + None + } +} + +// ─── Router Implementation ──────────────────────────────────────────────────── + impl Router { pub fn from_manifest(manifest: &ManifestInput) -> Result { let mut exact_get_root = None; let mut dynamic_exact_routes = HashMap::new(); - let mut dynamic_param_routes = HashMap::new(); let mut exact_static_routes = HashMap::new(); + let mut radix_trees: HashMap = HashMap::new(); for route in &manifest.routes { let method = route.method.to_uppercase(); @@ -119,13 +228,13 @@ impl Router { ); } ROUTE_KIND_PARAM => { - let compiled = compile_param_route(route, path.as_str()); - dynamic_param_routes + // Insert into radix tree instead of linear Vec + let segments = parse_segments(path.as_str()); + let spec = compile_dynamic_route_spec(route); + radix_trees .entry(method_key) - .or_insert_with(HashMap::new) - .entry(route.segment_count) - .or_insert_with(Vec::new) - .push(compiled); + .or_insert_with(RadixNode::new) + .insert(&segments, spec); } _ => {} } @@ -134,8 +243,8 @@ impl Router { Ok(Self { exact_get_root, dynamic_exact_routes, - dynamic_param_routes, exact_static_routes, + radix_trees, }) } @@ -146,6 +255,7 @@ impl Router { ) -> Option> { let method_key = MethodKey::from_code(method_code)?; + // Fast path: exact match (O(1)) if let Some(route_spec) = self .dynamic_exact_routes .get(&method_key) @@ -159,40 +269,18 @@ impl Router { }); } - let path_segments = split_request_segments(path); - let param_routes = self - .dynamic_param_routes - .get(&method_key) - .and_then(|routes| routes.get(&(path_segments.len() as u16)))?; - - for route in param_routes { - let mut param_values = Vec::with_capacity(route.spec.param_names.len()); - let mut matched = true; - - for (segment, path_segment) in route.segments.iter().zip(path_segments.iter()) { - match segment { - CompiledSegment::Static(value) if value.as_ref() == *path_segment => {} - CompiledSegment::Static(_) => { - matched = false; - break; - } - CompiledSegment::Param => { - param_values.push(*path_segment); - } - } - } - - if matched { - return Some(MatchedRoute { - handler_id: route.spec.handler_id, - param_values, - header_keys: route.spec.header_keys.as_ref(), - full_headers: route.spec.full_headers, - }); - } - } - - None + // Radix tree match (O(M) where M = segment count) + let segments = split_request_segments(path); + let tree = self.radix_trees.get(&method_key)?; + let mut param_values = Vec::new(); + let spec = tree.match_path(&segments, &mut param_values)?; + + Some(MatchedRoute { + handler_id: spec.handler_id, + param_values, + header_keys: spec.header_keys.as_ref(), + full_headers: spec.full_headers, + }) } pub fn exact_static_route(&self, method: &[u8], path: &[u8]) -> Option<&ExactStaticRoute> { @@ -211,6 +299,8 @@ impl Router { } } +// ─── MethodKey ──────────────────────────────────────────────────────────────── + impl MethodKey { fn from_method_str(method: &str) -> Option { match method { @@ -252,20 +342,7 @@ impl MethodKey { } } -fn compile_param_route(route: &RouteInput, path: &str) -> ParamRoute { - let segments = parse_segments(path) - .into_iter() - .map(|segment| match segment { - RouteSegment::Static(value) => CompiledSegment::Static(value.into_boxed_str()), - RouteSegment::Param(_) => CompiledSegment::Param, - }) - .collect(); - - ParamRoute { - spec: compile_dynamic_route_spec(route), - segments, - } -} +// ─── Helpers ────────────────────────────────────────────────────────────────── fn compile_dynamic_route_spec(route: &RouteInput) -> DynamicRouteSpec { let param_names = route @@ -328,6 +405,10 @@ fn build_response_bytes( .into_bytes(); for (name, value) in headers { + // Security: skip headers with CRLF injection + if name.contains('\r') || name.contains('\n') || value.contains('\r') || value.contains('\n') { + continue; + } response.extend_from_slice(name.as_bytes()); response.extend_from_slice(b": "); response.extend_from_slice(value.as_bytes()); diff --git a/src/bridge.js b/src/bridge.js index 094d230..86a6d55 100644 --- a/src/bridge.js +++ b/src/bridge.js @@ -7,6 +7,7 @@ const EMPTY_ARRAY = Object.freeze([]); export const BRIDGE_VERSION = 1; export const REQUEST_FLAG_QUERY_PRESENT = 1 << 0; +export const REQUEST_FLAG_BODY_PRESENT = 1 << 1; export const METHOD_CODES = Object.freeze({ GET: 1, @@ -23,15 +24,42 @@ export const ROUTE_KIND = Object.freeze({ PARAM: 2, }); -const EMPTY_OBJECT = Object.freeze({}); +// Security: use null-prototype objects for user-facing data to prevent prototype pollution +const EMPTY_OBJECT = Object.freeze(Object.create(null)); + +// ─── Regex patterns for source analysis ─────────────────────────────────────── const PARAM_DOT_RE = /\breq\.params\.([A-Za-z_$][\w$]*)\b/g; -const PARAM_BRACKET_RE = /\breq\.params\[(["'])([^"'\\]+)\1\]/g; +const PARAM_BRACKET_RE = /\breq\.params\[(['"])([^"'\\]+)\1\]/g; const QUERY_DOT_RE = /\breq\.query\.([A-Za-z_$][\w$]*)\b/g; -const QUERY_BRACKET_RE = /\breq\.query\[(["'])([^"'\\]+)\1\]/g; +const QUERY_BRACKET_RE = /\breq\.query\[(['"])([^"'\\]+)\1\]/g; const HEADER_DOT_RE = /\breq\.headers\.([A-Za-z_$][\w$]*)\b/g; -const HEADER_BRACKET_RE = /\breq\.headers\[(["'])([^"'\\]+)\1\]/g; -const HEADER_CALL_RE = /\breq\.header\((["'])([^"'\\]+)\1\)/g; +const HEADER_BRACKET_RE = /\breq\.headers\[(['"])([^"'\\]+)\1\]/g; +const HEADER_CALL_RE = /\breq\.header\((['"])([^"'\\]+)\1\)/g; + +// Pre-compiled patterns for full-access detection (using RegExp constructor to avoid escaping issues) +const PARAMS_FULL_ACCESS_RE = new RegExp('\\breq\\.params\\b(?!\\s*(?:\\.|\\[))'); +const PARAMS_DYN_BRACKET_RE = new RegExp("\\breq\\.params\\[(?!['\"])"); +const QUERY_FULL_ACCESS_RE = new RegExp('\\breq\\.query\\b(?!\\s*(?:\\.|\\[))'); +const QUERY_DYN_BRACKET_RE = new RegExp("\\breq\\.query\\[(?!['\"])"); +const HEADERS_FULL_ACCESS_RE = new RegExp('\\breq\\.headers\\b(?!\\s*(?:\\.|\\[))'); +const HEADERS_DYN_BRACKET_RE = new RegExp("\\breq\\.headers\\[(?!['\"])"); +const REQ_BRACKET_STR_RE = new RegExp("\\breq\\s*\\[(['\"])[^\"'\\\\]+\\1\\]"); +const REQ_BRACKET_DYN_RE = new RegExp("\\breq\\s*\\[(?!['\"])"); +const HEADER_CALL_DYN_RE = new RegExp("\\breq\\.header\\((?!['\"])"); + +// Security: dangerous prototype keys that must never be allowed in user objects +const DANGEROUS_KEYS = new Set([ + "__proto__", + "constructor", + "prototype", + "__defineGetter__", + "__defineSetter__", + "__lookupGetter__", + "__lookupSetter__", +]); + +// ─── Route Compilation ──────────────────────────────────────────────────────── export function compileRouteShape(method, path) { const methodCode = METHOD_CODES[method]; @@ -62,6 +90,8 @@ export function compileRouteShape(method, path) { }; } +// ─── Request Access Analysis ────────────────────────────────────────────────── + export function analyzeRequestAccess(source) { const plan = createEmptyAccessPlan(); const normalizedSource = String(source ?? ""); @@ -88,22 +118,22 @@ export function analyzeRequestAccess(source) { collectMatches(normalizedSource, HEADER_BRACKET_RE, plan.headerKeys, normalizeHeaderLookup, 2); collectMatches(normalizedSource, HEADER_CALL_RE, plan.headerKeys, normalizeHeaderLookup, 2); - if (/\breq\.params\b(?!\s*(?:\.|\[))/.test(normalizedSource) || /\breq\.params\[(?!["'])/.test(normalizedSource)) { + if (PARAMS_FULL_ACCESS_RE.test(normalizedSource) || PARAMS_DYN_BRACKET_RE.test(normalizedSource)) { plan.fullParams = true; plan.dispatchKind = "generic_fallback"; } - if (/\breq\.query\b(?!\s*(?:\.|\[))/.test(normalizedSource) || /\breq\.query\[(?!["'])/.test(normalizedSource)) { + if (QUERY_FULL_ACCESS_RE.test(normalizedSource) || QUERY_DYN_BRACKET_RE.test(normalizedSource)) { plan.fullQuery = true; plan.dispatchKind = "generic_fallback"; } - if (/\breq\.headers\b(?!\s*(?:\.|\[))/.test(normalizedSource) || /\breq\.headers\[(?!["'])/.test(normalizedSource)) { + if (HEADERS_FULL_ACCESS_RE.test(normalizedSource) || HEADERS_DYN_BRACKET_RE.test(normalizedSource)) { plan.fullHeaders = true; plan.dispatchKind = "generic_fallback"; } - if (/\breq\s*\[(["'])[^"'\\]+\1\]/.test(normalizedSource) || /\breq\s*\[(?!["'])/.test(normalizedSource)) { + if (REQ_BRACKET_STR_RE.test(normalizedSource) || REQ_BRACKET_DYN_RE.test(normalizedSource)) { plan.method = true; plan.path = true; plan.url = true; @@ -113,7 +143,7 @@ export function analyzeRequestAccess(source) { plan.dispatchKind = "generic_fallback"; } - if (/\breq\.header\((?!["'])/.test(normalizedSource)) { + if (HEADER_CALL_DYN_RE.test(normalizedSource)) { plan.fullHeaders = true; plan.dispatchKind = "generic_fallback"; } @@ -152,106 +182,211 @@ export function mergeRequestAccessPlans(plans) { return freezeAccessPlan(merged); } -export function createRequestFactory( - plan, - routeParamNames = EMPTY_ARRAY, - routeMethod = "GET", -) { - return function buildRequest(decoded) { - const needsParams = plan.fullParams || plan.paramKeys.size > 0; - const needsQuery = plan.fullQuery || plan.queryKeys.size > 0; - const needsHeaders = plan.fullHeaders || plan.headerKeys.size > 0; - - let path; - let url; - let params; - let query; - let headers; - - function decodePath() { - if (path === undefined) { - path = textDecoder.decode(decoded.pathBytes); +// ─── Request Factory (with Object Pooling) ──────────────────────────────────── + +// Pool for request objects — avoids per-request allocations +const REQUEST_POOL_MAX = 512; +const requestPool = []; + +function acquireRequestObject() { + return requestPool.pop() || null; +} + +function releaseRequestObject(req) { + if (requestPool.length >= REQUEST_POOL_MAX) { + return; + } + // Reset all fields before pooling + req.method = ""; + req._path = undefined; + req._url = undefined; + req._params = undefined; + req._query = undefined; + req._headers = undefined; + req._decoded = null; + req._routeParamNames = null; + req._plan = null; + req._routeMethod = null; + requestPool.push(req); +} + +function createPooledRequest() { + const req = Object.create(null); + + // Internal state + req._path = undefined; + req._url = undefined; + req._params = undefined; + req._query = undefined; + req._headers = undefined; + req._bodyParsed = undefined; + req._decoded = null; + req._routeParamNames = null; + req._plan = null; + req._routeMethod = null; + req.method = ""; + + Object.defineProperty(req, "path", { + configurable: true, + enumerable: true, + get() { + if (req._path === undefined) { + req._path = textDecoder.decode(req._decoded.pathBytes); + } + return req._path; + }, + }); + + Object.defineProperty(req, "url", { + configurable: true, + enumerable: true, + get() { + if (req._url === undefined) { + req._url = textDecoder.decode(req._decoded.urlBytes); + } + return req._url; + }, + }); + + Object.defineProperty(req, "params", { + configurable: true, + enumerable: true, + get() { + if (req._params === undefined) { + const needsParams = req._plan.fullParams || req._plan.paramKeys.size > 0; + req._params = needsParams + ? materializeParamObject(req._decoded.paramValues, req._routeParamNames, req._plan) + : EMPTY_OBJECT; + } + return req._params; + }, + }); + + Object.defineProperty(req, "query", { + configurable: true, + enumerable: true, + get() { + if (req._query === undefined) { + const needsQuery = req._plan.fullQuery || req._plan.queryKeys.size > 0; + if (!needsQuery) { + req._query = EMPTY_OBJECT; + } else { + // Compute URL for query parsing + if (req._url === undefined) { + req._url = textDecoder.decode(req._decoded.urlBytes); + } + req._query = materializeQueryObject(req._url, req._decoded.flags, req._plan); + } + } + return req._query; + }, + }); + + Object.defineProperty(req, "headers", { + configurable: true, + enumerable: true, + get() { + if (req._headers === undefined) { + const needsHeaders = req._plan.fullHeaders || req._plan.headerKeys.size > 0; + req._headers = needsHeaders + ? materializeHeaderObject(req._decoded.rawHeaders, req._plan) + : EMPTY_OBJECT; } - return path; + return req._headers; + }, + }); + + req.header = function header(name) { + const lookup = normalizeHeaderLookup(name); + if (req._headers && lookup in req._headers) { + return req._headers[lookup]; } + if (req._decoded.rawHeaders.length === 0) { + return undefined; + } + return lookupHeaderValue(req._decoded.rawHeaders, lookup); + }; - function decodeUrl() { - if (url === undefined) { - url = textDecoder.decode(decoded.urlBytes); + // ─── Body APIs ──────────────────────────────────────────────────────── + + Object.defineProperty(req, "body", { + configurable: true, + enumerable: true, + get() { + if (req._decoded.bodyBytes === null) { + return null; } - return url; + return Buffer.from(req._decoded.bodyBytes.buffer, req._decoded.bodyBytes.byteOffset, req._decoded.bodyBytes.byteLength); + }, + }); + + req.json = function json() { + if (req._bodyParsed !== undefined) { + return req._bodyParsed; + } + if (req._decoded.bodyBytes === null || req._decoded.bodyBytes.length === 0) { + req._bodyParsed = null; + return null; } + const text = textDecoder.decode(req._decoded.bodyBytes); + req._bodyParsed = JSON.parse(text); + return req._bodyParsed; + }; - const request = { - method: routeMethod, + req.text = function text() { + if (req._decoded.bodyBytes === null || req._decoded.bodyBytes.length === 0) { + return ""; + } + return textDecoder.decode(req._decoded.bodyBytes); + }; - get path() { - return decodePath(); - }, + req.arrayBuffer = function arrayBuffer() { + if (req._decoded.bodyBytes === null) { + return new ArrayBuffer(0); + } + return req._decoded.bodyBytes.buffer.slice( + req._decoded.bodyBytes.byteOffset, + req._decoded.bodyBytes.byteOffset + req._decoded.bodyBytes.byteLength, + ); + }; - get url() { - return decodeUrl(); - }, + return req; +} - get params() { - if (params === undefined) { - params = needsParams - ? materializeParamObject(decoded.paramValues, routeParamNames, plan) - : EMPTY_OBJECT; - } - return params; - }, - - get query() { - if (query === undefined) { - query = needsQuery - ? materializeQueryObject(decodeUrl(), decoded.flags, plan) - : EMPTY_OBJECT; - } - return query; - }, - - get headers() { - if (headers === undefined) { - headers = needsHeaders - ? materializeHeaderObject(decoded.rawHeaders, plan) - : EMPTY_OBJECT; - } - return headers; - }, +export function createRequestFactory( + plan, + routeParamNames = EMPTY_ARRAY, + routeMethod = "GET", +) { + return function buildRequest(decoded) { + let request = acquireRequestObject(); + if (!request) { + request = createPooledRequest(); + } - header(name) { - const lookup = normalizeHeaderLookup(name); - if (headers && lookup in headers) { - return headers[lookup]; - } - if (decoded.rawHeaders.length === 0) { - return undefined; - } - return lookupHeaderValue(decoded.rawHeaders, lookup); - }, - }; + // Initialize for this request + request._decoded = decoded; + request._routeParamNames = routeParamNames; + request._plan = plan; + request._routeMethod = routeMethod; + request.method = routeMethod; + request._path = undefined; + request._url = undefined; + request._params = undefined; + request._query = undefined; + request._headers = undefined; + request._bodyParsed = undefined; return request; }; } -export function createJsonSerializer(mode = "fallback") { - if (mode === "fallback") { - const serializer = (value) => { - const serialized = JSON.stringify(value); - return Buffer.from(serialized, "utf8"); - }; - serializer.kind = "fallback"; - return serializer; - } +// ─── JSON Serialization ─────────────────────────────────────────────────────── +export function createJsonSerializer(mode = "fallback") { + // Performance: V8's native JSON.stringify is heavily optimized and almost always + // faster than any JS-level reimplementation. Use it directly. const serializer = (value) => { - const fastValue = trySerializeJsonFast(value); - if (fastValue !== null) { - return Buffer.from(fastValue, "utf8"); - } - const serialized = JSON.stringify(value); return Buffer.from(serialized, "utf8"); }; @@ -259,6 +394,8 @@ export function createJsonSerializer(mode = "fallback") { return serializer; } +// ─── Binary Protocol Codec ──────────────────────────────────────────────────── + export function decodeRequestEnvelope(buffer) { const bytes = buffer instanceof Uint8Array ? buffer : new Uint8Array(buffer); const view = new DataView(bytes.buffer, bytes.byteOffset, bytes.byteLength); @@ -285,6 +422,9 @@ export function decodeRequestEnvelope(buffer) { const headerCount = readU16(view, offset); offset += 2; + const bodyLength = readU32(view, offset); + offset += 4; + const urlBytes = readBytes(bytes, offset, urlLength); offset += urlLength; const pathBytes = readBytes(bytes, offset, pathLength); @@ -312,6 +452,10 @@ export function decodeRequestEnvelope(buffer) { rawHeaders[index] = [nameBytes, valueBytes]; } + // Read body bytes + const bodyBytes = bodyLength > 0 ? readBytes(bytes, offset, bodyLength) : null; + offset += bodyLength; + switch (methodCode) { case METHOD_CODES.GET: case METHOD_CODES.POST: @@ -333,6 +477,7 @@ export function decodeRequestEnvelope(buffer) { pathBytes, paramValues, rawHeaders, + bodyBytes, }; } @@ -384,6 +529,10 @@ export function encodeResponseEnvelope(snapshot) { return output; } +// ─── Object Materialization (Security-Hardened) ─────────────────────────────── +// +// All user-facing objects use Object.create(null) to prevent prototype pollution. + function materializeParamObject(entries, paramNames, plan) { if (plan.fullParams) { return materializeParamPairs(entries, paramNames); @@ -402,7 +551,7 @@ function materializeHeaderObject(entries, plan) { function materializeQueryObject(url, flags, plan) { if (!(flags & REQUEST_FLAG_QUERY_PRESENT)) { - return {}; + return Object.create(null); } if (plan.fullQuery) { @@ -413,11 +562,16 @@ function materializeQueryObject(url, flags, plan) { } function materializePairs(entries, lowerCaseKeys = false) { - const result = {}; + // Security: null-prototype object prevents prototype pollution + const result = Object.create(null); for (const [rawName, rawValue] of entries) { const name = textDecoder.decode(rawName); const key = lowerCaseKeys ? name.toLowerCase() : name; + // Security: skip dangerous prototype keys + if (DANGEROUS_KEYS.has(key)) { + continue; + } result[key] = textDecoder.decode(rawValue); } @@ -425,10 +579,14 @@ function materializePairs(entries, lowerCaseKeys = false) { } function materializeParamPairs(entries, paramNames) { - const result = {}; + const result = Object.create(null); for (let index = 0; index < entries.length; index += 1) { - result[paramNames[index]] = textDecoder.decode(entries[index]); + const key = paramNames[index]; + if (DANGEROUS_KEYS.has(key)) { + continue; + } + result[key] = textDecoder.decode(entries[index]); } return result; @@ -436,13 +594,13 @@ function materializeParamPairs(entries, paramNames) { function materializeSelectedParamPairs(entries, paramNames, selectedKeys) { if (selectedKeys.size === 0) { - return {}; + return Object.create(null); } - const result = {}; + const result = Object.create(null); for (let index = 0; index < entries.length; index += 1) { const key = paramNames[index]; - if (selectedKeys.has(key)) { + if (selectedKeys.has(key) && !DANGEROUS_KEYS.has(key)) { result[key] = textDecoder.decode(entries[index]); } } @@ -452,14 +610,14 @@ function materializeSelectedParamPairs(entries, paramNames, selectedKeys) { function materializeSelectedPairs(entries, selectedKeys, lowerCaseKeys = false) { if (selectedKeys.size === 0) { - return {}; + return Object.create(null); } - const result = {}; + const result = Object.create(null); for (const [rawName, rawValue] of entries) { const name = textDecoder.decode(rawName); const key = lowerCaseKeys ? name.toLowerCase() : name; - if (selectedKeys.has(key)) { + if (selectedKeys.has(key) && !DANGEROUS_KEYS.has(key)) { result[key] = textDecoder.decode(rawValue); } } @@ -470,13 +628,16 @@ function materializeSelectedPairs(entries, selectedKeys, lowerCaseKeys = false) function parseQuery(url) { const queryStart = url.indexOf("?"); if (queryStart < 0 || queryStart === url.length - 1) { - return {}; + return Object.create(null); } const params = new URLSearchParams(url.slice(queryStart + 1)); - const result = {}; + const result = Object.create(null); for (const [key, value] of params) { + if (DANGEROUS_KEYS.has(key)) { + continue; + } pushQueryEntry(result, key, value); } @@ -485,19 +646,19 @@ function parseQuery(url) { function parseSelectedQuery(url, selectedKeys) { if (selectedKeys.size === 0) { - return {}; + return Object.create(null); } const queryStart = url.indexOf("?"); if (queryStart < 0 || queryStart === url.length - 1) { - return {}; + return Object.create(null); } const params = new URLSearchParams(url.slice(queryStart + 1)); - const result = {}; + const result = Object.create(null); for (const [key, value] of params) { - if (selectedKeys.has(key)) { + if (selectedKeys.has(key) && !DANGEROUS_KEYS.has(key)) { pushQueryEntry(result, key, value); } } @@ -530,6 +691,8 @@ function lookupHeaderValue(entries, targetName) { return undefined; } +// ─── Access Plan ────────────────────────────────────────────────────────────── + function createEmptyAccessPlan() { return { method: false, @@ -590,87 +753,6 @@ function addSetEntries(target, source) { } } -function trySerializeJsonFast(value) { - const stack = new WeakSet(); - return serializeJsonValue(value, stack); -} - -function serializeJsonValue(value, stack) { - if (value === null) { - return "null"; - } - - switch (typeof value) { - case "string": - return JSON.stringify(value); - case "number": - return Number.isFinite(value) ? (Object.is(value, -0) ? "0" : String(value)) : "null"; - case "boolean": - return value ? "true" : "false"; - case "undefined": - case "function": - case "symbol": - return null; - case "bigint": - return null; - case "object": - break; - default: - return null; - } - - if (typeof value.toJSON === "function") { - return null; - } - - if (Array.isArray(value)) { - if (stack.has(value)) { - return null; - } - - stack.add(value); - const items = new Array(value.length); - for (let index = 0; index < value.length; index += 1) { - const descriptor = Object.getOwnPropertyDescriptor(value, String(index)); - if (descriptor && (descriptor.get || descriptor.set)) { - stack.delete(value); - return null; - } - - const serialized = serializeJsonValue(value[index], stack); - items[index] = serialized === null ? "null" : serialized; - } - stack.delete(value); - return `[${items.join(",")}]`; - } - - const prototype = Object.getPrototypeOf(value); - if (prototype !== PLAIN_OBJECT_PROTOTYPE && prototype !== null) { - return null; - } - - if (stack.has(value)) { - return null; - } - - stack.add(value); - const entries = []; - for (const key of Object.keys(value)) { - const descriptor = Object.getOwnPropertyDescriptor(value, key); - if (descriptor && (descriptor.get || descriptor.set)) { - stack.delete(value); - return null; - } - - const serializedValue = serializeJsonValue(value[key], stack); - if (serializedValue !== null) { - entries.push(`${JSON.stringify(key)}:${serializedValue}`); - } - } - stack.delete(value); - return `{${entries.join(",")}}`; -} - function identity(value) { return value; } @@ -679,6 +761,8 @@ function encodeUtf8(value) { return textEncoder.encode(String(value)); } +// ─── Binary Protocol Helpers ────────────────────────────────────────────────── + function readBytes(bytes, offset, length) { if (offset + length > bytes.byteLength) { throw new Error("Request envelope truncated"); diff --git a/src/cors.js b/src/cors.js new file mode 100644 index 0000000..3a36758 --- /dev/null +++ b/src/cors.js @@ -0,0 +1,116 @@ +/** + * http-native CORS middleware + * + * Usage: + * import { cors } from "http-native/cors"; + * app.use(cors({ origin: "*" })); + * app.use(cors({ origin: ["https://example.com"], credentials: true })); + */ + +const DEFAULT_METHODS = "GET,HEAD,PUT,PATCH,POST,DELETE"; +const DEFAULT_HEADERS = "Content-Type,Authorization,Accept,X-Requested-With"; + +/** + * @param {Object} [options] + * @param {string | string[] | ((origin: string) => boolean)} [options.origin="*"] + * @param {string | string[]} [options.methods] + * @param {string | string[]} [options.allowedHeaders] + * @param {string | string[]} [options.exposedHeaders] + * @param {boolean} [options.credentials=false] + * @param {number} [options.maxAge] + * @param {boolean} [options.preflight=true] + */ +export function cors(options = {}) { + const { + origin = "*", + methods = DEFAULT_METHODS, + allowedHeaders = DEFAULT_HEADERS, + exposedHeaders, + credentials = false, + maxAge, + preflight = true, + } = options; + + const methodsString = Array.isArray(methods) ? methods.join(",") : methods; + const allowedHeadersString = Array.isArray(allowedHeaders) + ? allowedHeaders.join(",") + : allowedHeaders; + const exposedHeadersString = exposedHeaders + ? Array.isArray(exposedHeaders) + ? exposedHeaders.join(",") + : exposedHeaders + : null; + + // Pre-compute origin matching function + const matchOrigin = buildOriginMatcher(origin); + + return async function corsMiddleware(req, res, next) { + const requestOrigin = req.header("origin") ?? req.headers?.origin; + + // Not a CORS request + if (!requestOrigin) { + await next(); + return; + } + + const allowed = matchOrigin(requestOrigin); + + if (!allowed) { + await next(); + return; + } + + // Set CORS headers + const effectiveOrigin = + origin === "*" && !credentials ? "*" : requestOrigin; + res.set("Access-Control-Allow-Origin", effectiveOrigin); + + if (credentials) { + res.set("Access-Control-Allow-Credentials", "true"); + } + + if (effectiveOrigin !== "*") { + res.set("Vary", "Origin"); + } + + if (exposedHeadersString) { + res.set("Access-Control-Expose-Headers", exposedHeadersString); + } + + // Handle preflight + if (preflight && req.method === "OPTIONS") { + res.set("Access-Control-Allow-Methods", methodsString); + res.set("Access-Control-Allow-Headers", allowedHeadersString); + + if (maxAge !== undefined) { + res.set("Access-Control-Max-Age", String(maxAge)); + } + + res.status(204).send(); + return; + } + + await next(); + }; +} + +function buildOriginMatcher(origin) { + if (origin === "*") { + return () => true; + } + + if (typeof origin === "function") { + return origin; + } + + if (typeof origin === "string") { + return (requestOrigin) => requestOrigin === origin; + } + + if (Array.isArray(origin)) { + const originSet = new Set(origin); + return (requestOrigin) => originSet.has(requestOrigin); + } + + return () => false; +} diff --git a/src/index.d.ts b/src/index.d.ts new file mode 100644 index 0000000..9bfa5d9 --- /dev/null +++ b/src/index.d.ts @@ -0,0 +1,238 @@ +import { Buffer } from "node:buffer"; + +// ─── Core Types ─────────────────────────────────────────────────────────────── + +export interface Request { + /** HTTP method (GET, POST, PUT, DELETE, PATCH, OPTIONS, HEAD) */ + readonly method: string; + + /** URL path without query string */ + readonly path: string; + + /** Full request URL including query string */ + readonly url: string; + + /** Route parameters extracted from path segments (e.g., /users/:id → { id: "42" }) */ + readonly params: Record; + + /** Parsed query string parameters. Multi-value params are arrays. */ + readonly query: Record; + + /** Lowercase request headers as key-value pairs */ + readonly headers: Record; + + /** Raw request body as a Buffer, or null if no body was sent */ + readonly body: Buffer | null; + + /** Get a specific header value by name (case-insensitive) */ + header(name: string): string | undefined; + + /** Parse the request body as JSON */ + json(): T | null; + + /** Get the request body as a UTF-8 string */ + text(): string; + + /** Get the request body as an ArrayBuffer */ + arrayBuffer(): ArrayBuffer; +} + +export interface Response { + /** Whether the response has already been finalized */ + readonly finished: boolean; + + /** Per-request storage for passing data between middlewares */ + locals: Record; + + /** Set the HTTP status code */ + status(code: number): Response; + + /** Set a response header */ + set(name: string, value: string): Response; + + /** Alias for set() */ + header(name: string, value: string): Response; + + /** Get a response header value */ + get(name: string): string | undefined; + + /** Set the Content-Type header */ + type(value: string): Response; + + /** Send a JSON response with proper Content-Type */ + json(data: unknown): Response; + + /** Send a response body (string, Buffer, or object) */ + send(data?: string | Buffer | Uint8Array | null): Response; + + /** Set status and send status code as text body */ + sendStatus(code: number): Response; +} + +export type NextFunction = () => Promise; + +export type Middleware = ( + req: Request, + res: Response, + next: NextFunction, +) => void | Promise; + +export type RouteHandler = ( + req: Request, + res: Response, +) => void | Promise; + +export type ErrorHandler = ( + error: Error, + req: Request, + res: Response, +) => void | Promise; + +// ─── Listen Options ─────────────────────────────────────────────────────────── + +export interface HttpServerConfig { + defaultHost?: string; + defaultBacklog?: number; + maxHeaderBytes?: number; +} + +export interface ListenOptions { + /** Host to bind to (default: "127.0.0.1") */ + host?: string; + + /** Port to bind to (default: 3000, use 0 for random) */ + port?: number; + + /** TCP listen backlog (default: 2048) */ + backlog?: number; + + /** Override server configuration */ + serverConfig?: HttpServerConfig; + + /** Runtime optimization options */ + opt?: Record; +} + +// ─── Server Handle ──────────────────────────────────────────────────────────── + +export interface ServerHandle { + /** Bound hostname */ + readonly host: string; + + /** Bound port number */ + readonly port: number; + + /** Full URL (http://host:port) */ + readonly url: string; + + /** Runtime optimization introspection */ + readonly optimizations: { + /** Get a snapshot of route optimization state */ + snapshot(): OptimizationSnapshot; + /** Get a human-readable optimization summary */ + summary(): string; + }; + + /** Gracefully close the server */ + close(): void; +} + +export interface OptimizationSnapshot { + routes: RouteOptimizationInfo[]; +} + +export interface RouteOptimizationInfo { + method: string; + path: string; + staticFastPath: boolean; + binaryBridge: boolean; + bridgeObserved: boolean; + cacheCandidate: boolean; + hits: number; + recommendation?: string; + dispatchKind?: string; + jsonFastPath?: string; +} + +// ─── Application ────────────────────────────────────────────────────────────── + +export interface Application { + /** Register path-scoped or global middleware */ + use(middleware: Middleware): Application; + use(path: string, middleware: Middleware): Application; + + /** Register a global error handler */ + onError(handler: ErrorHandler): Application; + + /** Register a GET route handler */ + get(path: string, handler: RouteHandler): Application; + + /** Register a POST route handler */ + post(path: string, handler: RouteHandler): Application; + + /** Register a PUT route handler */ + put(path: string, handler: RouteHandler): Application; + + /** Register a DELETE route handler */ + delete(path: string, handler: RouteHandler): Application; + + /** Register a PATCH route handler */ + patch(path: string, handler: RouteHandler): Application; + + /** Register an OPTIONS route handler */ + options(path: string, handler: RouteHandler): Application; + + /** Register a handler for all HTTP methods */ + all(path: string, handler: RouteHandler): Application; + + /** Start the server and listen for connections */ + listen(options?: ListenOptions): Promise; +} + +/** Create a new http-native application */ +export function createApp(): Application; + +// ─── CORS Types ─────────────────────────────────────────────────────────────── + +export interface CorsOptions { + /** Allowed origin(s). Default: "*" */ + origin?: string | string[] | ((origin: string) => boolean); + + /** Allowed HTTP methods */ + methods?: string | string[]; + + /** Allowed request headers */ + allowedHeaders?: string | string[]; + + /** Headers exposed to the browser */ + exposedHeaders?: string | string[]; + + /** Allow credentials (cookies, authorization) */ + credentials?: boolean; + + /** Cache duration for preflight response (seconds) */ + maxAge?: number; + + /** Handle preflight OPTIONS requests. Default: true */ + preflight?: boolean; +} + +/** Create a CORS middleware */ +export function cors(options?: CorsOptions): Middleware; + +// ─── Validation Types ───────────────────────────────────────────────────────── + +export interface ValidationSchema { + parse(data: unknown): T; +} + +export interface ValidateOptions { + body?: ValidationSchema; + query?: ValidationSchema; + params?: ValidationSchema; +} + +/** Create a validation middleware (works with Zod, TypeBox, or any schema with .parse()) */ +export function validate( + schema: ValidateOptions, +): Middleware; diff --git a/src/index.js b/src/index.js index a9e5788..ca8430a 100644 --- a/src/index.js +++ b/src/index.js @@ -19,6 +19,8 @@ const HTTP_METHODS = ["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"]; const ACTIVE_NATIVE_SERVERS = new Set(); const EMPTY_BUFFER = Buffer.alloc(0); +// ─── Path Normalization ─────────────────────────────────────────────────────── + function normalizePathPrefix(path) { if (path === "/") { return "/"; @@ -64,14 +66,45 @@ function normalizeContentType(type) { return type; } -function createResponseEnvelope(jsonSerializer = createJsonSerializer("fallback")) { - const state = { +// ─── Response Envelope (Pooled) ─────────────────────────────────────────────── + +const RESPONSE_POOL_MAX = 512; +const responsePool = []; + +function acquireResponseState() { + const pooled = responsePool.pop(); + if (pooled) { + pooled.status = 200; + // Reset headers — use null-prototype object for security + for (const key in pooled.headers) { + delete pooled.headers[key]; + } + pooled.body = EMPTY_BUFFER; + pooled.finished = false; + // Reset locals + for (const key in pooled.locals) { + delete pooled.locals[key]; + } + return pooled; + } + + return { status: 200, - headers: {}, + headers: Object.create(null), body: EMPTY_BUFFER, finished: false, - locals: {}, + locals: Object.create(null), }; +} + +function releaseResponseState(state) { + if (responsePool.length < RESPONSE_POOL_MAX) { + responsePool.push(state); + } +} + +function createResponseEnvelope(jsonSerializer = createJsonSerializer("fallback")) { + const state = acquireResponseState(); const response = { locals: state.locals, @@ -86,7 +119,14 @@ function createResponseEnvelope(jsonSerializer = createJsonSerializer("fallback" }, set(name, value) { - state.headers[String(name).toLowerCase()] = String(value); + // Security: validate header name/value for CRLF injection + const headerName = String(name).toLowerCase(); + const headerValue = String(value); + if (headerName.includes("\r") || headerName.includes("\n") || + headerValue.includes("\r") || headerValue.includes("\n")) { + return response; // Silently reject — security + } + state.headers[headerName] = headerValue; return response; }, @@ -161,10 +201,42 @@ function createResponseEnvelope(jsonSerializer = createJsonSerializer("fallback" body: state.body, }; }, + release() { + releaseResponseState(state); + }, }; } +// ─── Compiled Middleware Runner ──────────────────────────────────────────────── +// +// Generates an optimized runner that avoids function.length checks at runtime +// by pre-classifying middlewares during compilation. + function createMiddlewareRunner(middlewares) { + if (middlewares.length === 0) { + // Fast path: no middlewares — return a no-op + return async function noopMiddleware(_req, _res) {}; + } + + if (middlewares.length === 1) { + // Fast path: single middleware — avoid dispatch overhead + const mw = middlewares[0]; + if (mw.handler.length >= 3) { + return async function runSingleMiddleware(req, res) { + await mw.handler(req, res, () => Promise.resolve()); + }; + } + return async function runSingleMiddleware(req, res) { + await mw.handler(req, res); + }; + } + + // Pre-classify each middleware as "next-aware" or "auto-advance" + const classified = middlewares.map((mw) => ({ + handler: mw.handler, + needsNext: mw.handler.length >= 3, + })); + return async function runCompiledMiddlewares(req, res) { let index = -1; @@ -174,12 +246,12 @@ function createMiddlewareRunner(middlewares) { } index = position; - const middleware = middlewares[position]; + const middleware = classified[position]; if (!middleware || res.finished) { return; } - if (middleware.handler.length >= 3) { + if (middleware.needsNext) { await middleware.handler(req, res, () => dispatch(position + 1)); return; } @@ -194,23 +266,30 @@ function createMiddlewareRunner(middlewares) { }; } +// ─── Error Handling (Security-Hardened) ─────────────────────────────────────── + function serializeErrorResponse(error) { + // Security: NEVER leak internal error details to the client + const isProduction = process.env.NODE_ENV === "production"; + const body = isProduction + ? { error: "Internal Server Error" } + : { + error: "Internal Server Error", + detail: error instanceof Error ? error.message : String(error), + }; + return encodeResponseEnvelope({ status: 500, headers: { "content-type": "application/json; charset=utf-8", }, - body: Buffer.from( - JSON.stringify({ - error: "Internal Server Error", - detail: error instanceof Error ? error.message : String(error), - }), - "utf8", - ), + body: Buffer.from(JSON.stringify(body), "utf8"), }); } -function createDispatcher(compiledRoutes, runtimeOptimizer) { +// ─── Dispatcher ─────────────────────────────────────────────────────────────── + +function createDispatcher(compiledRoutes, runtimeOptimizer, errorHandlers = []) { const routesById = new Map(compiledRoutes.map((route) => [route.handlerId, route])); return async function dispatch(requestBuffer) { @@ -228,7 +307,7 @@ function createDispatcher(compiledRoutes, runtimeOptimizer) { } const req = route.requestFactory(decoded); - const { response: res, snapshot } = createResponseEnvelope(route.jsonSerializer); + const { response: res, snapshot, release } = createResponseEnvelope(route.jsonSerializer); try { await route.runMiddlewares(req, res); @@ -237,16 +316,38 @@ function createDispatcher(compiledRoutes, runtimeOptimizer) { } } catch (error) { if (!res.finished) { - return serializeErrorResponse(error); + // Try registered error handlers first + for (const errorHandler of errorHandlers) { + try { + await errorHandler(error, req, res); + if (res.finished) { + break; + } + } catch (handlerError) { + // Error handler itself threw — fall through to default + release(); + return serializeErrorResponse(handlerError); + } + } + + // If no error handler responded, use default + if (!res.finished) { + release(); + return serializeErrorResponse(error); + } } } const responseSnapshot = snapshot(); runtimeOptimizer?.recordDispatch(route, req, responseSnapshot); - return encodeResponseEnvelope(responseSnapshot); + const encoded = encodeResponseEnvelope(responseSnapshot); + release(); + return encoded; }; } +// ─── Route Registration & Compilation ───────────────────────────────────────── + function normalizeRouteRegistration(method, path, handler) { if (typeof handler !== "function") { throw new TypeError(`Handler for ${method} ${path} must be a function`); @@ -334,6 +435,8 @@ function normalizeListenOptions(options = {}) { }; } +// ─── Application Factory ───────────────────────────────────────────────────── + export function createApp() { const native = loadNativeModule(); let nextHandlerId = 1; @@ -341,6 +444,7 @@ export function createApp() { const app = { _routes: [], _middlewares: [], + _errorHandlers: [], use(pathOrMiddleware, maybeMiddleware) { let pathPrefix = "/"; @@ -359,6 +463,14 @@ export function createApp() { return this; }, + onError(handler) { + if (typeof handler !== "function") { + throw new TypeError("Error handler must be a function"); + } + this._errorHandlers.push(handler); + return this; + }, + get: undefined, post: undefined, put: undefined, @@ -411,7 +523,7 @@ export function createApp() { compiledMiddlewares, normalizedOptions.opt, ); - const dispatcher = createDispatcher(compiledRoutes, runtimeOptimizer); + const dispatcher = createDispatcher(compiledRoutes, runtimeOptimizer, this._errorHandlers); const handle = native.startServer(JSON.stringify(manifest), dispatcher, { host: normalizedOptions.host, port: normalizedOptions.port, diff --git a/src/validate.js b/src/validate.js new file mode 100644 index 0000000..e6b26b5 --- /dev/null +++ b/src/validate.js @@ -0,0 +1,152 @@ +/** + * http-native Validation Middleware + * + * Schema-agnostic: works with Zod, TypeBox, Yup, Joi, or any object with .parse() + * + * Usage: + * import { validate } from "http-native/validate"; + * import { z } from "zod"; + * + * app.post("/users", validate({ + * body: z.object({ name: z.string(), email: z.string().email() }), + * }), async (req, res) => { + * const { name, email } = req.validatedBody; + * res.json({ ok: true, name, email }); + * }); + */ + +/** + * @param {Object} schema + * @param {Object} [schema.body] - Schema to validate req.json() against + * @param {Object} [schema.query] - Schema to validate req.query against + * @param {Object} [schema.params] - Schema to validate req.params against + */ +export function validate(schema = {}) { + const { body: bodySchema, query: querySchema, params: paramsSchema } = schema; + + return async function validationMiddleware(req, res, next) { + try { + // Validate params + if (paramsSchema) { + const result = parseSchema(paramsSchema, req.params, "params"); + if (result.error) { + res.status(400).json({ + error: "Validation Error", + field: "params", + details: result.error, + }); + return; + } + req.validatedParams = result.value; + } + + // Validate query + if (querySchema) { + const result = parseSchema(querySchema, req.query, "query"); + if (result.error) { + res.status(400).json({ + error: "Validation Error", + field: "query", + details: result.error, + }); + return; + } + req.validatedQuery = result.value; + } + + // Validate body + if (bodySchema) { + const bodyData = req.json(); + if (bodyData === null && bodySchema) { + res.status(400).json({ + error: "Validation Error", + field: "body", + details: "Request body is required", + }); + return; + } + + const result = parseSchema(bodySchema, bodyData, "body"); + if (result.error) { + res.status(400).json({ + error: "Validation Error", + field: "body", + details: result.error, + }); + return; + } + req.validatedBody = result.value; + } + + await next(); + } catch (error) { + // JSON parse error or schema error + res.status(400).json({ + error: "Validation Error", + details: error instanceof Error ? error.message : String(error), + }); + } + }; +} + +/** + * Schema-agnostic parser. Supports: + * - Zod: schema.parse() throws ZodError + * - Zod safe: schema.safeParse() returns { success, data, error } + * - TypeBox/Ajv: schema.parse() or custom + * - Any object with .parse(data) that returns the parsed value or throws + */ +function parseSchema(schema, data, _fieldName) { + // Zod-style safeParse + if (typeof schema.safeParse === "function") { + const result = schema.safeParse(data); + if (result.success) { + return { value: result.data, error: null }; + } + // Zod error format + const details = result.error?.issues + ? result.error.issues.map((issue) => ({ + path: issue.path?.join(".") ?? "", + message: issue.message, + })) + : result.error?.message ?? "Validation failed"; + return { value: null, error: details }; + } + + // Standard .parse() — throws on error + if (typeof schema.parse === "function") { + try { + const value = schema.parse(data); + return { value, error: null }; + } catch (error) { + // Zod throws ZodError with .issues + if (error?.issues) { + const details = error.issues.map((issue) => ({ + path: issue.path?.join(".") ?? "", + message: issue.message, + })); + return { value: null, error: details }; + } + return { value: null, error: error?.message ?? "Validation failed" }; + } + } + + // Joi-style .validate() + if (typeof schema.validate === "function") { + const result = schema.validate(data); + if (result.error) { + const details = result.error.details + ? result.error.details.map((detail) => ({ + path: detail.path?.join(".") ?? "", + message: detail.message, + })) + : result.error.message ?? "Validation failed"; + return { value: null, error: details }; + } + return { value: result.value, error: null }; + } + + throw new TypeError( + "Schema must have a .parse(), .safeParse(), or .validate() method", + ); +} diff --git a/testing/better-express b/testing/better-express deleted file mode 160000 index 30254d8..0000000 --- a/testing/better-express +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 30254d89205a85577533a5b2e2d22704d640ec38 From 8c735ef78599e851b361e6d6284bcee2ac1326df Mon Sep 17 00:00:00 2001 From: "Nadhi-(Kushi)" Date: Fri, 27 Mar 2026 20:28:16 +0530 Subject: [PATCH 2/4] chore: the workflow is being picky --- .github/workflows/main.yml | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 3176b4e..e804f18 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -3,12 +3,10 @@ name: Benchmark on: push: branches: - - main - - master + - "**" pull_request: branches: - - main - - master + - "**" workflow_dispatch: inputs: engines: From d6c5240add6242d54a57d86bd572b5623cd69b96 Mon Sep 17 00:00:00 2001 From: Rishi Yadav Date: Fri, 27 Mar 2026 20:16:37 +0530 Subject: [PATCH 3/4] feat!: architecture hardening and enterprise feature suite CORE & PERFORMANCE: - Implement zero-allocation request body parsing in Rust core - Upgrade router from O(N) linear scan to O(M) Radix Tree - Optimize binary bridge to support high-throughput body transfers - Implement buffer pooling for body reads to minimize GC pressure FEATURES & MIDDLEWARE: - Add centralized error handling middleware support (app.onError) - Implement built-in CORS middleware with preflight handling - Add schema-agnostic validation middleware (Zod/TypeBox compatible) - Add native support for req.json(), req.text(), and req.body DEVELOPER EXPERIENCE: - Provide full TypeScript definitions (index.d.ts) - Implement self-referencing package exports in package.json - Add comprehensive examples (REST API, CORS, Validation, Middleware) SECURITY: - Add CRLF injection prevention in static response headers - Implement strict max body size limits (1MB) to prevent DoS - Fix Rust borrow checker conflicts for safe concurrent body handling --- package.json | 3 +- rust-native/src/lib.rs | 1391 ++++++++++++++++++++++++++++++---------- 2 files changed, 1066 insertions(+), 328 deletions(-) diff --git a/package.json b/package.json index 21052f3..c86de66 100644 --- a/package.json +++ b/package.json @@ -16,8 +16,7 @@ "build": "bun scripts/build-native.mjs", "build:release": "bun scripts/build-native.mjs --release", "test": "bun run build && bun test/test.js", - "bench:ci": "bun bench/ci.js", - "bench": "bun run build:release && cargo build --release --manifest-path old/native/Cargo.toml && bun bench/run.js", + "bench": "bun run build:release && bun bench/run.js", "bench:http-native:static": "bun run build:release && bun bench/run.js http-native static 3001", "bench:bun:static": "bun bench/run.js bun static 3000", "bench:xitca:static": "bun bench/run.js xitca static 3003", diff --git a/rust-native/src/lib.rs b/rust-native/src/lib.rs index 7dd0955..9d6eebf 100644 --- a/rust-native/src/lib.rs +++ b/rust-native/src/lib.rs @@ -1,134 +1,448 @@ -use anyhow::{Context, Result}; +mod analyzer; +mod manifest; +mod router; + +use anyhow::{anyhow, Context, Result}; use bytes::Bytes; -use monoio::io::{AsyncReadExt, AsyncWriteExt}; -use monoio::net::{TcpListener, TcpStream}; -use monoio::utils::memmem; -use napi::bindgen_prelude::Buffer; +use memchr::memmem; +use monoio::io::{AsyncReadRent, AsyncWriteRent, AsyncWriteRentExt}; +use monoio::net::{ListenerOpts, TcpListener, TcpStream}; +use napi::bindgen_prelude::{Buffer, Function, Promise}; +use napi::threadsafe_function::ThreadsafeFunction; +use napi::{Error, Status}; +use napi_derive::napi; use std::borrow::Cow; -use std::collections::HashMap; +use std::cell::RefCell; use std::net::{SocketAddr, ToSocketAddrs}; -use std::sync::Arc; -use std::thread; -use crate::router::{ExactStaticRoute, MatchedRoute, Router}; -use crate::manifest::HttpServerConfigInput; - -// ─── Constants & Limits ─────────────────────────────────────────────────────── - -const MAX_HEADERS: usize = 64; -const MAX_BODY_BYTES: usize = 1024 * 1024; // 1MB limit for safety -const REQUEST_FLAG_QUERY_PRESENT: u16 = 1; - -// ─── Types ──────────────────────────────────────────────────────────────────── - -pub type JsDispatcher = napi::threadsafe_function::ThreadsafeFunction; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{mpsc, Arc, Mutex}; -#[derive(Clone, Debug)] -pub struct HttpServerConfig { - pub default_host: String, - pub default_backlog: i32, - pub max_header_bytes: usize, - pub hot_get_root_http11: Vec, - pub hot_get_root_http10: Vec, - pub header_connection_prefix: Vec, - pub header_content_length_prefix: Vec, - pub header_transfer_encoding_prefix: Vec, -} - -pub struct ParsedRequest<'a> { - method: &'a [u8], - target: &'a [u8], - path: &'a [u8], - keep_alive: bool, - header_bytes: usize, - has_body: bool, - content_length: Option, - headers: Vec<(&'a str, &'a str)>, -} +use crate::manifest::{HttpServerConfigInput, ManifestInput}; +use crate::router::{ExactStaticRoute, MatchedRoute, Router}; -// ─── Buffer Pooling ─────────────────────────────────────────────────────────── +// ─── Constants ──────────────────────────────────────────────────────────────── + +const FALLBACK_DEFAULT_HOST: &str = "127.0.0.1"; +const FALLBACK_DEFAULT_BACKLOG: i32 = 2048; +const FALLBACK_MAX_HEADER_BYTES: usize = 16 * 1024; +const FALLBACK_HOT_GET_ROOT_HTTP11: &str = "GET / HTTP/1.1\r\n"; +const FALLBACK_HOT_GET_ROOT_HTTP10: &str = "GET / HTTP/1.0\r\n"; +const FALLBACK_HEADER_CONNECTION_PREFIX: &str = "connection:"; +const FALLBACK_HEADER_CONTENT_LENGTH_PREFIX: &str = "content-length:"; +const FALLBACK_HEADER_TRANSFER_ENCODING_PREFIX: &str = "transfer-encoding:"; +const BRIDGE_VERSION: u8 = 1; +const REQUEST_FLAG_QUERY_PRESENT: u16 = 1 << 0; +const REQUEST_FLAG_BODY_PRESENT: u16 = 1 << 1; +const NOT_FOUND_BODY: &[u8] = br#"{"error":"Route not found"}"#; + +/// Security: Maximum number of headers we allow per request +const MAX_HEADER_COUNT: usize = 64; +/// Security: Maximum URL length to prevent abuse +const MAX_URL_LENGTH: usize = 8192; +/// Security: Maximum single header value length +const MAX_HEADER_VALUE_LENGTH: usize = 8192; +/// Security: Maximum request body size (1 MB) +const MAX_BODY_BYTES: usize = 1024 * 1024; + +/// Buffer pool: initial capacity for connection read buffers +const BUFFER_INITIAL_CAPACITY: usize = 8192; +/// Buffer pool: max buffers held per thread +const BUFFER_POOL_MAX_SIZE: usize = 256; +/// Buffer pool: max buffer size to recycle (don't recycle oversized buffers) +const BUFFER_POOL_MAX_RECYCLE_SIZE: usize = 65536; + +type DispatchTsfn = ThreadsafeFunction, Buffer, Status, false, false, 0>; + +// ─── Thread-Local Buffer Pool ───────────────────────────────────────────────── // -// Zero-allocation buffer management. Buffers are re-used across connections -// within the same thread to avoid expensive syscalls and allocator pressure. +// Eliminates per-connection Vec allocations by recycling buffers. thread_local! { - static BUFFER_POOL: std::cell::RefCell> = std::cell::RefCell::new(Vec::with_capacity(65536)); + static BUFFER_POOL: RefCell>> = RefCell::new(Vec::with_capacity(BUFFER_POOL_MAX_SIZE)); } fn acquire_buffer() -> Vec { BUFFER_POOL.with(|pool| { - let mut b = pool.borrow_mut(); - if b.capacity() < 65536 { - Vec::with_capacity(65536) - } else { - std::mem::take(&mut *b) - } + pool.borrow_mut() + .pop() + .unwrap_or_else(|| Vec::with_capacity(BUFFER_INITIAL_CAPACITY)) }) } fn release_buffer(mut buf: Vec) { + if buf.capacity() > BUFFER_POOL_MAX_RECYCLE_SIZE { + return; // Don't recycle oversized buffers + } buf.clear(); BUFFER_POOL.with(|pool| { - *pool.borrow_mut() = buf; + let mut pool = pool.borrow_mut(); + if pool.len() < BUFFER_POOL_MAX_SIZE { + pool.push(buf); + } }); } -// ─── Server Entry Point ─────────────────────────────────────────────────────── +// ─── Server Configuration ───────────────────────────────────────────────────── + +#[derive(Clone)] +struct HttpServerConfig { + default_host: String, + default_backlog: i32, + max_header_bytes: usize, + hot_get_root_http11: Vec, + hot_get_root_http10: Vec, + header_connection_prefix: Vec, + header_content_length_prefix: Vec, + header_transfer_encoding_prefix: Vec, +} + +impl HttpServerConfig { + fn from_manifest(manifest: &ManifestInput) -> Result { + let input = manifest.server_config.as_ref(); + let default_backlog = input + .and_then(|config| config.default_backlog) + .unwrap_or(FALLBACK_DEFAULT_BACKLOG); + let max_header_bytes = input + .and_then(|config| config.max_header_bytes) + .unwrap_or(FALLBACK_MAX_HEADER_BYTES); + + if default_backlog <= 0 { + return Err(anyhow!( + "serverConfig.defaultBacklog must be greater than 0" + )); + } + + if max_header_bytes == 0 { + return Err(anyhow!( + "serverConfig.maxHeaderBytes must be greater than 0" + )); + } + Ok(Self { + default_host: config_string( + input, + |config| config.default_host.as_deref(), + FALLBACK_DEFAULT_HOST, + ), + default_backlog, + max_header_bytes, + hot_get_root_http11: config_string( + input, + |config| config.hot_get_root_http11.as_deref(), + FALLBACK_HOT_GET_ROOT_HTTP11, + ) + .into_bytes(), + hot_get_root_http10: config_string( + input, + |config| config.hot_get_root_http10.as_deref(), + FALLBACK_HOT_GET_ROOT_HTTP10, + ) + .into_bytes(), + header_connection_prefix: config_string( + input, + |config| config.header_connection_prefix.as_deref(), + FALLBACK_HEADER_CONNECTION_PREFIX, + ) + .into_bytes(), + header_content_length_prefix: config_string( + input, + |config| config.header_content_length_prefix.as_deref(), + FALLBACK_HEADER_CONTENT_LENGTH_PREFIX, + ) + .into_bytes(), + header_transfer_encoding_prefix: config_string( + input, + |config| config.header_transfer_encoding_prefix.as_deref(), + FALLBACK_HEADER_TRANSFER_ENCODING_PREFIX, + ) + .into_bytes(), + }) + } +} + +// ─── NAPI Interface ─────────────────────────────────────────────────────────── + +#[napi(object)] +pub struct NativeListenOptions { + pub host: Option, + pub port: u16, + pub backlog: Option, +} + +struct ShutdownHandle { + flag: Arc, + wake_addrs: Vec, +} + +#[napi] +pub struct NativeServerHandle { + host: String, + port: u32, + url: String, + shutdown: Mutex>, + closed: Mutex>>>, +} + +#[napi] +impl NativeServerHandle { + #[napi(getter)] + pub fn host(&self) -> String { + self.host.clone() + } + + #[napi(getter)] + pub fn port(&self) -> u32 { + self.port + } + + #[napi(getter)] + pub fn url(&self) -> String { + self.url.clone() + } + + #[napi] + pub fn close(&self) -> napi::Result<()> { + if let Some(shutdown) = self + .shutdown + .lock() + .expect("shutdown mutex poisoned") + .take() + { + shutdown.flag.store(true, Ordering::SeqCst); + for wake_addr in shutdown.wake_addrs { + let _ = std::net::TcpStream::connect(wake_addr); + } + } + + if let Some(receivers) = self.closed.lock().expect("closed mutex poisoned").take() { + for receiver in receivers { + let _ = receiver.recv(); + } + } + + Ok(()) + } +} + +#[napi] pub fn start_server( manifest_json: String, - handler: JsDispatcher, + dispatcher: Function<'_, Buffer, Promise>, options: NativeListenOptions, -) -> Result { - let manifest: crate::manifest::ManifestInput = serde_json::from_str(&manifest_json)?; - let router = Arc::new(Router::from_manifest(&manifest)?); - let dispatcher = Arc::new(handler); - let server_config = Arc::new(HttpServerConfig::from_input(manifest.server_config.as_ref())); +) -> napi::Result { + let manifest: ManifestInput = serde_json::from_str(&manifest_json).map_err(to_napi_error)?; + validate_manifest(&manifest).map_err(to_napi_error)?; + let server_config = + Arc::new(HttpServerConfig::from_manifest(&manifest).map_err(to_napi_error)?); + let router = Arc::new(Router::from_manifest(&manifest).map_err(to_napi_error)?); + + let callback: DispatchTsfn = dispatcher + .build_threadsafe_function::() + .build() + .map_err(to_napi_error)?; + let dispatcher = Arc::new(JsDispatcher { callback }); let worker_count = worker_count_for(&options); - let mut workers = Vec::with_capacity(worker_count); - - for i in 0..worker_count { - let router = Arc::clone(&router); - let dispatcher = Arc::clone(&dispatcher); - let server_config = Arc::clone(&server_config); - let options = options.clone(); - - let handle = thread::spawn(move || { - let mut driver = monoio::RuntimeBuilder::new_current_thread() - .enable_all() - .build() - .unwrap(); - - driver.block_on(async move { - let listener = bind_listener(&options, &server_config)?; - - loop { - let (stream, _) = listener.accept().await?; - - if should_enable_nodelay() { - if let Err(error) = stream.set_nodelay(true) { - eprintln!("[http-native] failed to enable TCP_NODELAY: {error}"); - } - } + let (startup_tx, startup_rx) = mpsc::sync_channel::>(worker_count); + let shutdown_flag = Arc::new(AtomicBool::new(false)); + let mut closed_receivers = Vec::with_capacity(worker_count); + + for _ in 0..worker_count { + let (closed_tx, closed_rx) = mpsc::channel::<()>(); + closed_receivers.push(closed_rx); + + let thread_router = Arc::clone(&router); + let thread_dispatcher = Arc::clone(&dispatcher); + let thread_config = Arc::clone(&server_config); + let thread_shutdown = Arc::clone(&shutdown_flag); + let thread_options = NativeListenOptions { + host: options.host.clone(), + port: options.port, + backlog: options.backlog, + }; + let thread_startup_tx = startup_tx.clone(); + + std::thread::spawn(move || { + let startup_tx_error = thread_startup_tx.clone(); + let result = (|| -> Result<()> { + let mut runtime = monoio::RuntimeBuilder::::new() + .build() + .context("failed to build monoio runtime")?; + + runtime.block_on(async move { + let listener = bind_listener(&thread_options, thread_config.as_ref()) + .context("failed to create monoio listener")?; + let local_addr = listener.local_addr()?; + let _ = thread_startup_tx.send(Ok(local_addr)); + run_server( + listener, + thread_router, + thread_dispatcher, + thread_config, + thread_shutdown, + ) + .await + }) + })(); + + if let Err(error) = &result { + let _ = startup_tx_error.send(Err(error.to_string())); + eprintln!("[http-native] native server error: {error:#}"); + } - let router = Arc::clone(&router); - let dispatcher = Arc::clone(&dispatcher); - let server_config = Arc::clone(&server_config); + let _ = closed_tx.send(()); + }); + } - monoio::spawn(async move { - if let Err(e) = handle_connection(stream, router, dispatcher, server_config).await { - eprintln!("[http-native] worker {i} connection error: {e}"); - } - }); + let mut wake_addrs = Vec::with_capacity(worker_count); + let mut local_addr = None; + for _ in 0..worker_count { + match startup_rx.recv() { + Ok(Ok(addr)) => { + if local_addr.is_none() { + local_addr = Some(addr); } - }) - }); - workers.push(handle); + wake_addrs.push(addr); + } + Ok(Err(message)) => { + shutdown_flag.store(true, Ordering::SeqCst); + for wake_addr in &wake_addrs { + let _ = std::net::TcpStream::connect(*wake_addr); + } + for receiver in closed_receivers { + let _ = receiver.recv(); + } + return Err(Error::from_reason(message)); + } + Err(_) => { + shutdown_flag.store(true, Ordering::SeqCst); + for wake_addr in &wake_addrs { + let _ = std::net::TcpStream::connect(*wake_addr); + } + for receiver in closed_receivers { + let _ = receiver.recv(); + } + return Err(Error::from_reason( + "Native server exited before reporting readiness".to_string(), + )); + } + } + } + + let local_addr = local_addr.expect("worker count must be at least 1"); + + let host = local_addr.ip().to_string(); + let port = local_addr.port() as u32; + + Ok(NativeServerHandle { + host: host.clone(), + port, + url: format!("http://{host}:{port}"), + shutdown: Mutex::new(Some(ShutdownHandle { + flag: shutdown_flag, + wake_addrs, + })), + closed: Mutex::new(Some(closed_receivers)), + }) +} + +fn worker_count_for(options: &NativeListenOptions) -> usize { + if options.port == 0 { + return 1; } - Ok(ServerHandle { workers }) + std::env::var("HTTP_NATIVE_WORKERS") + .ok() + .and_then(|value| value.parse::().ok()) + .filter(|count| *count > 0) + .unwrap_or(1) } +// ─── JS Dispatcher ──────────────────────────────────────────────────────────── + +struct JsDispatcher { + callback: DispatchTsfn, +} + +impl JsDispatcher { + async fn dispatch(&self, request: Buffer) -> Result { + let response_json = self + .callback + .call_async(request) + .await + .map_err(|error| anyhow!(error.to_string()))? + .await + .map_err(|error| anyhow!(error.to_string()))?; + + Ok(response_json) + } +} + +// ─── Server Loop ────────────────────────────────────────────────────────────── + +async fn run_server( + listener: TcpListener, + router: Arc, + dispatcher: Arc, + server_config: Arc, + shutdown_flag: Arc, +) -> Result<()> { + loop { + if shutdown_flag.load(Ordering::Acquire) { + break; + } + + match listener.accept().await { + Ok((stream, _)) => { + if shutdown_flag.load(Ordering::Acquire) { + break; + } + + if let Err(error) = stream.set_nodelay(true) { + eprintln!("[http-native] failed to enable TCP_NODELAY: {error}"); + } + + let router = Arc::clone(&router); + let dispatcher = Arc::clone(&dispatcher); + let server_config = Arc::clone(&server_config); + + monoio::spawn(async move { + if let Err(error) = + handle_connection(stream, router, dispatcher, server_config).await + { + eprintln!("[http-native] connection error: {error}"); + } + }); + } + Err(error) => { + if shutdown_flag.load(Ordering::Acquire) { + break; + } + + eprintln!("[http-native] accept error: {error}"); + } + } + } + + Ok(()) +} + +// ─── Parsed Request (from httparse) ─────────────────────────────────────────── + +struct ParsedRequest<'a> { + method: &'a [u8], + target: &'a [u8], + path: &'a [u8], + keep_alive: bool, + header_bytes: usize, + has_body: bool, + content_length: Option, + /// Pre-parsed header pairs — stored once, used by both routing and bridge + headers: Vec<(&'a str, &'a str)>, +} + +// ─── Connection Handler with Buffer Pool ────────────────────────────────────── + async fn handle_connection( mut stream: TcpStream, router: Arc, @@ -158,7 +472,7 @@ async fn handle_connection_inner( server_config: &HttpServerConfig, ) -> Result<()> { loop { - // Try hot-path parsing first + // Try hot-path parsing first (GET / with known prefix) let parsed = loop { let result = if router.exact_get_root().is_some() { parse_hot_root_request(buffer, server_config) @@ -172,10 +486,12 @@ async fn handle_connection_inner( } if find_header_end(buffer).is_some() { + // Headers complete but couldn't parse — malformed request stream.shutdown().await?; return Ok(()); } + // SAFETY: We take ownership of the buffer, read into it, then put it back let owned_buf = std::mem::take(buffer); let (read_result, next_buffer) = stream.read(owned_buf).await; *buffer = next_buffer; @@ -186,6 +502,7 @@ async fn handle_connection_inner( } if buffer.len() > server_config.max_header_bytes { + // Security: Request header too large let response = build_error_response_bytes( 431, b"{\"error\":\"Request Header Fields Too Large\"}", @@ -241,10 +558,12 @@ async fn handle_connection_inner( // ── Read request body if present ────────────────────────────── let body_bytes: Vec = if has_body { - let cl = match content_length { + let content_length = match content_length { Some(len) => len, None => { - let response = build_error_response_bytes(411, b"{\"error\":\"Length Required\"}", false); + // Chunked or unknown body length — reject for now + let response = + build_error_response_bytes(411, b"{\"error\":\"Length Required\"}", false); let (write_result, _) = stream.write_all(response).await; write_result?; stream.shutdown().await?; @@ -252,42 +571,47 @@ async fn handle_connection_inner( } }; - if cl > MAX_BODY_BYTES { - let response = build_error_response_bytes(413, b"{\"error\":\"Payload Too Large\"}", false); + // Security: enforce max body size + if content_length > MAX_BODY_BYTES { + let response = + build_error_response_bytes(413, b"{\"error\":\"Payload Too Large\"}", false); let (write_result, _) = stream.write_all(response).await; write_result?; stream.shutdown().await?; return Ok(()); } + // Some body bytes may already be in the buffer after the headers let already_in_buffer = if buffer.len() > header_bytes { buffer.len() - header_bytes } else { 0 }; - if already_in_buffer >= cl { - let body = buffer[header_bytes..header_bytes + cl].to_vec(); - drain_consumed_bytes(buffer, header_bytes + cl); + if already_in_buffer >= content_length { + // Entire body is already in the buffer + let body = buffer[header_bytes..header_bytes + content_length].to_vec(); + drain_consumed_bytes(buffer, header_bytes + content_length); body } else { - let mut body = Vec::with_capacity(cl); + // Need to read more bytes from the stream + let mut body = Vec::with_capacity(content_length); if already_in_buffer > 0 { body.extend_from_slice(&buffer[header_bytes..]); } drain_consumed_bytes(buffer, buffer.len()); - while body.len() < cl { - let remaining = cl - body.len(); + while body.len() < content_length { + let remaining = content_length - body.len(); let chunk_buf = vec![0u8; remaining.min(65536)]; let (read_result, returned_buf) = stream.read(chunk_buf).await; let bytes_read = read_result?; if bytes_read == 0 { - return Ok(()); + return Ok(()); // Connection closed mid-body } body.extend_from_slice(&returned_buf[..bytes_read]); } - body.truncate(cl); + body.truncate(content_length); body } } else { @@ -295,7 +619,7 @@ async fn handle_connection_inner( Vec::new() }; - // ── Dynamic path: Bridge to JS ──── + // Dynamic path: build bridge envelope and dispatch to JS let dispatch_request = build_dispatch_request_owned( router, &method_owned, @@ -308,87 +632,121 @@ async fn handle_connection_inner( match dispatch_request { Some(request) => { write_dynamic_dispatch_response(stream, dispatcher, request, keep_alive).await?; - if !keep_alive { - stream.shutdown().await?; - return Ok(()); - } } None => { write_not_found_response(stream, keep_alive).await?; - if !keep_alive { - stream.shutdown().await?; - return Ok(()); - } } } + + if !keep_alive { + stream.shutdown().await?; + return Ok(()); + } } } -// ─── Header Parsers ─────────────────────────────────────────────────────────── +// ─── httparse-based Request Parsing ─────────────────────────────────────────── +// +// Uses the battle-tested `httparse` crate for RFC-compliant zero-copy parsing. +// Single-pass: parses headers once and stores them for reuse by both the +// router and the bridge envelope builder. fn parse_request_httparse(bytes: &[u8]) -> Option> { - let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS]; - let mut req = httparse::Request::new(&mut headers); - - match req.parse(bytes) { - Ok(httparse::Status::Complete(header_len)) => { - let method = req.method?; - let target = req.path?; - let path = target.split(|&b| b == b'?').next()?; - - let mut keep_alive = req.version == Some(1); // Default true for HTTP/1.1 - let mut content_length = None; - let mut has_body = false; - let mut parsed_headers = Vec::with_capacity(req.headers.len()); - - for h in req.headers { - let name = h.name.to_lowercase(); - let value_bytes = h.value; - let value = std::str::from_utf8(value_bytes).ok()?; - - match name.as_str() { - "connection" => { - if contains_ascii_case_insensitive(value_bytes, b"close") { - keep_alive = false; - } else if contains_ascii_case_insensitive(value_bytes, b"keep-alive") { - keep_alive = true; - } - } - "content-length" => { - if let Ok(len) = value.trim().parse::() { - content_length = Some(len); - if len > 0 { has_body = true; } - } - } - "transfer-encoding" => { - if !value.eq_ignore_ascii_case("identity") { - has_body = true; - } - } - _ => {} + let mut raw_headers = [httparse::EMPTY_HEADER; MAX_HEADER_COUNT]; + let mut req = httparse::Request::new(&mut raw_headers); + + let header_len = match req.parse(bytes) { + Ok(httparse::Status::Complete(len)) => len, + Ok(httparse::Status::Partial) => return None, + Err(_) => return None, // Malformed — caller will handle + }; + + let method = req.method?.as_bytes(); + let target = req.path?.as_bytes(); + let version = req.version?; + + // Security: enforce URL length limit + if target.len() > MAX_URL_LENGTH { + return None; + } + + // Extract path (before '?') + let path = target.split(|b| *b == b'?').next()?; + + let mut keep_alive = version >= 1; // HTTP/1.1+ defaults to keep-alive + let mut has_body = false; + let mut content_length: Option = None; + let mut headers = Vec::with_capacity(req.headers.len()); + + for header in req.headers.iter() { + if header.name.is_empty() { + break; + } + + // Security: enforce header value length + if header.value.len() > MAX_HEADER_VALUE_LENGTH { + return None; + } + + let name = header.name; // httparse gives us &str + let value = match std::str::from_utf8(header.value) { + Ok(v) => v, + Err(_) => continue, // Skip non-UTF-8 headers + }; + + // Connection handling + if name.eq_ignore_ascii_case("connection") { + let lower = value.to_ascii_lowercase(); + if lower.contains("close") { + keep_alive = false; + } + if lower.contains("keep-alive") { + keep_alive = true; + } + } + + // Body detection + if name.eq_ignore_ascii_case("content-length") { + let trimmed = value.trim(); + if let Ok(len) = trimmed.parse::() { + content_length = Some(len); + if len > 0 { + has_body = true; } - parsed_headers.push((h.name, value)); } + } - Some(ParsedRequest { - method: method.as_bytes(), - target: target.as_bytes(), - path: path.as_bytes(), - keep_alive, - header_bytes: header_len, - has_body, - content_length, - headers: parsed_headers, - }) + if name.eq_ignore_ascii_case("transfer-encoding") { + let trimmed = value.trim(); + if !trimmed.is_empty() && !trimmed.eq_ignore_ascii_case("identity") { + has_body = true; + } } - _ => None, + + headers.push((name, value)); } + + Some(ParsedRequest { + method, + target, + path, + keep_alive, + header_bytes: header_len, + has_body, + content_length, + headers, + }) } -fn parse_hot_root_request<'a>( - bytes: &'a [u8], +// ─── Hot Root Path (GET /) ──────────────────────────────────────────────────── +// +// Ultra-fast path for the most common benchmark case. Falls back to httparse +// if the request doesn't exactly match the expected prefix. + +fn parse_hot_root_request( + bytes: &[u8], server_config: &HttpServerConfig, -) -> Option> { +) -> Option> { let (_, keep_alive) = if bytes.starts_with(server_config.hot_get_root_http11.as_slice()) { (server_config.hot_get_root_http11.len(), true) } else if bytes.starts_with(server_config.hot_get_root_http10.as_slice()) { @@ -398,25 +756,70 @@ fn parse_hot_root_request<'a>( }; let header_end = find_header_end(bytes)?; - // For hot path, we just verify it looks like a header block ending - // but we use httparse for the actual details to be safe. - parse_request_httparse(bytes) -} + let mut keep_alive = keep_alive; + let mut has_body = false; + let mut line_start = bytes.iter().position(|b| *b == b'\n')? + 1; -// ─── Routing ────────────────────────────────────────────────────────────────── + while line_start + 2 <= header_end { + let next_end = memmem::find(&bytes[line_start..header_end + 2], b"\r\n")? + line_start; -#[allow(dead_code)] -fn resolve_static_fast_path<'a>( - router: &'a Router, - parsed: &ParsedRequest<'_>, - _server_config: &HttpServerConfig, -) -> Option<&'a ExactStaticRoute> { - if parsed.path == b"/" && parsed.method == b"GET" { - return router.exact_get_root(); + if next_end == line_start { + break; + } + + let line = &bytes[line_start..next_end]; + if line.len() >= server_config.header_connection_prefix.len() + && line[..server_config.header_connection_prefix.len()] + .eq_ignore_ascii_case(server_config.header_connection_prefix.as_slice()) + { + let value = &line[server_config.header_connection_prefix.len()..]; + if contains_ascii_case_insensitive(value, b"close") { + keep_alive = false; + } + if contains_ascii_case_insensitive(value, b"keep-alive") { + keep_alive = true; + } + } else if line.len() >= server_config.header_content_length_prefix.len() + && line[..server_config.header_content_length_prefix.len()] + .eq_ignore_ascii_case(server_config.header_content_length_prefix.as_slice()) + { + let value = + trim_ascii_spaces(&line[server_config.header_content_length_prefix.len()..]); + if value != b"0" { + has_body = true; + } + } else if line.len() >= server_config.header_transfer_encoding_prefix.len() + && line[..server_config.header_transfer_encoding_prefix.len()] + .eq_ignore_ascii_case(server_config.header_transfer_encoding_prefix.as_slice()) + { + let value = + trim_ascii_spaces(&line[server_config.header_transfer_encoding_prefix.len()..]); + if !value.is_empty() && !value.eq_ignore_ascii_case(b"identity") { + has_body = true; + } + } + + line_start = next_end + 2; } - router.exact_static_route(parsed.method, parsed.path) + + Some(ParsedRequest { + method: b"GET", + target: b"/", + path: b"/", + keep_alive, + header_bytes: header_end + 4, + has_body, + content_length: None, + headers: Vec::new(), // Hot path: no headers needed for static response + }) } +// ─── Routing ────────────────────────────────────────────────────────────────── + +// ─── Bridge Envelope Building (Single-Pass Headers) ─────────────────────────── +// +// Uses the pre-parsed headers from httparse — no second scan of the raw bytes. + fn build_dispatch_request_owned( router: &Router, method: &[u8], @@ -429,9 +832,16 @@ fn build_dispatch_request_owned( return Ok(None); }; - let path_str = std::str::from_utf8(path).ok().context("Invalid UTF-8 path")?; - let url_str = std::str::from_utf8(target).ok().context("Invalid UTF-8 URL")?; + let path_str = match std::str::from_utf8(path) { + Ok(path_str) => path_str, + Err(_) => return Ok(None), + }; + let url_str = match std::str::from_utf8(target) { + Ok(url_str) => url_str, + Err(_) => return Ok(None), + }; + // Security: strict path validation let normalized_path = normalize_runtime_path(path_str); if contains_path_traversal(&normalized_path) { return Ok(None); @@ -441,8 +851,20 @@ fn build_dispatch_request_owned( return Ok(None); }; - let header_refs: Vec<(&str, &str)> = headers.iter().map(|(n, v)| (n.as_str(), v.as_str())).collect(); - build_dispatch_envelope(&matched_route, method_code, path_str, url_str, &header_refs, body).map(Some) + let header_refs: Vec<(&str, &str)> = headers + .iter() + .map(|(n, v)| (n.as_str(), v.as_str())) + .collect(); + + build_dispatch_envelope( + &matched_route, + method_code, + path_str, + url_str, + &header_refs, + body, + ) + .map(Some) } fn build_dispatch_envelope( @@ -459,130 +881,319 @@ fn build_dispatch_envelope( if url.contains('?') { flags |= REQUEST_FLAG_QUERY_PRESENT; } + if !body.is_empty() { + flags |= REQUEST_FLAG_BODY_PRESENT; + } - let mut envelope = Vec::with_capacity(512 + body.len()); - envelope.push(1); // Version - envelope.push(method_code); - envelope.extend_from_slice(&(flags).to_le_bytes()); - envelope.extend_from_slice(&(matched_route.handler_id).to_le_bytes()); + if url_bytes.len() > u32::MAX as usize { + return Err(anyhow!("request url too large")); + } + if path_bytes.len() > u16::MAX as usize { + return Err(anyhow!("request path too large")); + } + if matched_route.param_values.len() > u16::MAX as usize { + return Err(anyhow!("too many params")); + } + let selected_headers = select_header_entries(header_entries, matched_route); + if selected_headers.len() > u16::MAX as usize { + return Err(anyhow!("too many headers")); + } - write_usize(&mut envelope, url_bytes.len()); - envelope.extend_from_slice(url_bytes); + let mut frame = Vec::with_capacity( + 20 + url_bytes.len() + path_bytes.len() + selected_headers.len() * 16 + body.len(), + ); + frame.push(BRIDGE_VERSION); + frame.push(method_code); + push_u16(&mut frame, flags); + push_u32(&mut frame, matched_route.handler_id); + push_u32(&mut frame, url_bytes.len() as u32); + push_u16(&mut frame, path_bytes.len() as u16); + push_u16(&mut frame, matched_route.param_values.len() as u16); + push_u16(&mut frame, selected_headers.len() as u16); + push_u32(&mut frame, body.len() as u32); // NEW: body length + frame.extend_from_slice(url_bytes); + frame.extend_from_slice(path_bytes); + + for value in matched_route.param_values.iter() { + push_string_value(&mut frame, value)?; + } + + for (name, value) in selected_headers { + push_string_pair(&mut frame, name, value)?; + } - write_usize(&mut envelope, path_bytes.len()); - envelope.extend_from_slice(path_bytes); + frame.extend_from_slice(body); // NEW: body bytes at end - write_usize(&mut envelope, matched_route.param_values.len()); - for val in &matched_route.param_values { - write_usize(&mut envelope, val.len()); - envelope.extend_from_slice(val.as_bytes()); + Ok(Buffer::from(frame)) +} + +fn select_header_entries<'a>( + header_entries: &[(&'a str, &'a str)], + matched_route: &MatchedRoute<'_, '_>, +) -> Vec<(&'a str, &'a str)> { + if matched_route.full_headers { + return header_entries.to_vec(); } - let header_count = matched_route.header_keys.len(); - write_usize(&mut envelope, header_count); + if matched_route.header_keys.is_empty() { + return Vec::new(); + } - for key_boxed in matched_route.header_keys { - let key = key_boxed.as_ref(); - let mut found = false; - for (h_name, h_value) in header_entries { - if h_name.eq_ignore_ascii_case(key) { - write_usize(&mut envelope, h_value.len()); - envelope.extend_from_slice(h_value.as_bytes()); - found = true; - break; - } - } - if !found { - write_usize(&mut envelope, 0); + let mut selected = Vec::with_capacity(matched_route.header_keys.len()); + for (name, value) in header_entries { + if matched_route + .header_keys + .iter() + .any(|target| target.as_ref().eq_ignore_ascii_case(name)) + { + selected.push((*name, *value)); } } - // Body support - write_usize(&mut envelope, body.len()); - envelope.extend_from_slice(body); - - Ok(Buffer::from(envelope)) + selected } // ─── Response Writing ───────────────────────────────────────────────────────── async fn write_exact_static_response( stream: &mut TcpStream, - route: &ExactStaticRoute, + static_route: &ExactStaticRoute, keep_alive: bool, ) -> Result<()> { let response = if keep_alive { - &route.keep_alive_response + static_route.keep_alive_response.clone() } else { - &route.close_response + static_route.close_response.clone() }; - let (res, _) = stream.write_all(response.clone()).await; - res?; + + let (write_result, _) = stream.write_all(response).await; + write_result?; Ok(()) } +#[derive(Clone)] +struct DispatchResponseEnvelope { + status: u16, + headers: Vec<(String, String)>, + body: Bytes, +} + async fn write_dynamic_dispatch_response( stream: &mut TcpStream, dispatcher: &JsDispatcher, - request_buffer: Buffer, + request: Buffer, keep_alive: bool, ) -> Result<()> { - let result: Buffer = dispatcher.call_async(request_buffer).await - .map_err(|e| anyhow::anyhow!("JS dispatch failed: {e}"))?; - - let (write_res, _) = stream.write_all(result).await; - write_res?; + let parsed = match dispatcher.dispatch(request).await { + Ok(response) => match parse_dispatch_response(response.as_ref()) { + Ok(parsed) => parsed, + Err(_) => DispatchResponseEnvelope { + status: 500, + headers: vec![( + "content-type".to_string(), + "application/json; charset=utf-8".to_string(), + )], + // Security: sanitized error — no internal details + body: Bytes::from_static(b"{\"error\":\"Internal Server Error\"}"), + }, + }, + Err(_) => DispatchResponseEnvelope { + status: 502, + headers: vec![( + "content-type".to_string(), + "application/json; charset=utf-8".to_string(), + )], + // Security: sanitized error — no internal details + body: Bytes::from_static(b"{\"error\":\"Bad Gateway\"}"), + }, + }; + let response_bytes = build_dispatch_response_bytes(parsed, keep_alive); + let (write_result, _) = stream.write_all(response_bytes).await; + write_result?; Ok(()) } async fn write_not_found_response(stream: &mut TcpStream, keep_alive: bool) -> Result<()> { - let response = build_error_response_bytes(404, b"{\"error\":\"Not Found\"}", keep_alive); - let (res, _) = stream.write_all(response).await; - res?; + let response = build_response_bytes( + 404, + &[( + "content-type".to_string(), + "application/json; charset=utf-8".to_string(), + )], + Bytes::from_static(NOT_FOUND_BODY), + keep_alive, + ); + let (write_result, _) = stream.write_all(response).await; + write_result?; Ok(()) } -// ─── Helpers ────────────────────────────────────────────────────────────────── - +/// Build a simple error response without going through the JS bridge fn build_error_response_bytes(status: u16, body: &[u8], keep_alive: bool) -> Vec { - let mut response = format!( - "HTTP/1.1 {} {}\r\ncontent-length: {}\r\ncontent-type: application/json\r\nconnection: {}\r\n\r\n", + build_response_bytes( status, - status_reason(status), - body.len(), - if keep_alive { "keep-alive" } else { "close" } + &[( + "content-type".to_string(), + "application/json; charset=utf-8".to_string(), + )], + Bytes::copy_from_slice(body), + keep_alive, ) - .into_bytes(); - response.extend_from_slice(body); - response } -fn drain_consumed_bytes(buffer: &mut Vec, consumed: usize) { - if consumed >= buffer.len() { - buffer.clear(); - } else { - buffer.drain(..consumed); +fn build_dispatch_response_bytes(response: DispatchResponseEnvelope, keep_alive: bool) -> Vec { + build_response_bytes( + response.status, + &response.headers, + response.body, + keep_alive, + ) +} + +/// Optimized response builder: pre-calculates size and writes in a single pass +fn build_response_bytes( + status: u16, + headers: &[(String, String)], + body: Bytes, + keep_alive: bool, +) -> Vec { + let reason = status_reason(status); + let connection = if keep_alive { "keep-alive" } else { "close" }; + let body_len = body.len(); + + // Pre-calculate total size to avoid reallocations + // "HTTP/1.1 " + status(3) + " " + reason + "\r\n" + "content-length: " + digits + "\r\n" + "connection: " + conn + "\r\n" + let mut total_size = + 9 + 3 + 1 + reason.len() + 2 + 16 + count_digits(body_len) + 2 + 12 + connection.len() + 2; + + for (name, value) in headers { + if name.eq_ignore_ascii_case("content-length") || name.eq_ignore_ascii_case("connection") { + continue; + } + // Security: skip headers with CRLF injection + if name.contains('\r') + || name.contains('\n') + || value.contains('\r') + || value.contains('\n') + { + continue; + } + total_size += name.len() + 2 + value.len() + 2; } + + total_size += 2 + body_len; // final \r\n + body + + let mut output = Vec::with_capacity(total_size); + + // Status line + output.extend_from_slice(b"HTTP/1.1 "); + write_u16(&mut output, status); + output.push(b' '); + output.extend_from_slice(reason.as_bytes()); + output.extend_from_slice(b"\r\n"); + + // Mandatory headers + output.extend_from_slice(b"content-length: "); + write_usize(&mut output, body_len); + output.extend_from_slice(b"\r\n"); + output.extend_from_slice(b"connection: "); + output.extend_from_slice(connection.as_bytes()); + output.extend_from_slice(b"\r\n"); + + // User headers + for (name, value) in headers { + if name.eq_ignore_ascii_case("content-length") || name.eq_ignore_ascii_case("connection") { + continue; + } + if name.contains('\r') + || name.contains('\n') + || value.contains('\r') + || value.contains('\n') + { + continue; + } + output.extend_from_slice(name.as_bytes()); + output.extend_from_slice(b": "); + output.extend_from_slice(value.as_bytes()); + output.extend_from_slice(b"\r\n"); + } + + output.extend_from_slice(b"\r\n"); + output.extend_from_slice(body.as_ref()); + output } -fn status_reason(status: u16) -> &'static str { - match status { - 200 => "OK", - 201 => "Created", - 204 => "No Content", - 400 => "Bad Request", - 401 => "Unauthorized", - 403 => "Forbidden", - 404 => "Not Found", - 411 => "Length Required", - 413 => "Payload Too Large", - 431 => "Request Header Fields Too Large", - 500 => "Internal Server Error", - _ => "OK", +// ─── Response Parsing (from JS bridge) ──────────────────────────────────────── + +fn parse_dispatch_response(bytes: &[u8]) -> Result { + let mut offset = 0; + let status = read_u16(bytes, &mut offset)?; + let header_count = read_u16(bytes, &mut offset)? as usize; + let body_length = read_u32(bytes, &mut offset)? as usize; + + let mut headers = Vec::with_capacity(header_count); + for _ in 0..header_count { + let name_length = read_u8(bytes, &mut offset)? as usize; + let value_length = read_u16(bytes, &mut offset)? as usize; + let name = read_utf8(bytes, &mut offset, name_length)?; + let value = read_utf8(bytes, &mut offset, value_length)?; + headers.push((name, value)); + } + + if offset + body_length > bytes.len() { + return Err(anyhow!("response body truncated")); } + + let body = Bytes::copy_from_slice(&bytes[offset..offset + body_length]); + Ok(DispatchResponseEnvelope { + status, + headers, + body, + }) +} + +// ─── Security Utilities ─────────────────────────────────────────────────────── + +/// Check for path traversal attempts (../, ..\, etc.) +fn contains_path_traversal(path: &str) -> bool { + // Decode percent-encoded dots + let decoded = path.replace("%2e", ".").replace("%2E", "."); + + // Check for traversal patterns + decoded.contains("/../") + || decoded.contains("\\..\\") + || decoded.ends_with("/..") + || decoded.ends_with("\\..") + || decoded.starts_with("../") + || decoded.starts_with("..\\") + || decoded == ".." } +/// RFC 8259 compliant JSON string escaping — handles ALL control characters +#[allow(dead_code)] +pub(crate) fn escape_json(value: &str) -> String { + let mut output = String::with_capacity(value.len() + 8); + for ch in value.chars() { + match ch { + '"' => output.push_str("\\\""), + '\\' => output.push_str("\\\\"), + '\n' => output.push_str("\\n"), + '\r' => output.push_str("\\r"), + '\t' => output.push_str("\\t"), + '\x08' => output.push_str("\\b"), + '\x0C' => output.push_str("\\f"), + c if c.is_control() => { + output.push_str(&format!("\\u{:04x}", c as u32)); + } + c => output.push(c), + } + } + output +} + +// ─── Helpers ────────────────────────────────────────────────────────────────── + fn method_code_from_bytes(method: &[u8]) -> Option { match method { b"GET" => Some(1), @@ -596,20 +1207,49 @@ fn method_code_from_bytes(method: &[u8]) -> Option { } } -fn write_usize(output: &mut Vec, value: usize) { - output.extend_from_slice(&(value as u32).to_le_bytes()); +fn drain_consumed_bytes(buffer: &mut Vec, consumed: usize) { + if consumed >= buffer.len() { + buffer.clear(); + return; + } + + let remaining = buffer.len() - consumed; + buffer.copy_within(consumed.., 0); + buffer.truncate(remaining); } -fn normalize_runtime_path(path: &str) -> Cow<'_, str> { - if path == "/" || !path.ends_with('/') { - Cow::Borrowed(path) - } else { - Cow::Owned(path.trim_end_matches('/').to_string()) - } +fn bind_listener( + options: &NativeListenOptions, + server_config: &HttpServerConfig, +) -> Result { + let host = options + .host + .as_deref() + .unwrap_or(server_config.default_host.as_str()); + let bind_addr = resolve_socket_addr(host, options.port) + .with_context(|| format!("failed to resolve bind address {host}:{}", options.port))?; + let listener_opts = ListenerOpts::new() + .reuse_addr(true) + .reuse_port(true) + .backlog(options.backlog.unwrap_or(server_config.default_backlog)); + + TcpListener::bind_with_config(bind_addr, &listener_opts) + .with_context(|| format!("failed to bind TCP listener on {bind_addr}")) } -fn contains_path_traversal(path: &str) -> bool { - path.contains("/../") || path.contains("\\..\\") || path.ends_with("/..") || path.ends_with("\\..") +fn resolve_socket_addr(host: &str, port: u16) -> Result { + (host, port) + .to_socket_addrs()? + .next() + .ok_or_else(|| anyhow!("unable to resolve {host}:{port}")) +} + +fn validate_manifest(manifest: &ManifestInput) -> Result<()> { + if manifest.version != 1 { + return Err(anyhow!("Unsupported manifest version {}", manifest.version)); + } + + Ok(()) } fn find_header_end(bytes: &[u8]) -> Option { @@ -617,72 +1257,171 @@ fn find_header_end(bytes: &[u8]) -> Option { } fn contains_ascii_case_insensitive(haystack: &[u8], needle: &[u8]) -> bool { - haystack.windows(needle.len()).any(|w| w.eq_ignore_ascii_case(needle)) + if needle.is_empty() || haystack.len() < needle.len() { + return false; + } + + haystack + .windows(needle.len()) + .any(|window| window.eq_ignore_ascii_case(needle)) } -fn should_enable_nodelay() -> bool { - std::env::var("HTTP_NATIVE_TCP_NODELAY") - .ok() - .map(|v| !matches!(v.trim().to_lowercase().as_str(), "0" | "false" | "off" | "no")) - .unwrap_or(true) +fn trim_ascii_spaces(bytes: &[u8]) -> &[u8] { + let start = bytes + .iter() + .position(|byte| !byte.is_ascii_whitespace()) + .unwrap_or(bytes.len()); + let end = bytes + .iter() + .rposition(|byte| !byte.is_ascii_whitespace()) + .map(|index| index + 1) + .unwrap_or(start); + &bytes[start..end] } -fn bind_listener(options: &NativeListenOptions, config: &HttpServerConfig) -> Result { - let host = options.host.as_deref().unwrap_or(&config.default_host); - let addr = (host, options.port).to_socket_addrs()?.next() - .ok_or_else(|| anyhow::anyhow!("Failed to resolve address {host}:{}", options.port))?; +fn normalize_runtime_path(path: &str) -> Cow<'_, str> { + if path == "/" || !path.ends_with('/') { + return Cow::Borrowed(path); + } - let mut opts = monoio::net::ListenerOpts::new() - .reuse_addr(true) - .backlog(options.backlog.unwrap_or(config.default_backlog)); - - if worker_count_for(options) > 1 { - opts = opts.reuse_port(true); + Cow::Owned(crate::analyzer::normalize_path(path)) +} + +fn config_string( + input: Option<&HttpServerConfigInput>, + pick: impl Fn(&HttpServerConfigInput) -> Option<&str>, + fallback: &str, +) -> String { + input.and_then(pick).unwrap_or(fallback).to_string() +} + +fn status_reason(status: u16) -> &'static str { + match status { + 200 => "OK", + 201 => "Created", + 202 => "Accepted", + 204 => "No Content", + 400 => "Bad Request", + 404 => "Not Found", + 411 => "Length Required", + 413 => "Payload Too Large", + 415 => "Unsupported Media Type", + 431 => "Request Header Fields Too Large", + 500 => "Internal Server Error", + 502 => "Bad Gateway", + 503 => "Service Unavailable", + _ => "OK", } +} - TcpListener::bind_with_config(addr, &opts) +/// Fast integer-to-string for small values — uses stack-allocated itoa buffer +#[inline(always)] +fn write_usize(output: &mut Vec, value: usize) { + let mut buf = itoa::Buffer::new(); + output.extend_from_slice(buf.format(value).as_bytes()); } -fn worker_count_for(options: &NativeListenOptions) -> usize { - options.workers.unwrap_or_else(num_cpus::get) +#[inline(always)] +fn write_u16(output: &mut Vec, value: u16) { + let mut buf = itoa::Buffer::new(); + output.extend_from_slice(buf.format(value).as_bytes()); } -// ─── NAPI Glue ──────────────────────────────────────────────────────────────── +fn count_digits(mut n: usize) -> usize { + if n == 0 { + return 1; + } + let mut count = 0; + while n > 0 { + count += 1; + n /= 10; + } + count +} -#[napi(object)] -#[derive(Clone, Default)] -pub struct NativeListenOptions { - pub host: Option, - pub port: u16, - pub backlog: Option, - pub workers: Option, +fn push_string_pair(frame: &mut Vec, name: &str, value: &str) -> Result<()> { + if name.len() > u8::MAX as usize { + return Err(anyhow!("field name too long")); + } + if value.len() > u16::MAX as usize { + return Err(anyhow!("field value too long")); + } + + frame.push(name.len() as u8); + push_u16(frame, value.len() as u16); + frame.extend_from_slice(name.as_bytes()); + frame.extend_from_slice(value.as_bytes()); + Ok(()) } -#[napi] -pub struct ServerHandle { - workers: Vec>, +fn push_string_value(frame: &mut Vec, value: &str) -> Result<()> { + if value.len() > u16::MAX as usize { + return Err(anyhow!("field value too long")); + } + + push_u16(frame, value.len() as u16); + frame.extend_from_slice(value.as_bytes()); + Ok(()) } -#[napi] -impl ServerHandle { - #[napi] - pub fn close(&mut self) { - // In a real implementation, we'd send a shutdown signal. - // For now, we just let them drop or kill the process. +fn push_u16(frame: &mut Vec, value: u16) { + frame.extend_from_slice(&value.to_le_bytes()); +} + +fn push_u32(frame: &mut Vec, value: u32) { + frame.extend_from_slice(&value.to_le_bytes()); +} + +fn read_u8(bytes: &[u8], offset: &mut usize) -> Result { + if *offset + 1 > bytes.len() { + return Err(anyhow!("response envelope truncated")); } + + let value = bytes[*offset]; + *offset += 1; + Ok(value) } -impl HttpServerConfig { - fn from_input(input: Option<&HttpServerConfigInput>) -> Self { - Self { - default_host: input.and_then(|i| i.default_host.clone()).unwrap_or_else(|| "127.0.0.1".to_string()), - default_backlog: input.and_then(|i| i.default_backlog).unwrap_or(2048), - max_header_bytes: input.and_then(|i| i.max_header_bytes).unwrap_or(8192), - hot_get_root_http11: b"GET / HTTP/1.1\r\n".to_vec(), - hot_get_root_http10: b"GET / HTTP/1.0\r\n".to_vec(), - header_connection_prefix: b"connection:".to_vec(), - header_content_length_prefix: b"content-length:".to_vec(), - header_transfer_encoding_prefix: b"transfer-encoding:".to_vec(), - } +fn read_u16(bytes: &[u8], offset: &mut usize) -> Result { + if *offset + 2 > bytes.len() { + return Err(anyhow!("response envelope truncated")); } + + let value = u16::from_le_bytes([bytes[*offset], bytes[*offset + 1]]); + *offset += 2; + Ok(value) +} + +fn read_u32(bytes: &[u8], offset: &mut usize) -> Result { + if *offset + 4 > bytes.len() { + return Err(anyhow!("response envelope truncated")); + } + + let value = u32::from_le_bytes([ + bytes[*offset], + bytes[*offset + 1], + bytes[*offset + 2], + bytes[*offset + 3], + ]); + *offset += 4; + Ok(value) +} + +fn read_utf8(bytes: &[u8], offset: &mut usize, length: usize) -> Result { + if *offset + length > bytes.len() { + return Err(anyhow!("response envelope truncated")); + } + + let value = std::str::from_utf8(&bytes[*offset..*offset + length]) + .context("response envelope contained invalid utf-8")? + .to_string(); + *offset += length; + Ok(value) +} + +fn to_napi_error(error: E) -> Error +where + E: std::fmt::Display, +{ + Error::from_reason(error.to_string()) } From 00dafe1c10b1b915286578c30cb91e73124faef4 Mon Sep 17 00:00:00 2001 From: Rishi Yadav Date: Fri, 27 Mar 2026 20:54:38 +0530 Subject: [PATCH 4/4] fix --- rust-native/src/manifest.rs | 1 + rust-native/src/router.rs | 1 + 2 files changed, 2 insertions(+) diff --git a/rust-native/src/manifest.rs b/rust-native/src/manifest.rs index 6176c42..c0d45f0 100644 --- a/rust-native/src/manifest.rs +++ b/rust-native/src/manifest.rs @@ -38,6 +38,7 @@ pub struct RouteInput { pub handler_id: u32, pub handler_source: String, pub param_names: Vec, + #[allow(dead_code)] pub segment_count: u16, pub header_keys: Vec, pub full_headers: bool, diff --git a/rust-native/src/router.rs b/rust-native/src/router.rs index d893391..070d839 100644 --- a/rust-native/src/router.rs +++ b/rust-native/src/router.rs @@ -41,6 +41,7 @@ pub struct MatchedRoute<'a, 'b> { #[derive(Clone)] struct DynamicRouteSpec { handler_id: u32, + #[allow(dead_code)] param_names: Box<[Box]>, header_keys: Box<[Box]>, full_headers: bool,