diff --git a/.github/bench/target.js b/.github/bench/target.js index 15beaa4..96103e8 100644 --- a/.github/bench/target.js +++ b/.github/bench/target.js @@ -29,6 +29,10 @@ function buildOptimizedPayload(label) { function buildDynamicApp(createApp, label) { const app = createApp(); + /** [Auto generated by http-native] + * [http-native optimization] bridge-dispatch + * This route currently runs through bridge dispatch because it depends on runtime request data. + */ app.get("/users/:id", async (req, res) => { res.json({ id: req.params.id, @@ -44,10 +48,6 @@ function buildStaticApp(createApp, label) { const app = createApp(); if (label === "http-native") { - /** [Auto generated by http-native] - * [http-native optimization] static-fast-path - * This route is served by the static fast path and avoids generic bridge dispatch. - */ app.get("/", (req, res) => { res.json({ ok: true, diff --git a/.github/tests/test.js b/.github/tests/test.js index 39c73f4..97cabfc 100644 --- a/.github/tests/test.js +++ b/.github/tests/test.js @@ -319,6 +319,7 @@ async function main() { serverConfig: { ...httpServerConfig, maxHeaderBytes: httpServerConfig.maxHeaderBytes, + tls: null, }, }); let closed = false; diff --git a/.gitignore b/.gitignore index 196326b..0ab39f2 100644 --- a/.gitignore +++ b/.gitignore @@ -18,3 +18,7 @@ examples/ plan/ plans/ http-native.wiki/ + +PLAN.md +boost.md +dx.md diff --git a/Cargo.lock b/Cargo.lock index d664c28..355b1bd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -38,6 +38,18 @@ dependencies = [ "rustversion", ] +[[package]] +name = "arrayvec" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" + +[[package]] +name = "atomic-waker" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" + [[package]] name = "auto-const-array" version = "0.2.2" @@ -155,12 +167,24 @@ dependencies = [ "shlex", ] +[[package]] +name = "cesu8" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d43a04d8753f35258c91f8ec639f792891f748a1edbd759cf1dcea3382ad83c" + [[package]] name = "cfg-if" version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" +[[package]] +name = "cfg_aliases" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" + [[package]] name = "cmake" version = "0.1.58" @@ -170,6 +194,16 @@ dependencies = [ "cc", ] +[[package]] +name = "combine" +version = "4.6.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba5a308b75df32fe02788e748662718f03fde005016435c444eea572398219fd" +dependencies = [ + "bytes", + "memchr", +] + [[package]] name = "convert_case" version = "0.11.0" @@ -179,6 +213,22 @@ dependencies = [ "unicode-segmentation", ] +[[package]] +name = "core-foundation" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2a6cd9ae233e7f62ba4e9353e81a88df7fc8a5987b8d445b4d90c879bd156f6" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" + [[package]] name = "cpufeatures" version = "0.2.17" @@ -237,7 +287,7 @@ checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" dependencies = [ "cfg-if", "crossbeam-utils", - "hashbrown", + "hashbrown 0.14.5", "lock_api", "once_cell", "parking_lot_core", @@ -286,6 +336,30 @@ version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" +[[package]] +name = "equivalent" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" + +[[package]] +name = "fastbloom" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e7f34442dbe69c60fe8eaf58a8cafff81a1f278816d8ab4db255b3bef4ac3c4" +dependencies = [ + "getrandom 0.3.4", + "libm", + "rand", + "siphasher", +] + +[[package]] +name = "fastrand" +version = "2.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f1f227452a390804cdb637b74a86990f2a7d7ba4b7d5693aac9b4dd6defd8d6" + [[package]] name = "find-msvc-tools" version = "0.1.9" @@ -314,6 +388,12 @@ dependencies = [ "spin", ] +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + [[package]] name = "form_urlencoded" version = "1.2.2" @@ -456,9 +536,58 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" dependencies = [ "cfg-if", + "js-sys", "libc", "r-efi", "wasip2", + "wasm-bindgen", +] + +[[package]] +name = "h2" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f44da3a8150a6703ed5d34e164b875fd14c2cdab9af1252a9a1020bde2bdc54" +dependencies = [ + "atomic-waker", + "bytes", + "fnv", + "futures-core", + "futures-sink", + "http", + "indexmap", + "slab", + "tokio", + "tokio-util", + "tracing", +] + +[[package]] +name = "h3" +version = "0.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10872b55cfb02a821b69dc7cf8dc6a71d6af25eb9a79662bec4a9d016056b3be" +dependencies = [ + "bytes", + "fastrand", + "futures-util", + "http", + "pin-project-lite", + "tokio", +] + +[[package]] +name = "h3-quinn" +version = "0.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b2e732c8d91a74731663ac8479ab505042fbf547b9a207213ab7fbcbfc4f8b4" +dependencies = [ + "bytes", + "futures", + "h3", + "quinn", + "tokio", + "tokio-util", ] [[package]] @@ -467,6 +596,12 @@ version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +[[package]] +name = "hashbrown" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" + [[package]] name = "hermit-abi" version = "0.5.2" @@ -482,23 +617,40 @@ dependencies = [ "digest", ] +[[package]] +name = "http" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a" +dependencies = [ + "bytes", + "itoa", +] + [[package]] name = "http_native_napi" version = "0.1.0" dependencies = [ "anyhow", "arc-swap", + "arrayvec", "base64", "brotli", + "bumpalo", "bytes", "dashmap", "flate2", "flume", "getrandom 0.2.17", + "h2", + "h3", + "h3-quinn", "hmac", + "http", "httparse", "itoa", "json5", + "libc", "memchr", "monoio", "monoio-rustls", @@ -506,6 +658,7 @@ dependencies = [ "napi-build", "napi-derive", "parking_lot", + "quinn", "rustc-hash", "rustls", "rustls-pemfile", @@ -513,7 +666,9 @@ dependencies = [ "serde_json", "sha1", "sha2", - "socket2", + "socket2 0.5.10", + "tokio", + "tokio-rustls", "url", ] @@ -625,6 +780,16 @@ dependencies = [ "icu_properties", ] +[[package]] +name = "indexmap" +version = "2.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45a8a2b9cb3e0b0c1803dbb0758ffac5de2f425b23c28f518faabd9d805342ff" +dependencies = [ + "equivalent", + "hashbrown 0.16.1", +] + [[package]] name = "io-uring" version = "0.6.4" @@ -641,6 +806,50 @@ version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682" +[[package]] +name = "jni" +version = "0.21.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a87aa2bb7d2af34197c04845522473242e1aa17c12f4935d5856491a7fb8c97" +dependencies = [ + "cesu8", + "cfg-if", + "combine", + "jni-sys 0.3.1", + "log", + "thiserror 1.0.69", + "walkdir", + "windows-sys 0.45.0", +] + +[[package]] +name = "jni-sys" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41a652e1f9b6e0275df1f15b32661cf0d4b78d4d87ddec5e0c3c20f097433258" +dependencies = [ + "jni-sys 0.4.1", +] + +[[package]] +name = "jni-sys" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6377a88cb3910bee9b0fa88d4f42e1d2da8e79915598f65fb0c7ee14c878af2" +dependencies = [ + "jni-sys-macros", +] + +[[package]] +name = "jni-sys-macros" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38c0b942f458fe50cdac086d2f946512305e5631e720728f2a61aabcd47a6264" +dependencies = [ + "quote", + "syn", +] + [[package]] name = "jobserver" version = "0.1.34" @@ -688,6 +897,12 @@ dependencies = [ "windows-link", ] +[[package]] +name = "libm" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6d2cec3eae94f9f509c767b45932f1ada8350c4bdb85af2fcab4a3c14807981" + [[package]] name = "litemap" version = "0.8.1" @@ -709,6 +924,12 @@ version = "0.4.29" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" +[[package]] +name = "lru-slab" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" + [[package]] name = "memchr" version = "2.8.0" @@ -746,6 +967,17 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "mio" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50b7e5b27aa02a74bac8c3f23f448f8d87ff11f92d3aac1a6ed369ee08cc56c1" +dependencies = [ + "libc", + "wasi", + "windows-sys 0.61.2", +] + [[package]] name = "monoio" version = "0.2.4" @@ -759,13 +991,14 @@ dependencies = [ "io-uring", "libc", "memchr", - "mio", + "mio 0.8.11", "monoio-macros", "nix", "once_cell", "pin-project-lite", - "socket2", + "socket2 0.5.10", "threadpool", + "tokio", "windows-sys 0.48.0", ] @@ -799,7 +1032,7 @@ dependencies = [ "monoio", "monoio-io-wrapper", "rustls", - "thiserror", + "thiserror 1.0.69", ] [[package]] @@ -903,6 +1136,12 @@ version = "1.21.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" +[[package]] +name = "openssl-probe" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe" + [[package]] name = "parking_lot" version = "0.12.5" @@ -996,6 +1235,15 @@ dependencies = [ "zerovec", ] +[[package]] +name = "ppv-lite86" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" +dependencies = [ + "zerocopy", +] + [[package]] name = "proc-macro2" version = "1.0.106" @@ -1005,6 +1253,64 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "quinn" +version = "0.11.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e20a958963c291dc322d98411f541009df2ced7b5a4f2bd52337638cfccf20" +dependencies = [ + "bytes", + "cfg_aliases", + "futures-io", + "pin-project-lite", + "quinn-proto", + "quinn-udp", + "rustc-hash", + "rustls", + "socket2 0.6.3", + "thiserror 2.0.18", + "tokio", + "tracing", + "web-time", +] + +[[package]] +name = "quinn-proto" +version = "0.11.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "434b42fec591c96ef50e21e886936e66d3cc3f737104fdb9b737c40ffb94c098" +dependencies = [ + "bytes", + "fastbloom", + "getrandom 0.3.4", + "lru-slab", + "rand", + "ring", + "rustc-hash", + "rustls", + "rustls-pki-types", + "rustls-platform-verifier", + "slab", + "thiserror 2.0.18", + "tinyvec", + "tracing", + "web-time", +] + +[[package]] +name = "quinn-udp" +version = "0.5.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "addec6a0dcad8a8d96a771f815f0eaf55f9d1805756410b39f5fa81332574cbd" +dependencies = [ + "cfg_aliases", + "libc", + "once_cell", + "socket2 0.6.3", + "tracing", + "windows-sys 0.52.0", +] + [[package]] name = "quote" version = "1.0.45" @@ -1020,6 +1326,35 @@ version = "5.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" +[[package]] +name = "rand" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" +dependencies = [ + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76afc826de14238e6e8c374ddcc1fa19e374fd8dd986b0d2af0d02377261d83c" +dependencies = [ + "getrandom 0.3.4", +] + [[package]] name = "redox_syscall" version = "0.5.18" @@ -1058,12 +1393,25 @@ dependencies = [ "aws-lc-rs", "log", "once_cell", + "ring", "rustls-pki-types", "rustls-webpki", "subtle", "zeroize", ] +[[package]] +name = "rustls-native-certs" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "612460d5f7bea540c490b2b6395d8e34a953e52b491accd6c86c8164c5932a63" +dependencies = [ + "openssl-probe", + "rustls-pki-types", + "schannel", + "security-framework", +] + [[package]] name = "rustls-pemfile" version = "2.2.0" @@ -1079,9 +1427,37 @@ version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "be040f8b0a225e40375822a563fa9524378b9d63112f53e19ffff34df5d33fdd" dependencies = [ + "web-time", "zeroize", ] +[[package]] +name = "rustls-platform-verifier" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d99feebc72bae7ab76ba994bb5e121b8d83d910ca40b36e0921f53becc41784" +dependencies = [ + "core-foundation", + "core-foundation-sys", + "jni", + "log", + "once_cell", + "rustls", + "rustls-native-certs", + "rustls-platform-verifier-android", + "rustls-webpki", + "security-framework", + "security-framework-sys", + "webpki-root-certs", + "windows-sys 0.61.2", +] + +[[package]] +name = "rustls-platform-verifier-android" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f87165f0995f63a9fbeea62b64d10b4d9d8e78ec6d7d51fb2125fda7bb36788f" + [[package]] name = "rustls-webpki" version = "0.103.10" @@ -1100,12 +1476,53 @@ version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + +[[package]] +name = "schannel" +version = "0.1.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91c1b7e4904c873ef0710c1f407dde2e6287de2bebc1bbbf7d430bb7cbffd939" +dependencies = [ + "windows-sys 0.61.2", +] + [[package]] name = "scopeguard" version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "security-framework" +version = "3.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7f4bc775c73d9a02cde8bf7b2ec4c9d12743edf609006c7facc23998404cd1d" +dependencies = [ + "bitflags 2.11.0", + "core-foundation", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ce2691df843ecc5d231c0b14ece2acc3efb62c0a398c7e1d875f3983ce020e3" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "semver" version = "1.0.27" @@ -1189,6 +1606,12 @@ version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "703d5c7ef118737c72f1af64ad2f6f8c5e1921f818cdcb97b8fe6fc69bf66214" +[[package]] +name = "siphasher" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2aa850e253778c88a04c3d7323b043aeda9d3e30d5971937c1855769763678e" + [[package]] name = "slab" version = "0.4.12" @@ -1211,6 +1634,16 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "socket2" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e" +dependencies = [ + "libc", + "windows-sys 0.61.2", +] + [[package]] name = "spin" version = "0.9.8" @@ -1260,7 +1693,16 @@ version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" dependencies = [ - "thiserror-impl", + "thiserror-impl 1.0.69", +] + +[[package]] +name = "thiserror" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4" +dependencies = [ + "thiserror-impl 2.0.18", ] [[package]] @@ -1274,6 +1716,17 @@ dependencies = [ "syn", ] +[[package]] +name = "thiserror-impl" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "threadpool" version = "1.8.1" @@ -1293,6 +1746,90 @@ dependencies = [ "zerovec", ] +[[package]] +name = "tinyvec" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3e61e67053d25a4e82c844e8424039d9745781b3fc4f32b8d55ed50f5f667ef3" +dependencies = [ + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + +[[package]] +name = "tokio" +version = "1.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2bd1c4c0fc4a7ab90fc15ef6daaa3ec3b893f004f915f2392557ed23237820cd" +dependencies = [ + "bytes", + "libc", + "mio 1.2.0", + "pin-project-lite", + "socket2 0.6.3", + "tokio-macros", + "windows-sys 0.61.2", +] + +[[package]] +name = "tokio-macros" +version = "2.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "385a6cb71ab9ab790c5fe8d67f1645e6c450a7ce006a33de03daa956cf70a496" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tokio-rustls" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61" +dependencies = [ + "rustls", + "tokio", +] + +[[package]] +name = "tokio-util" +version = "0.7.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ae9cec805b01e8fc3fd2fe289f89149a9b66dd16786abd8b19cfa7b48cb0098" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", +] + +[[package]] +name = "tracing" +version = "0.1.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" +dependencies = [ + "log", + "pin-project-lite", + "tracing-core", +] + +[[package]] +name = "tracing-core" +version = "0.1.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a" +dependencies = [ + "once_cell", +] + [[package]] name = "typenum" version = "1.19.0" @@ -1347,6 +1884,16 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + [[package]] name = "wasi" version = "0.11.1+wasi-snapshot-preview1" @@ -1407,12 +1954,49 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "web-time" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "webpki-root-certs" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "804f18a4ac2676ffb4e8b5b5fa9ae38af06df08162314f96a68d2a363e21a8ca" +dependencies = [ + "rustls-pki-types", +] + +[[package]] +name = "winapi-util" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" +dependencies = [ + "windows-sys 0.61.2", +] + [[package]] name = "windows-link" version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" +[[package]] +name = "windows-sys" +version = "0.45.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0" +dependencies = [ + "windows-targets 0.42.2", +] + [[package]] name = "windows-sys" version = "0.48.0" @@ -1431,6 +2015,30 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-targets" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071" +dependencies = [ + "windows_aarch64_gnullvm 0.42.2", + "windows_aarch64_msvc 0.42.2", + "windows_i686_gnu 0.42.2", + "windows_i686_msvc 0.42.2", + "windows_x86_64_gnu 0.42.2", + "windows_x86_64_gnullvm 0.42.2", + "windows_x86_64_msvc 0.42.2", +] + [[package]] name = "windows-targets" version = "0.48.5" @@ -1462,6 +2070,12 @@ dependencies = [ "windows_x86_64_msvc 0.52.6", ] +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" + [[package]] name = "windows_aarch64_gnullvm" version = "0.48.5" @@ -1474,6 +2088,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" +[[package]] +name = "windows_aarch64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" + [[package]] name = "windows_aarch64_msvc" version = "0.48.5" @@ -1486,6 +2106,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" +[[package]] +name = "windows_i686_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" + [[package]] name = "windows_i686_gnu" version = "0.48.5" @@ -1504,6 +2130,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" +[[package]] +name = "windows_i686_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" + [[package]] name = "windows_i686_msvc" version = "0.48.5" @@ -1516,6 +2148,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" +[[package]] +name = "windows_x86_64_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" + [[package]] name = "windows_x86_64_gnu" version = "0.48.5" @@ -1528,6 +2166,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" + [[package]] name = "windows_x86_64_gnullvm" version = "0.48.5" @@ -1540,6 +2184,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" +[[package]] +name = "windows_x86_64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" + [[package]] name = "windows_x86_64_msvc" version = "0.48.5" @@ -1587,6 +2237,26 @@ dependencies = [ "synstructure", ] +[[package]] +name = "zerocopy" +version = "0.8.48" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eed437bf9d6692032087e337407a86f04cd8d6a16a37199ed57949d415bd68e9" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.48" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70e3cd084b1788766f53af483dd21f93881ff30d7320490ec3ef7526d203bad4" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "zerofrom" version = "0.1.6" diff --git a/Cargo.toml b/Cargo.toml index 98f0611..2c420f1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,11 +10,14 @@ path = "rsrc/src/lib.rs" [dependencies] anyhow = "1.0" arc-swap = "1.7" +arrayvec = "0.7" base64 = "0.22" +bumpalo = "3" bytes = "1.10" httparse = "1.9" itoa = "1.0" json5 = "0.4" +libc = "0.2" memchr = "2.7" rustc-hash = "2" dashmap = "6" @@ -22,12 +25,19 @@ flume = "0.11" getrandom = "0.2" sha1 = "0.10" hmac = "0.12" -monoio = { version = "0.2", features = ["sync", "legacy"] } +monoio = { version = "0.2", features = ["sync", "legacy", "poll-io"] } +h2 = "0.4" +h3 = "0.0.8" +h3-quinn = "0.0.10" +http = "1" monoio-rustls = "0.4" napi = { version = "3", default-features = false, features = ["napi8"] } napi-derive = "3" parking_lot = "0.12" rustls = "0.23" +quinn = "0.11" +tokio = { version = "1", default-features = false, features = ["io-util", "rt-multi-thread", "macros"] } +tokio-rustls = "0.26" rustls-pemfile = "2" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" diff --git a/package.json b/package.json index a29a0c4..d498094 100644 --- a/package.json +++ b/package.json @@ -23,19 +23,87 @@ "types": "./src/index.d.ts", "default": "./src/index.js" }, - "./cors": "./src/cors.js", + "./cors": { + "types": "./src/index.d.ts", + "default": "./src/cors.js" + }, "./dev": { "types": "./src/dev/index.d.ts", "default": "./src/dev/index.js" }, "./hot": "./src/hot.js", - "./session": "./src/session.js", - "./compress": "./src/compress.js", + "./session": { + "types": "./src/session.d.ts", + "default": "./src/session.js" + }, + "./compress": { + "types": "./src/index.d.ts", + "default": "./src/compress.js" + }, "./rate-limit": { "types": "./src/rate-limit.d.ts", "default": "./src/rate-limit.js" }, - "./validate": "./src/validate.js", + "./validate": { + "types": "./src/index.d.ts", + "default": "./src/validate.js" + }, + "./helmet": { + "types": "./src/index.d.ts", + "default": "./src/helmet.js" + }, + "./request-id": { + "types": "./src/index.d.ts", + "default": "./src/request-id.js" + }, + "./body-limit": { + "types": "./src/index.d.ts", + "default": "./src/body-limit.js" + }, + "./csrf": { + "types": "./src/index.d.ts", + "default": "./src/csrf.js" + }, + "./ip-filter": { + "types": "./src/index.d.ts", + "default": "./src/ip-filter.js" + }, + "./error": { + "types": "./src/index.d.ts", + "default": "./src/error.js" + }, + "./audit-log": { + "types": "./src/index.d.ts", + "default": "./src/audit-log.js" + }, + "./test": { + "types": "./src/index.d.ts", + "default": "./src/test.js" + }, + "./circuit-breaker": { + "types": "./src/index.d.ts", + "default": "./src/circuit-breaker.js" + }, + "./env": { + "types": "./src/index.d.ts", + "default": "./src/env.js" + }, + "./openapi": { + "types": "./src/index.d.ts", + "default": "./src/openapi.js" + }, + "./multipart": { + "types": "./src/index.d.ts", + "default": "./src/multipart.js" + }, + "./logger": { + "types": "./src/index.d.ts", + "default": "./src/logger.js" + }, + "./otel": { + "types": "./src/index.d.ts", + "default": "./src/otel.js" + }, "./http-server.config": "./src/http-server.config.js" }, "scripts": { @@ -46,6 +114,9 @@ "setup:native": "bun scripts/setup-native.mjs || node scripts/setup-native.mjs", "test": "bun run build && bun .github/tests/test.js && bun .github/tests/test-dev.js" }, + "dependencies": { + "acorn": "^8.14.1" + }, "devDependencies": { "express": "^5.1.0" } diff --git a/rsrc/src/compress.rs b/rsrc/src/compress.rs index bce8761..175a654 100644 --- a/rsrc/src/compress.rs +++ b/rsrc/src/compress.rs @@ -125,16 +125,31 @@ pub fn parse_accept_encoding(value: &[u8]) -> AcceptedEncoding { let mut best = AcceptedEncoding::Identity; for part in value.split(|&b| b == b',') { let trimmed = trim_ascii(part); + let mut segments = trimmed.split(|&b| b == b';'); // Extract the encoding name (before any ;q= weight) - let name = trimmed - .split(|&b| b == b';') + let name = segments .next() .map(trim_ascii) .unwrap_or(trimmed); - if name.eq_ignore_ascii_case(b"br") { - return AcceptedEncoding::Brotli; // Best possible — return immediately + + // Check for q=0 which means the client explicitly rejects this encoding + let rejected = segments.any(|seg| { + let seg = trim_ascii(seg); + if seg.len() >= 2 && seg[0].eq_ignore_ascii_case(&b'q') && seg[1] == b'=' { + let qval = trim_ascii(&seg[2..]); + // Match q=0, q=0., q=0.0, q=0.00, q=0.000 + matches!(qval, b"0" | b"0." | b"0.0" | b"0.00" | b"0.000") + } else { + false + } + }); + if rejected { + continue; } - if name.eq_ignore_ascii_case(b"gzip") { + + if name.eq_ignore_ascii_case(b"br") { + best = AcceptedEncoding::Brotli; + } else if name.eq_ignore_ascii_case(b"gzip") && best != AcceptedEncoding::Brotli { best = AcceptedEncoding::Gzip; } } diff --git a/rsrc/src/h2_handler.rs b/rsrc/src/h2_handler.rs new file mode 100644 index 0000000..16886e9 --- /dev/null +++ b/rsrc/src/h2_handler.rs @@ -0,0 +1,448 @@ +//! HTTP/2 connection handler (DX-4.1 / BOOST-3.1). +//! +//! Handles HTTP/2 multiplexed streams over TLS connections that negotiate +//! the "h2" ALPN protocol. Each stream dispatches through the same routing +//! and handler infrastructure as HTTP/1.1, maintaining API parity. +//! +//! Architecture: +//! - TLS handshake with ALPN negotiation selects "h2" protocol +//! - Raw TCP stream is converted to poll-io compatible wrapper +//! - tokio-rustls provides TLS over the poll-io stream +//! - h2 crate handles frame processing, HPACK compression, flow control +//! - Each HTTP/2 stream maps to a single route handler invocation +//! - Responses are sent back through the h2 send-response mechanism + +use std::rc::Rc; +use std::sync::Arc; + +use anyhow::{anyhow, Result}; +use bytes::Bytes; +use h2::server::{self, SendResponse}; +use h2::RecvStream; +use http::{Request as H2Request, Response as H2Response, StatusCode}; +use napi::bindgen_prelude::Buffer; +use tokio_rustls::server::TlsStream as TokioTlsStream; + +use crate::compress; +use crate::parser::intern_header_name; +use crate::router::{ExactStaticRoute, Router}; +use crate::{ + JsDispatcher, LiveRouter, HttpServerConfig, + BRIDGE_VERSION, REQUEST_FLAG_QUERY_PRESENT, REQUEST_FLAG_BODY_PRESENT, + UNKNOWN_METHOD_CODE, NOT_FOUND_HANDLER_ID, MAX_BODY_BYTES, + INFLIGHT_REQUESTS, method_code_from_str, +}; + +/// Maximum concurrent streams per HTTP/2 connection. +const H2_MAX_CONCURRENT_STREAMS: u32 = 256; + +/// Initial window size for HTTP/2 flow control (2 MB). +const H2_INITIAL_WINDOW_SIZE: u32 = 2 * 1024 * 1024; + +/// Handle an HTTP/2 connection over a poll-io TLS stream. +/// +/// Performs the h2 server handshake, then dispatches each stream through +/// the standard routing infrastructure. Multiplexed streams are handled +/// concurrently — there is no head-of-line blocking. +pub(crate) async fn handle_h2_connection( + io: TokioTlsStream, + live_router: Rc>, + dispatcher: Rc>, + server_config: Rc>, + peer_ip: Option, +) -> Result<()> +where + IO: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static, +{ + let mut h2_builder = server::Builder::new(); + h2_builder + .max_concurrent_streams(H2_MAX_CONCURRENT_STREAMS) + .initial_window_size(H2_INITIAL_WINDOW_SIZE); + + let mut connection = h2_builder + .handshake(io) + .await + .map_err(|e| anyhow!("h2 handshake failed: {e}"))?; + + while let Some(result) = connection.accept().await { + let (request, respond) = result.map_err(|e| anyhow!("h2 accept error: {e}"))?; + + let router_ref = live_router.as_ref().as_ref(); + let router = router_ref.router.load_full(); + let dispatcher_ref = Rc::clone(&dispatcher); + let config_ref = Rc::clone(&server_config); + let peer_ip_clone = peer_ip.clone(); + + /* Each HTTP/2 stream is an independent request — dispatch concurrently. + * This eliminates head-of-line blocking that HTTP/1.1 suffers from. */ + monoio::spawn(async move { + if let Err(e) = handle_h2_stream( + request, + respond, + router, + dispatcher_ref.as_ref().as_ref(), + config_ref.as_ref().as_ref(), + peer_ip_clone.as_deref(), + ) + .await + { + log_error!("h2 stream error: {e}"); + } + }); + } + + Ok(()) +} + +/// Handle a single HTTP/2 stream (one request → one response). +async fn handle_h2_stream( + request: H2Request, + mut respond: SendResponse, + router: Arc, + dispatcher: &JsDispatcher, + server_config: &HttpServerConfig, + peer_ip: Option<&str>, +) -> Result<()> { + INFLIGHT_REQUESTS.fetch_add(1, std::sync::atomic::Ordering::Release); + let _guard = InflightGuard; + + let (parts, mut body_stream) = request.into_parts(); + + let method_str = parts.method.as_str(); + let path_and_query = parts.uri.path_and_query(); + let path = path_and_query.map(|pq| pq.path()).unwrap_or("/"); + let full_uri = path_and_query + .map(|pq| pq.as_str()) + .unwrap_or("/"); + let query_present = parts.uri.query().is_some(); + + let method_code = method_code_from_str(method_str).unwrap_or(UNKNOWN_METHOD_CODE); + + /* Collect headers into (name, value) pairs for routing and bridge */ + let mut headers: Vec<(&str, &str)> = Vec::with_capacity(parts.headers.len()); + let mut accepted_encoding = compress::AcceptedEncoding::Identity; + + for (name, value) in parts.headers.iter() { + let name_str = intern_header_name(name.as_str()); + if let Ok(val_str) = std::str::from_utf8(value.as_bytes()) { + if name_str == "accept-encoding" + && accepted_encoding != compress::AcceptedEncoding::Brotli + { + accepted_encoding = compress::parse_accept_encoding(val_str.as_bytes()); + } + headers.push((name_str, val_str)); + } + } + + /* Try static route fast-path first */ + if method_str == "GET" { + if let Some(static_route) = router.exact_static_route(b"GET", path.as_bytes()) { + return send_h2_static_response(&mut respond, static_route, accepted_encoding); + } + } + + /* Read body if present */ + let mut body_bytes = Vec::new(); + while let Some(chunk) = body_stream.data().await { + let chunk = chunk.map_err(|e| anyhow!("h2 body read error: {e}"))?; + if body_bytes.len() + chunk.len() > MAX_BODY_BYTES { + let response = H2Response::builder() + .status(StatusCode::PAYLOAD_TOO_LARGE) + .body(()) + .unwrap(); + let mut send = respond.send_response(response, false)?; + send.send_data(Bytes::from_static(b"{\"error\":\"Payload Too Large\"}"), true)?; + return Ok(()); + } + body_bytes.extend_from_slice(&chunk); + body_stream.flow_control().release_capacity(chunk.len())?; + } + + /* Route matching */ + let normalized_path = crate::normalize_runtime_path(path); + let matched_route = if method_code != UNKNOWN_METHOD_CODE { + router.match_route(method_code, normalized_path.as_ref()) + } else { + None + }; + + let _handler_id = matched_route.as_ref().map(|r| r.handler_id).unwrap_or(NOT_FOUND_HANDLER_ID); + + /* Build binary bridge envelope for JS dispatch */ + let envelope = build_h2_bridge_envelope( + method_code, + &headers, + path, + full_uri, + &body_bytes, + peer_ip, + query_present, + &matched_route, + ); + + /* Dispatch to JS and send response */ + match dispatcher.dispatch(Buffer::from(envelope)).await { + Ok(response_buf) => { + send_h2_response_from_bridge(&mut respond, &response_buf, server_config.compression.as_ref(), accepted_encoding)?; + } + Err(e) => { + let response = H2Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(()) + .unwrap(); + let mut send = respond.send_response(response, false)?; + send.send_data( + Bytes::from(format!("{{\"error\":\"Internal Server Error: {e}\"}}")), + true, + )?; + } + } + + Ok(()) +} + +/// Send a pre-built static response over an HTTP/2 stream. +fn send_h2_static_response( + respond: &mut SendResponse, + static_route: &ExactStaticRoute, + encoding: compress::AcceptedEncoding, +) -> Result<()> { + /* Select the best pre-compressed variant */ + let response_bytes = match encoding { + compress::AcceptedEncoding::Brotli => { + static_route.keep_alive_response_br.as_ref() + .unwrap_or(&static_route.keep_alive_response) + } + compress::AcceptedEncoding::Gzip => { + static_route.keep_alive_response_gzip.as_ref() + .unwrap_or(&static_route.keep_alive_response) + } + _ => &static_route.keep_alive_response, + }; + + /* Parse the pre-built HTTP/1.1 response to extract status, headers, and body + * for the h2 response. The static response format is: + * "HTTP/1.1 200 OK\r\nheader: value\r\n...\r\n\r\nbody" */ + let bytes = response_bytes.as_ref(); + if let Some(header_end) = memchr::memmem::find(bytes, b"\r\n\r\n") { + let header_section = &bytes[..header_end]; + let body = &bytes[header_end + 4..]; + + /* Parse actual status code from the HTTP/1.1 status line */ + let actual_status = header_section.split(|&b| b == b'\n') + .next() + .and_then(|line| { + let line = if line.ends_with(b"\r") { &line[..line.len() - 1] } else { line }; + // "HTTP/1.1 200 OK" → extract "200" + let parts: Vec<&[u8]> = line.splitn(3, |&b| b == b' ').collect(); + if parts.len() >= 2 { + std::str::from_utf8(parts[1]).ok()?.parse::().ok() + } else { + None + } + }) + .and_then(|code| StatusCode::from_u16(code).ok()) + .unwrap_or(StatusCode::OK); + + let mut builder = H2Response::builder().status(actual_status); + + /* Parse headers from the pre-built response */ + for line in header_section.split(|&b| b == b'\n') { + let line = if line.ends_with(b"\r") { &line[..line.len() - 1] } else { line }; + if line.starts_with(b"HTTP/") || line.is_empty() { + continue; + } + if let Some(colon) = line.iter().position(|&b| b == b':') { + let name = &line[..colon]; + let value = &line[colon + 1..]; + let value = if value.first() == Some(&b' ') { &value[1..] } else { value }; + /* Skip connection-specific headers invalid in h2 */ + if name.eq_ignore_ascii_case(b"connection") + || name.eq_ignore_ascii_case(b"transfer-encoding") + { + continue; + } + if let (Ok(n), Ok(v)) = (std::str::from_utf8(name), std::str::from_utf8(value)) { + builder = builder.header(n, v); + } + } + } + + let response = builder.body(()).map_err(|e| anyhow!("h2 response build error: {e}"))?; + let mut send = respond.send_response(response, body.is_empty())?; + if !body.is_empty() { + send.send_data(Bytes::copy_from_slice(body), true)?; + } + } else { + /* Fallback: send raw body */ + let response = H2Response::builder() + .status(StatusCode::OK) + .body(()) + .unwrap(); + let mut send = respond.send_response(response, false)?; + send.send_data(response_bytes.clone(), true)?; + } + + Ok(()) +} + +/// Build a bridge envelope for JS dispatch from HTTP/2 request data. +/// Also used by the HTTP/3 handler for bridge envelope construction. +pub fn build_h2_bridge_envelope( + method_code: u8, + headers: &[(&str, &str)], + path: &str, + url: &str, + body: &[u8], + peer_ip: Option<&str>, + query_present: bool, + matched: &Option>, +) -> Vec { + let handler_id = matched.as_ref().map(|m| m.handler_id).unwrap_or(NOT_FOUND_HANDLER_ID); + let ip_str = peer_ip.unwrap_or(""); + + /* Serialize matched route params as "k=v\0k=v" */ + let mut params_buf = Vec::new(); + if let Some(m) = matched.as_ref() { + for (i, name) in m.param_names.iter().enumerate() { + if i > 0 { params_buf.push(0); } + params_buf.extend_from_slice(name.as_bytes()); + params_buf.push(b'='); + if let Some(val) = m.param_values.get(i) { + params_buf.extend_from_slice(val.as_bytes()); + } + } + } + + /* Serialize headers needed by bridge */ + let needed_headers = if let Some(m) = matched.as_ref() { + if m.full_headers { + headers.to_vec() + } else { + headers.iter() + .filter(|(name, _)| m.header_keys.iter().any(|k| k.as_ref().eq_ignore_ascii_case(name))) + .copied() + .collect::>() + } + } else { + headers.to_vec() + }; + + let mut hdr_buf = Vec::new(); + for (name, value) in &needed_headers { + if !hdr_buf.is_empty() { hdr_buf.push(0); } + hdr_buf.extend_from_slice(name.as_bytes()); + hdr_buf.push(b':'); + hdr_buf.extend_from_slice(value.as_bytes()); + } + + let mut flags: u16 = 0; + if query_present { flags |= REQUEST_FLAG_QUERY_PRESENT; } + if !body.is_empty() { flags |= REQUEST_FLAG_BODY_PRESENT; } + + /* Binary envelope: version | method | flags(2) | handler_id(4) | url_len(2) | path_len(2) | + * ip_len(2) | params_len(2) | headers_len(2) | body_len(4) | url | path | ip | params | headers | body */ + let url_bytes = url.as_bytes(); + let path_bytes = path.as_bytes(); + let ip_bytes = ip_str.as_bytes(); + let total = 1 + 1 + 2 + 4 + 2 + 2 + 2 + 2 + 2 + 4 + url_bytes.len() + path_bytes.len() + + ip_bytes.len() + params_buf.len() + hdr_buf.len() + body.len(); + + let mut buf = Vec::with_capacity(total); + buf.push(BRIDGE_VERSION); + buf.push(method_code); + buf.extend_from_slice(&flags.to_le_bytes()); + buf.extend_from_slice(&handler_id.to_le_bytes()); + buf.extend_from_slice(&(url_bytes.len() as u16).to_le_bytes()); + buf.extend_from_slice(&(path_bytes.len() as u16).to_le_bytes()); + buf.extend_from_slice(&(ip_bytes.len() as u16).to_le_bytes()); + buf.extend_from_slice(&(params_buf.len() as u16).to_le_bytes()); + buf.extend_from_slice(&(hdr_buf.len() as u16).to_le_bytes()); + buf.extend_from_slice(&(body.len() as u32).to_le_bytes()); + buf.extend_from_slice(url_bytes); + buf.extend_from_slice(path_bytes); + buf.extend_from_slice(ip_bytes); + buf.extend_from_slice(¶ms_buf); + buf.extend_from_slice(&hdr_buf); + buf.extend_from_slice(body); + + buf +} + +/// Parse a JS bridge response envelope and send it as an HTTP/2 response. +fn send_h2_response_from_bridge( + respond: &mut SendResponse, + response_buf: &[u8], + _compression_config: Option<&compress::CompressionConfig>, + _accepted_encoding: compress::AcceptedEncoding, +) -> Result<()> { + if response_buf.len() < 8 { + let response = H2Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(()) + .unwrap(); + let mut send = respond.send_response(response, false)?; + send.send_data(Bytes::from_static(b"{\"error\":\"Invalid bridge response\"}"), true)?; + return Ok(()); + } + + /* Parse bridge response: status(2) | header_count(2) | body_len(4) | headers... | body */ + let status = u16::from_le_bytes([response_buf[0], response_buf[1]]); + let header_count = u16::from_le_bytes([response_buf[2], response_buf[3]]) as usize; + let body_len = u32::from_le_bytes([ + response_buf[4], response_buf[5], response_buf[6], response_buf[7], + ]) as usize; + + let mut offset = 8usize; + let mut builder = H2Response::builder().status(StatusCode::from_u16(status).unwrap_or(StatusCode::OK)); + + for _ in 0..header_count { + if offset + 4 > response_buf.len() { break; } + let name_len = u16::from_le_bytes([response_buf[offset], response_buf[offset + 1]]) as usize; + offset += 2; + let value_len = u16::from_le_bytes([response_buf[offset], response_buf[offset + 1]]) as usize; + offset += 2; + + if offset + name_len + value_len > response_buf.len() { break; } + + let name = &response_buf[offset..offset + name_len]; + offset += name_len; + let value = &response_buf[offset..offset + value_len]; + offset += value_len; + + /* Skip h1-only headers invalid in h2 */ + if name.eq_ignore_ascii_case(b"connection") + || name.eq_ignore_ascii_case(b"transfer-encoding") + { + continue; + } + + if let (Ok(n), Ok(v)) = (std::str::from_utf8(name), std::str::from_utf8(value)) { + builder = builder.header(n, v); + } + } + + let body = if offset + body_len <= response_buf.len() { + &response_buf[offset..offset + body_len] + } else if offset < response_buf.len() { + &response_buf[offset..] + } else { + &[] + }; + + let response = builder.body(()).map_err(|e| anyhow!("h2 response error: {e}"))?; + let mut send = respond.send_response(response, body.is_empty())?; + if !body.is_empty() { + send.send_data(Bytes::copy_from_slice(body), true)?; + } + + Ok(()) +} + +/// RAII guard for decrementing the in-flight request counter. +struct InflightGuard; +impl Drop for InflightGuard { + fn drop(&mut self) { + INFLIGHT_REQUESTS.fetch_sub(1, std::sync::atomic::Ordering::Release); + } +} diff --git a/rsrc/src/h3_handler.rs b/rsrc/src/h3_handler.rs new file mode 100644 index 0000000..c031898 --- /dev/null +++ b/rsrc/src/h3_handler.rs @@ -0,0 +1,429 @@ +//! HTTP/3 (QUIC) connection handler (DX-4.2). +//! +//! Runs on a dedicated tokio runtime thread since quinn requires tokio. +//! The QUIC endpoint binds to the same port as TCP (but over UDP). +//! Each HTTP/3 request dispatches through the same JS bridge as HTTP/1.1 +//! and HTTP/2, maintaining full API parity. +//! +//! Architecture: +//! - Separate std::thread with tokio runtime for QUIC +//! - quinn::Endpoint accepts QUIC connections +//! - h3 crate handles HTTP/3 semantics (QPACK, stream mapping) +//! - Requests are encoded into bridge envelopes and dispatched to JS +//! - Alt-Svc header injection in HTTP/1.1 and HTTP/2 responses + +use std::net::SocketAddr; +use std::sync::Arc; + +use anyhow::{anyhow, Result}; +use bytes::{Buf, Bytes}; +use http::{Response as HttpResponse, StatusCode}; +use napi::bindgen_prelude::Buffer; + +use crate::compress; +use crate::parser::intern_header_name; +use crate::router::ExactStaticRoute; +use crate::{ + JsDispatcher, LiveRouter, HttpServerConfig, + UNKNOWN_METHOD_CODE, MAX_BODY_BYTES, + INFLIGHT_REQUESTS, method_code_from_str, +}; + +/// Start the HTTP/3 QUIC listener on a dedicated tokio runtime. +/// +/// Spawns a background thread running a full tokio runtime, since quinn +/// requires tokio's async I/O. The QUIC endpoint binds to the same address +/// as the TCP listener but on UDP. +pub fn start_h3_listener( + bind_addr: SocketAddr, + tls_config: Arc, + live_router: Arc, + dispatcher: Arc, + server_config: Arc, +) -> Result<()> { + let quinn_server_config = quinn::ServerConfig::with_crypto(Arc::new( + quinn::crypto::rustls::QuicServerConfig::try_from(tls_config) + .map_err(|e| anyhow!("failed to create QUIC TLS config: {e}"))? + )); + + std::thread::Builder::new() + .name("h3-quic-runtime".into()) + .spawn(move || { + let rt = tokio::runtime::Builder::new_multi_thread() + .worker_threads(2) + .enable_all() + .build() + .expect("failed to create tokio runtime for QUIC"); + + rt.block_on(async move { + if let Err(e) = run_h3_endpoint( + bind_addr, + quinn_server_config, + live_router, + dispatcher, + server_config, + ).await { + log_error!("HTTP/3 endpoint error: {e}"); + } + }); + }) + .map_err(|e| anyhow!("failed to spawn QUIC thread: {e}"))?; + + Ok(()) +} + +/// Main QUIC accept loop on the tokio runtime. +async fn run_h3_endpoint( + bind_addr: SocketAddr, + server_config: quinn::ServerConfig, + live_router: Arc, + dispatcher: Arc, + http_config: Arc, +) -> Result<()> { + let endpoint = quinn::Endpoint::server(server_config, bind_addr) + .map_err(|e| anyhow!("QUIC bind failed on {bind_addr}: {e}"))?; + + eprintln!("[http-native] HTTP/3 (QUIC) listening on {bind_addr}"); + + while let Some(incoming) = endpoint.accept().await { + let router = Arc::clone(&live_router); + let disp = Arc::clone(&dispatcher); + let config = Arc::clone(&http_config); + + tokio::spawn(async move { + match incoming.await { + Ok(connection) => { + let peer = connection.remote_address(); + if let Err(e) = handle_h3_connection( + connection, router, disp, config, peer, + ).await { + log_error!("h3 connection error from {peer}: {e}"); + } + } + Err(e) => { + log_error!("QUIC handshake error: {e}"); + } + } + }); + } + + Ok(()) +} + +/// Handle a single HTTP/3 connection (multiple streams). +async fn handle_h3_connection( + quinn_conn: quinn::Connection, + live_router: Arc, + dispatcher: Arc, + server_config: Arc, + peer_addr: SocketAddr, +) -> Result<()> { + let mut h3_conn = h3::server::Connection::new(h3_quinn::Connection::new(quinn_conn)) + .await + .map_err(|e| anyhow!("h3 handshake failed: {e}"))?; + + let peer_ip = peer_addr.ip().to_string(); + + loop { + match h3_conn.accept().await { + Ok(Some(resolver)) => { + let router = live_router.router.load_full(); + let disp = Arc::clone(&dispatcher); + let config = Arc::clone(&server_config); + let ip = peer_ip.clone(); + + tokio::spawn(async move { + if let Err(e) = handle_h3_request( + resolver, router, &disp, &config, &ip, + ).await { + log_error!("h3 stream error: {e}"); + } + }); + } + Ok(None) => break, + Err(e) => { + log_error!("h3 accept error: {e}"); + break; + } + } + } + + Ok(()) +} + +/// Handle a single HTTP/3 request via the resolver → stream pattern. +async fn handle_h3_request( + resolver: h3::server::RequestResolver, + router: Arc, + dispatcher: &JsDispatcher, + _server_config: &HttpServerConfig, + peer_ip: &str, +) -> Result<()> { + INFLIGHT_REQUESTS.fetch_add(1, std::sync::atomic::Ordering::Release); + let _guard = InflightGuard; + + let (request, mut stream) = resolver + .resolve_request() + .await + .map_err(|e| anyhow!("h3 resolve request: {e}"))?; + + let (parts, _) = request.into_parts(); + + let method_str = parts.method.as_str(); + let path_and_query = parts.uri.path_and_query(); + let path = path_and_query.map(|pq| pq.path()).unwrap_or("/"); + let full_uri = path_and_query.map(|pq| pq.as_str()).unwrap_or("/"); + let query_present = parts.uri.query().is_some(); + let method_code = method_code_from_str(method_str).unwrap_or(UNKNOWN_METHOD_CODE); + + /* Collect headers */ + let mut headers: Vec<(&str, &str)> = Vec::with_capacity(parts.headers.len()); + let mut accepted_encoding = compress::AcceptedEncoding::Identity; + + for (name, value) in parts.headers.iter() { + let name_str = intern_header_name(name.as_str()); + if let Ok(val_str) = std::str::from_utf8(value.as_bytes()) { + if name_str == "accept-encoding" + && accepted_encoding != compress::AcceptedEncoding::Brotli + { + accepted_encoding = compress::parse_accept_encoding(val_str.as_bytes()); + } + headers.push((name_str, val_str)); + } + } + + /* Static route fast-path */ + if method_str == "GET" { + if let Some(static_route) = router.exact_static_route(b"GET", path.as_bytes()) { + return send_h3_static_response(&mut stream, static_route, accepted_encoding).await; + } + } + + /* Read request body from QUIC stream */ + let mut body_bytes = Vec::new(); + while let Some(chunk) = stream.recv_data().await.map_err(|e| anyhow!("h3 body: {e}"))? { + let remaining = chunk.remaining(); + if body_bytes.len() + remaining > MAX_BODY_BYTES { + let response = HttpResponse::builder() + .status(StatusCode::PAYLOAD_TOO_LARGE) + .body(()) + .unwrap(); + stream.send_response(response).await.map_err(|e| anyhow!("h3 send: {e}"))?; + stream.send_data(Bytes::from_static(b"{\"error\":\"Payload Too Large\"}")) + .await + .map_err(|e| anyhow!("h3 send body: {e}"))?; + stream.finish().await.map_err(|e| anyhow!("h3 finish: {e}"))?; + return Ok(()); + } + let mut buf = vec![0u8; remaining]; + chunk.chunk().iter().enumerate().for_each(|(i, &b)| buf[i] = b); + body_bytes.extend_from_slice(&buf[..remaining]); + } + + /* Route matching */ + let normalized_path = crate::normalize_runtime_path(path); + let matched_route = if method_code != UNKNOWN_METHOD_CODE { + router.match_route(method_code, normalized_path.as_ref()) + } else { + None + }; + + /* Build bridge envelope — reuse the H2 envelope builder */ + let envelope = crate::h2_handler::build_h2_bridge_envelope( + method_code, + &headers, + path, + full_uri, + &body_bytes, + Some(peer_ip), + query_present, + &matched_route, + ); + + /* Dispatch to JS */ + match dispatcher.dispatch(Buffer::from(envelope)).await { + Ok(response_buf) => { + send_h3_response_from_bridge(&mut stream, &response_buf).await?; + } + Err(e) => { + let response = HttpResponse::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(()) + .unwrap(); + stream.send_response(response).await.ok(); + stream.send_data( + Bytes::from(format!("{{\"error\":\"Internal Server Error: {e}\"}}")), + ).await.ok(); + stream.finish().await.ok(); + } + } + + Ok(()) +} + +/// Send a pre-built static response over HTTP/3. +async fn send_h3_static_response( + stream: &mut h3::server::RequestStream, Bytes>, + static_route: &ExactStaticRoute, + encoding: compress::AcceptedEncoding, +) -> Result<()> { + let response_bytes = match encoding { + compress::AcceptedEncoding::Brotli => { + static_route.keep_alive_response_br.as_ref() + .unwrap_or(&static_route.keep_alive_response) + } + compress::AcceptedEncoding::Gzip => { + static_route.keep_alive_response_gzip.as_ref() + .unwrap_or(&static_route.keep_alive_response) + } + _ => &static_route.keep_alive_response, + }; + + let bytes = response_bytes.as_ref(); + if let Some(header_end) = memchr::memmem::find(bytes, b"\r\n\r\n") { + let header_section = &bytes[..header_end]; + let body = &bytes[header_end + 4..]; + + /* Parse actual status code from the HTTP/1.1 status line */ + let actual_status = header_section.split(|&b| b == b'\n') + .next() + .and_then(|line| { + let line = if line.ends_with(b"\r") { &line[..line.len() - 1] } else { line }; + let parts: Vec<&[u8]> = line.splitn(3, |&b| b == b' ').collect(); + if parts.len() >= 2 { + std::str::from_utf8(parts[1]).ok()?.parse::().ok() + } else { + None + } + }) + .and_then(|code| StatusCode::from_u16(code).ok()) + .unwrap_or(StatusCode::OK); + + let mut builder = HttpResponse::builder().status(actual_status); + + for line in header_section.split(|&b| b == b'\n') { + let line = if line.ends_with(b"\r") { &line[..line.len() - 1] } else { line }; + if line.starts_with(b"HTTP/") || line.is_empty() { continue; } + if let Some(colon) = line.iter().position(|&b| b == b':') { + let name = &line[..colon]; + let value = &line[colon + 1..]; + let value = if value.first() == Some(&b' ') { &value[1..] } else { value }; + /* Skip h1-only headers invalid in h3 */ + if name.eq_ignore_ascii_case(b"connection") + || name.eq_ignore_ascii_case(b"transfer-encoding") + || name.eq_ignore_ascii_case(b"keep-alive") + { + continue; + } + if let (Ok(n), Ok(v)) = (std::str::from_utf8(name), std::str::from_utf8(value)) { + builder = builder.header(n, v); + } + } + } + + let response = builder.body(()).map_err(|e| anyhow!("h3 response: {e}"))?; + stream.send_response(response).await.map_err(|e| anyhow!("h3 send: {e}"))?; + if !body.is_empty() { + stream.send_data(Bytes::copy_from_slice(body)).await + .map_err(|e| anyhow!("h3 send body: {e}"))?; + } + stream.finish().await.map_err(|e| anyhow!("h3 finish: {e}"))?; + } else { + /* Malformed static response — missing header/body boundary. Send 500. */ + let response = HttpResponse::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(()) + .unwrap(); + stream.send_response(response).await.map_err(|e| anyhow!("h3 send: {e}"))?; + stream.send_data(Bytes::from_static(b"{\"error\":\"Internal Server Error\"}")) + .await.map_err(|e| anyhow!("h3 send body: {e}"))?; + stream.finish().await.map_err(|e| anyhow!("h3 finish: {e}"))?; + } + + Ok(()) +} + +/// Parse bridge response and send as HTTP/3. +async fn send_h3_response_from_bridge( + stream: &mut h3::server::RequestStream, Bytes>, + response_buf: &[u8], +) -> Result<()> { + if response_buf.len() < 8 { + let response = HttpResponse::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(()) + .unwrap(); + stream.send_response(response).await.ok(); + stream.send_data(Bytes::from_static(b"{\"error\":\"Invalid bridge response\"}")) + .await.ok(); + stream.finish().await.ok(); + return Ok(()); + } + + let status = u16::from_le_bytes([response_buf[0], response_buf[1]]); + let header_count = u16::from_le_bytes([response_buf[2], response_buf[3]]) as usize; + let body_len = u32::from_le_bytes([ + response_buf[4], response_buf[5], response_buf[6], response_buf[7], + ]) as usize; + + let mut offset = 8usize; + let mut builder = HttpResponse::builder() + .status(StatusCode::from_u16(status).unwrap_or(StatusCode::OK)); + + for _ in 0..header_count { + if offset + 4 > response_buf.len() { break; } + let name_len = u16::from_le_bytes([response_buf[offset], response_buf[offset + 1]]) as usize; + offset += 2; + let value_len = u16::from_le_bytes([response_buf[offset], response_buf[offset + 1]]) as usize; + offset += 2; + + if offset + name_len + value_len > response_buf.len() { break; } + + let name = &response_buf[offset..offset + name_len]; + offset += name_len; + let value = &response_buf[offset..offset + value_len]; + offset += value_len; + + if name.eq_ignore_ascii_case(b"connection") + || name.eq_ignore_ascii_case(b"transfer-encoding") + { + continue; + } + + if let (Ok(n), Ok(v)) = (std::str::from_utf8(name), std::str::from_utf8(value)) { + builder = builder.header(n, v); + } + } + + let body = if offset + body_len <= response_buf.len() { + &response_buf[offset..offset + body_len] + } else if offset < response_buf.len() { + &response_buf[offset..] + } else { + &[] + }; + + let response = builder.body(()).map_err(|e| anyhow!("h3 response: {e}"))?; + stream.send_response(response).await.map_err(|e| anyhow!("h3 send: {e}"))?; + if !body.is_empty() { + stream.send_data(Bytes::copy_from_slice(body)).await + .map_err(|e| anyhow!("h3 send body: {e}"))?; + } + stream.finish().await.map_err(|e| anyhow!("h3 finish: {e}"))?; + + Ok(()) +} + +/// RAII guard for decrementing in-flight request counter. +struct InflightGuard; +impl Drop for InflightGuard { + fn drop(&mut self) { + INFLIGHT_REQUESTS.fetch_sub(1, std::sync::atomic::Ordering::Release); + } +} + +/// Build the Alt-Svc header value advertising HTTP/3 availability. +/// Injected into HTTP/1.1 and HTTP/2 responses when HTTP/3 is enabled. +pub fn alt_svc_header(port: u16) -> String { + format!("h3=\":{port}\"; ma=86400") +} diff --git a/rsrc/src/http_utils.rs b/rsrc/src/http_utils.rs new file mode 100644 index 0000000..80cece3 --- /dev/null +++ b/rsrc/src/http_utils.rs @@ -0,0 +1,142 @@ +//! Shared HTTP utilities used by both the server core (lib.rs) and the router. +//! +//! Consolidates duplicated helpers that were previously defined independently +//! in lib.rs and router.rs (plan items D1, D2, R4). + +use std::collections::HashMap; + +// ─── Status Reason Phrases ──────────────── + +/// Map an HTTP status code to its standard reason phrase. +/// +/// /* @param status — HTTP status code (e.g. 200, 404) */ +/// /* @returns — static reason phrase string */ +pub fn status_reason(status: u16) -> &'static str { + match status { + 200 => "OK", + 201 => "Created", + 202 => "Accepted", + 204 => "No Content", + 301 => "Moved Permanently", + 302 => "Found", + 304 => "Not Modified", + 400 => "Bad Request", + 401 => "Unauthorized", + 403 => "Forbidden", + 404 => "Not Found", + 405 => "Method Not Allowed", + 408 => "Request Timeout", + 409 => "Conflict", + 411 => "Length Required", + 413 => "Payload Too Large", + 415 => "Unsupported Media Type", + 422 => "Unprocessable Entity", + 429 => "Too Many Requests", + 431 => "Request Header Fields Too Large", + 500 => "Internal Server Error", + 501 => "Not Implemented", + 502 => "Bad Gateway", + 503 => "Service Unavailable", + 504 => "Gateway Timeout", + _ => "Unknown", + } +} + +// ─── HTTP Response Building ─────────────── + +/// Build a complete HTTP/1.1 response from status, headers (HashMap), body, +/// and keep-alive flag. Used by the router for static/prebuilt responses. +/// +/// /* @param status — HTTP status code */ +/// /* @param headers — response headers as name/value pairs */ +/// /* @param body — raw response body bytes */ +/// /* @param keep_alive — whether to set connection: keep-alive or close */ +pub fn build_response_bytes( + status: u16, + headers: &HashMap, + body: &[u8], + keep_alive: bool, +) -> Vec { + let reason = status_reason(status); + let connection = if keep_alive { "keep-alive" } else { "close" }; + let body_len = body.len(); + + /* pre-calculate total size to avoid reallocation */ + let mut total_size = 9 + 3 + 1 + reason.len() + 2 // "HTTP/1.1 {status} {reason}\r\n" + + 16 + count_digits(body_len) + 2 // "content-length: {len}\r\n" + + 12 + connection.len() + 2; // "connection: {val}\r\n" + + for (name, value) in headers { + total_size += name.len() + 2 + value.len() + 2; + } + total_size += 2 + body_len; // "\r\n" + body + + let mut output = Vec::with_capacity(total_size); + output.extend_from_slice(b"HTTP/1.1 "); + write_u16(&mut output, status); + output.push(b' '); + output.extend_from_slice(reason.as_bytes()); + output.extend_from_slice(b"\r\ncontent-length: "); + write_usize(&mut output, body_len); + output.extend_from_slice(b"\r\nconnection: "); + output.extend_from_slice(connection.as_bytes()); + output.extend_from_slice(b"\r\n"); + + for (name, value) in headers { + /* security: skip headers with CRLF injection */ + if name.contains('\r') + || name.contains('\n') + || value.contains('\r') + || value.contains('\n') + { + continue; + } + output.extend_from_slice(name.as_bytes()); + output.extend_from_slice(b": "); + output.extend_from_slice(value.as_bytes()); + output.extend_from_slice(b"\r\n"); + } + + output.extend_from_slice(b"\r\n"); + output.extend_from_slice(body); + output +} + +// ─── Integer Formatting Helpers ─────────── + +/// Write a usize as decimal ASCII into the output buffer. +/// Uses itoa for zero-allocation formatting. +/// +/// /* @param output — target byte buffer */ +/// /* @param value — integer to format */ +#[inline(always)] +pub fn write_usize(output: &mut Vec, value: usize) { + let mut buf = itoa::Buffer::new(); + output.extend_from_slice(buf.format(value).as_bytes()); +} + +/// Write a u16 as decimal ASCII into the output buffer. +/// +/// /* @param output — target byte buffer */ +/// /* @param value — integer to format */ +#[inline(always)] +pub fn write_u16(output: &mut Vec, value: u16) { + let mut buf = itoa::Buffer::new(); + output.extend_from_slice(buf.format(value).as_bytes()); +} + +/// Count the number of decimal digits in a usize value. +/// +/// /* @param n — value to count digits of */ +/// /* @returns — digit count (1 for n=0) */ +pub fn count_digits(mut n: usize) -> usize { + if n == 0 { + return 1; + } + let mut count = 0; + while n > 0 { + count += 1; + n /= 10; + } + count +} diff --git a/rsrc/src/lib.rs b/rsrc/src/lib.rs index 845407e..c9c2996 100644 --- a/rsrc/src/lib.rs +++ b/rsrc/src/lib.rs @@ -1,6 +1,34 @@ +// ─── Structured Logging Macros (D6) ──── +// +// Lightweight structured log macros that output to stderr with a consistent +// "[http-native] LEVEL: message" format. Drop-in replacement for bare +// eprintln! calls. Can be swapped to `tracing` crate later without +// changing call sites. Defined before module declarations so child +// modules can use them. + +/// /* @param $($arg)* — format arguments identical to eprintln! */ +macro_rules! log_error { + ($($arg:tt)*) => { + eprintln!("[http-native] error: {}", format_args!($($arg)*)) + }; +} + +/// /* @param $($arg)* — format arguments identical to eprintln! */ +macro_rules! log_warn { + ($($arg:tt)*) => { + eprintln!("[http-native] warn: {}", format_args!($($arg)*)) + }; +} + mod analyzer; pub mod compress; +pub mod h2_handler; +#[allow(dead_code)] +mod h3_handler; +pub mod http_utils; mod manifest; +pub mod parser; +pub mod response; mod rate_limit; mod router; pub mod session; @@ -10,7 +38,7 @@ use anyhow::{anyhow, Context, Result}; use arc_swap::ArcSwap; use memchr::memmem; use monoio::io::{AsyncReadRent, AsyncWriteRent, AsyncWriteRentExt}; -use monoio::net::{ListenerOpts, TcpListener}; +use monoio::net::TcpListener; use monoio_rustls::TlsAcceptor; use napi::bindgen_prelude::{Buffer, Function, Promise}; use napi::threadsafe_function::ThreadsafeFunction; @@ -33,6 +61,7 @@ use crate::analyzer::{ TextSegment, }; use crate::manifest::{HttpServerConfigInput, ManifestInput, TlsConfigInput}; +use crate::response::{build_response_bytes_fast, patch_connection_header, inject_set_cookie_header, build_error_response_bytes}; use crate::router::{ExactStaticRoute, MatchedRoute, Router}; // ─── Constants ────────────────────────── @@ -53,12 +82,6 @@ const UNKNOWN_METHOD_CODE: u8 = 0; /// Sentinel handler ID dispatched to JS when no route matches — JS treats this as 404. const NOT_FOUND_HANDLER_ID: u32 = 0; -/// Security: Maximum number of headers we allow per request -const MAX_HEADER_COUNT: usize = 64; -/// Security: Maximum URL length to prevent abuse -const MAX_URL_LENGTH: usize = 8192; -/// Security: Maximum single header value length -const MAX_HEADER_VALUE_LENGTH: usize = 8192; /// Security: Maximum request body size (1 MB) const MAX_BODY_BYTES: usize = 1024 * 1024; /// Security: Maximum concurrent connections per worker thread @@ -102,6 +125,74 @@ fn release_buffer(mut buf: Vec) { }); } +// ─── Response Buffer Pool (BOOST-2.3) ── +// +// Eliminates per-response Vec allocations by recycling response buffers. +// Separate from the connection read buffer pool — response buffers are +// typically smaller and have different capacity profiles. + +#[allow(dead_code)] +const RESPONSE_POOL_MAX_SIZE: usize = 128; +#[allow(dead_code)] +const RESPONSE_POOL_MAX_RECYCLE_SIZE: usize = 65536; + +thread_local! { + static RESPONSE_POOL: RefCell>> = RefCell::new(Vec::with_capacity(RESPONSE_POOL_MAX_SIZE)); +} + +/// Acquire a response buffer from the thread-local pool. +/// Defaults to 1KB capacity — right-sized for typical JSON API responses. +#[allow(dead_code)] +fn acquire_response_buffer(estimated_size: usize) -> Vec { + RESPONSE_POOL.with(|pool| { + pool.borrow_mut() + .pop() + .map(|mut buf| { + buf.clear(); + if buf.capacity() < estimated_size { + buf.reserve(estimated_size - buf.capacity()); + } + buf + }) + .unwrap_or_else(|| Vec::with_capacity(estimated_size.max(1024))) + }) +} + +#[allow(dead_code)] +fn release_response_buffer(mut buf: Vec) { + if buf.capacity() > RESPONSE_POOL_MAX_RECYCLE_SIZE { + return; + } + buf.clear(); + RESPONSE_POOL.with(|pool| { + let mut pool = pool.borrow_mut(); + if pool.len() < RESPONSE_POOL_MAX_SIZE { + pool.push(buf); + } + }); +} + +// ─── Per-Request Arena Allocator (BOOST-1.2) ── +// +// Uses bumpalo for per-request bump allocation. All request-scoped strings +// and small buffers are allocated from the arena, which resets at the end of +// each request. This reduces per-request heap allocations from ~8-12 to 1. + +#[allow(dead_code)] +const REQUEST_ARENA_CAPACITY: usize = 4096; + +thread_local! { + static REQUEST_ARENA: RefCell = + RefCell::new(bumpalo::Bump::with_capacity(REQUEST_ARENA_CAPACITY)); +} + +/// Reset the per-thread request arena. Called at the end of each request +/// to release all arena-allocated memory in a single operation. +#[allow(dead_code)] +fn reset_request_arena() { + REQUEST_ARENA.with(|arena| arena.borrow_mut().reset()); +} + // ─── Server Configuration ─────────────── #[derive(Clone)] @@ -195,9 +286,16 @@ pub struct NativeListenOptions { struct ShutdownHandle { flag: Arc, + /// @DX-6.3: when true, the server rejects new connections but drains + /// in-flight requests before fully stopping. Set via `shutdown()`. + draining: Arc, wake_addrs: Vec, } +/// @DX-6.3: global atomic counter of in-flight requests across all workers. +/// Used by graceful shutdown to wait for requests to complete before closing. +static INFLIGHT_REQUESTS: AtomicU64 = AtomicU64::new(0); + struct LiveRouter { router: ArcSwap, } @@ -252,6 +350,38 @@ impl NativeServerHandle { Ok(()) } + /// @DX-6.3: graceful shutdown — stop accepting new connections, drain + /// in-flight requests up to `timeout_ms`, then force-stop workers. + /// Returns the number of in-flight requests that were still pending + /// when the timeout expired (0 = fully drained). + #[napi] + pub fn shutdown(&self, timeout_ms: Option) -> napi::Result { + let drain_timeout = Duration::from_millis(timeout_ms.unwrap_or(30_000) as u64); + + /* Phase 1: set draining flag — workers reject new connections but + * finish processing in-flight requests normally. */ + if let Some(handle) = self.shutdown.lock().expect("shutdown mutex poisoned").as_ref() { + handle.draining.store(true, Ordering::SeqCst); + } + + close_all_websocket_connections(); + + /* Phase 2: poll in-flight request counter until drained or timeout */ + let deadline = std::time::Instant::now() + drain_timeout; + while INFLIGHT_REQUESTS.load(Ordering::Acquire) > 0 { + if std::time::Instant::now() >= deadline { + break; + } + std::thread::sleep(Duration::from_millis(10)); + } + + let remaining = INFLIGHT_REQUESTS.load(Ordering::Acquire) as u32; + + /* Phase 3: force-stop workers regardless of drain state */ + self.close()?; + Ok(remaining) + } + #[napi] pub fn close(&self) -> napi::Result<()> { let registered_namespaces = { @@ -310,6 +440,20 @@ static STREAM_CHANNELS: std::sync::OnceLock> = std::sync::OnceLock::new(); + +// ─── WebSocket Pub/Sub Registry (DX-4.4) ─── +// +// Topic-based pub/sub for WebSocket connections. Connections subscribe to +// named topics; publishing to a topic broadcasts to all subscribers. The +// registry uses DashMap for lock-free concurrent access across worker threads. + +/// topic → set of connection IDs subscribed to that topic +static WS_TOPICS: std::sync::OnceLock>> = + std::sync::OnceLock::new(); + +fn ws_topic_registry() -> &'static dashmap::DashMap> { + WS_TOPICS.get_or_init(dashmap::DashMap::new) +} static CACHE_NAMESPACE_COUNTS: std::sync::OnceLock> = std::sync::OnceLock::new(); pub(crate) static CACHE_NAMESPACE_GENERATION: AtomicU64 = AtomicU64::new(1); @@ -419,10 +563,83 @@ pub fn stream_end(stream_id: i64) -> napi::Result<()> { if let Some((_, sender)) = registry.remove(&(stream_id as u64)) { let _ = sender.send(StreamMessage::End); } - websocket_connections().remove(&(stream_id as u64)); + let conn_id = stream_id as u64; + websocket_connections().remove(&conn_id); + // Clean up any pub/sub subscriptions for this connection + ws_unsubscribe_all(stream_id); Ok(()) } +// ─── WebSocket Pub/Sub NAPI Functions ─── + +/// Subscribe a WebSocket connection to a topic. +#[napi] +pub fn ws_subscribe(connection_id: i64, topic: String) { + let id = connection_id as u64; + ws_topic_registry() + .entry(topic) + .or_insert_with(HashSet::new) + .insert(id); +} + +/// Unsubscribe a WebSocket connection from a topic. +#[napi] +pub fn ws_unsubscribe(connection_id: i64, topic: String) { + let registry = ws_topic_registry(); + if let Some(mut subs) = registry.get_mut(&topic) { + subs.remove(&(connection_id as u64)); + if subs.is_empty() { + drop(subs); + registry.remove(&topic); + } + } +} + +/// Unsubscribe a connection from ALL topics (called on disconnect). +#[napi] +pub fn ws_unsubscribe_all(connection_id: i64) { + let id = connection_id as u64; + let registry = ws_topic_registry(); + let mut empty_topics = Vec::new(); + for mut entry in registry.iter_mut() { + entry.value_mut().remove(&id); + if entry.value().is_empty() { + empty_topics.push(entry.key().clone()); + } + } + for topic in empty_topics { + registry.remove(&topic); + } +} + +/// Publish a message to all connections subscribed to a topic. +/// Returns the number of connections the message was sent to. +#[napi] +pub fn ws_publish(topic: String, data: Buffer) -> u32 { + let registry = ws_topic_registry(); + let Some(subscribers) = registry.get(&topic) else { return 0 }; + let stream_reg = stream_registry(); + let frame = websocket::encode_frame(websocket::OPCODE_TEXT, data.as_ref()); + let mut count = 0u32; + for &conn_id in subscribers.value() { + if let Some(sender) = stream_reg.get(&conn_id) { + if sender.send(StreamMessage::Chunk(frame.clone())).is_ok() { + count += 1; + } + } + } + count +} + +/// Get the number of subscribers for a topic. +#[napi] +pub fn ws_subscriber_count(topic: String) -> u32 { + ws_topic_registry() + .get(&topic) + .map(|s| s.len() as u32) + .unwrap_or(0) +} + /// Get a session value by key. Returns JSON string or null. #[napi] pub fn session_get(session_id_hex: String, key: String) -> Option { @@ -440,8 +657,7 @@ pub fn session_set(session_id_hex: String, key: String, value: String) -> bool { let Some(id) = session::hex_decode_id(&session_id_hex) else { return false }; let mut mutations = std::collections::HashMap::new(); mutations.insert(key, value.into_bytes()); - store.upsert(&id, mutations, &[]); - true + store.upsert(&id, mutations, &[]) } /// Delete a session key. @@ -449,8 +665,7 @@ pub fn session_set(session_id_hex: String, key: String, value: String) -> bool { pub fn session_delete(session_id_hex: String, key: String) -> bool { let Some(store) = GLOBAL_SESSION_STORE.get() else { return false }; let Some(id) = session::hex_decode_id(&session_id_hex) else { return false }; - store.upsert(&id, std::collections::HashMap::new(), &[key]); - true + store.upsert(&id, std::collections::HashMap::new(), &[key]) } /// Destroy an entire session. @@ -506,8 +721,7 @@ pub fn session_set_all(session_id_hex: String, data_json: String) -> bool { for (key, value) in map { mutations.insert(key.clone(), value.to_string().into_bytes()); } - store.upsert(&id, mutations, &[]); - true + store.upsert(&id, mutations, &[]) } #[napi(object)] @@ -593,8 +807,12 @@ pub fn start_server( let live_router = Arc::new(LiveRouter { router: ArcSwap::from(router), }); - let tls_acceptor = build_tls_acceptor(&manifest).map_err(to_napi_error)?; - let tls_enabled = tls_acceptor.is_some(); + let tls_result = build_tls_acceptor(&manifest).map_err(to_napi_error)?; + let tls_enabled = tls_result.is_some(); + let (tls_acceptor, tls_config) = match tls_result { + Some((acceptor, config)) => (Some(acceptor), Some(config)), + None => (None, None), + }; // Build session store if session config is present in manifest let session_store: Option> = manifest.session.as_ref().map(|cfg| { @@ -604,7 +822,7 @@ pub fn start_server( cookie_name: cfg.cookie_name.clone(), http_only: cfg.http_only, secure: cfg.secure, - same_site: session::SameSite::from_str(&cfg.same_site), + same_site: cfg.same_site.parse::().unwrap(), path: cfg.path.clone(), max_sessions: cfg.max_sessions, max_data_size: cfg.max_data_size, @@ -623,6 +841,7 @@ pub fn start_server( let worker_count = worker_count_for(&options); let (startup_tx, startup_rx) = mpsc::sync_channel::>(worker_count); let shutdown_flag = Arc::new(AtomicBool::new(false)); + let draining_flag = Arc::new(AtomicBool::new(false)); let mut closed_receivers = Vec::with_capacity(worker_count); for _ in 0..worker_count { @@ -633,8 +852,10 @@ pub fn start_server( let thread_dispatcher = Arc::clone(&dispatcher); let thread_config = Arc::clone(&server_config); let thread_shutdown = Arc::clone(&shutdown_flag); + let thread_draining = Arc::clone(&draining_flag); let thread_session_store = session_store.clone(); let thread_tls_acceptor = tls_acceptor.clone(); + let thread_tls_config = tls_config.clone(); let thread_options = NativeListenOptions { host: options.host.clone(), port: options.port, @@ -661,7 +882,9 @@ pub fn start_server( thread_dispatcher, thread_config, thread_tls_acceptor, + thread_tls_config, thread_shutdown, + thread_draining, thread_session_store, ) .await @@ -670,7 +893,7 @@ pub fn start_server( if let Err(error) = &result { let _ = startup_tx_error.send(Err(error.to_string())); - eprintln!("[http-native] native server error: {error:#}"); + log_error!("native server exited: {error:#}"); } let _ = closed_tx.send(()); @@ -729,6 +952,7 @@ pub fn start_server( cache_namespaces: Mutex::new(registered_cache_namespaces), shutdown: Mutex::new(Some(ShutdownHandle { flag: shutdown_flag, + draining: draining_flag, wake_addrs, })), closed: Mutex::new(Some(closed_receivers)), @@ -779,7 +1003,9 @@ async fn run_server( dispatcher: Arc, server_config: Arc, tls_acceptor: Option, + tls_config: Option>, shutdown_flag: Arc, + draining_flag: Arc, session_store: Option>, ) -> Result<()> { // Wrap Arc in Rc for cheap per-connection cloning within this single-threaded @@ -787,11 +1013,12 @@ async fn run_server( let live_router: Rc> = Rc::new(live_router); let dispatcher: Rc> = Rc::new(dispatcher); let server_config: Rc> = Rc::new(server_config); - let tls_acceptor: Option> = tls_acceptor.map(Rc::new); + let _tls_acceptor: Option> = tls_acceptor.map(Rc::new); + let tls_config: Option>> = tls_config.map(Rc::new); let session_store: Option>> = session_store.map(Rc::new); - let active_connections: std::cell::Cell = std::cell::Cell::new(0); + let active_connections = Rc::new(std::cell::Cell::new(0usize)); loop { if shutdown_flag.load(Ordering::Acquire) { @@ -804,6 +1031,14 @@ async fn run_server( break; } + /* @DX-6.3: in draining mode, reject new connections with 503 + * so load balancers route traffic elsewhere. In-flight requests + * on existing connections continue normally. */ + if draining_flag.load(Ordering::Acquire) { + drop(stream); + continue; + } + // Security (S3): enforce per-worker connection limit if active_connections.get() >= MAX_CONNECTIONS_PER_WORKER { drop(stream); @@ -811,32 +1046,66 @@ async fn run_server( } if let Err(error) = stream.set_nodelay(true) { - eprintln!("[http-native] failed to enable TCP_NODELAY: {error}"); + log_warn!("failed to enable TCP_NODELAY: {error}"); } let live_router = Rc::clone(&live_router); let dispatcher = Rc::clone(&dispatcher); let server_config = Rc::clone(&server_config); - let tls_acceptor = tls_acceptor.clone(); + let tls_config = tls_config.clone(); let session_store = session_store.clone(); active_connections.set(active_connections.get() + 1); - // Safety: monoio is single-threaded per worker, so Cell is fine here - let conn_counter = &active_connections as *const std::cell::Cell; + let conn_counter = Rc::clone(&active_connections); monoio::spawn(async move { - let connection_result = if let Some(acceptor) = tls_acceptor.as_ref() { - match acceptor.accept(stream).await { + let connection_result = if let Some(tls_cfg) = tls_config.as_ref() { + /* @DX-4.1: TLS connections use poll-io path to support both + * HTTP/1.1 and HTTP/2 via ALPN negotiation. The raw TCP stream + * is converted to a tokio-compatible type, then wrapped with + * tokio-rustls for TLS. After the handshake, ALPN determines + * whether to dispatch to the h2 handler or fall back to h1.1. */ + let poll_io = match monoio::io::IntoPollIo::into_poll_io(stream) { + Ok(s) => s, + Err(e) => { + log_error!("failed to convert to poll-io: {e}"); + return; + } + }; + + let tokio_acceptor = tokio_rustls::TlsAcceptor::from(Arc::clone(tls_cfg.as_ref())); + match tokio_acceptor.accept(poll_io).await { Ok(tls_stream) => { - handle_connection( - tls_stream, - live_router, - dispatcher, - server_config, - session_store, - Some(peer_addr), - ) - .await + let alpn = tls_stream.get_ref().1.alpn_protocol() + .map(|p| p.to_vec()); + let is_h2 = alpn.as_deref() == Some(b"h2"); + + if is_h2 { + let peer_ip = Some(peer_addr.ip().to_string()); + h2_handler::handle_h2_connection( + tls_stream, + live_router, + dispatcher, + server_config, + peer_ip, + ) + .await + } else { + /* HTTP/1.1 over TLS — use monoio-rustls for + * completion-based I/O (fast path). We need to + * re-accept with monoio-rustls since we already + * consumed the stream. For now, handle h1.1 over + * the poll-io TLS stream. */ + handle_h1_over_poll_tls( + tls_stream, + live_router, + dispatcher, + server_config, + session_store, + Some(peer_addr), + ) + .await + } } Err(error) => Err(anyhow!("TLS accept failed: {error}")), } @@ -852,12 +1121,9 @@ async fn run_server( .await }; if let Err(error) = connection_result { - eprintln!("[http-native] connection error: {error}"); + log_error!("connection error: {error}"); } - // Safety: single-threaded — pointer is always valid while server runs - unsafe { &*conn_counter }.set( - unsafe { &*conn_counter }.get().saturating_sub(1), - ); + conn_counter.set(conn_counter.get().saturating_sub(1)); }); } Err(error) => { @@ -865,7 +1131,7 @@ async fn run_server( break; } - eprintln!("[http-native] accept error: {error}"); + log_error!("accept error: {error}"); } } } @@ -873,29 +1139,10 @@ async fn run_server( Ok(()) } -// ─── Parsed Request (from httparse) ───── - -struct ParsedRequest<'a> { - method: &'a [u8], - target: &'a [u8], - path: &'a [u8], - keep_alive: bool, - header_bytes: usize, - has_body: bool, - content_length: Option, - /// True when a non-identity Transfer-Encoding header was seen - has_chunked_te: bool, - /// Pre-parsed header pairs — stored once, used by both routing and bridge - headers: Vec<(&'a str, &'a str)>, - /// Raw cookie header value for session extraction - cookie_header: Option<&'a str>, - /// True when the request contains an Upgrade: websocket header - is_websocket_upgrade: bool, - /// The Sec-WebSocket-Key header value, if present - ws_key: Option<&'a str>, - /// Best accepted encoding from Accept-Encoding header - accepted_encoding: compress::AcceptedEncoding, -} +use crate::parser::{ + ParsedRequest, parse_request_httparse, find_header_end, + contains_ascii_case_insensitive, trim_ascii_spaces, +}; use monoio::time::timeout; use std::time::Duration; @@ -903,8 +1150,100 @@ use std::time::Duration; const TIMEOUT_HEADER_READ: Duration = Duration::from_secs(30); const TIMEOUT_IDLE_KEEPALIVE: Duration = Duration::from_secs(120); const TIMEOUT_BODY_READ: Duration = Duration::from_secs(60); +const TIMEOUT_WS_IDLE: Duration = Duration::from_secs(300); +/// @S8: maximum wall-clock time to accumulate a complete set of request headers. +/// Defends against slow-loris attacks that send one byte per read to reset +/// per-read timeouts. If the total header phase exceeds this, the connection +/// is closed regardless of per-read progress. +const TIMEOUT_HEADER_DEADLINE: Duration = Duration::from_secs(60); + +// ─── Tokio ↔ Monoio I/O Adapter ──────── +// +// Bridges tokio's poll-based `AsyncRead`/`AsyncWrite` traits to monoio's +// ownership-based `AsyncReadRent`/`AsyncWriteRent` traits. This lets us reuse +// the existing HTTP/1.1 connection handler for TLS streams that went through +// tokio-rustls (needed for ALPN negotiation with h2). + +struct TokioCompat(T); + +impl monoio::io::AsyncReadRent for TokioCompat { + async fn read(&mut self, mut buf: B) -> monoio::BufResult { + use tokio::io::AsyncReadExt; + let total = buf.bytes_total(); + if total == 0 { + return (Ok(0), buf); + } + // Safety: write_ptr returns the buffer start, bytes_total gives capacity. + // Matches monoio's own recv semantics — kernel writes from position 0. + let slice = unsafe { std::slice::from_raw_parts_mut(buf.write_ptr(), total) }; + match self.0.read(slice).await { + Ok(n) => { + unsafe { buf.set_init(n) }; + (Ok(n), buf) + } + Err(e) => (Err(e), buf), + } + } -// ─── Connection Handler with Buffer Pool + async fn readv(&mut self, buf: B) -> monoio::BufResult { + // Vectored reads are never used by the HTTP/1.1 handler + (Err(std::io::Error::new(std::io::ErrorKind::Unsupported, "readv not available over TLS adapter")), buf) + } +} + +impl monoio::io::AsyncWriteRent for TokioCompat { + async fn write(&mut self, buf: B) -> monoio::BufResult { + use tokio::io::AsyncWriteExt; + let init = buf.bytes_init(); + // Safety: read_ptr gives the buffer start, bytes_init gives valid byte count + let slice = unsafe { std::slice::from_raw_parts(buf.read_ptr(), init) }; + match self.0.write(slice).await { + Ok(n) => (Ok(n), buf), + Err(e) => (Err(e), buf), + } + } + + async fn writev(&mut self, buf: B) -> monoio::BufResult { + (Err(std::io::Error::new(std::io::ErrorKind::Unsupported, "writev not available over TLS adapter")), buf) + } + + async fn flush(&mut self) -> std::io::Result<()> { + tokio::io::AsyncWriteExt::flush(&mut self.0).await + } + + async fn shutdown(&mut self) -> std::io::Result<()> { + tokio::io::AsyncWriteExt::shutdown(&mut self.0).await + } +} + +/// Handle HTTP/1.1 traffic over a tokio-rustls TLS stream. +/// +/// Wraps the TLS stream in a `TokioCompat` adapter so the existing +/// monoio-based connection handler can drive it. This path is used when +/// ALPN negotiation selects "http/1.1" instead of "h2". +async fn handle_h1_over_poll_tls( + tls_stream: tokio_rustls::server::TlsStream, + live_router: Rc>, + dispatcher: Rc>, + server_config: Rc>, + session_store: Option>>, + peer_addr: Option, +) -> Result<()> +where + IO: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static, +{ + handle_connection( + TokioCompat(tls_stream), + live_router, + dispatcher, + server_config, + session_store, + peer_addr, + ) + .await +} + +// ─── Connection Handler with Buffer Pool async fn handle_connection( mut stream: S, @@ -947,11 +1286,19 @@ async fn handle_connection_inner( where S: AsyncReadRent + AsyncWriteRent + Unpin, { - let mut is_first_request = true; - loop { let router = live_router.router.load_full(); + /* @S8: deadline tracks total wall-clock time for header accumulation. + * Starts unset; initialized on first byte of a new request. Prevents + * slow-loris attacks that drip-feed bytes to reset per-read timeouts. */ + let mut header_deadline: Option = None; + + /* The first read of a new request uses the longer idle/keep-alive + * timeout. Once bytes arrive, subsequent reads use the shorter + * header-read timeout. */ + let mut awaiting_new_request = true; + // Try hot-path parsing first (GET / with known prefix) let parsed = loop { let result = if router.exact_get_root().is_some() { @@ -973,21 +1320,21 @@ where // SAFETY: We take ownership of the buffer, read into it, then put it back let owned_buf = std::mem::take(buffer); - let read_duration = if is_first_request { - TIMEOUT_HEADER_READ - } else { + let read_duration = if awaiting_new_request { TIMEOUT_IDLE_KEEPALIVE + } else { + TIMEOUT_HEADER_READ }; - + let timeout_result = timeout(read_duration, stream.read(owned_buf)).await; let (read_result, next_buffer) = match timeout_result { Ok(res) => res, Err(_) => { - // Read timeout + // Per-read timeout expired return Ok(()); } }; - + *buffer = next_buffer; let bytes_read = read_result?; @@ -995,7 +1342,27 @@ where return Ok(()); } - is_first_request = false; + awaiting_new_request = false; + + /* @S8: start the header deadline on the first byte of a new request */ + if header_deadline.is_none() { + header_deadline = Some(std::time::Instant::now() + TIMEOUT_HEADER_DEADLINE); + } + + /* @S8: enforce total header-phase wall-clock deadline */ + if let Some(deadline) = header_deadline { + if std::time::Instant::now() >= deadline { + let response = build_error_response_bytes( + 408, + b"{\"error\":\"Request Timeout\"}", + false, + ); + let (write_result, _) = stream.write_all(response).await; + let _ = write_result; + stream.shutdown().await?; + return Ok(()); + } + } if buffer.len() > server_config.max_header_bytes { // Security: Request header too large @@ -1017,6 +1384,19 @@ where let content_length = parsed.content_length; let accepted_encoding = parsed.accepted_encoding; + /* @DX-6.3: track in-flight requests for graceful shutdown draining. + * The guard decrements the counter on drop, ensuring correct counting + * regardless of which code path exits the request (error, early return, + * or normal completion). */ + INFLIGHT_REQUESTS.fetch_add(1, Ordering::Release); + struct InflightGuard; + impl Drop for InflightGuard { + fn drop(&mut self) { + INFLIGHT_REQUESTS.fetch_sub(1, Ordering::Release); + } + } + let _inflight = InflightGuard; + // Security (S1): reject requests with non-identity Transfer-Encoding if parsed.has_chunked_te { drop(parsed); @@ -1114,15 +1494,53 @@ where continue; } + /* @DX-3.4: pre-match route to determine per-route body size limit. + * This runs before the body is read so oversized payloads are rejected + * without wasting I/O or memory. */ + let route_body_limit = { + let method_code = method_code_from_bytes(parsed.method).unwrap_or(UNKNOWN_METHOD_CODE); + let path_str = std::str::from_utf8(parsed.path).unwrap_or("/"); + let normalized = normalize_runtime_path(path_str); + if method_code != UNKNOWN_METHOD_CODE { + router.match_route(method_code, normalized.as_ref()) + .and_then(|m| m.max_body_bytes) + } else { + None + } + }; + // ── Body requests: need owned copies to release buffer for body read ── - let method_owned: Vec = parsed.method.to_vec(); - let target_owned: Vec = parsed.target.to_vec(); - let path_owned: Vec = parsed.path.to_vec(); - let headers_owned: Vec<(String, String)> = parsed - .headers - .iter() - .map(|(n, v)| (n.to_string(), v.to_string())) - .collect(); + // + // @P1: coalesce method + target + path into a single allocation and + // pack all header name/value pairs into one flat buffer with offset + // ranges. Avoids N individual String allocations per header. + let method_len = parsed.method.len(); + let target_len = parsed.target.len(); + let path_len = parsed.path.len(); + let mtp_total = method_len + target_len + path_len; + let mut mtp_buf = Vec::with_capacity(mtp_total); + mtp_buf.extend_from_slice(parsed.method); + mtp_buf.extend_from_slice(parsed.target); + mtp_buf.extend_from_slice(parsed.path); + let method_owned = &mtp_buf[..method_len]; + let target_owned = &mtp_buf[method_len..method_len + target_len]; + let path_owned = &mtp_buf[method_len + target_len..]; + + let header_count = parsed.headers.len(); + let mut hdr_buf_size = 0; + for (n, v) in &parsed.headers { + hdr_buf_size += n.len() + v.len(); + } + let mut hdr_buf = Vec::with_capacity(hdr_buf_size); + /* (name_start, name_len, value_start, value_len) per header */ + let mut hdr_ranges: Vec<(usize, usize, usize, usize)> = Vec::with_capacity(header_count); + for (n, v) in &parsed.headers { + let ns = hdr_buf.len(); + hdr_buf.extend_from_slice(n.as_bytes()); + let vs = hdr_buf.len(); + hdr_buf.extend_from_slice(v.as_bytes()); + hdr_ranges.push((ns, n.len(), vs, v.len())); + } let (session_id_body, is_new_session_body) = resolve_session(session_store, parsed.cookie_header); drop(parsed); @@ -1140,7 +1558,10 @@ where } }; - if content_length > MAX_BODY_BYTES { + /* @DX-3.4: use per-route body limit when configured, otherwise + * fall back to the global MAX_BODY_BYTES constant. */ + let effective_body_limit = route_body_limit.unwrap_or(MAX_BODY_BYTES); + if content_length > effective_body_limit { let response = build_error_response_bytes(413, b"{\"error\":\"Payload Too Large\"}", false); let (write_result, _) = stream.write_all(response).await; @@ -1185,12 +1606,24 @@ where } }; + /* @P1: build (&str, &str) header refs directly from the packed buffer — + * all bytes were valid UTF-8 from httparse, so this avoids N individual + * String heap allocations. */ + let header_refs: Vec<(&str, &str)> = hdr_ranges + .iter() + .map(|&(ns, nl, vs, vl)| { + let name = std::str::from_utf8(&hdr_buf[ns..ns + nl]).unwrap_or(""); + let value = std::str::from_utf8(&hdr_buf[vs..vs + vl]).unwrap_or(""); + (name, value) + }) + .collect(); + let dispatch_decision_owned = build_dispatch_decision_owned( router.as_ref(), - &method_owned, - &target_owned, - &path_owned, - &headers_owned, + method_owned, + target_owned, + path_owned, + &header_refs, &body_bytes, peer_ip, accepted_encoding, @@ -1217,132 +1650,6 @@ where } } -// ─── httparse-based Request Parsing ───── -// -// Uses the battle-tested `httparse` crate for RFC-compliant zero-copy parsing. -// Single-pass: parses headers once and stores them for reuse by both the -// router and the bridge envelope builder. - -fn parse_request_httparse(bytes: &[u8]) -> Option> { - let mut raw_headers = [httparse::EMPTY_HEADER; MAX_HEADER_COUNT]; - let mut req = httparse::Request::new(&mut raw_headers); - - let header_len = match req.parse(bytes) { - Ok(httparse::Status::Complete(len)) => len, - Ok(httparse::Status::Partial) => return None, - Err(_) => return None, // Malformed — caller will handle - }; - - let method = req.method?.as_bytes(); - let target = req.path?.as_bytes(); - let version = req.version?; - - // Security: enforce URL length limit - if target.len() > MAX_URL_LENGTH { - return None; - } - - // Extract path (before '?') - let path = target.split(|b| *b == b'?').next()?; - - let mut keep_alive = version >= 1; // HTTP/1.1+ defaults to keep-alive - let mut has_body = false; - let mut content_length: Option = None; - let mut has_chunked_te = false; - let mut cookie_header: Option<&str> = None; - let mut is_websocket_upgrade = false; - let mut ws_key: Option<&str> = None; - let mut accepted_encoding = compress::AcceptedEncoding::Identity; - let mut headers = Vec::with_capacity(req.headers.len()); - - for header in req.headers.iter() { - if header.name.is_empty() { - break; - } - - // Security: enforce header value length - if header.value.len() > MAX_HEADER_VALUE_LENGTH { - return None; - } - - let name = header.name; // httparse gives us &str - let value = match std::str::from_utf8(header.value) { - Ok(v) => v, - Err(_) => continue, // Skip non-UTF-8 headers - }; - - // Connection handling — allocation-free byte comparison - if name.eq_ignore_ascii_case("connection") { - let vb = value.as_bytes(); - if contains_ascii_case_insensitive(vb, b"close") { - keep_alive = false; - } - if contains_ascii_case_insensitive(vb, b"keep-alive") { - keep_alive = true; - } - } - - // Body detection - if name.eq_ignore_ascii_case("content-length") { - let trimmed = value.trim(); - if let Ok(len) = trimmed.parse::() { - content_length = Some(len); - if len > 0 { - has_body = true; - } - } - } - - if name.eq_ignore_ascii_case("transfer-encoding") { - let trimmed = value.trim(); - if !trimmed.is_empty() && !trimmed.eq_ignore_ascii_case("identity") { - has_body = true; - has_chunked_te = true; - } - } - - // Session: capture cookie header for session extraction - if name.eq_ignore_ascii_case("cookie") { - cookie_header = Some(value); - } - - // WebSocket: detect upgrade request - if name.eq_ignore_ascii_case("upgrade") { - if value.eq_ignore_ascii_case("websocket") { - is_websocket_upgrade = true; - } - } - if name.eq_ignore_ascii_case("sec-websocket-key") { - ws_key = Some(value); - } - - // Compression: parse Accept-Encoding - if accepted_encoding != compress::AcceptedEncoding::Brotli - && name.eq_ignore_ascii_case("accept-encoding") - { - accepted_encoding = compress::parse_accept_encoding(value.as_bytes()); - } - - headers.push((name, value)); - } - - Some(ParsedRequest { - method, - target, - path, - keep_alive, - header_bytes: header_len, - has_body, - content_length, - has_chunked_te, - headers, - cookie_header, - is_websocket_upgrade, - ws_key, - accepted_encoding, - }) -} - // ─── Hot Root Path (GET /) ────────────── // // Ultra-fast path for the most common benchmark case. Falls back to httparse @@ -1363,6 +1670,8 @@ fn parse_hot_root_request( let header_end = find_header_end(bytes)?; let mut keep_alive = keep_alive; let mut has_body = false; + let mut has_chunked_te = false; + let mut content_length: Option = None; let mut accepted_encoding = compress::AcceptedEncoding::Identity; let mut line_start = bytes.iter().position(|b| *b == b'\n')? + 1; @@ -1391,8 +1700,15 @@ fn parse_hot_root_request( { let value = trim_ascii_spaces(&line[server_config.header_content_length_prefix.len()..]); - if value != b"0" { - has_body = true; + /* @B4: parse Content-Length for TE+CL smuggling detection */ + if let Some(len) = std::str::from_utf8(value) + .ok() + .and_then(|s| s.parse::().ok()) + { + content_length = Some(len); + if len > 0 { + has_body = true; + } } } else if line.len() >= server_config.header_transfer_encoding_prefix.len() && line[..server_config.header_transfer_encoding_prefix.len()] @@ -1400,8 +1716,10 @@ fn parse_hot_root_request( { let value = trim_ascii_spaces(&line[server_config.header_transfer_encoding_prefix.len()..]); + /* @B4: flag non-identity Transfer-Encoding for smuggling guard */ if !value.is_empty() && !value.eq_ignore_ascii_case(b"identity") { has_body = true; + has_chunked_te = true; } } else if accepted_encoding != compress::AcceptedEncoding::Brotli && line.len() >= 17 @@ -1421,10 +1739,10 @@ fn parse_hot_root_request( keep_alive, header_bytes: header_end + 4, has_body, - content_length: None, - has_chunked_te: false, + content_length, + has_chunked_te, headers: Vec::new(), - cookie_header: None, // Hot path doesn't parse cookies + cookie_header: None, is_websocket_upgrade: false, ws_key: None, accepted_encoding, @@ -1494,7 +1812,7 @@ fn build_dispatch_decision_zero_copy( let mut cache_insertion = None; if let Some(cfg) = matched_route.cache_config { - let base_key = crate::router::interpolate_cache_key(cfg, parsed, url_str, matched_route.param_names, &matched_route.param_values); + let base_key = crate::router::interpolate_cache_key(cfg, &parsed.headers, url_str, matched_route.param_names, &matched_route.param_values); let key = vary_cache_key_by_encoding(base_key, accepted_encoding); if let Some(cached_response) = crate::router::get_cached_response(matched_route.cache_namespace, key, parsed.keep_alive) { return Ok(DispatchDecision::CachedResponse(cached_response)); @@ -1535,12 +1853,20 @@ fn build_dispatch_decision_zero_copy( )) } -fn build_dispatch_decision_owned( +/// /* @param router — compiled route table */ +/// /* @param method — raw HTTP method bytes */ +/// /* @param target — raw request target (URL) bytes */ +/// /* @param path — path portion (before '?') bytes */ +/// /* @param headers — pre-parsed (name, value) str pairs */ +/// /* @param body — request body bytes */ +/// /* @param peer_ip — client IP string, if known */ +/// /* @param accepted_encoding — best encoding from Accept-Encoding */ +fn build_dispatch_decision_owned<'a>( router: &Router, method: &[u8], target: &[u8], path: &[u8], - headers: &[(String, String)], + headers: &[(&'a str, &'a str)], body: &[u8], peer_ip: Option<&str>, accepted_encoding: compress::AcceptedEncoding, @@ -1552,10 +1878,7 @@ fn build_dispatch_decision_owned( let url_cow = String::from_utf8_lossy(target); let url_str = url_cow.as_ref(); - let header_refs: Vec<(&str, &str)> = headers - .iter() - .map(|(n, v)| (n.as_str(), v.as_str())) - .collect(); + let header_refs = headers; // Security: strict path validation let normalized_path = normalize_runtime_path(path_str); @@ -1591,22 +1914,9 @@ fn build_dispatch_decision_owned( let mut cache_insertion = None; if let Some(cfg) = matched_route.cache_config { - let mock_parsed = ParsedRequest { - method, - target, - path, - keep_alive: false, - header_bytes: 0, - has_body: true, - content_length: None, - has_chunked_te: false, - headers: header_refs.clone(), - cookie_header: None, - is_websocket_upgrade: false, - ws_key: None, - accepted_encoding: compress::AcceptedEncoding::Identity, - }; - let base_key = crate::router::interpolate_cache_key(cfg, &mock_parsed, url_str, matched_route.param_names, &matched_route.param_values); + /* @BOOST-1.3: pass header slice directly — avoids constructing a + * throwaway ParsedRequest and the .to_vec() clone that entailed. */ + let base_key = crate::router::interpolate_cache_key(cfg, header_refs, url_str, matched_route.param_names, &matched_route.param_values); let key = vary_cache_key_by_encoding(base_key, accepted_encoding); cache_insertion = Some((matched_route.cache_namespace, key, cfg.max_entries, cfg.ttl_secs)); } else { @@ -2117,107 +2427,6 @@ fn append_json_string(output: &mut Vec, value: &str) { output.push(b'"'); } -fn build_response_bytes_fast( - status: u16, - headers: &[(Box, Box)], - body: &[u8], - keep_alive: bool, - encoding: compress::AcceptedEncoding, - compression_config: Option<&compress::CompressionConfig>, -) -> Vec { - // ── Scan headers for content-type / content-encoding ── - let mut content_type: Option<&[u8]> = None; - let mut has_content_encoding = false; - for (name, value) in headers { - if name.eq_ignore_ascii_case("content-type") { - content_type = Some(value.as_bytes()); - } else if name.eq_ignore_ascii_case("content-encoding") { - has_content_encoding = true; - } - } - - // ── Attempt compression ── - let compressed = compression_config.and_then(|config| { - compress::should_compress(config, encoding, body.len(), content_type, has_content_encoding) - .and_then(|enc| compress::compress_body(body, enc, config, content_type).map(|data| (data, enc))) - }); - - let (final_body, applied_encoding): (&[u8], Option) = - match &compressed { - Some((data, enc)) => (data.as_slice(), Some(*enc)), - None => (body, None), - }; - - // ── Build HTTP response ── - let reason = status_reason(status); - let connection = if keep_alive { "keep-alive" } else { "close" }; - let body_len = final_body.len(); - - let mut total_size = - 9 + 3 + 1 + reason.len() + 2 + 16 + count_digits(body_len) + 2 + 12 + connection.len() + 2; - - if applied_encoding.is_some() { - // "content-encoding: br\r\nvary: accept-encoding\r\n" worst case ~50 bytes - total_size += 50; - } - - for (name, value) in headers { - if name.eq_ignore_ascii_case("content-length") || name.eq_ignore_ascii_case("connection") { - continue; - } - if name.contains('\r') - || name.contains('\n') - || value.contains('\r') - || value.contains('\n') - { - continue; - } - total_size += name.len() + 2 + value.len() + 2; - } - - total_size += 2 + body_len; - - let mut output = Vec::with_capacity(total_size); - output.extend_from_slice(b"HTTP/1.1 "); - write_u16(&mut output, status); - output.push(b' '); - output.extend_from_slice(reason.as_bytes()); - output.extend_from_slice(b"\r\n"); - output.extend_from_slice(b"content-length: "); - write_usize(&mut output, body_len); - output.extend_from_slice(b"\r\n"); - output.extend_from_slice(b"connection: "); - output.extend_from_slice(connection.as_bytes()); - output.extend_from_slice(b"\r\n"); - - // Compression headers - if let Some(enc) = applied_encoding { - output.extend_from_slice(b"content-encoding: "); - output.extend_from_slice(compress::encoding_header_value(enc)); - output.extend_from_slice(b"\r\nvary: accept-encoding\r\n"); - } - - for (name, value) in headers { - if name.eq_ignore_ascii_case("content-length") || name.eq_ignore_ascii_case("connection") { - continue; - } - if name.contains('\r') - || name.contains('\n') - || value.contains('\r') - || value.contains('\n') - { - continue; - } - output.extend_from_slice(name.as_bytes()); - output.extend_from_slice(b": "); - output.extend_from_slice(value.as_bytes()); - output.extend_from_slice(b"\r\n"); - } - - output.extend_from_slice(b"\r\n"); - output.extend_from_slice(final_body); - output -} // ─── Response Writing ─────────────────── @@ -2345,8 +2554,11 @@ fn extract_ncache_trailer(dispatch_bytes: &[u8]) -> Option<(u64, usize)> { /// Uses FxHasher (~5x faster than SipHash/DefaultHasher for short keys). fn compute_ncache_key(url_bytes: &[u8]) -> u64 { use std::hash::{Hash, Hasher}; - use rustc_hash::FxHasher; - let mut hasher = FxHasher::default(); + use std::collections::hash_map::DefaultHasher; + /* Use SipHash (DefaultHasher) instead of FxHasher to prevent + * cache poisoning via hash collision attacks. SipHash is + * randomized per process, making collisions unpredictable. */ + let mut hasher = DefaultHasher::new(); url_bytes.hash(&mut hasher); hasher.finish() } @@ -2664,7 +2876,7 @@ where match trailer.action { session::SessionAction::Update => { if let Some(sid) = session_id { - store.upsert(&sid, trailer.mutations, &trailer.deleted_keys); + let _ = store.upsert(&sid, trailer.mutations, &trailer.deleted_keys); // Inject Set-Cookie for new sessions if is_new_session { let cookie = store.build_set_cookie(&sid); @@ -2686,9 +2898,9 @@ where store.destroy(&old_sid); let new_sid = store.generate_id(); if let Some(entry) = old_data { - store.upsert(&new_sid, entry.data, &[]); + let _ = store.upsert(&new_sid, entry.data, &[]); } - store.upsert(&new_sid, trailer.mutations, &trailer.deleted_keys); + let _ = store.upsert(&new_sid, trailer.mutations, &trailer.deleted_keys); let cookie = store.build_set_cookie(&new_sid); inject_set_cookie_header(&mut http_response, &cookie); } @@ -2697,7 +2909,7 @@ where } else if is_new_session { // No session trailer but session was accessed — set cookie if let Some(sid) = session_id { - store.upsert(&sid, std::collections::HashMap::new(), &[]); + let _ = store.upsert(&sid, std::collections::HashMap::new(), &[]); let cookie = store.build_set_cookie(&sid); inject_set_cookie_header(&mut http_response, &cookie); } @@ -2847,28 +3059,6 @@ fn build_http_response_from_dispatch( Ok(output) } -/// Patch the `connection:` header value in an already-built HTTP response. -/// Searches for `connection: keep-alive` or `connection: close` and swaps to the -/// requested variant. The two values differ in length (10 vs 5 bytes) so the -/// Vec may grow or shrink by a few bytes. -fn patch_connection_header(response: &[u8], keep_alive: bool) -> Vec { - let (find, replace) = if keep_alive { - (&b"connection: close\r\n"[..], &b"connection: keep-alive\r\n"[..]) - } else { - (&b"connection: keep-alive\r\n"[..], &b"connection: close\r\n"[..]) - }; - - if let Some(pos) = memmem::find(response, find) { - let mut out = Vec::with_capacity(response.len() + replace.len() - find.len()); - out.extend_from_slice(&response[..pos]); - out.extend_from_slice(replace); - out.extend_from_slice(&response[pos + find.len()..]); - out - } else { - // Header not found (shouldn't happen) — return unchanged clone - response.to_vec() - } -} /// Resolve the session ID from the cookie header. Returns (session_id, is_new). /// If no cookie is present or invalid, generates a new session ID. @@ -2967,9 +3157,16 @@ where Err(flume::TryRecvError::Disconnected) => break, } - // Read more data from the client + // Read more data from the client (with idle timeout) let owned_buf = std::mem::take(buffer); - let (read_result, returned_buf) = stream.read(owned_buf).await; + let timeout_result = timeout(TIMEOUT_WS_IDLE, stream.read(owned_buf)).await; + let (read_result, returned_buf) = match timeout_result { + Ok(res) => res, + Err(_) => { + // Idle timeout — close connection + break; + } + }; *buffer = returned_buf; match read_result { Ok(0) => break, @@ -3002,46 +3199,7 @@ fn build_ws_event_envelope(event_type: u8, ws_id: u64, handler_id: u32, data: &[ buf } -/// Inject a Set-Cookie header into an already-built HTTP response. -/// Inserts the header just before the final \r\n\r\n (end of headers). -fn inject_set_cookie_header(response: &mut Vec, cookie_value: &str) { - // Find the \r\n\r\n boundary between headers and body - if let Some(pos) = memmem::find(response, b"\r\n\r\n") { - let header_line = format!("set-cookie: {}\r\n", cookie_value); - let header_bytes = header_line.as_bytes(); - // Insert before the final \r\n\r\n - let insert_pos = pos + 2; // after the last header's \r\n, before the blank \r\n - response.splice(insert_pos..insert_pos, header_bytes.iter().copied()); - - // Update Content-Length — it shouldn't change since we're adding headers, not body. - // Content-Length only measures the body, which is unchanged. - } -} - -/// Build a simple error response without going through the JS bridge -fn build_error_response_bytes(status: u16, body: &[u8], keep_alive: bool) -> Vec { - let reason = status_reason(status); - let connection = if keep_alive { "keep-alive" } else { "close" }; - let body_len = body.len(); - - let total_size = - 9 + 3 + 1 + reason.len() + 2 + 16 + count_digits(body_len) + 2 + 12 + connection.len() + 2 + 45 + 2 + body_len; - - let mut output = Vec::with_capacity(total_size); - output.extend_from_slice(b"HTTP/1.1 "); - write_u16(&mut output, status); - output.push(b' '); - output.extend_from_slice(reason.as_bytes()); - output.extend_from_slice(b"\r\ncontent-length: "); - write_usize(&mut output, body_len); - output.extend_from_slice(b"\r\nconnection: "); - output.extend_from_slice(connection.as_bytes()); - output.extend_from_slice(b"\r\ncontent-type: application/json; charset=utf-8\r\n\r\n"); - output.extend_from_slice(body); - - output -} // ─── Security Utilities ───────────────── @@ -3167,6 +3325,20 @@ fn method_code_from_bytes(method: &[u8]) -> Option { } } +/// @DX-4.1: str-based variant for HTTP/2 method mapping (h2 crate uses &str). +pub(crate) fn method_code_from_str(method: &str) -> Option { + match method { + "GET" => Some(1), + "POST" => Some(2), + "PUT" => Some(3), + "DELETE" => Some(4), + "PATCH" => Some(5), + "OPTIONS" => Some(6), + "HEAD" => Some(7), + _ => None, + } +} + fn drain_consumed_bytes(buffer: &mut Vec, consumed: usize) { if consumed >= buffer.len() { buffer.clear(); @@ -3193,16 +3365,55 @@ fn bind_listener( .unwrap_or(server_config.default_host.as_str()); let bind_addr = resolve_socket_addr(host, options.port) .with_context(|| format!("failed to resolve bind address {host}:{}", options.port))?; - let listener_opts = ListenerOpts::new() - .reuse_addr(true) - .reuse_port(true) - .backlog(options.backlog.unwrap_or(server_config.default_backlog)); - TcpListener::bind_with_config(bind_addr, &listener_opts) - .with_context(|| format!("failed to bind TCP listener on {bind_addr}")) + /* @B3.5: configure raw socket with TCP_FASTOPEN before binding. + * TFO allows data in the SYN packet on resumed connections, saving 1 RTT + * for repeat clients. The queue length (256) limits the number of pending + * TFO connections the kernel will accept. Falls back silently on systems + * that don't support it (macOS, older Linux kernels). */ + let raw_socket = socket2::Socket::new( + if bind_addr.is_ipv6() { socket2::Domain::IPV6 } else { socket2::Domain::IPV4 }, + socket2::Type::STREAM, + Some(socket2::Protocol::TCP), + ).context("failed to create raw socket")?; + + raw_socket.set_reuse_address(true)?; + #[cfg(unix)] + { + raw_socket.set_reuse_port(true)?; + } + + /* @B3.5: TCP_FASTOPEN — allow data in SYN packet on resumed connections. + * Uses raw setsockopt since socket2 doesn't expose TFO directly. + * Silently ignored on systems that don't support it. */ + #[cfg(target_os = "linux")] + { + use std::os::unix::io::AsRawFd; + let fd = raw_socket.as_raw_fd(); + let val: libc::c_int = 256; // TFO queue length + unsafe { + libc::setsockopt( + fd, + libc::IPPROTO_TCP, + libc::TCP_FASTOPEN, + &val as *const _ as *const libc::c_void, + std::mem::size_of::() as libc::socklen_t, + ); + } + } + + raw_socket.bind(&bind_addr.into())?; + raw_socket.listen(options.backlog.unwrap_or(server_config.default_backlog))?; + raw_socket.set_nonblocking(true)?; + + let std_listener: std::net::TcpListener = raw_socket.into(); + TcpListener::from_std(std_listener) + .with_context(|| format!("failed to create monoio listener from raw socket on {bind_addr}")) } -fn build_tls_acceptor(manifest: &ManifestInput) -> Result> { +/// @DX-4.1: returns both the monoio-rustls acceptor and the shared rustls +/// config Arc. The Arc is needed for tokio-rustls when h2 is negotiated. +fn build_tls_acceptor(manifest: &ManifestInput) -> Result)>> { let Some(tls) = manifest.tls.as_ref() else { return Ok(None); }; @@ -3218,9 +3429,13 @@ fn build_tls_acceptor(manifest: &ManifestInput) -> Result> { .with_no_client_auth() .with_single_cert(cert_chain, private_key) .context("failed to construct rustls server config")?; - config.alpn_protocols = vec![b"http/1.1".to_vec()]; + /* @DX-4.1: advertise both HTTP/2 and HTTP/1.1 via ALPN. rustls will + * select the client's preferred protocol during the TLS handshake. + * h2 is listed first so capable clients default to HTTP/2. */ + config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; - Ok(Some(TlsAcceptor::from(Arc::new(config)))) + let config_arc = Arc::new(config); + Ok(Some((TlsAcceptor::from(config_arc.clone()), config_arc))) } fn parse_tls_certificates( @@ -3250,7 +3465,7 @@ fn parse_tls_private_key(tls: &TlsConfigInput) -> Result> if tls.passphrase.is_some() { return Err(anyhow!( - "encrypted TLS private keys are not supported by this loader; provide an unencrypted PEM key" + "encrypted TLS private keys are not supported; provide an unencrypted PEM key" )); } @@ -3281,34 +3496,8 @@ fn validate_manifest(manifest: &ManifestInput) -> Result<()> { Ok(()) } -fn find_header_end(bytes: &[u8]) -> Option { - memmem::find(bytes, b"\r\n\r\n") -} - -fn contains_ascii_case_insensitive(haystack: &[u8], needle: &[u8]) -> bool { - if needle.is_empty() || haystack.len() < needle.len() { - return false; - } - - haystack - .windows(needle.len()) - .any(|window| window.eq_ignore_ascii_case(needle)) -} - -fn trim_ascii_spaces(bytes: &[u8]) -> &[u8] { - let start = bytes - .iter() - .position(|byte| !byte.is_ascii_whitespace()) - .unwrap_or(bytes.len()); - let end = bytes - .iter() - .rposition(|byte| !byte.is_ascii_whitespace()) - .map(|index| index + 1) - .unwrap_or(start); - &bytes[start..end] -} -fn normalize_runtime_path(path: &str) -> Cow<'_, str> { +pub(crate) fn normalize_runtime_path(path: &str) -> Cow<'_, str> { // Fast path: "/" or no trailing slash — zero allocation if path == "/" || !path.ends_with('/') { return Cow::Borrowed(path); @@ -3338,61 +3527,7 @@ fn config_string( input.and_then(pick).unwrap_or(fallback).to_string() } -fn status_reason(status: u16) -> &'static str { - match status { - 200 => "OK", - 201 => "Created", - 202 => "Accepted", - 204 => "No Content", - 301 => "Moved Permanently", - 302 => "Found", - 304 => "Not Modified", - 400 => "Bad Request", - 401 => "Unauthorized", - 403 => "Forbidden", - 404 => "Not Found", - 405 => "Method Not Allowed", - 408 => "Request Timeout", - 409 => "Conflict", - 411 => "Length Required", - 413 => "Payload Too Large", - 415 => "Unsupported Media Type", - 422 => "Unprocessable Entity", - 429 => "Too Many Requests", - 431 => "Request Header Fields Too Large", - 500 => "Internal Server Error", - 501 => "Not Implemented", - 502 => "Bad Gateway", - 503 => "Service Unavailable", - 504 => "Gateway Timeout", - _ => "Unknown", - } -} - -/// Fast integer-to-string for small values — uses stack-allocated itoa buffer -#[inline(always)] -fn write_usize(output: &mut Vec, value: usize) { - let mut buf = itoa::Buffer::new(); - output.extend_from_slice(buf.format(value).as_bytes()); -} - -#[inline(always)] -fn write_u16(output: &mut Vec, value: u16) { - let mut buf = itoa::Buffer::new(); - output.extend_from_slice(buf.format(value).as_bytes()); -} - -fn count_digits(mut n: usize) -> usize { - if n == 0 { - return 1; - } - let mut count = 0; - while n > 0 { - count += 1; - n /= 10; - } - count -} +use crate::http_utils::{status_reason, write_usize, write_u16}; fn push_string_pair(frame: &mut Vec, name: &str, value: &str) -> Result<()> { if name.len() > u8::MAX as usize { diff --git a/rsrc/src/manifest.rs b/rsrc/src/manifest.rs index cbe5878..8fbb14b 100644 --- a/rsrc/src/manifest.rs +++ b/rsrc/src/manifest.rs @@ -108,6 +108,11 @@ pub struct RouteInput { pub cache: Option, #[serde(default)] pub static_response: Option, + /// @DX-3.4: per-route maximum request body size in bytes. Overrides the + /// global MAX_BODY_BYTES constant when set. Enforced in Rust before the + /// full body is read into memory. + #[serde(default)] + pub max_body_bytes: Option, } #[derive(Debug, Clone, Deserialize)] diff --git a/rsrc/src/parser.rs b/rsrc/src/parser.rs new file mode 100644 index 0000000..a8a4dd7 --- /dev/null +++ b/rsrc/src/parser.rs @@ -0,0 +1,294 @@ +//! HTTP/1.1 request parser extracted from lib.rs (plan item R1). +//! +//! Contains the `ParsedRequest` struct, the httparse-based parser, and +//! low-level byte utilities shared between the hot-path parser and the +//! full parser. + +use memchr::memmem; + +use crate::compress; + +// ─── Constants ──────────────────────────── + +/// Maximum number of headers allowed per request. +pub const MAX_HEADER_COUNT: usize = 64; +/// Maximum URL length to prevent abuse. +pub const MAX_URL_LENGTH: usize = 8192; +/// Maximum single header value length. +pub const MAX_HEADER_VALUE_LENGTH: usize = 8192; + +// ─── Parsed Request ─────────────────────── + +/// Zero-copy parsed HTTP/1.1 request. All string slices borrow from the +/// connection read buffer and are valid until the buffer is drained. +pub struct ParsedRequest<'a> { + pub method: &'a [u8], + pub target: &'a [u8], + pub path: &'a [u8], + pub keep_alive: bool, + pub header_bytes: usize, + pub has_body: bool, + pub content_length: Option, + /// True when a non-identity Transfer-Encoding header was seen. + pub has_chunked_te: bool, + /// Pre-parsed header pairs — stored once, used by both routing and bridge. + pub headers: Vec<(&'a str, &'a str)>, + /// Raw cookie header value for session extraction. + pub cookie_header: Option<&'a str>, + /// True when the request contains an Upgrade: websocket header. + pub is_websocket_upgrade: bool, + /// The Sec-WebSocket-Key header value, if present. + pub ws_key: Option<&'a str>, + /// Best accepted encoding from Accept-Encoding header. + pub accepted_encoding: compress::AcceptedEncoding, +} + +// ─── httparse-based Request Parsing ─────── +// +// Uses the battle-tested `httparse` crate for RFC-compliant zero-copy parsing. +// Single-pass: parses headers once and stores them for reuse by both the +// router and the bridge envelope builder. + +/// Parse a raw HTTP/1.1 request from a byte buffer using httparse. +/// +/// Returns `None` if the buffer is incomplete (partial headers) or if the +/// request is malformed. The caller retries after reading more data. +/// +/// /* @param bytes — raw bytes from the connection read buffer */ +/// /* @returns — parsed request borrowing from the buffer, or None */ +pub fn parse_request_httparse(bytes: &[u8]) -> Option> { + let mut raw_headers = [httparse::EMPTY_HEADER; MAX_HEADER_COUNT]; + let mut req = httparse::Request::new(&mut raw_headers); + + let header_len = match req.parse(bytes) { + Ok(httparse::Status::Complete(len)) => len, + Ok(httparse::Status::Partial) => return None, + Err(_) => return None, + }; + + let method = req.method?.as_bytes(); + let target = req.path?.as_bytes(); + let version = req.version?; + + /* security: enforce URL length limit */ + if target.len() > MAX_URL_LENGTH { + return None; + } + + /* extract path (before '?') — use SIMD-accelerated memchr for the scan (@B4.1) */ + let path = match memchr::memchr(b'?', target) { + Some(pos) => &target[..pos], + None => target, + }; + + let mut keep_alive = version >= 1; // HTTP/1.1+ defaults to keep-alive + let mut has_body = false; + let mut content_length: Option = None; + let mut has_chunked_te = false; + let mut cookie_header: Option<&str> = None; + let mut is_websocket_upgrade = false; + let mut ws_key: Option<&str> = None; + let mut accepted_encoding = compress::AcceptedEncoding::Identity; + let mut headers = Vec::with_capacity(req.headers.len()); + + for header in req.headers.iter() { + if header.name.is_empty() { + break; + } + + /* security: enforce header value length */ + if header.value.len() > MAX_HEADER_VALUE_LENGTH { + return None; + } + + let name = intern_header_name(header.name); + let value = match std::str::from_utf8(header.value) { + Ok(v) => v, + Err(_) => continue, + }; + + /* connection handling — interned names are lowercase so fast-path + * comparisons can use pointer equality for known headers. Fall back + * to case-insensitive for non-interned names. */ + if name.eq_ignore_ascii_case("connection") { + let vb = value.as_bytes(); + if contains_ascii_case_insensitive(vb, b"close") { + keep_alive = false; + } + if contains_ascii_case_insensitive(vb, b"keep-alive") { + keep_alive = true; + } + } + + /* body detection */ + if name.eq_ignore_ascii_case("content-length") { + let trimmed = value.trim(); + if let Ok(len) = trimmed.parse::() { + content_length = Some(len); + if len > 0 { + has_body = true; + } + } + } + + if name.eq_ignore_ascii_case("transfer-encoding") { + let trimmed = value.trim(); + if !trimmed.is_empty() && !trimmed.eq_ignore_ascii_case("identity") { + has_body = true; + has_chunked_te = true; + } + } + + /* session: capture cookie header */ + if name.eq_ignore_ascii_case("cookie") { + cookie_header = Some(value); + } + + /* websocket: detect upgrade request */ + if name.eq_ignore_ascii_case("upgrade") { + if value.eq_ignore_ascii_case("websocket") { + is_websocket_upgrade = true; + } + } + if name.eq_ignore_ascii_case("sec-websocket-key") { + ws_key = Some(value); + } + + /* compression: parse Accept-Encoding */ + if accepted_encoding != compress::AcceptedEncoding::Brotli + && name.eq_ignore_ascii_case("accept-encoding") + { + accepted_encoding = compress::parse_accept_encoding(value.as_bytes()); + } + + headers.push((name, value)); + } + + Some(ParsedRequest { + method, + target, + path, + keep_alive, + header_bytes: header_len, + has_body, + content_length, + has_chunked_te, + headers, + cookie_header, + is_websocket_upgrade, + ws_key, + accepted_encoding, + }) +} + +// ─── Byte Utilities ─────────────────────── + +/// Find the `\r\n\r\n` header-body boundary in a byte buffer. +/// +/// /* @param bytes — raw request bytes */ +/// /* @returns — byte offset of the first `\r\n` in the `\r\n\r\n` sequence */ +pub fn find_header_end(bytes: &[u8]) -> Option { + memmem::find(bytes, b"\r\n\r\n") +} + +/// Check if `haystack` contains `needle` with ASCII case-insensitive comparison. +/// +/// /* @param haystack — bytes to search in */ +/// /* @param needle — pattern to find (case-insensitive) */ +pub fn contains_ascii_case_insensitive(haystack: &[u8], needle: &[u8]) -> bool { + if needle.is_empty() || haystack.len() < needle.len() { + return false; + } + + haystack + .windows(needle.len()) + .any(|window| window.eq_ignore_ascii_case(needle)) +} + +/// Trim leading and trailing ASCII whitespace from a byte slice. +/// +/// /* @param bytes — input byte slice */ +/// /* @returns — trimmed sub-slice */ +pub fn trim_ascii_spaces(bytes: &[u8]) -> &[u8] { + let start = bytes + .iter() + .position(|byte| !byte.is_ascii_whitespace()) + .unwrap_or(bytes.len()); + let end = bytes + .iter() + .rposition(|byte| !byte.is_ascii_whitespace()) + .map(|index| index + 1) + .unwrap_or(start); + &bytes[start..end] +} + +// ─── Header Name Interning (@BOOST-2.2) ───── +// +// Maps the most common HTTP header names to `&'static str` references. +// This eliminates redundant byte copies when header names flow through +// the body-request path (hdr_buf packing) and speeds up downstream +// case-insensitive comparisons (interned names are already lowercase). +// +// Strategy: branch on `raw.len()` first (single integer compare), then +// do a case-insensitive match only within that length bucket. This keeps +// the miss path cost to one branch on the length — no string comparisons +// for headers that don't match any known length. + +/// /* @param raw — header name borrowed from the httparse parse buffer */ +/// /* @returns — static `&str` for well-known headers, original `&str` otherwise */ +pub fn intern_header_name<'a>(raw: &'a str) -> &'a str { + match raw.len() { + 4 => { + if raw.eq_ignore_ascii_case("host") { return "host"; } + if raw.eq_ignore_ascii_case("date") { return "date"; } + if raw.eq_ignore_ascii_case("vary") { return "vary"; } + raw + } + 6 => { + if raw.eq_ignore_ascii_case("accept") { return "accept"; } + if raw.eq_ignore_ascii_case("cookie") { return "cookie"; } + if raw.eq_ignore_ascii_case("origin") { return "origin"; } + raw + } + 7 => { + if raw.eq_ignore_ascii_case("referer") { return "referer"; } + if raw.eq_ignore_ascii_case("upgrade") { return "upgrade"; } + raw + } + 10 => { + if raw.eq_ignore_ascii_case("connection") { return "connection"; } + if raw.eq_ignore_ascii_case("user-agent") { return "user-agent"; } + raw + } + 12 => { + if raw.eq_ignore_ascii_case("content-type") { return "content-type"; } + raw + } + 13 => { + if raw.eq_ignore_ascii_case("authorization") { return "authorization"; } + if raw.eq_ignore_ascii_case("cache-control") { return "cache-control"; } + if raw.eq_ignore_ascii_case("if-none-match") { return "if-none-match"; } + raw + } + 14 => { + if raw.eq_ignore_ascii_case("content-length") { return "content-length"; } + if raw.eq_ignore_ascii_case("accept-charset") { return "accept-charset"; } + raw + } + 15 => { + if raw.eq_ignore_ascii_case("accept-encoding") { return "accept-encoding"; } + if raw.eq_ignore_ascii_case("accept-language") { return "accept-language"; } + raw + } + 17 => { + if raw.eq_ignore_ascii_case("transfer-encoding") { return "transfer-encoding"; } + if raw.eq_ignore_ascii_case("if-modified-since") { return "if-modified-since"; } + raw + } + 19 => { + if raw.eq_ignore_ascii_case("content-disposition") { return "content-disposition"; } + raw + } + _ => raw, + } +} diff --git a/rsrc/src/rate_limit.rs b/rsrc/src/rate_limit.rs index 4ef060d..9eba368 100644 --- a/rsrc/src/rate_limit.rs +++ b/rsrc/src/rate_limit.rs @@ -4,6 +4,9 @@ use std::sync::OnceLock; use std::time::{SystemTime, UNIX_EPOCH}; const NAMESPACE_SEPARATOR: char = '\u{1f}'; +const MAX_RATE_LIMIT_ENTRIES: usize = 100_000; +const EVICTION_TARGET_RATIO: f64 = 0.8; +const EVICT_STALE_THRESHOLD_MS: u64 = 3_600_000; // 1 hour #[derive(Debug, Clone)] pub struct RateLimitDecision { @@ -88,6 +91,40 @@ pub fn now_ms() -> u64 { .unwrap_or(0) } +fn maybe_evict(map: &DashMap, now_ms: u64) { + if map.len() <= MAX_RATE_LIMIT_ENTRIES { + return; + } + + let target = (MAX_RATE_LIMIT_ENTRIES as f64 * EVICTION_TARGET_RATIO) as usize; + + // First pass: remove stale entries (not seen in over 1 hour) + let stale_keys: Vec = map + .iter() + .filter(|entry| now_ms.saturating_sub(entry.value().last_seen_ms) > EVICT_STALE_THRESHOLD_MS) + .map(|entry| entry.key().clone()) + .collect(); + for key in &stale_keys { + map.remove(key); + } + + if map.len() <= target { + return; + } + + // Second pass: evict oldest entries by last_seen_ms + let mut entries: Vec<(String, u64)> = map + .iter() + .map(|e| (e.key().clone(), e.value().last_seen_ms)) + .collect(); + entries.sort_by_key(|(_, ts)| *ts); + + let to_evict = map.len().saturating_sub(target); + for (key, _) in entries.into_iter().take(to_evict) { + map.remove(&key); + } +} + pub fn check( namespace: &str, key: &str, @@ -127,6 +164,8 @@ pub fn check( let _ = map.remove(&compound); } + maybe_evict(map, now_ms); + RateLimitDecision { allowed, limit: max, diff --git a/rsrc/src/response.rs b/rsrc/src/response.rs new file mode 100644 index 0000000..3bc62d2 --- /dev/null +++ b/rsrc/src/response.rs @@ -0,0 +1,202 @@ +//! Response builder utilities (plan item R2). +//! +//! Contains functions for constructing raw HTTP/1.1 response byte vectors, +//! patching connection headers, and injecting Set-Cookie headers into +//! already-built responses. + +use memchr::memmem; + +use crate::http_utils::{count_digits, status_reason, write_u16, write_usize}; +use crate::compress; + +/// Build a complete HTTP/1.1 response with optional compression. +/// +/// Returns a fully formed response (status line + headers + body) as a byte vector. +pub fn build_response_bytes_fast( + /* @param status HTTP status code (e.g. 200, 404) */ + status: u16, + /* @param headers response headers as name/value pairs */ + headers: &[(Box, Box)], + /* @param body raw response body bytes */ + body: &[u8], + /* @param keep_alive whether to set connection: keep-alive */ + keep_alive: bool, + /* @param encoding client-accepted encoding from Accept-Encoding */ + encoding: compress::AcceptedEncoding, + /* @param compression_config per-content-type compression settings, if enabled */ + compression_config: Option<&compress::CompressionConfig>, +) -> Vec { + // ── Scan headers for content-type / content-encoding ── + let mut content_type: Option<&[u8]> = None; + let mut has_content_encoding = false; + for (name, value) in headers { + if name.eq_ignore_ascii_case("content-type") { + content_type = Some(value.as_bytes()); + } else if name.eq_ignore_ascii_case("content-encoding") { + has_content_encoding = true; + } + } + + // ── Attempt compression ── + let compressed = compression_config.and_then(|config| { + compress::should_compress(config, encoding, body.len(), content_type, has_content_encoding) + .and_then(|enc| compress::compress_body(body, enc, config, content_type).map(|data| (data, enc))) + }); + + let (final_body, applied_encoding): (&[u8], Option) = + match &compressed { + Some((data, enc)) => (data.as_slice(), Some(*enc)), + None => (body, None), + }; + + // ── Build HTTP response ── + let reason = status_reason(status); + let connection = if keep_alive { "keep-alive" } else { "close" }; + let body_len = final_body.len(); + + let mut total_size = + 9 + 3 + 1 + reason.len() + 2 + 16 + count_digits(body_len) + 2 + 12 + connection.len() + 2; + + if applied_encoding.is_some() { + // "content-encoding: br\r\nvary: accept-encoding\r\n" worst case ~50 bytes + total_size += 50; + } + + for (name, value) in headers { + if name.eq_ignore_ascii_case("content-length") || name.eq_ignore_ascii_case("connection") { + continue; + } + if name.contains('\r') + || name.contains('\n') + || value.contains('\r') + || value.contains('\n') + { + continue; + } + total_size += name.len() + 2 + value.len() + 2; + } + + total_size += 2 + body_len; + + let mut output = Vec::with_capacity(total_size); + output.extend_from_slice(b"HTTP/1.1 "); + write_u16(&mut output, status); + output.push(b' '); + output.extend_from_slice(reason.as_bytes()); + output.extend_from_slice(b"\r\n"); + output.extend_from_slice(b"content-length: "); + write_usize(&mut output, body_len); + output.extend_from_slice(b"\r\n"); + output.extend_from_slice(b"connection: "); + output.extend_from_slice(connection.as_bytes()); + output.extend_from_slice(b"\r\n"); + + // Compression headers + if let Some(enc) = applied_encoding { + output.extend_from_slice(b"content-encoding: "); + output.extend_from_slice(compress::encoding_header_value(enc)); + output.extend_from_slice(b"\r\nvary: accept-encoding\r\n"); + } + + for (name, value) in headers { + if name.eq_ignore_ascii_case("content-length") || name.eq_ignore_ascii_case("connection") { + continue; + } + if name.contains('\r') + || name.contains('\n') + || value.contains('\r') + || value.contains('\n') + { + continue; + } + output.extend_from_slice(name.as_bytes()); + output.extend_from_slice(b": "); + output.extend_from_slice(value.as_bytes()); + output.extend_from_slice(b"\r\n"); + } + + output.extend_from_slice(b"\r\n"); + output.extend_from_slice(final_body); + output +} + +/// Patch the `connection:` header value in an already-built HTTP response. +/// Searches for `connection: keep-alive` or `connection: close` and swaps to the +/// requested variant. The two values differ in length (10 vs 5 bytes) so the +/// Vec may grow or shrink by a few bytes. +pub fn patch_connection_header( + /* @param response raw HTTP response bytes to patch */ + response: &[u8], + /* @param keep_alive true → set keep-alive, false → set close */ + keep_alive: bool, +) -> Vec { + let (find, replace) = if keep_alive { + (&b"connection: close\r\n"[..], &b"connection: keep-alive\r\n"[..]) + } else { + (&b"connection: keep-alive\r\n"[..], &b"connection: close\r\n"[..]) + }; + + if let Some(pos) = memmem::find(response, find) { + let mut out = Vec::with_capacity(response.len() + replace.len() - find.len()); + out.extend_from_slice(&response[..pos]); + out.extend_from_slice(replace); + out.extend_from_slice(&response[pos + find.len()..]); + out + } else { + // Header not found (shouldn't happen) — return unchanged clone + response.to_vec() + } +} + +/// Inject a Set-Cookie header into an already-built HTTP response. +/// Inserts the header just before the final \r\n\r\n (end of headers). +pub fn inject_set_cookie_header( + /* @param response mutable raw HTTP response bytes */ + response: &mut Vec, + /* @param cookie_value full Set-Cookie value string */ + cookie_value: &str, +) { + // Find the \r\n\r\n boundary between headers and body + if let Some(pos) = memmem::find(response, b"\r\n\r\n") { + let header_line = format!("set-cookie: {}\r\n", cookie_value); + let header_bytes = header_line.as_bytes(); + + // Insert before the final \r\n\r\n + let insert_pos = pos + 2; // after the last header's \r\n, before the blank \r\n + response.splice(insert_pos..insert_pos, header_bytes.iter().copied()); + + // Update Content-Length — it shouldn't change since we're adding headers, not body. + // Content-Length only measures the body, which is unchanged. + } +} + +/// Build a simple error response without going through the JS bridge +pub fn build_error_response_bytes( + /* @param status HTTP status code */ + status: u16, + /* @param body JSON error body bytes */ + body: &[u8], + /* @param keep_alive whether to set connection: keep-alive */ + keep_alive: bool, +) -> Vec { + let reason = status_reason(status); + let connection = if keep_alive { "keep-alive" } else { "close" }; + let body_len = body.len(); + + let total_size = + 9 + 3 + 1 + reason.len() + 2 + 16 + count_digits(body_len) + 2 + 12 + connection.len() + 2 + 45 + 2 + body_len; + + let mut output = Vec::with_capacity(total_size); + output.extend_from_slice(b"HTTP/1.1 "); + write_u16(&mut output, status); + output.push(b' '); + output.extend_from_slice(reason.as_bytes()); + output.extend_from_slice(b"\r\ncontent-length: "); + write_usize(&mut output, body_len); + output.extend_from_slice(b"\r\nconnection: "); + output.extend_from_slice(connection.as_bytes()); + output.extend_from_slice(b"\r\ncontent-type: application/json; charset=utf-8\r\n\r\n"); + output.extend_from_slice(body); + + output +} diff --git a/rsrc/src/router.rs b/rsrc/src/router.rs index 4193e2f..a6e24b8 100644 --- a/rsrc/src/router.rs +++ b/rsrc/src/router.rs @@ -1,11 +1,21 @@ use anyhow::Result; +use arrayvec::ArrayVec; use bytes::Bytes; use std::collections::HashMap; use std::collections::HashSet; -use std::collections::hash_map::DefaultHasher; -use std::hash::{Hash, Hasher}; +use std::collections::hash_map::RandomState; +use std::hash::{BuildHasher, Hash, Hasher}; +use std::sync::OnceLock; use std::time::Instant; +/// @S4/@S6: per-process randomly-keyed SipHash state. `RandomState::new()` +/// seeds from OS entropy, so hash values are unpredictable to attackers — +/// prevents cache-key collision attacks on vary-by params/headers. +fn keyed_state() -> &'static RandomState { + static STATE: OnceLock = OnceLock::new(); + STATE.get_or_init(RandomState::new) +} + use crate::analyzer::{ analyze_dynamic_fast_path, analyze_route, normalize_path, parse_segments, AnalysisResult, DynamicFastPathSpec, RouteSegment, @@ -44,7 +54,9 @@ pub struct ExactStaticRoute { pub struct MatchedRoute<'a, 'b> { pub handler_id: u32, pub cache_namespace: u64, - pub param_values: Vec<&'b str>, + /// Stack-allocated parameter values — avoids heap allocation per route match. + /// Capacity is MAX_STACK_SEGMENTS (16), matching the maximum segment depth. + pub param_values: ArrayVec<&'b str, MAX_STACK_SEGMENTS>, pub param_names: &'a [Box], pub header_keys: &'a [Box], pub full_headers: bool, @@ -53,6 +65,9 @@ pub struct MatchedRoute<'a, 'b> { pub needs_query: bool, pub fast_path: Option<&'a DynamicFastPathSpec>, pub cache_config: Option<&'a RouteCacheConfig>, + /// @DX-3.4: per-route body size limit — overrides the global MAX_BODY_BYTES + /// when set. None means use the global default. + pub max_body_bytes: Option, } #[derive(Clone)] @@ -83,6 +98,8 @@ struct DynamicRouteSpec { needs_query: bool, fast_path: Option, cache_config: Option, + /// @DX-3.4: per-route body size limit in bytes. + max_body_bytes: Option, } #[derive(Clone, Copy, Eq, Hash, PartialEq)] @@ -134,6 +151,9 @@ impl RadixNode { /// Insert a route into the radix tree fn insert(&mut self, segments: &[RouteSegment], spec: DynamicRouteSpec) { if segments.is_empty() { + if self.handler.is_some() { + log_warn!("duplicate route registered, overwriting previous handler"); + } self.handler = Some(spec); return; } @@ -171,11 +191,14 @@ impl RadixNode { } } - /// Match a request path against this radix tree — O(M) where M = segment count + /// Match a request path against this radix tree — O(M) where M = segment count. + /// + /// /* @param segments — path segments split on '/' */ + /// /* @param param_values — stack-allocated accumulator for captured :param values */ fn match_path<'a, 'b>( &'a self, segments: &[&'b str], - param_values: &mut Vec<&'b str>, + param_values: &mut ArrayVec<&'b str, MAX_STACK_SEGMENTS>, ) -> Option<&'a DynamicRouteSpec> { if segments.is_empty() { return self.handler.as_ref(); @@ -320,7 +343,7 @@ impl Router { return Some(MatchedRoute { handler_id: route_spec.handler_id, cache_namespace: route_spec.cache_namespace, - param_values: Vec::new(), + param_values: ArrayVec::new(), param_names: route_spec.param_names.as_ref(), header_keys: route_spec.header_keys.as_ref(), full_headers: route_spec.full_headers, @@ -329,6 +352,7 @@ impl Router { needs_query: route_spec.needs_query, fast_path: route_spec.fast_path.as_ref(), cache_config: route_spec.cache_config.as_ref(), + max_body_bytes: route_spec.max_body_bytes, }); } @@ -336,7 +360,7 @@ impl Router { let tree = self.radix_trees.get(&method_key)?; let mut seg_buf = [""; MAX_STACK_SEGMENTS]; let seg_count = split_segments_stack(path, &mut seg_buf); - let mut param_values = Vec::with_capacity(4); + let mut param_values = ArrayVec::new(); let spec = if seg_count <= MAX_STACK_SEGMENTS { tree.match_path(&seg_buf[..seg_count], &mut param_values)? } else { @@ -356,6 +380,7 @@ impl Router { needs_query: spec.needs_query, fast_path: spec.fast_path.as_ref(), cache_config: spec.cache_config.as_ref(), + max_body_bytes: spec.max_body_bytes, }) } @@ -480,6 +505,7 @@ fn compile_dynamic_route_spec(route: &RouteInput, middlewares: &[MiddlewareInput needs_query: route.needs_query, fast_path: analyze_dynamic_fast_path(route, middlewares), cache_config, + max_body_bytes: route.max_body_bytes, } } @@ -642,8 +668,9 @@ fn build_exact_static_route_from_spec( } } +/// /* @param value — cache namespace string (e.g. "GET:/users/:id") */ fn hash_cache_namespace(value: &str) -> u64 { - let mut hasher = DefaultHasher::new(); + let mut hasher = keyed_state().build_hasher(); value.hash(&mut hasher); hasher.finish() } @@ -706,73 +733,7 @@ fn build_close_response(status: u16, headers: &HashMap, body: &[ build_response_bytes(status, headers, body, false) } -fn build_response_bytes( - status: u16, - headers: &HashMap, - body: &[u8], - keep_alive: bool, -) -> Vec { - let mut response = format!( - "HTTP/1.1 {} {}\r\ncontent-length: {}\r\nconnection: {}\r\n", - status, - status_reason(status), - body.len(), - if keep_alive { "keep-alive" } else { "close" } - ) - .into_bytes(); - - for (name, value) in headers { - // Security: skip headers with CRLF injection - if name.contains('\r') - || name.contains('\n') - || value.contains('\r') - || value.contains('\n') - { - continue; - } - response.extend_from_slice(name.as_bytes()); - response.extend_from_slice(b": "); - response.extend_from_slice(value.as_bytes()); - response.extend_from_slice(b"\r\n"); - } - - response.extend_from_slice(b"\r\n"); - response.extend_from_slice(body); - response -} - -// Todo: Are these expensive? if so remove them. - -fn status_reason(status: u16) -> &'static str { - match status { - 200 => "OK", - 201 => "Created", - 202 => "Accepted", - 204 => "No Content", - 301 => "Moved Permanently", - 302 => "Found", - 304 => "Not Modified", - 400 => "Bad Request", - 401 => "Unauthorized", - 403 => "Forbidden", - 404 => "Not Found", - 405 => "Method Not Allowed", - 408 => "Request Timeout", - 409 => "Conflict", - 411 => "Length Required", - 413 => "Payload Too Large", - 415 => "Unsupported Media Type", - 422 => "Unprocessable Entity", - 429 => "Too Many Requests", - 431 => "Request Header Fields Too Large", - 500 => "Internal Server Error", - 501 => "Not Implemented", - 502 => "Bad Gateway", - 503 => "Service Unavailable", - 504 => "Gateway Timeout", - _ => "Unknown", - } -} +use crate::http_utils::{build_response_bytes, status_reason}; // ─── Native Zero-Allocation LRU Cache ─── @@ -975,15 +936,23 @@ pub fn insert_cached_response(cache_namespace: u64, key: u64, entry: CacheEntry, }); } +/// Build a collision-resistant cache key from the route's vary configuration. +/// Uses per-process keyed SipHash (@S4/@S6) so attackers cannot predict collisions. +/// +/// /* @param config — route-level cache configuration with vary keys */ +/// /* @param headers — pre-parsed request header pairs */ +/// /* @param url — full request URL (for query param extraction) */ +/// /* @param param_names — route parameter names from the route spec */ +/// /* @param param_values — captured parameter values from the radix match */ pub fn interpolate_cache_key( config: &RouteCacheConfig, - parsed: &crate::ParsedRequest<'_>, + headers: &[(&str, &str)], url: &str, param_names: &[Box], param_values: &[&str], ) -> u64 { - let mut hasher = DefaultHasher::new(); - + let mut hasher = keyed_state().build_hasher(); + for vary_key in config.vary_keys.iter() { match vary_key { CacheVaryKey::QueryParam(name) => { @@ -1029,7 +998,7 @@ pub fn interpolate_cache_key( CacheVaryKey::Header(name) => { let name_str = name.as_ref(); let mut found = false; - for (h_name, h_val) in parsed.headers.iter() { + for (h_name, h_val) in headers.iter() { if h_name.eq_ignore_ascii_case(name_str) { name_str.hash(&mut hasher); h_val.hash(&mut hasher); @@ -1044,6 +1013,6 @@ pub fn interpolate_cache_key( } } } - + hasher.finish() } diff --git a/rsrc/src/session.rs b/rsrc/src/session.rs index c39353c..57db824 100644 --- a/rsrc/src/session.rs +++ b/rsrc/src/session.rs @@ -49,13 +49,17 @@ impl SameSite { SameSite::None => "None", } } +} + +impl std::str::FromStr for SameSite { + type Err = std::convert::Infallible; - pub fn from_str(s: &str) -> Self { - match s.to_ascii_lowercase().as_str() { + fn from_str(s: &str) -> Result { + Ok(match s.to_ascii_lowercase().as_str() { "strict" => SameSite::Strict, "none" => SameSite::None, _ => SameSite::Lax, - } + }) } } @@ -105,11 +109,6 @@ impl SessionEntry { self.last_accessed = now; self.expires_at = now + max_age; } - - /// Total size of all stored data in bytes. - fn data_size(&self) -> usize { - self.data.iter().map(|(k, v)| k.len() + v.len()).sum() - } } // ─── Session Shard ──────────────────────── @@ -131,16 +130,19 @@ impl SessionShard { pub struct SessionStore { shards: Box<[RwLock]>, config: SessionConfig, + max_per_shard: usize, } impl SessionStore { pub fn new(config: SessionConfig) -> Self { let shards: Vec> = (0..SHARD_COUNT).map(|_| RwLock::new(SessionShard::new())).collect(); + let max_per_shard = (config.max_sessions / SHARD_COUNT).max(1); Self { shards: shards.into_boxed_slice(), config, + max_per_shard, } } @@ -263,15 +265,43 @@ impl SessionStore { /// Create or update a session with the given data mutations. /// `mutations` contains only the changed keys. Existing keys not in /// `mutations` are preserved. + /// Returns `true` if the upsert succeeded, `false` if the projected data + /// size would exceed `max_data_size` (no modifications are applied). pub fn upsert( &self, id: &[u8; SESSION_ID_BYTES], mutations: HashMap>, deleted_keys: &[String], - ) { + ) -> bool { let shard_idx = self.shard_index(id); let mut shard = self.shards[shard_idx].write(); let max_age = Duration::from_secs(self.config.max_age_secs); + let is_new = !shard.map.contains_key(id); + + // --- D5: Project the data size BEFORE applying mutations --- + let current_data = if is_new { + None + } else { + Some(&shard.map[id].data) + }; + + let projected_size = Self::projected_data_size(current_data, &mutations, deleted_keys); + if projected_size > self.config.max_data_size { + return false; + } + + // --- S3: Enforce per-shard capacity with LRU eviction --- + if is_new && shard.map.len() >= self.max_per_shard { + // Evict the session with the oldest last_accessed timestamp + let lru_id = shard + .map + .iter() + .min_by_key(|(_, entry)| entry.last_accessed) + .map(|(id, _)| *id); + if let Some(lru_id) = lru_id { + shard.map.remove(&lru_id); + } + } let entry = shard .map @@ -288,14 +318,39 @@ impl SessionStore { entry.data.insert(key, value); } - // Enforce per-session data size limit - if entry.data_size() > self.config.max_data_size { - // Truncate by removing oldest entries until under limit. - // Simple strategy: just clear if over limit. - entry.data.clear(); + entry.touch(max_age); + true + } + + /// Calculate the projected data size after applying mutations and deletions + /// without actually modifying the entry. + fn projected_data_size( + current_data: Option<&HashMap>>, + mutations: &HashMap>, + deleted_keys: &[String], + ) -> usize { + let mut size: usize = 0; + + if let Some(data) = current_data { + for (k, v) in data { + // Skip keys that will be deleted + if deleted_keys.contains(k) { + continue; + } + // Skip keys that will be overwritten by mutations + if mutations.contains_key(k) { + continue; + } + size += k.len() + v.len(); + } } - entry.touch(max_age); + // Add sizes of all mutation entries + for (k, v) in mutations { + size += k.len() + v.len(); + } + + size } /// Destroy a session. diff --git a/src/audit-log.js b/src/audit-log.js new file mode 100644 index 0000000..fda993b --- /dev/null +++ b/src/audit-log.js @@ -0,0 +1,128 @@ +/** + * http-native audit logging middleware + * + * Emits structured security-relevant events for compliance (SOC2, PCI-DSS). + * Events are sent to a configurable sink — file, stream, or custom function. + * + * Usage: + * import { auditLog } from "@http-native/core/audit-log"; + * app.use(auditLog({ sink: (event) => console.log(JSON.stringify(event)) })); + */ + +/** + * @param {Object} options + * @param {Function} options.sink - Receives each audit event object + * @param {string[]} [options.events] - Event types to capture (default: all) + * @param {boolean} [options.includeHeaders] - Include request headers (default: false) + * @param {string[]} [options.redactHeaders] - Header names to redact from logs + * @param {boolean} [options.includeBody] - Include request body (default: false) + */ +export function auditLog(options = {}) { + if (typeof options.sink !== "function") { + throw new Error("auditLog requires a sink function"); + } + + const sink = options.sink; + const allowedEvents = options.events ? new Set(options.events) : null; + const includeHeaders = options.includeHeaders ?? false; + const includeBody = options.includeBody ?? false; + const redactSet = options.redactHeaders + ? new Set(options.redactHeaders.map((h) => h.toLowerCase())) + : new Set(["authorization", "cookie", "set-cookie"]); + + function emit(event) { + if (allowedEvents && !allowedEvents.has(event.type)) return; + event.timestamp = new Date().toISOString(); + try { + sink(event); + } catch { + /* Audit sink failure must not crash the request */ + } + } + + return async function auditLogMiddleware(req, res, next) { + const start = performance.now(); + + /* Capture request metadata */ + const event = { + type: "http.request", + method: req.method, + path: req.path, + ip: req.ip, + requestId: req.id ?? undefined, + userId: undefined, + statusCode: undefined, + durationMs: undefined, + }; + + if (includeHeaders) { + const headers = { ...req.headers }; + for (const name of redactSet) { + if (headers[name]) headers[name] = "[REDACTED]"; + } + event.headers = headers; + } + + if (includeBody && req.body != null) { + event.body = + typeof req.body === "string" ? req.body : JSON.stringify(req.body); + } + + try { + await next(); + } catch (err) { + event.type = "http.error"; + event.error = err.message ?? String(err); + event.statusCode = err.status ?? 500; + event.durationMs = Math.round((performance.now() - start) * 100) / 100; + emit(event); + throw err; + } + + event.statusCode = res._state?.status ?? 200; + event.durationMs = Math.round((performance.now() - start) * 100) / 100; + event.userId = req.userId ?? req.session?.userId ?? undefined; + emit(event); + }; +} + +/** + * Create pre-defined audit event emitters for use outside middleware. + * + * @param {Function} sink - The same sink passed to auditLog() + */ +export function createAuditEmitter(sink) { + return { + emit(type, data = {}) { + sink({ type, timestamp: new Date().toISOString(), ...data }); + }, + authLogin: (userId, ip) => + sink({ + type: "auth.login", + userId, + ip, + timestamp: new Date().toISOString(), + }), + authLogout: (userId, ip) => + sink({ + type: "auth.logout", + userId, + ip, + timestamp: new Date().toISOString(), + }), + authFailed: (reason, ip) => + sink({ + type: "auth.failed", + reason, + ip, + timestamp: new Date().toISOString(), + }), + rateLimitExceeded: (ip, path) => + sink({ + type: "rate_limit.exceeded", + ip, + path, + timestamp: new Date().toISOString(), + }), + }; +} diff --git a/src/body-limit.js b/src/body-limit.js new file mode 100644 index 0000000..b8d13ad --- /dev/null +++ b/src/body-limit.js @@ -0,0 +1,89 @@ +/** + * http-native per-route request body size limit middleware. + * + * Provides a JS-side guard for body size enforcement. The primary + * enforcement is in Rust (via `maxBodyBytes` in the route manifest), + * but this middleware offers a composable fallback for routes where + * the Rust-side limit is not configured or for post-parse validation. + * + * Usage: + * import { bodyLimit } from "@http-native/core/body-limit"; + * + * app.post("/upload", bodyLimit("50mb"), handler); + * app.post("/api/data", bodyLimit("1mb"), handler); + * app.post("/small", bodyLimit(1024), handler); // bytes + */ + +// ─── Size Parser ────────────────────────── + +const SIZE_UNITS = { + b: 1, + kb: 1024, + mb: 1024 * 1024, + gb: 1024 * 1024 * 1024, +}; + +/** + * Parse a human-readable size string into bytes. + * + * @param {string|number} input - e.g. "50mb", "1kb", 1024 + * @returns {number} Size in bytes + */ +function parseSize(input) { + if (typeof input === "number") { + if (!Number.isFinite(input) || input < 0) { + throw new TypeError("bodyLimit size must be a non-negative number"); + } + return Math.floor(input); + } + + if (typeof input !== "string") { + throw new TypeError("bodyLimit size must be a string or number"); + } + + const match = input.trim().match(/^(\d+(?:\.\d+)?)\s*(b|kb|mb|gb)?$/i); + if (!match) { + throw new TypeError(`Invalid body limit size: "${input}"`); + } + + const value = parseFloat(match[1]); + const unit = (match[2] || "b").toLowerCase(); + const multiplier = SIZE_UNITS[unit]; + + if (!multiplier) { + throw new TypeError(`Unknown size unit: "${match[2]}"`); + } + + return Math.floor(value * multiplier); +} + +// ─── Middleware Factory ─────────────────── + +/** + * Create a body size limit middleware. + * + * @param {string|number} limit - Maximum body size (e.g. "50mb", 1024) + * @returns {Function} Middleware function + */ +export function bodyLimit(limit) { + const maxBytes = parseSize(limit); + + return async function bodyLimitMiddleware(req, res, next) { + /* Check Content-Length header first (fast reject before body access) */ + const contentLength = req.header("content-length"); + if (contentLength !== undefined) { + const length = parseInt(contentLength, 10); + if (Number.isFinite(length) && length > maxBytes) { + return res.status(413).json({ error: "Payload Too Large" }); + } + } + + /* Also check actual body size — Content-Length can be absent or wrong */ + const body = req.body; + if (body != null && body.length > maxBytes) { + return res.status(413).json({ error: "Payload Too Large" }); + } + + await next(); + }; +} diff --git a/src/bridge.js b/src/bridge.js index c30d884..db04ce1 100644 --- a/src/bridge.js +++ b/src/bridge.js @@ -1,4 +1,5 @@ import { Buffer } from "node:buffer"; +import { parse as acornParse } from "acorn"; const textEncoder = new TextEncoder(); const textDecoder = new TextDecoder(); @@ -112,9 +113,198 @@ export function compileRouteShape(method, path) { * @returns {Object} Frozen access plan describing required request fields */ export function analyzeRequestAccess(source) { - const plan = createEmptyAccessPlan(); const normalizedSource = String(source ?? ""); + /* @R3: try AST-based analysis first — falls back to regex on parse failure */ + const astPlan = analyzeRequestAccessAST(normalizedSource); + if (astPlan) { + astPlan.jsonFastPath = detectJsonFastPath(normalizedSource); + return freezeAccessPlan(astPlan); + } + + return analyzeRequestAccessRegex(normalizedSource); +} + +/* @R3: AST-based access analysis using acorn. Walks MemberExpression nodes + * to find req.params.*, req.query.*, req.headers.*, req.method, etc. + * Correctly ignores string literals and comments that fool the regex path. */ +function analyzeRequestAccessAST(source) { + let ast; + try { + /* wrap source as a program — handles arrow functions, function declarations, + * and function expressions returned by Function.prototype.toString() */ + ast = acornParse(`(${source})`, { + ecmaVersion: "latest", + sourceType: "module", + allowAwaitOutsideFunction: true, + allowReturnOutsideFunction: true, + }); + } catch { + return null; /* parse failed — caller falls back to regex */ + } + + const plan = createEmptyAccessPlan(); + const reqNames = findReqParamNames(ast); + if (reqNames.size === 0) { + return plan; /* handler doesn't declare a req parameter */ + } + + walkNode(ast, (node) => { + /* Handle destructuring: const { headers, query, params } = req */ + if (node.type === "VariableDeclarator" + && node.id?.type === "ObjectPattern" + && isReqIdentifier(node.init, reqNames)) { + /* Destructuring from req — mark all destructured properties as full access */ + for (const prop of node.id.properties) { + const key = prop.key?.name ?? prop.key?.value; + if (!key) { plan.fullParams = true; plan.fullQuery = true; plan.fullHeaders = true; plan.dispatchKind = "generic_fallback"; continue; } + switch (key) { + case "method": plan.method = true; break; + case "path": plan.path = true; break; + case "url": plan.url = true; break; + case "params": plan.fullParams = true; plan.dispatchKind = "generic_fallback"; break; + case "query": plan.fullQuery = true; plan.dispatchKind = "generic_fallback"; break; + case "headers": plan.fullHeaders = true; plan.dispatchKind = "generic_fallback"; break; + default: break; + } + } + return; + } + + if (node.type !== "MemberExpression") return; + if (!isReqIdentifier(node.object, reqNames)) return; + + const prop = memberPropName(node); + if (!prop) { + /* dynamic bracket access on req itself: req[variable] */ + plan.method = true; + plan.path = true; + plan.url = true; + plan.fullParams = true; + plan.fullQuery = true; + plan.fullHeaders = true; + plan.dispatchKind = "generic_fallback"; + return; + } + + switch (prop) { + case "method": plan.method = true; break; + case "path": plan.path = true; break; + case "url": plan.url = true; break; + case "params": markSubAccess(node, plan, "paramKeys", "fullParams", identity); break; + case "query": markSubAccess(node, plan, "queryKeys", "fullQuery", identity); break; + case "headers": markSubAccess(node, plan, "headerKeys", "fullHeaders", normalizeHeaderLookup); break; + case "header": + /* req.header("content-type") call pattern */ + markCallArg(node, plan, "headerKeys", "fullHeaders", normalizeHeaderLookup); + break; + case "body": break; /* body is always materialized when present */ + default: break; + } + }); + + return plan; +} + +/* find the parameter name(s) that represent the request object. + * handles: (req, res) =>, function(request, response), destructuring */ +function findReqParamNames(ast) { + const names = new Set(); + walkNode(ast, (node) => { + if ( + node.type === "ArrowFunctionExpression" || + node.type === "FunctionExpression" || + node.type === "FunctionDeclaration" + ) { + const first = node.params?.[0]; + if (first?.type === "Identifier") { + names.add(first.name); + } else if (first?.type === "ObjectPattern") { + /* destructured req: ({ params, query }) => ... — full access */ + names.add("__destructured__"); + } + } + }); + return names; +} + +/* check if a node is an Identifier matching a known req param name */ +function isReqIdentifier(node, reqNames) { + return node?.type === "Identifier" && reqNames.has(node.name); +} + +/* extract a static property name from a MemberExpression, or null if dynamic */ +function memberPropName(node) { + if (!node.computed && node.property?.type === "Identifier") { + return node.property.name; + } + if (node.computed && node.property?.type === "Literal" && typeof node.property.value === "string") { + return node.property.value; + } + return null; +} + +/* given req.params (or .query/.headers) as a MemberExpression, check + * if the parent accesses a specific key or the whole object */ +function markSubAccess(parentNode, plan, keysField, fullField, transform) { + const grandparent = parentNode._parent; + if (!grandparent || grandparent.type !== "MemberExpression" || grandparent.object !== parentNode) { + /* bare `req.params` usage (passed as argument, spread, etc.) */ + plan[fullField] = true; + plan.dispatchKind = "generic_fallback"; + return; + } + const key = memberPropName(grandparent); + if (key) { + plan[keysField].add(transform(key)); + } else { + /* dynamic bracket: req.params[variable] */ + plan[fullField] = true; + plan.dispatchKind = "generic_fallback"; + } +} + +/* given req.header as a MemberExpression, check if it's called with a + * static string argument: req.header("content-type") */ +function markCallArg(parentNode, plan, keysField, fullField, transform) { + const grandparent = parentNode._parent; + if (grandparent?.type === "CallExpression" && grandparent.callee === parentNode) { + const arg = grandparent.arguments?.[0]; + if (arg?.type === "Literal" && typeof arg.value === "string") { + plan[keysField].add(transform(arg.value)); + return; + } + } + /* dynamic call: req.header(variable) */ + plan[fullField] = true; + plan.dispatchKind = "generic_fallback"; +} + +/* minimal recursive AST walker that sets _parent on each child node */ +function walkNode(node, visitor) { + if (!node || typeof node !== "object") return; + visitor(node); + for (const key of Object.keys(node)) { + if (key === "_parent") continue; + const child = node[key]; + if (Array.isArray(child)) { + for (const item of child) { + if (item && typeof item.type === "string") { + item._parent = node; + walkNode(item, visitor); + } + } + } else if (child && typeof child.type === "string") { + child._parent = node; + walkNode(child, visitor); + } + } +} + +/* @R3: regex fallback — original analysis path for unparseable sources */ +function analyzeRequestAccessRegex(normalizedSource) { + const plan = createEmptyAccessPlan(); + plan.method = /\breq\.method\b/.test(normalizedSource); plan.path = /\breq\.path\b/.test(normalizedSource); plan.url = /\breq\.url\b/.test(normalizedSource); @@ -215,6 +405,35 @@ export function mergeRequestAccessPlans(plans) { const REQUEST_POOL_MAX = 512; const requestPool = []; +// @B2.5 — Extended pools for sub-objects (headers, params, query) +// These pools reuse null-prototype objects instead of creating fresh ones each request. +const HEADER_POOL_MAX = 256; +const PARAMS_POOL_MAX = 128; +const headerPool = []; +const paramsPool = []; + +export function acquireHeaderObject() { + return headerPool.pop() || Object.create(null); +} + +export function releaseHeaderObject(obj) { + if (headerPool.length >= HEADER_POOL_MAX) return; + const keys = Object.keys(obj); + for (let i = 0; i < keys.length; i++) delete obj[keys[i]]; + headerPool.push(obj); +} + +export function acquireParamsObject() { + return paramsPool.pop() || Object.create(null); +} + +export function releaseParamsObject(obj) { + if (paramsPool.length >= PARAMS_POOL_MAX) return; + const keys = Object.keys(obj); + for (let i = 0; i < keys.length; i++) delete obj[keys[i]]; + paramsPool.push(obj); +} + function acquireRequestObject() { return requestPool.pop() || null; } @@ -485,7 +704,45 @@ export function createRequestFactory( * @param {string} [mode="fallback"] - Serialization mode hint ("fallback"|"generic"|"specialized") * @returns {Function & { kind: string }} Serializer: (value) => Buffer */ -export function createJsonSerializer(mode = "fallback") { +/** + * Create a JSON serializer optimized for common response shapes (@B4.3). + * + * "fallback" mode uses JSON.stringify (safe for any shape). + * "shape" mode generates a hand-rolled serializer for objects with known keys, + * avoiding the overhead of JSON.stringify's generic traversal. + * + * @param {"fallback"|"shape"} mode + * @param {string[]} [keys] — known object keys (required for "shape" mode) + * @returns {(value: unknown) => Buffer} + */ +export function createJsonSerializer(mode = "fallback", keys) { + if (mode === "shape" && keys && keys.length > 0) { + /* Build a hand-rolled serializer for the known shape. + * For { id, name, email } produces: '{"id":' + JSON(id) + ',"name":' + JSON(name) + ... + * This avoids Object.keys() and generic property traversal. */ + const prefix = keys.map((k, i) => (i === 0 ? '{"' : ',"') + escapeJsonString(k) + '":'); + const serializer = (value) => { + if (value === null || value === undefined) return Buffer.from("null"); + let out = ""; + for (let i = 0; i < keys.length; i++) { + out += prefix[i]; + const v = value[keys[i]]; + if (typeof v === "string") { + out += '"' + escapeJsonString(v) + '"'; + } else if (v === null || v === undefined) { + out += "null"; + } else { + out += JSON.stringify(v); + } + } + out += "}"; + return Buffer.from(out, "utf8"); + }; + serializer.kind = "shape"; + serializer.keys = keys; + return serializer; + } + const serializer = (value) => { const serialized = JSON.stringify(value); return Buffer.from(serialized, "utf8"); @@ -494,6 +751,22 @@ export function createJsonSerializer(mode = "fallback") { return serializer; } +/** Escape a string for safe embedding in JSON (@B4.3 helper). */ +function escapeJsonString(str) { + let escaped = ""; + for (let i = 0; i < str.length; i++) { + const ch = str.charCodeAt(i); + if (ch === 0x22) escaped += '\\"'; // " + else if (ch === 0x5c) escaped += "\\\\"; // \ + else if (ch === 0x0a) escaped += "\\n"; + else if (ch === 0x0d) escaped += "\\r"; + else if (ch === 0x09) escaped += "\\t"; + else if (ch < 0x20) escaped += "\\u" + ch.toString(16).padStart(4, "0"); + else escaped += str[i]; + } + return escaped; +} + // ─── Binary Protocol Codec ────────────── /** @@ -1086,7 +1359,7 @@ function readU16(bytes, offset) { if (offset + 2 > bytes.byteLength) { throw new Error("Request envelope truncated"); } - return bytes[offset] | (bytes[offset + 1] << 8); + return (bytes[offset] | (bytes[offset + 1] << 8)) >>> 0; } function readU32(bytes, offset) { diff --git a/src/circuit-breaker.js b/src/circuit-breaker.js new file mode 100644 index 0000000..a73be6d --- /dev/null +++ b/src/circuit-breaker.js @@ -0,0 +1,146 @@ +/** + * http-native circuit breaker + * + * Protects downstream services from cascading failures using the + * closed → open → half-open state machine. + * + * Usage: + * import { circuitBreaker } from "@http-native/core/circuit-breaker"; + * + * const dbBreaker = circuitBreaker({ name: "database", threshold: 5, timeout: 30000 }); + * const result = await dbBreaker.call(() => db.query("SELECT 1")); + */ + +const STATE_CLOSED = "closed"; +const STATE_OPEN = "open"; +const STATE_HALF_OPEN = "half-open"; + +/** + * @param {Object} options + * @param {string} options.name - Circuit name (for logging/metrics) + * @param {number} [options.threshold=5] - Consecutive failures before opening + * @param {number} [options.timeout=30000] - ms in open state before half-open probe + * @param {number} [options.halfOpenMax=1] - Max concurrent requests in half-open state + * @param {Function} [options.onOpen] - Called when circuit opens + * @param {Function} [options.onHalfOpen] - Called when circuit transitions to half-open + * @param {Function} [options.onClose] - Called when circuit closes (healthy again) + * @param {Function} [options.isFailure] - Custom failure detection (default: any thrown error) + */ +export function circuitBreaker(options) { + const { + name, + threshold = 5, + timeout = 30000, + halfOpenMax = 1, + onOpen, + onHalfOpen, + onClose, + isFailure = () => true, + } = options; + + if (!name) throw new Error("Circuit breaker requires a name"); + + let state = STATE_CLOSED; + let failureCount = 0; + let lastFailureTime = 0; + let halfOpenActive = 0; + + function trip() { + if (state === STATE_OPEN) return; + state = STATE_OPEN; + lastFailureTime = Date.now(); + if (typeof onOpen === "function") { + try { onOpen(); } catch { /* lifecycle callback failure must not propagate */ } + } + } + + function reset() { + failureCount = 0; + halfOpenActive = 0; + if (state !== STATE_CLOSED) { + state = STATE_CLOSED; + if (typeof onClose === "function") { + try { onClose(); } catch { /* lifecycle callback failure must not propagate */ } + } + } + } + + function tryHalfOpen() { + if (state !== STATE_OPEN) return false; + if (Date.now() - lastFailureTime < timeout) return false; + if (halfOpenActive >= halfOpenMax) return false; + + if (state !== STATE_HALF_OPEN) { + state = STATE_HALF_OPEN; + if (typeof onHalfOpen === "function") { + try { onHalfOpen(); } catch { /* lifecycle callback failure must not propagate */ } + } + } + halfOpenActive++; + return true; + } + + return { + get name() { return name; }, + get state() { return state; }, + get failureCount() { return failureCount; }, + + /** + * Execute a function through the circuit breaker. + * + * @template T + * @param {() => Promise} fn + * @returns {Promise} + */ + async call(fn) { + /* Closed — allow the call */ + if (state === STATE_CLOSED) { + try { + const result = await fn(); + failureCount = 0; + return result; + } catch (err) { + if (isFailure(err)) { + failureCount++; + if (failureCount >= threshold) trip(); + } + throw err; + } + } + + /* Open — check if we should probe */ + if (state === STATE_OPEN) { + if (!tryHalfOpen()) { + throw new CircuitOpenError(name); + } + } + + /* Half-open — probe the downstream */ + try { + const result = await fn(); + reset(); + return result; + } catch (err) { + halfOpenActive--; + trip(); + throw err; + } + }, + + /** Manually reset the circuit to closed state */ + reset, + + /** Manually trip the circuit to open state */ + trip, + }; +} + +export class CircuitOpenError extends Error { + constructor(circuitName) { + super(`Circuit "${circuitName}" is open — request rejected`); + this.name = "CircuitOpenError"; + this.circuit = circuitName; + this.status = 503; + this.code = "CIRCUIT_OPEN"; + } +} diff --git a/src/cors.js b/src/cors.js index 3a36758..1001d09 100644 --- a/src/cors.js +++ b/src/cors.js @@ -70,7 +70,14 @@ export function cors(options = {}) { } if (effectiveOrigin !== "*") { - res.set("Vary", "Origin"); + const existing = res.get("vary"); + if (existing) { + if (!existing.toLowerCase().includes("origin")) { + res.set("vary", `${existing}, Origin`); + } + } else { + res.set("vary", "Origin"); + } } if (exposedHeadersString) { diff --git a/src/csrf.js b/src/csrf.js new file mode 100644 index 0000000..3a2be5b --- /dev/null +++ b/src/csrf.js @@ -0,0 +1,194 @@ +/** + * http-native CSRF protection middleware. + * + * Implements the double-submit cookie pattern: a random token is set in + * a cookie and must be echoed back in a request header or body field. + * Safe methods (GET, HEAD, OPTIONS) are skipped by default. + * + * Usage: + * import { csrf } from "@http-native/core/csrf"; + * + * app.use(csrf()); + * + * app.use(csrf({ + * cookie: { name: "_csrf", httpOnly: true, sameSite: "strict" }, + * ignoreMethods: ["GET", "HEAD", "OPTIONS"], + * tokenHeader: "x-csrf-token", + * tokenField: "_csrf", + * })); + */ + +import { randomBytes, timingSafeEqual } from "node:crypto"; + +// ─── Constants ──────────────────────────── + +const TOKEN_BYTES = 32; +const SAFE_METHODS = new Set(["GET", "HEAD", "OPTIONS"]); + +// ─── Token Generation ───────────────────── + +/** + * Generate a cryptographically random CSRF token. + * + * @returns {string} Hex-encoded 32-byte token + */ +function generateToken() { + return randomBytes(TOKEN_BYTES).toString("hex"); +} + +// ─── Cookie Helpers ─────────────────────── + +/** + * Build a Set-Cookie header value from options. + * + * @param {string} name + * @param {string} value + * @param {Object} options + * @returns {string} + */ +function buildSetCookie(name, value, options) { + let cookie = `${name}=${value}; Path=${options.path ?? "/"}`; + + if (options.httpOnly !== false) { + cookie += "; HttpOnly"; + } + if (options.secure) { + cookie += "; Secure"; + } + if (options.sameSite) { + cookie += `; SameSite=${options.sameSite}`; + } + if (typeof options.maxAge === "number") { + cookie += `; Max-Age=${options.maxAge}`; + } + + return cookie; +} + +/** + * Extract a cookie value by name from the Cookie header. + * + * @param {string|undefined} cookieHeader + * @param {string} name + * @returns {string|undefined} + */ +function parseCookieValue(cookieHeader, name) { + if (!cookieHeader) return undefined; + + const prefix = `${name}=`; + const cookies = cookieHeader.split(";"); + + for (let i = 0; i < cookies.length; i++) { + const trimmed = cookies[i].trim(); + if (trimmed.startsWith(prefix)) { + return trimmed.slice(prefix.length); + } + } + + return undefined; +} + +// ─── Middleware Factory ─────────────────── + +/** + * Create a CSRF protection middleware. + * + * @param {Object} [options] + * @param {Object} [options.cookie] - Cookie options + * @param {string} [options.cookie.name="_csrf"] - Cookie name + * @param {boolean} [options.cookie.httpOnly=true] + * @param {string} [options.cookie.sameSite="strict"] + * @param {boolean} [options.cookie.secure=false] + * @param {string} [options.cookie.path="/"] + * @param {number} [options.cookie.maxAge] + * @param {string[]} [options.ignoreMethods] - Methods to skip (default: GET, HEAD, OPTIONS) + * @param {string} [options.tokenHeader] - Header to read token from (default: x-csrf-token) + * @param {string} [options.tokenField] - Body field to read token from (default: _csrf) + * @returns {Function} Middleware function + */ +export function csrf(options = {}) { + if (typeof options !== "object" || options === null) { + throw new TypeError("csrf(options) expects an object"); + } + + const cookieOpts = { + name: "_csrf", + httpOnly: true, + sameSite: "strict", + secure: false, + path: "/", + ...options.cookie, + }; + + const ignoreMethods = options.ignoreMethods + ? new Set(options.ignoreMethods.map((m) => m.toUpperCase())) + : SAFE_METHODS; + + const tokenHeader = (options.tokenHeader ?? "x-csrf-token").toLowerCase(); + const tokenField = options.tokenField ?? "_csrf"; + + return async function csrfMiddleware(req, res, next) { + const cookieHeader = req.header("cookie"); + let cookieToken = parseCookieValue(cookieHeader, cookieOpts.name); + + /* Ensure a CSRF cookie is always set — new visitors get a token on + * their first request (even safe methods) so forms can include it. */ + if (!cookieToken) { + cookieToken = generateToken(); + res.set( + "set-cookie", + buildSetCookie(cookieOpts.name, cookieToken, cookieOpts), + ); + } + + /* Expose token on req so templates/handlers can embed it in forms */ + req.csrfToken = cookieToken; + + /* Safe methods pass through — only state-changing requests need validation */ + if (ignoreMethods.has(req.method)) { + return next(); + } + + /* Validate: token must match in header or body field */ + const headerToken = req.header(tokenHeader); + + /* Extract body token: req.body is a Buffer, so parse JSON first. + * Only attempt JSON parse for content-types that could carry form data. */ + let bodyToken; + if (req.body && Buffer.isBuffer(req.body) && req.body.length > 0) { + try { + const parsed = JSON.parse(req.body.toString("utf8")); + if (parsed && typeof parsed === "object") { + bodyToken = parsed[tokenField]; + } + } catch { + /* Not JSON — ignore body token */ + } + } + + /* Constant-time comparison to prevent timing side-channel attacks */ + if (safeTokenEquals(headerToken, cookieToken) || safeTokenEquals(bodyToken, cookieToken)) { + return next(); + } + + return res.status(403).json({ + error: "Forbidden", + code: "CSRF_TOKEN_MISMATCH", + message: "CSRF token validation failed", + }); + }; +} + +/** + * Constant-time string comparison to prevent timing side-channel attacks. + * Returns false if either value is not a string or lengths differ. + */ +function safeTokenEquals(a, b) { + if (typeof a !== "string" || typeof b !== "string") return false; + if (a.length !== b.length) return false; + try { + return timingSafeEqual(Buffer.from(a), Buffer.from(b)); + } catch { + return false; + } +} diff --git a/src/env.js b/src/env.js new file mode 100644 index 0000000..d65cf45 --- /dev/null +++ b/src/env.js @@ -0,0 +1,90 @@ +/** + * http-native environment configuration (DX-6.4) + * + * Type-safe .env loading with validation and coercion. Reads from + * process.env and an optional .env file (via Bun's built-in support). + * + * Usage: + * import { loadEnv } from "@http-native/core/env"; + * + * const env = loadEnv({ + * PORT: { type: "number", default: 3000 }, + * DATABASE_URL: { type: "string", required: true }, + * DEBUG: { type: "boolean", default: false }, + * }); + * console.log(env.PORT); // 3000 + */ + +/** + * @param {Record} schema + * @param {{ prefix?: string, envFile?: string }} [options] + * @returns {Record} + */ +export function loadEnv(schema, options = {}) { + const prefix = options.prefix ?? ""; + const env = Object.create(null); + const errors = []; + + for (const [key, spec] of Object.entries(schema)) { + const envKey = prefix + key; + const raw = process.env[envKey]; + + if (raw === undefined || raw === "") { + if (spec.required) { + errors.push(`Missing required env var: ${envKey}`); + continue; + } + env[key] = spec.default ?? undefined; + continue; + } + + try { + env[key] = coerce(raw, spec.type ?? "string", envKey); + } catch (err) { + errors.push(err.message); + } + } + + if (errors.length > 0) { + throw new EnvValidationError(errors); + } + + return Object.freeze(env); +} + +/** + * @param {string} value + * @param {"string"|"number"|"boolean"|"json"} type + * @param {string} key + */ +function coerce(value, type, key) { + switch (type) { + case "string": + return value; + case "number": { + const n = Number(value); + if (Number.isNaN(n)) throw new Error(`Env var ${key} must be a number, got "${value}"`); + return n; + } + case "boolean": + if (value === "true" || value === "1" || value === "yes") return true; + if (value === "false" || value === "0" || value === "no" || value === "") return false; + throw new Error(`Env var ${key} must be a boolean, got "${value}"`); + case "json": + try { + return JSON.parse(value); + } catch { + throw new Error(`Env var ${key} must be valid JSON, got "${value}"`); + } + default: + return value; + } +} + +export class EnvValidationError extends Error { + constructor(errors) { + super(`Environment validation failed:\n - ${errors.join("\n - ")}`); + this.name = "EnvValidationError"; + this.errors = errors; + } +} diff --git a/src/error.js b/src/error.js new file mode 100644 index 0000000..24fe445 --- /dev/null +++ b/src/error.js @@ -0,0 +1,222 @@ +/** + * http-native structured error types. + * + * Provides a typed error system with error codes for programmatic handling. + * HttpError instances are automatically serialized into structured JSON + * responses by the framework's error handler. + * + * Usage: + * import { HttpError, BadRequest, NotFound, Unauthorized } from "@http-native/core/error"; + * + * throw new HttpError(422, "VALIDATION_FAILED", "Invalid input", { fields: errors }); + * throw new BadRequest("Missing required field: name"); + * throw new NotFound("User not found"); + * throw new Unauthorized("Invalid token"); + */ + +// ─── Base Error Class ───────────────────── + +/** + * Structured HTTP error with status code, machine-readable error code, + * human-readable message, and optional detail payload. + * + * @extends Error + */ +export class HttpError extends Error { + /** + * @param {number} status - HTTP status code + * @param {string} [code] - Machine-readable error code (e.g. "VALIDATION_FAILED") + * @param {string} [message] - Human-readable error message + * @param {Object} [details] - Additional error details + */ + constructor(status, code, message, details) { + /* Support (status, message) shorthand — shift arguments */ + if (typeof code === "string" && message === undefined && details === undefined) { + if (!code.includes("_") && code.length > 20) { + /* Looks like a message, not a code */ + super(code); + this.status = status; + this.code = defaultCodeForStatus(status); + this.details = undefined; + } else { + super(code); + this.status = status; + this.code = code; + this.details = undefined; + } + } else { + super(message ?? `HTTP ${status}`); + this.status = status; + this.code = code ?? defaultCodeForStatus(status); + this.details = details; + } + + this.name = "HttpError"; + } + + /** + * Serialize to a plain object suitable for JSON response. + * + * @returns {{ status: number, code: string, message: string, details?: Object }} + */ + toJSON() { + const json = { + status: this.status, + code: this.code, + message: this.message, + }; + if (this.details !== undefined) { + json.details = this.details; + } + return json; + } +} + +// ─── Common Error Factories ─────────────── + +/** + * 400 Bad Request + * @param {string} [message] + * @param {Object} [details] + */ +export class BadRequest extends HttpError { + constructor(message, details) { + super(400, "BAD_REQUEST", message ?? "Bad Request", details); + this.name = "BadRequest"; + } +} + +/** + * 401 Unauthorized + * @param {string} [message] + * @param {Object} [details] + */ +export class Unauthorized extends HttpError { + constructor(message, details) { + super(401, "UNAUTHORIZED", message ?? "Unauthorized", details); + this.name = "Unauthorized"; + } +} + +/** + * 403 Forbidden + * @param {string} [message] + * @param {Object} [details] + */ +export class Forbidden extends HttpError { + constructor(message, details) { + super(403, "FORBIDDEN", message ?? "Forbidden", details); + this.name = "Forbidden"; + } +} + +/** + * 404 Not Found + * @param {string} [message] + * @param {Object} [details] + */ +export class NotFound extends HttpError { + constructor(message, details) { + super(404, "NOT_FOUND", message ?? "Not Found", details); + this.name = "NotFound"; + } +} + +/** + * 409 Conflict + * @param {string} [message] + * @param {Object} [details] + */ +export class Conflict extends HttpError { + constructor(message, details) { + super(409, "CONFLICT", message ?? "Conflict", details); + this.name = "Conflict"; + } +} + +/** + * 422 Unprocessable Entity + * @param {string} [message] + * @param {Object} [details] + */ +export class UnprocessableEntity extends HttpError { + constructor(message, details) { + super(422, "UNPROCESSABLE_ENTITY", message ?? "Unprocessable Entity", details); + this.name = "UnprocessableEntity"; + } +} + +/** + * 429 Too Many Requests + * @param {string} [message] + * @param {Object} [details] + */ +export class TooManyRequests extends HttpError { + constructor(message, details) { + super(429, "TOO_MANY_REQUESTS", message ?? "Too Many Requests", details); + this.name = "TooManyRequests"; + } +} + +/** + * 500 Internal Server Error + * @param {string} [message] + * @param {Object} [details] + */ +export class InternalServerError extends HttpError { + constructor(message, details) { + super(500, "INTERNAL_SERVER_ERROR", message ?? "Internal Server Error", details); + this.name = "InternalServerError"; + } +} + +/** + * 502 Bad Gateway + * @param {string} [message] + * @param {Object} [details] + */ +export class BadGateway extends HttpError { + constructor(message, details) { + super(502, "BAD_GATEWAY", message ?? "Bad Gateway", details); + this.name = "BadGateway"; + } +} + +/** + * 503 Service Unavailable + * @param {string} [message] + * @param {Object} [details] + */ +export class ServiceUnavailable extends HttpError { + constructor(message, details) { + super(503, "SERVICE_UNAVAILABLE", message ?? "Service Unavailable", details); + this.name = "ServiceUnavailable"; + } +} + +// ─── Helpers ────────────────────────────── + +/** + * Map common HTTP status codes to default error code strings. + * + * @param {number} status + * @returns {string} + */ +function defaultCodeForStatus(status) { + switch (status) { + case 400: return "BAD_REQUEST"; + case 401: return "UNAUTHORIZED"; + case 403: return "FORBIDDEN"; + case 404: return "NOT_FOUND"; + case 405: return "METHOD_NOT_ALLOWED"; + case 409: return "CONFLICT"; + case 413: return "PAYLOAD_TOO_LARGE"; + case 422: return "UNPROCESSABLE_ENTITY"; + case 429: return "TOO_MANY_REQUESTS"; + case 500: return "INTERNAL_SERVER_ERROR"; + case 502: return "BAD_GATEWAY"; + case 503: return "SERVICE_UNAVAILABLE"; + case 504: return "GATEWAY_TIMEOUT"; + default: return `HTTP_${status}`; + } +} diff --git a/src/helmet.js b/src/helmet.js new file mode 100644 index 0000000..8e4288e --- /dev/null +++ b/src/helmet.js @@ -0,0 +1,221 @@ +/** + * http-native security headers middleware (helmet). + * + * Sets sensible default security headers on every response. Each header + * is individually configurable or can be disabled by setting it to `false`. + * Header values are pre-computed at middleware creation time — per-request + * cost is the minimum: a flat loop of `res.set()` calls. + * + * Usage: + * import { helmet } from "@http-native/core/helmet"; + * + * // Sane defaults + * app.use(helmet()); + * + * // Customize + * app.use(helmet({ + * hsts: { maxAge: 63072000, includeSubDomains: true, preload: true }, + * contentSecurityPolicy: { directives: { defaultSrc: ["'self'"], scriptSrc: ["'self'", "cdn.example.com"] } }, + * xFrameOptions: "SAMEORIGIN", + * permissionsPolicy: { camera: [], microphone: [], geolocation: ["self"] }, + * })); + * + * // Disable a specific header + * app.use(helmet({ xFrameOptions: false })); + */ + +// ─── Default Values ──────────────────── + +const DEFAULT_HSTS_MAX_AGE = 31536000; // 1 year + +// ─── Header Builders ─────────────────── + +/** + * Build HSTS header value from options. + * + * @param {Object|boolean} hsts + * @returns {string|null} + */ +function buildHsts(hsts) { + if (hsts === false) return null; + + const config = typeof hsts === "object" && hsts !== null ? hsts : {}; + const maxAge = Number(config.maxAge ?? DEFAULT_HSTS_MAX_AGE); + if (!Number.isFinite(maxAge) || maxAge < 0) { + throw new TypeError("helmet hsts.maxAge must be a non-negative number"); + } + + let value = `max-age=${Math.floor(maxAge)}`; + if (config.includeSubDomains !== false) { + value += "; includeSubDomains"; + } + if (config.preload === true) { + value += "; preload"; + } + return value; +} + +/** + * Build CSP header value from directives. + * + * @param {Object|boolean} csp + * @returns {string|null} + */ +function buildContentSecurityPolicy(csp) { + if (csp === false || csp === undefined || csp === null) return null; + if (csp === true) return "default-src 'self'"; + + const directives = csp.directives ?? csp; + if (typeof directives !== "object" || directives === null) { + throw new TypeError("helmet contentSecurityPolicy.directives must be an object"); + } + + const parts = []; + for (const [key, sources] of Object.entries(directives)) { + /* camelCase → kebab-case: defaultSrc → default-src */ + const directive = key.replace(/[A-Z]/g, (c) => `-${c.toLowerCase()}`); + const value = Array.isArray(sources) ? sources.join(" ") : String(sources); + parts.push(`${directive} ${value}`); + } + + return parts.join("; "); +} + +/** + * Build Permissions-Policy header value. + * + * @param {Object|boolean} policy + * @returns {string|null} + */ +function buildPermissionsPolicy(policy) { + if (policy === false || policy === undefined || policy === null) return null; + if (policy === true) { + return "camera=(), microphone=(), geolocation=()"; + } + + if (typeof policy !== "object") { + throw new TypeError("helmet permissionsPolicy must be an object"); + } + + const parts = []; + for (const [feature, allowlist] of Object.entries(policy)) { + const directive = feature.replace(/[A-Z]/g, (c) => `-${c.toLowerCase()}`); + if (Array.isArray(allowlist)) { + parts.push(`${directive}=(${allowlist.join(" ")})`); + } else { + parts.push(`${directive}=${String(allowlist)}`); + } + } + + return parts.join(", "); +} + +/** + * Build the full list of [name, value] header pairs from options. + * + * @param {Object} options + * @returns {Array<[string, string]>} + */ +function buildHeaderMap(options) { + const headers = []; + + /* X-Content-Type-Options: prevents MIME-type sniffing */ + if (options.xContentTypeOptions !== false) { + headers.push(["x-content-type-options", "nosniff"]); + } + + /* X-Frame-Options: clickjacking protection */ + if (options.xFrameOptions !== false) { + const value = typeof options.xFrameOptions === "string" + ? options.xFrameOptions.toUpperCase() + : "DENY"; + headers.push(["x-frame-options", value]); + } + + /* X-XSS-Protection: modern best practice is to disable it (CSP replaces it) */ + if (options.xXssProtection !== false) { + headers.push(["x-xss-protection", "0"]); + } + + /* Referrer-Policy */ + if (options.referrerPolicy !== false) { + const value = typeof options.referrerPolicy === "string" + ? options.referrerPolicy + : "strict-origin-when-cross-origin"; + headers.push(["referrer-policy", value]); + } + + /* Strict-Transport-Security */ + const hstsValue = buildHsts(options.hsts); + if (hstsValue !== null) { + headers.push(["strict-transport-security", hstsValue]); + } + + /* Content-Security-Policy */ + const cspValue = buildContentSecurityPolicy(options.contentSecurityPolicy); + if (cspValue !== null) { + headers.push(["content-security-policy", cspValue]); + } + + /* Cross-Origin-Opener-Policy */ + if (options.crossOriginOpenerPolicy !== false) { + const value = typeof options.crossOriginOpenerPolicy === "string" + ? options.crossOriginOpenerPolicy + : "same-origin"; + headers.push(["cross-origin-opener-policy", value]); + } + + /* Cross-Origin-Resource-Policy */ + if (options.crossOriginResourcePolicy !== false) { + const value = typeof options.crossOriginResourcePolicy === "string" + ? options.crossOriginResourcePolicy + : "same-origin"; + headers.push(["cross-origin-resource-policy", value]); + } + + /* Permissions-Policy */ + const ppValue = buildPermissionsPolicy(options.permissionsPolicy); + if (ppValue !== null) { + headers.push(["permissions-policy", ppValue]); + } + + /* X-DNS-Prefetch-Control */ + if (options.xDnsPrefetchControl !== false) { + const value = options.xDnsPrefetchControl === "on" ? "on" : "off"; + headers.push(["x-dns-prefetch-control", value]); + } + + /* X-Permitted-Cross-Domain-Policies */ + if (options.xPermittedCrossDomainPolicies !== false) { + const value = typeof options.xPermittedCrossDomainPolicies === "string" + ? options.xPermittedCrossDomainPolicies + : "none"; + headers.push(["x-permitted-cross-domain-policies", value]); + } + + return headers; +} + +// ─── Middleware Factory ──────────────── + +/** + * Create a security headers middleware. + * + * @param {Object} [options] - Per-header configuration; set any to `false` to disable + * @returns {Function} Middleware function + */ +export function helmet(options = {}) { + if (typeof options !== "object" || options === null) { + throw new TypeError("helmet(options) expects an object"); + } + + /* Pre-compute all header pairs at startup — zero allocation per request */ + const headers = buildHeaderMap(options); + + return async function helmetMiddleware(req, res, next) { + for (let i = 0; i < headers.length; i++) { + res.set(headers[i][0], headers[i][1]); + } + await next(); + }; +} diff --git a/src/http-server.config.js b/src/http-server.config.js index 5d84ceb..4ae9445 100644 --- a/src/http-server.config.js +++ b/src/http-server.config.js @@ -95,7 +95,7 @@ export function normalizeHttpServerConfig(overrides = {}) { overrides.headerTransferEncodingPrefix ?? httpServerConfig.headerTransferEncodingPrefix, ), - tls: normalizeTlsConfig(overrides.tls ?? httpServerConfig.tls), + tls: normalizeTlsConfig("tls" in overrides ? overrides.tls : httpServerConfig.tls), }; } diff --git a/src/index.d.ts b/src/index.d.ts index 5c7cd09..2576fad 100644 --- a/src/index.d.ts +++ b/src/index.d.ts @@ -9,6 +9,12 @@ export interface Request { /** HTTP method (GET, POST, PUT, DELETE, PATCH, OPTIONS, HEAD) */ readonly method: string; + /** Request ID (available when requestId middleware is used) */ + id?: string; + + /** CSRF token (available when csrf middleware is used) */ + csrfToken?: string; + /** URL path without query string */ readonly path: string; @@ -47,11 +53,35 @@ export interface Request { /** Current session ID (available when session middleware is used) */ readonly sessionId?: string; + + /** Trace ID (available when otel middleware is used) */ + traceId?: string; + + /** Span ID (available when otel middleware is used) */ + spanId?: string; + + /** Parsed multipart fields (available when multipart middleware is used) */ + fields?: Record; + + /** Parsed multipart files (available when multipart middleware is used) */ + files?: MultipartFile[]; + + /** Validated request body (available when validate middleware is used) */ + validatedBody?: unknown; + + /** Validated query params (available when validate middleware is used) */ + validatedQuery?: unknown; + + /** Validated route params (available when validate middleware is used) */ + validatedParams?: unknown; + + /** Decorator values attached via app.decorate() */ + [key: string]: unknown; } export interface Session { - /** Get a session value by key */ - get(key: string): T | undefined; + /** Get a session value by key (sync for MemoryStore, async for RedisStore) */ + get(key: string): T | undefined | Promise; /** Set a session value */ set(key: string, value: unknown): void; @@ -59,8 +89,8 @@ export interface Session { /** Delete a session key */ delete(key: string): void; - /** Check if a key exists */ - has(key: string): boolean; + /** Check if a key exists (sync for MemoryStore, async for RedisStore) */ + has(key: string): boolean | Promise; /** Destroy the entire session */ destroy(): void; @@ -84,6 +114,20 @@ export interface SessionOptions { sameSite?: "strict" | "lax" | "none"; /** Cookie path (default "/") */ path?: string; + /** Maximum concurrent sessions in Rust store (default 100000) */ + maxSessions?: number; + /** Maximum session data size in bytes (default 4096) */ + maxDataSize?: number; + /** Custom session store (default: MemoryStore backed by Rust) */ + store?: SessionStore; +} + +export interface SessionStore { + get(sessionId: string, key: string): unknown | Promise; + set(sessionId: string, key: string, value: unknown): void | Promise; + delete(sessionId: string, key: string): void | Promise; + destroy(sessionId: string): void | Promise; + getAll(sessionId: string): Record | null | Promise | null>; } export interface Response { @@ -129,6 +173,31 @@ export interface Response { * @param options.maxEntries - Max LRU entries per route (default 256) */ ncache(data: unknown, ttl: number, options?: { maxEntries?: number }): Response; + + /** + * Redirect the client to a different URL. + * + * @param url - Target URL + * @param status - HTTP redirect status code (default 302) + */ + redirect(url: string, status?: number): Response; + + /** + * Start a chunked streaming response. Returns a StreamWriter for writing + * chunks incrementally, or null if the response is already finished. + */ + stream(options?: { contentType?: string }): StreamWriter | null; +} + +export interface StreamWriter { + /** Write a chunk to the stream (string, Buffer, Uint8Array, or object → JSON) */ + write(data: string | Buffer | Uint8Array | object): boolean; + + /** End the stream, optionally writing a final chunk */ + end(finalChunk?: string | Buffer | Uint8Array | object): boolean; + + /** The internal stream ID */ + readonly id: number; } export type NextFunction = () => Promise; @@ -260,6 +329,12 @@ export interface ServerHandle { /** Gracefully close the server */ close(): Promise; + + /** + * Graceful shutdown — stop accepting new connections, drain in-flight + * requests up to the timeout, then force-stop workers. + */ + shutdown(options?: ShutdownOptions): Promise; } export interface ListenHandle extends Promise { @@ -277,6 +352,20 @@ export interface ListenHandle extends Promise { /** Enable runtime hot reload respawn for self-starting apps */ hot(options?: boolean | HotReloadOptions): ListenHandle; + + /** + * Enable HTTP/3 (QUIC) support. Requires TLS. + * Binds a UDP listener on the same port as TCP. + * Alt-Svc headers are automatically injected in HTTP/1.1 and HTTP/2 responses. + */ + http3(options?: boolean | Http3Options): ListenHandle; +} + +export interface Http3Options { + /** Enable HTTP/3 (default: true) */ + enabled?: boolean; + /** Max idle timeout in ms before closing QUIC connections (default: 30000) */ + maxIdleTimeout?: number; } export interface OptimizationSnapshot { @@ -340,6 +429,18 @@ export interface AppConfig { // ─── Application ──────────────────────── +export interface RouteOptions { + /** Native cache configuration for this route */ + cache?: { + /** Cache TTL in seconds */ + ttl?: number; + /** Fields to vary cache key by (e.g. "query.page", "header.accept") */ + varyBy?: string[]; + /** Max LRU entries (default 256) */ + maxEntries?: number; + }; +} + export interface Application { /** Register path-scoped or global middleware */ use(middleware: Middleware): Application; @@ -348,44 +449,146 @@ export interface Application { /** Register a global error / not-found handler */ error(handler: ErrorHandler): Application; - /** - * Deprecate soon - * Not very needed rn - * Register a global error handler */ + /** Register a global error handler */ onError(handler: ErrorHandler): Application; + /** + * Create route registrars scoped to a fixed HTTP status code. + * @example app.status(201).post("/users", handler) + */ + status(code: number): Record<"get" | "post" | "put" | "delete" | "patch" | "options" | "head" | "all", (path: string, handler: RouteHandler) => Application>; + + /** + * Register a 404 error handler — called when no route matches. + * @example app[404]((req, res) => res.status(404).json({ error: "Not found" })) + */ + 404(handler: RouteHandler): Application; + + /** + * Register a 401 error handler — called when UNAUTHORIZED error is thrown. + * @example app[401]((req, res) => res.status(401).json({ error: "Unauthorized" })) + */ + 401(handler: RouteHandler): Application; + /** Group routes under a shared path prefix */ group(pathPrefix: string, registerGroup: (group: Application) => void): Application; /** Register a GET route handler */ get(path: string, handler: RouteHandler): Application; + get(path: string, options: RouteOptions, handler: RouteHandler): Application; /** Register a POST route handler */ post(path: string, handler: RouteHandler): Application; + post(path: string, options: RouteOptions, handler: RouteHandler): Application; /** Register a PUT route handler */ put(path: string, handler: RouteHandler): Application; + put(path: string, options: RouteOptions, handler: RouteHandler): Application; /** Register a DELETE route handler */ delete(path: string, handler: RouteHandler): Application; + delete(path: string, options: RouteOptions, handler: RouteHandler): Application; /** Register a PATCH route handler */ patch(path: string, handler: RouteHandler): Application; + patch(path: string, options: RouteOptions, handler: RouteHandler): Application; /** Register an OPTIONS route handler */ options(path: string, handler: RouteHandler): Application; + options(path: string, options: RouteOptions, handler: RouteHandler): Application; + + /** Register a HEAD route handler */ + head(path: string, handler: RouteHandler): Application; + head(path: string, options: RouteOptions, handler: RouteHandler): Application; /** Register a handler for all HTTP methods */ all(path: string, handler: RouteHandler): Application; + all(path: string, options: RouteOptions, handler: RouteHandler): Application; /** Register an exact GET HTML route served from the native static fast path */ static(path: string, html: string, options?: HtmlResponseOptions): Application; + /** + * Register a health check endpoint served from the Rust static fast path. + * Zero JS dispatch overhead — response is pre-built at startup. + */ + health(path: string, options?: HealthCheckOptions): Application; + + /** Register a WebSocket upgrade handler for a path */ + ws(path: string, handlers: WebSocketHandlers): Application; + /** Configure first-class app reload behavior for dev runtimes */ reload(options?: ReloadOptions): Application; /** Start the server and listen for connections */ listen(options?: ListenOptions): ListenHandle; + + /** + * Attach a named property/service to every request object. + * Decorators are set once at startup — zero per-request overhead. + */ + decorate(name: string, value: T): Application; + + /** + * Register a lifecycle hook called at the specified event. + */ + addHook( + event: "onRequest" | "onRoute" | "onResponse" | "onError" | "onClose", + fn: (...args: unknown[]) => void | Promise, + ): Application; + + /** + * Install a plugin. Plugins can register routes, middleware, hooks, and decorators. + */ + register(plugin: Plugin, options?: Record): Application; +} + +// ─── WebSocket Types ─────────────────── + +export interface WebSocketHandlers { + /** Called when a new WebSocket connection opens */ + open?(ws: WebSocketConnection): void | Promise; + + /** Called when a message is received from the client */ + message?(ws: WebSocketConnection, data: string | Buffer): void | Promise; + + /** Called when the WebSocket connection closes */ + close?(ws: WebSocketConnection, code?: number, reason?: string): void | Promise; + + /** Maximum payload length in bytes (default: 65536) */ + maxPayloadLength?: number; + + /** Backpressure strategy: "drop" discards, "buffer" queues, "block" awaits (default: "drop") */ + backpressure?: "drop" | "buffer" | "block"; + + /** Idle timeout in seconds — connections with no activity are closed (default: 120) */ + idleTimeout?: number; + + /** Enable per-message deflate compression (RFC 7692) (default: false) */ + perMessageDeflate?: boolean; +} + +export interface WebSocketConnection { + /** Send a text or binary message to the client */ + send(data: string | Buffer | Uint8Array): void; + + /** Close the WebSocket connection */ + close(code?: number, reason?: string): void; + + /** Subscribe this connection to a pub/sub topic */ + subscribe(topic: string): void; + + /** Unsubscribe this connection from a pub/sub topic */ + unsubscribe(topic: string): void; + + /** Publish a message to all subscribers of a topic */ + publish(topic: string, data: string | Buffer | Uint8Array): number; + + /** Get the number of subscribers for a topic */ + subscriberCount(topic: string): number; + + /** The internal connection ID */ + readonly id: number; } /** Create a new http-native application */ @@ -522,3 +725,489 @@ export function createNativeRateLimiter(options?: NativeRateLimiterOptions): Nat /** Create a high-level middleware wrapper around the native limiter handle. */ export function rateLimit(options: RateLimitOptions): Middleware; + +// ─── Health Check Types ──────────────── + +export interface HealthCheckOptions { + /** JSON body to return (default: { status: "ok" }) */ + body?: Record; + /** HTTP status code (default: 200) */ + status?: number; + /** Additional response headers */ + headers?: Record; +} + +// ─── Security Headers Types ──────────── + +export interface HstsOptions { + /** Max age in seconds (default: 31536000 = 1 year) */ + maxAge?: number; + /** Include subdomains (default: true) */ + includeSubDomains?: boolean; + /** Add preload flag */ + preload?: boolean; +} + +export interface ContentSecurityPolicyOptions { + /** CSP directives as camelCase keys → source arrays */ + directives?: Record; +} + +export interface HelmetOptions { + /** HSTS header (default: enabled with 1-year max-age). Set false to disable. */ + hsts?: HstsOptions | boolean; + /** Content-Security-Policy. Set false or omit to disable. */ + contentSecurityPolicy?: ContentSecurityPolicyOptions | boolean; + /** X-Frame-Options value (default: "DENY"). Set false to disable. */ + xFrameOptions?: string | boolean; + /** X-Content-Type-Options (default: "nosniff"). Set false to disable. */ + xContentTypeOptions?: boolean; + /** X-XSS-Protection (default: "0"). Set false to disable. */ + xXssProtection?: boolean; + /** Referrer-Policy (default: "strict-origin-when-cross-origin"). Set false to disable. */ + referrerPolicy?: string | boolean; + /** Cross-Origin-Opener-Policy (default: "same-origin"). Set false to disable. */ + crossOriginOpenerPolicy?: string | boolean; + /** Cross-Origin-Resource-Policy (default: "same-origin"). Set false to disable. */ + crossOriginResourcePolicy?: string | boolean; + /** Permissions-Policy. Set false or omit to disable. */ + permissionsPolicy?: Record | boolean; + /** X-DNS-Prefetch-Control (default: "off"). Set false to disable. */ + xDnsPrefetchControl?: string | boolean; + /** X-Permitted-Cross-Domain-Policies (default: "none"). Set false to disable. */ + xPermittedCrossDomainPolicies?: string | boolean; +} + +/** Create a security headers middleware with sane defaults */ +export function helmet(options?: HelmetOptions): Middleware; + +// ─── Request ID Types ────────────────── + +export interface RequestIdOptions { + /** Incoming request header to read (default: "x-request-id") */ + header?: string; + /** Response header to set (default: same as header; false to disable) */ + responseHeader?: string | false; + /** Custom ID generator function (default: crypto.randomUUID) */ + generate?: () => string; +} + +/** Create a request ID middleware for distributed tracing correlation */ +export function requestId(options?: RequestIdOptions): Middleware; + +// ─── Body Limit Types ──────────────────── + +/** + * Create a per-route body size limit middleware. + * Accepts a human-readable size string ("50mb", "1kb") or byte count. + */ +export function bodyLimit(limit: string | number): Middleware; + +// ─── CSRF Types ────────────────────────── + +export interface CsrfCookieOptions { + /** Cookie name (default: "_csrf") */ + name?: string; + /** HttpOnly flag (default: true) */ + httpOnly?: boolean; + /** SameSite policy (default: "strict") */ + sameSite?: "strict" | "lax" | "none"; + /** Secure flag (default: false) */ + secure?: boolean; + /** Cookie path (default: "/") */ + path?: string; + /** Cookie max age in seconds */ + maxAge?: number; +} + +export interface CsrfOptions { + /** Cookie configuration for the CSRF token */ + cookie?: CsrfCookieOptions; + /** HTTP methods to skip CSRF validation (default: GET, HEAD, OPTIONS) */ + ignoreMethods?: string[]; + /** Request header to read the CSRF token from (default: "x-csrf-token") */ + tokenHeader?: string; + /** Request body field to read the CSRF token from (default: "_csrf") */ + tokenField?: string; +} + +/** Create a CSRF protection middleware using double-submit cookie pattern */ +export function csrf(options?: CsrfOptions): Middleware; + +// ─── IP Filter Types ───────────────────── + +export interface IpFilterOptions { + /** CIDR ranges to allow (e.g. ["10.0.0.0/8", "192.168.0.0/16"]) */ + allow?: string[]; + /** CIDR ranges to deny (e.g. ["0.0.0.0/0"]) */ + deny?: string[]; + /** Use X-Forwarded-For header for client IP (default: false) */ + trustProxy?: boolean; + /** Custom handler for denied requests */ + onDenied?: (req: Request, res: Response) => void | Promise; +} + +/** Create an IP allowlist/denylist middleware with CIDR range matching */ +export function ipFilter(options: IpFilterOptions): Middleware; + +// ─── Error Types ───────────────────────── + +export class HttpError extends Error { + /** HTTP status code */ + readonly status: number; + /** Machine-readable error code (e.g. "VALIDATION_FAILED") */ + readonly code: string; + /** Additional error details */ + readonly details?: unknown; + + constructor(status: number, code?: string, message?: string, details?: unknown); + + /** Serialize to JSON-safe plain object */ + toJSON(): { status: number; code: string; message: string; details?: unknown }; +} + +export class BadRequest extends HttpError { + constructor(message?: string, details?: unknown); +} + +export class Unauthorized extends HttpError { + constructor(message?: string, details?: unknown); +} + +export class Forbidden extends HttpError { + constructor(message?: string, details?: unknown); +} + +export class NotFound extends HttpError { + constructor(message?: string, details?: unknown); +} + +export class Conflict extends HttpError { + constructor(message?: string, details?: unknown); +} + +export class UnprocessableEntity extends HttpError { + constructor(message?: string, details?: unknown); +} + +export class TooManyRequests extends HttpError { + constructor(message?: string, details?: unknown); +} + +export class InternalServerError extends HttpError { + constructor(message?: string, details?: unknown); +} + +export class BadGateway extends HttpError { + constructor(message?: string, details?: unknown); +} + +export class ServiceUnavailable extends HttpError { + constructor(message?: string, details?: unknown); +} + +// ─── Shutdown Types ────────────────────── + +export interface ShutdownOptions { + /** Maximum drain time in milliseconds (default: 30000) */ + timeout?: number; + /** Force kill after this many ms (default: timeout + 5000) */ + forceAfter?: number; + /** Called when draining starts — mark service as unhealthy for load balancers */ + onDraining?: () => void; + /** Called after all in-flight requests finish (before close) */ + onDrained?: () => void | Promise; +} + +export interface ShutdownResult { + /** Whether all in-flight requests completed before the timeout */ + drained: boolean; + /** Number of requests still in-flight when the timeout expired */ + remaining: number; +} + +// ─── Plugin Types ────────────────────── + +export interface Plugin { + /** Unique plugin name */ + name: string; + /** Semver version string */ + version?: string; + /** Called once when the plugin is registered */ + setup(app: Application, options?: Record): void; + /** Called during server shutdown */ + teardown?(): void | Promise; +} + +/** + * Define a plugin with the standard interface. + */ +export function definePlugin(definition: Plugin): Plugin; + +// ─── Audit Log Types ─────────────────── + +export interface AuditEvent { + type: string; + timestamp: string; + method?: string; + path?: string; + ip?: string; + requestId?: string; + userId?: string; + statusCode?: number; + durationMs?: number; + error?: string; + headers?: Record; + body?: string; + [key: string]: unknown; +} + +export interface AuditLogOptions { + /** Receives each audit event — write to file, stream, SIEM, etc. */ + sink: (event: AuditEvent) => void; + /** Event types to capture (default: all) */ + events?: string[]; + /** Include request headers in events (default: false) */ + includeHeaders?: boolean; + /** Header names to redact (default: authorization, cookie, set-cookie) */ + redactHeaders?: string[]; + /** Include request body in events (default: false) */ + includeBody?: boolean; +} + +/** Create an audit logging middleware for compliance events */ +export function auditLog(options: AuditLogOptions): Middleware; + +export interface AuditEmitter { + emit(type: string, data?: Record): void; + authLogin(userId: string, ip: string): void; + authLogout(userId: string, ip: string): void; + authFailed(reason: string, ip: string): void; + rateLimitExceeded(ip: string, path: string): void; +} + +/** Create pre-defined audit event emitters for use outside middleware */ +export function createAuditEmitter( + sink: (event: AuditEvent) => void, +): AuditEmitter; + +// ─── Test Client Types ───────────────── + +export interface TestResponse { + status: number; + headers: Record; + ok: boolean; + json(): Promise; + text(): Promise; + raw: Response; +} + +export interface TestWebSocket { + send(data: string | Buffer | Uint8Array): void; + next(): Promise; + close(): void; + raw: WebSocket; +} + +export interface TestClient { + baseUrl: string; + request(path: string, init?: RequestInit & { json?: unknown }): Promise; + get(path: string, init?: RequestInit): Promise; + post(path: string, init?: RequestInit & { json?: unknown }): Promise; + put(path: string, init?: RequestInit & { json?: unknown }): Promise; + patch(path: string, init?: RequestInit & { json?: unknown }): Promise; + delete(path: string, init?: RequestInit): Promise; + ws(path: string): Promise; + close(): Promise; +} + +/** Create a test client that starts the app on an ephemeral port */ +export function testClient( + app: Application, + options?: { port?: number; host?: string }, +): Promise; + +// ─── Circuit Breaker Types ───────────── + +export interface CircuitBreakerOptions { + /** Unique circuit name (for logging/metrics) */ + name: string; + /** Consecutive failures before opening (default: 5) */ + threshold?: number; + /** ms in open state before half-open probe (default: 30000) */ + timeout?: number; + /** Max concurrent requests in half-open state (default: 1) */ + halfOpenMax?: number; + /** Called when circuit opens */ + onOpen?: () => void; + /** Called when circuit transitions to half-open */ + onHalfOpen?: () => void; + /** Called when circuit closes (healthy again) */ + onClose?: () => void; + /** Custom failure detection (default: any thrown error) */ + isFailure?: (err: unknown) => boolean; +} + +export interface CircuitBreaker { + readonly name: string; + readonly state: "closed" | "open" | "half-open"; + readonly failureCount: number; + call(fn: () => Promise): Promise; + reset(): void; + trip(): void; +} + +export function circuitBreaker(options: CircuitBreakerOptions): CircuitBreaker; + +export class CircuitOpenError extends Error { + circuit: string; + status: 503; + code: "CIRCUIT_OPEN"; +} + +// ─── Environment Config Types ────────── + +export interface EnvVarSpec { + type?: "string" | "number" | "boolean" | "json"; + required?: boolean; + default?: unknown; +} + +export interface LoadEnvOptions { + /** Prefix all env var names (e.g. "APP_") */ + prefix?: string; +} + +type EnvVarType = + S["type"] extends "number" ? number : + S["type"] extends "boolean" ? boolean : + S["type"] extends "json" ? unknown : + string; + +/** Load and validate environment variables */ +export function loadEnv>( + schema: T, + options?: LoadEnvOptions, +): { [K in keyof T]: EnvVarType }; + +export class EnvValidationError extends Error { + errors: string[]; +} + +// ─── OpenAPI Types ───────────────────── + +export interface OpenApiOptions { + /** OpenAPI info object */ + info?: { title: string; version: string; description?: string }; + /** Server URLs */ + servers?: { url: string; description?: string }[]; + /** Path to serve the raw JSON spec (default: "/openapi.json") */ + json?: string; + /** Path to serve Swagger UI (optional) */ + ui?: string; + /** Extra OpenAPI components to merge */ + components?: Record; + /** Top-level tag definitions */ + tags?: (string | { name: string; description?: string })[]; +} + +export function openapi(options?: OpenApiOptions): Middleware; +export function generateSpec( + appMeta: { routes?: { path: string; method: string; meta?: Record }[] }, + options?: OpenApiOptions, +): Record; + +// ─── Multipart Types ─────────────────── + +export interface MultipartOptions { + /** Max size per file (e.g. "10mb", 1048576). Default: 10MB */ + maxFileSize?: string | number; + /** Max number of files. Default: 10 */ + maxFiles?: number; + /** Max size per text field. Default: 1MB */ + maxFieldSize?: string | number; + /** Auto-save directory (optional) */ + uploadDir?: string; +} + +export interface MultipartFile { + name: string; + fieldName: string; + mimetype: string; + size: number; + data: Buffer; + saveTo?(destPath?: string): Promise; +} + +export function multipart(options?: MultipartOptions): Middleware; + +// ─── Logger Types ─────────────────────── + +export interface LoggerOptions { + /** Minimum log level (default: "info") */ + level?: "debug" | "info" | "warn" | "error" | "silent"; + /** Output format (default: "json") */ + format?: "json" | "pretty"; + /** Dot-paths to redact from log output (e.g. "req.headers.authorization") */ + redact?: string[]; + /** Custom output function (default: JSON to stderr) */ + sink?: (entry: Record) => void; + /** Include timestamps (default: true) */ + timestamp?: boolean; + /** Extra fields to include per request */ + customProps?: (req: Request) => Record; +} + +export interface Logger { + debug(msg: string, fields?: Record): void; + info(msg: string, fields?: Record): void; + warn(msg: string, fields?: Record): void; + error(msg: string, fields?: Record): void; + child(defaults: Record): Logger; +} + +export function logger(options?: LoggerOptions): Middleware; +export function createLogger(options?: Omit): Logger; + +// ─── OpenTelemetry Types ──────────────── + +export interface OtelOptions { + /** Service name for trace/metric resource */ + serviceName?: string; + /** OTLP collector endpoint */ + endpoint?: string; + /** Context propagation format (default: "w3c") */ + propagation?: "w3c" | "b3" | "jaeger"; + /** Fraction of requests to trace, 0.0–1.0 (default: 1.0) */ + sampleRate?: number; + /** Custom span exporter function */ + exporter?: (spans: OtelSpan[]) => void; + /** Enable request metrics collection (default: true) */ + metrics?: boolean; + /** Metrics flush interval in ms (default: 60000) */ + metricsInterval?: number; +} + +export interface OtelSpan { + traceId: string; + spanId: string; + parentSpanId: string | null; + operationName: string; + serviceName: string; + startTime: number; + duration: number; + tags: Record; + status: "OK" | "ERROR"; +} + +export type OtelMiddleware = Middleware & { + /** Flush pending spans to the exporter immediately */ + flushSpans(): void; + /** Get the number of pending (unbatched) spans */ + pendingSpans(): number; +}; + +export function otel(options?: OtelOptions): OtelMiddleware; + +/** Flush pending spans from the otel middleware (convenience for shutdown) */ +export function flushSpans(middleware: OtelMiddleware): void; diff --git a/src/index.js b/src/index.js index eee5b54..ccc1161 100644 --- a/src/index.js +++ b/src/index.js @@ -14,7 +14,6 @@ import { mergeRequestAccessPlans, releaseRequestObject, } from "./bridge.js"; -import { encodeSessionTrailer } from "./session.js"; import { loadNativeModule } from "./native.js"; import defaultHttpServerConfig, { normalizeHttpServerConfig, @@ -24,7 +23,7 @@ import { createRouteDevCommentWriter } from "./dev/comments.js"; import { createRuntimeHotReloadController } from "./dev/hot-reload.js"; import { createRuntimeOptimizer } from "./opt/runtime.js"; -const HTTP_METHODS = ["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"]; +const HTTP_METHODS = ["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"]; const ACTIVE_NATIVE_SERVERS = new Set(); const EMPTY_BUFFER = Buffer.alloc(0); const NOOP_NEXT = () => undefined; @@ -265,6 +264,11 @@ function acquireResponseState() { pooled.ncache = null; pooled.streaming = false; pooled.streamId = null; + /* Clear session state to prevent cross-request data leakage */ + pooled._sessionId = undefined; + pooled._sessionIsNew = undefined; + pooled._sessionCookie = undefined; + pooled._sessionState = undefined; return pooled; } @@ -486,6 +490,28 @@ const RESPONSE_PROTO = { id: streamId, }; }, + + /** + * Redirect the client to a different URL. + * + * @param {string} url - Target URL + * @param {number} [status=302] - HTTP redirect status code (301, 302, 307, 308) + * @returns {Response} + */ + redirect(url, status = 302) { + const state = this._state; + if (state.finished) return this; + state.status = Number(status); + const urlStr = String(url); + /* Block CRLF injection in redirect URLs — same check as res.set() */ + if (urlStr.includes("\r") || urlStr.includes("\n")) { + throw new Error("[http-native] CRLF injection blocked in redirect URL"); + } + state.headers["location"] = urlStr; + state.body = EMPTY_BUFFER; + state.finished = true; + return this; + }, }; function createResponseEnvelope(jsonSerializer = DEFAULT_JSON_SERIALIZER) { @@ -650,6 +676,16 @@ function isPromiseLike(value) { // ─── Dispatcher ───────────────────────── function buildDispatchState(snapshot) { + /* @B4.6 — compile lifecycle hooks into direct function references. + * When no hooks are registered for an event, the function is undefined + * and the dispatch path skips the call entirely (zero cost). */ + const hooks = snapshot.hooks ?? {}; + const compileHookChain = (fns) => { + if (!fns || fns.length === 0) return undefined; + if (fns.length === 1) return fns[0]; + return async (...args) => { for (const fn of fns) await fn(...args); }; + }; + return { snapshot, ...snapshot, @@ -657,6 +693,10 @@ function buildDispatchState(snapshot) { wsRoutesById: new Map(snapshot.wsRoutes.map((route) => [route.handlerId, route])), trackDispatchTiming: snapshot.runtimeOptimizer?.shouldCaptureDispatchTiming?.() === true, + /* Compiled hook chains — undefined when no hooks registered */ + onRequest: compileHookChain(hooks.onRequest), + onResponse: compileHookChain(hooks.onResponse), + onError: compileHookChain(hooks.onError), }; } @@ -724,14 +764,33 @@ function createDispatcher(initialSnapshot) { const native = loadNativeModule(); const ws = { + id: wsId, + /** Send a text or binary message to this connection. */ send(msg) { const chunk = typeof msg === "string" ? Buffer.from(msg, "utf8") : Buffer.from(msg); native.streamWrite(Number(wsId), chunk); }, + /** Close the WebSocket connection. */ close(code = 1000, reason = "") { native.streamEnd(Number(wsId)); }, - id: wsId, + /** Subscribe this connection to a pub/sub topic. */ + subscribe(topic) { + native.wsSubscribe(Number(wsId), topic); + }, + /** Unsubscribe this connection from a pub/sub topic. */ + unsubscribe(topic) { + native.wsUnsubscribe(Number(wsId), topic); + }, + /** Publish a message to all subscribers of a topic (including self). */ + publish(topic, msg) { + const chunk = typeof msg === "string" ? Buffer.from(msg, "utf8") : Buffer.from(msg); + return native.wsPublish(topic, chunk); + }, + /** Get the number of subscribers for a topic. */ + subscriberCount(topic) { + return native.wsSubscriberCount(topic); + }, }; try { @@ -741,10 +800,16 @@ function createDispatcher(initialSnapshot) { await route.handlers.open?.(ws); break; case 0x02: { + /* Text message — decode to string */ const textData = data ? new TextDecoder().decode(data) : ""; await route.handlers.message?.(ws, textData); break; } + case 0x04: { + /* Binary message — pass raw Buffer */ + await route.handlers.message?.(ws, data ?? Buffer.alloc(0)); + break; + } case 0x03: activeWebSocketIds.delete(wsId); await route.handlers.close?.(ws); @@ -827,13 +892,7 @@ function createDispatcher(initialSnapshot) { ? performance.now() - dispatchStartMs : undefined; state.runtimeOptimizer?.recordDispatch(route, req, responseSnapshot, dispatchDurationMs); - let encoded = encodeResponseEnvelope(responseSnapshot); - - // Append session trailer if session mutations exist - const sessionTrailer = encodeSessionTrailer(res._sessionState); - if (sessionTrailer) { - encoded = Buffer.concat([encoded, sessionTrailer]); - } + const encoded = encodeResponseEnvelope(responseSnapshot); maybePromoteRouteResponseCache( route, @@ -929,6 +988,16 @@ function normalizeStaticRouteRegistration(path, html, options = {}) { }; } +const _existsCache = new Map(); +function cachedExistsSync(filePath) { + let result = _existsCache.get(filePath); + if (result === undefined) { + result = existsSync(filePath); + _existsCache.set(filePath, result); + } + return result; +} + function captureRouteRegistrationLocation() { const stack = new Error().stack; if (!stack) { @@ -961,7 +1030,7 @@ function captureRouteRegistrationLocation() { filePath = path.resolve(process.cwd(), filePath); } - if (!existsSync(filePath)) { + if (!cachedExistsSync(filePath)) { continue; } @@ -1061,8 +1130,7 @@ function maybePromoteRouteResponseCache(route, snapshot, encoded, devRouteCommen } function buildSnapshotCacheKey(snapshot) { - let hash = 0x811c9dc5; - hash = fnv1aString(hash, String(snapshot.status ?? 200)); + let hash = fnv1aString(0x811c9dc5, String(snapshot.status ?? 200)); const headers = snapshot.headers ?? Object.create(null); const headerNames = Object.keys(headers); @@ -1078,25 +1146,43 @@ function buildSnapshotCacheKey(snapshot) { : EMPTY_BUFFER; hash = fnv1aBytes(hash, body); - return `${hash}:${body.length}:${headerNames.length}`; + return `${fnv1aFinish(hash)}:${body.length}:${headerNames.length}`; } +/* @B3: dual-lane FNV-1a — two independent 32-bit hashes with different offsets + * and primes, yielding 64-bit effective collision resistance in JS without + * BigInt overhead. The "lo" lane is standard FNV-1a-32; the "hi" lane uses + * FNV-1a-32 seeded at a different offset with a co-prime multiplier. */ + +/* @param seed — 64-bit state as { lo, hi } or a plain 32-bit number (legacy) */ +/* @param value — string to hash */ function fnv1aString(seed, value) { - let hash = seed >>> 0; + let lo = typeof seed === "number" ? seed >>> 0 : seed.lo >>> 0; + let hi = typeof seed === "number" ? (seed ^ 0x6c62272e) >>> 0 : seed.hi >>> 0; for (let index = 0; index < value.length; index += 1) { - hash ^= value.charCodeAt(index); - hash = Math.imul(hash, 0x01000193); + const c = value.charCodeAt(index); + lo = Math.imul(lo ^ c, 0x01000193) >>> 0; + hi = Math.imul(hi ^ c, 0x01000193) >>> 0; } - return hash >>> 0; + return { lo, hi }; } +/* @param seed — 64-bit state as { lo, hi } or a plain 32-bit number (legacy) */ +/* @param bytes — Uint8Array to hash */ function fnv1aBytes(seed, bytes) { - let hash = seed >>> 0; + let lo = typeof seed === "number" ? seed >>> 0 : seed.lo >>> 0; + let hi = typeof seed === "number" ? (seed ^ 0x6c62272e) >>> 0 : seed.hi >>> 0; for (let index = 0; index < bytes.length; index += 1) { - hash ^= bytes[index]; - hash = Math.imul(hash, 0x01000193); + const b = bytes[index]; + lo = Math.imul(lo ^ b, 0x01000193) >>> 0; + hi = Math.imul(hi ^ b, 0x01000193) >>> 0; } - return hash >>> 0; + return { lo, hi }; +} + +/* @param h — dual-lane hash state { lo, hi } */ +function fnv1aFinish(h) { + return ((h.hi >>> 0) * 0x100000000 + (h.lo >>> 0)).toString(16); } // ─── Fast-Path Probe ──────────────────── @@ -1125,6 +1211,12 @@ function probeHandlerForFastPath(route, originalSource) { return null; } + // Don't probe if the handler has side effects beyond req/res calls. + // Detect external mutations: array pushes, property assignments on non-res objects, etc. + if (/(? - analyzeRequestAccess(Function.prototype.toString.call(handler)), - ); + /* Error handlers have signature (error, req, res) — the request is the + * SECOND parameter, not the first. analyzeRequestAccess assumes the first + * parameter is req, which would misidentify `error`. Create a plan that + * requests the basic fields error handlers typically need (method, path, url) + * without forcing generic_fallback on every route. */ + const ERROR_HANDLER_PLAN = Object.freeze({ + method: true, + path: true, + url: true, + fullParams: false, + fullQuery: false, + fullHeaders: false, + paramKeys: new Set(), + queryKeys: new Set(), + headerKeys: new Set(), + dispatchKind: "specialized", + jsonFastPath: "fallback", + }); + const errorHandlerPlans = app._errorHandlers.map(() => ERROR_HANDLER_PLAN); const routes = app._routes.map((route) => { let handlerSource = @@ -1637,7 +1745,7 @@ function buildCompiledApplication(app, normalizedOptions) { cert: normalizedOptions.tls.cert, key: normalizedOptions.tls.key, ca: normalizedOptions.tls.ca, - passphrase: normalizedOptions.tls.passphrase, + passphrase: normalizedOptions.tls.passphrase ? "[REDACTED]" : undefined, }; } @@ -1781,6 +1889,56 @@ async function startCompiledServer(compiledSnapshot, normalizedOptions) { hotReloadController.dispose(); return closeServerHandle(); }, + + /** + * @DX-6.3: graceful shutdown — stop accepting new connections, drain + * in-flight requests up to `timeout` ms, then force-stop workers. + * + * @param {Object} [options] + * @param {number} [options.timeout=30000] - Maximum drain time in ms + * @param {number} [options.forceAfter] - Hard kill after this many ms (defaults to timeout + 5000) + * @param {Function} [options.onDraining] - Called when draining starts + * @param {Function} [options.onDrained] - Called when all requests finish (before close) + * @returns {Promise<{ drained: boolean, remaining: number }>} + */ + async shutdown(options = {}) { + const drainTimeout = options.timeout ?? 30000; + const forceTimeout = options.forceAfter ?? (drainTimeout + 5000); + + if (typeof options.onDraining === "function") { + options.onDraining(); + } + + /* Use the Rust-side shutdown which sets the draining flag, polls + * the in-flight counter, and then force-closes workers. The force + * timeout is handled by a JS-side race. */ + const shutdownPromise = new Promise((resolve) => { + try { + const remaining = nativeHandle.shutdown(drainTimeout); + resolve(remaining); + } catch { + resolve(0); + } + }); + + const forcePromise = new Promise((resolve) => + setTimeout(() => resolve(-1), forceTimeout), + ); + + const result = await Promise.race([shutdownPromise, forcePromise]); + const remaining = result === -1 ? 0 : result; + const drained = remaining === 0; + + if (drained && typeof options.onDrained === "function") { + await options.onDrained(); + } + + hotReloadController.dispose(); + dispatcher.dispose(); + ACTIVE_NATIVE_SERVERS.delete(nativeHandle); + + return { drained, remaining }; + }, }; return serverHandle; @@ -2077,6 +2235,17 @@ export function createApp(config = {}) { selectedTls = tlsConfig; return chainableListen; }, + http3(h3Options = { enabled: true }) { + if (startPromise) { + return startPromise; + } + + selectedOpt = { + ...(selectedOpt ?? {}), + http3: typeof h3Options === "boolean" ? { enabled: h3Options } : h3Options, + }; + return chainableListen; + }, then(onFulfilled, onRejected) { return start().then(onFulfilled, onRejected); }, @@ -2088,14 +2257,6 @@ export function createApp(config = {}) { }, }; - // Preserve previous behavior by starting even when callers don't await. - queueMicrotask(() => { - if (startBlockedByRemovedOpt) { - return; - } - void start(); - }); - return chainableListen; }, }; @@ -2106,8 +2267,17 @@ export function createApp(config = {}) { } app._wsRoutes.push({ path: normalizeRoutePath("GET", path), - handlers, + handlers: { + open: handlers.open, + message: handlers.message, + close: handlers.close, + }, handlerId: app._allocateHandlerId(), + /* DX-4.4 WebSocket config */ + maxPayloadLength: handlers.maxPayloadLength ?? 64 * 1024, + backpressure: handlers.backpressure ?? "drop", + idleTimeout: handlers.idleTimeout ?? 120, + perMessageDeflate: handlers.perMessageDeflate ?? false, }); return app; }; @@ -2118,6 +2288,7 @@ export function createApp(config = {}) { app.delete = createMethodRegistrar(app, "DELETE"); app.patch = createMethodRegistrar(app, "PATCH"); app.options = createMethodRegistrar(app, "OPTIONS"); + app.head = createMethodRegistrar(app, "HEAD"); app.all = createMethodRegistrar(app, "ALL"); app.static = (routePath, html, options = {}) => { const groupPrefix = app._groupPrefix ?? "/"; @@ -2137,9 +2308,135 @@ export function createApp(config = {}) { return app; }; + /** + * Register a health check endpoint served entirely from the Rust static + * fast path — zero JS dispatch, zero allocation per request. The response + * is pre-built at startup with Brotli/Gzip compression variants. + * + * @param {string} routePath - Health check path (e.g. "/healthz") + * @param {Object} [options] + * @param {Object} [options.body] - JSON body (default: { status: "ok" }) + * @param {number} [options.status] - HTTP status code (default: 200) + * @param {Object} [options.headers] - Additional response headers + * @returns {Application} + */ + app.health = (routePath, options = {}) => { + const body = JSON.stringify(options.body ?? { status: "ok" }); + const headers = { + "content-type": "application/json; charset=utf-8", + ...(options.headers ?? {}), + }; + return app.static(routePath, body, { + status: options.status ?? 200, + headers, + }); + }; + + // ─── DX-5.4: Decorator Pattern ────────── + // + // Attach custom properties to every request object. Decorators are set once + // at startup and available in every handler as `req.`. This avoids + // per-request middleware overhead for injecting services/pools. + + const decorators = Object.create(null); + + /** + * Attach a named property to every request object. + * + * @param {string} name - Property name accessible on `req` + * @param {*} value - Value or service instance to attach + * @returns {Application} + */ + app.decorate = (name, value) => { + if (name in decorators) { + throw new Error(`Decorator "${name}" is already registered`); + } + decorators[name] = value; + return app; + }; + + /** @internal — called by bridge to attach decorators to each request */ + app._applyDecorators = (req) => { + for (const key in decorators) { + req[key] = decorators[key]; + } + }; + + // ─── DX-5.2: Plugin System ──────────── + // + // Plugins are objects with a `setup(app, options)` function called at + // registration time and an optional `teardown()` called on shutdown. + // Lifecycle hooks are aggregated across plugins and called in order. + + const registeredPlugins = []; + const lifecycleHooks = { + onRequest: [], + onRoute: [], + onResponse: [], + onError: [], + onClose: [], + }; + + /** + * Register lifecycle hooks from plugins. + * + * @param {"onRequest"|"onRoute"|"onResponse"|"onError"|"onClose"} event + * @param {Function} fn + * @returns {Application} + */ + app.addHook = (event, fn) => { + if (!lifecycleHooks[event]) { + throw new Error(`Unknown hook event "${event}"`); + } + lifecycleHooks[event].push(fn); + return app; + }; + + /** + * Install a plugin. Plugins receive the app instance and can register + * routes, middleware, hooks, and decorators. + * + * @param {{ name: string, setup: Function, teardown?: Function }} plugin + * @param {Object} [pluginOptions] - Options forwarded to the plugin + * @returns {Application} + */ + app.register = (plugin, pluginOptions = {}) => { + if (!plugin || typeof plugin.setup !== "function") { + throw new Error("Plugin must have a setup(app, options) function"); + } + if (registeredPlugins.some((p) => p.name === plugin.name)) { + throw new Error(`Plugin "${plugin.name}" is already registered`); + } + plugin.setup(app, pluginOptions); + registeredPlugins.push(plugin); + return app; + }; + + /** @internal — expose hooks and plugins for bridge/shutdown integration */ + app._hooks = lifecycleHooks; + app._plugins = registeredPlugins; + return app; } +/** + * Define a plugin object with the standard interface. + * + * @param {{ name: string, version?: string, setup: Function, teardown?: Function }} definition + * @returns {{ name: string, version: string, setup: Function, teardown?: Function }} + */ +export function definePlugin(definition) { + if (!definition.name) throw new Error("Plugin must have a name"); + if (typeof definition.setup !== "function") + throw new Error("Plugin must have a setup function"); + return { + name: definition.name, + version: definition.version ?? "0.0.0", + setup: definition.setup, + teardown: definition.teardown, + }; +} + export { buildCompiledApplication as _buildCompiledApplication, normalizeListenOptions as _normalizeListenOptions, diff --git a/src/ip-filter.js b/src/ip-filter.js new file mode 100644 index 0000000..6a2e44e --- /dev/null +++ b/src/ip-filter.js @@ -0,0 +1,162 @@ +/** + * http-native IP allowlist / denylist middleware. + * + * Filters requests based on client IP address using CIDR range matching. + * Supports both IPv4 and IPv6 addresses. CIDR ranges are parsed at startup + * into efficient binary representations for O(1)-per-bit prefix matching. + * + * Usage: + * import { ipFilter } from "@http-native/core/ip-filter"; + * + * // Allow only private networks + * app.use(ipFilter({ + * allow: ["10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16"], + * deny: ["0.0.0.0/0"], + * })); + * + * // Deny specific ranges + * app.use(ipFilter({ deny: ["203.0.113.0/24"] })); + * + * // Trust proxy (use X-Forwarded-For) + * app.use(ipFilter({ allow: ["10.0.0.0/8"], trustProxy: true })); + */ + +// ─── CIDR Parsing ───────────────────────── + +/** + * Parse an IPv4 address string into a 32-bit integer. + * + * @param {string} ip + * @returns {number|null} + */ +function parseIPv4(ip) { + const parts = ip.split("."); + if (parts.length !== 4) return null; + + let result = 0; + for (let i = 0; i < 4; i++) { + const octet = parseInt(parts[i], 10); + if (!Number.isFinite(octet) || octet < 0 || octet > 255) return null; + result = (result << 8) | octet; + } + + return result >>> 0; // unsigned 32-bit +} + +/** + * Parse a CIDR notation string into a range descriptor. + * + * @param {string} cidr - e.g. "10.0.0.0/8", "192.168.1.0/24" + * @returns {{ ip: number, mask: number }} + */ +function parseCIDR(cidr) { + const [ipStr, prefixStr] = cidr.split("/"); + const ip = parseIPv4(ipStr); + + if (ip === null) { + throw new TypeError(`Invalid IP address in CIDR: "${cidr}"`); + } + + const prefix = prefixStr !== undefined ? parseInt(prefixStr, 10) : 32; + if (!Number.isFinite(prefix) || prefix < 0 || prefix > 32) { + throw new TypeError(`Invalid prefix length in CIDR: "${cidr}"`); + } + + /* Build the network mask: /24 → 0xFFFFFF00 */ + const mask = prefix === 0 ? 0 : (~0 << (32 - prefix)) >>> 0; + const network = (ip & mask) >>> 0; + + return { ip: network, mask }; +} + +/** + * Check if an IPv4 integer matches a parsed CIDR range. + * + * @param {number} ip + * @param {{ ip: number, mask: number }} range + * @returns {boolean} + */ +function matchesCIDR(ip, range) { + return ((ip & range.mask) >>> 0) === range.ip; +} + +/** + * Extract the client IPv4 address from a potentially IPv6-mapped string. + * + * @param {string} ip + * @returns {number|null} + */ +function extractIPv4(ip) { + if (!ip) return null; + + /* Strip IPv6 prefix for IPv4-mapped addresses (::ffff:10.0.0.1) */ + const stripped = ip.startsWith("::ffff:") ? ip.slice(7) : ip; + return parseIPv4(stripped); +} + +// ─── Middleware Factory ─────────────────── + +/** + * Create an IP filter middleware. + * + * @param {Object} options + * @param {string[]} [options.allow] - CIDR ranges to allow + * @param {string[]} [options.deny] - CIDR ranges to deny + * @param {boolean} [options.trustProxy] - Use X-Forwarded-For header + * @param {Function} [options.onDenied] - Custom denial handler + * @returns {Function} Middleware function + */ +export function ipFilter(options = {}) { + if (typeof options !== "object" || options === null) { + throw new TypeError("ipFilter(options) expects an object"); + } + + const allowRanges = (options.allow ?? []).map(parseCIDR); + const denyRanges = (options.deny ?? []).map(parseCIDR); + const trustProxy = options.trustProxy === true; + const onDenied = options.onDenied ?? null; + + if (allowRanges.length === 0 && denyRanges.length === 0) { + throw new TypeError("ipFilter requires at least one allow or deny range"); + } + + return async function ipFilterMiddleware(req, res, next) { + let clientIp = req.ip; + + /* When behind a reverse proxy, use the first X-Forwarded-For entry */ + if (trustProxy) { + const xff = req.header("x-forwarded-for"); + if (xff) { + clientIp = xff.split(",")[0].trim(); + } + } + + const ipInt = extractIPv4(clientIp); + + if (ipInt === null) { + /* Cannot parse IP — deny by default for safety */ + if (typeof onDenied === "function") { + return onDenied(req, res); + } + return res.status(403).json({ error: "Forbidden" }); + } + + /* Check deny list — explicitly denied IPs are blocked unless also allowed */ + const isDenied = denyRanges.some((r) => matchesCIDR(ipInt, r)); + const isAllowed = allowRanges.length > 0 + ? allowRanges.some((r) => matchesCIDR(ipInt, r)) + : true; /* No allow list = all IPs implicitly allowed */ + + /* Deny if: explicitly denied and not allowed, OR allow list exists and IP not in it */ + const shouldDeny = (isDenied && !isAllowed) || (allowRanges.length > 0 && !isAllowed); + + if (shouldDeny) { + if (typeof onDenied === "function") { + return onDenied(req, res); + } + return res.status(403).json({ error: "Forbidden" }); + } + + await next(); + }; +} diff --git a/src/logger.js b/src/logger.js new file mode 100644 index 0000000..cc68cd1 --- /dev/null +++ b/src/logger.js @@ -0,0 +1,146 @@ +/** + * http-native structured logging middleware (DX-2.1) + * + * Emits structured JSON logs for every request/response cycle. + * Compatible with pino, winston, or any logger with info/warn/error methods. + * + * Usage: + * import { logger } from "@http-native/core/logger"; + * app.use(logger()); + * app.use(logger({ level: "debug", format: "pretty", redact: ["req.headers.authorization"] })); + */ + +const LEVELS = { debug: 10, info: 20, warn: 30, error: 40, silent: 100 }; + +/** + * @param {Object} [options] + * @param {"debug"|"info"|"warn"|"error"|"silent"} [options.level="info"] + * @param {"json"|"pretty"} [options.format="json"] + * @param {string[]} [options.redact] - Dot-paths to redact (e.g. "req.headers.authorization") + * @param {(entry: Object) => void} [options.sink] - Custom output function (default: stderr) + * @param {boolean} [options.timestamp=true] + * @param {(req: Object) => Object} [options.customProps] - Extra fields to include per request + */ +export function logger(options = {}) { + const { + level = "info", + format = "json", + redact = [], + sink, + timestamp = true, + customProps, + } = options; + + const minLevel = LEVELS[level] ?? LEVELS.info; + const redactSet = new Set(redact); + const write = sink ?? ((entry) => process.stderr.write(formatEntry(entry, format) + "\n")); + + return async function loggerMiddleware(req, res, next) { + const start = performance.now(); + const reqId = req.id; // from requestId middleware, if present + + try { + await next(); + } finally { + const duration = performance.now() - start; + const status = res._state?.status ?? 200; + const entryLevel = status >= 500 ? "error" : status >= 400 ? "warn" : "info"; + + if (LEVELS[entryLevel] >= minLevel) { + const entry = { + level: entryLevel, + method: req.method, + path: req.path, + status, + duration_ms: Math.round(duration * 100) / 100, + }; + + if (timestamp) entry.time = Date.now(); + if (reqId) entry.requestId = reqId; + if (customProps) Object.assign(entry, customProps(req)); + + /* Apply redaction before output */ + applyRedaction(entry, req, res, redactSet); + + write(entry); + } + } + }; +} + +/** + * Create a standalone logger instance for use outside middleware. + * + * @param {Object} [options] + * @param {"debug"|"info"|"warn"|"error"|"silent"} [options.level="info"] + * @param {"json"|"pretty"} [options.format="json"] + * @param {(entry: Object) => void} [options.sink] + */ +export function createLogger(options = {}) { + const { + level = "info", + format = "json", + sink, + } = options; + + const minLevel = LEVELS[level] ?? LEVELS.info; + const write = sink ?? ((entry) => process.stderr.write(formatEntry(entry, format) + "\n")); + + function emit(entryLevel, msg, fields = {}) { + if (LEVELS[entryLevel] < minLevel) return; + const entry = { level: entryLevel, msg, time: Date.now(), ...fields }; + write(entry); + } + + return { + debug: (msg, fields) => emit("debug", msg, fields), + info: (msg, fields) => emit("info", msg, fields), + warn: (msg, fields) => emit("warn", msg, fields), + error: (msg, fields) => emit("error", msg, fields), + child(defaults) { + return createLogger({ + level, + format, + sink: (entry) => write({ ...defaults, ...entry }), + }); + }, + }; +} + +function formatEntry(entry, format) { + if (format === "pretty") { + const ts = entry.time ? new Date(entry.time).toISOString() : ""; + const lvl = (entry.level ?? "info").toUpperCase().padEnd(5); + const dur = entry.duration_ms != null ? ` ${entry.duration_ms}ms` : ""; + const id = entry.requestId ? ` [${entry.requestId}]` : ""; + const msg = entry.msg ?? `${entry.method} ${entry.path} ${entry.status}`; + return `${ts} ${lvl}${id} ${msg}${dur}`; + } + return JSON.stringify(entry); +} + +/** + * Redact specific fields from the log entry. + * Dot-paths like "req.headers.authorization" are resolved and replaced with "[REDACTED]". + */ +function applyRedaction(entry, req, res, redactSet) { + for (const path of redactSet) { + const parts = path.split("."); + const root = parts[0]; + if (root === "req" && parts.length >= 2) { + /* Walk nested path in entry: "req.headers.authorization" → entry.headers.authorization */ + const keys = parts.slice(1); + let target = entry; + for (let i = 0; i < keys.length - 1; i++) { + target = target?.[keys[i]]; + if (!target || typeof target !== "object") break; + } + if (target && typeof target === "object") { + const leafKey = keys[keys.length - 1]; + if (target[leafKey] !== undefined) { + target[leafKey] = "[REDACTED]"; + } + } + } + } +} diff --git a/src/multipart.js b/src/multipart.js new file mode 100644 index 0000000..8d5e94e --- /dev/null +++ b/src/multipart.js @@ -0,0 +1,222 @@ +/** + * http-native streaming multipart parser (DX-5.3) + * + * Parses multipart/form-data request bodies with streaming file support. + * Files can be saved directly to disk without buffering the entire file + * in memory. + * + * Usage: + * import { multipart } from "@http-native/core/multipart"; + * app.post("/upload", multipart({ maxFileSize: "10mb", maxFiles: 5 }), handler); + */ + +import { createWriteStream } from "node:fs"; +import { join, basename, resolve } from "node:path"; +import { randomBytes } from "node:crypto"; + +const DEFAULT_MAX_FILE_SIZE = 10 * 1024 * 1024; // 10 MB +const DEFAULT_MAX_FILES = 10; +const DEFAULT_MAX_FIELD_SIZE = 1024 * 1024; // 1 MB + +/** + * @param {Object} [options] + * @param {string|number} [options.maxFileSize="10mb"] - Max size per file + * @param {number} [options.maxFiles=10] - Max number of files + * @param {string|number} [options.maxFieldSize="1mb"] - Max size per text field + * @param {string} [options.uploadDir] - Auto-save directory (optional) + */ +export function multipart(options = {}) { + const maxFileSize = parseSize(options.maxFileSize ?? DEFAULT_MAX_FILE_SIZE); + const maxFiles = options.maxFiles ?? DEFAULT_MAX_FILES; + const maxFieldSize = parseSize(options.maxFieldSize ?? DEFAULT_MAX_FIELD_SIZE); + const uploadDir = options.uploadDir; + + return async function multipartMiddleware(req, res, next) { + const contentType = req.header("content-type") ?? req.headers?.["content-type"] ?? ""; + if (!contentType.startsWith("multipart/form-data")) { + return next(); + } + + const boundary = extractBoundary(contentType); + if (!boundary) { + res.status(400).json({ error: "Missing multipart boundary" }); + return; + } + + try { + const rawBody = typeof req.body === "string" ? Buffer.from(req.body) : req.body; + if (!rawBody || rawBody.length === 0) { + req.fields = Object.create(null); + req.files = []; + return await next(); + } + const { fields, files } = parseMultipartBody( + rawBody, + boundary, + { maxFileSize, maxFiles, maxFieldSize }, + ); + + req.fields = fields; + req.files = files; + + /* If uploadDir is set, attach saveTo helper to each file */ + if (uploadDir) { + for (const file of files) { + file.saveTo = (destPath) => { + /* Sanitize filename: strip path traversal sequences and use only the base name */ + /* Sanitize filename: strip path traversal sequences and use only the base name */ + const safeName = basename(file.name).replace(/\.\./g, ""); + const fullPath = destPath ?? join(uploadDir, `${randomBytes(16).toString("hex")}-${safeName}`); + /* Verify the resolved path stays within the upload directory */ + if (!resolve(fullPath).startsWith(resolve(uploadDir))) { + return Promise.reject(new Error("Path traversal detected in upload filename")); + } + return new Promise((resolvePromise, reject) => { + const ws = createWriteStream(fullPath); + ws.on("finish", () => resolvePromise(fullPath)); + ws.on("error", reject); + ws.end(file.data); + }); + }; + } + } + + await next(); + } catch (err) { + if (err.status) { + res.status(err.status).json({ error: err.message }); + } else { + throw err; + } + } + }; +} + +function extractBoundary(contentType) { + const match = contentType.match(/boundary=(?:"([^"]+)"|([^\s;]+))/); + return match ? match[1] || match[2] : null; +} + +/** + * Parse a multipart body buffer into fields and files. + * + * @param {Buffer} body + * @param {string} boundary + * @param {Object} limits + * @returns {{ fields: Record, files: MultipartFile[] }} + */ +function parseMultipartBody(body, boundary, limits) { + const delimiter = Buffer.from(`--${boundary}`); + const endDelimiter = Buffer.from(`--${boundary}--`); + const fields = Object.create(null); + const files = []; + + let offset = 0; + + /* Find the first boundary */ + const firstBoundary = indexOf(body, delimiter, offset); + if (firstBoundary === -1) return { fields, files }; + offset = firstBoundary + delimiter.length; + + while (offset < body.length) { + /* Skip CRLF after boundary */ + if (body[offset] === 0x0d && body[offset + 1] === 0x0a) offset += 2; + + /* Check for end delimiter */ + if (body[offset] === 0x2d && body[offset + 1] === 0x2d) break; + + /* Parse headers of this part */ + const headerEnd = indexOf(body, Buffer.from("\r\n\r\n"), offset); + if (headerEnd === -1) break; + + const headerSection = body.subarray(offset, headerEnd).toString("utf-8"); + offset = headerEnd + 4; + + const headers = parsePartHeaders(headerSection); + const disposition = parseContentDisposition(headers["content-disposition"] ?? ""); + + /* Find the end of this part's body (next boundary) */ + const nextBoundary = indexOf(body, delimiter, offset); + if (nextBoundary === -1) break; + + /* Part body is between current offset and (nextBoundary - 2) for trailing CRLF */ + const partBody = body.subarray(offset, nextBoundary - 2); + offset = nextBoundary + delimiter.length; + + if (disposition.filename !== undefined) { + /* File part */ + if (files.length >= limits.maxFiles) { + const err = new Error(`Too many files (max ${limits.maxFiles})`); + err.status = 413; + throw err; + } + if (partBody.length > limits.maxFileSize) { + const err = new Error(`File "${disposition.filename}" exceeds max size`); + err.status = 413; + throw err; + } + files.push({ + name: disposition.filename, + fieldName: disposition.name, + mimetype: headers["content-type"] ?? "application/octet-stream", + size: partBody.length, + data: partBody, + }); + } else if (disposition.name) { + /* Text field */ + if (partBody.length > limits.maxFieldSize) { + const err = new Error(`Field "${disposition.name}" exceeds max size`); + err.status = 413; + throw err; + } + fields[disposition.name] = partBody.toString("utf-8"); + } + } + + return { fields, files }; +} + +function parsePartHeaders(headerSection) { + const headers = Object.create(null); + for (const line of headerSection.split("\r\n")) { + const colon = line.indexOf(":"); + if (colon === -1) continue; + const name = line.slice(0, colon).trim().toLowerCase(); + const value = line.slice(colon + 1).trim(); + headers[name] = value; + } + return headers; +} + +function parseContentDisposition(value) { + const result = { name: undefined, filename: undefined }; + const nameMatch = value.match(/\bname="([^"]+)"/); + const filenameMatch = value.match(/\bfilename="([^"]+)"/); + if (nameMatch) result.name = nameMatch[1]; + if (filenameMatch) result.filename = filenameMatch[1]; + return result; +} + +function indexOf(buffer, search, fromIndex) { + for (let i = fromIndex; i <= buffer.length - search.length; i++) { + let found = true; + for (let j = 0; j < search.length; j++) { + if (buffer[i + j] !== search[j]) { found = false; break; } + } + if (found) return i; + } + return -1; +} + +function parseSize(input) { + if (typeof input === "number") return input; + const match = String(input).match(/^(\d+(?:\.\d+)?)\s*(b|kb|mb|gb)?$/i); + if (!match) return DEFAULT_MAX_FILE_SIZE; + const num = parseFloat(match[1]); + switch ((match[2] ?? "b").toLowerCase()) { + case "gb": return num * 1024 * 1024 * 1024; + case "mb": return num * 1024 * 1024; + case "kb": return num * 1024; + default: return num; + } +} diff --git a/src/openapi.js b/src/openapi.js new file mode 100644 index 0000000..fb05b35 --- /dev/null +++ b/src/openapi.js @@ -0,0 +1,220 @@ +/** + * http-native OpenAPI 3.1 auto-generation (DX-5.1) + * + * Derives an OpenAPI spec from registered routes and validation schemas. + * Serves the JSON spec and optional Swagger UI from static fast-path routes. + * + * Usage: + * import { openapi } from "@http-native/core/openapi"; + * app.use(openapi({ + * info: { title: "My API", version: "1.0.0" }, + * json: "/openapi.json", + * ui: "/docs", + * })); + */ + +/** + * @param {Object} options + * @param {Object} options.info - OpenAPI info object (title, version, description) + * @param {Object[]} [options.servers] - Server URLs + * @param {string} [options.json="/openapi.json"] - Path to serve the raw JSON spec + * @param {string} [options.ui] - Path to serve Swagger UI (optional) + * @param {Object} [options.components] - Extra OpenAPI components to merge + * @param {string[]} [options.tags] - Top-level tag definitions + */ +export function openapi(options = {}) { + const { + info = { title: "API", version: "1.0.0" }, + servers = [], + json: jsonPath = "/openapi.json", + ui: uiPath, + components = {}, + tags = [], + } = options; + + let cachedSpec = null; + + return function openapiMiddleware(req, res, next) { + if (req.method === "GET" && req.path === jsonPath) { + if (!cachedSpec) { + cachedSpec = JSON.stringify(generateSpec({ routes: [] }, { info, servers, components, tags })); + } + res.set("content-type", "application/json; charset=utf-8"); + return res.send(cachedSpec); + } + + if (uiPath && req.method === "GET" && req.path === uiPath) { + const html = ` +${info.title ?? "API Docs"} + + +
+ + +`; + res.set("content-type", "text/html; charset=utf-8"); + return res.send(html); + } + + return next(); + }; +} + +/** + * Generate an OpenAPI spec from route metadata. + * Called internally or can be used standalone. + * + * @param {Object} appMeta - Route and middleware metadata from the compiled app + * @param {Object} options - Same options as openapi() + * @returns {Object} OpenAPI 3.1 spec object + */ +export function generateSpec(appMeta, options = {}) { + const { + info = { title: "API", version: "1.0.0" }, + servers = [], + components = {}, + tags = [], + } = options; + + const spec = { + openapi: "3.1.0", + info, + servers: servers.length > 0 ? servers : undefined, + tags: tags.length > 0 ? tags.map((t) => (typeof t === "string" ? { name: t } : t)) : undefined, + paths: {}, + components: { + schemas: {}, + ...components, + }, + }; + + /* Walk routes and build path items */ + const routes = appMeta.routes ?? []; + for (const route of routes) { + const pathKey = toOpenApiPath(route.path); + if (!spec.paths[pathKey]) spec.paths[pathKey] = {}; + + const method = route.method.toLowerCase(); + const operation = { + summary: route.meta?.summary ?? undefined, + tags: route.meta?.tags ?? undefined, + operationId: route.meta?.operationId ?? `${method}_${pathKey.replace(/[^a-zA-Z0-9]/g, "_")}`, + parameters: extractPathParams(route.path), + responses: { + 200: { description: "Successful response" }, + }, + }; + + /* Extract validation schema if present */ + if (route.meta?.validation?.body) { + const schema = extractJsonSchema(route.meta.validation.body); + if (schema) { + operation.requestBody = { + required: true, + content: { + "application/json": { schema }, + }, + }; + } + } + + if (route.meta?.validation?.query) { + const queryParams = extractQueryParams(route.meta.validation.query); + operation.parameters = [...(operation.parameters ?? []), ...queryParams]; + } + + spec.paths[pathKey][method] = operation; + } + + return spec; +} + +/** Convert express-style path to OpenAPI path: /users/:id → /users/{id} */ +function toOpenApiPath(path) { + return path.replace(/:([^/]+)/g, "{$1}"); +} + +/** Extract path parameters from route path string */ +function extractPathParams(path) { + const params = []; + const matches = path.matchAll(/:([^/]+)/g); + for (const match of matches) { + params.push({ + name: match[1], + in: "path", + required: true, + schema: { type: "string" }, + }); + } + return params.length > 0 ? params : undefined; +} + +/** Try to convert a Zod/TypeBox schema to JSON Schema */ +function extractJsonSchema(schema) { + /* Zod schemas have a ._def property */ + if (schema?._def?.typeName) { + return zodToJsonSchema(schema); + } + /* TypeBox schemas are already JSON Schema */ + if (schema?.type || schema?.properties) { + return schema; + } + return { type: "object" }; +} + +/** Minimal Zod → JSON Schema conversion for common types */ +function zodToJsonSchema(schema) { + const def = schema._def; + switch (def.typeName) { + case "ZodString": + return { type: "string" }; + case "ZodNumber": + return { type: "number" }; + case "ZodBoolean": + return { type: "boolean" }; + case "ZodArray": + return { type: "array", items: zodToJsonSchema(def.type) }; + case "ZodObject": { + const properties = {}; + const required = []; + if (def.shape) { + const shape = typeof def.shape === "function" ? def.shape() : def.shape; + for (const [key, val] of Object.entries(shape)) { + properties[key] = zodToJsonSchema(val); + if (!val.isOptional?.()) required.push(key); + } + } + return { + type: "object", + properties, + required: required.length > 0 ? required : undefined, + }; + } + case "ZodOptional": + return zodToJsonSchema(def.innerType); + case "ZodEnum": + return { type: "string", enum: def.values }; + default: + return { type: "object" }; + } +} + +/** Extract query parameters from a validation schema */ +function extractQueryParams(schema) { + const params = []; + const shape = schema?._def?.shape + ? typeof schema._def.shape === "function" + ? schema._def.shape() + : schema._def.shape + : schema?.properties ?? {}; + + for (const [name, val] of Object.entries(shape)) { + params.push({ + name, + in: "query", + required: !val?.isOptional?.(), + schema: extractJsonSchema(val), + }); + } + return params; +} diff --git a/src/otel.js b/src/otel.js new file mode 100644 index 0000000..9e05480 --- /dev/null +++ b/src/otel.js @@ -0,0 +1,249 @@ +/** + * http-native OpenTelemetry integration middleware (DX-2.2) + * + * Emits W3C-compatible trace spans and request metrics. + * Works standalone (collects & exports) or integrates with an existing + * OpenTelemetry SDK setup. + * + * Usage: + * import { otel } from "@http-native/core/otel"; + * app.use(otel({ serviceName: "my-api", endpoint: "http://collector:4317" })); + */ + +import { randomBytes } from "node:crypto"; + +/** + * @param {Object} options + * @param {string} options.serviceName - Service name for trace/metric resource + * @param {string} [options.endpoint] - OTLP collector endpoint (e.g. "http://localhost:4317") + * @param {"w3c"|"b3"|"jaeger"} [options.propagation="w3c"] - Context propagation format + * @param {number} [options.sampleRate=1.0] - Fraction of requests to trace (0.0 - 1.0) + * @param {(spans: Object[]) => void} [options.exporter] - Custom span exporter function + * @param {boolean} [options.metrics=true] - Enable request metrics collection + * @param {number} [options.metricsInterval=60000] - Metrics flush interval in ms + */ +export function otel(options = {}) { + const { + serviceName = "http-native", + propagation = "w3c", + sampleRate = 1.0, + exporter, + metrics: enableMetrics = true, + metricsInterval = 60000, + } = options; + + /* Metrics accumulators */ + const counters = { requests: 0, errors: 0, rateLimitRejections: 0 }; + const histograms = { duration: [] }; + const statusCounts = Object.create(null); + const methodCounts = Object.create(null); + + /* Span buffer for batch export */ + const spanBuffer = []; + const MAX_SPAN_BUFFER = 512; + + /* Periodic metrics flush */ + let metricsTimer; + if (enableMetrics && exporter) { + metricsTimer = setInterval(() => { + flushMetrics(exporter, serviceName, counters, histograms, statusCounts, methodCounts); + }, metricsInterval); + if (metricsTimer.unref) metricsTimer.unref(); + } + + const middleware = async function otelMiddleware(req, res, next) { + /* Extract or generate trace context */ + const parentCtx = extractTraceContext(req, propagation); + const sampled = Math.random() < sampleRate; + + const traceId = parentCtx.traceId || generateTraceId(); + const spanId = generateSpanId(); + const parentSpanId = parentCtx.spanId || null; + + /* Attach trace context to request for downstream use */ + req.traceId = traceId; + req.spanId = spanId; + + const start = performance.now(); + let error = null; + + try { + await next(); + } catch (err) { + error = err; + throw err; + } finally { + const duration = performance.now() - start; + const status = res._state?.status ?? (error ? 500 : 200); + + /* Inject trace context into response headers */ + injectTraceContext(res, propagation, traceId, spanId, sampled); + + /* Build span */ + if (sampled) { + const span = { + traceId, + spanId, + parentSpanId, + operationName: "http.request", + serviceName, + startTime: Date.now() - duration, + duration: Math.round(duration * 1000) / 1000, + tags: { + "http.method": req.method, + "http.url": req.path, + "http.status_code": status, + "http.route": req._matchedRoute || req.path, + }, + status: error ? "ERROR" : "OK", + }; + + if (error) { + span.tags["error.message"] = error.message; + span.tags["error.type"] = error.constructor.name; + } + + if (req.id) span.tags["http.request_id"] = req.id; + + spanBuffer.push(span); + if (spanBuffer.length >= MAX_SPAN_BUFFER && exporter) { + const batch = spanBuffer.splice(0, spanBuffer.length); + exporter(batch); + } + } + + /* Metrics */ + if (enableMetrics) { + counters.requests++; + if (status >= 500) counters.errors++; + statusCounts[status] = (statusCounts[status] || 0) + 1; + methodCounts[req.method] = (methodCounts[req.method] || 0) + 1; + histograms.duration.push(duration); + /* Keep histogram bounded */ + if (histograms.duration.length > 10000) histograms.duration.splice(0, 5000); + } + } + }; + + /** Flush pending spans — exposed for graceful shutdown. */ + middleware.flushSpans = () => { + if (spanBuffer.length > 0 && exporter) { + const batch = spanBuffer.splice(0, spanBuffer.length); + exporter(batch); + } + }; + + /** Access the current span buffer length. */ + middleware.pendingSpans = () => spanBuffer.length; + + return middleware; +} + +/** + * Flush pending spans to the exporter. + * Call this on graceful shutdown to ensure no spans are lost. + * + * @param {Function} middleware - The otel middleware returned by otel() + */ +export function flushSpans(middleware) { + if (typeof middleware?.flushSpans === "function") { + middleware.flushSpans(); + } +} + +/* ─── Trace Context Propagation ─── */ + +function extractTraceContext(req, format) { + const get = (name) => req.header?.(name) ?? req.headers?.[name] ?? ""; + + if (format === "w3c") { + /* W3C Trace Context: traceparent header */ + const tp = get("traceparent"); + const match = tp.match(/^00-([0-9a-f]{32})-([0-9a-f]{16})-([0-9a-f]{2})$/); + if (match) { + return { traceId: match[1], spanId: match[2], sampled: match[3] === "01" }; + } + } else if (format === "b3") { + /* B3 single-header or multi-header */ + const single = get("b3"); + if (single) { + const parts = single.split("-"); + return { traceId: parts[0], spanId: parts[1], sampled: parts[2] !== "0" }; + } + const traceId = get("x-b3-traceid"); + const spanId = get("x-b3-spanid"); + if (traceId && spanId) { + return { traceId, spanId, sampled: get("x-b3-sampled") !== "0" }; + } + } else if (format === "jaeger") { + /* Jaeger uber-trace-id: {trace-id}:{span-id}:{parent-span-id}:{flags} */ + const uber = get("uber-trace-id"); + const parts = uber.split(":"); + if (parts.length >= 4) { + return { traceId: parts[0], spanId: parts[1], sampled: parts[3] !== "0" }; + } + } + + return { traceId: null, spanId: null, sampled: true }; +} + +function injectTraceContext(res, format, traceId, spanId, sampled) { + if (format === "w3c") { + res.set("traceparent", `00-${traceId}-${spanId}-${sampled ? "01" : "00"}`); + } else if (format === "b3") { + res.set("b3", `${traceId}-${spanId}-${sampled ? "1" : "0"}`); + } else if (format === "jaeger") { + res.set("uber-trace-id", `${traceId}:${spanId}:0:${sampled ? "1" : "0"}`); + } +} + +/* ─── ID Generation ─── */ + +function generateTraceId() { + return randomBytes(16).toString("hex"); +} + +function generateSpanId() { + return randomBytes(8).toString("hex"); +} + +/* ─── Metrics Flush ─── */ + +function flushMetrics(exporter, serviceName, counters, histograms, statusCounts, methodCounts) { + if (counters.requests === 0) return; + + const durations = histograms.duration; + const sorted = durations.slice().sort((a, b) => a - b); + const p50 = sorted[Math.floor(sorted.length * 0.5)] ?? 0; + const p95 = sorted[Math.floor(sorted.length * 0.95)] ?? 0; + const p99 = sorted[Math.floor(sorted.length * 0.99)] ?? 0; + + const metricsSpan = { + traceId: generateTraceId(), + spanId: generateSpanId(), + operationName: "metrics.flush", + serviceName, + startTime: Date.now(), + duration: 0, + tags: { + "metric.type": "summary", + "http.request.count": counters.requests, + "http.error.count": counters.errors, + "http.request.duration.p50": Math.round(p50 * 1000) / 1000, + "http.request.duration.p95": Math.round(p95 * 1000) / 1000, + "http.request.duration.p99": Math.round(p99 * 1000) / 1000, + "http.status_counts": { ...statusCounts }, + "http.method_counts": { ...methodCounts }, + }, + status: "OK", + }; + + exporter([metricsSpan]); + + /* Reset accumulators */ + counters.requests = 0; + counters.errors = 0; + histograms.duration.length = 0; + for (const k in statusCounts) delete statusCounts[k]; + for (const k in methodCounts) delete methodCounts[k]; +} diff --git a/src/rate-limit.js b/src/rate-limit.js index c2556f1..7ec9295 100644 --- a/src/rate-limit.js +++ b/src/rate-limit.js @@ -67,20 +67,14 @@ function parseForwardedFor(value) { } function defaultRequestKey(req) { + /* Use the peer IP from the native connection — this cannot be spoofed. + * Only fall back to proxy headers if req.ip is unavailable (e.g., tests). + * Users behind a reverse proxy should provide a custom `key` function + * that extracts the real IP from trusted proxy headers. */ if (typeof req?.ip === "string" && req.ip.trim() !== "") { return req.ip.trim(); } - const forwarded = parseForwardedFor(getHeader(req, "x-forwarded-for")); - if (forwarded) { - return forwarded; - } - - const realIp = getHeader(req, "x-real-ip"); - if (realIp) { - return realIp.trim(); - } - return "unknown"; } diff --git a/src/request-id.js b/src/request-id.js new file mode 100644 index 0000000..93e7bc6 --- /dev/null +++ b/src/request-id.js @@ -0,0 +1,75 @@ +/** + * http-native request ID middleware. + * + * Generates or propagates a unique request identifier on every request. + * The ID is attached to `req.id` and echoed back in a configurable + * response header for end-to-end correlation in distributed systems. + * + * Usage: + * import { requestId } from "@http-native/core/request-id"; + * + * // Defaults: reads/writes "x-request-id", generates crypto.randomUUID() + * app.use(requestId()); + * + * // Custom header + generator + * app.use(requestId({ + * header: "x-correlation-id", + * generate: () => `req-${Date.now()}-${Math.random().toString(36).slice(2, 8)}`, + * })); + * + * // Disable response header echo + * app.use(requestId({ responseHeader: false })); + */ + +import { randomUUID } from "node:crypto"; + +// ─── Default Generator ───────────────── + +/** + * Default ID generator using crypto.randomUUID() — RFC 4122 v4 UUID. + * Available in Bun, Node 19+, and modern browsers. + * + * @returns {string} UUID v4 string + */ +function defaultGenerate() { + return randomUUID(); +} + +// ─── Middleware Factory ──────────────── + +/** + * Create a request ID middleware. + * + * @param {Object} [options] + * @param {string} [options.header="x-request-id"] - Incoming header to read + * @param {string|false} [options.responseHeader] - Response header to set (defaults to same as header; false to disable) + * @param {() => string} [options.generate] - Custom ID generator function + * @returns {Function} Middleware function + */ +export function requestId(options = {}) { + if (typeof options !== "object" || options === null) { + throw new TypeError("requestId(options) expects an object"); + } + + const headerName = String(options.header ?? "x-request-id").toLowerCase(); + const responseHeader = options.responseHeader === false + ? null + : String(options.responseHeader ?? headerName).toLowerCase(); + const generate = options.generate ?? defaultGenerate; + + if (typeof generate !== "function") { + throw new TypeError("requestId generate must be a function"); + } + + return async function requestIdMiddleware(req, res, next) { + /* Propagate existing ID from upstream proxy, or generate a new one */ + const id = req.header(headerName) || generate(); + req.id = id; + + if (responseHeader) { + res.set(responseHeader, id); + } + + await next(); + }; +} diff --git a/src/session.d.ts b/src/session.d.ts new file mode 100644 index 0000000..6640fb3 --- /dev/null +++ b/src/session.d.ts @@ -0,0 +1,116 @@ +import type { Middleware, Session } from "./index.js"; + +// ─── Session Store Interface ─────────── + +/** + * Pluggable session store contract. Implementations must provide + * get/set/delete/destroy/getAll operations. Operations may be sync + * (MemoryStore) or async (RedisStore, custom stores). + */ +export interface SessionStore { + /** Retrieve a single session value by key */ + get(sessionId: string, key: string): unknown | Promise; + + /** Set a single session value */ + set(sessionId: string, key: string, value: unknown): void | Promise; + + /** Delete a single session key */ + delete(sessionId: string, key: string): void | Promise; + + /** Destroy the entire session (all keys) */ + destroy(sessionId: string): void | Promise; + + /** Retrieve all session data as a key-value record */ + getAll(sessionId: string): Record | null | Promise | null>; + + /** Replace all session data (optional — used for bulk session restoration) */ + setAll?(sessionId: string, data: Record): void | Promise; +} + +// ─── Built-in Stores ─────────────────── + +/** + * In-memory session store backed by Rust's native sharded RwLock. + * All operations are synchronous (direct NAPI calls into the Rust layer). + */ +export class MemoryStore implements SessionStore { + constructor(); + get(sessionId: string, key: string): unknown; + set(sessionId: string, key: string, value: unknown): void; + delete(sessionId: string, key: string): void; + destroy(sessionId: string): void; + getAll(sessionId: string): Record | null; + setAll(sessionId: string, data: Record): void; +} + +export interface RedisStoreOptions { + /** Key prefix for session hashes (default "sess:") */ + prefix?: string; + /** TTL in seconds (default: from session config maxAge) */ + maxAge?: number; +} + +/** + * Redis session store. Requires an ioredis-compatible client. + * All operations are async (Redis round-trips). + */ +export class RedisStore implements SessionStore { + constructor(client: unknown, options?: RedisStoreOptions); + get(sessionId: string, key: string): Promise; + set(sessionId: string, key: string, value: unknown): Promise; + delete(sessionId: string, key: string): Promise; + destroy(sessionId: string): Promise; + getAll(sessionId: string): Promise | null>; + setAll(sessionId: string, data: Record): Promise; +} + +// ─── Session Middleware ───────────────── + +export interface SessionMiddlewareOptions { + /** HMAC signing secret (required) */ + secret: string; + + /** Session TTL in seconds (default 3600) */ + maxAge?: number; + + /** Cookie name (default "sid") */ + cookieName?: string; + + /** HttpOnly flag (default true) */ + httpOnly?: boolean; + + /** Secure flag (default false) */ + secure?: boolean; + + /** SameSite policy (default "lax") */ + sameSite?: "strict" | "lax" | "none"; + + /** Cookie path (default "/") */ + path?: string; + + /** Pluggable session store (default: MemoryStore — Rust-backed) */ + store?: SessionStore; + + /** Maximum sessions per shard before LRU eviction (default 100_000) */ + maxSessions?: number; + + /** Maximum serialized data size per session in bytes (default 4096) */ + maxDataSize?: number; +} + +/** + * Create a session middleware. + * + * Default store: Rust in-memory (sharded RwLock, cross-worker safe). + * Pluggable: pass any store with get/set/delete/destroy/getAll methods. + * + * @example + * // In-memory (default, Rust-backed) + * app.use(session({ secret: "my-key" })); + * + * @example + * // Redis + * import Redis from "ioredis"; + * app.use(session({ secret: "my-key", store: new RedisStore(new Redis()) })); + */ +export function session(options: SessionMiddlewareOptions): Middleware; diff --git a/src/session.js b/src/session.js index db4f8c3..025ac73 100644 --- a/src/session.js +++ b/src/session.js @@ -357,13 +357,3 @@ function capitalize(s) { return s.charAt(0).toUpperCase() + s.slice(1).toLowerCase(); } -// ─── Session Trailer (for backward compat) ── - -/** - * Encode session write trailer. Returns null since session ops now go - * directly through NAPI — no trailer needed for MemoryStore. - * Kept for API compatibility. - */ -export function encodeSessionTrailer(sessionState) { - return null; -} diff --git a/src/test.js b/src/test.js new file mode 100644 index 0000000..3d3b680 --- /dev/null +++ b/src/test.js @@ -0,0 +1,140 @@ +/** + * http-native test utilities + * + * Provides a lightweight test client that spins up the app on an ephemeral + * port and exposes fetch-style methods for integration testing. + * + * Usage: + * import { testClient } from "@http-native/core/test"; + * + * const client = await testClient(app); + * const res = await client.get("/users/1"); + * expect(res.status).toBe(200); + * await client.close(); + */ + +/** + * Create a test client bound to an ephemeral port. + * + * @param {import("./index.js").Application} app + * @param {{ port?: number, host?: string }} [options] + * @returns {Promise} + */ +export async function testClient(app, options = {}) { + const port = options.port ?? 0; + const host = options.host ?? "127.0.0.1"; + + /* Start the server on an ephemeral port — the OS assigns a free one */ + let serverHandle; + const baseUrl = await new Promise((resolve, reject) => { + try { + const builder = app.listen(port); + if (typeof builder.host === "function") builder.host(host); + builder.start((handle) => { + serverHandle = handle; + const addr = handle.address?.() ?? { port: handle.port ?? port }; + resolve(`http://${host}:${addr.port ?? port}`); + }); + } catch (err) { + reject(err); + } + }); + + /** + * @param {string} path + * @param {RequestInit & { json?: unknown }} [init] + */ + async function request(path, init = {}) { + const url = `${baseUrl}${path}`; + const headers = { ...init.headers }; + + /* Convenience: if `json` is set, auto-serialize and set content-type */ + let body = init.body; + if (init.json !== undefined) { + body = JSON.stringify(init.json); + headers["content-type"] = headers["content-type"] ?? "application/json"; + } + + const res = await fetch(url, { ...init, headers, body }); + + /* Attach helper methods for easy assertion */ + const wrapped = { + status: res.status, + headers: Object.fromEntries(res.headers.entries()), + ok: res.ok, + /** Parse body as JSON */ + json: () => res.json(), + /** Read body as text */ + text: () => res.text(), + /** Access raw Response object */ + raw: res, + }; + return wrapped; + } + + return { + /** Base URL the server is listening on (e.g. "http://127.0.0.1:54321") */ + baseUrl, + + /** Raw request — any method */ + request, + + /** GET helper */ + get: (path, init) => request(path, { ...init, method: "GET" }), + + /** POST helper */ + post: (path, init) => request(path, { ...init, method: "POST" }), + + /** PUT helper */ + put: (path, init) => request(path, { ...init, method: "PUT" }), + + /** PATCH helper */ + patch: (path, init) => request(path, { ...init, method: "PATCH" }), + + /** DELETE helper */ + delete: (path, init) => request(path, { ...init, method: "DELETE" }), + + /** Open a WebSocket connection */ + ws: (path) => { + const wsUrl = baseUrl.replace(/^http/, "ws") + path; + const ws = new WebSocket(wsUrl); + return new Promise((resolve) => { + const messages = []; + let nextResolve = null; + + ws.onmessage = (event) => { + if (nextResolve) { + nextResolve(event.data); + nextResolve = null; + } else { + messages.push(event.data); + } + }; + + ws.onopen = () => { + resolve({ + send: (data) => ws.send(data), + /** Await the next message from the server */ + next: () => + messages.length > 0 + ? Promise.resolve(messages.shift()) + : new Promise((r) => { nextResolve = r; }), + close: () => ws.close(), + raw: ws, + }); + }; + }); + }, + + /** Shut down the test server */ + close: async () => { + if (serverHandle) { + if (typeof serverHandle.shutdown === "function") { + await serverHandle.shutdown({ timeout: 2000 }); + } else if (typeof serverHandle.close === "function") { + serverHandle.close(); + } + } + }, + }; +}