diff --git a/.bleep b/.bleep index bcde257d..a9416ff6 100644 --- a/.bleep +++ b/.bleep @@ -1 +1 @@ -ed8657309187516d2e673037821a9fbd8405d703 \ No newline at end of file +d2c25d726c5738e6a8028dc3e7642ecfe6c1824e diff --git a/.cargo/config.toml b/.cargo/config.toml new file mode 100644 index 00000000..3c1f3636 --- /dev/null +++ b/.cargo/config.toml @@ -0,0 +1,2 @@ +[resolver] +incompatible-rust-versions = "fallback" \ No newline at end of file diff --git a/.github/workflows/audit.yml b/.github/workflows/audit.yml index cd10b8c0..6fe67dea 100644 --- a/.github/workflows/audit.yml +++ b/.github/workflows/audit.yml @@ -24,7 +24,7 @@ jobs: - name: Generate Cargo.lock # https://github.com/rustsec/audit-check/issues/27 - run: cargo generate-lockfile + run: cargo generate-lockfile --ignore-rust-version - name: Audit Check # https://github.com/rustsec/audit-check/issues/2 diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 43c8aa9d..22a4c458 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -8,7 +8,7 @@ jobs: fail-fast: false matrix: # nightly, msrv, and latest stable - toolchain: [nightly, 1.83.0, 1.87.0] + toolchain: [nightly, 1.84.0, 1.91.1] runs-on: ubuntu-latest # Only run on "pull_request" event for external PRs. This is to avoid # duplicate builds for PRs created from internal branches. @@ -48,12 +48,12 @@ jobs: - name: Run cargo clippy run: | - [[ ${{ matrix.toolchain }} != 1.87.0 ]] || cargo clippy --all-targets --all -- --allow=unknown-lints --deny=warnings + [[ ${{ matrix.toolchain }} != 1.91.1 ]] || cargo clippy --all-targets --all -- --allow=unknown-lints --deny=warnings - name: Run cargo audit run: | - [[ ${{ matrix.toolchain }} != 1.87.0 ]] || (cargo install --locked cargo-audit && cargo audit) + [[ ${{ matrix.toolchain }} != 1.91.1 ]] || (cargo install --locked cargo-audit && cargo generate-lockfile --ignore-rust-version && cargo audit) - name: Run cargo machete run: | - [[ ${{ matrix.toolchain }} != 1.87.0 ]] || (cargo install cargo-machete --version 0.7.0 && cargo machete) + [[ ${{ matrix.toolchain }} != 1.91.1 ]] || (cargo install cargo-machete --version 0.7.0 && cargo machete) diff --git a/CHANGELOG.md b/CHANGELOG.md index e7c25c7f..e8cb8dbe 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,10 +2,125 @@ All notable changes to this project will be documented in this file. +## [0.8.0](https://github.com/cloudflare/pingora/compare/0.7.0...0.8.0) - 2026-03-02 + + +**πŸš€ Features** + +* Add support for client certificate verification in mTLS configuration. +* Add upstream\_write\_pending\_time to Session for upload diagnostics. +* Pipe subrequests utility: creates a state machine to treat subrequests as a "pipe," enabling direct sending of request body and writing of response tasks, with a handler for error propagation and support for reusing a preset or captured input body for chained subrequests. +* Add the ability to limit the number of times a downstream connection can be reused +* Add a system for specifying and using service-level dependencies +* Add a builder for pingora proxy service, e.g. to specify ServerOptions. + +**πŸ› Bug Fixes** + +* Fix various Windows compiler issues. +* Handle custom ALPNs in s2n impl of ALPN::to\_wire\_protocols() to fix s2n compile issues. +* Fix: don't use β€œall” permissions for socket. +* Fix a bug with the ketama load balancing where configurations were not persisted after updates. +* Ensure http1 downstream session is not reused on more body bytes than expected. +* Send RST\_STREAM CANCEL on application read timeouts for h2 client. +* Start close-delimited body mode after 101 is received for WebSocket upgrades. `UpgradedBody` is now an explicit HttpTask. +* Avoid close delimit mode on http/1.0 req. +* Reject invalid content-length http/1 requests to eliminate ambiguous request framing. +* Validate invalid content-length on http/1 resp by default, and removes content-length from the response if transfer-encoding is present, per RFC. +* Correct the custom protocol code for shutdown: changed the numeric code passed on shutdown to 0 to indicate an explicit shutdown rather than a transport error. + +**βš™οΈ Miscellaneous Tasks** + +* Remove `CacheKey::default` impl, users of caching should implement `cache_key_callback` themselves +* Allow server bootstrapping to take place in the context of services with dependents and dependencies +* Don't consider "bytes=" a valid range header: added an early check for an empty/whitespace-only range-set after the `bytes=` prefix, returning 416 Range Not Satisfiable, consistent with RFC 9110 14.1.2. +* Strip {content, transfer}-encoding from 416s to mirror the behavior for 304 Not Modified responses. +* Disable CONNECT method proxying by default, with an option to enable via server options; unsupported requests will now be automatically rejected. + +## [0.7.0](https://github.com/cloudflare/pingora/compare/0.6.0...0.7.0) - 2026-01-30 + +### Highlights + +- Extensible SslDigest to save user-defined TLS context +- Add ConnectionFilter trait for early TCP connection filtering + +### πŸš€ Features + +- Add ConnectionFilter trait for early TCP connection filtering +- Introduce a virtual L4 stream abstraction +- Add support for verify_cert and verify_hostname using rustls +- Exposes the HttpProxy struct to allow external crates to customize the proxy logic. +- Exposes a new_mtls method for creating a HttpProxy with a client_cert_key to enable mtls peers. +- Add SSLKEYLOGFILE support to rustls connector +- Allow spawning background subrequests from main session +- Allow Extensions in cache LockCore and user tracing +- Add body-bytes tracking across H1/H2 and proxy metrics +- Allow setting max_weight on MissFinishType::Appended +- Allow adding SslDigestExtensions on downstream and upstream +- Add Custom session support for encapsulated HTTP + +### πŸ› Bug Fixes + +- Use write timeout consistently for h2 body writes +- Prevent downstream error prior to header from canceling cache fill +- Fix debug log and new tests +- Fix size calculation for buffer capacity +- Fix cache admission on header only misses +- Fix duplicate zero-size chunk on cache hit +- Fix chunked trailer end parsing +- Lock age timeouts cause lock reacquisition +- Fix transfer fd compile error for non linux os + +### Sec + +- Removed atty +- Upgrade lru to >= 0.16.3 crate version because of RUSTSEC-2026-0002 + +### Everything Else + +- Add tracing to log reason for not caching an asset on cache put +- Evict when asset count exceeds optional watermark +- Remove trailing comma from Display for HttpPeer +- Make ProxyHTTP::upstream_response_body_filter return an optional duration for rate limiting +- Restore daemonize STDOUT/STDERR when error log file is not specified +- Log task info when upstream header failed to send +- Check cache enablement to determine cache fill +- Update meta when revalidating before lock release +- Add ForceFresh status to cache hit filter +- Pass stale status to cache lock +- Bump max multipart ranges to 200 +- Downgrade Expires header warn to debug log +- CI and effective msrv bump to 1.83 +- Add default noop custom param to client Session +- Use static str in ErrorSource or ErrorType as_str +- Use bstr for formatting byte strings +- Tweak the implementation of and documentation of `connection_filter` feature +- Set h1.1 when proxying cacheable responses +- Add or remove accept-ranges on range header filter +- Update msrv in github ci, fixup .bleep +- Override request keepalive on process shutdown +- Add shutdown flag to proxy session +- Add ResponseHeader in pingora_http crate's prelude +- Add a configurable upgrade for pingora-ketama that reduces runtime cpu and memory +- Add to cache api spans +- Increase visibility of multirange items +- Use seek_multipart on body readers +- Log read error when reading trailers end +- Re-add the warning about cache-api volatility +- Default to close on downstream response before body finish +- Ensure idle_timeout is polled even if idle_timeout is unset so notify events are registered for h2 idle pool, filter out closed connections when retrieving from h2 in use pool. +- Add simple read test for invalid extra char in header end +- Allow customizing lock status on Custom NoCacheReasons +- Close h1 conn by default if req header unfinished +- Add configurable retries for upgrade sock connect/accept +- Deflake test by increasing write size +- Make the version restrictions on rmp and rmp-serde more strict to prevent forcing consumers to use 2024 edition +- Rewind preread bytes when parsing next H1 response +- Add epoch and epoch_override to CacheMeta + ## [0.6.0](https://github.com/cloudflare/pingora/compare/0.5.0...0.6.0) - 2025-08-15 - + ### Highlights -- This release bumps the minimum h2 crate dependency to guard against the [MadeYouReset]((https://blog.cloudflare.com/madeyoureset-an-http-2-vulnerability-thwarted-by-rapid-reset-mitigations/)) H2 attack +- This release bumps the minimum h2 crate dependency to guard against the [MadeYouReset]((https://blog.cloudflare.com/madeyoureset-an-http-2-vulnerability-thwarted-by-rapid-reset-mitigations/)) H2 attack ### πŸš€ Features @@ -63,7 +178,7 @@ All notable changes to this project will be documented in this file. ## [0.5.0](https://github.com/cloudflare/pingora/compare/0.4.0...0.5.0) - 2025-05-09 - + ### πŸš€ Features - [Add tweak_new_upstream_tcp_connection hook to invoke logic on new upstream TCP sockets prior to connection](https://github.com/cloudflare/pingora/commit/be4a023d18c2b061f64ad5efd0868f9498199c91) @@ -76,7 +191,7 @@ All notable changes to this project will be documented in this file. - [Add get_stale and get_stale_while_update for memory-cache](https://github.com/cloudflare/pingora/commit/bb28044cbe9ac9251940b8a313d970c7d15aaff6) ### πŸ› Bug Fixes - + - [Fix deadloop if proxy_handle_upstream exits earlier than proxy_handle_downstream](https://github.com/cloudflare/pingora/commit/bb111aaa92b3753e650957df3a68f56b0cffc65d) - [Check on h2 stream end if error occurred for forwarding HTTP tasks](https://github.com/cloudflare/pingora/commit/e18f41bb6ddb1d6354e824df3b91d77f3255bea2) - [Check for content-length underflow on end of stream h2 header](https://github.com/cloudflare/pingora/commit/575d1aafd7c679a50a443701a4c55dcfdbc443b2) @@ -91,9 +206,9 @@ All notable changes to this project will be documented in this file. - [Always drain v1 request body before session reuse](https://github.com/cloudflare/pingora/commit/fda3317ec822678564d641e7cf1c9b77ee3759ff) - [Fixes HTTP1 client reads to properly timeout on initial read](https://github.com/cloudflare/pingora/commit/3c7db34acb0d930ae7043290a88bc56c1cd77e45) - [Fixes issue where if TLS client never sends any bytes, hangs forever](https://github.com/cloudflare/pingora/commit/d1bf0bcac98f943fd716278d674e7d10dce2223e) - + ### Everything Else - + - [Add builder api for pingora listeners](https://github.com/cloudflare/pingora/commit/3f564af3ae56e898478e13e71d67d095d7f5dbbd) - [Better handling for h1 requests that contain both transfer-encoding and content-length](https://github.com/cloudflare/pingora/commit/9287b82645be4a52b0b63530ba38aa0c7ddc4b77) - [Allow setting raw path in request to support non-UTF8 use cases](https://github.com/cloudflare/pingora/commit/e6b823c5d89860bb97713fdf14f197f799aed6af) @@ -209,7 +324,7 @@ All notable changes to this project will be documented in this file. ## [0.1.1](https://github.com/cloudflare/pingora/compare/0.1.0...0.1.1) - 2024-04-05 ### πŸš€ Features -- `Server::new` now accepts `Into>` +- `Server::new` now accepts `Into>` - Implemented client `HttpSession::get_keepalive_values` for Keep-Alive parsing - Expose `ListenFds` and `Fds` to fix a voldemort types issue - Expose config options in `ServerConf`, provide new `Server` constructor diff --git a/Cargo.toml b/Cargo.toml index ce057972..d3c8603b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,16 +29,18 @@ members = [ ] [workspace.dependencies] +bstr = "1.12.0" tokio = "1" +tokio-stream = { version = "0.1" } async-trait = "0.1.42" httparse = "1" bytes = "1.0" derivative = "2.2.0" -http = "1.0.0" +http = "1" log = "0.4" h2 = ">=0.4.11" once_cell = "1" -lru = "0.14" +lru = "0.16.3" ahash = ">=0.8.9" [profile.bench] diff --git a/README.md b/README.md index 1cc716dc..94fd1a59 100644 --- a/README.md +++ b/README.md @@ -59,11 +59,11 @@ Both x86_64 and aarch64 architectures will be supported. ## Rust version -Pingora keeps a rolling MSRV (minimum supported Rust version) policy of 6 months. This means we will accept PRs that upgrade the MSRV as long as the new Rust version used is at least 6 months old. +Pingora keeps a rolling MSRV (minimum supported Rust version) policy of 6 months. This means we will accept PRs that upgrade the MSRV as long as the new Rust version used is at least 6 months old. However, we generally will not bump the highest MSRV across the workspace without a sufficiently compelling reason. -Our current MSRV is effectively 1.83. +Our current MSRV is 1.84. -Previously Pingora advertised an MSRV of 1.72. Older Rust versions may still be able to compile via `cargo update` pinning dependencies such as `backtrace@0.3.74`. The advertised MSRV in config files will be officially bumped to 1.83 in an upcoming release. +Currently not all crates enforce `rust-version` as it is possible to use some crates on lower versions. ## Build Requirements diff --git a/clippy.toml b/clippy.toml index ebba0354..83a5e087 100644 --- a/clippy.toml +++ b/clippy.toml @@ -1 +1 @@ -msrv = "1.72" +msrv = "1.84" diff --git a/docs/user_guide/rate_limiter.md b/docs/user_guide/rate_limiter.md index fe337a19..31a6b5a9 100644 --- a/docs/user_guide/rate_limiter.md +++ b/docs/user_guide/rate_limiter.md @@ -20,7 +20,6 @@ Pingora provides a crate `pingora-limits` which provides a simple and easy to us ```rust use async_trait::async_trait; use once_cell::sync::Lazy; -use pingora::http::ResponseHeader; use pingora::prelude::*; use pingora_limits::rate::Rate; use std::sync::Arc; @@ -135,11 +134,11 @@ impl ProxyHttp for LB { ``` ## Testing -To use the example above, +To use the example above, -1. Run your program with `cargo run`. +1. Run your program with `cargo run`. 2. Verify the program is working with a few executions of ` curl localhost:6188 -H "appid:1" -v` - - The first request should work and any later requests that arrive within 1s of a previous request should fail with: + - The first request should work and any later requests that arrive within 1s of a previous request should fail with: ``` * Trying 127.0.0.1:6188... * Connected to localhost (127.0.0.1) port 6188 (#0) @@ -148,20 +147,20 @@ To use the example above, > User-Agent: curl/7.88.1 > Accept: */* > appid:1 - > + > < HTTP/1.1 429 Too Many Requests < X-Rate-Limit-Limit: 1 < X-Rate-Limit-Remaining: 0 < X-Rate-Limit-Reset: 1 < Date: Sun, 14 Jul 2024 20:29:02 GMT < Connection: close - < + < * Closing connection 0 ``` ## Complete Example -You can run the pre-made example code in the [`pingora-proxy` examples folder](https://github.com/cloudflare/pingora/tree/main/pingora-proxy/examples/rate_limiter.rs) with +You can run the pre-made example code in the [`pingora-proxy` examples folder](https://github.com/cloudflare/pingora/tree/main/pingora-proxy/examples/rate_limiter.rs) with ``` cargo run --example rate_limiter -``` \ No newline at end of file +``` diff --git a/pingora-boringssl/Cargo.toml b/pingora-boringssl/Cargo.toml index ec0b7bc0..03086460 100644 --- a/pingora-boringssl/Cargo.toml +++ b/pingora-boringssl/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pingora-boringssl" -version = "0.6.0" +version = "0.8.0" authors = ["Yuchen Wu "] license = "Apache-2.0" edition = "2021" diff --git a/pingora-boringssl/src/boring_tokio.rs b/pingora-boringssl/src/boring_tokio.rs index 4dd2f91e..ef5d60c2 100644 --- a/pingora-boringssl/src/boring_tokio.rs +++ b/pingora-boringssl/src/boring_tokio.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -263,9 +263,7 @@ where return Poll::Pending; } Err(e) => { - return Poll::Ready(Err(e - .into_io_error() - .unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e)))); + return Poll::Ready(Err(e.into_io_error().unwrap_or_else(io::Error::other))); } } diff --git a/pingora-boringssl/src/ext.rs b/pingora-boringssl/src/ext.rs index 0af2bb0b..256e4ac5 100644 --- a/pingora-boringssl/src/ext.rs +++ b/pingora-boringssl/src/ext.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-boringssl/src/lib.rs b/pingora-boringssl/src/lib.rs index dd560a84..9701c598 100644 --- a/pingora-boringssl/src/lib.rs +++ b/pingora-boringssl/src/lib.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-cache/Cargo.toml b/pingora-cache/Cargo.toml index cd51b638..401d827c 100644 --- a/pingora-cache/Cargo.toml +++ b/pingora-cache/Cargo.toml @@ -1,9 +1,10 @@ [package] name = "pingora-cache" -version = "0.6.0" +version = "0.8.0" authors = ["Yuchen Wu "] license = "Apache-2.0" edition = "2021" +rust-version = "1.84" repository = "https://github.com/cloudflare/pingora" categories = ["asynchronous", "network-programming"] keywords = ["async", "http", "cache"] @@ -17,19 +18,20 @@ name = "pingora_cache" path = "src/lib.rs" [dependencies] -pingora-core = { version = "0.6.0", path = "../pingora-core", default-features = false } -pingora-error = { version = "0.6.0", path = "../pingora-error" } -pingora-header-serde = { version = "0.6.0", path = "../pingora-header-serde" } -pingora-http = { version = "0.6.0", path = "../pingora-http" } -pingora-lru = { version = "0.6.0", path = "../pingora-lru" } -pingora-timeout = { version = "0.6.0", path = "../pingora-timeout" } +pingora-core = { version = "0.8.0", path = "../pingora-core", default-features = false } +pingora-error = { version = "0.8.0", path = "../pingora-error" } +pingora-header-serde = { version = "0.8.0", path = "../pingora-header-serde" } +pingora-http = { version = "0.8.0", path = "../pingora-http" } +pingora-lru = { version = "0.8.0", path = "../pingora-lru" } +pingora-timeout = { version = "0.8.0", path = "../pingora-timeout" } +bstr = { workspace = true } http = { workspace = true } indexmap = "1" once_cell = { workspace = true } regex = "1" blake2 = "0.10" serde = { version = "1.0", features = ["derive"] } -rmp-serde = "1" +rmp-serde = "1.3.0" bytes = { workspace = true } httpdate = "1.0.2" log = { workspace = true } @@ -37,7 +39,7 @@ async-trait = { workspace = true } parking_lot = "0.12" cf-rustracing = "1.0" cf-rustracing-jaeger = "1.0" -rmp = "0.8" +rmp = "0.8.14" tokio = { workspace = true } lru = { workspace = true } ahash = { workspace = true } @@ -49,7 +51,7 @@ rand = "0.8" [dev-dependencies] tokio-test = "0.4" tokio = { workspace = true, features = ["fs"] } -env_logger = "0.9" +env_logger = "0.11" dhat = "0" futures = "0.3" diff --git a/pingora-cache/benches/lru_memory.rs b/pingora-cache/benches/lru_memory.rs index 1d0678dc..67428671 100644 --- a/pingora-cache/benches/lru_memory.rs +++ b/pingora-cache/benches/lru_memory.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-cache/benches/lru_serde.rs b/pingora-cache/benches/lru_serde.rs index 5c0809e4..237a827e 100644 --- a/pingora-cache/benches/lru_serde.rs +++ b/pingora-cache/benches/lru_serde.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-cache/benches/simple_lru_memory.rs b/pingora-cache/benches/simple_lru_memory.rs index 30500c72..fa1199e3 100644 --- a/pingora-cache/benches/simple_lru_memory.rs +++ b/pingora-cache/benches/simple_lru_memory.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-cache/src/cache_control.rs b/pingora-cache/src/cache_control.rs index 8083298e..98af7fbb 100644 --- a/pingora-cache/src/cache_control.rs +++ b/pingora-cache/src/cache_control.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -255,13 +255,13 @@ impl CacheControl { self.has_key_without_value("private") } - fn get_field_names(&self, key: &str) -> Option { + fn get_field_names(&self, key: &str) -> Option> { let value = self.directives.get(key)?.as_ref()?; Some(ListValueIter::from(value)) } /// Get the values of `private=` - pub fn private_field_names(&self) -> Option { + pub fn private_field_names(&self) -> Option> { self.get_field_names("private") } @@ -271,7 +271,7 @@ impl CacheControl { } /// Get the values of `no-cache=` - pub fn no_cache_field_names(&self) -> Option { + pub fn no_cache_field_names(&self) -> Option> { self.get_field_names("no-cache") } diff --git a/pingora-cache/src/eviction/lru.rs b/pingora-cache/src/eviction/lru.rs index 7b4846b9..d241ee69 100644 --- a/pingora-cache/src/eviction/lru.rs +++ b/pingora-cache/src/eviction/lru.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -62,6 +62,29 @@ impl Manager { Manager(Lru::with_capacity_and_watermark(limit, capacity, watermark)) } + /// Get the number of shards + pub fn shards(&self) -> usize { + self.0.shards() + } + + /// Get the weight (total size) of a specific shard + pub fn shard_weight(&self, shard: usize) -> usize { + self.0.shard_weight(shard) + } + + /// Get the number of items in a specific shard + pub fn shard_len(&self, shard: usize) -> usize { + self.0.shard_len(shard) + } + + /// Get the shard index for a given cache key + /// + /// This allows callers to know which shard was affected by an operation + /// without acquiring any locks. + pub fn get_shard_for_key(&self, key: &CompactCacheKey) -> usize { + (u64key(key) % N as u64) as usize + } + /// Serialize the given shard pub fn serialize_shard(&self, shard: usize) -> Result> { use rmp_serde::encode::Serializer; @@ -101,6 +124,12 @@ impl Manager { .or_err(InternalError, "when deserializing LRU")?; Ok(()) } + + /// Peek the weight associated with a cache key without changing its LRU order. + pub fn peek_weight(&self, item: &CompactCacheKey) -> Option { + let key = u64key(item); + self.0.peek_weight(key) + } } struct InsertToManager<'a, const N: usize> { @@ -171,9 +200,14 @@ impl EvictionManager for Manager { .collect() } - fn increment_weight(&self, item: CompactCacheKey, delta: usize) -> Vec { - let key = u64key(&item); - self.0.increment_weight(key, delta); + fn increment_weight( + &self, + item: &CompactCacheKey, + delta: usize, + max_weight: Option, + ) -> Vec { + let key = u64key(item); + self.0.increment_weight(key, delta, max_weight); self.0 .evict_to_limit() .into_iter() diff --git a/pingora-cache/src/eviction/mod.rs b/pingora-cache/src/eviction/mod.rs index cd48cd4a..0e78fbe1 100644 --- a/pingora-cache/src/eviction/mod.rs +++ b/pingora-cache/src/eviction/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -60,10 +60,18 @@ pub trait EvictionManager: Send + Sync { /// Adjust an item's weight upwards by a delta. If the item is not already admitted, /// nothing will happen. /// + /// An optional `max_weight` hint indicates the known max weight of the current key in case the + /// weight should not be incremented above this amount. + /// /// Return one or more items to evict. The sizes of these items are deducted /// from the total size already. The caller needs to make sure that these assets are actually /// removed from the storage. - fn increment_weight(&self, item: CompactCacheKey, delta: usize) -> Vec; + fn increment_weight( + &self, + item: &CompactCacheKey, + delta: usize, + max_weight: Option, + ) -> Vec; /// Remove an item from the eviction manager. /// diff --git a/pingora-cache/src/eviction/simple_lru.rs b/pingora-cache/src/eviction/simple_lru.rs index 3125dfb4..1c887552 100644 --- a/pingora-cache/src/eviction/simple_lru.rs +++ b/pingora-cache/src/eviction/simple_lru.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -124,7 +124,7 @@ impl Manager { if self.used.load(Ordering::Relaxed) <= self.limit && self .items_watermark - .map_or(true, |w| self.items.load(Ordering::Relaxed) <= w) + .is_none_or(|w| self.items.load(Ordering::Relaxed) <= w) { return vec![]; } @@ -235,8 +235,13 @@ impl EvictionManager for Manager { self.evict() } - fn increment_weight(&self, item: CompactCacheKey, delta: usize) -> Vec { - let key = u64key(&item); + fn increment_weight( + &self, + item: &CompactCacheKey, + delta: usize, + _max_weight: Option, + ) -> Vec { + let key = u64key(item); self.increase_weight(key, delta); self.evict() } diff --git a/pingora-cache/src/filters.rs b/pingora-cache/src/filters.rs index 20202ea2..607e6303 100644 --- a/pingora-cache/src/filters.rs +++ b/pingora-cache/src/filters.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -89,7 +89,7 @@ pub fn calculate_fresh_until( if authorization_present { let uncacheable = cache_control .as_ref() - .map_or(true, |cc| !cc.allow_caching_authorized_req()); + .is_none_or(|cc| !cc.allow_caching_authorized_req()); if uncacheable { return None; } diff --git a/pingora-cache/src/hashtable.rs b/pingora-cache/src/hashtable.rs index 52292046..07ca5f3f 100644 --- a/pingora-cache/src/hashtable.rs +++ b/pingora-cache/src/hashtable.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -49,11 +49,11 @@ where } #[allow(dead_code)] - pub fn read(&self, key: u128) -> RwLockReadGuard> { + pub fn read(&self, key: u128) -> RwLockReadGuard<'_, HashMap> { self.get(key).read() } - pub fn write(&self, key: u128) -> RwLockWriteGuard> { + pub fn write(&self, key: u128) -> RwLockWriteGuard<'_, HashMap> { self.get(key).write() } @@ -103,7 +103,7 @@ where pub fn new(shard_capacity: usize) -> Self { use std::num::NonZeroUsize; // safe, 1 != 0 - const ONE: NonZeroUsize = unsafe { NonZeroUsize::new_unchecked(1) }; + const ONE: NonZeroUsize = NonZeroUsize::new(1).unwrap(); let mut cache = ConcurrentLruCache { lrus: Default::default(), }; @@ -119,11 +119,11 @@ where } #[allow(dead_code)] - pub fn read(&self, key: u128) -> RwLockReadGuard> { + pub fn read(&self, key: u128) -> RwLockReadGuard<'_, LruCache> { self.get(key).read() } - pub fn write(&self, key: u128) -> RwLockWriteGuard> { + pub fn write(&self, key: u128) -> RwLockWriteGuard<'_, LruCache> { self.get(key).write() } diff --git a/pingora-cache/src/key.rs b/pingora-cache/src/key.rs index 0e2d51a6..c606d85d 100644 --- a/pingora-cache/src/key.rs +++ b/pingora-cache/src/key.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,8 +14,6 @@ //! Cache key -use super::*; - use blake2::{Blake2b, Digest}; use http::Extensions; use serde::{Deserialize, Serialize}; @@ -214,18 +212,6 @@ impl CacheKey { hasher } - /// Create a default [CacheKey] from a request, which just takes its URI as the primary key. - pub fn default(req_header: &ReqHeader) -> Self { - CacheKey { - namespace: Vec::new(), - primary: format!("{}", req_header.uri).into_bytes(), - primary_bin_override: None, - variance: None, - user_tag: "".into(), - extensions: Extensions::new(), - } - } - /// Create a new [CacheKey] from the given namespace, primary, and user_tag input. /// /// Both `namespace` and `primary` will be used for the primary hash diff --git a/pingora-cache/src/lib.rs b/pingora-cache/src/lib.rs index 98b466c2..867cff08 100644 --- a/pingora-cache/src/lib.rs +++ b/pingora-cache/src/lib.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -398,8 +398,7 @@ impl HttpCache { OriginNotCache | ResponseTooLarge | PredictedResponseTooLarge => { LockStatus::GiveUp } - // not sure which LockStatus make sense, we treat it as GiveUp for now - Custom(_) => LockStatus::GiveUp, + Custom(reason) => lock_ctx.cache_lock.custom_lock_status(reason), // should never happen, NeverEnabled shouldn't hold a lock NeverEnabled => panic!("NeverEnabled holds a write lock"), CacheLockGiveUp | CacheLockTimeout => { @@ -688,7 +687,7 @@ impl HttpCache { self.inner_mut() .max_file_size_tracker .as_mut() - .map_or(true, |t| t.add_body_bytes(bytes_len)) + .is_none_or(|t| t.add_body_bytes(bytes_len)) } /// Check if the max file size has been exceeded according to max file size tracker. @@ -823,6 +822,18 @@ impl HttpCache { } } + /// Return whether the underlying storage backend supports streaming partial write. + /// + /// Returns None if cache is not enabled. + pub fn support_streaming_partial_write(&self) -> Option { + self.inner.as_ref().and_then(|inner| { + inner + .enabled_ctx + .as_ref() + .map(|c| c.storage.support_streaming_partial_write()) + }) + } + /// Call this when cache hit is fully read. /// /// This call will release resource if any and log the timing in tracing if set. @@ -969,8 +980,8 @@ impl HttpCache { MissFinishType::Created(size) => { eviction.admit(cache_key, size, meta.0.internal.fresh_until) } - MissFinishType::Appended(size) => { - eviction.increment_weight(cache_key, size) + MissFinishType::Appended(size, max_size) => { + eviction.increment_weight(&cache_key, size, max_size) } }; // actual eviction can be done async @@ -1250,6 +1261,18 @@ impl HttpCache { } } + /// Return the [`CacheKey`] of this asset if any. + /// + /// This is allowed to be called in any phase. If the cache key callback was not called, + /// this will return None. + pub fn maybe_cache_key(&self) -> Option<&CacheKey> { + (!matches!( + self.phase(), + CachePhase::Disabled(NoCacheReason::NeverEnabled) | CachePhase::Uninit + )) + .then(|| self.cache_key()) + } + /// Perform the cache lookup from the given cache storage with the given cache key /// /// A cache hit will return [CacheMeta] which contains the header and meta info about @@ -1426,7 +1449,7 @@ impl HttpCache { let mut span = inner_enabled.traces.child("cache_lock"); // should always call is_cache_locked() before this function, which should guarantee that // the inner cache has a read lock and lock ctx - if let Some(lock_ctx) = inner_enabled.lock_ctx.as_mut() { + let (read_lock, status) = if let Some(lock_ctx) = inner_enabled.lock_ctx.as_mut() { let lock = lock_ctx.lock.take(); // remove the lock from self if let Some(Locked::Read(r)) = lock { let now = Instant::now(); @@ -1437,23 +1460,26 @@ impl HttpCache { wait_timeout.saturating_sub(self.lock_duration().unwrap_or(Duration::ZERO)); match timeout(wait_timeout, r.wait()).await { Ok(()) => r.lock_status(), - // TODO: need to differentiate WaitTimeout vs. Lock(Age)Timeout (expired)? - Err(_) => LockStatus::Timeout, + Err(_) => LockStatus::WaitTimeout, } } else { r.wait().await; r.lock_status() }; self.digest.add_lock_duration(now.elapsed()); - let tag_value: &'static str = status.into(); - span.set_tag(|| Tag::new("status", tag_value)); - status + (r, status) } else { panic!("cache_lock_wait on wrong type of lock") } } else { panic!("cache_lock_wait without cache lock") + }; + if let Some(lock_ctx) = self.inner_enabled().lock_ctx.as_ref() { + lock_ctx + .cache_lock + .trace_lock_wait(&mut span, &read_lock, status); } + status } /// How long did this request wait behind the read lock diff --git a/pingora-cache/src/lock.rs b/pingora-cache/src/lock.rs index 680f609e..5633b09c 100644 --- a/pingora-cache/src/lock.rs +++ b/pingora-cache/src/lock.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,12 +15,14 @@ //! Cache lock use crate::{hashtable::ConcurrentHashTable, key::CacheHashKey, CacheKey}; +use crate::{Span, Tag}; +use http::Extensions; use pingora_timeout::timeout; use std::sync::Arc; use std::time::Duration; -pub type CacheKeyLockImpl = (dyn CacheKeyLock + Send + Sync); +pub type CacheKeyLockImpl = dyn CacheKeyLock + Send + Sync; pub trait CacheKeyLock { /// Try to lock a cache fetch @@ -37,6 +39,19 @@ pub trait CacheKeyLock { /// When the write lock is dropped without being released, the read lock holders will consider /// it to be failed so that they will compete for the write lock again. fn release(&self, key: &CacheKey, permit: WritePermit, reason: LockStatus); + + /// Set tags on a trace span for the cache lock wait. + fn trace_lock_wait(&self, span: &mut Span, _read_lock: &ReadLock, lock_status: LockStatus) { + let tag_value: &'static str = lock_status.into(); + span.set_tag(|| Tag::new("status", tag_value)); + } + + /// Set a lock status for a custom `NoCacheReason`. + fn custom_lock_status(&self, _custom_no_cache: &'static str) -> LockStatus { + // treat custom no cache reasons as GiveUp by default + // (like OriginNotCache) + LockStatus::GiveUp + } } const N_SHARDS: usize = 16; @@ -106,7 +121,7 @@ impl CacheKeyLock for CacheLock { // requests ought to recreate the lock. if !matches!( lock.0.lock_status(), - LockStatus::Dangling | LockStatus::Timeout + LockStatus::Dangling | LockStatus::AgeTimeout ) { return Locked::Read(lock.read_lock()); } @@ -119,12 +134,13 @@ impl CacheKeyLock for CacheLock { if let Some(lock) = table.get(&key) { if !matches!( lock.0.lock_status(), - LockStatus::Dangling | LockStatus::Timeout + LockStatus::Dangling | LockStatus::AgeTimeout ) { return Locked::Read(lock.read_lock()); } } - let (permit, stub) = WritePermit::new(self.age_timeout_default, stale_writer); + let (permit, stub) = + WritePermit::new(self.age_timeout_default, stale_writer, Extensions::new()); table.insert(key, stub); Locked::Write(permit) } @@ -132,13 +148,13 @@ impl CacheKeyLock for CacheLock { fn release(&self, key: &CacheKey, mut permit: WritePermit, reason: LockStatus) { let hash = key.combined_bin(); let key = u128::from_be_bytes(hash); // endianness doesn't matter - if permit.lock.lock_status() == LockStatus::Timeout { + if permit.lock.lock_status() == LockStatus::AgeTimeout { // if lock age timed out, then readers are capable of // replacing the lock associated with this permit from the lock table // (see lock() implementation) // keep the lock status as Timeout accordingly when unlocking // (because we aren't removing it from the lock_table) - permit.unlock(LockStatus::Timeout); + permit.unlock(LockStatus::AgeTimeout); } else if let Some(_lock) = self.lock_table.write(key).remove(&key) { permit.unlock(reason); } @@ -150,25 +166,28 @@ impl CacheKeyLock for CacheLock { use log::warn; use std::sync::atomic::{AtomicU8, Ordering}; use std::time::Instant; -use strum::IntoStaticStr; +use strum::{FromRepr, IntoStaticStr}; use tokio::sync::Semaphore; /// Status which the read locks could possibly see. -#[derive(Debug, Copy, Clone, PartialEq, Eq, IntoStaticStr)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, IntoStaticStr, FromRepr)] +#[repr(u8)] pub enum LockStatus { /// Waiting for the writer to populate the asset - Waiting, + Waiting = 0, /// The writer finishes, readers can start - Done, + Done = 1, /// The writer encountered error, such as network issue. A new writer will be elected. - TransientError, + TransientError = 2, /// The writer observed that no cache lock is needed (e.g., uncacheable), readers should start /// to fetch independently without a new writer - GiveUp, + GiveUp = 3, /// The write lock is dropped without being unlocked - Dangling, - /// The lock is held for too long - Timeout, + Dangling = 4, + /// Reader has held onto cache locks for too long, give up + WaitTimeout = 5, + /// The lock is held for too long by the writer + AgeTimeout = 6, } impl From for u8 { @@ -179,22 +198,15 @@ impl From for u8 { LockStatus::TransientError => 2, LockStatus::GiveUp => 3, LockStatus::Dangling => 4, - LockStatus::Timeout => 5, + LockStatus::WaitTimeout => 5, + LockStatus::AgeTimeout => 6, } } } impl From for LockStatus { fn from(v: u8) -> Self { - match v { - 0 => Self::Waiting, - 1 => Self::Done, - 2 => Self::TransientError, - 3 => Self::GiveUp, - 4 => Self::Dangling, - 5 => Self::Timeout, - _ => Self::GiveUp, // placeholder - } + Self::from_repr(v).unwrap_or(Self::GiveUp) } } @@ -206,16 +218,18 @@ pub struct LockCore { // use u8 for Atomic enum lock_status: AtomicU8, stale_writer: bool, + extensions: Extensions, } impl LockCore { - pub fn new_arc(timeout: Duration, stale_writer: bool) -> Arc { + pub fn new_arc(timeout: Duration, stale_writer: bool, extensions: Extensions) -> Arc { Arc::new(LockCore { lock: Semaphore::new(0), age_timeout: timeout, lock_start: Instant::now(), lock_status: AtomicU8::new(LockStatus::Waiting.into()), stale_writer, + extensions, }) } @@ -224,6 +238,10 @@ impl LockCore { } pub fn unlock(&self, reason: LockStatus) { + assert!( + reason != LockStatus::WaitTimeout, + "WaitTimeout is not stored in LockCore" + ); self.lock_status.store(reason.into(), Ordering::SeqCst); // Any small positive number will do, 10 is used for RwLock as well. // No need to wake up all at once. @@ -238,6 +256,10 @@ impl LockCore { pub fn stale_writer(&self) -> bool { self.stale_writer } + + pub fn extensions(&self) -> &Extensions { + &self.extensions + } } // all 3 structs below are just Arc with different interfaces @@ -268,14 +290,14 @@ impl ReadLock { Err(_) => { self.0 .lock_status - .store(LockStatus::Timeout.into(), Ordering::SeqCst); + .store(LockStatus::AgeTimeout.into(), Ordering::SeqCst); } } } else { // expiration has already occurred, store timeout status self.0 .lock_status - .store(LockStatus::Timeout.into(), Ordering::SeqCst); + .store(LockStatus::AgeTimeout.into(), Ordering::SeqCst); } } @@ -295,11 +317,15 @@ impl ReadLock { pub fn lock_status(&self) -> LockStatus { let status = self.0.lock_status(); if matches!(status, LockStatus::Waiting) && self.expired() { - LockStatus::Timeout + LockStatus::AgeTimeout } else { status } } + + pub fn extensions(&self) -> &Extensions { + self.0.extensions() + } } /// WritePermit: requires who get it need to populate the cache and then release it @@ -311,8 +337,12 @@ pub struct WritePermit { impl WritePermit { /// Create a new lock, with a permit to be given to the associated writer. - pub fn new(timeout: Duration, stale_writer: bool) -> (WritePermit, LockStub) { - let lock = LockCore::new_arc(timeout, stale_writer); + pub fn new( + timeout: Duration, + stale_writer: bool, + extensions: Extensions, + ) -> (WritePermit, LockStub) { + let lock = LockCore::new_arc(timeout, stale_writer, extensions); let stub = LockStub(lock.clone()); ( WritePermit { @@ -336,6 +366,10 @@ impl WritePermit { pub fn lock_status(&self) -> LockStatus { self.lock.lock_status() } + + pub fn extensions(&self) -> &Extensions { + self.lock.extensions() + } } impl Drop for WritePermit { @@ -354,6 +388,10 @@ impl LockStub { pub fn read_lock(&self) -> ReadLock { ReadLock(self.0.clone()) } + + pub fn extensions(&self) -> &Extensions { + &self.0.extensions + } } #[cfg(test)] @@ -417,7 +455,7 @@ mod test { let handle = tokio::spawn(async move { // timed out lock.wait().await; - assert_eq!(lock.lock_status(), LockStatus::Timeout); + assert_eq!(lock.lock_status(), LockStatus::AgeTimeout); }); tokio::time::sleep(Duration::from_millis(2100)).await; @@ -462,7 +500,7 @@ mod test { let handle = tokio::spawn(async move { // timed out lock.wait().await; - assert_eq!(lock.lock_status(), LockStatus::Timeout); + assert_eq!(lock.lock_status(), LockStatus::AgeTimeout); }); tokio::time::sleep(Duration::from_millis(1100)).await; // let lock age time out @@ -512,9 +550,9 @@ mod test { }; // reader expires write permit lock.wait().await; - assert_eq!(lock.lock_status(), LockStatus::Timeout); - assert_eq!(permit.lock.lock_status(), LockStatus::Timeout); - permit.unlock(LockStatus::Timeout); + assert_eq!(lock.lock_status(), LockStatus::AgeTimeout); + assert_eq!(permit.lock.lock_status(), LockStatus::AgeTimeout); + permit.unlock(LockStatus::AgeTimeout); } #[tokio::test] diff --git a/pingora-cache/src/max_file_size.rs b/pingora-cache/src/max_file_size.rs index 106b012e..7c9eccd9 100644 --- a/pingora-cache/src/max_file_size.rs +++ b/pingora-cache/src/max_file_size.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-cache/src/memory.rs b/pingora-cache/src/memory.rs index 786cf453..6ab57c80 100644 --- a/pingora-cache/src/memory.rs +++ b/pingora-cache/src/memory.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-cache/src/meta.rs b/pingora-cache/src/meta.rs index 9c6bd6fc..4545ee22 100644 --- a/pingora-cache/src/meta.rs +++ b/pingora-cache/src/meta.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -93,8 +93,10 @@ mod internal_meta { // schema to decode it // After full releases, remove `skip_serializing_if` so that we can add the next extended field. #[serde(default)] - #[serde(skip_serializing_if = "Option::is_none")] pub(crate) variance: Option, + #[serde(default)] + #[serde(skip_serializing_if = "Option::is_none")] + pub(crate) epoch_override: Option, } impl Default for InternalMetaV2 { @@ -108,6 +110,7 @@ mod internal_meta { stale_while_revalidate_sec: 0, stale_if_error_sec: 0, variance: None, + epoch_override: None, } } } @@ -258,35 +261,75 @@ mod internal_meta { assert_eq!(meta2.created, meta2.updated); } - #[test] - fn test_internal_meta_serde_v2_extend_fields() { - // make sure that v2 format is backward compatible - // this is the base version of v2 without any extended fields - #[derive(Deserialize, Serialize)] - pub(crate) struct InternalMetaV2Base { - pub(crate) version: u8, - pub(crate) fresh_until: SystemTime, - pub(crate) created: SystemTime, - pub(crate) updated: SystemTime, - pub(crate) stale_while_revalidate_sec: u32, - pub(crate) stale_if_error_sec: u32, + // make sure that v2 format is backward compatible + // this is the base version of v2 without any extended fields + #[derive(Deserialize, Serialize)] + struct InternalMetaV2Base { + version: u8, + fresh_until: SystemTime, + created: SystemTime, + updated: SystemTime, + stale_while_revalidate_sec: u32, + stale_if_error_sec: u32, + } + + impl InternalMetaV2Base { + pub const VERSION: u8 = 2; + pub fn serialize(&self) -> Result> { + assert!(self.version >= Self::VERSION); + rmp_serde::encode::to_vec(self).or_err(InternalError, "failed to encode cache meta") + } + fn deserialize(buf: &[u8]) -> Result { + rmp_serde::decode::from_slice(buf) + .or_err(InternalError, "failed to decode cache meta v2") } + } - impl InternalMetaV2Base { - pub const VERSION: u8 = 2; - pub fn serialize(&self) -> Result> { - assert!(self.version >= Self::VERSION); - rmp_serde::encode::to_vec(self) - .or_err(InternalError, "failed to encode cache meta") - } - fn deserialize(buf: &[u8]) -> Result { - rmp_serde::decode::from_slice(buf) - .or_err(InternalError, "failed to decode cache meta v2") + // this is the base version of v2 with variance but without epoch_override + #[derive(Deserialize, Serialize)] + struct InternalMetaV2BaseWithVariance { + version: u8, + fresh_until: SystemTime, + created: SystemTime, + updated: SystemTime, + stale_while_revalidate_sec: u32, + stale_if_error_sec: u32, + #[serde(default)] + #[serde(skip_serializing_if = "Option::is_none")] + variance: Option, + } + + impl Default for InternalMetaV2BaseWithVariance { + fn default() -> Self { + let epoch = SystemTime::UNIX_EPOCH; + InternalMetaV2BaseWithVariance { + version: InternalMetaV2::VERSION, + fresh_until: epoch, + created: epoch, + updated: epoch, + stale_while_revalidate_sec: 0, + stale_if_error_sec: 0, + variance: None, } } + } + impl InternalMetaV2BaseWithVariance { + pub const VERSION: u8 = 2; + pub fn serialize(&self) -> Result> { + assert!(self.version >= Self::VERSION); + rmp_serde::encode::to_vec(self).or_err(InternalError, "failed to encode cache meta") + } + fn deserialize(buf: &[u8]) -> Result { + rmp_serde::decode::from_slice(buf) + .or_err(InternalError, "failed to decode cache meta v2") + } + } + + #[test] + fn test_internal_meta_serde_v2_extend_fields_variance() { // ext V2 to base v2 - let meta = InternalMetaV2::default(); + let meta = InternalMetaV2BaseWithVariance::default(); let binary = meta.serialize().unwrap(); let meta2 = InternalMetaV2Base::deserialize(&binary).unwrap(); assert_eq!(meta2.version, 2); @@ -305,11 +348,62 @@ mod internal_meta { stale_if_error_sec: 0, }; let binary = meta.serialize().unwrap(); + let meta2 = InternalMetaV2BaseWithVariance::deserialize(&binary).unwrap(); + assert_eq!(meta2.version, 2); + assert_eq!(meta.fresh_until, meta2.fresh_until); + assert_eq!(meta.created, meta2.created); + assert_eq!(meta.updated, meta2.updated); + } + + #[test] + fn test_internal_meta_serde_v2_extend_fields_epoch_override() { + let now = SystemTime::now(); + + // ext V2 (with epoch_override = None) to V2 with variance (without epoch_override field) + let meta = InternalMetaV2 { + fresh_until: now, + created: now, + updated: now, + epoch_override: None, // None means it will be skipped during serialization + ..Default::default() + }; + let binary = meta.serialize().unwrap(); + let meta2 = InternalMetaV2BaseWithVariance::deserialize(&binary).unwrap(); + assert_eq!(meta2.version, 2); + assert_eq!(meta.fresh_until, meta2.fresh_until); + assert_eq!(meta.created, meta2.created); + assert_eq!(meta.updated, meta2.updated); + assert!(meta2.variance.is_none()); + + // V2 base with variance (without epoch_override) to ext V2 (with epoch_override) + let mut meta = InternalMetaV2BaseWithVariance { + version: InternalMetaV2::VERSION, + fresh_until: now, + created: now, + updated: now, + stale_while_revalidate_sec: 0, + stale_if_error_sec: 0, + variance: None, + }; + let binary = meta.serialize().unwrap(); let meta2 = InternalMetaV2::deserialize(&binary).unwrap(); assert_eq!(meta2.version, 2); assert_eq!(meta.fresh_until, meta2.fresh_until); assert_eq!(meta.created, meta2.created); assert_eq!(meta.updated, meta2.updated); + assert!(meta2.variance.is_none()); + assert!(meta2.epoch_override.is_none()); + + // try with variance set + meta.variance = Some(*b"variance_testing"); + let binary = meta.serialize().unwrap(); + let meta2 = InternalMetaV2::deserialize(&binary).unwrap(); + assert_eq!(meta2.version, 2); + assert_eq!(meta.fresh_until, meta2.fresh_until); + assert_eq!(meta.created, meta2.created); + assert_eq!(meta.updated, meta2.updated); + assert_eq!(meta.variance, meta2.variance); + assert!(meta2.epoch_override.is_none()); } } } @@ -364,6 +458,32 @@ impl CacheMeta { self.0.internal.updated } + /// The reference point for cache age. This represents the "starting point" for `fresh_until`. + /// + /// This defaults to the `updated` timestamp but is overridden by the `epoch_override` field + /// if set. + pub fn epoch(&self) -> SystemTime { + self.0.internal.epoch_override.unwrap_or(self.updated()) + } + + /// Get the epoch override for this asset + pub fn epoch_override(&self) -> Option { + self.0.internal.epoch_override + } + + /// Set the epoch override for this asset + /// + /// When set, this will be used as the reference point for calculating age and freshness + /// instead of the updated time. + pub fn set_epoch_override(&mut self, epoch: SystemTime) { + self.0.internal.epoch_override = Some(epoch); + } + + /// Remove the epoch override for this asset + pub fn remove_epoch_override(&mut self) { + self.0.internal.epoch_override = None; + } + /// Is the asset still valid pub fn is_fresh(&self, time: SystemTime) -> bool { // NOTE: HTTP cache time resolution is second @@ -372,15 +492,17 @@ impl CacheMeta { /// How long (in seconds) the asset should be fresh since its admission/revalidation /// - /// This is essentially the max-age value (or its equivalence) + /// This is essentially the max-age value (or its equivalence). + /// If an epoch override is set, it will be used as the reference point instead of the updated time. pub fn fresh_sec(&self) -> u64 { // swallow `duration_since` error, assets that are always stale have earlier `fresh_until` than `created` // practically speaking we can always treat these as 0 ttl // XXX: return Error if `fresh_until` is much earlier than expected? + let reference = self.epoch(); self.0 .internal .fresh_until - .duration_since(self.0.internal.updated) + .duration_since(reference) .map_or(0, |duration| duration.as_secs()) } @@ -390,9 +512,12 @@ impl CacheMeta { } /// How old the asset is since its admission/revalidation + /// + /// If an epoch override is set, it will be used as the reference point instead of the updated time. pub fn age(&self) -> Duration { + let reference = self.epoch(); SystemTime::now() - .duration_since(self.updated()) + .duration_since(reference) .unwrap_or_default() } @@ -499,6 +624,7 @@ impl CacheMeta { pub fn serialize(&self) -> Result<(Vec, Vec)> { let internal = self.0.internal.serialize()?; let header = header_serialize(&self.0.header)?; + log::debug!("header to serialize: {:?}", &self.0.header); Ok((internal, header)) } @@ -616,3 +742,93 @@ pub fn set_compression_dict_path(path: &str) -> bool { pub fn set_compression_dict_content(content: Cow<'static, [u8]>) -> bool { COMPRESSION_DICT_CONTENT.set(content).is_ok() } + +#[cfg(test)] +mod tests { + use super::*; + use std::time::Duration; + + #[test] + fn test_cache_meta_age_without_override() { + let now = SystemTime::now(); + let header = ResponseHeader::build_no_case(200, None).unwrap(); + let meta = CacheMeta::new(now + Duration::from_secs(300), now, 0, 0, header); + + // Without epoch_override, age() should use updated() as reference + std::thread::sleep(Duration::from_millis(100)); + let age = meta.age(); + assert!(age.as_secs() < 1, "age should be close to 0"); + + // epoch() should return updated() when no override is set + assert_eq!(meta.epoch(), meta.updated()); + } + + #[test] + fn test_cache_meta_age_with_epoch_override_past() { + let now = SystemTime::now(); + let header = ResponseHeader::build(200, None).unwrap(); + let mut meta = CacheMeta::new(now + Duration::from_secs(300), now, 0, 0, header); + + // Set epoch_override to 10 seconds in the past + let epoch_override = now - Duration::from_secs(10); + meta.set_epoch_override(epoch_override); + + // age() should now use epoch_override as the reference + let age = meta.age(); + assert!(age.as_secs() >= 10); + assert!(age.as_secs() < 12); + + // epoch() should return the override + assert_eq!(meta.epoch(), epoch_override); + assert_eq!(meta.epoch_override(), Some(epoch_override)); + } + + #[test] + fn test_cache_meta_age_with_epoch_override_future() { + let now = SystemTime::now(); + let header = ResponseHeader::build(200, None).unwrap(); + let mut meta = CacheMeta::new(now + Duration::from_secs(100), now, 0, 0, header); + + // Set epoch_override to a future time + let future_epoch = now + Duration::from_secs(10); + meta.set_epoch_override(future_epoch); + + let age_with_epoch = meta.age(); + // age should be 0 since epoch_override is in the future + assert_eq!(age_with_epoch, Duration::ZERO); + } + + #[test] + fn test_cache_meta_fresh_sec() { + let header = ResponseHeader::build(StatusCode::OK, None).unwrap(); + let mut meta = CacheMeta::new( + SystemTime::now() + Duration::from_secs(100), + SystemTime::now() - Duration::from_secs(100), + 0, + 0, + header, + ); + + meta.0.internal.updated = SystemTime::UNIX_EPOCH + Duration::from_secs(1000); + meta.0.internal.fresh_until = SystemTime::UNIX_EPOCH + Duration::from_secs(1100); + + // Without epoch_override, fresh_sec should use updated as reference + let fresh_sec_without_override = meta.fresh_sec(); + assert_eq!(fresh_sec_without_override, 100); // 1100 - 1000 = 100 seconds + + // With epoch_override set to a later time (1050), fresh_sec should be calculated from that reference + let epoch_override = SystemTime::UNIX_EPOCH + Duration::from_secs(1050); + meta.set_epoch_override(epoch_override); + assert_eq!(meta.epoch_override(), Some(epoch_override)); + assert_eq!(meta.epoch(), epoch_override); + + let fresh_sec_with_override = meta.fresh_sec(); + // fresh_until - epoch_override = 1100 - 1050 = 50 seconds + assert_eq!(fresh_sec_with_override, 50); + + meta.remove_epoch_override(); + assert_eq!(meta.epoch_override(), None); + assert_eq!(meta.epoch(), meta.updated()); + assert_eq!(meta.fresh_sec(), 100); // back to normal calculation + } +} diff --git a/pingora-cache/src/predictor.rs b/pingora-cache/src/predictor.rs index 58f1315f..8c2f5a8f 100644 --- a/pingora-cache/src/predictor.rs +++ b/pingora-cache/src/predictor.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-cache/src/put.rs b/pingora-cache/src/put.rs index 4c82a482..fbbbb70e 100644 --- a/pingora-cache/src/put.rs +++ b/pingora-cache/src/put.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -34,6 +34,9 @@ pub trait CachePut { /// Return the [CacheMetaDefaults] fn cache_defaults() -> &'static CacheMetaDefaults; + + /// Put interesting things in the span given the parsed response header. + fn trace_header(&mut self, _response: &ResponseHeader) {} } use parse_response::ResponseParse; @@ -81,11 +84,12 @@ impl CachePutCtx { } async fn put_header(&mut self, meta: CacheMeta) -> Result<()> { - let trace = self.trace.child("cache put header", |o| o.start()).handle(); + let mut trace = self.trace.child("cache put header", |o| o.start()); let miss_handler = self .storage - .get_miss_handler(&self.key, &meta, &trace) + .get_miss_handler(&self.key, &meta, &trace.handle()) .await?; + trace::tag_span_with_meta(&mut trace, &meta); self.miss_handler = Some(miss_handler); self.meta = Some(meta); Ok(()) @@ -121,7 +125,9 @@ impl CachePutCtx { let cache_key = self.key.to_compact(); let meta = self.meta.as_ref().unwrap(); let evicted = match finish { - MissFinishType::Appended(delta) => eviction.increment_weight(cache_key, delta), + MissFinishType::Appended(delta, max_size) => { + eviction.increment_weight(&cache_key, delta, max_size) + } MissFinishType::Created(size) => { eviction.admit(cache_key, size, meta.0.internal.fresh_until) } @@ -144,29 +150,48 @@ impl CachePutCtx { Ok(()) } + fn trace_header(&mut self, header: &ResponseHeader) { + self.trace.set_tag(|| { + Tag::new( + "cache-control", + header + .headers + .get_all(http::header::CACHE_CONTROL) + .into_iter() + .map(|v| String::from_utf8_lossy(v.as_bytes()).to_string()) + .collect::>() + .join(","), + ) + }); + } + async fn do_cache_put(&mut self, data: &[u8]) -> Result> { let tasks = self.parser.inject_data(data)?; for task in tasks { match task { - HttpTask::Header(header, _eos) => match self.cache_put.cacheable(*header) { - RespCacheable::Cacheable(meta) => { - if let Some(max_file_size_tracker) = &self.max_file_size_tracker { - let content_length_hdr = meta.headers().get(header::CONTENT_LENGTH); - if let Some(content_length) = - header_value_content_length(content_length_hdr) - { - if content_length > max_file_size_tracker.max_file_size_bytes() { - return Ok(Some(NoCacheReason::ResponseTooLarge)); + HttpTask::Header(header, _eos) => { + self.trace_header(&header); + match self.cache_put.cacheable(*header) { + RespCacheable::Cacheable(meta) => { + if let Some(max_file_size_tracker) = &self.max_file_size_tracker { + let content_length_hdr = meta.headers().get(header::CONTENT_LENGTH); + if let Some(content_length) = + header_value_content_length(content_length_hdr) + { + if content_length > max_file_size_tracker.max_file_size_bytes() + { + return Ok(Some(NoCacheReason::ResponseTooLarge)); + } } } - } - self.put_header(meta).await?; - } - RespCacheable::Uncacheable(reason) => { - return Ok(Some(reason)); + self.put_header(meta).await?; + } + RespCacheable::Uncacheable(reason) => { + return Ok(Some(reason)); + } } - }, + } HttpTask::Body(data, eos) => { if let Some(data) = data { self.put_body(data, eos).await?; @@ -369,6 +394,7 @@ mod test { mod parse_response { use super::*; + use bstr::ByteSlice; use bytes::BytesMut; use httparse::Status; use pingora_error::{ @@ -475,7 +501,7 @@ mod parse_response { self.state = ParseState::Invalid(e); return Error::e_because( InvalidHTTPHeader, - format!("buf: {:?}", String::from_utf8_lossy(&self.buf)), + format!("buf: {:?}", self.buf.as_bstr()), e, ); } diff --git a/pingora-cache/src/storage.rs b/pingora-cache/src/storage.rs index 6a870b43..5df1526d 100644 --- a/pingora-cache/src/storage.rs +++ b/pingora-cache/src/storage.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -117,18 +117,39 @@ pub trait HandleHit { trace: &SpanHandle, ) -> Result<()>; - /// Whether this storage allow seeking to a certain range of body + /// Whether this storage allows seeking to a certain range of body for single ranges. fn can_seek(&self) -> bool { false } - /// Try to seek to a certain range of the body + /// Whether this storage allows seeking to a certain range of body for multipart ranges. + /// + /// By default uses the `can_seek` implementation. + fn can_seek_multipart(&self) -> bool { + self.can_seek() + } + + /// Try to seek to a certain range of the body for single ranges. /// /// `end: None` means to read to the end of the body. fn seek(&mut self, _start: usize, _end: Option) -> Result<()> { // to prevent impl can_seek() without impl seek todo!("seek() needs to be implemented") } + + /// Try to seek to a certain range of the body for multipart ranges. + /// + /// Works in an identical manner to `seek()`. + /// + /// `end: None` means to read to the end of the body. + /// + /// By default uses the `seek` implementation, but hit handlers may customize the + /// implementation specifically to anticipate multipart requests. + fn seek_multipart(&mut self, start: usize, end: Option) -> Result<()> { + // to prevent impl can_seek() without impl seek + self.seek(start, end) + } + // TODO: fn is_stream_hit() /// Should we count this hit handler instance as an access in the eviction manager. @@ -157,12 +178,14 @@ pub trait HandleHit { } /// Hit Handler -pub type HitHandler = Box<(dyn HandleHit + Sync + Send)>; +pub type HitHandler = Box; /// MissFinishType pub enum MissFinishType { + /// A new asset was created with the given size. Created(usize), - Appended(usize), + /// Appended size to existing asset, with an optional max size param. + Appended(usize, Option), } /// Cache miss handling trait @@ -197,7 +220,7 @@ pub trait HandleMiss { } /// Miss Handler -pub type MissHandler = Box<(dyn HandleMiss + Sync + Send)>; +pub type MissHandler = Box; pub mod streaming_write { /// Portable u64 (sized) write id convenience type for use with streaming writes. diff --git a/pingora-cache/src/trace.rs b/pingora-cache/src/trace.rs index 90d4f1c3..f27929a2 100644 --- a/pingora-cache/src/trace.rs +++ b/pingora-cache/src/trace.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -33,6 +33,28 @@ pub(crate) struct CacheTraceCTX { pub hit_span: Span, } +pub fn tag_span_with_meta(span: &mut Span, meta: &CacheMeta) { + fn ts2epoch(ts: SystemTime) -> f64 { + ts.duration_since(SystemTime::UNIX_EPOCH) + .unwrap_or_default() // should never overflow but be safe here + .as_secs_f64() + } + let internal = &meta.0.internal; + span.set_tags(|| { + [ + Tag::new("created", ts2epoch(internal.created)), + Tag::new("fresh_until", ts2epoch(internal.fresh_until)), + Tag::new("updated", ts2epoch(internal.updated)), + Tag::new("stale_if_error_sec", internal.stale_if_error_sec as i64), + Tag::new( + "stale_while_revalidate_sec", + internal.stale_while_revalidate_sec as i64, + ), + Tag::new("variance", internal.variance.is_some()), + ] + }); +} + impl CacheTraceCTX { pub fn new() -> Self { CacheTraceCTX { @@ -82,33 +104,11 @@ impl CacheTraceCTX { self.hit_span.set_finish_time(SystemTime::now); } - fn log_meta(span: &mut Span, meta: &CacheMeta) { - fn ts2epoch(ts: SystemTime) -> f64 { - ts.duration_since(SystemTime::UNIX_EPOCH) - .unwrap_or_default() // should never overflow but be safe here - .as_secs_f64() - } - let internal = &meta.0.internal; - span.set_tags(|| { - [ - Tag::new("created", ts2epoch(internal.created)), - Tag::new("fresh_until", ts2epoch(internal.fresh_until)), - Tag::new("updated", ts2epoch(internal.updated)), - Tag::new("stale_if_error_sec", internal.stale_if_error_sec as i64), - Tag::new( - "stale_while_revalidate_sec", - internal.stale_while_revalidate_sec as i64, - ), - Tag::new("variance", internal.variance.is_some()), - ] - }); - } - pub fn log_meta_in_hit_span(&mut self, meta: &CacheMeta) { - CacheTraceCTX::log_meta(&mut self.hit_span, meta); + tag_span_with_meta(&mut self.hit_span, meta); } pub fn log_meta_in_miss_span(&mut self, meta: &CacheMeta) { - CacheTraceCTX::log_meta(&mut self.miss_span, meta); + tag_span_with_meta(&mut self.miss_span, meta); } } diff --git a/pingora-core/Cargo.toml b/pingora-core/Cargo.toml index e5ed2834..5822932d 100644 --- a/pingora-core/Cargo.toml +++ b/pingora-core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pingora-core" -version = "0.6.0" +version = "0.8.0" authors = ["Yuchen Wu "] license = "Apache-2.0" edition = "2021" @@ -19,16 +19,18 @@ name = "pingora_core" path = "src/lib.rs" [dependencies] -pingora-runtime = { version = "0.6.0", path = "../pingora-runtime" } -pingora-openssl = { version = "0.6.0", path = "../pingora-openssl", optional = true } -pingora-boringssl = { version = "0.6.0", path = "../pingora-boringssl", optional = true } -pingora-pool = { version = "0.6.0", path = "../pingora-pool" } -pingora-error = { version = "0.6.0", path = "../pingora-error" } -pingora-timeout = { version = "0.6.0", path = "../pingora-timeout" } -pingora-http = { version = "0.6.0", path = "../pingora-http" } -pingora-rustls = { version = "0.6.0", path = "../pingora-rustls", optional = true } -pingora-s2n = { version = "0.6.0", path = "../pingora-s2n", optional = true } +pingora-runtime = { version = "0.8.0", path = "../pingora-runtime" } +pingora-openssl = { version = "0.8.0", path = "../pingora-openssl", optional = true } +pingora-boringssl = { version = "0.8.0", path = "../pingora-boringssl", optional = true } +pingora-pool = { version = "0.8.0", path = "../pingora-pool" } +pingora-error = { version = "0.8.0", path = "../pingora-error" } +pingora-timeout = { version = "0.8.0", path = "../pingora-timeout" } +pingora-http = { version = "0.8.0", path = "../pingora-http" } +pingora-rustls = { version = "0.8.0", path = "../pingora-rustls", optional = true } +pingora-s2n = { version = "0.8.0", path = "../pingora-s2n", optional = true } +bstr = { workspace = true } tokio = { workspace = true, features = ["net", "rt-multi-thread", "signal"] } +tokio-stream = { workspace = true } futures = "0.3" async-trait = { workspace = true } httparse = { workspace = true } @@ -37,10 +39,10 @@ http = { workspace = true } log = { workspace = true } h2 = { workspace = true } derivative.workspace = true -clap = { version = "3.2.25", features = ["derive"] } +clap = { version = "4.5", features = ["derive"] } once_cell = { workspace = true } serde = { version = "1.0", features = ["derive"] } -serde_yaml = "0.8" +serde_yaml = "0.9" strum = "0.26.2" strum_macros = "0.26.2" libc = "0.2.70" @@ -69,7 +71,8 @@ zstd = "0" httpdate = "1" x509-parser = { version = "0.16.0", optional = true } ouroboros = { version = "0.18.4", optional = true } -lru = { version = "0.16.0", optional = true } +lru = { workspace = true, optional = true } +daggy = "0.8" proxy-protocol = {git = "https://github.com/arxignis/proxy-protocol.git"} [target.'cfg(unix)'.dependencies] daemonize = "0.5.0" @@ -79,14 +82,15 @@ nix = "~0.24.3" windows-sys = { version = "0.59.0", features = ["Win32_Networking_WinSock"] } [dev-dependencies] -h2 = { workspace = true, features=["unstable"]} +h2 = { workspace = true, features = ["unstable"] } tokio-stream = { version = "0.1", features = ["full"] } -env_logger = "0.9" +env_logger = "0.11" reqwest = { version = "0.11", features = [ "rustls-tls", ], default-features = false } hyper = "0.14" rstest = "0.23.0" +rustls = "0.23" [target.'cfg(unix)'.dev-dependencies] hyperlocal = "0.8" @@ -102,3 +106,4 @@ patched_http1 = ["pingora-http/patched_http1"] openssl_derived = ["any_tls"] any_tls = [] sentry = ["dep:sentry"] +connection_filter = [] diff --git a/pingora-core/examples/bootstrap_as_a_service.rs b/pingora-core/examples/bootstrap_as_a_service.rs new file mode 100644 index 00000000..c49ad271 --- /dev/null +++ b/pingora-core/examples/bootstrap_as_a_service.rs @@ -0,0 +1,102 @@ +// Copyright 2026 Cloudflare, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Example demonstrating how to start a server using [`Server::bootstrap_as_a_service`] +//! instead of calling [`Server::bootstrap`] directly. +//! +//! # Why `bootstrap_as_a_service`? +//! +//! [`Server::bootstrap`] runs the bootstrap phase synchronously before any services start. +//! This means the calling thread blocks during socket FD acquisition and Sentry initialization. +//! +//! [`Server::bootstrap_as_a_service`] instead schedules bootstrap as a dependency-aware init +//! service. This allows other services to declare a dependency on the bootstrap handle and +//! ensures they only start after bootstrap completes β€” while keeping setup fully asynchronous +//! and composable with the rest of the service graph. +//! +//! Use `bootstrap_as_a_service` when: +//! - You want to integrate bootstrap into the service dependency graph +//! - You want services to wait for bootstrap without blocking the main thread +//! - You are building more complex startup sequences (e.g. multiple ordered init steps) +//! +//! # Running the example +//! +//! ```bash +//! cargo run --example bootstrap_as_a_service --package pingora-core +//! ``` +//! +//! # Expected behaviour +//! +//! Bootstrap runs as a service before `MyService` starts. `MyService` declares a dependency +//! on the bootstrap handle, so it will not be started until bootstrap has completed. + +use async_trait::async_trait; +use log::info; +use pingora_core::server::configuration::Opt; +#[cfg(unix)] +use pingora_core::server::ListenFds; +use pingora_core::server::{Server, ShutdownWatch}; +use pingora_core::services::Service; + +/// A simple application service that requires bootstrap to be complete before it starts. +pub struct MyService; + +#[async_trait] +impl Service for MyService { + async fn start_service( + &mut self, + #[cfg(unix)] _fds: Option, + mut shutdown: ShutdownWatch, + _listeners_per_fd: usize, + ) { + info!("MyService: bootstrap is complete, starting up"); + + // Keep running until a shutdown signal is received. + shutdown.changed().await.ok(); + + info!("MyService: shutting down"); + } + + fn name(&self) -> &str { + "my_service" + } + + fn threads(&self) -> Option { + Some(1) + } +} + +fn main() { + env_logger::Builder::from_default_env() + .filter_level(log::LevelFilter::Info) + .init(); + + let opt = Opt::parse_args(); + let mut server = Server::new(Some(opt)).unwrap(); + + // Schedule bootstrap as a service instead of calling server.bootstrap() directly. + // The returned handle can be used to declare dependencies so that other services + // only start after bootstrap has finished. + let bootstrap_handle = server.bootstrap_as_a_service(); + + // Register our application service and get its handle. + let service_handle = server.add_service(MyService); + + // MyService will not start until the bootstrap service has signaled that it is ready. + service_handle.add_dependency(&bootstrap_handle); + + info!("Starting server β€” bootstrap will run as a service before MyService starts"); + + server.run_forever(); +} diff --git a/pingora-core/examples/client_cert.rs b/pingora-core/examples/client_cert.rs new file mode 100644 index 00000000..cbac46a1 --- /dev/null +++ b/pingora-core/examples/client_cert.rs @@ -0,0 +1,227 @@ +// Copyright 2026 Cloudflare, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#![cfg_attr(not(feature = "openssl"), allow(unused))] + +use std::any::Any; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; +use std::sync::Arc; + +use async_trait::async_trait; +use clap::Parser; +use http::header::{CONTENT_LENGTH, CONTENT_TYPE}; +use http::{Response, StatusCode}; +use pingora_core::apps::http_app::ServeHttp; +use pingora_core::listeners::tls::TlsSettings; +use pingora_core::listeners::TlsAccept; +use pingora_core::protocols::http::ServerSession; +use pingora_core::protocols::tls::TlsRef; +use pingora_core::server::configuration::Opt; +use pingora_core::server::Server; +use pingora_core::services::listening::Service; +use pingora_core::Result; +#[cfg(feature = "openssl")] +use pingora_openssl::{ + nid::Nid, + ssl::{NameType, SslFiletype, SslVerifyMode}, + x509::{GeneralName, X509Name}, +}; + +// Custom structure to hold TLS information +struct MyTlsInfo { + // SNI (Server Name Indication) from the TLS handshake + sni: Option, + // SANs (Subject Alternative Names) from client certificate + sans: Vec, + // Common Name (CN) from client certificate + common_name: Option, +} + +struct MyApp; + +#[async_trait] +impl ServeHttp for MyApp { + async fn response(&self, session: &mut ServerSession) -> http::Response> { + static EMPTY_VEC: Vec = vec![]; + + // Extract TLS info from the session's digest extensions + let my_tls_info = session + .digest() + .and_then(|digest| digest.ssl_digest.as_ref()) + .and_then(|ssl_digest| ssl_digest.extension.get::()); + let sni = my_tls_info + .and_then(|my_tls_info| my_tls_info.sni.as_deref()) + .unwrap_or(""); + let sans = my_tls_info + .map(|my_tls_info| &my_tls_info.sans) + .unwrap_or(&EMPTY_VEC); + let common_name = my_tls_info + .and_then(|my_tls_info| my_tls_info.common_name.as_deref()) + .unwrap_or(""); + + // Create response message + let mut message = String::new(); + message += &format!("Your SNI was: {sni}\n"); + message += &format!("Your SANs were: {sans:?}\n"); + message += &format!("Client Common Name (CN): {}\n", common_name); + let message = message.into_bytes(); + + Response::builder() + .status(StatusCode::OK) + .header(CONTENT_TYPE, "text/plain") + .header(CONTENT_LENGTH, message.len()) + .body(message) + .unwrap() + } +} + +struct MyTlsCallbacks; + +#[async_trait] +impl TlsAccept for MyTlsCallbacks { + #[cfg(feature = "openssl")] + async fn handshake_complete_callback( + &self, + tls_ref: &TlsRef, + ) -> Option> { + // Here you can inspect the TLS connection and return an extension if needed. + + // Extract SNI (Server Name Indication) + let sni = tls_ref + .servername(NameType::HOST_NAME) + .map(ToOwned::to_owned); + + // Extract SAN (Subject Alternative Names) from the client certificate + let sans = tls_ref + .peer_certificate() + .and_then(|cert| cert.subject_alt_names()) + .map_or(vec![], |sans| { + sans.into_iter() + .filter_map(|san| san_to_string(&san)) + .collect::>() + }); + + // Extract Common Name (CN) from the client certificate + let common_name = tls_ref.peer_certificate().and_then(|cert| { + let cn = cert.subject_name().entries_by_nid(Nid::COMMONNAME).next()?; + Some(cn.data().as_utf8().ok()?.to_string()) + }); + + let tls_info = MyTlsInfo { + sni, + sans, + common_name, + }; + Some(Arc::new(tls_info)) + } +} + +// Convert GeneralName of SAN to String representation +#[cfg(feature = "openssl")] +fn san_to_string(san: &GeneralName) -> Option { + if let Some(dnsname) = san.dnsname() { + return Some(dnsname.to_owned()); + } + if let Some(uri) = san.uri() { + return Some(uri.to_owned()); + } + if let Some(email) = san.email() { + return Some(email.to_owned()); + } + if let Some(ip) = san.ipaddress() { + return bytes_to_ip_addr(ip).map(|addr| addr.to_string()); + } + None +} + +// Convert byte slice to IpAddr +fn bytes_to_ip_addr(bytes: &[u8]) -> Option { + match bytes.len() { + 4 => { + let addr = Ipv4Addr::new(bytes[0], bytes[1], bytes[2], bytes[3]); + Some(IpAddr::V4(addr)) + } + 16 => { + let mut octets = [0u8; 16]; + octets.copy_from_slice(bytes); + let addr = Ipv6Addr::from(octets); + Some(IpAddr::V6(addr)) + } + _ => None, + } +} + +// This example demonstrates an HTTP server that requires client certificates. +// The server extracts the SNI (Server Name Indication) from the TLS handshake and +// SANs (Subject Alternative Names) from the client certificate, then returns them +// as part of the HTTP response. +// +// ## How to run +// +// cargo run -F openssl --example client_cert +// +// # In another terminal, run the following command to test the server: +// cd pingora-core +// curl -k -i \ +// --cert examples/keys/clients/cert-1.pem --key examples/keys/clients/key-1.pem \ +// --resolve myapp.example.com:6196:127.0.0.1 \ +// https://myapp.example.com:6196/ +// curl -k -i \ +// --cert examples/keys/clients/cert-2.pem --key examples/keys/clients/key-2.pem \ +// --resolve myapp.example.com:6196:127.0.0.1 \ +// https://myapp.example.com:6196/ +// curl -k -i \ +// --cert examples/keys/clients/invalid-cert.pem --key examples/keys/clients/invalid-key.pem \ +// --resolve myapp.example.com:6196:127.0.0.1 \ +// https://myapp.example.com:6196/ +#[cfg(feature = "openssl")] +fn main() -> Result<(), Box> { + env_logger::init(); + + // read command line arguments + let opt = Opt::parse(); + let mut my_server = Server::new(Some(opt))?; + my_server.bootstrap(); + + let mut my_app = Service::new("my app".to_owned(), MyApp); + + // Paths to server certificate, private key, and client CA certificate + let manifest_dir = env!("CARGO_MANIFEST_DIR"); + let server_cert_path = format!("{manifest_dir}/examples/keys/server/cert.pem"); + let server_key_path = format!("{manifest_dir}/examples/keys/server/key.pem"); + let client_ca_path = format!("{manifest_dir}/examples/keys/client-ca/cert.pem"); + + // Create TLS settings with callbacks + let callbacks = Box::new(MyTlsCallbacks); + let mut tls_settings = TlsSettings::with_callbacks(callbacks)?; + // Set server certificate and private key + tls_settings.set_certificate_chain_file(&server_cert_path)?; + tls_settings.set_private_key_file(server_key_path, SslFiletype::PEM)?; + // Require client certificate + tls_settings.set_verify(SslVerifyMode::PEER | SslVerifyMode::FAIL_IF_NO_PEER_CERT); + // Set CA for client certificate verification + tls_settings.set_ca_file(&client_ca_path)?; + // Optionally, set the list of acceptable client CAs sent to the client + tls_settings.set_client_ca_list(X509Name::load_client_ca_file(&client_ca_path)?); + + my_app.add_tls_with_settings("0.0.0.0:6196", None, tls_settings); + my_server.add_service(my_app); + + my_server.run_forever(); +} + +#[cfg(not(feature = "openssl"))] +fn main() { + eprintln!("This example requires the 'openssl' feature to be enabled."); +} diff --git a/pingora-core/examples/keys/client-ca/cert.pem b/pingora-core/examples/keys/client-ca/cert.pem new file mode 100644 index 00000000..2025cda3 --- /dev/null +++ b/pingora-core/examples/keys/client-ca/cert.pem @@ -0,0 +1,15 @@ +-----BEGIN CERTIFICATE----- +MIICTjCCAfWgAwIBAgIULuUoq/di4EKmLyN0YwAkd6MQjv4wCgYIKoZIzj0EAwIw +dTELMAkGA1UEBhMCVVMxEzARBgNVBAgMCkNhbGlmb3JuaWExFjAUBgNVBAcMDVNh +biBGcmFuY2lzY28xGDAWBgNVBAoMD0Nsb3VkZmxhcmUsIEluYzEfMB0GA1UEAwwW +RXhhbXBsZSBDbGllbnQgUm9vdCBDQTAeFw0yNTExMTkwNDU5MjRaFw0zNTExMTcw +NDU5MjRaMHUxCzAJBgNVBAYTAlVTMRMwEQYDVQQIDApDYWxpZm9ybmlhMRYwFAYD +VQQHDA1TYW4gRnJhbmNpc2NvMRgwFgYDVQQKDA9DbG91ZGZsYXJlLCBJbmMxHzAd +BgNVBAMMFkV4YW1wbGUgQ2xpZW50IFJvb3QgQ0EwWTATBgcqhkjOPQIBBggqhkjO +PQMBBwNCAARxcxOAR4zUDPilKpMLiBzNs+HxdW6ZBlHVA7/0VyJtSPw03IdlbtFs +FhgcIa8uQ9nrppHlrzploTA7cg7YWUoso2MwYTAPBgNVHRMBAf8EBTADAQH/MA4G +A1UdDwEB/wQEAwIBBjAdBgNVHQ4EFgQUL6S83l9AGZmmwHh+64YlUtMQzZcwHwYD +VR0jBBgwFoAUL6S83l9AGZmmwHh+64YlUtMQzZcwCgYIKoZIzj0EAwIDRwAwRAIg +cohFQxG22J2YKw+DGAidU5u3mxtB/BALxIusqd+OfFUCIGmT2GHVxz1FwK2pJrM1 +FTWEcEbAw3r86iIVJBYP4qX6 +-----END CERTIFICATE----- diff --git a/pingora-core/examples/keys/client-ca/key.pem b/pingora-core/examples/keys/client-ca/key.pem new file mode 100644 index 00000000..a4c54f95 --- /dev/null +++ b/pingora-core/examples/keys/client-ca/key.pem @@ -0,0 +1,5 @@ +-----BEGIN EC PRIVATE KEY----- +MHcCAQEEIJOxEQowpYL5VLNf+qaCEBhic8e26UyR0ku65Sk6gjMIoAoGCCqGSM49 +AwEHoUQDQgAEcXMTgEeM1Az4pSqTC4gczbPh8XVumQZR1QO/9FcibUj8NNyHZW7R +bBYYHCGvLkPZ66aR5a86ZaEwO3IO2FlKLA== +-----END EC PRIVATE KEY----- diff --git a/pingora-core/examples/keys/clients/cert-1.pem b/pingora-core/examples/keys/clients/cert-1.pem new file mode 100644 index 00000000..7d6ce13f --- /dev/null +++ b/pingora-core/examples/keys/clients/cert-1.pem @@ -0,0 +1,16 @@ +-----BEGIN CERTIFICATE----- +MIICjjCCAjWgAwIBAgIUYUSqEzxm/oebfxxQmZEesZL2WFAwCgYIKoZIzj0EAwIw +dTELMAkGA1UEBhMCVVMxEzARBgNVBAgMCkNhbGlmb3JuaWExFjAUBgNVBAcMDVNh +biBGcmFuY2lzY28xGDAWBgNVBAoMD0Nsb3VkZmxhcmUsIEluYzEfMB0GA1UEAwwW +RXhhbXBsZSBDbGllbnQgUm9vdCBDQTAeFw0yNTExMTkwNTEyMThaFw0zNTExMTcw +NTEyMThaMG8xCzAJBgNVBAYTAlVTMRMwEQYDVQQIDApDYWxpZm9ybmlhMRYwFAYD +VQQHDA1TYW4gRnJhbmNpc2NvMRgwFgYDVQQKDA9DbG91ZGZsYXJlLCBJbmMxGTAX +BgNVBAMMEGV4YW1wbGUtY2xpZW50LTEwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNC +AATDe6hBwpmE4Jt//sIWGWuBDYXHezVoFeoHsDzcWo6RwyHDfm7lvnACmqWAdRUV +1GA7yfkzc1CaTqnvU8GjFdfXo4GoMIGlMAwGA1UdEwEB/wQCMAAwDgYDVR0PAQH/ +BAQDAgWgMBMGA1UdJQQMMAoGCCsGAQUFBwMCMDAGA1UdEQQpMCeGJXNwaWZmZTov +L2V4YW1wbGUuY29tL2V4YW1wbGUtY2xpZW50LTEwHQYDVR0OBBYEFAjfTzgX+AVh +M+BIaU0qTgINZWOdMB8GA1UdIwQYMBaAFC+kvN5fQBmZpsB4fuuGJVLTEM2XMAoG +CCqGSM49BAMCA0cAMEQCIHyJDCvYKgxVthHcLjlEGW4Pj0Y7XnQUCJARa3jAUTd9 +AiB8tSXbo6J6Jhy6nasaxT1HAZwjgMVQwdo8O8UYOXXZpQ== +-----END CERTIFICATE----- diff --git a/pingora-core/examples/keys/clients/cert-2.pem b/pingora-core/examples/keys/clients/cert-2.pem new file mode 100644 index 00000000..b209b933 --- /dev/null +++ b/pingora-core/examples/keys/clients/cert-2.pem @@ -0,0 +1,18 @@ +-----BEGIN CERTIFICATE----- +MIIC0zCCAnmgAwIBAgIUVQlGCD9Zryvkh9G8GZXFBa2L9kQwCgYIKoZIzj0EAwIw +dTELMAkGA1UEBhMCVVMxEzARBgNVBAgMCkNhbGlmb3JuaWExFjAUBgNVBAcMDVNh +biBGcmFuY2lzY28xGDAWBgNVBAoMD0Nsb3VkZmxhcmUsIEluYzEfMB0GA1UEAwwW +RXhhbXBsZSBDbGllbnQgUm9vdCBDQTAeFw0yNTExMTkwODA5MDlaFw0zNTExMTcw +ODA5MDlaMG8xCzAJBgNVBAYTAlVTMRMwEQYDVQQIDApDYWxpZm9ybmlhMRYwFAYD +VQQHDA1TYW4gRnJhbmNpc2NvMRgwFgYDVQQKDA9DbG91ZGZsYXJlLCBJbmMxGTAX +BgNVBAMMEGV4YW1wbGUtY2xpZW50LTIwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNC +AAS2J10rq5Rt4TjhqEjHED0UPdceuzHUcw8doLC4StBIxJIrFk9Ag0g5ti9vN4fG +kK6J11GXk/pBmu3O3s48Gsfgo4HsMIHpMAwGA1UdEwEB/wQCMAAwDgYDVR0PAQH/ +BAQDAgWgMBMGA1UdJQQMMAoGCCsGAQUFBwMCMHQGA1UdEQRtMGuGJXNwaWZmZTov +L2V4YW1wbGUuY29tL2V4YW1wbGUtY2xpZW50LTKCFGNsaWVudC0yLmV4YW1wbGUu +Y29thwR/AAABhxAAAAAAAAAAAAAAAAAAAAABgRRjbGllbnQtMkBleGFtcGxlLmNv +bTAdBgNVHQ4EFgQUGHwnr7Ube1hqsodgcxJkfYuCKE8wHwYDVR0jBBgwFoAUL6S8 +3l9AGZmmwHh+64YlUtMQzZcwCgYIKoZIzj0EAwIDSAAwRQIgK4JL1OO2nB7MqvGW +y2nbH4yYMu2jUkYhw9HFLUG2B6MCIQC4iDWKXp7R977LvuaaQaNcMmbGysrmfo8V +wOmp1JGOtA== +-----END CERTIFICATE----- diff --git a/pingora-core/examples/keys/clients/invalid-cert.pem b/pingora-core/examples/keys/clients/invalid-cert.pem new file mode 100644 index 00000000..27ae7c93 --- /dev/null +++ b/pingora-core/examples/keys/clients/invalid-cert.pem @@ -0,0 +1,16 @@ +-----BEGIN CERTIFICATE----- +MIICjzCCAjWgAwIBAgIUHYIVFYFooGVi2bNlk5R6GsbDKqUwCgYIKoZIzj0EAwIw +dTELMAkGA1UEBhMCVVMxEzARBgNVBAgMCkNhbGlmb3JuaWExFjAUBgNVBAcMDVNh +biBGcmFuY2lzY28xGDAWBgNVBAoMD0Nsb3VkZmxhcmUsIEluYzEfMB0GA1UEAwwW +RXhhbXBsZSBDbGllbnQgUm9vdCBDQTAeFw0yNTExMTkwODEzNDJaFw0zNTExMTcw +ODEzNDJaMG8xCzAJBgNVBAYTAlVTMRMwEQYDVQQIDApDYWxpZm9ybmlhMRYwFAYD +VQQHDA1TYW4gRnJhbmNpc2NvMRgwFgYDVQQKDA9DbG91ZGZsYXJlLCBJbmMxGTAX +BgNVBAMMEGV4YW1wbGUtY2xpZW50LTMwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNC +AATGKppMkUDsNvpzPPPiKmz53bbyIJPemIq5OdgJli8XZUFozxroJuFKhUuJOuFF +Jns2pzLHewIDzFXgErPqPxA/o4GoMIGlMAwGA1UdEwEB/wQCMAAwDgYDVR0PAQH/ +BAQDAgWgMBMGA1UdJQQMMAoGCCsGAQUFBwMCMDAGA1UdEQQpMCeGJXNwaWZmZTov +L2V4YW1wbGUuY29tL2V4YW1wbGUtY2xpZW50LTMwHQYDVR0OBBYEFDV/v0zsiC/t +aomzxKa0jJ4SlmSzMB8GA1UdIwQYMBaAFK04aCtyumAb4PEMnh9OXLW7EIJSMAoG +CCqGSM49BAMCA0gAMEUCIH/wxvS0ae8DF1QteE+2FDOd/G2WeBMjsS8A6VyebAru +AiEAl2vjq0KePvM2X0jTZ/+RMJO33HOpYr0+PZw6FAa+aaw= +-----END CERTIFICATE----- diff --git a/pingora-core/examples/keys/clients/invalid-key.pem b/pingora-core/examples/keys/clients/invalid-key.pem new file mode 100644 index 00000000..343688aa --- /dev/null +++ b/pingora-core/examples/keys/clients/invalid-key.pem @@ -0,0 +1,5 @@ +-----BEGIN EC PRIVATE KEY----- +MHcCAQEEIFyLneOGHgjTBS8I2GB8kF0LHgDS/eTJBSDNS4PAkJ0JoAoGCCqGSM49 +AwEHoUQDQgAExiqaTJFA7Db6czzz4ips+d228iCT3piKuTnYCZYvF2VBaM8a6Cbh +SoVLiTrhRSZ7Nqcyx3sCA8xV4BKz6j8QPw== +-----END EC PRIVATE KEY----- diff --git a/pingora-core/examples/keys/clients/key-1.pem b/pingora-core/examples/keys/clients/key-1.pem new file mode 100644 index 00000000..e5a27feb --- /dev/null +++ b/pingora-core/examples/keys/clients/key-1.pem @@ -0,0 +1,5 @@ +-----BEGIN EC PRIVATE KEY----- +MHcCAQEEIFNioASifzPy0Fcp+qmMoMUhFOJGLki20ygISqZb+HY1oAoGCCqGSM49 +AwEHoUQDQgAEw3uoQcKZhOCbf/7CFhlrgQ2Fx3s1aBXqB7A83FqOkcMhw35u5b5w +ApqlgHUVFdRgO8n5M3NQmk6p71PBoxXX1w== +-----END EC PRIVATE KEY----- diff --git a/pingora-core/examples/keys/clients/key-2.pem b/pingora-core/examples/keys/clients/key-2.pem new file mode 100644 index 00000000..8d4063c7 --- /dev/null +++ b/pingora-core/examples/keys/clients/key-2.pem @@ -0,0 +1,5 @@ +-----BEGIN EC PRIVATE KEY----- +MHcCAQEEICd8DwjvpvE6nIKKKH2smrnLBM5zQyIkAKwBCiiRZGGsoAoGCCqGSM49 +AwEHoUQDQgAEtiddK6uUbeE44ahIxxA9FD3XHrsx1HMPHaCwuErQSMSSKxZPQINI +ObYvbzeHxpCuiddRl5P6QZrtzt7OPBrH4A== +-----END EC PRIVATE KEY----- diff --git a/pingora-core/examples/keys/server/cert.pem b/pingora-core/examples/keys/server/cert.pem new file mode 100644 index 00000000..4e927ce4 --- /dev/null +++ b/pingora-core/examples/keys/server/cert.pem @@ -0,0 +1,15 @@ +-----BEGIN CERTIFICATE----- +MIICVzCCAf6gAwIBAgIUYGbx/r4kY40a+zNq7IW/1lsvzk0wCgYIKoZIzj0EAwIw +bDELMAkGA1UEBhMCVVMxEzARBgNVBAgMCkNhbGlmb3JuaWExFjAUBgNVBAcMDVNh +biBGcmFuY2lzY28xGDAWBgNVBAoMD0Nsb3VkZmxhcmUsIEluYzEWMBQGA1UEAwwN +b3BlbnJ1c3R5Lm9yZzAeFw0yNTExMTkwNDUxMzdaFw0zNTExMTcwNDUxMzdaMGwx +CzAJBgNVBAYTAlVTMRMwEQYDVQQIDApDYWxpZm9ybmlhMRYwFAYDVQQHDA1TYW4g +RnJhbmNpc2NvMRgwFgYDVQQKDA9DbG91ZGZsYXJlLCBJbmMxFjAUBgNVBAMMDW9w +ZW5ydXN0eS5vcmcwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAAT9EuNEw3e3syHW +SNnyJw7QVtOzDlILlt6F+jXT8UMBoMn4OnwC7AFlV8XzR9UpYSf1yq7Raps7c8TU +W9YF6ee4o34wfDAdBgNVHQ4EFgQU6B2YXLmWaboIZsf9YOCePRQXrO4wHwYDVR0j +BBgwFoAU6B2YXLmWaboIZsf9YOCePRQXrO4wDwYDVR0TAQH/BAUwAwEB/zApBgNV +HREEIjAggg8qLm9wZW5ydXN0eS5vcmeCDW9wZW5ydXN0eS5vcmcwCgYIKoZIzj0E +AwIDRwAwRAIgcSThJ5CWjuyWKfHbR+RuJ/9DtH1ag/47OolMQAvOczsCIDKVgPO/ +A69bTOk4sq0y92YBBbe3hF82KrsgTR3nlkKF +-----END CERTIFICATE----- diff --git a/pingora-core/examples/keys/server/key.pem b/pingora-core/examples/keys/server/key.pem new file mode 100644 index 00000000..5781629a --- /dev/null +++ b/pingora-core/examples/keys/server/key.pem @@ -0,0 +1,5 @@ +-----BEGIN PRIVATE KEY----- +MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgTAnVhDuKvV5epzX4 +uuC8kEZL2vUPI49gUmS5kM+j5VWhRANCAAT9EuNEw3e3syHWSNnyJw7QVtOzDlIL +lt6F+jXT8UMBoMn4OnwC7AFlV8XzR9UpYSf1yq7Raps7c8TUW9YF6ee4 +-----END PRIVATE KEY----- diff --git a/pingora-core/examples/service_dependencies.rs b/pingora-core/examples/service_dependencies.rs new file mode 100644 index 00000000..d5f5e392 --- /dev/null +++ b/pingora-core/examples/service_dependencies.rs @@ -0,0 +1,234 @@ +// Copyright 2026 Cloudflare, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Example demonstrating service dependency management. +//! +//! This example shows how services can declare dependencies on other services using +//! a fluent API with [`ServiceHandle`] references, ensuring they start in the correct +//! order and wait for dependencies to be ready. +//! +//! # Running the example +//! +//! ```bash +//! cargo run --example service_dependencies --package pingora-core +//! ``` +//! +//! Expected output: +//! - DatabaseService starts and initializes (takes 2 seconds) +//! - CacheService starts and initializes (takes 1 second) +//! - ApiService waits for both dependencies, then starts +//! +//! # Key Features Demonstrated +//! +//! - Fluent API for declaring dependencies via [`ServiceHandle::add_dependency()`] +//! - Type-safe dependency declaration (no strings) +//! - Multiple ways to implement services based on readiness needs: +//! - **DatabaseService**: Custom readiness timing (uses `ServiceWithDependencies`) +//! - **CacheService**: Ready immediately (uses `Service`) +//! - **ApiService**: Ready immediately (uses `Service`) +//! - Automatic dependency ordering and validation +//! - Prevention of typos in service names (compile-time safety) + +use async_trait::async_trait; +use log::info; +use pingora_core::server::configuration::Opt; +#[cfg(unix)] +use pingora_core::server::ListenFds; +use pingora_core::server::{Server, ShutdownWatch}; +use pingora_core::services::{Service, ServiceWithDependents}; +// DatabaseService needs to control readiness timing +use pingora_core::services::ServiceReadyNotifier; +use std::sync::Arc; +use tokio::sync::Mutex; +use tokio::time::{sleep, Duration}; + +/// A custom service that delays signaling ready until initialization is complete +pub struct DatabaseService { + connection_string: Arc>>, +} + +impl DatabaseService { + fn new() -> Self { + Self { + connection_string: Arc::new(Mutex::new(None)), + } + } + + fn get_connection_string(&self) -> Arc>> { + self.connection_string.clone() + } +} + +#[async_trait] +impl ServiceWithDependents for DatabaseService { + async fn start_service( + &mut self, + #[cfg(unix)] _fds: Option, + mut shutdown: ShutdownWatch, + _listeners_per_fd: usize, + ready_notifier: ServiceReadyNotifier, + ) { + info!("DatabaseService: Starting initialization..."); + + // Simulate database connection setup + sleep(Duration::from_secs(2)).await; + + // Store the connection string + { + let mut conn = self.connection_string.lock().await; + *conn = Some("postgresql://localhost:5432/mydb".to_string()); + } + + info!("DatabaseService: Initialization complete, signaling ready"); + + // Signal that the service is ready + ready_notifier.notify_ready(); + + // Keep running until shutdown + shutdown.changed().await.ok(); + info!("DatabaseService: Shutting down"); + } + + fn name(&self) -> &str { + "database" + } + + fn threads(&self) -> Option { + Some(1) + } +} + +/// A cache service that uses the simplified API +/// Signals ready immediately (using default implementation) +pub struct CacheService; + +#[async_trait] +impl Service for CacheService { + // Uses default start_service implementation which signals ready immediately + + async fn start_service( + &mut self, + #[cfg(unix)] _fds: Option, + mut shutdown: ShutdownWatch, + _listeners_per_fd: usize, + ) { + info!("CacheService: Starting (ready immediately)..."); + + // Simulate cache warmup + sleep(Duration::from_secs(1)).await; + info!("CacheService: Warmup complete"); + + // Keep running until shutdown + shutdown.changed().await.ok(); + info!("CacheService: Shutting down"); + } + + fn name(&self) -> &str { + "cache" + } + + fn threads(&self) -> Option { + Some(1) + } +} + +/// An API service that depends on both database and cache +/// Uses the simplest API - signals ready immediately and just implements [Service] +pub struct ApiService { + db_connection: Arc>>, +} + +impl ApiService { + fn new(db_connection: Arc>>) -> Self { + Self { db_connection } + } +} + +#[async_trait] +impl Service for ApiService { + // Uses default start_service - signals ready immediately + + async fn start_service( + &mut self, + #[cfg(unix)] _fds: Option, + mut shutdown: ShutdownWatch, + _listeners_per_fd: usize, + ) { + info!("ApiService: Starting (dependencies should be ready)..."); + + // Verify database connection is available + { + let conn = self.db_connection.lock().await; + if let Some(conn_str) = &*conn { + info!("ApiService: Using database connection: {}", conn_str); + } else { + panic!("ApiService: Database connection not available!"); + } + } + + info!("ApiService: Ready to serve requests"); + + // Keep running until shutdown + shutdown.changed().await.ok(); + info!("ApiService: Shutting down"); + } + + fn name(&self) -> &str { + "api" + } + + fn threads(&self) -> Option { + Some(1) + } +} + +fn main() { + env_logger::Builder::from_default_env() + .filter_level(log::LevelFilter::Info) + .init(); + + info!("Starting server with service dependencies..."); + + let opt = Opt::parse_args(); + let mut server = Server::new(Some(opt)).unwrap(); + server.bootstrap(); + + // Create the database service + let db_service = DatabaseService::new(); + let db_connection = db_service.get_connection_string(); + + // Create services + let cache_service = CacheService; + let api_service = ApiService::new(db_connection); + + // Add services and get their handles + let db_handle = server.add_service(db_service); + let cache_handle = server.add_service(cache_service); + let api_handle = server.add_service(api_service); + + // Declare dependencies using the fluent API + // The API service will not start until both dependencies signal ready + api_handle.add_dependency(db_handle); + api_handle.add_dependency(&cache_handle); + + info!("Services configured. Starting server..."); + info!("Expected startup order:"); + info!(" 1. database (will initialize for 2 seconds)"); + info!(" 2. cache (will initialize for 1 second)"); + info!(" 3. api (will wait for both, then start)"); + info!(""); + info!("Press Ctrl+C to shut down"); + + server.run_forever(); +} diff --git a/pingora-core/src/apps/http_app.rs b/pingora-core/src/apps/http_app.rs index d2c59513..f511012c 100644 --- a/pingora-core/src/apps/http_app.rs +++ b/pingora-core/src/apps/http_app.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-core/src/apps/mod.rs b/pingora-core/src/apps/mod.rs index 461084e4..d751fbcc 100644 --- a/pingora-core/src/apps/mod.rs +++ b/pingora-core/src/apps/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -61,24 +61,57 @@ pub trait ServerApp { #[derive(Default)] /// HTTP Server options that control how the server handles some transport types. pub struct HttpServerOptions { - /// Use HTTP/2 for plaintext. + /// Allow HTTP/2 for plaintext. pub h2c: bool, + + /// Allow proxying CONNECT requests when handling HTTP traffic. + /// + /// When disabled, CONNECT requests are rejected with 405 by proxy services. + pub allow_connect_method_proxying: bool, + + #[doc(hidden)] + pub force_custom: bool, + + /// Maximum number of requests that this connection will handle. This is + /// equivalent to [Nginx's keepalive requests](https://nginx.org/en/docs/http/ngx_http_upstream_module.html#keepalive_requests) + /// which says: + /// + /// > Closing connections periodically is necessary to free per-connection + /// > memory allocations. Therefore, using too high maximum number of + /// > requests could result in excessive memory usage and not recommended. + /// + /// Unlike nginx, the default behavior here is _no limit_. + pub keepalive_request_limit: Option, } #[derive(Debug, Clone)] pub struct HttpPersistentSettings { keepalive_timeout: Option, + keepalive_reuses_remaining: Option, } impl HttpPersistentSettings { pub fn for_session(session: &ServerSession) -> Self { HttpPersistentSettings { keepalive_timeout: session.get_keepalive(), + keepalive_reuses_remaining: session.get_keepalive_reuses_remaining(), } } - pub fn apply_to_session(&self, session: &mut ServerSession) { - session.set_keepalive(self.keepalive_timeout); + pub fn apply_to_session(self, session: &mut ServerSession) { + let Self { + keepalive_timeout, + mut keepalive_reuses_remaining, + } = self; + + // Reduce the number of times the connection for this session can be + // reused by one. A session with reuse count of zero won't be reused + if let Some(reuses) = keepalive_reuses_remaining.as_mut() { + *reuses = reuses.saturating_sub(1); + } + + session.set_keepalive(keepalive_timeout); + session.set_keepalive_reuses_remaining(keepalive_reuses_remaining); } } @@ -133,6 +166,15 @@ pub trait HttpServerApp { } async fn http_cleanup(&self) {} + + #[doc(hidden)] + async fn process_custom_session( + self: Arc, + _stream: Stream, + _shutdown: &ShutdownWatch, + ) -> Option { + None + } } #[async_trait] @@ -146,9 +188,13 @@ where shutdown: &ShutdownWatch, ) -> Option { let mut h2c = self.server_options().as_ref().map_or(false, |o| o.h2c); + let custom = self + .server_options() + .as_ref() + .map_or(false, |o| o.force_custom); // try to read h2 preface - if h2c { + if h2c && !custom { let mut buf = [0u8; H2_PREFACE.len()]; let peeked = stream .try_peek(&mut buf) @@ -215,6 +261,8 @@ where .await; }); } + } else if custom || matches!(stream.selected_alpn_proto(), Some(ALPN::Custom(_))) { + return self.clone().process_custom_session(stream, shutdown).await; } else { // No ALPN or ALPN::H1 and h2c was not configured, fallback to HTTP/1.1 let mut session = ServerSession::new_http1(stream); @@ -225,6 +273,10 @@ where // default 60s session.set_keepalive(Some(60)); } + session.set_keepalive_reuses_remaining( + self.server_options() + .and_then(|opts| opts.keepalive_request_limit), + ); let mut result = self.process_new_http(session, shutdown).await; while let Some((stream, persistent_settings)) = result.map(|r| r.consume()) { @@ -232,10 +284,6 @@ where if let Some(persistent_settings) = persistent_settings { persistent_settings.apply_to_session(&mut session); } - if *shutdown.borrow() { - // stop downstream from reusing if this service is shutting down soon - session.set_keepalive(None); - } result = self.process_new_http(session, shutdown).await; } diff --git a/pingora-core/src/apps/prometheus_http_app.rs b/pingora-core/src/apps/prometheus_http_app.rs index 963d5a9e..ed8a217a 100644 --- a/pingora-core/src/apps/prometheus_http_app.rs +++ b/pingora-core/src/apps/prometheus_http_app.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-core/src/connectors/http/custom/mod.rs b/pingora-core/src/connectors/http/custom/mod.rs new file mode 100644 index 00000000..e1e8a11d --- /dev/null +++ b/pingora-core/src/connectors/http/custom/mod.rs @@ -0,0 +1,80 @@ +// Copyright 2026 Cloudflare, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use std::time::Duration; + +use pingora_error::Result; + +use crate::{ + protocols::{http::custom::client::Session, Stream}, + upstreams::peer::Peer, +}; + +// Either returns a Custom Session or the Stream for creating a new H1 session as a fallback. +pub enum Connection { + Session(S), + Stream(Stream), +} +#[doc(hidden)] +#[async_trait] +pub trait Connector: Send + Sync + Unpin + 'static { + type Session: Session; + + async fn get_http_session( + &self, + peer: &P, + ) -> Result<(Connection, bool)>; + + async fn reused_http_session( + &self, + peer: &P, + ) -> Option; + + async fn release_http_session( + &self, + mut session: Self::Session, + peer: &P, + idle_timeout: Option, + ); +} + +#[doc(hidden)] +#[async_trait] +impl Connector for () { + type Session = (); + + async fn get_http_session( + &self, + _peer: &P, + ) -> Result<(Connection, bool)> { + unreachable!("connector: get_http_session") + } + + async fn reused_http_session( + &self, + _peer: &P, + ) -> Option { + unreachable!("connector: reused_http_session") + } + + async fn release_http_session( + &self, + _session: Self::Session, + _peer: &P, + _idle_timeout: Option, + ) { + unreachable!("connector: release_http_session") + } +} diff --git a/pingora-core/src/connectors/http/mod.rs b/pingora-core/src/connectors/http/mod.rs index 45b14f44..2545cf7c 100644 --- a/pingora-core/src/connectors/http/mod.rs +++ b/pingora-core/src/connectors/http/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,25 +14,47 @@ //! Connecting to HTTP servers +use crate::connectors::http::custom::Connection; use crate::connectors::ConnectorOptions; +use crate::listeners::ALPN; use crate::protocols::http::client::HttpSession; +use crate::protocols::http::v1::client::HttpSession as Http1Session; use crate::upstreams::peer::Peer; use pingora_error::Result; use std::time::Duration; +pub mod custom; pub mod v1; pub mod v2; -pub struct Connector { +pub struct Connector +where + C: custom::Connector, +{ h1: v1::Connector, h2: v2::Connector, + custom: C, } -impl Connector { +impl Connector<()> { pub fn new(options: Option) -> Self { Connector { h1: v1::Connector::new(options.clone()), - h2: v2::Connector::new(options), + h2: v2::Connector::new(options.clone()), + custom: Default::default(), + } + } +} + +impl Connector +where + C: custom::Connector, +{ + pub fn new_custom(options: Option, custom: C) -> Self { + Connector { + h1: v1::Connector::new(options.clone()), + h2: v2::Connector::new(options.clone()), + custom, } } @@ -42,14 +64,46 @@ impl Connector { pub async fn get_http_session( &self, peer: &P, - ) -> Result<(HttpSession, bool)> { + ) -> Result<(HttpSession, bool)> { + let peer_opts = peer.get_peer_options(); + + // Switch to custom protocol as early as possible + if peer_opts.is_some_and(|o| matches!(o.alpn, ALPN::Custom(_))) { + // We create the Connector before TLS, so we need to make sure that the server also supports the same custom protocol. + // We will first check for sessions that we can reuse, if not we will create a new one based on the negotiated protocol + + // Step 1: Look for reused Custom Session + if let Some(session) = self.custom.reused_http_session(peer).await { + return Ok((HttpSession::Custom(session), true)); + } + // Step 2: Check reuse pool for reused H1 session + if let Some(h1) = self.h1.reused_http_session(peer).await { + return Ok((HttpSession::H1(h1), true)); + } + // Step 3: Try and create a new Custom session + let (connection, reused) = self.custom.get_http_session(peer).await?; + // We create the Connector before TLS, so we need to make sure that the server also supports the same custom protocol + match connection { + Connection::Session(s) => { + return Ok((HttpSession::Custom(s), reused)); + } + // Negotiated ALPN is not custom, create a new H1 session + Connection::Stream(s) => { + return Ok(( + HttpSession::H1(Http1Session::new_with_options(s, peer)), + false, + )); + } + } + } + // NOTE: maybe TODO: we do not yet enforce that only TLS traffic can use h2, which is the // de facto requirement for h2, because non TLS traffic lack the negotiation mechanism. // We assume no peer option == no ALPN == h1 only let h1_only = peer .get_peer_options() - .map_or(true, |o| o.alpn.get_max_http_version() == 1); + .is_none_or(|o| o.alpn.get_max_http_version() == 1); if h1_only { let (h1, reused) = self.h1.get_http_session(peer).await?; Ok((HttpSession::H1(h1), reused)) @@ -78,13 +132,18 @@ impl Connector { pub async fn release_http_session( &self, - session: HttpSession, + session: HttpSession, peer: &P, idle_timeout: Option, ) { match session { HttpSession::H1(h1) => self.h1.release_http_session(h1, peer, idle_timeout).await, HttpSession::H2(h2) => self.h2.release_http_session(h2, peer, idle_timeout), + HttpSession::Custom(c) => { + self.custom + .release_http_session(c, peer, idle_timeout) + .await; + } } } @@ -98,9 +157,21 @@ impl Connector { #[cfg(feature = "any_tls")] mod tests { use super::*; + use crate::connectors::TransportConnector; + use crate::listeners::tls::TlsSettings; + use crate::listeners::{Listeners, TransportStack, ALPN}; use crate::protocols::http::v1::client::HttpSession as Http1Session; + use crate::protocols::tls::CustomALPN; use crate::upstreams::peer::HttpPeer; + use crate::upstreams::peer::PeerOptions; + use async_trait::async_trait; use pingora_http::RequestHeader; + use std::sync::Arc; + use std::sync::Mutex; + use tokio::io::AsyncWriteExt; + use tokio::net::TcpListener; + use tokio::task::JoinHandle; + use tokio::time::sleep; async fn get_http(http: &mut Http1Session, expected_status: u16) { let mut req = Box::new(RequestHeader::build("GET", b"/", None).unwrap()); @@ -123,6 +194,7 @@ mod tests { match &h2 { HttpSession::H1(_) => panic!("expect h2"), HttpSession::H2(h2_stream) => assert!(!h2_stream.ping_timedout()), + HttpSession::Custom(_) => panic!("expect h2"), } connector.release_http_session(h2, &peer, None).await; @@ -133,6 +205,7 @@ mod tests { match &h2 { HttpSession::H1(_) => panic!("expect h2"), HttpSession::H2(h2_stream) => assert!(!h2_stream.ping_timedout()), + HttpSession::Custom(_) => panic!("expect h2"), } } @@ -148,6 +221,7 @@ mod tests { get_http(http, 200).await; } HttpSession::H2(_) => panic!("expect h1"), + HttpSession::Custom(_) => panic!("expect h1"), } connector.release_http_session(h1, &peer, None).await; @@ -157,6 +231,7 @@ mod tests { match &mut h1 { HttpSession::H1(_) => {} HttpSession::H2(_) => panic!("expect h1"), + HttpSession::Custom(_) => panic!("expect h1"), } } @@ -178,6 +253,7 @@ mod tests { get_http(http, 200).await; } HttpSession::H2(_) => panic!("expect h1"), + HttpSession::Custom(_) => panic!("expect h1"), } connector.release_http_session(h1, &peer, None).await; @@ -190,6 +266,7 @@ mod tests { match &mut h1 { HttpSession::H1(_) => {} HttpSession::H2(_) => panic!("expect h1"), + HttpSession::Custom(_) => panic!("expect h1"), } } @@ -207,6 +284,7 @@ mod tests { get_http(http, 200).await; } HttpSession::H2(_) => panic!("expect h1"), + HttpSession::Custom(_) => panic!("expect h1"), } connector.release_http_session(h1, &peer, None).await; @@ -217,6 +295,314 @@ mod tests { match &mut h1 { HttpSession::H1(_) => {} HttpSession::H2(_) => panic!("expect h1"), + HttpSession::Custom(_) => panic!("expect h1"), + } + } + // Track the flow of calls when using a custom protocol. For this we need to create a Mock Connector + struct MockConnector { + transport: TransportConnector, + reusable: Arc>, // Mock for tracking reusable sessions + } + + #[async_trait] + impl custom::Connector for MockConnector { + type Session = (); + + async fn get_http_session( + &self, + peer: &P, + ) -> Result<(Connection, bool)> { + let (stream, _) = self.transport.get_stream(peer).await?; + + match stream.selected_alpn_proto() { + Some(ALPN::Custom(_)) => Ok((custom::Connection::Session(()), false)), + _ => Ok(((custom::Connection::Stream(stream)), false)), + } + } + + async fn reused_http_session( + &self, + _peer: &P, + ) -> Option { + let mut flag = self.reusable.lock().unwrap(); + if *flag { + *flag = false; + Some(()) + } else { + None + } + } + + async fn release_http_session( + &self, + _session: Self::Session, + _peer: &P, + _idle_timeout: Option, + ) { + let mut flag = self.reusable.lock().unwrap(); + *flag = true; } } + + // Finds an available TCP port on localhost for test server setup. + async fn get_available_port() -> u16 { + TcpListener::bind("127.0.0.1:0") + .await + .unwrap() + .local_addr() + .unwrap() + .port() + } + // Creates a test connector for integration/unit tests. + // For rustls, only ConnectorOptions are used here; the actual dangerous verifier is patched in the TLS connector. + fn create_test_connector() -> Connector { + #[cfg(feature = "rustls")] + let custom_transport = { + let options = ConnectorOptions::new(1); + TransportConnector::new(Some(options)) + }; + #[cfg(not(feature = "rustls"))] + let custom_transport = TransportConnector::new(None); + Connector { + h1: v1::Connector::new(None), + h2: v2::Connector::new(None), + custom: MockConnector { + transport: custom_transport, + reusable: Arc::new(Mutex::new(false)), + }, + } + } + + // Creates a test peer that uses a custom ALPN protocol and disables cert/hostname verification for tests. + fn create_peer_with_custom_proto(port: u16, proto: &[u8]) -> HttpPeer { + let mut peer = HttpPeer::new(("127.0.0.1", port), true, "localhost".into()); + let mut options = PeerOptions::new(); + options.alpn = ALPN::Custom(CustomALPN::new(proto.to_vec())); + // Disable cert verification for this test (self-signed or invalid certs are OK) + options.verify_cert = false; + options.verify_hostname = false; + peer.options = options; + peer + } + async fn build_custom_tls_listener(port: u16, custom_alpn: CustomALPN) -> TransportStack { + let cert_path = format!("{}/tests/keys/server.crt", env!("CARGO_MANIFEST_DIR")); + let key_path = format!("{}/tests/keys/key.pem", env!("CARGO_MANIFEST_DIR")); + let addr = format!("127.0.0.1:{}", port); + let mut listeners = Listeners::new(); + let mut tls_settings = TlsSettings::intermediate(&cert_path, &key_path).unwrap(); + + tls_settings.set_alpn(ALPN::Custom(custom_alpn)); + listeners.add_tls_with_settings(&addr, None, tls_settings); + listeners + .build( + #[cfg(unix)] + None, + ) + .await + .unwrap() + .pop() + .unwrap() + } + + // Spawn a simple TLS Server + fn spawn_test_tls_server(listener: TransportStack) -> JoinHandle<()> { + tokio::spawn(async move { + loop { + let stream = match listener.accept().await { + Ok(stream) => stream, + Err(_) => break, // Exit if listener is closed + }; + let mut stream = stream.handshake().await.unwrap(); + + let _ = stream.write_all(b"CUSTOM").await; // Ignore write errors + } + }) + } + + // Both server and client are using the same custom protocol + #[tokio::test] + async fn test_custom_client_custom_upstream() { + let port = get_available_port().await; + let custom_protocol = b"custom".to_vec(); + + let listener = + build_custom_tls_listener(port, CustomALPN::new(custom_protocol.clone())).await; + let server_handle = spawn_test_tls_server(listener); + // Wait for server to start up + sleep(Duration::from_millis(100)).await; + + let connector = create_test_connector(); + let peer = create_peer_with_custom_proto(port, &custom_protocol); + + // Check that the agreed ALPN is custom and matches the expected value + if let Ok((stream, reused)) = connector.custom.transport.get_stream(&peer).await { + assert!(!reused); + match stream.selected_alpn_proto() { + Some(ALPN::Custom(protocol)) => { + assert_eq!( + protocol.protocol(), + custom_protocol.as_slice(), + "Negotiated custom ALPN does not match expected value" + ); + } + other => panic!("Expected custom ALPN, got {:?}", other), + } + } else { + panic!("Should be able to create a stream"); + } + + let (custom, reused) = connector.get_http_session(&peer).await.unwrap(); + assert!(!reused); + match custom { + HttpSession::H1(_) => panic!("expect custom"), + HttpSession::H2(_) => panic!("expect custom"), + HttpSession::Custom(_) => {} + } + connector.release_http_session(custom, &peer, None).await; + + // Assert it returns a reused custom session this time + let (custom, reused) = connector.get_http_session(&peer).await.unwrap(); + assert!(reused); + match custom { + HttpSession::H1(_) => panic!("expect custom"), + HttpSession::H2(_) => panic!("expect custom"), + HttpSession::Custom(_) => {} + } + + // Kill the server task + server_handle.abort(); + sleep(Duration::from_millis(100)).await; + } + + // Both client and server are using custom protocols, but different ones - we should create H1 sessions as fallback. + // For RusTLS if there is no agreed protocol, the handshake directly fails, so this won't work + // TODO: If no ALPN is matched, rustls should return None instead of failing the handshake. + #[cfg(not(feature = "rustls"))] + #[tokio::test] + async fn test_incompatible_custom_client_custom_upstream() { + let port = get_available_port().await; + let custom_protocol = b"custom".to_vec(); + + let listener = + build_custom_tls_listener(port, CustomALPN::new(b"different_custom".to_vec())).await; + let server_handle = spawn_test_tls_server(listener); + // Wait for server to start up + sleep(Duration::from_millis(100)).await; + + let connector = create_test_connector(); + let peer = create_peer_with_custom_proto(port, &custom_protocol); + + // Verify that there is no agreed ALPN + if let Ok((stream, reused)) = connector.custom.transport.get_stream(&peer).await { + assert!(!reused); + assert!(stream.selected_alpn_proto().is_none()); + } else { + panic!("Should be able to create a stream"); + } + + let (h1, reused) = connector.get_http_session(&peer).await.unwrap(); + assert!(!reused); + match h1 { + HttpSession::H1(_) => {} + HttpSession::H2(_) => panic!("expect h1"), + HttpSession::Custom(_) => panic!("expect h1"), + } + // Not testing session reuse logic here as we haven't implemented it. Next test will test this. + + // Kill the server task + server_handle.abort(); + sleep(Duration::from_millis(100)).await; + } + + // Client thinks server is custom but server is not Custom. Should fallback to H1 + #[tokio::test] + async fn test_custom_client_non_custom_upstream() { + let custom_proto = b"custom".to_vec(); + let connector = create_test_connector(); + // Upstream supports H1 and H2 + let mut peer = HttpPeer::new(("1.1.1.1", 443), true, "one.one.one.one".into()); + // Client sets upstream ALPN as custom protocol + peer.options.alpn = ALPN::Custom(CustomALPN::new(custom_proto)); + + // Verify that there is no agreed ALPN + if let Ok((stream, reused)) = connector.custom.transport.get_stream(&peer).await { + assert!(!reused); + assert!(stream.selected_alpn_proto().is_none()); + } else { + panic!("Should be able to create a stream"); + } + + let (mut h1, reused) = connector.get_http_session(&peer).await.unwrap(); + // Assert it returns a new H1 session + assert!(!reused); + match &mut h1 { + HttpSession::H1(http) => { + get_http(http, 200).await; + } + HttpSession::H2(_) => panic!("expect h1"), + HttpSession::Custom(_) => panic!("expect h1"), + } + connector.release_http_session(h1, &peer, None).await; + + // Assert it returns a reused h1 session this time + let (mut h1, reused) = connector.get_http_session(&peer).await.unwrap(); + assert!(reused); + match &mut h1 { + HttpSession::H1(_) => {} + HttpSession::H2(_) => panic!("expect h1"), + HttpSession::Custom(_) => panic!("expect h1"), + } + } +} + +// Used for disabling certificate/hostname verification in rustls for tests and custom ALPN/self-signed scenarios. +#[cfg(all(test, feature = "rustls"))] +pub mod rustls_no_verify { + use rustls::client::danger::{ServerCertVerified, ServerCertVerifier}; + use rustls::pki_types::{CertificateDer, ServerName}; + use rustls::Error as TLSError; + use std::sync::Arc; + #[derive(Debug)] + pub struct NoCertificateVerification; + + impl ServerCertVerifier for NoCertificateVerification { + fn verify_server_cert( + &self, + _end_entity: &CertificateDer, + _intermediates: &[CertificateDer], + _server_name: &ServerName, + _scts: &[u8], + _now: rustls::pki_types::UnixTime, + ) -> Result { + Ok(ServerCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + _message: &[u8], + _cert: &CertificateDer, + _dss: &rustls::DigitallySignedStruct, + ) -> Result { + Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) + } + + fn verify_tls13_signature( + &self, + _message: &[u8], + _cert: &CertificateDer, + _dss: &rustls::DigitallySignedStruct, + ) -> Result { + Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) + } + + fn supported_verify_schemes(&self) -> Vec { + vec![rustls::SignatureScheme::ECDSA_NISTP256_SHA256] + } + } + + pub fn apply_no_verify(config: &mut rustls::ClientConfig) { + config + .dangerous() + .set_certificate_verifier(Arc::new(NoCertificateVerification)); + } } diff --git a/pingora-core/src/connectors/http/v1.rs b/pingora-core/src/connectors/http/v1.rs index 36026a40..62ecfcb6 100644 --- a/pingora-core/src/connectors/http/v1.rs +++ b/pingora-core/src/connectors/http/v1.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -35,7 +35,7 @@ impl Connector { peer: &P, ) -> Result<(HttpSession, bool)> { let (stream, reused) = self.transport.get_stream(peer).await?; - let http = HttpSession::new(stream); + let http = HttpSession::new_with_options(stream, peer); Ok((http, reused)) } @@ -43,10 +43,9 @@ impl Connector { &self, peer: &P, ) -> Option { - self.transport - .reused_stream(peer) - .await - .map(HttpSession::new) + let stream = self.transport.reused_stream(peer).await?; + let http = HttpSession::new_with_options(stream, peer); + Some(http) } pub async fn release_http_session( @@ -68,7 +67,9 @@ mod tests { use super::*; use crate::protocols::l4::socket::SocketAddr; use crate::upstreams::peer::HttpPeer; + use crate::upstreams::peer::Peer; use pingora_http::RequestHeader; + use std::fmt::{Display, Formatter, Result as FmtResult}; async fn get_http(http: &mut HttpSession, expected_status: u16) { let mut req = Box::new(RequestHeader::build("GET", b"/", None).unwrap()); @@ -102,6 +103,63 @@ mod tests { assert!(reused); } + #[cfg(unix)] + #[tokio::test] + async fn test_reuse_rejects_fd_mismatch() { + use std::os::unix::prelude::AsRawFd; + + #[derive(Clone)] + struct MismatchPeer { + reuse_hash: u64, + address: SocketAddr, + } + + impl Display for MismatchPeer { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + write!(f, "{:?}", self.address) + } + } + + impl Peer for MismatchPeer { + fn address(&self) -> &SocketAddr { + &self.address + } + + fn tls(&self) -> bool { + false + } + + fn sni(&self) -> &str { + "" + } + + fn reuse_hash(&self) -> u64 { + self.reuse_hash + } + + fn matches_fd(&self, _fd: V) -> bool { + false + } + } + + let connector = Connector::new(None); + let peer = HttpPeer::new(("1.1.1.1", 80), false, "".into()); + let (mut http, reused) = connector.get_http_session(&peer).await.unwrap(); + assert!(!reused); + get_http(&mut http, 301).await; + connector.release_http_session(http, &peer, None).await; + + let mismatch_peer = MismatchPeer { + reuse_hash: peer.reuse_hash(), + address: peer.address().clone(), + }; + + assert!(connector + .reused_http_session(&mismatch_peer) + .await + .is_none()); + } + #[tokio::test] #[cfg(feature = "any_tls")] async fn test_connect_tls() { diff --git a/pingora-core/src/connectors/http/v2.rs b/pingora-core/src/connectors/http/v2.rs index 92cc31d5..c18914c0 100644 --- a/pingora-core/src/connectors/http/v2.rs +++ b/pingora-core/src/connectors/http/v2.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,6 +14,7 @@ use super::HttpSession; use crate::connectors::{ConnectorOptions, TransportConnector}; +use crate::protocols::http::custom::client::Session; use crate::protocols::http::v1::client::HttpSession as Http1Session; use crate::protocols::http::v2::client::{drive_connection, Http2Session}; use crate::protocols::{Digest, Stream, UniqueIDType}; @@ -62,7 +63,7 @@ pub(crate) struct ConnectionRefInner { } #[derive(Clone)] -pub(crate) struct ConnectionRef(Arc); +pub struct ConnectionRef(Arc); impl ConnectionRef { pub fn new( @@ -162,7 +163,7 @@ impl ConnectionRef { } } -struct InUsePool { +pub struct InUsePool { // TODO: use pingora hashmap to shard the lock contention pools: RwLock>>, } @@ -174,7 +175,7 @@ impl InUsePool { } } - fn insert(&self, reuse_hash: u64, conn: ConnectionRef) { + pub fn insert(&self, reuse_hash: u64, conn: ConnectionRef) { { let pools = self.pools.read(); if let Some(pool) = pools.get(&reuse_hash) { @@ -192,14 +193,14 @@ impl InUsePool { // retrieve a h2 conn ref to create a new stream // the caller should return the conn ref to this pool if there are still // capacity left for more streams - fn get(&self, reuse_hash: u64) -> Option { + pub fn get(&self, reuse_hash: u64) -> Option { let pools = self.pools.read(); pools.get(&reuse_hash)?.get_any().map(|v| v.1) } // release a h2_stream, this functional will cause an ConnectionRef to be returned (if exist) // the caller should update the ref and then decide where to put it (in use pool or idle) - fn release(&self, reuse_hash: u64, id: UniqueIDType) -> Option { + pub fn release(&self, reuse_hash: u64, id: UniqueIDType) -> Option { let pools = self.pools.read(); if let Some(pool) = pools.get(&reuse_hash) { pool.remove(id) @@ -235,13 +236,25 @@ impl Connector { } } + pub fn transport(&self) -> &TransportConnector { + &self.transport + } + + pub fn idle_pool(&self) -> &Arc> { + &self.idle_pool + } + + pub fn in_use_pool(&self) -> &InUsePool { + &self.in_use_pool + } + /// Create a new Http2 connection to the given server /// /// Either an Http2 or Http1 session can be returned depending on the server's preference. - pub async fn new_http_session( + pub async fn new_http_session( &self, peer: &P, - ) -> Result { + ) -> Result> { let stream = self.transport.new_stream(peer).await?; // check alpn @@ -249,7 +262,9 @@ impl Connector { Some(ALPN::H2) => { /* continue */ } Some(_) => { // H2 not supported - return Ok(HttpSession::H1(Http1Session::new(stream))); + return Ok(HttpSession::H1(Http1Session::new_with_options( + stream, peer, + ))); } None => { // if tls but no ALPN, default to h1 @@ -257,9 +272,11 @@ impl Connector { if peer.tls() || peer .get_peer_options() - .map_or(true, |o| o.alpn.get_min_http_version() == 1) + .is_none_or(|o| o.alpn.get_min_http_version() == 1) { - return Ok(HttpSession::H1(Http1Session::new(stream))); + return Ok(HttpSession::H1(Http1Session::new_with_options( + stream, peer, + ))); } // else: min http version=H2 over plaintext, there is no ALPN anyways, we trust // the caller that the server speaks h2c @@ -302,8 +319,28 @@ impl Connector { let maybe_conn = self .in_use_pool .get(reuse_hash) + // filter out closed, InUsePool does not have notify closed eviction like the idle pool + // and it's possible we get an in use connection that is closed and not yet released + .filter(|c| !c.is_closed()) .or_else(|| self.idle_pool.get(&reuse_hash)); if let Some(conn) = maybe_conn { + #[cfg(unix)] + if !peer.matches_fd(conn.id()) { + return Ok(None); + } + #[cfg(windows)] + { + use std::os::windows::io::{AsRawSocket, RawSocket}; + struct WrappedRawSocket(RawSocket); + impl AsRawSocket for WrappedRawSocket { + fn as_raw_socket(&self) -> RawSocket { + self.0 + } + } + if !peer.matches_sock(WrappedRawSocket(conn.id() as RawSocket)) { + return Ok(None); + } + } let h2_stream = conn.spawn_stream().await?; if conn.more_streams_allowed() { self.in_use_pool.insert(reuse_hash, conn); @@ -353,14 +390,12 @@ impl Connector { }; let closed = conn.0.closed.clone(); let (notify_evicted, watch_use) = self.idle_pool.put(&meta, conn); - if let Some(to) = idle_timeout { - let pool = self.idle_pool.clone(); //clone the arc - let rt = pingora_runtime::current_handle(); - rt.spawn(async move { - pool.idle_timeout(&meta, to, notify_evicted, closed, watch_use) - .await; - }); - } + let pool = self.idle_pool.clone(); //clone the arc + let rt = pingora_runtime::current_handle(); + rt.spawn(async move { + pool.idle_timeout(&meta, idle_timeout, notify_evicted, closed, watch_use) + .await; + }); } else { self.in_use_pool.insert(reuse_hash, conn); drop(locked); @@ -388,7 +423,7 @@ impl Connector { // 8 Mbytes = 80 Mbytes X 100ms, which should be enough for most links. const H2_WINDOW_SIZE: u32 = 1 << 23; -pub(crate) async fn handshake( +pub async fn handshake( stream: Stream, max_streams: usize, h2_ping_interval: Option, @@ -457,6 +492,7 @@ pub(crate) async fn handshake( )) } +// TODO(slava): add custom unit tests #[cfg(test)] mod tests { use super::*; @@ -468,10 +504,14 @@ mod tests { let connector = Connector::new(None); let mut peer = HttpPeer::new(("1.1.1.1", 443), true, "one.one.one.one".into()); peer.options.set_http_version(2, 2); - let h2 = connector.new_http_session(&peer).await.unwrap(); + let h2 = connector + .new_http_session::(&peer) + .await + .unwrap(); match h2 { HttpSession::H1(_) => panic!("expect h2"), HttpSession::H2(h2_stream) => assert!(!h2_stream.ping_timedout()), + HttpSession::Custom(_) => panic!("expect h2"), } } @@ -482,10 +522,14 @@ mod tests { let mut peer = HttpPeer::new(("1.1.1.1", 443), true, "one.one.one.one".into()); // a hack to force h1, new_http_session() in the future might validate this setting peer.options.set_http_version(1, 1); - let h2 = connector.new_http_session(&peer).await.unwrap(); + let h2 = connector + .new_http_session::(&peer) + .await + .unwrap(); match h2 { HttpSession::H1(_) => {} HttpSession::H2(_) => panic!("expect h1"), + HttpSession::Custom(_) => panic!("expect h1"), } } @@ -494,10 +538,14 @@ mod tests { let connector = Connector::new(None); let mut peer = HttpPeer::new(("1.1.1.1", 80), false, "".into()); peer.options.set_http_version(2, 1); - let h2 = connector.new_http_session(&peer).await.unwrap(); + let h2 = connector + .new_http_session::(&peer) + .await + .unwrap(); match h2 { HttpSession::H1(_) => {} HttpSession::H2(_) => panic!("expect h1"), + HttpSession::Custom(_) => panic!("expect h1"), } } @@ -508,10 +556,14 @@ mod tests { let mut peer = HttpPeer::new(("1.1.1.1", 443), true, "one.one.one.one".into()); peer.options.set_http_version(2, 2); peer.options.max_h2_streams = 1; - let h2 = connector.new_http_session(&peer).await.unwrap(); + let h2 = connector + .new_http_session::(&peer) + .await + .unwrap(); let h2_1 = match h2 { HttpSession::H1(_) => panic!("expect h2"), HttpSession::H2(h2_stream) => h2_stream, + HttpSession::Custom(_) => panic!("expect h2"), }; let id = h2_1.conn.id(); @@ -540,10 +592,14 @@ mod tests { let mut peer = HttpPeer::new(("1.1.1.1", 443), true, "one.one.one.one".into()); peer.options.set_http_version(2, 2); peer.options.max_h2_streams = 3; - let h2 = connector.new_http_session(&peer).await.unwrap(); + let h2 = connector + .new_http_session::(&peer) + .await + .unwrap(); let h2_1 = match h2 { HttpSession::H1(_) => panic!("expect h2"), HttpSession::H2(h2_stream) => h2_stream, + HttpSession::Custom(_) => panic!("expect h2"), }; let id = h2_1.conn.id(); @@ -573,4 +629,75 @@ mod tests { let h2_5 = connector.reused_http_session(&peer).await.unwrap().unwrap(); assert_eq!(id, h2_5.conn.id()); } + + #[cfg(all(feature = "any_tls", unix))] + #[tokio::test] + async fn test_h2_reuse_rejects_fd_mismatch() { + use crate::protocols::l4::socket::SocketAddr; + use crate::upstreams::peer::Peer; + use std::fmt::{Display, Formatter, Result as FmtResult}; + use std::os::unix::prelude::AsRawFd; + + #[derive(Clone)] + struct MismatchPeer { + reuse_hash: u64, + address: SocketAddr, + } + + impl Display for MismatchPeer { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + write!(f, "{:?}", self.address) + } + } + + impl Peer for MismatchPeer { + fn address(&self) -> &SocketAddr { + &self.address + } + + fn tls(&self) -> bool { + true + } + + fn sni(&self) -> &str { + "" + } + + fn reuse_hash(&self) -> u64 { + self.reuse_hash + } + + fn matches_fd(&self, _fd: V) -> bool { + false + } + } + + let connector = Connector::new(None); + let mut peer = HttpPeer::new(("1.1.1.1", 443), true, "one.one.one.one".into()); + peer.options.set_http_version(2, 2); + peer.options.max_h2_streams = 1; + + let h2 = connector + .new_http_session::(&peer) + .await + .unwrap(); + let h2_stream = match h2 { + HttpSession::H1(_) => panic!("expect h2"), + HttpSession::H2(h2_stream) => h2_stream, + HttpSession::Custom(_) => panic!("expect h2"), + }; + + connector.release_http_session(h2_stream, &peer, None); + + let mismatch_peer = MismatchPeer { + reuse_hash: peer.reuse_hash(), + address: peer.address().clone(), + }; + + assert!(connector + .reused_http_session(&mismatch_peer) + .await + .unwrap() + .is_none()); + } } diff --git a/pingora-core/src/connectors/l4.rs b/pingora-core/src/connectors/l4.rs index dc442644..bd7439d4 100644 --- a/pingora-core/src/connectors/l4.rs +++ b/pingora-core/src/connectors/l4.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -280,7 +280,7 @@ async fn proxy_connect(peer: &P) -> Result { ); let req_header = raw_connect::generate_connect_header(&proxy.host, proxy.port, &mut headers)?; - let fut = raw_connect::connect(stream, &req_header); + let fut = raw_connect::connect(stream, &req_header, peer); let (mut stream, digest) = match peer.connection_timeout() { Some(t) => pingora_timeout::timeout(t, fut) .await @@ -314,8 +314,6 @@ mod tests { use std::sync::Arc; use std::time::{Duration, Instant}; use tokio::io::AsyncWriteExt; - #[cfg(unix)] - use tokio::net::UnixListener; use tokio::time::sleep; /// Some of the tests below are flaky when making new connections to mock @@ -465,31 +463,20 @@ mod tests { } #[cfg(unix)] - const MOCK_UDS_PATH: &str = "/tmp/test_unix_connect_proxy.sock"; - - // one-off mock server - #[cfg(unix)] - async fn mock_connect_server() { - let _ = std::fs::remove_file(MOCK_UDS_PATH); - let listener = UnixListener::bind(MOCK_UDS_PATH).unwrap(); - if let Ok((mut stream, _addr)) = listener.accept().await { - stream.write_all(b"HTTP/1.1 200 OK\r\n\r\n").await.unwrap(); - // wait a bit so that the client can read - tokio::time::sleep(std::time::Duration::from_millis(100)).await; - } - let _ = std::fs::remove_file(MOCK_UDS_PATH); - } - #[tokio::test(flavor = "multi_thread")] async fn test_connect_proxy_work() { - tokio::spawn(async { - mock_connect_server().await; - }); - // wait for the server to start - tokio::time::sleep(std::time::Duration::from_millis(100)).await; + use crate::connectors::test_utils; + + let socket_path = test_utils::unique_uds_path("connect_proxy_work"); + let (ready_rx, shutdown_tx, server_handle) = + test_utils::spawn_mock_uds_server(socket_path.clone(), b"HTTP/1.1 200 OK\r\n\r\n"); + + // Wait for the server to be ready + ready_rx.await.unwrap(); + let mut peer = HttpPeer::new("1.1.1.1:80".to_string(), false, "".to_string()); let mut path = PathBuf::new(); - path.push(MOCK_UDS_PATH); + path.push(&socket_path); peer.proxy = Some(Proxy { next_hop: path.into(), host: "1.1.1.1".into(), @@ -498,35 +485,27 @@ mod tests { }); let new_session = connect(&peer, None).await; assert!(new_session.is_ok()); - } - - #[cfg(unix)] - const MOCK_BAD_UDS_PATH: &str = "/tmp/test_unix_bad_connect_proxy.sock"; - // one-off mock bad proxy - // closes connection upon accepting - #[cfg(unix)] - async fn mock_connect_bad_server() { - let _ = std::fs::remove_file(MOCK_BAD_UDS_PATH); - let listener = UnixListener::bind(MOCK_BAD_UDS_PATH).unwrap(); - if let Ok((mut stream, _addr)) = listener.accept().await { - stream.shutdown().await.unwrap(); - tokio::time::sleep(std::time::Duration::from_millis(100)).await; - } - let _ = std::fs::remove_file(MOCK_BAD_UDS_PATH); + // Clean up + let _ = shutdown_tx.send(()); + server_handle.await.unwrap(); } #[cfg(unix)] #[tokio::test(flavor = "multi_thread")] async fn test_connect_proxy_conn_closed() { - tokio::spawn(async { - mock_connect_bad_server().await; - }); - // wait for the server to start - tokio::time::sleep(std::time::Duration::from_millis(100)).await; + use crate::connectors::test_utils; + + let socket_path = test_utils::unique_uds_path("connect_proxy_conn_closed"); + let (ready_rx, shutdown_tx, server_handle) = + test_utils::spawn_mock_uds_server_close_immediate(socket_path.clone()); + + // Wait for the server to be ready + ready_rx.await.unwrap(); + let mut peer = HttpPeer::new("1.1.1.1:80".to_string(), false, "".to_string()); let mut path = PathBuf::new(); - path.push(MOCK_BAD_UDS_PATH); + path.push(&socket_path); peer.proxy = Some(Proxy { next_hop: path.into(), host: "1.1.1.1".into(), @@ -537,6 +516,10 @@ mod tests { let err = new_session.unwrap_err(); assert_eq!(err.etype(), &ConnectionClosed); assert!(!err.retry()); + + // Clean up + let _ = shutdown_tx.send(()); + server_handle.await.unwrap(); } #[cfg(target_os = "linux")] diff --git a/pingora-core/src/connectors/mod.rs b/pingora-core/src/connectors/mod.rs index 1e6c08dc..3e3c1c46 100644 --- a/pingora-core/src/connectors/mod.rs +++ b/pingora-core/src/connectors/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -399,6 +399,86 @@ fn test_reusable_stream(stream: &mut Stream) -> bool { } } +/// Test utilities for creating mock acceptors. +#[cfg(all(test, unix))] +pub(crate) mod test_utils { + use tokio::io::AsyncWriteExt; + use tokio::net::UnixListener; + + /// Generates a unique socket path for testing to avoid conflicts when running in parallel + pub fn unique_uds_path(test_name: &str) -> String { + format!( + "/tmp/test_{test_name}_{:?}_{}.sock", + std::thread::current().id(), + std::process::id() + ) + } + + /// A mock UDS server that accepts one connection, sends data, and waits for shutdown signal + /// + /// Returns: (ready_rx, shutdown_tx, server_handle) + /// - ready_rx: Wait on this to know when server is ready to accept connections + /// - shutdown_tx: Send on this to tell server to shut down + /// - server_handle: Join handle for the server task + pub fn spawn_mock_uds_server( + socket_path: String, + response: &'static [u8], + ) -> ( + tokio::sync::oneshot::Receiver<()>, + tokio::sync::oneshot::Sender<()>, + tokio::task::JoinHandle<()>, + ) { + let (ready_tx, ready_rx) = tokio::sync::oneshot::channel(); + let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel(); + + let server_handle = tokio::spawn(async move { + let _ = std::fs::remove_file(&socket_path); + let listener = UnixListener::bind(&socket_path).unwrap(); + // Signal that the server is ready to accept connections + let _ = ready_tx.send(()); + + if let Ok((mut stream, _addr)) = listener.accept().await { + let _ = stream.write_all(response).await; + // Keep the connection open until the test tells us to shutdown + let _ = shutdown_rx.await; + } + let _ = std::fs::remove_file(&socket_path); + }); + + (ready_rx, shutdown_tx, server_handle) + } + + /// A mock UDS server that immediately closes connections (for testing error handling) + /// + /// Returns: (ready_rx, shutdown_tx, server_handle) + pub fn spawn_mock_uds_server_close_immediate( + socket_path: String, + ) -> ( + tokio::sync::oneshot::Receiver<()>, + tokio::sync::oneshot::Sender<()>, + tokio::task::JoinHandle<()>, + ) { + let (ready_tx, ready_rx) = tokio::sync::oneshot::channel(); + let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel(); + + let server_handle = tokio::spawn(async move { + let _ = std::fs::remove_file(&socket_path); + let listener = UnixListener::bind(&socket_path).unwrap(); + // Signal that the server is ready to accept connections + let _ = ready_tx.send(()); + + if let Ok((mut stream, _addr)) = listener.accept().await { + let _ = stream.shutdown().await; + // Wait for shutdown signal before cleaning up + let _ = shutdown_rx.await; + } + let _ = std::fs::remove_file(&socket_path); + }); + + (ready_rx, shutdown_tx, server_handle) + } +} + #[cfg(test)] #[cfg(feature = "any_tls")] mod tests { @@ -407,9 +487,6 @@ mod tests { use super::*; use crate::upstreams::peer::BasicPeer; - use tokio::io::AsyncWriteExt; - #[cfg(unix)] - use tokio::net::UnixListener; // 192.0.2.1 is effectively a black hole const BLACK_HOLE: &str = "192.0.2.1:79"; @@ -440,38 +517,34 @@ mod tests { assert!(reused); } - #[cfg(unix)] - const MOCK_UDS_PATH: &str = "/tmp/test_unix_transport_connector.sock"; - - // one-off mock server - #[cfg(unix)] - async fn mock_connect_server() { - let _ = std::fs::remove_file(MOCK_UDS_PATH); - let listener = UnixListener::bind(MOCK_UDS_PATH).unwrap(); - if let Ok((mut stream, _addr)) = listener.accept().await { - stream.write_all(b"it works!").await.unwrap(); - // wait a bit so that the client can read - tokio::time::sleep(std::time::Duration::from_millis(100)).await; - } - let _ = std::fs::remove_file(MOCK_UDS_PATH); - } #[tokio::test(flavor = "multi_thread")] + #[cfg(unix)] async fn test_connect_uds() { - tokio::spawn(async { - mock_connect_server().await; - }); + let socket_path = test_utils::unique_uds_path("transport_connector"); + let (ready_rx, shutdown_tx, server_handle) = + test_utils::spawn_mock_uds_server(socket_path.clone(), b"it works!"); + + // Wait for the server to be ready before connecting + ready_rx.await.unwrap(); + // create a new service at /tmp let connector = TransportConnector::new(None); - let peer = BasicPeer::new_uds(MOCK_UDS_PATH).unwrap(); + let peer = BasicPeer::new_uds(&socket_path).unwrap(); // make a new connection to mock uds let mut stream = connector.new_stream(&peer).await.unwrap(); let mut buf = [0; 9]; let _ = stream.read(&mut buf).await.unwrap(); assert_eq!(&buf, b"it works!"); - connector.release_stream(stream, peer.reuse_hash(), None); - let (_, reused) = connector.get_stream(&peer).await.unwrap(); + // Test connection reuse by releasing and getting the stream back + connector.release_stream(stream, peer.reuse_hash(), None); + let (stream, reused) = connector.get_stream(&peer).await.unwrap(); assert!(reused); + + // Clean up: drop the stream, tell server to shutdown, and wait for it + drop(stream); + let _ = shutdown_tx.send(()); + server_handle.await.unwrap(); } async fn do_test_conn_timeout(conf: Option) { diff --git a/pingora-core/src/connectors/offload.rs b/pingora-core/src/connectors/offload.rs index 06fc0895..fe2d1c72 100644 --- a/pingora-core/src/connectors/offload.rs +++ b/pingora-core/src/connectors/offload.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-core/src/connectors/tls/boringssl_openssl/mod.rs b/pingora-core/src/connectors/tls/boringssl_openssl/mod.rs index f9b8c3f1..9bb3a5a6 100644 --- a/pingora-core/src/connectors/tls/boringssl_openssl/mod.rs +++ b/pingora-core/src/connectors/tls/boringssl_openssl/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -198,7 +198,7 @@ where } // second_keyshare is default true - if !peer.get_peer_options().map_or(true, |o| o.second_keyshare) { + if !peer.get_peer_options().is_none_or(|o| o.second_keyshare) { ssl_use_second_key_share(&mut ssl_conf, false); } @@ -246,7 +246,11 @@ where } clear_error_stack(); - let connect_future = handshake(ssl_conf, peer.sni(), stream); + + let complete_hook = peer + .get_peer_options() + .and_then(|o| o.upstream_tls_handshake_complete_hook.clone()); + let connect_future = handshake(ssl_conf, peer.sni(), stream, complete_hook); match peer.connection_timeout() { Some(t) => match pingora_timeout::timeout(t, connect_future).await { diff --git a/pingora-core/src/connectors/tls/mod.rs b/pingora-core/src/connectors/tls/mod.rs index 4c41dfa5..c49be80b 100644 --- a/pingora-core/src/connectors/tls/mod.rs +++ b/pingora-core/src/connectors/tls/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-core/src/connectors/tls/rustls/mod.rs b/pingora-core/src/connectors/tls/rustls/mod.rs index 530d50cb..ff375929 100644 --- a/pingora-core/src/connectors/tls/rustls/mod.rs +++ b/pingora-core/src/connectors/tls/rustls/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -22,8 +22,14 @@ use pingora_error::{ }; use pingora_rustls::{ load_ca_file_into_store, load_certs_and_key_files, load_platform_certs_incl_env_into_store, - version, CertificateDer, ClientConfig as RusTlsClientConfig, PrivateKeyDer, RootCertStore, - TlsConnector as RusTlsConnector, + version, CertificateDer, CertificateError, ClientConfig as RusTlsClientConfig, + DigitallySignedStruct, KeyLogFile, PrivateKeyDer, RootCertStore, RusTlsError, ServerName, + SignatureScheme, TlsConnector as RusTlsConnector, UnixTime, WebPkiServerVerifier, +}; + +// Uses custom certificate verification from rustls's 'danger' module. +use pingora_rustls::{ + HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier as RusTlsServerCertVerifier, }; use crate::protocols::tls::{client::handshake, TlsStream}; @@ -75,7 +81,6 @@ impl TlsConnector { if let Some((cert, key)) = conf.cert_key_file.as_ref() { certs_key = load_certs_and_key_files(cert, key)?; } - // TODO: support SSLKEYLOGFILE } else { load_platform_certs_incl_env_into_store(&mut ca_certs)?; } @@ -88,7 +93,7 @@ impl TlsConnector { RusTlsClientConfig::builder_with_protocol_versions(&[&version::TLS12, &version::TLS13]) .with_root_certificates(ca_certs.clone()); - let config = match certs_key { + let mut config = match certs_key { Some((certs, key)) => { match builder.with_client_auth_cert(certs.clone(), key.clone_key()) { Ok(config) => config, @@ -102,6 +107,13 @@ impl TlsConnector { None => builder.with_no_client_auth(), }; + // Enable SSLKEYLOGFILE support for debugging TLS traffic + if let Some(options) = options.as_ref() { + if options.debug_ssl_keylog { + config.key_log = Arc::new(KeyLogFile::new()); + } + } + Ok(Connector { ctx: Arc::new(TlsConnector { config: Arc::new(config), @@ -155,10 +167,12 @@ where .with_root_certificates(Arc::clone(&tls_ctx.ca_certs)); debug!("added root ca certificates"); - let updated_config = builder.with_client_auth_cert(certs, private_key).or_err( + let mut updated_config = builder.with_client_auth_cert(certs, private_key).or_err( InvalidCert, "Failed to use peer cert/key to update Rustls config", )?; + // Preserve keylog setting from original config + updated_config.key_log = Arc::clone(&config.key_log); Some(updated_config) } }; @@ -174,30 +188,64 @@ where } } + let mut domain = peer.sni().to_string(); + + if let Some(updated_config) = updated_config_opt.as_mut() { + let verification_mode = if peer.sni().is_empty() { + updated_config.enable_sni = false; + /* NOTE: technically we can still verify who signs the cert but turn it off to be + consistent with nginx's behavior */ + Some(VerificationMode::SkipAll) // disable verification if sni does not exist + } else if !peer.verify_cert() { + Some(VerificationMode::SkipAll) + } else if !peer.verify_hostname() { + Some(VerificationMode::SkipHostname) + } else { + // if sni had underscores in leftmost label replace and add + if let Some(sni_s) = replace_leftmost_underscore(peer.sni()) { + domain = sni_s; + } + None + // to use the custom verifier for the full verify: + // Some(VerificationMode::Full) + }; + + // Builds the custom_verifier when verification_mode is set. + if let Some(mode) = verification_mode { + let delegate = WebPkiServerVerifier::builder(Arc::clone(&tls_ctx.ca_certs)) + .build() + .or_err(InvalidCert, "Failed to build WebPkiServerVerifier")?; + + let custom_verifier = Arc::new(CustomServerCertVerifier::new(delegate, mode)); + + updated_config + .dangerous() + .set_certificate_verifier(custom_verifier); + } + } + // TODO: curve setup from peer // - second key share from peer, currently only used in boringssl with PQ features + // Patch config for dangerous verifier if needed, but only in test builds. + #[cfg(test)] + if !peer.verify_cert() || !peer.verify_hostname() { + use crate::connectors::http::rustls_no_verify::apply_no_verify; + if let Some(cfg) = updated_config_opt.as_mut() { + apply_no_verify(cfg); + } else { + let mut tmp = RusTlsClientConfig::clone(config); + apply_no_verify(&mut tmp); + updated_config_opt = Some(tmp); + } + } + let tls_conn = if let Some(cfg) = updated_config_opt { RusTlsConnector::from(Arc::new(cfg)) } else { RusTlsConnector::from(Arc::clone(config)) }; - // TODO: for consistent behavior between TLS providers some additions are required - // - allowing to disable verification - // - the validation/replace logic would need adjustments to match the boringssl/openssl behavior - // implementing a custom certificate_verifier could be used to achieve matching behavior - //let d_conf = config.dangerous(); - //d_conf.set_certificate_verifier(...); - - let mut domain = peer.sni().to_string(); - if peer.verify_cert() && peer.verify_hostname() { - // TODO: streamline logic with replacing first underscore within TLS implementations - if let Some(sni_s) = replace_leftmost_underscore(peer.sni()) { - domain = sni_s; - } - } - let connect_future = handshake(&tls_conn, &domain, stream); match peer.connection_timeout() { @@ -211,3 +259,95 @@ where None => connect_future.await, } } + +#[allow(dead_code)] +#[derive(Debug)] +pub enum VerificationMode { + SkipHostname, + SkipAll, + Full, + // Note: "Full" Included for completeness, making this verifier self-contained + // and explicit about all possible verification modes, not just exceptions. +} + +#[derive(Debug)] +pub struct CustomServerCertVerifier { + delegate: Arc, + verification_mode: VerificationMode, +} + +impl CustomServerCertVerifier { + pub fn new(delegate: Arc, verification_mode: VerificationMode) -> Self { + Self { + delegate, + verification_mode, + } + } +} + +// CustomServerCertVerifier delegates TLS signature verification and allows 3 VerificationMode: +// Full: delegates all verification to the original WebPkiServerVerifier +// SkipHostname: same as "Full" but ignores "NotValidForName" certificate errors +// SkipAll: all certificate verification checks are skipped. +impl RusTlsServerCertVerifier for CustomServerCertVerifier { + fn verify_server_cert( + &self, + _end_entity: &CertificateDer<'_>, + _intermediates: &[CertificateDer<'_>], + _server_name: &ServerName<'_>, + _ocsp: &[u8], + _now: UnixTime, + ) -> Result { + match self.verification_mode { + VerificationMode::Full => self.delegate.verify_server_cert( + _end_entity, + _intermediates, + _server_name, + _ocsp, + _now, + ), + VerificationMode::SkipHostname => { + match self.delegate.verify_server_cert( + _end_entity, + _intermediates, + _server_name, + _ocsp, + _now, + ) { + Ok(scv) => Ok(scv), + Err(RusTlsError::InvalidCertificate(cert_error)) => { + if let CertificateError::NotValidForNameContext { .. } = cert_error { + Ok(ServerCertVerified::assertion()) + } else { + Err(RusTlsError::InvalidCertificate(cert_error)) + } + } + Err(e) => Err(e), + } + } + VerificationMode::SkipAll => Ok(ServerCertVerified::assertion()), + } + } + + fn verify_tls12_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &DigitallySignedStruct, + ) -> Result { + self.delegate.verify_tls12_signature(message, cert, dss) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &DigitallySignedStruct, + ) -> Result { + self.delegate.verify_tls13_signature(message, cert, dss) + } + + fn supported_verify_schemes(&self) -> Vec { + self.delegate.supported_verify_schemes() + } +} diff --git a/pingora-core/src/connectors/tls/s2n/mod.rs b/pingora-core/src/connectors/tls/s2n/mod.rs index 36f931d2..fbfdd7e7 100644 --- a/pingora-core/src/connectors/tls/s2n/mod.rs +++ b/pingora-core/src/connectors/tls/s2n/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-core/src/lib.rs b/pingora-core/src/lib.rs index 544a8669..a4450632 100644 --- a/pingora-core/src/lib.rs +++ b/pingora-core/src/lib.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -35,7 +35,57 @@ //! If looking to build a (reverse) proxy, see [`pingora-proxy`](https://docs.rs/pingora-proxy) crate. //! //! # Optional features -//! `boringssl`: Switch the internal TLS library from OpenSSL to BoringSSL. +//! +//! ## TLS backends (mutually exclusive) +//! - `openssl`: Use OpenSSL as the TLS library (default if no TLS feature is specified) +//! - `boringssl`: Use BoringSSL as the TLS library (FIPS compatible) +//! - `rustls`: Use Rustls as the TLS library +//! +//! ## Additional features +//! - `connection_filter`: Enable early TCP connection filtering before TLS handshake. +//! This allows implementing custom logic to accept/reject connections based on peer address +//! with zero overhead when disabled. +//! - `sentry`: Enable Sentry error reporting integration +//! - `patched_http1`: Enable patched HTTP/1 parser +//! +//! # Connection Filtering +//! +//! With the `connection_filter` feature enabled, you can implement early connection filtering +//! at the TCP level, before any TLS handshake or HTTP processing occurs. This is useful for: +//! - IP-based access control +//! - Rate limiting at the connection level +//! - Geographic restrictions +//! - DDoS mitigation +//! +//! ## Example +//! +//! ```rust,ignore +//! # #[cfg(feature = "connection_filter")] +//! # { +//! use async_trait::async_trait; +//! use pingora_core::listeners::ConnectionFilter; +//! use std::net::SocketAddr; +//! use std::sync::Arc; +//! +//! #[derive(Debug)] +//! struct MyFilter; +//! +//! #[async_trait] +//! impl ConnectionFilter for MyFilter { +//! async fn should_accept(&self, addr: &SocketAddr) -> bool { +//! // Custom logic to filter connections +//! !is_blocked_ip(addr.ip()) +//! } +//! } +//! +//! // Apply the filter to a service +//! let mut service = my_service(); +//! service.set_connection_filter(Arc::new(MyFilter)); +//! # } +//! ``` +//! +//! When the `connection_filter` feature is disabled, the filter API remains available +//! but becomes a no-op, ensuring zero overhead for users who don't need this functionality. // This enables the feature that labels modules that are only available with // certain pingora features diff --git a/pingora-core/src/listeners/connection_filter.rs b/pingora-core/src/listeners/connection_filter.rs new file mode 100644 index 00000000..10ae642f --- /dev/null +++ b/pingora-core/src/listeners/connection_filter.rs @@ -0,0 +1,147 @@ +// Copyright 2026 Cloudflare, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Connection filtering trait for early connection filtering +//! +//! This module provides the [`ConnectionFilter`] trait which allows filtering +//! incoming connections at the TCP level, before the TLS handshake occurs. +//! +//! # Feature Flag +//! +//! This functionality requires the `connection_filter` feature to be enabled: +//! ```toml +//! [dependencies] +//! pingora-core = { version = "0.5", features = ["connection_filter"] } +//! ``` +//! +//! When the feature is disabled, a no-op implementation is provided for API compatibility. + +use async_trait::async_trait; +use std::fmt::Debug; +use std::net::SocketAddr; + +/// A trait for filtering incoming connections at the TCP level. +/// +/// Implementations of this trait can inspect the peer address of incoming +/// connections and decide whether to accept or reject them before any +/// further processing (including TLS handshake) occurs. +/// +/// # Example +/// +/// ```rust,no_run +/// use async_trait::async_trait; +/// use pingora_core::listeners::ConnectionFilter; +/// use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +/// +/// #[derive(Debug)] +/// struct BlocklistFilter { +/// blocked_ips: Vec, +/// } +/// +/// #[async_trait] +/// impl ConnectionFilter for BlocklistFilter { +/// async fn should_accept(&self, addr: &SocketAddr) -> bool { +/// !self.blocked_ips.contains(&addr.ip()) +/// } +/// } +/// ``` +/// +/// # Performance Considerations +/// +/// This filter is called for every incoming connection, so implementations +/// should be efficient. Consider caching or pre-computing data structures +/// for IP filtering rather than doing expensive operations per connection. +#[async_trait] +pub trait ConnectionFilter: Debug + Send + Sync { + /// Determines whether an incoming connection should be accepted. + /// + /// This method is called after a TCP connection is accepted but before + /// any further processing (including TLS handshake). + /// + /// # Arguments + /// + /// * `addr` - The socket address of the incoming connection + /// + /// # Returns + /// + /// * `true` - Accept the connection and continue processing + /// * `false` - Drop the connection immediately + /// + /// # Example + /// + /// ```rust,no_run + /// async fn should_accept(&self, addr: &SocketAddr) -> bool { + /// // Accept only connections from private IP ranges + /// match addr.ip() { + /// IpAddr::V4(ip) => ip.is_private(), + /// IpAddr::V6(_) => true, + /// } + /// } + /// + async fn should_accept(&self, _addr: Option<&SocketAddr>) -> bool { + true + } +} + +/// Default implementation that accepts all connections. +/// +/// This filter accepts all incoming connections without any filtering. +/// It's used as the default when no custom filter is specified. +#[derive(Debug, Clone)] +pub struct AcceptAllFilter; + +#[async_trait] +impl ConnectionFilter for AcceptAllFilter { + // Uses default implementation +} + +#[cfg(test)] +mod tests { + use super::*; + use std::net::{IpAddr, Ipv4Addr}; + + #[derive(Debug, Clone)] + struct BlockListFilter { + blocked_ips: Vec, + } + + #[async_trait] + impl ConnectionFilter for BlockListFilter { + async fn should_accept(&self, addr_opt: Option<&SocketAddr>) -> bool { + addr_opt + .map(|addr| !self.blocked_ips.contains(&addr.ip())) + .unwrap_or(true) + } + } + + #[tokio::test] + async fn test_accept_all_filter() { + let filter = AcceptAllFilter; + let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080); + assert!(filter.should_accept(Some(&addr)).await); + } + + #[tokio::test] + async fn test_blocklist_filter() { + let filter = BlockListFilter { + blocked_ips: vec![IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1))], + }; + + let blocked_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 8080); + let allowed_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 2)), 8080); + + assert!(!filter.should_accept(Some(&blocked_addr)).await); + assert!(filter.should_accept(Some(&allowed_addr)).await); + } +} diff --git a/pingora-core/src/listeners/l4.rs b/pingora-core/src/listeners/l4.rs index 4dc07bce..1fee7437 100644 --- a/pingora-core/src/listeners/l4.rs +++ b/pingora-core/src/listeners/l4.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#[cfg(feature = "connection_filter")] +use log::debug; use log::warn; use pingora_error::{ ErrorType::{AcceptError, BindError}, @@ -29,9 +31,16 @@ use std::time::Duration; use std::{fs::Permissions, sync::Arc}; use tokio::net::TcpSocket; +#[cfg(feature = "connection_filter")] +use super::connection_filter::ConnectionFilter; +#[cfg(feature = "connection_filter")] +use crate::listeners::AcceptAllFilter; + use crate::protocols::l4::ext::{set_dscp, set_tcp_fastopen_backlog}; use crate::protocols::l4::listener::Listener; pub use crate::protocols::l4::stream::Stream; +#[cfg(feature = "connection_filter")] +use crate::protocols::GetSocketDigest; use crate::protocols::TcpKeepalive; #[cfg(unix)] use crate::server::ListenFds; @@ -271,16 +280,24 @@ async fn bind(addr: &ServerAddress) -> Result { pub struct ListenerEndpoint { listen_addr: ServerAddress, listener: Arc, + #[cfg(feature = "connection_filter")] + connection_filter: Arc, } #[derive(Default)] pub struct ListenerEndpointBuilder { listen_addr: Option, + #[cfg(feature = "connection_filter")] + connection_filter: Option>, } impl ListenerEndpointBuilder { pub fn new() -> ListenerEndpointBuilder { - Self { listen_addr: None } + Self { + listen_addr: None, + #[cfg(feature = "connection_filter")] + connection_filter: None, + } } pub fn listen_addr(&mut self, addr: ServerAddress) -> &mut Self { @@ -288,6 +305,12 @@ impl ListenerEndpointBuilder { self } + #[cfg(feature = "connection_filter")] + pub fn connection_filter(&mut self, filter: Arc) -> &mut Self { + self.connection_filter = Some(filter); + self + } + #[cfg(unix)] pub async fn listen(self, fds: Option) -> Result { let listen_addr = self @@ -313,9 +336,16 @@ impl ListenerEndpointBuilder { bind(&listen_addr).await? }; + #[cfg(feature = "connection_filter")] + let connection_filter = self + .connection_filter + .unwrap_or_else(|| Arc::new(AcceptAllFilter)); + Ok(ListenerEndpoint { listen_addr, listener: Arc::new(listener), + #[cfg(feature = "connection_filter")] + connection_filter, }) } @@ -324,11 +354,19 @@ impl ListenerEndpointBuilder { let listen_addr = self .listen_addr .expect("Tried to listen with no addr specified"); + let listener = bind(&listen_addr).await?; + #[cfg(feature = "connection_filter")] + let connection_filter = self + .connection_filter + .unwrap_or_else(|| Arc::new(AcceptAllFilter)); + Ok(ListenerEndpoint { listen_addr, listener: Arc::new(listener), + #[cfg(feature = "connection_filter")] + connection_filter, }) } } @@ -361,13 +399,50 @@ impl ListenerEndpoint { } pub async fn accept(&self) -> Result { - let mut stream = self - .listener - .accept() - .await - .or_err(AcceptError, "Fail to accept()")?; - self.apply_stream_settings(&mut stream)?; - Ok(stream) + #[cfg(feature = "connection_filter")] + { + loop { + let mut stream = self + .listener + .accept() + .await + .or_err(AcceptError, "Fail to accept()")?; + + // Performance: nested if-let avoids cloning/allocations on each connection accept + let should_accept = if let Some(digest) = stream.get_socket_digest() { + if let Some(peer_addr) = digest.peer_addr() { + self.connection_filter + .should_accept(peer_addr.as_inet()) + .await + } else { + // No peer address available - accept by default + true + } + } else { + // No socket digest available - accept by default + true + }; + + if !should_accept { + debug!("Connection rejected by filter"); + drop(stream); + continue; + } + + self.apply_stream_settings(&mut stream)?; + return Ok(stream); + } + } + #[cfg(not(feature = "connection_filter"))] + { + let mut stream = self + .listener + .accept() + .await + .or_err(AcceptError, "Fail to accept()")?; + self.apply_stream_settings(&mut stream)?; + Ok(stream) + } } } @@ -507,4 +582,146 @@ mod test { // Verify the first listener still works assert_eq!(listener1.as_str(), addr); } + + #[cfg(feature = "connection_filter")] + #[tokio::test] + async fn test_connection_filter_accept() { + use crate::listeners::ConnectionFilter; + use async_trait::async_trait; + use std::sync::atomic::{AtomicUsize, Ordering}; + + #[derive(Debug)] + struct CountingFilter { + accept_count: Arc, + reject_count: Arc, + } + + #[async_trait] + impl ConnectionFilter for CountingFilter { + async fn should_accept(&self, _addr: Option<&SocketAddr>) -> bool { + let count = self.accept_count.fetch_add(1, Ordering::SeqCst); + if count % 2 == 0 { + true + } else { + self.reject_count.fetch_add(1, Ordering::SeqCst); + false + } + } + } + + let addr = "127.0.0.1:7300"; + let accept_count = Arc::new(AtomicUsize::new(0)); + let reject_count = Arc::new(AtomicUsize::new(0)); + + let filter = Arc::new(CountingFilter { + accept_count: accept_count.clone(), + reject_count: reject_count.clone(), + }); + + let mut builder = ListenerEndpoint::builder(); + builder + .listen_addr(ServerAddress::Tcp(addr.into(), None)) + .connection_filter(filter); + + #[cfg(unix)] + let listener = builder.listen(None).await.unwrap(); + #[cfg(windows)] + let listener = builder.listen().await.unwrap(); + + let listener_clone = listener.clone(); + tokio::spawn(async move { + let _stream1 = listener_clone.accept().await.unwrap(); + let _stream2 = listener_clone.accept().await.unwrap(); + }); + + tokio::time::sleep(Duration::from_millis(10)).await; + + let _conn1 = tokio::net::TcpStream::connect(addr).await.unwrap(); + let _conn2 = tokio::net::TcpStream::connect(addr).await.unwrap(); + let _conn3 = tokio::net::TcpStream::connect(addr).await.unwrap(); + + tokio::time::sleep(Duration::from_millis(50)).await; + + assert_eq!(accept_count.load(Ordering::SeqCst), 3); + assert_eq!(reject_count.load(Ordering::SeqCst), 1); + } + + #[cfg(feature = "connection_filter")] + #[tokio::test] + async fn test_connection_filter_blocks_all() { + use crate::listeners::ConnectionFilter; + use async_trait::async_trait; + use std::sync::atomic::{AtomicUsize, Ordering}; + + #[derive(Debug)] + struct RejectAllFilter { + reject_count: Arc, + } + + #[async_trait] + impl ConnectionFilter for RejectAllFilter { + async fn should_accept(&self, _addr: Option<&SocketAddr>) -> bool { + self.reject_count.fetch_add(1, Ordering::SeqCst); + false + } + } + + let addr = "127.0.0.1:7301"; + let reject_count = Arc::new(AtomicUsize::new(0)); + + let mut builder = ListenerEndpoint::builder(); + builder + .listen_addr(ServerAddress::Tcp(addr.into(), None)) + .connection_filter(Arc::new(RejectAllFilter { + reject_count: reject_count.clone(), + })); + + #[cfg(unix)] + let listener = builder.listen(None).await.unwrap(); + #[cfg(windows)] + let listener = builder.listen().await.unwrap(); + + let listener_clone = listener.clone(); + let _accept_handle = tokio::spawn(async move { + // This will never return since all connections are rejected + let _ = listener_clone.accept().await; + }); + + tokio::time::sleep(Duration::from_millis(50)).await; + + let mut handles = vec![]; + for _ in 0..3 { + let handle = tokio::spawn(async move { + if let Ok(stream) = tokio::net::TcpStream::connect(addr).await { + drop(stream); + } + }); + handles.push(handle); + } + + for handle in handles { + let _ = handle.await; + } + + // Wait for rejections to be counted with timeout + let start = tokio::time::Instant::now(); + let timeout = Duration::from_secs(2); + + loop { + let rejected = reject_count.load(Ordering::SeqCst); + if rejected >= 3 { + assert_eq!(rejected, 3, "Should reject exactly 3 connections"); + break; + } + + if start.elapsed() > timeout { + panic!( + "Timeout waiting for rejections, got {} expected 3", + rejected + ); + } + + tokio::time::sleep(Duration::from_millis(10)).await; + } + } } diff --git a/pingora-core/src/listeners/mod.rs b/pingora-core/src/listeners/mod.rs index b8a45bf9..abc65ea1 100644 --- a/pingora-core/src/listeners/mod.rs +++ b/pingora-core/src/listeners/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,31 +13,80 @@ // limitations under the License. //! The listening endpoints (TCP and TLS) and their configurations. +//! +//! This module provides the infrastructure for setting up network listeners +//! that accept incoming connections. It supports TCP, Unix domain sockets, +//! and TLS endpoints. +//! +//! # Connection Filtering +//! +//! With the `connection_filter` feature enabled, this module also provides +//! early connection filtering capabilities through the [`ConnectionFilter`] trait. +//! This allows dropping unwanted connections at the TCP level before any +//! expensive operations like TLS handshakes. +//! +//! ## Example with Connection Filtering +//! +//! ```rust,no_run +//! # #[cfg(feature = "connection_filter")] +//! # { +//! use pingora_core::listeners::{Listeners, ConnectionFilter}; +//! use std::sync::Arc; +//! +//! // Create a custom filter +//! let filter = Arc::new(MyCustomFilter::new()); +//! +//! // Apply to listeners +//! let mut listeners = Listeners::new(); +//! listeners.set_connection_filter(filter); +//! listeners.add_tcp("0.0.0.0:8080"); +//! # } +//! ``` mod l4; +#[cfg(feature = "connection_filter")] +pub mod connection_filter; + +#[cfg(feature = "connection_filter")] +pub use connection_filter::{AcceptAllFilter, ConnectionFilter}; + +#[cfg(not(feature = "connection_filter"))] +#[derive(Debug, Clone)] +pub struct AcceptAllFilter; + +#[cfg(not(feature = "connection_filter"))] +pub trait ConnectionFilter: std::fmt::Debug + Send + Sync { + fn should_accept(&self, _addr: &std::net::SocketAddr) -> bool { + true + } +} + +#[cfg(not(feature = "connection_filter"))] +impl ConnectionFilter for AcceptAllFilter { + fn should_accept(&self, _addr: &std::net::SocketAddr) -> bool { + true + } +} #[cfg(feature = "any_tls")] pub mod tls; #[cfg(not(feature = "any_tls"))] pub use crate::tls::listeners as tls; -use crate::protocols::{ - l4::socket::SocketAddr, - proxy_protocol, - tls::TlsRef, - Stream, -}; +use crate::protocols::{l4::socket::SocketAddr, proxy_protocol, tls::TlsRef, Stream}; use log::{debug, warn}; -use pingora_error::{OrErr, ErrorType::*}; +use pingora_error::{ErrorType::*, OrErr}; /// Callback function type for ClientHello extraction /// This allows external code (like moat) to generate fingerprints from ClientHello -pub type ClientHelloCallback = Option)>; +pub type ClientHelloCallback = + Option)>; /// Global callback for ClientHello extraction /// This is set by moat to generate fingerprints -static CLIENT_HELLO_CALLBACK: std::sync::OnceLock> = std::sync::OnceLock::new(); +static CLIENT_HELLO_CALLBACK: std::sync::OnceLock> = + std::sync::OnceLock::new(); /// Set the ClientHello callback function /// This is called by moat to register fingerprint generation @@ -61,7 +110,10 @@ pub fn set_client_hello_callback(callback: ClientHelloCallback) { } /// Call the ClientHello callback if registered -fn call_client_hello_callback(hello: &crate::protocols::tls::client_hello::ClientHello, peer_addr: Option) { +fn call_client_hello_callback( + hello: &crate::protocols::tls::client_hello::ClientHello, + peer_addr: Option, +) { if let Some(cb_guard) = CLIENT_HELLO_CALLBACK.get() { if let Ok(cb) = cb_guard.lock() { if let Some(callback) = *cb { @@ -82,7 +134,7 @@ use crate::server::ListenFds; use async_trait::async_trait; use pingora_error::Result; -use std::{fs::Permissions, sync::Arc}; +use std::{any::Any, fs::Permissions, sync::Arc}; use l4::{ListenerEndpoint, Stream as L4Stream}; use tls::{Acceptor, TlsSettings}; @@ -102,6 +154,19 @@ pub trait TlsAccept { async fn certificate_callback(&self, _ssl: &mut TlsRef) -> () { // does nothing by default } + + /// This function is called after the TLS handshake is complete. + /// + /// Any value returned from this function (other than `None`) will be stored in the + /// `extension` field of `SslDigest`. This allows you to attach custom application-specific + /// data to the TLS connection, which will be accessible from the HTTP layer via the + /// `SslDigest` attached to the session digest. + async fn handshake_complete_callback( + &self, + _ssl: &TlsRef, + ) -> Option> { + None + } } pub type TlsAcceptCallbacks = Box; @@ -109,6 +174,8 @@ pub type TlsAcceptCallbacks = Box; struct TransportStackBuilder { l4: ServerAddress, tls: Option, + #[cfg(feature = "connection_filter")] + connection_filter: Option>, } impl TransportStackBuilder { @@ -120,6 +187,11 @@ impl TransportStackBuilder { builder.listen_addr(self.l4.clone()); + #[cfg(feature = "connection_filter")] + if let Some(filter) = &self.connection_filter { + builder.connection_filter(filter.clone()); + } + #[cfg(unix)] let l4 = builder.listen(upgrade_listeners).await?; @@ -196,10 +268,14 @@ impl UninitializedStream { Err(e) => { // Check if this is a connection error that should abort the handshake match e.kind() { - std::io::ErrorKind::ConnectionReset | std::io::ErrorKind::ConnectionAborted => { + std::io::ErrorKind::ConnectionReset + | std::io::ErrorKind::ConnectionAborted => { debug!("Connection closed during ClientHello extraction: {:?}", e); // Return error to abort the connection instead of proceeding to TLS handshake - return Err(e).or_err(AcceptError, "Connection closed during ClientHello extraction"); + return Err(e).or_err( + AcceptError, + "Connection closed during ClientHello extraction", + ); } _ => { debug!("Non-fatal error extracting ClientHello: {:?}", e); @@ -237,10 +313,14 @@ impl UninitializedStream { // Process the extracted ClientHello if available if let Some(hello) = extracted_hello { // Get peer address if available - let peer_addr = wrapper.get_socket_digest() + let peer_addr = wrapper + .get_socket_digest() .and_then(|d| d.peer_addr().cloned()); - debug!("Extracted ClientHello: SNI={:?}, ALPN={:?}, Peer={:?}", hello.sni, hello.alpn, peer_addr); + debug!( + "Extracted ClientHello: SNI={:?}, ALPN={:?}, Peer={:?}", + hello.sni, hello.alpn, peer_addr + ); // Call the callback to generate fingerprint (registered by moat) call_client_hello_callback(&hello, peer_addr); @@ -276,7 +356,8 @@ impl UninitializedStream { return Ok(()); } - let peer_addr = self.l4 + let peer_addr = self + .l4 .get_socket_digest() .and_then(|d| d.transport_peer_addr().cloned()); let peer_str = peer_addr @@ -296,10 +377,7 @@ impl UninitializedStream { proxy_addr, client_addr ); } else { - debug!( - "PROXY protocol detected downstream client {}", - client_addr - ); + debug!("PROXY protocol detected downstream client {}", client_addr); } } } else if proxy_protocol::header_has_source_addr(&header) { @@ -327,14 +405,19 @@ impl UninitializedStream { /// The struct to hold one more multiple listening endpoints pub struct Listeners { stacks: Vec, + #[cfg(feature = "connection_filter")] + connection_filter: Option>, } impl Listeners { /// Create a new [`Listeners`] with no listening endpoints. pub fn new() -> Self { - Listeners { stacks: vec![] } + Listeners { + stacks: vec![], + #[cfg(feature = "connection_filter")] + connection_filter: None, + } } - /// Create a new [`Listeners`] with a TCP server endpoint from the given string. pub fn tcp(addr: &str) -> Self { let mut listeners = Self::new(); @@ -399,9 +482,28 @@ impl Listeners { self.add_endpoint(addr, None); } + /// Set a connection filter for all endpoints in this listener collection + #[cfg(feature = "connection_filter")] + pub fn set_connection_filter(&mut self, filter: Arc) { + log::debug!("Setting connection filter on Listeners"); + + // Store the filter for future endpoints + self.connection_filter = Some(filter.clone()); + + // Apply to existing stacks + for stack in &mut self.stacks { + stack.connection_filter = Some(filter.clone()); + } + } + /// Add the given [`ServerAddress`] to `self` with the given [`TlsSettings`] if provided pub fn add_endpoint(&mut self, l4: ServerAddress, tls: Option) { - self.stacks.push(TransportStackBuilder { l4, tls }) + self.stacks.push(TransportStackBuilder { + l4, + tls, + #[cfg(feature = "connection_filter")] + connection_filter: self.connection_filter.clone(), + }) } pub(crate) async fn build( @@ -432,6 +534,8 @@ impl Listeners { #[cfg(test)] mod test { use super::*; + #[cfg(feature = "connection_filter")] + use std::sync::atomic::{AtomicUsize, Ordering}; #[cfg(feature = "any_tls")] use tokio::io::AsyncWriteExt; use tokio::net::TcpStream; @@ -509,4 +613,53 @@ mod test { let res = client.get(format!("https://{addr}")).send().await.unwrap(); assert_eq!(res.status(), reqwest::StatusCode::OK); } + + #[cfg(feature = "connection_filter")] + #[test] + fn test_connection_filter_inheritance() { + #[derive(Debug, Clone)] + struct TestFilter { + counter: Arc, + } + + #[async_trait] + impl ConnectionFilter for TestFilter { + async fn should_accept(&self, _addr: Option<&std::net::SocketAddr>) -> bool { + self.counter.fetch_add(1, Ordering::SeqCst); + true + } + } + + let mut listeners = Listeners::new(); + + // Add an endpoint before setting filter + listeners.add_tcp("127.0.0.1:7104"); + + // Set the connection filter + let filter = Arc::new(TestFilter { + counter: Arc::new(AtomicUsize::new(0)), + }); + listeners.set_connection_filter(filter.clone()); + + // Add endpoints after setting filter + listeners.add_tcp("127.0.0.1:7105"); + #[cfg(feature = "any_tls")] + { + // Only test TLS if the feature is enabled + if let Ok(tls_settings) = TlsSettings::intermediate( + &format!("{}/tests/keys/server.crt", env!("CARGO_MANIFEST_DIR")), + &format!("{}/tests/keys/key.pem", env!("CARGO_MANIFEST_DIR")), + ) { + listeners.add_tls_with_settings("127.0.0.1:7106", None, tls_settings); + } + } + + // Verify all stacks have the filter (only when feature is enabled) + for stack in &listeners.stacks { + assert!( + stack.connection_filter.is_some(), + "All stacks should have the connection filter set" + ); + } + } } diff --git a/pingora-core/src/listeners/tls/boringssl_openssl/mod.rs b/pingora-core/src/listeners/tls/boringssl_openssl/mod.rs index ef1eeafb..a1e757da 100644 --- a/pingora-core/src/listeners/tls/boringssl_openssl/mod.rs +++ b/pingora-core/src/listeners/tls/boringssl_openssl/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,8 +16,10 @@ use log::debug; use pingora_error::{ErrorType, OrErr, Result}; use std::ops::{Deref, DerefMut}; +use crate::listeners::tls::boringssl_openssl::alpn::valid_alpn; pub use crate::protocols::tls::ALPN; use crate::protocols::{GetSocketDigest, IO}; +use crate::tls::ssl::AlpnError; use crate::tls::ssl::{SslAcceptor, SslAcceptorBuilder, SslFiletype, SslMethod}; use crate::{ listeners::TlsAcceptCallbacks, @@ -26,7 +28,6 @@ use crate::{ SslStream, }, }; - pub const TLS_CONF_ERR: ErrorType = ErrorType::Custom("TLSConfigError"); pub(crate) struct Acceptor { @@ -113,6 +114,18 @@ impl TlsSettings { .set_alpn_select_callback(alpn::prefer_h2), ALPN::H1 => self.accept_builder.set_alpn_select_callback(alpn::h1_only), ALPN::H2 => self.accept_builder.set_alpn_select_callback(alpn::h2_only), + ALPN::Custom(custom) => { + self.accept_builder + .set_alpn_select_callback(move |_, alpn_in| { + if !valid_alpn(alpn_in) { + return Err(AlpnError::NOACK); + } + match alpn::select_protocol(alpn_in, custom.protocol()) { + Some(p) => Ok(p), + None => Err(AlpnError::NOACK), + } + }); + } } } @@ -138,7 +151,9 @@ impl Acceptor { /// Perform TLS handshake with ClientHello extraction /// This wraps the stream with ClientHelloWrapper before TLS handshake #[cfg(unix)] - pub async fn tls_handshake_with_client_hello( + pub async fn tls_handshake_with_client_hello< + S: IO + GetSocketDigest + std::os::unix::io::AsRawFd + 'static, + >( &self, stream: S, ) -> Result>> { @@ -150,10 +165,14 @@ impl Acceptor { // Extract ClientHello before TLS handshake (sync version blocks until data is available) if let Ok(Some(hello)) = wrapper.extract_client_hello() { // Get peer address if available - let peer_addr = wrapper.get_socket_digest() + let peer_addr = wrapper + .get_socket_digest() .and_then(|d| d.peer_addr().cloned()); - debug!("Extracted ClientHello: SNI={:?}, ALPN={:?}, Peer={:?}", hello.sni, hello.alpn, peer_addr); + debug!( + "Extracted ClientHello: SNI={:?}, ALPN={:?}, Peer={:?}", + hello.sni, hello.alpn, peer_addr + ); // Generate fingerprint from raw ClientHello bytes // This will be handled by moat's tls_client_hello module @@ -173,7 +192,7 @@ mod alpn { use super::*; use crate::tls::ssl::{select_next_proto, AlpnError, SslRef}; - fn valid_alpn(alpn_in: &[u8]) -> bool { + pub(super) fn valid_alpn(alpn_in: &[u8]) -> bool { if alpn_in.is_empty() { return false; } @@ -181,6 +200,27 @@ mod alpn { true } + /// Finds the first protocol in the client-offered ALPN list that matches the given protocol. + /// + /// This is a helper for ALPN negotiation. It iterates over the client's protocol list + /// (in wire format) and returns the first protocol that matches proto + /// The returned reference always points into `client_protocols`, so lifetimes are correct. + pub(super) fn select_protocol<'a>( + client_protocols: &'a [u8], + proto: &[u8], + ) -> Option<&'a [u8]> { + let mut bytes = client_protocols; + while !bytes.is_empty() { + let len = bytes[0] as usize; + bytes = &bytes[1..]; + if len == proto.len() && &bytes[..len] == proto { + return Some(&bytes[..len]); + } + bytes = &bytes[len..]; + } + None + } + // A standard implementation provided by the SSL lib is used below pub fn prefer_h2<'a>(_ssl: &mut SslRef, alpn_in: &'a [u8]) -> Result<&'a [u8], AlpnError> { diff --git a/pingora-core/src/listeners/tls/mod.rs b/pingora-core/src/listeners/tls/mod.rs index 887293b3..c345073e 100644 --- a/pingora-core/src/listeners/tls/mod.rs +++ b/pingora-core/src/listeners/tls/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-core/src/listeners/tls/rustls/mod.rs b/pingora-core/src/listeners/tls/rustls/mod.rs index 40babeb6..0ca94d51 100644 --- a/pingora-core/src/listeners/tls/rustls/mod.rs +++ b/pingora-core/src/listeners/tls/rustls/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ use log::debug; use pingora_error::ErrorType::InternalError; use pingora_error::{Error, OrErr, Result}; use pingora_rustls::load_certs_and_key_files; +use pingora_rustls::ClientCertVerifier; use pingora_rustls::ServerConfig; use pingora_rustls::{version, TlsAcceptor as RusTlsAcceptor}; @@ -30,6 +31,7 @@ pub struct TlsSettings { alpn_protocols: Option>>, cert_path: String, key_path: String, + client_cert_verifier: Option>, } pub struct Acceptor { @@ -54,15 +56,19 @@ impl TlsSettings { ) }; - // TODO - Add support for client auth & custom CA support - let mut config = - ServerConfig::builder_with_protocol_versions(&[&version::TLS12, &version::TLS13]) - .with_no_client_auth() - .with_single_cert(certs, key) - .explain_err(InternalError, |e| { - format!("Failed to create server listener config: {e}") - }) - .unwrap(); + let builder = + ServerConfig::builder_with_protocol_versions(&[&version::TLS12, &version::TLS13]); + let builder = if let Some(verifier) = self.client_cert_verifier { + builder.with_client_cert_verifier(verifier) + } else { + builder.with_no_client_auth() + }; + let mut config = builder + .with_single_cert(certs, key) + .explain_err(InternalError, |e| { + format!("Failed to create server listener config: {e}") + }) + .unwrap(); if let Some(alpn_protocols) = self.alpn_protocols { config.alpn_protocols = alpn_protocols; @@ -80,10 +86,15 @@ impl TlsSettings { self.set_alpn(ALPN::H2H1); } - fn set_alpn(&mut self, alpn: ALPN) { + pub fn set_alpn(&mut self, alpn: ALPN) { self.alpn_protocols = Some(alpn.to_wire_protocols()); } + /// Configure mTLS by providing a rustls client certificate verifier. + pub fn set_client_cert_verifier(&mut self, verifier: Arc) { + self.client_cert_verifier = Some(verifier); + } + pub fn intermediate(cert_path: &str, key_path: &str) -> Result where Self: Sized, @@ -92,6 +103,7 @@ impl TlsSettings { alpn_protocols: None, cert_path: cert_path.to_string(), key_path: key_path.to_string(), + client_cert_verifier: None, }) } diff --git a/pingora-core/src/listeners/tls/s2n/mod.rs b/pingora-core/src/listeners/tls/s2n/mod.rs index 2598e829..ed689445 100644 --- a/pingora-core/src/listeners/tls/s2n/mod.rs +++ b/pingora-core/src/listeners/tls/s2n/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-core/src/modules/http/compression.rs b/pingora-core/src/modules/http/compression.rs index 1906bd66..fa64d3c1 100644 --- a/pingora-core/src/modules/http/compression.rs +++ b/pingora-core/src/modules/http/compression.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-core/src/modules/http/grpc_web.rs b/pingora-core/src/modules/http/grpc_web.rs index b248e233..fd1d4ad2 100644 --- a/pingora-core/src/modules/http/grpc_web.rs +++ b/pingora-core/src/modules/http/grpc_web.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-core/src/modules/http/mod.rs b/pingora-core/src/modules/http/mod.rs index d220e6b0..04084258 100644 --- a/pingora-core/src/modules/http/mod.rs +++ b/pingora-core/src/modules/http/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-core/src/modules/mod.rs b/pingora-core/src/modules/mod.rs index 359b9ef4..c4a1c4a6 100644 --- a/pingora-core/src/modules/mod.rs +++ b/pingora-core/src/modules/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-core/src/protocols/client_hello_wrapper.rs b/pingora-core/src/protocols/client_hello_wrapper.rs index e61d80a9..46efbde6 100644 --- a/pingora-core/src/protocols/client_hello_wrapper.rs +++ b/pingora-core/src/protocols/client_hello_wrapper.rs @@ -168,10 +168,11 @@ impl ClientHelloWrapper { Err(e) => { wrapper.hello_extracted = true; match e.kind() { - io::ErrorKind::ConnectionReset | io::ErrorKind::ConnectionAborted => { + io::ErrorKind::ConnectionReset + | io::ErrorKind::ConnectionAborted => { Poll::Ready(Err(e)) } - _ => Poll::Ready(Ok(None)) + _ => Poll::Ready(Ok(None)), } } } @@ -180,10 +181,9 @@ impl ClientHelloWrapper { wrapper.hello_extracted = true; match e.kind() { io::ErrorKind::WouldBlock => Poll::Pending, - io::ErrorKind::ConnectionReset | io::ErrorKind::ConnectionAborted => { - Poll::Ready(Err(e)) - } - _ => Poll::Ready(Ok(None)) + io::ErrorKind::ConnectionReset + | io::ErrorKind::ConnectionAborted => Poll::Ready(Err(e)), + _ => Poll::Ready(Ok(None)), } } Poll::Pending => Poll::Pending, @@ -373,4 +373,3 @@ mod tests { assert_eq!(inner.into_inner(), data); } } - diff --git a/pingora-core/src/protocols/digest.rs b/pingora-core/src/protocols/digest.rs index 64fe15e9..405c6698 100644 --- a/pingora-core/src/protocols/digest.rs +++ b/pingora-core/src/protocols/digest.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-core/src/protocols/http/body_buffer.rs b/pingora-core/src/protocols/http/body_buffer.rs index f3c46df9..a122df20 100644 --- a/pingora-core/src/protocols/http/body_buffer.rs +++ b/pingora-core/src/protocols/http/body_buffer.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ use bytes::{Bytes, BytesMut}; /// A buffer with size limit. When the total amount of data written to the buffer is below the limit /// all the data will be held in the buffer. Otherwise, the buffer will report to be truncated. -pub(crate) struct FixedBuffer { +pub struct FixedBuffer { buffer: BytesMut, capacity: usize, truncated: bool, diff --git a/pingora-core/src/protocols/http/bridge/grpc_web.rs b/pingora-core/src/protocols/http/bridge/grpc_web.rs index 63d19727..8a091d27 100644 --- a/pingora-core/src/protocols/http/bridge/grpc_web.rs +++ b/pingora-core/src/protocols/http/bridge/grpc_web.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-core/src/protocols/http/bridge/mod.rs b/pingora-core/src/protocols/http/bridge/mod.rs index fa1f58ca..6d295d0b 100644 --- a/pingora-core/src/protocols/http/bridge/mod.rs +++ b/pingora-core/src/protocols/http/bridge/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-core/src/protocols/http/client.rs b/pingora-core/src/protocols/http/client.rs index 2d1278d9..54fc367f 100644 --- a/pingora-core/src/protocols/http/client.rs +++ b/pingora-core/src/protocols/http/client.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -17,21 +17,23 @@ use pingora_error::Result; use pingora_http::{RequestHeader, ResponseHeader}; use std::time::Duration; -use super::v1::client::HttpSession as Http1Session; use super::v2::client::Http2Session; +use super::{custom::client::Session, v1::client::HttpSession as Http1Session}; use crate::protocols::{Digest, SocketAddr, Stream}; /// A type for Http client session. It can be either an Http1 connection or an Http2 stream. -pub enum HttpSession { +pub enum HttpSession { H1(Http1Session), H2(Http2Session), + Custom(S), } -impl HttpSession { +impl HttpSession { pub fn as_http1(&self) -> Option<&Http1Session> { match self { Self::H1(s) => Some(s), Self::H2(_) => None, + Self::Custom(_) => None, } } @@ -39,8 +41,26 @@ impl HttpSession { match self { Self::H1(_) => None, Self::H2(s) => Some(s), + Self::Custom(_) => None, } } + + pub fn as_custom(&self) -> Option<&S> { + match self { + Self::H1(_) => None, + Self::H2(_) => None, + Self::Custom(c) => Some(c), + } + } + + pub fn as_custom_mut(&mut self) -> Option<&mut S> { + match self { + Self::H1(_) => None, + Self::H2(_) => None, + Self::Custom(c) => Some(c), + } + } + /// Write the request header to the server /// After the request header is sent. The caller can either start reading the response or /// sending request body if any. @@ -51,6 +71,7 @@ impl HttpSession { Ok(()) } HttpSession::H2(h2) => h2.write_request_header(req, false), + HttpSession::Custom(c) => c.write_request_header(req, false).await, } } @@ -63,6 +84,7 @@ impl HttpSession { Ok(()) } HttpSession::H2(h2) => h2.write_request_body(data, end).await, + HttpSession::Custom(c) => c.write_request_body(data, end).await, } } @@ -74,6 +96,7 @@ impl HttpSession { Ok(()) } HttpSession::H2(h2) => h2.finish_request_body(), + HttpSession::Custom(c) => c.finish_request_body().await, } } @@ -84,6 +107,7 @@ impl HttpSession { match self { HttpSession::H1(h1) => h1.read_timeout = timeout, HttpSession::H2(h2) => h2.read_timeout = timeout, + HttpSession::Custom(c) => c.set_read_timeout(timeout), } } @@ -94,6 +118,7 @@ impl HttpSession { match self { HttpSession::H1(h1) => h1.write_timeout = timeout, HttpSession::H2(h2) => h2.write_timeout = timeout, + HttpSession::Custom(c) => c.set_write_timeout(timeout), } } @@ -107,6 +132,7 @@ impl HttpSession { Ok(()) } HttpSession::H2(h2) => h2.read_response_header().await, + HttpSession::Custom(c) => c.read_response_header().await, } } @@ -117,6 +143,7 @@ impl HttpSession { match self { HttpSession::H1(h1) => h1.read_body_bytes().await, HttpSession::H2(h2) => h2.read_response_body().await, + HttpSession::Custom(c) => c.read_response_body().await, } } @@ -125,6 +152,7 @@ impl HttpSession { match self { HttpSession::H1(h1) => h1.is_body_done(), HttpSession::H2(h2) => h2.response_finished(), + HttpSession::Custom(c) => c.response_finished(), } } @@ -135,6 +163,7 @@ impl HttpSession { match self { Self::H1(s) => s.shutdown().await, Self::H2(s) => s.shutdown(), + Self::Custom(c) => c.shutdown(0, "shutdown").await, } } @@ -145,6 +174,7 @@ impl HttpSession { match self { Self::H1(s) => s.resp_header(), Self::H2(s) => s.response_header(), + Self::Custom(c) => c.response_header(), } } @@ -156,6 +186,7 @@ impl HttpSession { match self { Self::H1(s) => Some(s.digest()), Self::H2(s) => s.digest(), + Self::Custom(c) => c.digest(), } } @@ -166,6 +197,7 @@ impl HttpSession { match self { Self::H1(s) => Some(s.digest_mut()), Self::H2(s) => s.digest_mut(), + Self::Custom(s) => s.digest_mut(), } } @@ -174,6 +206,7 @@ impl HttpSession { match self { Self::H1(s) => s.server_addr(), Self::H2(s) => s.server_addr(), + Self::Custom(s) => s.server_addr(), } } @@ -182,6 +215,7 @@ impl HttpSession { match self { Self::H1(s) => s.client_addr(), Self::H2(s) => s.client_addr(), + Self::Custom(s) => s.client_addr(), } } @@ -191,6 +225,7 @@ impl HttpSession { match self { Self::H1(s) => Some(s.stream()), Self::H2(_) => None, + Self::Custom(_) => None, } } } diff --git a/pingora-core/src/protocols/http/compression/brotli.rs b/pingora-core/src/protocols/http/compression/brotli.rs index c4bb36a5..fa8a3bae 100644 --- a/pingora-core/src/protocols/http/compression/brotli.rs +++ b/pingora-core/src/protocols/http/compression/brotli.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-core/src/protocols/http/compression/gzip.rs b/pingora-core/src/protocols/http/compression/gzip.rs index 46678df6..97f7b636 100644 --- a/pingora-core/src/protocols/http/compression/gzip.rs +++ b/pingora-core/src/protocols/http/compression/gzip.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-core/src/protocols/http/compression/mod.rs b/pingora-core/src/protocols/http/compression/mod.rs index 2f86efce..9e84ab3c 100644 --- a/pingora-core/src/protocols/http/compression/mod.rs +++ b/pingora-core/src/protocols/http/compression/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -333,6 +333,8 @@ pub enum Algorithm { Gzip, Brotli, Zstd, + Dcb, + Dcz, // TODO: Identity, // TODO: Deflate Other, // anything unknown @@ -344,6 +346,8 @@ impl Algorithm { Algorithm::Gzip => "gzip", Algorithm::Brotli => "br", Algorithm::Zstd => "zstd", + Algorithm::Dcb => "dcb", + Algorithm::Dcz => "dcz", Algorithm::Any => "*", Algorithm::Other => "other", } @@ -390,6 +394,10 @@ impl From<&str> for Algorithm { Algorithm::Brotli } else if coding == UniCase::ascii("zstd") { Algorithm::Zstd + } else if coding == UniCase::ascii("dcb") { + Algorithm::Dcb + } else if coding == UniCase::ascii("dcz") { + Algorithm::Dcz } else if s.is_empty() { Algorithm::Any } else { @@ -614,6 +622,36 @@ fn test_decide_action() { let mut header = ResponseHeader::build(200, None).unwrap(); header.insert_header("content-encoding", "gzip").unwrap(); assert_eq!(decide_action(&header, &[Brotli, Gzip]), Noop); + + // dcb passthrough: client accepts dcb, response has dcb + let mut header = ResponseHeader::build(200, None).unwrap(); + header.insert_header("content-encoding", "dcb").unwrap(); + assert_eq!(decide_action(&header, &[Dcb, Brotli]), Noop); + + // dcz passthrough: client accepts dcz, response has dcz + let mut header = ResponseHeader::build(200, None).unwrap(); + header.insert_header("content-encoding", "dcz").unwrap(); + assert_eq!(decide_action(&header, &[Dcz, Zstd]), Noop); + + // Client wants dcz but response has brotli, decompress brotli + let mut header = ResponseHeader::build(200, None).unwrap(); + header.insert_header("content-encoding", "br").unwrap(); + assert_eq!(decide_action(&header, &[Dcz]), Decompress(Brotli)); + + // Client wants dcz but response has zstd, decompress zstd + let mut header = ResponseHeader::build(200, None).unwrap(); + header.insert_header("content-encoding", "zstd").unwrap(); + assert_eq!(decide_action(&header, &[Dcz]), Decompress(Zstd)); + + // Client wants dcb but response has gzip, decompress gzip + let mut header = ResponseHeader::build(200, None).unwrap(); + header.insert_header("content-encoding", "gzip").unwrap(); + assert_eq!(decide_action(&header, &[Dcb]), Decompress(Gzip)); + + // Client wants dcb but response has brotli, decompress brotli + let mut header = ResponseHeader::build(200, None).unwrap(); + header.insert_header("content-encoding", "br").unwrap(); + assert_eq!(decide_action(&header, &[Dcb]), Decompress(Brotli)); } use once_cell::sync::Lazy; diff --git a/pingora-core/src/protocols/http/compression/zstd.rs b/pingora-core/src/protocols/http/compression/zstd.rs index b8a45b41..39465918 100644 --- a/pingora-core/src/protocols/http/compression/zstd.rs +++ b/pingora-core/src/protocols/http/compression/zstd.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-core/src/protocols/http/conditional_filter.rs b/pingora-core/src/protocols/http/conditional_filter.rs index 49daebc9..10aee2f2 100644 --- a/pingora-core/src/protocols/http/conditional_filter.rs +++ b/pingora-core/src/protocols/http/conditional_filter.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-core/src/protocols/http/custom/client.rs b/pingora-core/src/protocols/http/custom/client.rs new file mode 100644 index 00000000..994ddf04 --- /dev/null +++ b/pingora-core/src/protocols/http/custom/client.rs @@ -0,0 +1,176 @@ +// Copyright 2026 Cloudflare, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::time::Duration; + +use async_trait::async_trait; +use bytes::Bytes; +use futures::Stream; +use http::HeaderMap; +use pingora_error::Result; +use pingora_http::{RequestHeader, ResponseHeader}; + +use crate::protocols::{l4::socket::SocketAddr, Digest, UniqueIDType}; + +use super::{BodyWrite, CustomMessageWrite}; + +#[doc(hidden)] +#[async_trait] +pub trait Session: Send + Sync + Unpin + 'static { + async fn write_request_header(&mut self, req: Box, end: bool) -> Result<()>; + + async fn write_request_body(&mut self, data: Bytes, end: bool) -> Result<()>; + + async fn finish_request_body(&mut self) -> Result<()>; + + fn set_read_timeout(&mut self, timeout: Option); + + fn set_write_timeout(&mut self, timeout: Option); + + async fn read_response_header(&mut self) -> Result<()>; + + async fn read_response_body(&mut self) -> Result>; + + fn response_finished(&self) -> bool; + + async fn shutdown(&mut self, code: u32, ctx: &str); + + fn response_header(&self) -> Option<&ResponseHeader>; + + fn was_upgraded(&self) -> bool; + + fn digest(&self) -> Option<&Digest>; + + fn digest_mut(&mut self) -> Option<&mut Digest>; + + fn server_addr(&self) -> Option<&SocketAddr>; + + fn client_addr(&self) -> Option<&SocketAddr>; + + async fn read_trailers(&mut self) -> Result>; + + fn fd(&self) -> UniqueIDType; + + async fn check_response_end_or_error(&mut self, headers: bool) -> Result; + + fn take_request_body_writer(&mut self) -> Option>; + + async fn finish_custom(&mut self) -> Result<()>; + + fn take_custom_message_reader( + &mut self, + ) -> Option> + Unpin + Send + Sync + 'static>>; + + async fn drain_custom_messages(&mut self) -> Result<()>; + + fn take_custom_message_writer(&mut self) -> Option>; +} + +#[doc(hidden)] +#[async_trait] +impl Session for () { + async fn write_request_header(&mut self, _req: Box, _end: bool) -> Result<()> { + unreachable!("client session: write_request_header") + } + + async fn write_request_body(&mut self, _data: Bytes, _end: bool) -> Result<()> { + unreachable!("client session: write_request_body") + } + + async fn finish_request_body(&mut self) -> Result<()> { + unreachable!("client session: finish_request_body") + } + + fn set_read_timeout(&mut self, _timeout: Option) { + unreachable!("client session: set_read_timeout") + } + + fn set_write_timeout(&mut self, _timeout: Option) { + unreachable!("client session: set_write_timeout") + } + + async fn read_response_header(&mut self) -> Result<()> { + unreachable!("client session: read_response_header") + } + + async fn read_response_body(&mut self) -> Result> { + unreachable!("client session: read_response_body") + } + + fn response_finished(&self) -> bool { + unreachable!("client session: response_finished") + } + + async fn shutdown(&mut self, _code: u32, _ctx: &str) { + unreachable!("client session: shutdown") + } + + fn response_header(&self) -> Option<&ResponseHeader> { + unreachable!("client session: response_header") + } + + fn was_upgraded(&self) -> bool { + unreachable!("client session: was upgraded") + } + + fn digest(&self) -> Option<&Digest> { + unreachable!("client session: digest") + } + + fn digest_mut(&mut self) -> Option<&mut Digest> { + unreachable!("client session: digest_mut") + } + + fn server_addr(&self) -> Option<&SocketAddr> { + unreachable!("client session: server_addr") + } + + fn client_addr(&self) -> Option<&SocketAddr> { + unreachable!("client session: client_addr") + } + + async fn finish_custom(&mut self) -> Result<()> { + unreachable!("client session: finish_custom") + } + + async fn read_trailers(&mut self) -> Result> { + unreachable!("client session: read_trailers") + } + + fn fd(&self) -> UniqueIDType { + unreachable!("client session: fd") + } + + async fn check_response_end_or_error(&mut self, _headers: bool) -> Result { + unreachable!("client session: check_response_end_or_error") + } + + fn take_custom_message_reader( + &mut self, + ) -> Option> + Unpin + Send + Sync + 'static>> { + unreachable!("client session: get_custom_message_reader") + } + + async fn drain_custom_messages(&mut self) -> Result<()> { + unreachable!("client session: drain_custom_messages") + } + + fn take_custom_message_writer(&mut self) -> Option> { + unreachable!("client session: get_custom_message_writer") + } + + fn take_request_body_writer(&mut self) -> Option> { + unreachable!("client session: take_request_body_writer") + } +} diff --git a/pingora-core/src/protocols/http/custom/mod.rs b/pingora-core/src/protocols/http/custom/mod.rs new file mode 100644 index 00000000..cac4a755 --- /dev/null +++ b/pingora-core/src/protocols/http/custom/mod.rs @@ -0,0 +1,90 @@ +// Copyright 2026 Cloudflare, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::time::Duration; + +use async_trait::async_trait; +use bytes::Bytes; +use futures::Stream; +use log::debug; +use pingora_error::Result; +use tokio_stream::StreamExt; + +pub mod client; +pub mod server; + +pub const CUSTOM_MESSAGE_QUEUE_SIZE: usize = 128; + +pub fn is_informational_except_101>(code: T) -> bool { + // excluding `101 Switching Protocols`, because it's not followed by any other + // response and it's a final + // The WebSocket Protocol https://datatracker.ietf.org/doc/html/rfc6455 + code > 99 && code < 200 && code != 101 +} + +#[async_trait] +pub trait CustomMessageWrite: Send + Sync + Unpin + 'static { + fn set_write_timeout(&mut self, timeout: Option); + async fn write_custom_message(&mut self, msg: Bytes) -> Result<()>; + async fn finish_custom(&mut self) -> Result<()>; +} + +#[doc(hidden)] +#[async_trait] +impl CustomMessageWrite for () { + fn set_write_timeout(&mut self, _timeout: Option) {} + + async fn write_custom_message(&mut self, msg: Bytes) -> Result<()> { + debug!("write_custom_message: {:?}", msg); + Ok(()) + } + + async fn finish_custom(&mut self) -> Result<()> { + debug!("finish_custom"); + Ok(()) + } +} + +#[async_trait] +pub trait BodyWrite: Send + Sync + Unpin + 'static { + async fn write_all_buf(&mut self, data: &mut Bytes) -> Result<()>; + async fn finish(&mut self) -> Result<()>; + async fn cleanup(&mut self) -> Result<()>; + fn upgrade_body_writer(&mut self); +} + +pub async fn drain_custom_messages( + reader: Option> + Unpin + Send + Sync + 'static>>, +) -> Result<()> { + let Some(mut reader) = reader else { + return Ok(()); + }; + + while let Some(res) = reader.next().await { + let msg = res?; + debug!("consume_custom_messages: {msg:?}"); + } + + Ok(()) +} + +#[macro_export] +macro_rules! custom_session { + ($base_obj:ident . $($method_tokens:tt)+) => { + if let Some(custom_session) = $base_obj.as_custom_mut() { + #[allow(clippy::semicolon_if_nothing_returned)] + custom_session.$($method_tokens)+; + } + }; +} diff --git a/pingora-core/src/protocols/http/custom/server.rs b/pingora-core/src/protocols/http/custom/server.rs new file mode 100644 index 00000000..fc9e4c48 --- /dev/null +++ b/pingora-core/src/protocols/http/custom/server.rs @@ -0,0 +1,299 @@ +// Copyright 2026 Cloudflare, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::time::Duration; + +use async_trait::async_trait; +use bytes::Bytes; +use futures::Stream; +use http::HeaderMap; +use pingora_error::Result; +use pingora_http::{RequestHeader, ResponseHeader}; + +use crate::protocols::{http::HttpTask, l4::socket::SocketAddr, Digest}; + +use super::CustomMessageWrite; + +#[doc(hidden)] +#[async_trait] +pub trait Session: Send + Sync + Unpin + 'static { + fn req_header(&self) -> &RequestHeader; + + fn req_header_mut(&mut self) -> &mut RequestHeader; + + async fn read_body_bytes(&mut self) -> Result>; + + async fn drain_request_body(&mut self) -> Result<()>; + + async fn write_response_header(&mut self, resp: Box, end: bool) -> Result<()>; + + async fn write_response_header_ref(&mut self, resp: &ResponseHeader, end: bool) -> Result<()>; + + async fn write_body(&mut self, data: Bytes, end: bool) -> Result<()>; + + async fn write_trailers(&mut self, trailers: HeaderMap) -> Result<()>; + + async fn response_duplex_vec(&mut self, tasks: Vec) -> Result; + + fn set_read_timeout(&mut self, timeout: Option); + + fn get_read_timeout(&self) -> Option; + + fn set_write_timeout(&mut self, timeout: Option); + + fn get_write_timeout(&self) -> Option; + + fn set_total_drain_timeout(&mut self, timeout: Option); + + fn get_total_drain_timeout(&self) -> Option; + + fn request_summary(&self) -> String; + + fn response_written(&self) -> Option<&ResponseHeader>; + + async fn shutdown(&mut self, code: u32, ctx: &str); + + fn is_body_done(&mut self) -> bool; + + async fn finish(&mut self) -> Result<()>; + + fn is_body_empty(&mut self) -> bool; + + async fn read_body_or_idle(&mut self, no_body_expected: bool) -> Result>; + + fn body_bytes_sent(&self) -> usize; + + fn body_bytes_read(&self) -> usize; + + fn digest(&self) -> Option<&Digest>; + + fn digest_mut(&mut self) -> Option<&mut Digest>; + + fn client_addr(&self) -> Option<&SocketAddr>; + + fn server_addr(&self) -> Option<&SocketAddr>; + + fn pseudo_raw_h1_request_header(&self) -> Bytes; + + fn enable_retry_buffering(&mut self); + + fn retry_buffer_truncated(&self) -> bool; + + fn get_retry_buffer(&self) -> Option; + + async fn finish_custom(&mut self) -> Result<()>; + + fn take_custom_message_reader( + &mut self, + ) -> Option> + Unpin + Send + Sync + 'static>>; + + fn restore_custom_message_reader( + &mut self, + reader: Box> + Unpin + Send + Sync + 'static>, + ) -> Result<()>; + + fn take_custom_message_writer(&mut self) -> Option>; + + fn restore_custom_message_writer(&mut self, writer: Box) -> Result<()>; + + /// Whether this request is for upgrade (e.g., websocket). + /// + /// Returns `true` if the request has HTTP/1.1 version and contains an Upgrade header. + fn is_upgrade_req(&self) -> bool { + false + } + + /// Whether this session was fully upgraded (completed Upgrade handshake). + /// + /// Returns `true` if the request was an upgrade request and a 101 response was sent. + fn was_upgraded(&self) -> bool { + false + } +} + +#[doc(hidden)] +#[async_trait] +impl Session for () { + fn req_header(&self) -> &RequestHeader { + unreachable!("server session: req_header") + } + + fn req_header_mut(&mut self) -> &mut RequestHeader { + unreachable!("server session: req_header_mut") + } + + async fn read_body_bytes(&mut self) -> Result> { + unreachable!("server session: read_body_bytes") + } + + async fn drain_request_body(&mut self) -> Result<()> { + unreachable!("server session: drain_request_body") + } + + async fn write_response_header( + &mut self, + _resp: Box, + _end: bool, + ) -> Result<()> { + unreachable!("server session: write_response_header") + } + + async fn write_response_header_ref( + &mut self, + _resp: &ResponseHeader, + _end: bool, + ) -> Result<()> { + unreachable!("server session: write_response_header_ref") + } + + async fn write_body(&mut self, _data: Bytes, _end: bool) -> Result<()> { + unreachable!("server session: write_body") + } + + async fn write_trailers(&mut self, _trailers: HeaderMap) -> Result<()> { + unreachable!("server session: write_trailers") + } + + async fn response_duplex_vec(&mut self, _tasks: Vec) -> Result { + unreachable!("server session: response_duplex_vec") + } + + fn set_read_timeout(&mut self, _timeout: Option) { + unreachable!("server session: set_read_timeout") + } + + fn get_read_timeout(&self) -> Option { + unreachable!("server_session: get_read_timeout") + } + + fn set_write_timeout(&mut self, _timeout: Option) { + unreachable!("server session: set_write_timeout") + } + + fn get_write_timeout(&self) -> Option { + unreachable!("server_session: get_write_timeout") + } + + fn set_total_drain_timeout(&mut self, _timeout: Option) { + unreachable!("server session: set_total_drain_timeout") + } + + fn get_total_drain_timeout(&self) -> Option { + unreachable!("server_session: get_total_drain_timeout") + } + + fn request_summary(&self) -> String { + unreachable!("server session: request_summary") + } + + fn response_written(&self) -> Option<&ResponseHeader> { + unreachable!("server session: response_written") + } + + async fn shutdown(&mut self, _code: u32, _ctx: &str) { + unreachable!("server session: shutdown") + } + + fn is_body_done(&mut self) -> bool { + unreachable!("server session: is_body_done") + } + + async fn finish(&mut self) -> Result<()> { + unreachable!("server session: finish") + } + + fn is_body_empty(&mut self) -> bool { + unreachable!("server session: is_body_empty") + } + + async fn read_body_or_idle(&mut self, _no_body_expected: bool) -> Result> { + unreachable!("server session: read_body_or_idle") + } + + fn body_bytes_sent(&self) -> usize { + unreachable!("server session: body_bytes_sent") + } + + fn body_bytes_read(&self) -> usize { + unreachable!("server session: body_bytes_read") + } + + fn digest(&self) -> Option<&Digest> { + unreachable!("server session: digest") + } + + fn digest_mut(&mut self) -> Option<&mut Digest> { + unreachable!("server session: digest_mut") + } + + fn client_addr(&self) -> Option<&SocketAddr> { + unreachable!("server session: client_addr") + } + + fn server_addr(&self) -> Option<&SocketAddr> { + unreachable!("server session: server_addr") + } + + fn pseudo_raw_h1_request_header(&self) -> Bytes { + unreachable!("server session: pseudo_raw_h1_request_header") + } + + fn enable_retry_buffering(&mut self) { + unreachable!("server session: enable_retry_bufferings") + } + + fn retry_buffer_truncated(&self) -> bool { + unreachable!("server session: retry_buffer_truncated") + } + + fn get_retry_buffer(&self) -> Option { + unreachable!("server session: get_retry_buffer") + } + + async fn finish_custom(&mut self) -> Result<()> { + unreachable!("server session: finish_custom") + } + + fn take_custom_message_reader( + &mut self, + ) -> Option> + Unpin + Send + Sync + 'static>> { + unreachable!("server session: get_custom_message_reader") + } + + fn restore_custom_message_reader( + &mut self, + _reader: Box> + Unpin + Send + Sync + 'static>, + ) -> Result<()> { + unreachable!("server session: get_custom_message_reader") + } + + fn take_custom_message_writer(&mut self) -> Option> { + unreachable!("server session: get_custom_message_writer") + } + + fn restore_custom_message_writer( + &mut self, + _writer: Box, + ) -> Result<()> { + unreachable!("server session: restore_custom_message_writer") + } + + fn is_upgrade_req(&self) -> bool { + unreachable!("server session: is_upgrade_req") + } + + fn was_upgraded(&self) -> bool { + unreachable!("server session: was_upgraded") + } +} diff --git a/pingora-core/src/protocols/http/date.rs b/pingora-core/src/protocols/http/date.rs index 87d49489..610c9386 100644 --- a/pingora-core/src/protocols/http/date.rs +++ b/pingora-core/src/protocols/http/date.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-core/src/protocols/http/error_resp.rs b/pingora-core/src/protocols/http/error_resp.rs index f802d4d0..e58f66fe 100644 --- a/pingora-core/src/protocols/http/error_resp.rs +++ b/pingora-core/src/protocols/http/error_resp.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-core/src/protocols/http/mod.rs b/pingora-core/src/protocols/http/mod.rs index 488cffb6..f5bc729d 100644 --- a/pingora-core/src/protocols/http/mod.rs +++ b/pingora-core/src/protocols/http/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,12 +14,13 @@ //! HTTP/1.x and HTTP/2 implementation APIs -mod body_buffer; +pub mod body_buffer; pub mod bridge; pub mod client; pub mod compression; pub mod conditional_filter; -pub(crate) mod date; +pub mod custom; +pub mod date; pub mod error_resp; pub mod server; pub mod subrequest; @@ -36,8 +37,10 @@ pub const SERVER_NAME: &[u8; 7] = b"Pingora"; pub enum HttpTask { /// the response header and the boolean end of response flag Header(Box, bool), - /// A piece of response body and the end of response boolean flag + /// A piece of request or response body and the end of request/response boolean flag. Body(Option, bool), + /// Request or response body bytes that have been upgraded on H1.1, and EOF bool flag. + UpgradedBody(Option, bool), /// HTTP response trailer Trailer(Option>), /// Signal that the response is already finished @@ -52,6 +55,7 @@ impl HttpTask { match self { HttpTask::Header(_, end) => *end, HttpTask::Body(_, end) => *end, + HttpTask::UpgradedBody(_, end) => *end, HttpTask::Trailer(_) => true, HttpTask::Done => true, HttpTask::Failed(_) => true, @@ -63,6 +67,7 @@ impl HttpTask { match self { HttpTask::Header(..) => "Header", HttpTask::Body(..) => "Body", + HttpTask::UpgradedBody(..) => "UpgradedBody", HttpTask::Trailer(_) => "Trailer", HttpTask::Done => "Done", HttpTask::Failed(_) => "Failed", diff --git a/pingora-core/src/protocols/http/server.rs b/pingora-core/src/protocols/http/server.rs index bc5964d6..035a65cc 100644 --- a/pingora-core/src/protocols/http/server.rs +++ b/pingora-core/src/protocols/http/server.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,16 +14,18 @@ //! HTTP server session APIs +use super::custom::server::Session as SessionCustom; use super::error_resp; use super::subrequest::server::HttpSession as SessionSubrequest; use super::v1::server::HttpSession as SessionV1; use super::v2::server::HttpSession as SessionV2; use super::HttpTask; +use crate::custom_session; use crate::protocols::{Digest, SocketAddr, Stream}; use bytes::Bytes; use http::HeaderValue; use http::{header::AsHeaderName, HeaderMap}; -use pingora_error::Result; +use pingora_error::{Error, Result}; use pingora_http::{RequestHeader, ResponseHeader}; use std::time::Duration; @@ -32,6 +34,7 @@ pub enum Session { H1(SessionV1), H2(SessionV2), Subrequest(SessionSubrequest), + Custom(Box), } impl Session { @@ -50,6 +53,11 @@ impl Session { Self::Subrequest(session) } + /// Create a new [`Session`] from a custom session + pub fn new_custom(session: Box) -> Self { + Self::Custom(session) + } + /// Whether the session is HTTP/2. If not it is HTTP/1.x pub fn is_http2(&self) -> bool { matches!(self, Self::H2(_)) @@ -60,6 +68,11 @@ impl Session { matches!(self, Self::Subrequest(_)) } + /// Whether the session is Custom + pub fn is_custom(&self) -> bool { + matches!(self, Self::Custom(_)) + } + /// Read the request header. This method is required to be called first before doing anything /// else with the session. /// - `Ok(true)`: successful @@ -77,6 +90,7 @@ impl Session { let read = s.read_request().await?; Ok(read.is_some()) } + Self::Custom(_) => Ok(true), } } @@ -88,6 +102,7 @@ impl Session { Self::H1(s) => s.req_header(), Self::H2(s) => s.req_header(), Self::Subrequest(s) => s.req_header(), + Self::Custom(s) => s.req_header(), } } @@ -99,6 +114,7 @@ impl Session { Self::H1(s) => s.req_header_mut(), Self::H2(s) => s.req_header_mut(), Self::Subrequest(s) => s.req_header_mut(), + Self::Custom(s) => s.req_header_mut(), } } @@ -122,6 +138,7 @@ impl Session { Self::H1(s) => s.read_body_bytes().await, Self::H2(s) => s.read_body_bytes().await, Self::Subrequest(s) => s.read_body_bytes().await, + Self::Custom(s) => s.read_body_bytes().await, } } @@ -134,6 +151,7 @@ impl Session { Self::H1(s) => s.drain_request_body().await, Self::H2(s) => s.drain_request_body().await, Self::Subrequest(s) => s.drain_request_body().await, + Self::Custom(s) => s.drain_request_body().await, } } @@ -151,6 +169,7 @@ impl Session { s.write_response_header(resp).await?; Ok(()) } + Self::Custom(s) => s.write_response_header(resp, false).await, } } @@ -166,6 +185,7 @@ impl Session { s.write_response_header_ref(resp).await?; Ok(()) } + Self::Custom(s) => s.write_response_header_ref(resp, false).await, } } @@ -192,6 +212,7 @@ impl Session { s.write_body(data).await?; Ok(()) } + Self::Custom(s) => s.write_body(data, end).await, } } @@ -201,6 +222,7 @@ impl Session { Self::H1(_) => Ok(()), // TODO: support trailers for h1 Self::H2(s) => s.write_trailers(trailers), Self::Subrequest(s) => s.write_trailers(Some(Box::new(trailers))).await, + Self::Custom(s) => s.write_trailers(trailers).await, } } @@ -223,6 +245,25 @@ impl Session { s.finish().await?; Ok(None) } + Self::Custom(mut s) => { + s.finish().await?; + Ok(None) + } + } + } + + /// Callback for cleanup logic on downstream specifically when we fail to proxy the session + /// other than cleanup via finish(). + /// + /// If caching the downstream failure may be independent of (and precede) an upstream error in + /// which case this function may be called more than once. + pub fn on_proxy_failure(&mut self, e: Box) { + match self { + Self::H1(_) | Self::H2(_) | Self::Custom(_) => { + // all cleanup logic handled in finish(), + // stream and resources dropped when session dropped + } + Self::Subrequest(ref mut s) => s.on_proxy_failure(e), } } @@ -231,6 +272,7 @@ impl Session { Self::H1(s) => s.response_duplex_vec(tasks).await, Self::H2(s) => s.response_duplex_vec(tasks).await, Self::Subrequest(s) => s.response_duplex_vec(tasks).await, + Self::Custom(s) => s.response_duplex_vec(tasks).await, } } @@ -241,6 +283,7 @@ impl Session { Self::H1(s) => s.set_server_keepalive(duration), Self::H2(_) => {} Self::Subrequest(_) => {} + Self::Custom(_) => {} } } @@ -251,6 +294,26 @@ impl Session { Self::H1(s) => s.get_keepalive_timeout(), Self::H2(_) => None, Self::Subrequest(_) => None, + Self::Custom(_) => None, + } + } + + /// Set the number of times the upstream connection connection for this + /// session can be reused via keepalive. Noop for h2 and subrequest + pub fn set_keepalive_reuses_remaining(&mut self, reuses: Option) { + if let Self::H1(s) = self { + s.set_keepalive_reuses_remaining(reuses); + } + } + + /// Get the number of times the upstream connection connection for this + /// session can be reused via keepalive. Not applicable for h2 or + /// subrequest + pub fn get_keepalive_reuses_remaining(&self) -> Option { + if let Self::H1(s) = self { + s.get_keepalive_reuses_remaining() + } else { + None } } @@ -263,6 +326,7 @@ impl Session { Self::H1(s) => s.set_read_timeout(timeout), Self::H2(_) => {} Self::Subrequest(s) => s.set_read_timeout(timeout), + Self::Custom(c) => c.set_read_timeout(timeout), } } @@ -272,6 +336,7 @@ impl Session { Self::H1(s) => s.get_read_timeout(), Self::H2(_) => None, Self::Subrequest(s) => s.get_read_timeout(), + Self::Custom(s) => s.get_read_timeout(), } } @@ -283,6 +348,7 @@ impl Session { Self::H1(s) => s.set_write_timeout(timeout), Self::H2(s) => s.set_write_timeout(timeout), Self::Subrequest(s) => s.set_write_timeout(timeout), + Self::Custom(c) => c.set_write_timeout(timeout), } } @@ -292,6 +358,7 @@ impl Session { Self::H1(s) => s.get_write_timeout(), Self::H2(s) => s.get_write_timeout(), Self::Subrequest(s) => s.get_write_timeout(), + Self::Custom(s) => s.get_write_timeout(), } } @@ -306,6 +373,7 @@ impl Session { Self::H1(s) => s.set_total_drain_timeout(timeout), Self::H2(s) => s.set_total_drain_timeout(timeout), Self::Subrequest(s) => s.set_total_drain_timeout(timeout), + Self::Custom(c) => c.set_total_drain_timeout(timeout), } } @@ -315,6 +383,7 @@ impl Session { Self::H1(s) => s.get_total_drain_timeout(), Self::H2(s) => s.get_total_drain_timeout(), Self::Subrequest(s) => s.get_total_drain_timeout(), + Self::Custom(s) => s.get_total_drain_timeout(), } } @@ -333,6 +402,7 @@ impl Session { Self::H1(s) => s.set_min_send_rate(rate), Self::H2(_) => {} Self::Subrequest(_) => {} + Self::Custom(_) => {} } } @@ -349,6 +419,7 @@ impl Session { Self::H1(s) => s.set_ignore_info_resp(ignore), Self::H2(_) => {} // always ignored Self::Subrequest(_) => {} + Self::Custom(_) => {} // always ignored } } @@ -361,6 +432,7 @@ impl Session { Self::H1(s) => s.set_close_on_response_before_downstream_finish(close), Self::H2(_) => {} // always ignored Self::Subrequest(_) => {} // always ignored + Self::Custom(_) => {} // always ignored } } @@ -371,6 +443,7 @@ impl Session { Self::H1(s) => s.request_summary(), Self::H2(s) => s.request_summary(), Self::Subrequest(s) => s.request_summary(), + Self::Custom(s) => s.request_summary(), } } @@ -381,6 +454,7 @@ impl Session { Self::H1(s) => s.response_written(), Self::H2(s) => s.response_written(), Self::Subrequest(s) => s.response_written(), + Self::Custom(s) => s.response_written(), } } @@ -393,6 +467,7 @@ impl Session { Self::H1(s) => s.shutdown().await, Self::H2(s) => s.shutdown(), Self::Subrequest(s) => s.shutdown(), + Self::Custom(s) => s.shutdown(0, "shutdown").await, } } @@ -401,6 +476,7 @@ impl Session { Self::H1(s) => s.get_headers_raw_bytes(), Self::H2(s) => s.pseudo_raw_h1_request_header(), Self::Subrequest(s) => s.get_headers_raw_bytes(), + Self::Custom(c) => c.pseudo_raw_h1_request_header(), } } @@ -410,6 +486,7 @@ impl Session { Self::H1(s) => s.is_body_done(), Self::H2(s) => s.is_body_done(), Self::Subrequest(s) => s.is_body_done(), + Self::Custom(s) => s.is_body_done(), } } @@ -423,6 +500,7 @@ impl Session { Self::H1(s) => s.finish_body().await.map(|_| ()), Self::H2(s) => s.finish(), Self::Subrequest(s) => s.finish().await.map(|_| ()), + Self::Custom(s) => s.finish().await, } } @@ -477,6 +555,8 @@ impl Session { self.finish_body().await?; } + custom_session!(self.finish_custom().await?); + Ok(()) } @@ -486,6 +566,7 @@ impl Session { Self::H1(s) => s.is_body_empty(), Self::H2(s) => s.is_body_empty(), Self::Subrequest(s) => s.is_body_empty(), + Self::Custom(s) => s.is_body_empty(), } } @@ -494,6 +575,7 @@ impl Session { Self::H1(s) => s.retry_buffer_truncated(), Self::H2(s) => s.retry_buffer_truncated(), Self::Subrequest(s) => s.retry_buffer_truncated(), + Self::Custom(s) => s.retry_buffer_truncated(), } } @@ -502,6 +584,7 @@ impl Session { Self::H1(s) => s.enable_retry_buffering(), Self::H2(s) => s.enable_retry_buffering(), Self::Subrequest(s) => s.enable_retry_buffering(), + Self::Custom(s) => s.enable_retry_buffering(), } } @@ -510,6 +593,7 @@ impl Session { Self::H1(s) => s.get_retry_buffer(), Self::H2(s) => s.get_retry_buffer(), Self::Subrequest(s) => s.get_retry_buffer(), + Self::Custom(s) => s.get_retry_buffer(), } } @@ -520,6 +604,7 @@ impl Session { Self::H1(s) => s.read_body_or_idle(no_body_expected).await, Self::H2(s) => s.read_body_or_idle(no_body_expected).await, Self::Subrequest(s) => s.read_body_or_idle(no_body_expected).await, + Self::Custom(s) => s.read_body_or_idle(no_body_expected).await, } } @@ -528,6 +613,7 @@ impl Session { Self::H1(s) => Some(s), Self::H2(_) => None, Self::Subrequest(_) => None, + Self::Custom(_) => None, } } @@ -536,6 +622,7 @@ impl Session { Self::H1(_) => None, Self::H2(s) => Some(s), Self::Subrequest(_) => None, + Self::Custom(_) => None, } } @@ -544,6 +631,7 @@ impl Session { Self::H1(_) => None, Self::H2(_) => None, Self::Subrequest(s) => Some(s), + Self::Custom(_) => None, } } @@ -552,6 +640,25 @@ impl Session { Self::H1(_) => None, Self::H2(_) => None, Self::Subrequest(s) => Some(s), + Self::Custom(_) => None, + } + } + + pub fn as_custom(&self) -> Option<&dyn SessionCustom> { + match self { + Self::H1(_) => None, + Self::H2(_) => None, + Self::Subrequest(_) => None, + Self::Custom(c) => Some(c.as_ref()), + } + } + + pub fn as_custom_mut(&mut self) -> Option<&mut Box> { + match self { + Self::H1(_) => None, + Self::H2(_) => None, + Self::Subrequest(_) => None, + Self::Custom(c) => Some(c), } } @@ -564,15 +671,34 @@ impl Session { false, ), Self::Subrequest(s) => s.write_continue_response().await, + // TODO(slava): is there any write_continue_response calls? + Self::Custom(s) => { + s.write_response_header( + Box::new(ResponseHeader::build(100, Some(0)).unwrap()), + false, + ) + .await + } } } - /// Whether this request is for upgrade (e.g., websocket) + /// Whether this request is for upgrade (e.g., websocket). pub fn is_upgrade_req(&self) -> bool { match self { Self::H1(s) => s.is_upgrade_req(), Self::H2(_) => false, Self::Subrequest(s) => s.is_upgrade_req(), + Self::Custom(s) => s.is_upgrade_req(), + } + } + + /// Whether this session was fully upgraded (completed Upgrade handshake). + pub fn was_upgraded(&self) -> bool { + match self { + Self::H1(s) => s.was_upgraded(), + Self::H2(_) => false, + Self::Subrequest(s) => s.was_upgraded(), + Self::Custom(s) => s.was_upgraded(), } } @@ -582,6 +708,7 @@ impl Session { Self::H1(s) => s.body_bytes_sent(), Self::H2(s) => s.body_bytes_sent(), Self::Subrequest(s) => s.body_bytes_sent(), + Self::Custom(s) => s.body_bytes_sent(), } } @@ -591,6 +718,7 @@ impl Session { Self::H1(s) => s.body_bytes_read(), Self::H2(s) => s.body_bytes_read(), Self::Subrequest(s) => s.body_bytes_read(), + Self::Custom(s) => s.body_bytes_read(), } } @@ -600,6 +728,7 @@ impl Session { Self::H1(s) => Some(s.digest()), Self::H2(s) => s.digest(), Self::Subrequest(s) => s.digest(), + Self::Custom(s) => s.digest(), } } @@ -611,6 +740,7 @@ impl Session { Self::H1(s) => Some(s.digest_mut()), Self::H2(s) => s.digest_mut(), Self::Subrequest(s) => s.digest_mut(), + Self::Custom(s) => s.digest_mut(), } } @@ -620,6 +750,7 @@ impl Session { Self::H1(s) => s.client_addr(), Self::H2(s) => s.client_addr(), Self::Subrequest(s) => s.client_addr(), + Self::Custom(s) => s.client_addr(), } } @@ -629,6 +760,7 @@ impl Session { Self::H1(s) => s.server_addr(), Self::H2(s) => s.server_addr(), Self::Subrequest(s) => s.server_addr(), + Self::Custom(s) => s.server_addr(), } } @@ -639,6 +771,7 @@ impl Session { Self::H1(s) => Some(s.stream()), Self::H2(_) => None, Self::Subrequest(_) => None, + Self::Custom(_) => None, } } } diff --git a/pingora-core/src/protocols/http/subrequest/body.rs b/pingora-core/src/protocols/http/subrequest/body.rs index acfef4b5..183e6c12 100644 --- a/pingora-core/src/protocols/http/subrequest/body.rs +++ b/pingora-core/src/protocols/http/subrequest/body.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -72,7 +72,19 @@ impl BodyReader { } } - pub fn init_until_close(&mut self) { + pub fn init_close_delimited(&mut self) { + self.body_state = PS::UntilClose(0); + } + + /// Convert how we interpret the remainder of the body to read until close. + /// This is used for responses without explicit framing. + pub fn convert_to_close_delimited(&mut self) { + if matches!(self.body_state, PS::UntilClose(_)) { + // nothing to do, already in close-delimited mode + return; + } + + // reset body counter self.body_state = PS::UntilClose(0); } @@ -214,7 +226,7 @@ impl BodyWriter { } } - pub fn init_until_close(&mut self) { + pub fn init_close_delimited(&mut self) { self.body_mode = BM::UntilClose(0); } @@ -494,7 +506,7 @@ mod tests { let input2 = b""; // zero length body but not actually close let (tx, mut rx) = mpsc::channel::(TASK_BUFFER_SIZE); let mut body_reader = BodyReader::new(None); - body_reader.init_until_close(); + body_reader.init_close_delimited(); tx.send(HttpTask::Body(Some(Bytes::from(&input1[..])), false)) .await @@ -566,7 +578,7 @@ mod tests { let data = b"a"; let (mut tx, mut rx) = mpsc::channel::(TASK_BUFFER_SIZE); let mut body_writer = BodyWriter::new(); - body_writer.init_until_close(); + body_writer.init_close_delimited(); assert_eq!(body_writer.body_mode, BodyMode::UntilClose(0)); let res = body_writer .write_body(&mut tx, Bytes::from(&data[..])) diff --git a/pingora-core/src/protocols/http/subrequest/dummy.rs b/pingora-core/src/protocols/http/subrequest/dummy.rs index 9df9c2cb..93973448 100644 --- a/pingora-core/src/protocols/http/subrequest/dummy.rs +++ b/pingora-core/src/protocols/http/subrequest/dummy.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-core/src/protocols/http/subrequest/server.rs b/pingora-core/src/protocols/http/subrequest/server.rs index 691532d8..c91dbf91 100644 --- a/pingora-core/src/protocols/http/subrequest/server.rs +++ b/pingora-core/src/protocols/http/subrequest/server.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -34,7 +34,7 @@ use bytes::Bytes; use http::HeaderValue; -use http::{header, header::AsHeaderName, HeaderMap, Method, Version}; +use http::{header, header::AsHeaderName, HeaderMap, Method}; use log::{debug, trace, warn}; use pingora_error::{Error, ErrorType::*, OkOrErr, Result}; use pingora_http::{RequestHeader, ResponseHeader}; @@ -47,7 +47,7 @@ use crate::protocols::http::{ body_buffer::FixedBuffer, server::Session as GenericHttpSession, subrequest::dummy::DummyIO, - v1::common::{header_value_content_length, is_header_value_chunked_encoding, BODY_BUF_LIMIT}, + v1::common::{header_value_content_length, is_chunked_encoding_from_headers, BODY_BUF_LIMIT}, v1::server::HttpSession as SessionV1, HttpTask, }; @@ -61,6 +61,7 @@ pub struct HttpSession { // Currently subrequest session is initialized via a dummy SessionV1 only // TODO: need to be able to indicate H2 / other HTTP versions here v1_inner: Box, + proxy_error: Option>>, // option to consume the sender read_req_header: bool, response_written: Option, read_timeout: Option, @@ -84,8 +85,9 @@ pub struct SubrequestHandle { /// Channel receiver (for subrequest output) pub rx: mpsc::Receiver, /// Indicates when subrequest wants to start reading body input - // TODO: use when piping subrequest input/output pub subreq_wants_body: oneshot::Receiver<()>, + /// Any final or downstream error that was encountered while proxying + pub subreq_proxy_error: oneshot::Receiver>, } impl SubrequestHandle { @@ -111,11 +113,13 @@ impl HttpSession { let (downstream_tx, downstream_rx) = mpsc::channel(CHANNEL_BUFFER_SIZE); let (upstream_tx, upstream_rx) = mpsc::channel(CHANNEL_BUFFER_SIZE); let (wants_body_tx, wants_body_rx) = oneshot::channel(); + let (proxy_error_tx, proxy_error_rx) = oneshot::channel(); ( HttpSession { v1_inner: Box::new(v1_inner), tx: Some(upstream_tx), rx: Some(downstream_rx), + proxy_error: Some(proxy_error_tx), body_reader: BodyReader::new(Some(wants_body_tx)), body_writer: BodyWriter::new(), read_req_header: false, @@ -134,6 +138,7 @@ impl HttpSession { tx: downstream_tx, rx: upstream_rx, subreq_wants_body: wants_body_rx, + subreq_proxy_error: proxy_error_rx, }, ) } @@ -222,12 +227,11 @@ impl HttpSession { /// Read the request body. `Ok(None)` when there is no (more) body to read. pub async fn read_body_bytes(&mut self) -> Result> { let read = self.read_body().await?; - Ok(read.map(|b| { + Ok(read.inspect(|b| { self.body_bytes_read += b.len(); if let Some(buffer) = self.retry_buffer.as_mut() { - buffer.write_to_buffer(&b); + buffer.write_to_buffer(b); } - b })) } @@ -322,11 +326,25 @@ impl HttpSession { // a peer discards any further data received. // https://www.rfc-editor.org/rfc/rfc6455#section-1.4 self.upgraded = true; + // Now that the upgrade was successful, we need to change + // how we interpret the rest of the body as pass-through. + if self.body_reader.need_init() { + self.init_body_reader(); + } else { + // already initialized + // immediately start reading the rest of the body as upgraded + // (in theory most upgraded requests shouldn't have any body) + // + // TODO: https://datatracker.ietf.org/doc/html/rfc9110#name-upgrade + // the most spec-compliant behavior is to switch interpretation + // after sending the former body. For now we immediately + // switch interpretation to match nginx behavior. + // TODO: this has no effect resetting the body counter of TE chunked + self.body_reader.convert_to_close_delimited(); + } } else { debug!("bad upgrade handshake!"); - // reset request body buf and mark as done - // safe to reset an upgrade because it doesn't have body - self.body_reader.init_content_length(0); + // continue to read body as-is, this is now just a regular request } } self.init_body_writer(&header); @@ -361,6 +379,16 @@ impl HttpSession { self.v1_inner.is_upgrade(header) } + /// Was this request successfully turned into an upgraded connection? + /// + /// Both the request had to have been an `Upgrade` request + /// and the response had to have been a `101 Switching Protocols`. + // XXX: this should only be valid if subrequest is standing in for + // a v1 session. + pub fn was_upgraded(&self) -> bool { + self.upgraded + } + fn init_body_writer(&mut self, header: &ResponseHeader) { use http::StatusCode; /* the following responses don't have body 204, 304, and HEAD */ @@ -379,24 +407,21 @@ impl HttpSession { } if self.is_upgrade(header) == Some(true) { - self.body_writer.init_until_close(); + self.body_writer.init_close_delimited(); + } else if is_chunked_encoding_from_headers(&header.headers) { + // transfer-encoding takes priority over content-length + self.body_writer.init_close_delimited(); } else { - let te_value = header.headers.get(http::header::TRANSFER_ENCODING); - if is_header_value_chunked_encoding(te_value) { - // transfer-encoding takes priority over content-length - self.body_writer.init_until_close(); - } else { - let content_length = - header_value_content_length(header.headers.get(http::header::CONTENT_LENGTH)); - match content_length { - Some(length) => { - self.body_writer.init_content_length(length); - } - None => { - /* TODO: 1. connection: keepalive cannot be used, - 2. mark connection must be closed */ - self.body_writer.init_until_close(); - } + let content_length = + header_value_content_length(header.headers.get(http::header::CONTENT_LENGTH)); + match content_length { + Some(length) => { + self.body_writer.init_content_length(length); + } + None => { + /* TODO: 1. connection: keepalive cannot be used, + 2. mark connection must be closed */ + self.body_writer.init_close_delimited(); } } } @@ -454,6 +479,21 @@ impl HttpSession { Ok(res) } + /// Signal to error listener held by SubrequestHandle that a proxy error was encountered, + /// and pass along what that error was. + /// + /// This is helpful to signal what errors were encountered outside of the proxy state machine, + /// e.g. during subrequest request filters. + /// + /// Note: in the case of multiple proxy failures e.g. when caching, only the first error will + /// be propagated (i.e. downstream error first if it goes away before upstream). + pub fn on_proxy_failure(&mut self, e: Box) { + // fine if handle is gone + if let Some(sender) = self.proxy_error.take() { + let _ = sender.send(e); + } + } + /// Return how many response body bytes (application, not wire) already sent downstream pub fn body_bytes_sent(&self) -> usize { self.body_bytes_sent @@ -465,7 +505,7 @@ impl HttpSession { } fn is_chunked_encoding(&self) -> bool { - is_header_value_chunked_encoding(self.get_header(header::TRANSFER_ENCODING)) + is_chunked_encoding_from_headers(&self.req_header().headers) } /// Clear body-related subrequest headers. @@ -490,16 +530,15 @@ impl HttpSession { buffer.clear(); } - if self.req_header().version == Version::HTTP_11 && self.is_upgrade_req() { - self.body_reader.init_until_close(); - return; - } - - if self.is_chunked_encoding() { + if self.was_upgraded() { + // if upgraded _post_ 101 (and body was not init yet) + // treat as upgraded body (pass through until closed) + self.body_reader.init_close_delimited(); + } else if self.is_chunked_encoding() { // if chunked encoding, content-length should be ignored // TE is not visible at subrequest HttpTask level // so this means read until request closure - self.body_reader.init_until_close(); + self.body_reader.init_close_delimited(); } else { let cl = header_value_content_length(self.get_header(header::CONTENT_LENGTH)); match cl { @@ -507,15 +546,11 @@ impl HttpSession { self.body_reader.init_content_length(i); } None => { - match self.req_header().version { - Version::HTTP_11 => { - // Per RFC assume no body by default in HTTP 1.1 - self.body_reader.init_content_length(0); - } - _ => { - self.body_reader.init_until_close(); - } - } + // Per RFC 9112: "Request messages are never close-delimited because they are + // always explicitly framed by length or transfer coding, with the absence of + // both implying the request ends immediately after the header section." + // All HTTP/1.x requests without Content-Length or Transfer-Encoding have 0 body + self.body_reader.init_content_length(0); } } } @@ -554,7 +589,7 @@ impl HttpSession { // just consume empty body or done messages, the downstream channel is not a real // connection and only used for this one request while matches!(&task, HttpTask::Done) - || matches!(&task, HttpTask::Body(b, _) if b.as_ref().map_or(true, |b| b.is_empty())) + || matches!(&task, HttpTask::Body(b, _) if b.as_ref().is_none_or(|b| b.is_empty())) { task = rx .recv() @@ -660,6 +695,24 @@ impl HttpSession { Ok(()) } + async fn write_non_empty_body(&mut self, data: Option, upgraded: bool) -> Result<()> { + if upgraded != self.upgraded { + if upgraded { + panic!("Unexpected UpgradedBody task received on un-upgraded downstream session (subrequest)"); + } else { + panic!("Unexpected Body task received on upgraded downstream session (subrequest)"); + } + } + let Some(d) = data else { + return Ok(()); + }; + if d.is_empty() { + return Ok(()); + } + self.write_body(d).await.map_err(|e| e.into_down())?; + Ok(()) + } + async fn response_duplex(&mut self, task: HttpTask) -> Result { let end_stream = match task { HttpTask::Header(header, end_stream) => { @@ -668,15 +721,14 @@ impl HttpSession { .map_err(|e| e.into_down())?; end_stream } - HttpTask::Body(data, end_stream) => match data { - Some(d) => { - if !d.is_empty() { - self.write_body(d).await.map_err(|e| e.into_down())?; - } - end_stream - } - None => end_stream, - }, + HttpTask::Body(data, end_stream) => { + self.write_non_empty_body(data, false).await?; + end_stream + } + HttpTask::UpgradedBody(data, end_stream) => { + self.write_non_empty_body(data, true).await?; + end_stream + } HttpTask::Trailer(trailers) => { self.write_trailers(trailers).await?; true @@ -708,15 +760,14 @@ impl HttpSession { .map_err(|e| e.into_down())?; end_stream } - HttpTask::Body(data, end_stream) => match data { - Some(d) => { - if !d.is_empty() { - self.write_body(d).await.map_err(|e| e.into_down())?; - } - end_stream - } - None => end_stream, - }, + HttpTask::Body(data, end_stream) => { + self.write_non_empty_body(data, false).await?; + end_stream + } + HttpTask::UpgradedBody(data, end_stream) => { + self.write_non_empty_body(data, true).await?; + end_stream + } HttpTask::Done => { // write done // we'll send HttpTask::Done at the end of this loop in finish @@ -755,7 +806,9 @@ impl HttpSession { mod tests_stream { use super::*; use crate::protocols::http::subrequest::body::{BodyMode, ParseState}; + use bytes::BufMut; use http::StatusCode; + use rstest::rstest; use std::str; use tokio_test::io::Builder; @@ -834,7 +887,7 @@ mod tests_stream { .await .unwrap(); // 100 won't affect body state - assert!(!http_stream.is_body_done()); + assert!(http_stream.is_body_done()); } #[tokio::test] @@ -1037,4 +1090,131 @@ mod tests_stream { t => panic!("unexpected task {t:?}"), } } + + async fn session_from_input_no_validate(input: &[u8]) -> (HttpSession, SubrequestHandle) { + let mock_io = Builder::new().read(input).build(); + let mut http_stream = GenericHttpSession::new_http1(Box::new(mock_io)); + // Read the request in v1 inner session to set up headers properly + http_stream.read_request().await.unwrap(); + let (http_stream, handle) = HttpSession::new_from_session(&http_stream); + (http_stream, handle) + } + + #[rstest] + #[case::negative("-1")] + #[case::not_a_number("abc")] + #[case::float("1.5")] + #[case::empty("")] + #[case::spaces(" ")] + #[case::mixed("123abc")] + #[tokio::test] + async fn validate_request_rejects_invalid_content_length(#[case] invalid_value: &str) { + init_log(); + let input = format!( + "POST / HTTP/1.1\r\nHost: pingora.org\r\nContent-Length: {}\r\n\r\n", + invalid_value + ); + let mock_io = Builder::new().read(input.as_bytes()).build(); + let mut http_stream = GenericHttpSession::new_http1(Box::new(mock_io)); + // read_request calls validate_request internally on the v1 inner stream, so it should fail here + let res = http_stream.read_request().await; + assert!(res.is_err()); + assert_eq!( + res.unwrap_err().etype(), + &pingora_error::ErrorType::InvalidHTTPHeader + ); + } + + #[rstest] + #[case::valid_zero("0")] + #[case::valid_small("123")] + #[case::valid_large("999999")] + #[tokio::test] + async fn validate_request_accepts_valid_content_length(#[case] valid_value: &str) { + init_log(); + let input = format!( + "POST / HTTP/1.1\r\nHost: pingora.org\r\nContent-Length: {}\r\n\r\n", + valid_value + ); + let (mut http_stream, _handle) = session_from_input_no_validate(input.as_bytes()).await; + let res = http_stream.read_request().await; + assert!(res.is_ok()); + } + + #[tokio::test] + async fn validate_request_accepts_no_content_length() { + init_log(); + let input = b"GET / HTTP/1.1\r\nHost: pingora.org\r\n\r\n"; + let (mut http_stream, _handle) = session_from_input_no_validate(input).await; + let res = http_stream.read_request().await; + assert!(res.is_ok()); + } + + const POST_CL_UPGRADE_REQ: &[u8] = b"POST / HTTP/1.1\r\nHost: pingora.org\r\nUpgrade: websocket\r\nConnection: upgrade\r\nContent-Length: 10\r\n\r\n"; + const POST_CHUNKED_UPGRADE_REQ: &[u8] = b"POST / HTTP/1.1\r\nHost: pingora.org\r\nUpgrade: websocket\r\nConnection: upgrade\r\nTransfer-Encoding: chunked\r\n\r\n"; + const POST_BODY_DATA: &[u8] = b"abcdefghij"; + + async fn build_upgrade_req_with_body(header: &[u8]) -> (HttpSession, SubrequestHandle) { + let mock_io = Builder::new().read(header).build(); + let mut http_stream = GenericHttpSession::new_http1(Box::new(mock_io)); + http_stream.read_request().await.unwrap(); + let (mut http_stream, handle) = HttpSession::new_from_session(&http_stream); + http_stream.read_request().await.unwrap(); + (http_stream, handle) + } + + #[rstest] + #[case::content_length(POST_CL_UPGRADE_REQ)] + #[case::chunked(POST_CHUNKED_UPGRADE_REQ)] + #[tokio::test] + async fn read_upgrade_req_with_body(#[case] header: &[u8]) { + init_log(); + let (mut http_stream, handle) = build_upgrade_req_with_body(header).await; + assert!(http_stream.is_upgrade_req()); + // request has body + assert!(!http_stream.is_body_done()); + + // Send body via the handle + handle + .tx + .send(HttpTask::Body(Some(Bytes::from(POST_BODY_DATA)), true)) + .await + .unwrap(); + + let mut buf = vec![]; + while let Some(b) = http_stream.read_body_bytes().await.unwrap() { + buf.put_slice(&b); + } + assert_eq!(buf, POST_BODY_DATA); + assert_eq!(http_stream.body_reader.body_state, ParseState::Complete(10)); + assert_eq!(http_stream.body_bytes_read(), 10); + + assert!(http_stream.is_body_done()); + + let mut response = ResponseHeader::build(StatusCode::SWITCHING_PROTOCOLS, None).unwrap(); + response.set_version(http::Version::HTTP_11); + http_stream + .write_response_header(Box::new(response)) + .await + .unwrap(); + // body reader type switches + assert!(!http_stream.is_body_done()); + + // now send ws data + let ws_data = b"data"; + handle + .tx + .send(HttpTask::Body(Some(Bytes::from(&ws_data[..])), false)) + .await + .unwrap(); + + let buf = http_stream.read_body_bytes().await.unwrap().unwrap(); + assert_eq!(buf, ws_data.as_slice()); + assert!(!http_stream.is_body_done()); + + // EOF ends body + drop(handle.tx); + assert!(http_stream.read_body_bytes().await.unwrap().is_none()); + assert!(http_stream.is_body_done()); + } } diff --git a/pingora-core/src/protocols/http/v1/body.rs b/pingora-core/src/protocols/http/v1/body.rs index e118ef7e..72899257 100644 --- a/pingora-core/src/protocols/http/v1/body.rs +++ b/pingora-core/src/protocols/http/v1/body.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -29,20 +29,40 @@ use crate::utils::BufRef; const BODY_BUFFER_SIZE: usize = 1024 * 64; // limit how much incomplete chunk-size and chunk-ext to buffer const PARTIAL_CHUNK_HEAD_LIMIT: usize = 1024 * 8; - -const LAST_CHUNK: &[u8; 5] = b"0\r\n\r\n"; +// Trailers: https://datatracker.ietf.org/doc/html/rfc9112#section-7.1.2 +// TODO: proper trailer handling and parsing +// generally trailers are an uncommonly used HTTP/1.1 feature, this is a somewhat +// arbitrary cap on trailer size after the 0 chunk size (like header buf) +const TRAILER_SIZE_LIMIT: usize = 1024 * 64; + +const LAST_CHUNK: &[u8; 5] = &[b'0', CR, LF, CR, LF]; +const CR: u8 = b'\r'; +const LF: u8 = b'\n'; +const CRLF: &[u8; 2] = &[CR, LF]; +// This is really the CRLF end of the last trailer (or 0 chunk), + the last CRLF. +const TRAILERS_END: &[u8; 4] = &[CR, LF, CR, LF]; pub const INVALID_CHUNK: ErrorType = ErrorType::new("InvalidChunk"); +pub const INVALID_TRAILER_END: ErrorType = ErrorType::new("InvalidTrailerEnd"); pub const PREMATURE_BODY_END: ErrorType = ErrorType::new("PrematureBodyEnd"); #[derive(Clone, Debug, PartialEq, Eq)] pub enum ParseState { ToStart, - Complete(usize), // total size - Partial(usize, usize), // size read, remaining size - Chunked(usize, usize, usize, usize), // size read, next to read in current buf start, read in current buf start, remaining chucked size to read from IO - Done(usize), // done but there is error, size read - HTTP1_0(usize), // read until connection closed, size read + // Complete: total size (contetn-length) + Complete(usize), + // Partial: size read, remaining size (content-length) + Partial(usize, usize), + // Chunked: Chunked encoding, prior to the final 0\r\n chunk. + // size read, next to read in current buf start, read in current buf start, remaining chunked size to read from IO + Chunked(usize, usize, usize, usize), + // ChunkedFinal: Final section once the 0\r\n chunk is read. + // size read, trailer sizes parsed so far, use existing buf end, trailers end read + ChunkedFinal(usize, usize, usize, u8), + // Done: done but there is error, size read + Done(usize), + // UntilClose: read until connection closed, size read + UntilClose(usize), } type PS = ParseState; @@ -52,7 +72,8 @@ impl ParseState { match self { PS::Partial(read, to_read) => PS::Complete(read + to_read), PS::Chunked(read, _, _, _) => PS::Complete(read + additional_bytes), - PS::HTTP1_0(read) => PS::Complete(read + additional_bytes), + PS::ChunkedFinal(read, _, _, _) => PS::Complete(read + additional_bytes), + PS::UntilClose(read) => PS::Complete(read + additional_bytes), _ => self.clone(), /* invalid transaction */ } } @@ -61,7 +82,26 @@ impl ParseState { match self { PS::Partial(read, _) => PS::Done(read + additional_bytes), PS::Chunked(read, _, _, _) => PS::Done(read + additional_bytes), - PS::HTTP1_0(read) => PS::Done(read + additional_bytes), + PS::ChunkedFinal(read, _, _, _) => PS::Done(read + additional_bytes), + PS::UntilClose(read) => PS::Done(read + additional_bytes), + _ => self.clone(), /* invalid transaction */ + } + } + + pub fn read_final_chunk(&self, remaining_buf_size: usize) -> Self { + match self { + PS::Chunked(read, _, _, _) => { + // The BodyReader is currently expected to copy the remaining buf + // into self.body_buf. + // + // the 2 == the CRLF from the last chunk-size, 0 + CRLF + // because ChunkedFinal is looking for CRLF + CRLF to end + // the whole message. + // This extra 2 bytes technically ends up cutting into the max trailers size, + // which we consider fine for now until full trailers support. + PS::ChunkedFinal(*read, 0, remaining_buf_size, 2) + } + PS::ChunkedFinal(..) => panic!("already read final chunk"), _ => self.clone(), /* invalid transaction */ } } @@ -69,6 +109,7 @@ impl ParseState { pub fn partial_chunk(&self, bytes_read: usize, bytes_to_read: usize) -> Self { match self { PS::Chunked(read, _, _, _) => PS::Chunked(read + bytes_read, 0, 0, bytes_to_read), + PS::ChunkedFinal(..) => panic!("chunked transactions not applicable after final chunk"), _ => self.clone(), /* invalid transaction */ } } @@ -78,6 +119,7 @@ impl ParseState { PS::Chunked(read, _, buf_end, _) => { PS::Chunked(read + bytes_read, buf_start_index, *buf_end, 0) } + PS::ChunkedFinal(..) => panic!("chunked transactions not applicable after final chunk"), _ => self.clone(), /* invalid transaction */ } } @@ -86,6 +128,7 @@ impl ParseState { match self { /* inform reader to read more to form a legal chunk */ PS::Chunked(read, _, _, _) => PS::Chunked(*read, 0, head_end, head_size), + PS::ChunkedFinal(..) => panic!("chunked transactions not applicable after final chunk"), _ => self.clone(), /* invalid transaction */ } } @@ -93,6 +136,7 @@ impl ParseState { pub fn new_buf(&self, buf_end: usize) -> Self { match self { PS::Chunked(read, _, _, _) => PS::Chunked(*read, 0, buf_end, 0), + PS::ChunkedFinal(..) => panic!("chunked transactions not applicable after final chunk"), _ => self.clone(), /* invalid transaction */ } } @@ -150,7 +194,15 @@ impl BodyReader { pub fn init_content_length(&mut self, cl: usize, buf_to_rewind: &[u8]) { match cl { - 0 => self.body_state = PS::Complete(0), + 0 => { + self.body_state = PS::Complete(0); + // Store any extra bytes that were read as overread + if !buf_to_rewind.is_empty() { + let mut overread = BytesMut::with_capacity(buf_to_rewind.len()); + overread.put_slice(buf_to_rewind); + self.body_buf_overread = Some(overread); + } + } _ => { self.prepare_buf(buf_to_rewind); self.body_state = PS::Partial(0, cl); @@ -158,9 +210,29 @@ impl BodyReader { } } - pub fn init_http10(&mut self, buf_to_rewind: &[u8]) { + pub fn init_close_delimited(&mut self, buf_to_rewind: &[u8]) { self.prepare_buf(buf_to_rewind); - self.body_state = PS::HTTP1_0(0); + self.body_state = PS::UntilClose(0); + } + + /// Convert how we interpret the remainder of the body to read until close. + /// This is used for responses without explicit framing (e.g., HTTP/1.0 responses). + /// + /// Does nothing if already in close-delimited mode. + pub fn convert_to_close_delimited(&mut self) { + if matches!(self.body_state, PS::UntilClose(_)) { + // nothing to do, already in close-delimited mode + return; + } + + if self.rewind_buf_len == 0 { + // take any extra bytes and send them as-is, + // reset body counter + let extra = self.body_buf_overread.take(); + let buf = extra.as_deref().unwrap_or_default(); + self.prepare_buf(buf); + } // if rewind_buf_len is not 0, body read has not yet been polled + self.body_state = PS::UntilClose(0); } pub fn get_body(&self, buf_ref: &BufRef) -> &[u8] { @@ -168,7 +240,8 @@ impl BodyReader { buf_ref.get(self.body_buf.as_ref().unwrap()) } - fn get_body_overread(&self) -> Option<&[u8]> { + #[allow(dead_code)] + pub fn get_body_overread(&self) -> Option<&[u8]> { self.body_buf_overread.as_deref() } @@ -200,8 +273,9 @@ impl BodyReader { PS::Complete(_) => Ok(None), PS::Done(_) => Ok(None), PS::Partial(_, _) => self.do_read_body(stream).await, - PS::Chunked(_, _, _, _) => self.do_read_chunked_body(stream).await, - PS::HTTP1_0(_) => self.do_read_body_until_closed(stream).await, + PS::Chunked(..) => self.do_read_chunked_body(stream).await, + PS::ChunkedFinal(..) => self.do_read_chunked_body_final(stream).await, + PS::UntilClose(_) => self.do_read_body_until_closed(stream).await, PS::ToStart => panic!("need to init BodyReader first"), } } @@ -275,12 +349,12 @@ impl BodyReader { .or_err(ReadError, "when reading body")?; } match self.body_state { - PS::HTTP1_0(read) => { + PS::UntilClose(read) => { if n == 0 { self.body_state = PS::Complete(read); Ok(None) } else { - self.body_state = PS::HTTP1_0(read + n); + self.body_state = PS::UntilClose(read + n); Ok(Some(BufRef::new(0, n))) } } @@ -351,18 +425,18 @@ impl BodyReader { ) } else { if expecting_from_io > 0 { + let body_buf = self.body_buf.as_ref().unwrap(); trace!( "partial chunk payload, expecting_from_io: {}, \ existing_buf_end {}, buf: {:?}", expecting_from_io, existing_buf_end, - String::from_utf8_lossy( - &self.body_buf.as_ref().unwrap()[..existing_buf_end] - ) + self.body_buf.as_ref().unwrap()[..existing_buf_end].escape_ascii() ); + // partial chunk payload, will read more if expecting_from_io >= existing_buf_end + 2 { - // not enough + // not enough (doesn't contain CRLF end) self.body_state = self.body_state.partial_chunk( existing_buf_end, expecting_from_io - existing_buf_end, @@ -372,30 +446,76 @@ impl BodyReader { /* could be expecting DATA + CRLF or just CRLF */ let payload_size = expecting_from_io.saturating_sub(2); /* expecting_from_io < existing_buf_end + 2 */ + let need_lf_only = expecting_from_io == 1; // otherwise we need the whole CRLF + if expecting_from_io > existing_buf_end { + // potentially: + // | CR | LF | + // | | + // (existing_buf_end) + // | + // (expecting_from_io) + if payload_size < existing_buf_end { + Self::validate_crlf( + &mut self.body_state, + &body_buf[payload_size..existing_buf_end], + need_lf_only, + false, + )?; + } + } else { + // expecting_from_io <= existing_buf_end + // chunk CRLF end should end here + assert!(Self::validate_crlf( + &mut self.body_state, + &body_buf[payload_size..expecting_from_io], + need_lf_only, + false, + )?); + } if expecting_from_io >= existing_buf_end { self.body_state = self .body_state .partial_chunk(payload_size, expecting_from_io - existing_buf_end); + return Ok(Some(BufRef::new(0, payload_size))); } /* expecting_from_io < existing_buf_end */ self.body_state = self.body_state.multi_chunk(payload_size, expecting_from_io); + return Ok(Some(BufRef::new(0, payload_size))); } - self.parse_chunked_buf(existing_buf_start, existing_buf_end) + let (buf_res, last_chunk_size_end) = + self.parse_chunked_buf(existing_buf_start, existing_buf_end)?; + if buf_res.is_some() { + if let Some(idx) = last_chunk_size_end { + // just read the last 0 + CRLF, but not final end CRLF + // copy the rest of the buffer to the start of the body_buf + // so we can parse the remaining bytes as trailers / end + let body_buf = self.body_buf.as_deref_mut().unwrap(); + trace!( + "last chunk size end buf {:?}", + &body_buf[..existing_buf_end].escape_ascii(), + ); + body_buf.copy_within(idx..existing_buf_end, 0); + } + } + Ok(buf_res) } } _ => panic!("wrong body state: {:?}", self.body_state), } } + // Returns: BufRef of next body chunk, + // terminating chunk-size index end if read completely (0 + CRLF). + // Note input indices are absolute (to body_buf). fn parse_chunked_buf( &mut self, buf_index_start: usize, buf_index_end: usize, - ) -> Result> { + ) -> Result<(Option, Option)> { let buf = &self.body_buf.as_ref().unwrap()[buf_index_start..buf_index_end]; let chunk_status = httparse::parse_chunk_size(buf); match chunk_status { @@ -405,13 +525,39 @@ impl BodyReader { // TODO: Check chunk_size overflow trace!( "Got size {chunk_size}, payload_index: {payload_index}, chunk: {:?}", - String::from_utf8_lossy(buf) + String::from_utf8_lossy(buf).escape_default(), ); let chunk_size = chunk_size as usize; + // https://github.com/seanmonstar/httparse/issues/149 + // httparse does not treat zero-size chunk differently, it does not check + // that terminating chunk is 0 + double CRLF if chunk_size == 0 { - /* terminating chunk. TODO: trailer */ - self.body_state = self.body_state.finish(0); - return Ok(None); + /* terminating chunk, also need to handle trailer. */ + let chunk_end_index = payload_index + 2; + return if chunk_end_index <= buf.len() + && buf[payload_index..chunk_end_index] == CRLF[..] + { + // full terminating CRLF MAY exist in current buf + // Skip ChunkedFinal state and go directly to Complete + // as optimization. + self.body_state = self.body_state.finish(0); + self.finish_body_buf( + buf_index_start + chunk_end_index, + buf_index_end, + ); + Ok((None, Some(buf_index_start + payload_index))) + } else { + // Indicate start of parsing final chunked trailers, + // with remaining buf to read + self.body_state = self.body_state.read_final_chunk( + buf_index_end - (buf_index_start + payload_index), + ); + + Ok(( + Some(BufRef::new(0, 0)), + Some(buf_index_start + payload_index), + )) + }; } // chunk-size CRLF [payload_index] byte*[chunk_size] CRLF let data_end_index = payload_index + chunk_size; @@ -423,22 +569,40 @@ impl BodyReader { } else { chunk_size }; + + let crlf_start = chunk_end_index.saturating_sub(2); + if crlf_start < buf.len() { + Self::validate_crlf( + &mut self.body_state, + &buf[crlf_start..], + false, + false, + )?; + } + // else need to read more to get to CRLF + self.body_state = self .body_state .partial_chunk(actual_size, chunk_end_index - buf.len()); - return Ok(Some(BufRef::new( - buf_index_start + payload_index, - actual_size, - ))); + return Ok(( + Some(BufRef::new(buf_index_start + payload_index, actual_size)), + None, + )); } /* got multiple chunks, return the first */ + assert!(Self::validate_crlf( + &mut self.body_state, + &buf[data_end_index..chunk_end_index], + false, + false, + )?); self.body_state = self .body_state .multi_chunk(chunk_size, buf_index_start + chunk_end_index); - Ok(Some(BufRef::new( - buf_index_start + payload_index, - chunk_size, - ))) + Ok(( + Some(BufRef::new(buf_index_start + payload_index, chunk_size)), + None, + )) } httparse::Status::Partial => { if buf.len() > PARTIAL_CHUNK_HEAD_LIMIT { @@ -451,19 +615,283 @@ impl BodyReader { } else { self.body_state = self.body_state.partial_chunk_head(buf_index_end, buf.len()); - Ok(Some(BufRef::new(0, 0))) + Ok((Some(BufRef::new(0, 0)), None)) } } } } Err(e) => { - let context = format!("Invalid chucked encoding: {e:?}"); - debug!("{context}, {:?}", String::from_utf8_lossy(buf)); + let context = format!("Invalid chunked encoding: {e:?}"); + debug!( + "{context}, {:?}", + String::from_utf8_lossy(buf).escape_default() + ); self.body_state = self.body_state.done(0); Error::e_explain(INVALID_CHUNK, context) } } } + + pub async fn do_read_chunked_body_final(&mut self, stream: &mut S) -> Result> + where + S: AsyncRead + Unpin + Send, + { + // parse section after last-chunk: https://datatracker.ietf.org/doc/html/rfc9112#section-7.1 + // This is the section after the final chunk we're trying to read, which can include + // HTTP1 trailers (currently we just discard them). + // Really we are just waiting for a consecutive CRLF + CRLF to end the body. + match self.body_state { + PS::ChunkedFinal(read, trailers_read, existing_buf_end, end_read) => { + let body_buf = self.body_buf.as_deref_mut().unwrap(); + let (buf, n) = if existing_buf_end != 0 { + // finish rest of buf that was read with Chunked state + // existing_buf_end is non-zero only once + self.body_state = PS::ChunkedFinal(read, trailers_read, 0, end_read); + (&body_buf[..existing_buf_end], existing_buf_end) + } else { + let n = stream + .read(body_buf) + .await + .or_err(ReadError, "when reading trailers end")?; + + (&body_buf[..n], n) + }; + + if n == 0 { + self.body_state = PS::Done(read); + return Error::e_explain( + ConnectionClosed, + format!( + "Connection prematurely closed without the termination chunk, \ + read {read} bytes, {trailers_read} trailer bytes" + ), + ); + } + + let mut start = 0; + // try to find end within the current IO buffer + while start < n { + // Adjusts body state through each iteration to add trailers read + // Each iteration finds the next CR or LF to advance the buf + let (trailers_read, end_read) = match self.body_state { + PS::ChunkedFinal(_, new_trailers_read, _, new_end_read) => { + (new_trailers_read, new_end_read) + } + _ => unreachable!(), + }; + + let mut buf = &buf[start..n]; + trace!( + "Parsing chunk end for buf {:?}", + String::from_utf8_lossy(buf).escape_default(), + ); + + if end_read == 0 { + // find the next CRLF sequence / potential end + let (trailers_read, no_crlf) = + if let Some(p) = buf.iter().position(|b| *b == CR || *b == LF) { + buf = &buf[p..]; + start += p; + (trailers_read + p, false) + } else { + // consider this all trailer bytes + (trailers_read + (n - start), true) + }; + + if trailers_read > TRAILER_SIZE_LIMIT { + self.body_state = self.body_state.done(0); + return Error::e_explain( + INVALID_TRAILER_END, + "Trailer size over limit", + ); + } + + self.body_state = PS::ChunkedFinal(read, trailers_read, 0, 0); + + if no_crlf { + // break and allow polling read body again + break; + } + } + match Self::parse_trailers_end(&mut self.body_state, buf)? { + TrailersEndParseState::NotEnd(next_parse_index) => { + trace!( + "Parsing chunk end for buf {:?}, resume at {next_parse_index}", + String::from_utf8_lossy(buf).escape_default(), + ); + + start += next_parse_index; + } + TrailersEndParseState::Complete(end_idx) => { + trace!( + "Parsing chunk end for buf {:?}, finished at {end_idx}", + String::from_utf8_lossy(buf).escape_default(), + ); + + self.finish_body_buf(start + end_idx, n); + return Ok(None); + } + } + } + } + _ => panic!("wrong body state: {:?}", self.body_state), + } + // indicate final section is not done + Ok(Some(BufRef(0, 0))) + } + + // Parses up to one CRLF at a time to determine if, given the body state, + // we've parsed a full trailer end. + // Panics if empty buffer is given. + fn parse_trailers_end( + body_state: &mut ParseState, + buf: &[u8], + ) -> Result { + assert!(!buf.is_empty(), "parse_trailers_end given empty buffer"); + + match body_state.clone() { + PS::ChunkedFinal(read, trailers_read, _, end_read) => { + // Look at the body buf we just read and see if it matches + // the ending CRLF + CRLF sequence. + let end_read = end_read as usize; + assert!(end_read < TRAILERS_END.len()); + let to_read = std::cmp::min(buf.len(), TRAILERS_END.len() - end_read); + let buf = &buf[..to_read]; + + // If the start of the buf is not CRLF and we are not in the middle of reading a + // valid CRLF sequence, return to let caller seek for next CRLF + if end_read % 2 == 0 && buf[0] != CR && buf[0] != LF { + trace!( + "parse trailers end {:?}, not CRLF sequence", + String::from_utf8_lossy(buf).escape_default(), + ); + *body_state = PS::ChunkedFinal(read, trailers_read + end_read, 0, 0); + return Ok(TrailersEndParseState::NotEnd(0)); + } + // Check for malformed CRLF in trailers (or final end of trailers section) + let next_parse_index = match end_read { + 0 | 2 => { + // expect start with CR + if Self::validate_crlf(body_state, buf, false, true)? { + // found CR + LF + 2 + } else { + // read CR at least + 1 + } + } + 1 | 3 => { + // assert: only way this can return false is with an empty buffer + assert!(Self::validate_crlf(body_state, buf, true, true)?); + 1 + } + _ => unreachable!(), + }; + let next_end_read = end_read + next_parse_index; + let finished = next_end_read == TRAILERS_END.len(); + if finished { + trace!( + "parse trailers end {:?}, complete {next_end_read}", + String::from_utf8_lossy(buf).escape_default(), + ); + *body_state = PS::Complete(read); + Ok(TrailersEndParseState::Complete(next_parse_index)) + } else { + // either we read the end of one trailer and another one follows, + // or trailer end CRLF sequence so far is valid but we need more bytes + // to determine if more CRLF actually follows + trace!( + "parse trailers end {:?}, resume at {next_parse_index}", + String::from_utf8_lossy(buf).escape_default(), + ); + // unwrap safety for try_into() u8: next_end_read always < + // TRAILERS_END.len() + *body_state = + PS::ChunkedFinal(read, trailers_read, 0, next_end_read.try_into().unwrap()); + Ok(TrailersEndParseState::NotEnd(next_parse_index)) + } + } + _ => panic!("wrong body state: {:?}", body_state), + } + } + + // Validates that the starting bytes of `buf` are the expected CRLF bytes. + // Expects: buf that starts at the indices where CRLF should be for chunked bodies. + // If need_lf_only, we will only check for LF, else we will check starting with CR. + // + // Returns Ok() if buf begins with expected bytes (CR, LF, or CRLF). + // The inner bool returned is whether the whole CRLF sequence was completed. + fn validate_crlf( + body_state: &mut ParseState, + buf: &[u8], + need_lf_only: bool, + for_trailer_end: bool, + ) -> Result { + let etype = if for_trailer_end { + INVALID_TRAILER_END + } else { + INVALID_CHUNK + }; + if need_lf_only { + if buf.is_empty() { + Ok(false) + } else { + let b = &buf[..1]; + if b == b"\n" { + // only LF left + Ok(true) + } else { + *body_state = body_state.done(0); + Error::e_explain( + etype, + format!( + "Invalid chunked encoding: {} was not LF", + String::from_utf8_lossy(b).escape_default(), + ), + ) + } + } + } else { + match buf.len() { + 0 => Ok(false), + 1 => { + let b = &buf[..1]; + if b == b"\r" { + Ok(false) + } else { + *body_state = body_state.done(0); + Error::e_explain( + etype, + format!( + "Invalid chunked encoding: {} was not CR", + String::from_utf8_lossy(b).escape_default(), + ), + ) + } + } + _ => { + let b = &buf[..2]; + if b == b"\r\n" { + Ok(true) + } else { + *body_state = body_state.done(0); + Error::e_explain( + etype, + format!( + "Invalid chunked encoding: {} was not CRLF", + String::from_utf8_lossy(b).escape_default(), + ), + ) + } + } + } + } + } +} + +pub enum TrailersEndParseState { + NotEnd(usize), // start of bytes after CR or LF bytes + Complete(usize), // index of message completion } #[derive(Clone, Debug, PartialEq, Eq)] @@ -471,7 +899,7 @@ pub enum BodyMode { ToSelect, ContentLength(usize, usize), // total length to write, bytes already written ChunkedEncoding(usize), //bytes written - HTTP1_0(usize), //bytes written + UntilClose(usize), //bytes written Complete(usize), //bytes written } @@ -492,14 +920,26 @@ impl BodyWriter { self.body_mode = BM::ChunkedEncoding(0); } - pub fn init_http10(&mut self) { - self.body_mode = BM::HTTP1_0(0); + pub fn init_close_delimited(&mut self) { + self.body_mode = BM::UntilClose(0); } pub fn init_content_length(&mut self, cl: usize) { self.body_mode = BM::ContentLength(cl, 0); } + pub fn convert_to_close_delimited(&mut self) { + if matches!(self.body_mode, BodyMode::UntilClose(_)) { + // nothing to do, already in close-delimited mode + return; + } + + // NOTE: any stream buffered data will be flushed in next + // close-delimited write + // reset body state to close-delimited (UntilClose) + self.body_mode = BM::UntilClose(0); + } + // NOTE on buffering/flush stream when writing the body // Buffering writes can reduce the syscalls hence improves efficiency of the system // But it hurts real time communication @@ -515,7 +955,7 @@ impl BodyWriter { BM::Complete(_) => Ok(None), BM::ContentLength(_, _) => self.do_write_body(stream, buf).await, BM::ChunkedEncoding(_) => self.do_write_chunked_body(stream, buf).await, - BM::HTTP1_0(_) => self.do_write_http1_0_body(stream, buf).await, + BM::UntilClose(_) => self.do_write_until_close_body(stream, buf).await, BM::ToSelect => Ok(None), // Error here? } } @@ -528,6 +968,10 @@ impl BodyWriter { } } + pub fn is_close_delimited(&self) -> bool { + matches!(self.body_mode, BM::UntilClose(_)) + } + async fn do_write_body(&mut self, stream: &mut S, buf: &[u8]) -> Result> where S: AsyncWrite + Unpin + Send, @@ -586,7 +1030,7 @@ impl BodyWriter { } } - async fn do_write_http1_0_body( + async fn do_write_until_close_body( &mut self, stream: &mut S, buf: &[u8], @@ -595,11 +1039,11 @@ impl BodyWriter { S: AsyncWrite + Unpin + Send, { match self.body_mode { - BM::HTTP1_0(written) => { + BM::UntilClose(written) => { let res = stream.write_all(buf).await; match res { Ok(()) => { - self.body_mode = BM::HTTP1_0(written + buf.len()); + self.body_mode = BM::UntilClose(written + buf.len()); stream.flush().await.or_err(WriteError, "flushing body")?; Ok(Some(buf.len())) } @@ -618,7 +1062,7 @@ impl BodyWriter { BM::Complete(_) => Ok(None), BM::ContentLength(_, _) => self.do_finish_body(stream), BM::ChunkedEncoding(_) => self.do_finish_chunked_body(stream).await, - BM::HTTP1_0(_) => self.do_finish_http1_0_body(stream), + BM::UntilClose(_) => self.do_finish_until_close_body(stream), BM::ToSelect => Ok(None), } } @@ -656,9 +1100,9 @@ impl BodyWriter { } } - fn do_finish_http1_0_body(&mut self, _stream: &mut S) -> Result> { + fn do_finish_until_close_body(&mut self, _stream: &mut S) -> Result> { match self.body_mode { - BM::HTTP1_0(written) => { + BM::UntilClose(written) => { self.body_mode = BM::Complete(written); Ok(Some(written)) } @@ -687,6 +1131,7 @@ mod tests { assert_eq!(res, BufRef::new(0, 3)); assert_eq!(body_reader.body_state, ParseState::Complete(3)); assert_eq!(input, body_reader.get_body(&res)); + assert_eq!(body_reader.get_body_overread(), None); } #[tokio::test] @@ -705,6 +1150,7 @@ mod tests { assert_eq!(res, BufRef::new(0, 2)); assert_eq!(body_reader.body_state, ParseState::Complete(3)); assert_eq!(input2, body_reader.get_body(&res)); + assert_eq!(body_reader.get_body_overread(), None); } #[tokio::test] @@ -722,6 +1168,7 @@ mod tests { let res = body_reader.read_body(&mut mock_io).await.unwrap_err(); assert_eq!(&ConnectionClosed, res.etype()); assert_eq!(body_reader.body_state, ParseState::Done(1)); + assert_eq!(body_reader.get_body_overread(), None); } #[tokio::test] @@ -782,6 +1229,7 @@ mod tests { assert_eq!(res, BufRef::new(0, 1)); assert_eq!(body_reader.body_state, ParseState::Complete(3)); assert_eq!(input, body_reader.get_body(&res)); + assert_eq!(body_reader.get_body_overread(), None); } #[tokio::test] @@ -791,14 +1239,15 @@ mod tests { let input2 = b""; // simulating close let mut mock_io = Builder::new().read(&input1[..]).read(&input2[..]).build(); let mut body_reader = BodyReader::new(false); - body_reader.init_http10(b""); + body_reader.init_close_delimited(b""); let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); assert_eq!(res, BufRef::new(0, 1)); - assert_eq!(body_reader.body_state, ParseState::HTTP1_0(1)); + assert_eq!(body_reader.body_state, ParseState::UntilClose(1)); assert_eq!(input1, body_reader.get_body(&res)); let res = body_reader.read_body(&mut mock_io).await.unwrap(); assert_eq!(res, None); assert_eq!(body_reader.body_state, ParseState::Complete(1)); + assert_eq!(body_reader.get_body_overread(), None); } #[tokio::test] @@ -809,18 +1258,19 @@ mod tests { let input2 = b""; // simulating close let mut mock_io = Builder::new().read(&input1[..]).read(&input2[..]).build(); let mut body_reader = BodyReader::new(false); - body_reader.init_http10(rewind); + body_reader.init_close_delimited(rewind); let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); assert_eq!(res, BufRef::new(0, 2)); - assert_eq!(body_reader.body_state, ParseState::HTTP1_0(2)); + assert_eq!(body_reader.body_state, ParseState::UntilClose(2)); assert_eq!(rewind, body_reader.get_body(&res)); let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); assert_eq!(res, BufRef::new(0, 1)); - assert_eq!(body_reader.body_state, ParseState::HTTP1_0(3)); + assert_eq!(body_reader.body_state, ParseState::UntilClose(3)); assert_eq!(input1, body_reader.get_body(&res)); let res = body_reader.read_body(&mut mock_io).await.unwrap(); assert_eq!(res, None); assert_eq!(body_reader.body_state, ParseState::Complete(3)); + assert_eq!(body_reader.get_body_overread(), None); } #[tokio::test] @@ -833,6 +1283,123 @@ mod tests { let res = body_reader.read_body(&mut mock_io).await.unwrap(); assert_eq!(res, None); assert_eq!(body_reader.body_state, ParseState::Complete(0)); + assert_eq!(body_reader.get_body_overread(), None); + } + + #[tokio::test] + async fn read_with_body_zero_chunk_malformed() { + init_log(); + let input = b"0\r\nr\n"; + let mut mock_io = Builder::new().read(&input[..]).build(); + let mut body_reader = BodyReader::new(false); + body_reader.init_chunked(b""); + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); + assert_eq!(body_reader.body_state, ParseState::ChunkedFinal(0, 0, 2, 2)); + + // \n without leading \r + let e = body_reader.read_body(&mut mock_io).await.unwrap_err(); + assert_eq!(*e.etype(), INVALID_TRAILER_END); + assert_eq!(body_reader.body_state, ParseState::Done(0)); + assert_eq!(body_reader.get_body_overread(), None); + } + + #[tokio::test] + async fn read_with_body_zero_chunk_split() { + init_log(); + let input1 = b"0\r\n"; + let input2 = b"\r\n"; + let mut mock_io = Builder::new().read(&input1[..]).read(&input2[..]).build(); + let mut body_reader = BodyReader::new(false); + body_reader.init_chunked(b""); + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); + assert_eq!(body_reader.body_state, ParseState::ChunkedFinal(0, 0, 0, 2)); + let res = body_reader.read_body(&mut mock_io).await.unwrap(); + assert_eq!(res, None); + assert_eq!(body_reader.body_state, ParseState::Complete(0)); + assert_eq!(body_reader.get_body_overread(), None); + } + + #[tokio::test] + async fn read_with_body_zero_chunk_split_head() { + init_log(); + let input1 = b"0\r"; + let input2 = b"\n"; + let input3 = b"\r\n"; + let mut mock_io = Builder::new() + .read(&input1[..]) + .read(&input2[..]) + .read(&input3[..]) + .build(); + let mut body_reader = BodyReader::new(false); + body_reader.init_chunked(b""); + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); + assert_eq!(body_reader.body_state, ParseState::Chunked(0, 0, 2, 2)); + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); + assert_eq!(body_reader.body_state, ParseState::ChunkedFinal(0, 0, 0, 2)); + let res = body_reader.read_body(&mut mock_io).await.unwrap(); + assert_eq!(res, None); + assert_eq!(body_reader.body_state, ParseState::Complete(0)); + assert_eq!(body_reader.get_body_overread(), None); + } + + #[tokio::test] + async fn read_with_body_zero_chunk_split_head_2() { + init_log(); + let input1 = b"0"; + let input2 = b"\r\n"; + let input3 = b"\r\n"; + let mut mock_io = Builder::new() + .read(&input1[..]) + .read(&input2[..]) + .read(&input3[..]) + .build(); + let mut body_reader = BodyReader::new(false); + body_reader.init_chunked(b""); + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); + assert_eq!(body_reader.body_state, ParseState::Chunked(0, 0, 1, 1)); + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); + assert_eq!(body_reader.body_state, ParseState::ChunkedFinal(0, 0, 0, 2)); + let res = body_reader.read_body(&mut mock_io).await.unwrap(); + assert_eq!(res, None); + assert_eq!(body_reader.body_state, ParseState::Complete(0)); + assert_eq!(body_reader.get_body_overread(), None); + } + + #[tokio::test] + async fn read_with_body_zero_chunk_split_head_3() { + init_log(); + let input1 = b"0\r"; + let input2 = b"\n"; + let input3 = b"\r"; + let input4 = b"\n"; + let mut mock_io = Builder::new() + .read(&input1[..]) + .read(&input2[..]) + .read(&input3[..]) + .read(&input4[..]) + .build(); + let mut body_reader = BodyReader::new(false); + body_reader.init_chunked(b""); + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); + assert_eq!(body_reader.body_state, ParseState::Chunked(0, 0, 2, 2)); + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); + assert_eq!(body_reader.body_state, ParseState::ChunkedFinal(0, 0, 0, 2)); + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); + assert_eq!(body_reader.body_state, ParseState::ChunkedFinal(0, 0, 0, 3)); + + let res = body_reader.read_body(&mut mock_io).await.unwrap(); + assert_eq!(res, None); + assert_eq!(body_reader.body_state, ParseState::Complete(0)); + assert_eq!(body_reader.get_body_overread(), None); } #[tokio::test] @@ -845,6 +1412,7 @@ mod tests { let res = body_reader.read_body(&mut mock_io).await.unwrap(); assert_eq!(res, None); assert_eq!(body_reader.body_state, ParseState::Complete(0)); + assert_eq!(body_reader.get_body_overread(), None); } #[tokio::test] @@ -870,6 +1438,7 @@ mod tests { let res = body_reader.read_body(&mut mock_io).await; assert!(res.is_err()); assert_eq!(body_reader.body_state, ParseState::Done(0)); + assert_eq!(body_reader.get_body_overread(), None); } #[tokio::test] @@ -887,99 +1456,267 @@ mod tests { let res = body_reader.read_body(&mut mock_io).await.unwrap(); assert_eq!(res, None); assert_eq!(body_reader.body_state, ParseState::Complete(1)); + assert_eq!(body_reader.get_body_overread(), None); } #[tokio::test] - async fn read_with_body_1_chunk_incomplete() { + async fn read_with_body_1_chunk_malformed() { init_log(); - let input1 = b"1\r\na\r\n"; + let input1 = b"1\r\na\rn"; let mut mock_io = Builder::new().read(&input1[..]).build(); let mut body_reader = BodyReader::new(false); body_reader.init_chunked(b""); - let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); - assert_eq!(res, BufRef::new(3, 1)); - assert_eq!(&input1[3..4], body_reader.get_body(&res)); - assert_eq!(body_reader.body_state, ParseState::Chunked(1, 0, 0, 0)); - let res = body_reader.read_body(&mut mock_io).await; - assert!(res.is_err()); - assert_eq!(body_reader.body_state, ParseState::Done(1)); + + let e = body_reader.read_body(&mut mock_io).await.unwrap_err(); + assert_eq!(*e.etype(), INVALID_CHUNK); + assert_eq!(body_reader.body_state, ParseState::Done(0)); + + assert_eq!(body_reader.get_body_overread(), None); } #[tokio::test] - async fn read_with_body_1_chunk_rewind() { + async fn read_with_body_1_chunk_partial_end() { init_log(); - let rewind = b"1\r\nx\r\n"; - let input1 = b"1\r\na\r\n"; - let input2 = b"0\r\n\r\n"; + let input1 = b"1\r\na\r"; + let input2 = b"\n0\r\n\r\n"; let mut mock_io = Builder::new().read(&input1[..]).read(&input2[..]).build(); let mut body_reader = BodyReader::new(false); - body_reader.init_chunked(rewind); - let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); - assert_eq!(res, BufRef::new(3, 1)); - assert_eq!(&rewind[3..4], body_reader.get_body(&res)); - assert_eq!(body_reader.body_state, ParseState::Chunked(1, 0, 0, 0)); + body_reader.init_chunked(b""); let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); assert_eq!(res, BufRef::new(3, 1)); assert_eq!(&input1[3..4], body_reader.get_body(&res)); - assert_eq!(body_reader.body_state, ParseState::Chunked(2, 0, 0, 0)); + assert_eq!(body_reader.body_state, ParseState::Chunked(1, 0, 0, 1)); + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); + assert_eq!(body_reader.body_state, ParseState::Chunked(1, 1, 6, 0)); let res = body_reader.read_body(&mut mock_io).await.unwrap(); assert_eq!(res, None); - assert_eq!(body_reader.body_state, ParseState::Complete(2)); + assert_eq!(body_reader.body_state, ParseState::Complete(1)); + assert_eq!(body_reader.get_body_overread(), None); } #[tokio::test] - async fn read_with_body_multi_chunk() { + async fn read_with_body_1_chunk_partial_end_1() { init_log(); - let input1 = b"1\r\na\r\n2\r\nbc\r\n"; - let input2 = b"0\r\n\r\n"; - let mut mock_io = Builder::new().read(&input1[..]).read(&input2[..]).build(); + let input1 = b"3\r\n"; + let input2 = b"abc\r"; + let input3 = b"\n0\r\n\r\n"; + let mut mock_io = Builder::new() + .read(&input1[..]) + .read(&input2[..]) + .read(&input3[..]) + .build(); let mut body_reader = BodyReader::new(false); body_reader.init_chunked(b""); let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); - assert_eq!(res, BufRef::new(3, 1)); - assert_eq!(&input1[3..4], body_reader.get_body(&res)); - assert_eq!(body_reader.body_state, ParseState::Chunked(1, 6, 13, 0)); + assert_eq!(res, BufRef::new(3, 0)); + assert_eq!(b"", body_reader.get_body(&res)); + assert_eq!(body_reader.body_state, ParseState::Chunked(0, 0, 0, 5)); let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); - assert_eq!(res, BufRef::new(9, 2)); - assert_eq!(&input1[9..11], body_reader.get_body(&res)); - assert_eq!(body_reader.body_state, ParseState::Chunked(3, 0, 0, 0)); + assert_eq!(res, BufRef::new(0, 3)); + assert_eq!(&input2[0..3], body_reader.get_body(&res)); + assert_eq!(body_reader.body_state, ParseState::Chunked(3, 0, 0, 1)); + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); let res = body_reader.read_body(&mut mock_io).await.unwrap(); assert_eq!(res, None); assert_eq!(body_reader.body_state, ParseState::Complete(3)); + assert_eq!(body_reader.get_body_overread(), None); } #[tokio::test] - async fn read_with_body_partial_chunk() { + async fn read_with_body_1_chunk_partial_end_2() { init_log(); - let input1 = b"3\r\na"; - let input2 = b"bc\r\n0\r\n\r\n"; - let mut mock_io = Builder::new().read(&input1[..]).read(&input2[..]).build(); + let input1 = b"3\r\n"; + let input2 = b"abc"; + let input3 = b"\r\n0\r\n\r\n"; + let mut mock_io = Builder::new() + .read(&input1[..]) + .read(&input2[..]) + .read(&input3[..]) + .build(); let mut body_reader = BodyReader::new(false); body_reader.init_chunked(b""); let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); - assert_eq!(res, BufRef::new(3, 1)); - assert_eq!(&input1[3..4], body_reader.get_body(&res)); - assert_eq!(body_reader.body_state, ParseState::Chunked(1, 0, 0, 4)); + assert_eq!(res, BufRef::new(3, 0)); + assert_eq!(b"", body_reader.get_body(&res)); + assert_eq!(body_reader.body_state, ParseState::Chunked(0, 0, 0, 5)); let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); - assert_eq!(res, BufRef::new(0, 2)); - assert_eq!(&input2[0..2], body_reader.get_body(&res)); - assert_eq!(body_reader.body_state, ParseState::Chunked(3, 4, 9, 0)); + assert_eq!(res, BufRef::new(0, 3)); + assert_eq!(&input2[0..3], body_reader.get_body(&res)); + assert_eq!(body_reader.body_state, ParseState::Chunked(3, 0, 0, 2)); + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); let res = body_reader.read_body(&mut mock_io).await.unwrap(); assert_eq!(res, None); assert_eq!(body_reader.body_state, ParseState::Complete(3)); + assert_eq!(body_reader.get_body_overread(), None); } #[tokio::test] - async fn read_with_body_partial_head_chunk() { + async fn read_with_body_1_chunk_incomplete() { init_log(); - let input1 = b"1\r"; - let input2 = b"\na\r\n0\r\n\r\n"; - let mut mock_io = Builder::new().read(&input1[..]).read(&input2[..]).build(); + let input1 = b"1\r\na\r\n"; + let mut mock_io = Builder::new().read(&input1[..]).build(); let mut body_reader = BodyReader::new(false); body_reader.init_chunked(b""); let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); - assert_eq!(res, BufRef::new(0, 0)); - assert_eq!(body_reader.body_state, ParseState::Chunked(0, 0, 2, 2)); + assert_eq!(res, BufRef::new(3, 1)); + assert_eq!(&input1[3..4], body_reader.get_body(&res)); + assert_eq!(body_reader.body_state, ParseState::Chunked(1, 0, 0, 0)); + let res = body_reader.read_body(&mut mock_io).await; + assert!(res.is_err()); + assert_eq!(body_reader.body_state, ParseState::Done(1)); + } + + #[tokio::test] + async fn read_with_body_1_chunk_partial_end_malformed() { + init_log(); + let input1 = b"1\r\na\r"; + let input2 = b"n0\r\n\r\n"; + let mut mock_io = Builder::new().read(&input1[..]).read(&input2[..]).build(); + let mut body_reader = BodyReader::new(false); + body_reader.init_chunked(b""); + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(3, 1)); + assert_eq!(&input1[3..4], body_reader.get_body(&res)); + assert_eq!(body_reader.body_state, ParseState::Chunked(1, 0, 0, 1)); + let e = body_reader.read_body(&mut mock_io).await.unwrap_err(); + assert_eq!(*e.etype(), INVALID_CHUNK); + assert_eq!(body_reader.body_state, ParseState::Done(1)); + assert_eq!(body_reader.get_body_overread(), None); + } + + #[tokio::test] + async fn read_with_body_1_chunk_rewind() { + init_log(); + let rewind = b"1\r\nx\r\n"; + let input1 = b"1\r\na\r\n"; + let input2 = b"0\r\n\r\n"; + let mut mock_io = Builder::new().read(&input1[..]).read(&input2[..]).build(); + let mut body_reader = BodyReader::new(false); + body_reader.init_chunked(rewind); + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(3, 1)); + assert_eq!(&rewind[3..4], body_reader.get_body(&res)); + assert_eq!(body_reader.body_state, ParseState::Chunked(1, 0, 0, 0)); + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(3, 1)); + assert_eq!(&input1[3..4], body_reader.get_body(&res)); + assert_eq!(body_reader.body_state, ParseState::Chunked(2, 0, 0, 0)); + let res = body_reader.read_body(&mut mock_io).await.unwrap(); + assert_eq!(res, None); + assert_eq!(body_reader.body_state, ParseState::Complete(2)); + assert_eq!(body_reader.get_body_overread(), None); + } + + #[tokio::test] + async fn read_with_body_multi_chunk() { + init_log(); + let input1 = b"1\r\na\r\n2\r\nbc\r\n"; + let input2 = b"0\r\n\r\n"; + let mut mock_io = Builder::new().read(&input1[..]).read(&input2[..]).build(); + let mut body_reader = BodyReader::new(false); + body_reader.init_chunked(b""); + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(3, 1)); + assert_eq!(&input1[3..4], body_reader.get_body(&res)); + assert_eq!(body_reader.body_state, ParseState::Chunked(1, 6, 13, 0)); + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(9, 2)); + assert_eq!(&input1[9..11], body_reader.get_body(&res)); + assert_eq!(body_reader.body_state, ParseState::Chunked(3, 0, 0, 0)); + let res = body_reader.read_body(&mut mock_io).await.unwrap(); + assert_eq!(res, None); + assert_eq!(body_reader.body_state, ParseState::Complete(3)); + assert_eq!(body_reader.get_body_overread(), None); + } + + #[tokio::test] + async fn read_with_body_multi_chunk_malformed() { + init_log(); + let input1 = b"1\r\na\r\n2\r\nbcr\n"; + let mut mock_io = Builder::new().read(&input1[..]).build(); + let mut body_reader = BodyReader::new(false); + body_reader.init_chunked(b""); + + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(3, 1)); + assert_eq!(&input1[3..4], body_reader.get_body(&res)); + assert_eq!(body_reader.body_state, ParseState::Chunked(1, 6, 13, 0)); + + let e = body_reader.read_body(&mut mock_io).await.unwrap_err(); + assert_eq!(*e.etype(), INVALID_CHUNK); + assert_eq!(body_reader.body_state, ParseState::Done(1)); + assert_eq!(body_reader.get_body_overread(), None); + + let input1 = b"1\r\nar\n2\r\nbc\rn"; + let mut mock_io = Builder::new().read(&input1[..]).build(); + let mut body_reader = BodyReader::new(false); + body_reader.init_chunked(b""); + + let e = body_reader.read_body(&mut mock_io).await.unwrap_err(); + assert_eq!(*e.etype(), INVALID_CHUNK); + assert_eq!(body_reader.body_state, ParseState::Done(0)); + assert_eq!(body_reader.get_body_overread(), None); + } + + #[tokio::test] + async fn read_with_body_partial_chunk() { + init_log(); + let input1 = b"3\r\na"; + let input2 = b"bc\r\n0\r\n\r\n"; + let mut mock_io = Builder::new().read(&input1[..]).read(&input2[..]).build(); + let mut body_reader = BodyReader::new(false); + body_reader.init_chunked(b""); + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(3, 1)); + assert_eq!(&input1[3..4], body_reader.get_body(&res)); + assert_eq!(body_reader.body_state, ParseState::Chunked(1, 0, 0, 4)); + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 2)); + assert_eq!(&input2[0..2], body_reader.get_body(&res)); + assert_eq!(body_reader.body_state, ParseState::Chunked(3, 4, 9, 0)); + let res = body_reader.read_body(&mut mock_io).await.unwrap(); + assert_eq!(res, None); + assert_eq!(body_reader.body_state, ParseState::Complete(3)); + assert_eq!(body_reader.get_body_overread(), None); + } + + #[tokio::test] + async fn read_with_body_partial_chunk_end() { + init_log(); + let input1 = b"3\r\nabc"; + let input2 = b"\r\n0\r\n\r\n"; + let mut mock_io = Builder::new().read(&input1[..]).read(&input2[..]).build(); + let mut body_reader = BodyReader::new(false); + body_reader.init_chunked(b""); + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(3, 3)); + assert_eq!(&input1[3..6], body_reader.get_body(&res)); + // \r\n (2 bytes) left to read from IO + assert_eq!(body_reader.body_state, ParseState::Chunked(3, 0, 0, 2)); + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); + assert_eq!(&input2[0..0], body_reader.get_body(&res)); + assert_eq!(body_reader.body_state, ParseState::Chunked(3, 2, 7, 0)); + let res = body_reader.read_body(&mut mock_io).await.unwrap(); + assert_eq!(res, None); + assert_eq!(body_reader.body_state, ParseState::Complete(3)); + assert_eq!(body_reader.get_body_overread(), None); + } + + #[tokio::test] + async fn read_with_body_partial_head_chunk() { + init_log(); + let input1 = b"1\r"; + let input2 = b"\na\r\n0\r\n\r\n"; + let mut mock_io = Builder::new().read(&input1[..]).read(&input2[..]).build(); + let mut body_reader = BodyReader::new(false); + body_reader.init_chunked(b""); + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); + assert_eq!(body_reader.body_state, ParseState::Chunked(0, 0, 2, 2)); let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); assert_eq!(res, BufRef::new(3, 1)); // input1 concat input2 assert_eq!(&input2[1..2], body_reader.get_body(&res)); @@ -987,6 +1724,205 @@ mod tests { let res = body_reader.read_body(&mut mock_io).await.unwrap(); assert_eq!(res, None); assert_eq!(body_reader.body_state, ParseState::Complete(1)); + assert_eq!(body_reader.get_body_overread(), None); + } + + #[tokio::test] + async fn read_with_body_partial_head_terminal_crlf() { + init_log(); + let input1 = b"1\r"; + let input2 = b"\na\r\n0\r\n\r"; + let input3 = b"\n"; + let mut mock_io = Builder::new() + .read(&input1[..]) + .read(&input2[..]) + .read(&input3[..]) + .build(); + let mut body_reader = BodyReader::new(false); + body_reader.init_chunked(b""); + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); + assert_eq!(body_reader.body_state, ParseState::Chunked(0, 0, 2, 2)); + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(3, 1)); // input1 concat input2 + assert_eq!(&input2[1..2], body_reader.get_body(&res)); + assert_eq!(body_reader.body_state, ParseState::Chunked(1, 6, 10, 0)); + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); // only part of terminal crlf, one more byte to read + assert_eq!(body_reader.body_state, ParseState::ChunkedFinal(1, 0, 1, 2)); + // TODO: can optimize this to avoid the second read_body call + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); + assert_eq!(body_reader.body_state, ParseState::ChunkedFinal(1, 0, 0, 3)); + + let res = body_reader.read_body(&mut mock_io).await.unwrap(); + assert_eq!(res, None); + assert_eq!(body_reader.body_state, ParseState::Complete(1)); + assert_eq!(body_reader.get_body_overread(), None); + } + + #[tokio::test] + async fn read_with_body_partial_head_terminal_crlf_2() { + init_log(); + let input1 = b"1\r"; + let input2 = b"\na\r\n0\r"; + let input3 = b"\n\r\n"; + let mut mock_io = Builder::new() + .read(&input1[..]) + .read(&input2[..]) + .read(&input3[..]) + .build(); + let mut body_reader = BodyReader::new(false); + body_reader.init_chunked(b""); + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); + assert_eq!(body_reader.body_state, ParseState::Chunked(0, 0, 2, 2)); + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(3, 1)); // input1 concat input2 + assert_eq!(&input2[1..2], body_reader.get_body(&res)); + assert_eq!(body_reader.body_state, ParseState::Chunked(1, 6, 8, 0)); + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); // only part of terminal crlf, one more byte to read + assert_eq!(body_reader.body_state, ParseState::Chunked(1, 0, 8, 2)); + // optimized to go right to complete state + let res = body_reader.read_body(&mut mock_io).await.unwrap(); + assert_eq!(res, None); + assert_eq!(body_reader.body_state, ParseState::Complete(1)); + assert_eq!(body_reader.get_body_overread(), None); + } + + #[tokio::test] + async fn read_with_body_partial_head_terminal_crlf_3() { + init_log(); + let input1 = b"1\r\na\r\n0"; + let input2 = b"\r"; + let input3 = b"\n"; + let input4 = b"\r"; + let input5 = b"\n"; + let mut mock_io = Builder::new() + .read(&input1[..]) + .read(&input2[..]) + .read(&input3[..]) + .read(&input4[..]) + .read(&input5[..]) + .build(); + + let mut body_reader = BodyReader::new(false); + body_reader.init_chunked(b""); + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(3, 1)); + assert_eq!(&input1[3..4], body_reader.get_body(&res)); + assert_eq!(body_reader.body_state, ParseState::Chunked(1, 6, 7, 0)); + // to 0 + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); + assert_eq!(body_reader.body_state, ParseState::Chunked(1, 0, 7, 1)); + // \r + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); + assert_eq!(body_reader.body_state, ParseState::Chunked(1, 0, 2, 2)); + // \n + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); + assert_eq!(body_reader.body_state, ParseState::ChunkedFinal(1, 0, 0, 2)); + // \r + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); + assert_eq!(body_reader.body_state, ParseState::ChunkedFinal(1, 0, 0, 3)); + // \n + let res = body_reader.read_body(&mut mock_io).await.unwrap(); + assert_eq!(res, None); + assert_eq!(body_reader.body_state, ParseState::Complete(1)); + assert_eq!(body_reader.get_body_overread(), None); + } + + #[tokio::test] + async fn read_with_body_partial_head_terminal_crlf_malformed() { + init_log(); + let input1 = b"1\r"; + let input2 = b"\na\r\n0\r\nr"; + let mut mock_io = Builder::new().read(&input1[..]).read(&input2[..]).build(); + let mut body_reader = BodyReader::new(false); + body_reader.init_chunked(b""); + + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); + assert_eq!(body_reader.body_state, ParseState::Chunked(0, 0, 2, 2)); + + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(3, 1)); // input1 concat input2 + assert_eq!(&input2[1..2], body_reader.get_body(&res)); + assert_eq!(body_reader.body_state, ParseState::Chunked(1, 6, 10, 0)); + + // TODO: may be able to optimize this extra read_body out + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); + assert_eq!(body_reader.body_state, ParseState::ChunkedFinal(1, 0, 1, 2)); + // "r" is interpreted as a hanging trailer + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); + assert_eq!(body_reader.body_state, ParseState::ChunkedFinal(1, 3, 0, 0)); + + let res = body_reader.read_body(&mut mock_io).await.unwrap_err(); + assert_eq!(&ConnectionClosed, res.etype()); + assert_eq!(body_reader.body_state, ParseState::Done(1)); + assert_eq!(body_reader.get_body_overread(), None); + } + + #[tokio::test] + async fn read_with_body_partial_head_terminal_crlf_overread() { + init_log(); + let input1 = b"1\r"; + let input2 = b"\na\r\n0\r\n\r"; + let input3 = b"\nabcd"; + let mut mock_io = Builder::new() + .read(&input1[..]) + .read(&input2[..]) + .read(&input3[..]) + .build(); + let mut body_reader = BodyReader::new(false); + body_reader.init_chunked(b""); + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); + assert_eq!(body_reader.body_state, ParseState::Chunked(0, 0, 2, 2)); + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(3, 1)); // input1 concat input2 + assert_eq!(&input2[1..2], body_reader.get_body(&res)); + assert_eq!(body_reader.body_state, ParseState::Chunked(1, 6, 10, 0)); + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); // read only part of terminal crlf + assert_eq!(body_reader.body_state, ParseState::ChunkedFinal(1, 0, 1, 2)); + // TODO: can optimize this to avoid the second read_body call + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); + assert_eq!(body_reader.body_state, ParseState::ChunkedFinal(1, 0, 0, 3)); + + let res = body_reader.read_body(&mut mock_io).await.unwrap(); + assert_eq!(res, None); + assert_eq!(body_reader.body_state, ParseState::Complete(1)); + assert_eq!(body_reader.get_body_overread(), Some(&b"abcd"[..])); + } + + #[tokio::test] + async fn read_with_body_multi_chunk_overread() { + init_log(); + let input1 = b"1\r\na\r\n2\r\nbc\r\n"; + let input2 = b"0\r\n\r\nabc"; + let mut mock_io = Builder::new().read(&input1[..]).read(&input2[..]).build(); + let mut body_reader = BodyReader::new(false); + body_reader.init_chunked(b""); + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(3, 1)); + assert_eq!(&input1[3..4], body_reader.get_body(&res)); + assert_eq!(body_reader.body_state, ParseState::Chunked(1, 6, 13, 0)); + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(9, 2)); + assert_eq!(&input1[9..11], body_reader.get_body(&res)); + assert_eq!(body_reader.body_state, ParseState::Chunked(3, 0, 0, 0)); + let res = body_reader.read_body(&mut mock_io).await.unwrap(); + assert_eq!(res, None); + assert_eq!(body_reader.body_state, ParseState::Complete(3)); + assert_eq!(body_reader.get_body_overread(), Some(&b"abc"[..])); } #[tokio::test] @@ -1004,6 +1940,322 @@ mod tests { assert_eq!(body_reader.body_state, ParseState::Done(0)); } + #[tokio::test] + async fn read_with_body_trailers() { + init_log(); + let input1 = b"1\r\na\r\n2\r\nbc\r\n"; + let input2 = b"0\r\nabc: hi"; + let input3 = b"\r\ndef: bye\r"; + let input4 = b"\nghi: more\r\n"; + let input5 = b"\r\n"; + let mut mock_io = Builder::new() + .read(&input1[..]) + .read(&input2[..]) + .read(&input3[..]) + .read(&input4[..]) + .read(&input5[..]) + .build(); + let mut body_reader = BodyReader::new(false); + body_reader.init_chunked(b""); + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(3, 1)); + assert_eq!(&input1[3..4], body_reader.get_body(&res)); + assert_eq!(body_reader.body_state, ParseState::Chunked(1, 6, 13, 0)); + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(9, 2)); + assert_eq!(&input1[9..11], body_reader.get_body(&res)); + assert_eq!(body_reader.body_state, ParseState::Chunked(3, 0, 0, 0)); + // abc: hi + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); + assert_eq!(body_reader.body_state, ParseState::ChunkedFinal(3, 0, 7, 2)); + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); + assert_eq!( + body_reader.body_state, + // NOTE: 0 chunk-size CRLF counted in trailer size too + ParseState::ChunkedFinal(3, 9, 0, 0) + ); + // def: bye + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); + assert_eq!( + body_reader.body_state, + ParseState::ChunkedFinal(3, 19, 0, 1) + ); + // ghi: more + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); + assert_eq!( + body_reader.body_state, + ParseState::ChunkedFinal(3, 30, 0, 2) + ); + + let res = body_reader.read_body(&mut mock_io).await.unwrap(); + assert_eq!(res, None); + assert_eq!(body_reader.body_state, ParseState::Complete(3)); + assert_eq!(body_reader.get_body_overread(), None); + } + + #[tokio::test] + async fn read_with_body_trailers_2() { + init_log(); + let input1 = b"1\r\na\r\n0\r"; + let input2 = b"\nabc: hi\r\n\r\n"; + let mut mock_io = Builder::new().read(&input1[..]).read(&input2[..]).build(); + let mut body_reader = BodyReader::new(false); + body_reader.init_chunked(b""); + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(3, 1)); + assert_eq!(&input1[3..4], body_reader.get_body(&res)); + assert_eq!(body_reader.body_state, ParseState::Chunked(1, 6, 8, 0)); + // 0 \r + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); + assert_eq!(body_reader.body_state, ParseState::Chunked(1, 0, 8, 2)); + // \n TODO: optimize this call out + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); + assert_eq!( + body_reader.body_state, + ParseState::ChunkedFinal(1, 0, 11, 2) + ); + // abc: hi with end in same read + let res = body_reader.read_body(&mut mock_io).await.unwrap(); + assert_eq!(res, None); + assert_eq!(body_reader.body_state, ParseState::Complete(1)); + assert_eq!(body_reader.get_body_overread(), None); + } + + #[tokio::test] + async fn read_with_body_trailers_3() { + init_log(); + let input1 = b"1\r\na\r\n0\r"; + let input2 = b"\nabc: hi"; + let input3 = b"\r\n\r\n"; + let mut mock_io = Builder::new() + .read(&input1[..]) + .read(&input2[..]) + .read(&input3[..]) + .build(); + let mut body_reader = BodyReader::new(false); + body_reader.init_chunked(b""); + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(3, 1)); + assert_eq!(&input1[3..4], body_reader.get_body(&res)); + assert_eq!(body_reader.body_state, ParseState::Chunked(1, 6, 8, 0)); + // 0 \r + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); + assert_eq!(body_reader.body_state, ParseState::Chunked(1, 0, 8, 2)); + // \n TODO: optimize this call out + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); + assert_eq!(body_reader.body_state, ParseState::ChunkedFinal(1, 0, 7, 2)); + // abc: hi + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); + assert_eq!( + body_reader.body_state, + // NOTE: 0 chunk-size CRLF counted in trailer size too + ParseState::ChunkedFinal(1, 9, 0, 0) + ); + let res = body_reader.read_body(&mut mock_io).await.unwrap(); + assert_eq!(res, None); + assert_eq!(body_reader.body_state, ParseState::Complete(1)); + assert_eq!(body_reader.get_body_overread(), None); + } + + #[tokio::test] + async fn read_with_body_trailers_4() { + init_log(); + let input1 = b"1\r\na\r\n0\r"; + let input2 = b"\nabc: hi\r\n\r"; + let input3 = b"\n"; + let mut mock_io = Builder::new() + .read(&input1[..]) + .read(&input2[..]) + .read(&input3[..]) + .build(); + let mut body_reader = BodyReader::new(false); + body_reader.init_chunked(b""); + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(3, 1)); + assert_eq!(&input1[3..4], body_reader.get_body(&res)); + assert_eq!(body_reader.body_state, ParseState::Chunked(1, 6, 8, 0)); + // 0 \r + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); + assert_eq!(body_reader.body_state, ParseState::Chunked(1, 0, 8, 2)); + // \n TODO: optimize this call out + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); + assert_eq!( + body_reader.body_state, + ParseState::ChunkedFinal(1, 0, 10, 2) + ); + // abc: hi + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); + assert_eq!( + body_reader.body_state, + // NOTE: 0 chunk-size CRLF counted in trailer size too + ParseState::ChunkedFinal(1, 9, 0, 3) + ); + let res = body_reader.read_body(&mut mock_io).await.unwrap(); + assert_eq!(res, None); + assert_eq!(body_reader.body_state, ParseState::Complete(1)); + assert_eq!(body_reader.get_body_overread(), None); + } + + #[tokio::test] + async fn read_with_body_trailers_malformed() { + init_log(); + let input1 = b"1\r\na\r\n0\r"; + let input2 = b"\nabc: hi\rn"; + let mut mock_io = Builder::new().read(&input1[..]).read(&input2[..]).build(); + let mut body_reader = BodyReader::new(false); + body_reader.init_chunked(b""); + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(3, 1)); + assert_eq!(&input1[3..4], body_reader.get_body(&res)); + assert_eq!(body_reader.body_state, ParseState::Chunked(1, 6, 8, 0)); + // 0 \r + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); + assert_eq!(body_reader.body_state, ParseState::Chunked(1, 0, 8, 2)); + // abc: hi to \rn + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); + assert_eq!(body_reader.body_state, ParseState::ChunkedFinal(1, 0, 9, 2)); + // \rn not valid + let e = body_reader.read_body(&mut mock_io).await.unwrap_err(); + assert_eq!(*e.etype(), INVALID_TRAILER_END); + assert_eq!(body_reader.body_state, ParseState::Done(1)); + assert_eq!(body_reader.get_body_overread(), None); + } + + #[tokio::test] + async fn read_with_body_trailers_malformed_2() { + init_log(); + let input1 = b"1\r\na\r\n0\r"; + let input2 = b"\nabc: hi\r\n"; + // no end + let mut mock_io = Builder::new().read(&input1[..]).read(&input2[..]).build(); + let mut body_reader = BodyReader::new(false); + body_reader.init_chunked(b""); + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(3, 1)); + assert_eq!(&input1[3..4], body_reader.get_body(&res)); + assert_eq!(body_reader.body_state, ParseState::Chunked(1, 6, 8, 0)); + // 0 \r + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); + assert_eq!(body_reader.body_state, ParseState::Chunked(1, 0, 8, 2)); + // abc: hi to \r\n + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); + assert_eq!(body_reader.body_state, ParseState::ChunkedFinal(1, 0, 9, 2)); + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); + assert_eq!(body_reader.body_state, ParseState::ChunkedFinal(1, 9, 0, 2)); + // EOF + let e = body_reader.read_body(&mut mock_io).await.unwrap_err(); + assert_eq!(*e.etype(), ConnectionClosed); + assert_eq!(body_reader.body_state, ParseState::Done(1)); + assert_eq!(body_reader.get_body_overread(), None); + } + + #[tokio::test] + async fn read_with_body_trailers_malformed_3() { + init_log(); + let input1 = b"1\r\na\r\n0\r\n"; + let input2 = b"abc: hi\r\n"; + let input3 = b"r\n"; + let mut mock_io = Builder::new() + .read(&input1[..]) + .read(&input2[..]) + .read(&input3[..]) + .build(); + let mut body_reader = BodyReader::new(false); + body_reader.init_chunked(b""); + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(3, 1)); + assert_eq!(&input1[3..4], body_reader.get_body(&res)); + assert_eq!(body_reader.body_state, ParseState::Chunked(1, 6, 9, 0)); + // 0 \r\n + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); + assert_eq!(body_reader.body_state, ParseState::ChunkedFinal(1, 0, 0, 2)); + // abc: hi + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); + assert_eq!(body_reader.body_state, ParseState::ChunkedFinal(1, 9, 0, 2)); + // r\n not valid + let e = body_reader.read_body(&mut mock_io).await.unwrap_err(); + assert_eq!(*e.etype(), INVALID_TRAILER_END); + assert_eq!(body_reader.body_state, ParseState::Done(1)); + assert_eq!(body_reader.get_body_overread(), None); + } + + #[tokio::test] + async fn read_with_body_trailers_overflow() { + init_log(); + let input1 = b"1\r\na\r\n0\r\n"; + let input2 = b"abc: "; + let trailer1 = [b'a'; 1024 * 60]; + let trailer2 = [b'a'; 1024 * 5]; + let input3 = b"defghi: "; + let mut mock_io = Builder::new() + .read(&input1[..]) + .read(&input2[..]) + .read(&trailer1[..]) + .read(&CRLF[..]) + .read(&input3[..]) + .read(&trailer2[..]) + .build(); + let mut body_reader = BodyReader::new(false); + body_reader.init_chunked(b""); + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(3, 1)); + assert_eq!(&input1[3..4], body_reader.get_body(&res)); + assert_eq!(body_reader.body_state, ParseState::Chunked(1, 6, 9, 0)); + // 0 \r\n + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); + assert_eq!(body_reader.body_state, ParseState::ChunkedFinal(1, 0, 0, 2)); + // abc: + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); + assert_eq!(body_reader.body_state, ParseState::ChunkedFinal(1, 7, 0, 0)); + // aaa... + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); + assert_eq!( + body_reader.body_state, + ParseState::ChunkedFinal(1, 1024 * 60 + 7, 0, 0) + ); + // CRLF + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); + assert_eq!( + body_reader.body_state, + ParseState::ChunkedFinal(1, 1024 * 60 + 7, 0, 2) + ); + // defghi: + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); + assert_eq!( + body_reader.body_state, + ParseState::ChunkedFinal(1, 1024 * 60 + 17, 0, 0) + ); + // overflow + let e = body_reader.read_body(&mut mock_io).await.unwrap_err(); + assert_eq!(*e.etype(), INVALID_TRAILER_END); + assert_eq!(body_reader.body_state, ParseState::Done(1)); + assert_eq!(body_reader.get_body_overread(), None); + } + #[tokio::test] async fn write_body_cl() { init_log(); @@ -1072,22 +2324,22 @@ mod tests { let data = b"a"; let mut mock_io = Builder::new().write(&data[..]).write(&data[..]).build(); let mut body_writer = BodyWriter::new(); - body_writer.init_http10(); - assert_eq!(body_writer.body_mode, BodyMode::HTTP1_0(0)); + body_writer.init_close_delimited(); + assert_eq!(body_writer.body_mode, BodyMode::UntilClose(0)); let res = body_writer .write_body(&mut mock_io, &data[..]) .await .unwrap() .unwrap(); assert_eq!(res, 1); - assert_eq!(body_writer.body_mode, BodyMode::HTTP1_0(1)); + assert_eq!(body_writer.body_mode, BodyMode::UntilClose(1)); let res = body_writer .write_body(&mut mock_io, &data[..]) .await .unwrap() .unwrap(); assert_eq!(res, 1); - assert_eq!(body_writer.body_mode, BodyMode::HTTP1_0(2)); + assert_eq!(body_writer.body_mode, BodyMode::UntilClose(2)); let res = body_writer.finish(&mut mock_io).await.unwrap().unwrap(); assert_eq!(res, 2); assert_eq!(body_writer.body_mode, BodyMode::Complete(2)); diff --git a/pingora-core/src/protocols/http/v1/client.rs b/pingora-core/src/protocols/http/v1/client.rs index 44440f1c..5f9e4610 100644 --- a/pingora-core/src/protocols/http/v1/client.rs +++ b/pingora-core/src/protocols/http/v1/client.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -53,7 +53,18 @@ pub struct HttpSession { response_header: Option>, request_written: Option>, bytes_sent: usize, + /// Total response body payload bytes received from upstream + body_recv: usize, + // Tracks whether upgrade handshake was successfully completed upgraded: bool, + // Tracks whether downstream request body started sending upgraded bytes + received_upgrade_req_body: bool, + // Tracks whether the response read was ever close-delimited + // (even after body complete) + close_delimited_resp: bool, + // If allowed, does not fail with error on invalid content-length + // (treats as close-delimited response). + allow_h1_response_invalid_content_length: bool, } /// HTTP 1.x client session @@ -81,9 +92,25 @@ impl HttpSession { write_timeout: None, digest, bytes_sent: 0, + body_recv: 0, upgraded: false, + received_upgrade_req_body: false, + close_delimited_resp: false, + allow_h1_response_invalid_content_length: false, } } + + /// Create a new http client session and apply peer options + pub fn new_with_options(stream: Stream, peer: &P) -> Self { + let mut session = Self::new(stream); + if let Some(options) = peer.get_peer_options() { + session.set_allow_h1_response_invalid_content_length( + options.allow_h1_response_invalid_content_length, + ); + } + session + } + /// Write the request header to the server /// After the request header is sent. The caller can either start reading the response or /// sending request body if any. @@ -147,7 +174,7 @@ impl HttpSession { } fn maybe_force_close_body_reader(&mut self) { - if self.upgraded && !self.body_reader.body_done() { + if self.upgraded && self.received_upgrade_req_body && !self.body_reader.body_done() { // request is done, reset the response body to close self.body_reader.init_content_length(0, b""); } @@ -177,6 +204,12 @@ impl HttpSession { // ad-hoc checks super::common::check_dup_content_length(&resp_header.headers)?; + // Validate content-length value if present + // Note: Content-Length is already removed if Transfer-Encoding is present + if !self.allow_h1_response_invalid_content_length { + self.get_content_length()?; + } + Ok(()) } @@ -184,7 +217,17 @@ impl HttpSession { /// This function can be called multiple times, if the headers received are just informational /// headers. pub async fn read_response(&mut self) -> Result { - self.buf.clear(); + if self.preread_body.as_ref().is_none_or(|b| b.is_empty()) { + // preread_body is set after a completed valid response header is read + // if called multiple times (i.e. after informational responses), + // we want to parse the already read buffer bytes as more headers. + // (https://datatracker.ietf.org/doc/html/rfc9110#section-15.2 + // "A 1xx response is terminated by the end of the header section; + // it cannot contain content or trailers.") + // If this next read_response call completes successfully, + // self.buf will be reset to the last response + any body. + self.buf.clear(); + } let mut buf = BytesMut::with_capacity(INIT_HEADER_BUF_SIZE); let mut already_read: usize = 0; loop { @@ -198,12 +241,18 @@ impl HttpSession { ); } - let read_fut = self.underlying_stream.read_buf(&mut buf); - let read_result = match self.read_timeout { - Some(t) => timeout(t, read_fut) - .await - .map_err(|_| Error::explain(ReadTimedout, "while reading response headers"))?, - None => read_fut.await, + let preread = self.preread_body.take(); + let read_result = if let Some(preread) = preread.filter(|b| !b.is_empty()) { + buf.put_slice(preread.get(&self.buf)); + Ok(preread.len()) + } else { + let read_fut = self.underlying_stream.read_buf(&mut buf); + match self.read_timeout { + Some(t) => timeout(t, read_fut).await.map_err(|_| { + Error::explain(ReadTimedout, "while reading response headers") + })?, + None => read_fut.await, + } }; let n = match read_result { Ok(n) => match n { @@ -257,6 +306,9 @@ impl HttpSession { Some(resp.headers.len()), )?); + // TODO: enforce https://datatracker.ietf.org/doc/html/rfc9110#section-15.2 + // "Since HTTP/1.0 did not define any 1xx status codes, + // a server MUST NOT send a 1xx response to an HTTP/1.0 client." response_header.set_version(match resp.version { Some(1) => Version::HTTP_11, Some(0) => Version::HTTP_10, @@ -298,17 +350,41 @@ impl HttpSession { .or_err(InvalidHTTPHeader, "while parsing request header")?; } + let contains_transfer_encoding = response_header + .headers + .contains_key(header::TRANSFER_ENCODING); + let contains_content_length = + response_header.headers.contains_key(header::CONTENT_LENGTH); + + // Transfer encoding overrides content length, so when + // both are present, we MUST remove content length. This is + // https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.3 + if contains_content_length && contains_transfer_encoding { + response_header.remove_header(&header::CONTENT_LENGTH); + } + self.buf = buf; - self.upgraded = self.is_upgrade(&response_header).unwrap_or(false); self.response_header = Some(response_header); self.validate_response()?; + // convert to upgrade body type + // https://datatracker.ietf.org/doc/html/rfc9110#status.101 + // as an "informational" header, this cannot have a body + self.upgraded = self + .is_upgrade(self.response_header.as_deref().expect("init above")) + .unwrap_or(false); + // init body reader if upgrade status has changed body mode + // (read_response_task will immediately try to init body afterwards anyways) + // informational headers will automatically avoid initializing body reader + self.init_body_reader(); + // note that the (request) body writer is converted to close delimit + // when the upgraded body tasks are received return Ok(s); } HeaderParseState::Partial => { /* continue the loop */ } HeaderParseState::Invalid(e) => { return Error::e_because( InvalidHTTPHeader, - format!("buf: {}", String::from_utf8_lossy(&buf).escape_default()), + format!("buf: {}", buf.escape_ascii()), e, ); } @@ -367,7 +443,13 @@ impl HttpSession { None => self.do_read_body().await, }; - result.map(|maybe_body| maybe_body.map(|body_ref| self.body_reader.get_body(&body_ref))) + result.map(|maybe_body| { + maybe_body.map(|body_ref| { + let slice = self.body_reader.get_body(&body_ref); + self.body_recv = self.body_recv.saturating_add(slice.len()); + slice + }) + }) } /// Similar to [`Self::read_body_ref`] but return `Bytes` instead of a slice reference. @@ -376,12 +458,21 @@ impl HttpSession { Ok(read.map(Bytes::copy_from_slice)) } + /// Upstream response body bytes received (payload only; excludes headers/framing). + pub fn body_bytes_received(&self) -> usize { + self.body_recv + } + /// Whether there is no more body to read. pub fn is_body_done(&mut self) -> bool { self.init_body_reader(); self.body_reader.body_done() } + pub fn set_allow_h1_response_invalid_content_length(&mut self, allow: bool) { + self.allow_h1_response_invalid_content_length = allow; + } + pub(super) fn get_headers_raw(&self) -> &[u8] { // TODO: these get_*() could panic. handle them better self.raw_header.as_ref().unwrap().get(&self.buf[..]) @@ -411,17 +502,36 @@ impl HttpSession { /// For HTTP 1.1, assume keepalive as long as there is no `Connection: Close` request header. /// For HTTP 1.0, only keepalive if there is an explicit header `Connection: keep-alive`. pub fn respect_keepalive(&mut self) { - if self.get_status() == Some(StatusCode::SWITCHING_PROTOCOLS) { + if self.upgraded || self.get_status() == Some(StatusCode::SWITCHING_PROTOCOLS) { // make sure the connection is closed at the end when 101/upgrade is used self.set_keepalive(None); return; } + if self.body_reader.need_init() || self.close_delimited_resp { + // Defense-in-depth: response body close-delimited (or no body interpretation + // upon reuse check) + // explicitly disable reuse + self.set_keepalive(None); + return; + } if self.body_reader.has_bytes_overread() { // if more bytes sent than expected, there are likely more bytes coming // so don't reuse this connection self.set_keepalive(None); return; } + + // Per [RFC 9112 Section 6.1-16](https://datatracker.ietf.org/doc/html/rfc9112#section-6.1-16), + // if Transfer-Encoding is received in HTTP/1.0 response, connection MUST be closed after processing. + if self.resp_header().map(|h| h.version) == Some(Version::HTTP_10) + && self + .resp_header() + .and_then(|h| h.headers.get(header::TRANSFER_ENCODING)) + .is_some() + { + self.set_keepalive(None); + return; + } if let Some(keepalive) = self.is_connection_keepalive() { if keepalive { let (timeout, _max_use) = self.get_keepalive_values(); @@ -442,8 +552,6 @@ impl HttpSession { // Whether this session will be kept alive pub fn will_keepalive(&self) -> bool { - // TODO: check self.body_writer. If it is http1.0 type then keepalive - // cannot be used because the connection close is the signal of end body !matches!(self.keepalive_timeout, KeepaliveStatus::Off) } @@ -514,7 +622,7 @@ impl HttpSession { fn init_body_reader(&mut self) { if self.body_reader.need_init() { - /* follow https://tools.ietf.org/html/rfc7230#section-3.3.3 */ + // follow https://datatracker.ietf.org/doc/html/rfc9112#section-6.3 let preread_body = self.preread_body.as_ref().unwrap().get(&self.buf[..]); if let Some(req) = self.request_written.as_ref() { @@ -543,14 +651,16 @@ impl HttpSession { }; if upgraded { - self.body_reader.init_http10(preread_body); + self.body_reader.init_close_delimited(preread_body); + self.close_delimited_resp = true; } else if self.is_chunked_encoding() { // if chunked encoding, content-length should be ignored self.body_reader.init_chunked(preread_body); - } else if let Some(cl) = self.get_content_length() { + } else if let Some(cl) = self.get_content_length().unwrap_or(None) { self.body_reader.init_content_length(cl, preread_body); } else { - self.body_reader.init_http10(preread_body); + self.body_reader.init_close_delimited(preread_body); + self.close_delimited_resp = true; } } } @@ -574,7 +684,24 @@ impl HttpSession { } } - fn get_content_length(&self) -> Option { + /// Was this request successfully turned into an upgraded connection? + /// + /// Both the request had to have been an `Upgrade` request + /// and the response had to have been a `101 Switching Protocols`. + pub fn was_upgraded(&self) -> bool { + self.upgraded + } + + /// If upgraded but not yet converted, then body writer will be + /// converted to http1.0 mode (pass through bytes as-is). + pub fn maybe_upgrade_body_writer(&mut self) { + if self.was_upgraded() { + self.received_upgrade_req_body = true; + self.body_writer.convert_to_close_delimited(); + } + } + + fn get_content_length(&self) -> Result> { buf_to_content_length( self.get_header(header::CONTENT_LENGTH) .map(|v| v.as_bytes()), @@ -582,20 +709,17 @@ impl HttpSession { } fn is_chunked_encoding(&self) -> bool { - is_header_value_chunked_encoding(self.get_header(header::TRANSFER_ENCODING)) + self.resp_header() + .map(|h| is_chunked_encoding_from_headers(&h.headers)) + .unwrap_or(false) } fn init_req_body_writer(&mut self, header: &RequestHeader) { - if is_upgrade_req(header) { - self.body_writer.init_http10(); - } else { - self.init_body_writer_comm(&header.headers) - } + self.init_body_writer_comm(&header.headers) } fn init_body_writer_comm(&mut self, headers: &HMap) { - let te_value = headers.get(http::header::TRANSFER_ENCODING); - if is_header_value_chunked_encoding(te_value) { + if is_chunked_encoding_from_headers(headers) { // transfer-encoding takes priority over content-length self.body_writer.init_chunked(); } else { @@ -606,9 +730,11 @@ impl HttpSession { self.body_writer.init_content_length(length); } None => { - /* TODO: 1. connection: keepalive cannot be used, - 2. mark connection must be closed */ - self.body_writer.init_http10(); + // Per RFC 9112: "Request messages are never close-delimited because they are + // always explicitly framed by length or transfer coding, with the absence of + // both implying the request ends immediately after the header section." + // Requests without Content-Length or Transfer-Encoding have 0 body + self.body_writer.init_content_length(0); } } } @@ -646,8 +772,12 @@ impl HttpSession { "Response body: {} bytes, end: {end_of_body}", body.as_ref().map_or(0, |b| b.len()) ); - trace!("Response body: {body:?}"); - Ok(HttpTask::Body(body, end_of_body)) + trace!("Response body: {body:?}, upgraded: {}", self.upgraded); + if self.upgraded { + Ok(HttpTask::UpgradedBody(body, end_of_body)) + } else { + Ok(HttpTask::Body(body, end_of_body)) + } } // TODO: support h1 trailer } @@ -716,7 +846,7 @@ fn parse_resp_buffer<'buf>( // TODO: change it to to_buf #[inline] -pub(crate) fn http_req_header_to_wire(req: &RequestHeader) -> Option { +pub fn http_req_header_to_wire(req: &RequestHeader) -> Option { let mut buf = BytesMut::with_capacity(512); // Request-Line @@ -753,8 +883,10 @@ impl UniqueID for HttpSession { #[cfg(test)] mod tests_stream { use super::*; - use crate::protocols::http::v1::body::ParseState; + use crate::protocols::http::v1::body::{BodyMode, ParseState}; + use crate::upstreams::peer::PeerOptions; use crate::ErrorType; + use rstest::rstest; use tokio_test::io::Builder; fn init_log() { @@ -802,12 +934,161 @@ mod tests_stream { assert_eq!(input_header.len(), res.unwrap()); let res = http_stream.read_body_ref().await.unwrap(); assert_eq!(res.unwrap(), input_body); - assert_eq!(http_stream.body_reader.body_state, ParseState::HTTP1_0(3)); + assert_eq!( + http_stream.body_reader.body_state, + ParseState::UntilClose(3) + ); let res = http_stream.read_body_ref().await.unwrap(); assert_eq!(res, None); assert_eq!(http_stream.body_reader.body_state, ParseState::Complete(3)); } + #[tokio::test] + async fn body_bytes_received_content_length() { + init_log(); + let input_header = b"HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\n"; + let input_body = b"abc"; + let input_close = b""; // simulating close + let mock_io = Builder::new() + .read(&input_header[..]) + .read(&input_body[..]) + .read(&input_close[..]) + .build(); + let mut http = HttpSession::new(Box::new(mock_io)); + http.read_response().await.unwrap(); + let _ = http.read_body_ref().await.unwrap(); + let _ = http.read_body_ref().await.unwrap(); + assert_eq!(http.body_bytes_received(), 3); + } + + #[tokio::test] + async fn body_bytes_received_chunked() { + init_log(); + let input_header = b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n"; + let input_body = b"3\r\nabc\r\n0\r\n\r\n"; + let mock_io = Builder::new() + .read(&input_header[..]) + .read(&input_body[..]) + .build(); + let mut http = HttpSession::new(Box::new(mock_io)); + http.read_response().await.unwrap(); + // first read returns the payload chunk + let first = http.read_body_ref().await.unwrap(); + assert_eq!(first.unwrap(), b"abc"); + // next read consumes terminating chunk + let _ = http.read_body_ref().await.unwrap(); + assert_eq!(http.body_bytes_received(), 3); + } + + #[tokio::test] + async fn h1_body_bytes_received_http10_until_close() { + init_log(); + let header = b"HTTP/1.1 200 OK\r\n\r\n"; + let body = b"abc"; + let close = b""; + let mock = Builder::new() + .read(&header[..]) + .read(&body[..]) + .read(&close[..]) + .build(); + let mut http = HttpSession::new(Box::new(mock)); + http.read_response().await.unwrap(); + let _ = http.read_body_ref().await.unwrap(); + let _ = http.read_body_ref().await.unwrap(); + assert_eq!(http.body_bytes_received(), 3); + } + + #[tokio::test] + async fn h1_body_bytes_received_chunked_multi() { + init_log(); + let header = b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n"; + let body = b"1\r\na\r\n2\r\nbc\r\n0\r\n\r\n"; // payload abc + let mock = Builder::new().read(&header[..]).read(&body[..]).build(); + let mut http = HttpSession::new(Box::new(mock)); + http.read_response().await.unwrap(); + // first chunk + let s1 = http.read_body_ref().await.unwrap().unwrap(); + assert_eq!(s1, b"a"); + // second chunk + let s2 = http.read_body_ref().await.unwrap().unwrap(); + assert_eq!(s2, b"bc"); + // end + let _ = http.read_body_ref().await.unwrap(); + assert_eq!(http.body_bytes_received(), 3); + } + + #[tokio::test] + async fn h1_body_bytes_received_preread_in_header_buf() { + init_log(); + // header and a small body arrive together + let combined = b"HTTP/1.1 200 OK\r\n\r\nabc"; + let close = b""; + let mock = Builder::new().read(&combined[..]).read(&close[..]).build(); + let mut http = HttpSession::new(Box::new(mock)); + http.read_response().await.unwrap(); + // first body read should return the preread bytes + let s = http.read_body_ref().await.unwrap().unwrap(); + assert_eq!(s, b"abc"); + // then EOF + let _ = http.read_body_ref().await.unwrap(); + assert_eq!(http.body_bytes_received(), 3); + } + + #[tokio::test] + async fn h1_body_bytes_received_overread_content_length() { + init_log(); + let header1 = b"HTTP/1.1 200 OK\r\n"; + let header2 = b"Content-Length: 2\r\n\r\n"; + let body = b"abc"; // one extra byte beyond CL + let mock = Builder::new() + .read(&header1[..]) + .read(&header2[..]) + .read(&body[..]) + .build(); + let mut http = HttpSession::new(Box::new(mock)); + http.read_response().await.unwrap(); + let s = http.read_body_ref().await.unwrap().unwrap(); + assert_eq!(s, b"ab"); + // then end + let _ = http.read_body_ref().await.unwrap(); + assert_eq!(http.body_bytes_received(), 2); + } + + #[tokio::test] + async fn h1_body_bytes_received_after_100_continue() { + init_log(); + let info = b"HTTP/1.1 100 Continue\r\n\r\n"; + let header = b"HTTP/1.1 200 OK\r\nContent-Length: 1\r\n\r\n"; + let body = b"x"; + let mock = Builder::new() + .read(&info[..]) + .read(&header[..]) + .read(&body[..]) + .build(); + let mut http = HttpSession::new(Box::new(mock)); + // read informational + match http.read_response_task().await.unwrap() { + HttpTask::Header(h, eob) => { + assert_eq!(h.status, 100); + assert!(!eob); + } + _ => panic!("expected informational header"), + } + // read final header + match http.read_response_task().await.unwrap() { + HttpTask::Header(h, eob) => { + assert_eq!(h.status, 200); + assert!(!eob); + } + _ => panic!("expected final header"), + } + // read body + let s = http.read_body_ref().await.unwrap().unwrap(); + assert_eq!(s, b"x"); + let _ = http.read_body_ref().await.unwrap(); + assert_eq!(http.body_bytes_received(), 1); + } + #[tokio::test] async fn read_response_overread() { init_log(); @@ -917,6 +1198,138 @@ mod tests_stream { assert_eq!(wire.len(), n); } + #[rstest] + #[case::negative("-1")] + #[case::not_a_number("abc")] + #[case::float("1.5")] + #[case::empty("")] + #[case::spaces(" ")] + #[case::mixed("123abc")] + #[tokio::test] + async fn validate_response_rejects_invalid_content_length(#[case] invalid_value: &str) { + init_log(); + let input = format!( + "HTTP/1.1 200 OK\r\nServer: test\r\nContent-Length: {}\r\n\r\n", + invalid_value + ); + let mock_io = Builder::new().read(input.as_bytes()).build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + // read_response calls validate_response internally, so it should fail here + let res = http_stream.read_response().await; + assert!(res.is_err()); + assert_eq!(res.unwrap_err().etype(), &ErrorType::InvalidHTTPHeader); + } + + #[tokio::test] + async fn allow_invalid_content_length_close_delimited_when_configured() { + init_log(); + let input_header = b"HTTP/1.1 200 OK\r\nServer: test\r\nContent-Length: abc\r\n\r\n"; + let input_body = b"abc"; + let input_close = b""; + let mock_io = Builder::new() + .read(&input_header[..]) + .read(&input_body[..]) + .read(&input_close[..]) + .build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + let mut peer_options = PeerOptions::new(); + peer_options.allow_h1_response_invalid_content_length = true; + http_stream.set_allow_h1_response_invalid_content_length( + peer_options.allow_h1_response_invalid_content_length, + ); + + let res = http_stream.read_response().await; + assert!(res.is_ok()); + let body = http_stream.read_body_ref().await.unwrap().unwrap(); + assert_eq!(body, input_body); + assert_eq!( + http_stream.body_reader.body_state, + ParseState::UntilClose(3) + ); + let body = http_stream.read_body_ref().await.unwrap(); + assert!(body.is_none()); + assert_eq!(http_stream.body_reader.body_state, ParseState::Complete(3)); + } + + #[rstest] + #[case::valid_zero("0")] + #[case::valid_small("123")] + #[case::valid_large("999999")] + #[tokio::test] + async fn validate_response_accepts_valid_content_length(#[case] valid_value: &str) { + init_log(); + let input = format!( + "HTTP/1.1 200 OK\r\nServer: test\r\nContent-Length: {}\r\n\r\n", + valid_value + ); + let mock_io = Builder::new().read(input.as_bytes()).build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + let res = http_stream.read_response().await; + assert!(res.is_ok()); + } + + #[tokio::test] + async fn validate_response_accepts_no_content_length() { + init_log(); + let input = b"HTTP/1.1 200 OK\r\nServer: test\r\n\r\n"; + let mock_io = Builder::new().read(&input[..]).build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + let res = http_stream.read_response().await; + assert!(res.is_ok()); + } + + #[rstest] + #[case(None, None, None)] + #[case(Some("transfer-encoding"), None, None)] + #[case(Some("transfer-encoding"), Some("CONTENT-LENGTH"), Some("4"))] + #[case(Some("TRANSFER-ENCODING"), Some("CONTENT-LENGTH"), Some("4"))] + #[case(Some("TRANSFER-ENCODING"), None, None)] + #[case(None, Some("CONTENT-LENGTH"), Some("4"))] + #[case(Some("TRANSFER-ENCODING"), Some("content-length"), Some("4"))] + #[case(None, Some("content-length"), Some("4"))] + #[case(Some("TRANSFER-ENCODING"), Some("CONTENT-LENGTH"), Some("abc"))] + #[tokio::test] + async fn response_transfer_encoding_and_content_length_handling( + #[case] transfer_encoding_header: Option<&str>, + #[case] content_length_header: Option<&str>, + #[case] content_length_value: Option<&str>, + ) { + init_log(); + let input1 = b"HTTP/1.1 200 OK\r\n"; + let mut input2 = "Server: test\r\n".to_owned(); + + if let Some(transfer_encoding) = transfer_encoding_header { + input2 += &format!("{transfer_encoding}: chunked\r\n"); + } + if let Some(content_length) = content_length_header { + let value = content_length_value.unwrap_or("4"); + input2 += &format!("{content_length}: {value}\r\n") + } + + input2 += "\r\n"; + let mock_io = Builder::new() + .read(&input1[..]) + .read(input2.as_bytes()) + .build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + let _ = http_stream.read_response().await.unwrap(); + + match (content_length_header, transfer_encoding_header) { + (Some(_) | None, Some(_)) => { + assert!(http_stream.get_header(header::TRANSFER_ENCODING).is_some()); + assert!(http_stream.get_header(header::CONTENT_LENGTH).is_none()); + } + (Some(_), None) => { + assert!(http_stream.get_header(header::TRANSFER_ENCODING).is_none()); + assert!(http_stream.get_header(header::CONTENT_LENGTH).is_some()); + } + _ => { + assert!(http_stream.get_header(header::CONTENT_LENGTH).is_none()); + assert!(http_stream.get_header(header::TRANSFER_ENCODING).is_none()); + } + } + } + #[tokio::test] #[should_panic(expected = "There is still data left to write.")] async fn write_timeout() { @@ -938,7 +1351,8 @@ mod tests_stream { #[tokio::test] #[should_panic(expected = "There is still data left to write.")] async fn write_body_timeout() { - let header = b"POST /test HTTP/1.1\r\n\r\n"; + // Test needs Content-Length header to actually attempt to write body + let header = b"POST /test HTTP/1.1\r\nContent-Length: 3\r\n\r\n"; let body = b"abc"; let mock_io = Builder::new() .write(&header[..]) @@ -948,7 +1362,8 @@ mod tests_stream { let mut http_stream = HttpSession::new(Box::new(mock_io)); http_stream.write_timeout = Some(Duration::from_secs(1)); - let new_request = RequestHeader::build("POST", b"/test", None).unwrap(); + let mut new_request = RequestHeader::build("POST", b"/test", None).unwrap(); + new_request.insert_header("Content-Length", "3").unwrap(); http_stream .write_request_header(Box::new(new_request)) .await @@ -1004,33 +1419,395 @@ mod tests_stream { } } + #[tokio::test] + async fn read_informational_combined_with_final() { + init_log(); + let input = b"HTTP/1.1 100 Continue\r\n\r\nHTTP/1.1 200 OK\r\nServer: pingora\r\nContent-Length: 3\r\n\r\n"; + let body = b"abc"; + let mock_io = Builder::new().read(&input[..]).read(&body[..]).build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + + // read 100 header first + let task = http_stream.read_response_task().await.unwrap(); + match task { + HttpTask::Header(h, eob) => { + assert_eq!(h.status, 100); + assert!(!eob); + } + _ => { + panic!("task should be header") + } + } + // read 200 header next + let task = http_stream.read_response_task().await.unwrap(); + match task { + HttpTask::Header(h, eob) => { + assert_eq!(h.status, 200); + assert!(!eob); + } + _ => { + panic!("task should be header") + } + } + // read body next + let task = http_stream.read_response_task().await.unwrap(); + match task { + HttpTask::Body(b, eob) => { + assert_eq!(b.unwrap(), &body[..]); + assert!(eob); + } + _ => { + panic!("task {task:?} should be body") + } + } + } + + #[tokio::test] + async fn read_informational_multiple_combined_with_final() { + init_log(); + let input = b"HTTP/1.1 100 Continue\r\n\r\nHTTP/1.1 103 Early Hints\r\n\r\nHTTP/1.1 204 No Content\r\nServer: pingora\r\n\r\n"; + let mock_io = Builder::new().read(&input[..]).build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + + // read 100 header first + let task = http_stream.read_response_task().await.unwrap(); + match task { + HttpTask::Header(h, eob) => { + assert_eq!(h.status, 100); + assert!(!eob); + } + _ => { + panic!("task should be header") + } + } + + // then read 103 header + let task = http_stream.read_response_task().await.unwrap(); + match task { + HttpTask::Header(h, eob) => { + assert_eq!(h.status, 103); + assert!(!eob); + } + _ => { + panic!("task should be header") + } + } + + // finally read 200 header + let task = http_stream.read_response_task().await.unwrap(); + match task { + HttpTask::Header(h, eob) => { + assert_eq!(h.status, 204); + assert!(eob); + } + _ => { + panic!("task should be header") + } + } + } + + #[tokio::test] + async fn read_informational_then_keepalive_response() { + init_log(); + // Test that after reading an informational response (100 Continue), + // keepalive still works properly + let wire = b"GET / HTTP/1.1\r\n\r\n"; + let input1 = b"HTTP/1.1 100 Continue\r\n\r\n"; + let input2 = b"HTTP/1.1 200 OK\r\nContent-Length: 13\r\n\r\n"; // Proper Content-Length + let body = b"response body"; + + let mock_io = Builder::new() + .write(&wire[..]) + .read(&input1[..]) + .read(&input2[..]) + .read(&body[..]) + .build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + + // Write request + let new_request = RequestHeader::build("GET", b"/", None).unwrap(); + http_stream + .write_request_header(Box::new(new_request)) + .await + .unwrap(); + + // Read 100 Continue + let task = http_stream.read_response_task().await.unwrap(); + match task { + HttpTask::Header(h, eob) => { + assert_eq!(h.status, 100); + assert!(!eob); + } + _ => { + panic!("task should be informational header") + } + } + + // Read final 200 OK header + let task = http_stream.read_response_task().await.unwrap(); + match task { + HttpTask::Header(h, eob) => { + assert_eq!(h.status, 200); + assert!(!eob); // Should not be end of body yet + } + _ => { + panic!("task should be final header") + } + } + + // Read body + let task = http_stream.read_response_task().await.unwrap(); + match task { + HttpTask::Body(b, eob) => { + assert_eq!(b.unwrap(), &body[..]); + assert!(eob); // EOF - body is complete + } + _ => { + panic!("task {task:?} should be body") + } + } + + assert_eq!(http_stream.body_reader.body_state, ParseState::Complete(13)); + + // Keepalive should be enabled for properly-framed HTTP/1.1 + http_stream.respect_keepalive(); + assert!(http_stream.will_keepalive()); + } + #[tokio::test] async fn init_body_for_upgraded_req() { - use crate::protocols::http::v1::body::BodyMode; + let wire = + b"GET / HTTP/1.1\r\nConnection: Upgrade\r\nUpgrade: WS\r\nContent-Length: 0\r\n\r\n"; + let input1 = b"HTTP/1.1 101 Switching Protocols\r\n\r\n"; + let input2 = b"PAYLOAD"; + let ws_data = b"data"; + + let mock_io = Builder::new() + .write(wire) + .read(&input1[..]) + .write(&ws_data[..]) + .read(&input2[..]) + .build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + let mut new_request = RequestHeader::build("GET", b"/", None).unwrap(); + new_request.insert_header("Connection", "Upgrade").unwrap(); + new_request.insert_header("Upgrade", "WS").unwrap(); + new_request.insert_header("Content-Length", "0").unwrap(); + let _ = http_stream + .write_request_header(Box::new(new_request)) + .await + .unwrap(); + assert_eq!( + http_stream.body_writer.body_mode, + BodyMode::ContentLength(0, 0) + ); + assert!(http_stream.body_writer.finished()); + + let task = http_stream.read_response_task().await.unwrap(); + match task { + HttpTask::Header(h, eob) => { + assert_eq!(h.status, 101); + assert!(!eob); + } + _ => { + panic!("task should be header") + } + } + // changed body mode + assert_eq!( + http_stream.body_reader.body_state, + ParseState::UntilClose(0) + ); + // request writer will be explicitly initialized in a separate call + assert!(http_stream.body_writer.finished()); + http_stream.maybe_upgrade_body_writer(); + assert!(!http_stream.body_writer.finished()); + assert_eq!(http_stream.body_writer.body_mode, BodyMode::UntilClose(0)); + + http_stream.write_body(&ws_data[..]).await.unwrap(); + // read WS + let task = http_stream.read_response_task().await.unwrap(); + match task { + HttpTask::UpgradedBody(b, eob) => { + assert_eq!(b.unwrap(), &input2[..]); + assert!(!eob); + } + _ => { + panic!("task should be upgraded body") + } + } + } + + #[tokio::test] + async fn init_preread_body_for_upgraded_req() { let wire = b"GET / HTTP/1.1\r\nConnection: Upgrade\r\nUpgrade: WS\r\nContent-Length: 0\r\n\r\n"; - let mock_io = Builder::new().write(wire).build(); + let input = b"HTTP/1.1 101 Switching Protocols\r\n\r\nPAYLOAD"; + let ws_data = b"data"; + + let mock_io = Builder::new() + .write(wire) + .read(&input[..]) + .write(&ws_data[..]) + .build(); let mut http_stream = HttpSession::new(Box::new(mock_io)); let mut new_request = RequestHeader::build("GET", b"/", None).unwrap(); new_request.insert_header("Connection", "Upgrade").unwrap(); new_request.insert_header("Upgrade", "WS").unwrap(); - // CL is ignored when Upgrade presents new_request.insert_header("Content-Length", "0").unwrap(); let _ = http_stream .write_request_header(Box::new(new_request)) .await .unwrap(); - assert_eq!(http_stream.body_writer.body_mode, BodyMode::HTTP1_0(0)); + assert_eq!( + http_stream.body_writer.body_mode, + BodyMode::ContentLength(0, 0) + ); + assert!(http_stream.body_writer.finished()); + + let task = http_stream.read_response_task().await.unwrap(); + match task { + HttpTask::Header(h, eob) => { + assert_eq!(h.status, 101); + assert!(!eob); + } + _ => { + panic!("task should be header") + } + } + // changed body mode + assert_eq!( + http_stream.body_reader.body_state, + ParseState::UntilClose(0) + ); + // request writer will be explicitly initialized in a separate call + assert!(http_stream.body_writer.finished()); + http_stream.maybe_upgrade_body_writer(); + + assert!(!http_stream.body_writer.finished()); + assert_eq!(http_stream.body_writer.body_mode, BodyMode::UntilClose(0)); + + http_stream.write_body(&ws_data[..]).await.unwrap(); + // read WS + let task = http_stream.read_response_task().await.unwrap(); + match task { + HttpTask::UpgradedBody(b, eob) => { + assert_eq!(b.unwrap(), &b"PAYLOAD"[..]); + assert!(!eob); + } + _ => { + panic!("task should be upgraded body") + } + } + } + + #[tokio::test] + async fn read_body_eos_after_upgrade() { + let wire = + b"GET / HTTP/1.1\r\nConnection: Upgrade\r\nUpgrade: WS\r\nContent-Length: 10\r\n\r\n"; + let input1 = b"HTTP/1.1 101 Switching Protocols\r\n\r\n"; + let input2 = b"PAYLOAD"; + let body_data = b"0123456789"; + let ws_data = b"data"; + + let mock_io = Builder::new() + .write(wire) + .read(&input1[..]) + .write(&body_data[..]) + .read(&input2[..]) + .write(&ws_data[..]) + .build(); + + let mut http_stream = HttpSession::new(Box::new(mock_io)); + let mut new_request = RequestHeader::build("GET", b"/", None).unwrap(); + new_request.insert_header("Connection", "Upgrade").unwrap(); + new_request.insert_header("Upgrade", "WS").unwrap(); + new_request.insert_header("Content-Length", "10").unwrap(); + let _ = http_stream + .write_request_header(Box::new(new_request)) + .await + .unwrap(); + assert_eq!( + http_stream.body_writer.body_mode, + BodyMode::ContentLength(10, 0) + ); + assert!(!http_stream.body_writer.finished()); + + let task = http_stream.read_response_task().await.unwrap(); + match task { + HttpTask::Header(h, eob) => { + assert_eq!(h.status, 101); + assert!(!eob); + } + _ => { + panic!("task should be header") + } + } + // changed body mode + assert_eq!( + http_stream.body_reader.body_state, + ParseState::UntilClose(0) + ); + + // write regular request payload + http_stream.write_body(&body_data[..]).await.unwrap(); + http_stream.finish_body().await.unwrap(); + + // we should still be able to read more response body + // read WS + let task = http_stream.read_response_task().await.unwrap(); + match task { + HttpTask::UpgradedBody(b, eob) => { + assert_eq!(b.unwrap(), &input2[..]); + assert!(!eob); + } + t => { + panic!("task {t:?} should be upgraded body") + } + } + + // body IS finished, prior to upgrade on the downstream side + assert!(http_stream.body_writer.finished()); + http_stream.maybe_upgrade_body_writer(); + + assert!(!http_stream.body_writer.finished()); + assert_eq!(http_stream.body_writer.body_mode, BodyMode::UntilClose(0)); + + http_stream.write_body(&ws_data[..]).await.unwrap(); + assert_eq!(http_stream.body_writer.body_mode, BodyMode::UntilClose(4)); + http_stream.finish_body().await.unwrap(); } #[tokio::test] async fn read_switching_protocol() { init_log(); + + let wire = + b"GET / HTTP/1.1\r\nConnection: Upgrade\r\nUpgrade: WS\r\nContent-Length: 0\r\n\r\n"; let input1 = b"HTTP/1.1 101 Continue\r\n\r\n"; let input2 = b"PAYLOAD"; - let mock_io = Builder::new().read(&input1[..]).read(&input2[..]).build(); + + let mock_io = Builder::new() + .write(&wire[..]) + .read(&input1[..]) + .read(&input2[..]) + .build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + let mut new_request = RequestHeader::build("GET", b"/", None).unwrap(); + new_request.insert_header("Connection", "Upgrade").unwrap(); + new_request.insert_header("Upgrade", "WS").unwrap(); + new_request.insert_header("Content-Length", "0").unwrap(); + let _ = http_stream + .write_request_header(Box::new(new_request)) + .await + .unwrap(); + assert_eq!( + http_stream.body_writer.body_mode, + BodyMode::ContentLength(0, 0) + ); + assert!(http_stream.body_writer.finished()); // read 100 header first let task = http_stream.read_response_task().await.unwrap(); @@ -1046,18 +1823,18 @@ mod tests_stream { // read body let task = http_stream.read_response_task().await.unwrap(); match task { - HttpTask::Body(b, eob) => { + HttpTask::UpgradedBody(b, eob) => { assert_eq!(b.unwrap(), &input2[..]); assert!(!eob); } _ => { - panic!("task should be body") + panic!("task should be upgraded body") } } // read body let task = http_stream.read_response_task().await.unwrap(); match task { - HttpTask::Body(b, eob) => { + HttpTask::UpgradedBody(b, eob) => { assert!(b.is_none()); assert!(eob); } @@ -1115,7 +1892,9 @@ mod tests_stream { init_log(); async fn build_resp_with_keepalive(conn: &str) -> HttpSession { - let input = format!("HTTP/1.1 200 OK\r\nConnection: {conn}\r\n\r\n"); + // Include Content-Length to avoid triggering defense-in-depth close-delimited check + let input = + format!("HTTP/1.1 200 OK\r\nConnection: {conn}\r\nContent-Length: 0\r\n\r\n"); let mock_io = Builder::new().read(input.as_bytes()).build(); let mut http_stream = HttpSession::new(Box::new(mock_io)); let res = http_stream.read_response().await; @@ -1258,6 +2037,193 @@ mod tests_stream { } /* Note: body tests are covered in server.rs */ + + #[tokio::test] + async fn test_http10_response_with_transfer_encoding_disables_keepalive() { + // Transfer-Encoding in HTTP/1.0 response requires connection close + let input = b"HTTP/1.0 200 OK\r\n\ +Transfer-Encoding: chunked\r\n\ +Connection: keep-alive\r\n\ +\r\n\ +5\r\n\ +hello\r\n\ +0\r\n\ +\r\n"; + let mock_io = Builder::new().read(&input[..]).build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + http_stream.read_response().await.unwrap(); + http_stream.respect_keepalive(); + + // Keepalive must be disabled even if Connection: keep-alive header present + assert!(!http_stream.will_keepalive()); + assert_eq!(http_stream.keepalive_timeout, KeepaliveStatus::Off); + } + + #[tokio::test] + async fn test_http11_response_with_transfer_encoding_allows_keepalive() { + // HTTP/1.1 with Transfer-Encoding should allow keepalive (contrast with HTTP/1.0) + let input = b"HTTP/1.1 200 OK\r\n\ +Transfer-Encoding: chunked\r\n\ +\r\n\ +5\r\n\ +hello\r\n\ +0\r\n\ +\r\n"; + let mock_io = Builder::new().read(&input[..]).build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + http_stream.read_response().await.unwrap(); + http_stream.respect_keepalive(); + + // HTTP/1.1 should allow keepalive by default + assert!(http_stream.will_keepalive()); + } + + #[tokio::test] + async fn test_response_multiple_transfer_encoding_headers() { + init_log(); + // Multiple TE headers should be treated as comma-separated + let input = b"HTTP/1.1 200 OK\r\n\ +Transfer-Encoding: gzip\r\n\ +Transfer-Encoding: chunked\r\n\ +\r\n\ +5\r\n\ +hello\r\n\ +0\r\n\ +\r\n"; + + let mock_io = Builder::new().read(&input[..]).build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + http_stream.read_response().await.unwrap(); + + // Should correctly identify chunked encoding from last header + assert!(http_stream.is_chunked_encoding()); + + // Verify body can be read correctly + let body = http_stream.read_body_bytes().await.unwrap(); + assert_eq!(body.as_ref().unwrap().as_ref(), b"hello"); + http_stream.finish_body().await.unwrap(); + } + + #[tokio::test] + async fn test_response_multiple_te_headers_chunked_not_last() { + init_log(); + // Chunked in first header but not last - should NOT be chunked + let input = b"HTTP/1.1 200 OK\r\n\ +Transfer-Encoding: chunked\r\n\ +Transfer-Encoding: identity\r\n\ +Content-Length: 5\r\n\ +\r\n\ +hello"; + + let mock_io = Builder::new().read(&input[..]).build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + http_stream.read_response().await.unwrap(); + + // Should NOT be chunked - identity is final encoding + assert!(!http_stream.is_chunked_encoding()); + } + + #[test] + fn test_is_chunked_encoding_before_response() { + // Test that is_chunked_encoding returns false when no response received yet + let mock_io = Builder::new().build(); + let http_stream = HttpSession::new(Box::new(mock_io)); + + // Should return false when no response header exists yet + assert!(!http_stream.is_chunked_encoding()); + } + + #[tokio::test] + async fn write_request_body_implicit_zero_content_length() { + init_log(); + let header = b"POST /test HTTP/1.1\r\n\r\n"; + let mock_io = Builder::new().write(&header[..]).build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + + let new_request = RequestHeader::build("POST", b"/test", None).unwrap(); + http_stream + .write_request_header(Box::new(new_request)) + .await + .unwrap(); + + assert_eq!( + http_stream.body_writer.body_mode, + BodyMode::ContentLength(0, 0) + ); + } + + #[tokio::test] + async fn write_request_body_with_content_length() { + init_log(); + let header = b"POST /test HTTP/1.1\r\nContent-Length: 3\r\n\r\n"; + let body = b"abc"; + let mock_io = Builder::new().write(&header[..]).write(&body[..]).build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + + let mut new_request = RequestHeader::build("POST", b"/test", None).unwrap(); + new_request.insert_header("Content-Length", "3").unwrap(); + http_stream + .write_request_header(Box::new(new_request)) + .await + .unwrap(); + + assert_eq!( + http_stream.body_writer.body_mode, + BodyMode::ContentLength(3, 0) + ); + + http_stream.write_body(body).await.unwrap(); + assert_eq!( + http_stream.body_writer.body_mode, + BodyMode::ContentLength(3, 3) + ); + } + + #[tokio::test] + async fn close_delimited_response_explicitly_disables_keepalive() { + init_log(); + // Defense-in-depth: if we read a close-delimited response body, + // keepalive should be disabled + let wire = b"GET / HTTP/1.1\r\n\r\n"; + let input_header = b"HTTP/1.1 200 OK\r\n\r\n"; + let input_body = b"abc"; + let input_close = b""; // simulating close + let mock_io = Builder::new() + .write(&wire[..]) + .read(&input_header[..]) + .read(&input_body[..]) + .read(&input_close[..]) + .build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + + // Write request first + let new_request = RequestHeader::build("GET", b"/", None).unwrap(); + http_stream + .write_request_header(Box::new(new_request)) + .await + .unwrap(); + + // Read response + http_stream.read_response().await.unwrap(); + + // Read the body (this will initialize the body reader) + http_stream.read_body_ref().await.unwrap(); + + // Body reader should be in UntilClose mode (close-delimited response) + assert_eq!( + http_stream.body_reader.body_state, + ParseState::UntilClose(3) + ); + + let res2 = http_stream.read_body_ref().await.unwrap(); + assert!(res2.is_none()); // EOF + + // Body should now be Complete + assert_eq!(http_stream.body_reader.body_state, ParseState::Complete(3)); + + http_stream.respect_keepalive(); + assert!(!http_stream.will_keepalive()); + } } #[cfg(test)] diff --git a/pingora-core/src/protocols/http/v1/common.rs b/pingora-core/src/protocols/http/v1/common.rs index d4b3e6e6..93f4524c 100644 --- a/pingora-core/src/protocols/http/v1/common.rs +++ b/pingora-core/src/protocols/http/v1/common.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ use http::{header, HeaderValue}; use log::warn; -use pingora_error::Result; +use pingora_error::{Error, ErrorType::*, Result}; use pingora_http::{HMap, RequestHeader, ResponseHeader}; use std::str; use std::time::Duration; @@ -121,8 +121,7 @@ fn parse_connection_header(value: &[u8]) -> ConnectionValue { } pub(crate) fn init_body_writer_comm(body_writer: &mut BodyWriter, headers: &HMap) { - let te_value = headers.get(http::header::TRANSFER_ENCODING); - if is_header_value_chunked_encoding(te_value) { + if is_chunked_encoding_from_headers(headers) { // transfer-encoding takes priority over content-length body_writer.init_chunked(); } else { @@ -134,18 +133,46 @@ pub(crate) fn init_body_writer_comm(body_writer: &mut BodyWriter, headers: &HMap None => { /* TODO: 1. connection: keepalive cannot be used, 2. mark connection must be closed */ - body_writer.init_http10(); + body_writer.init_close_delimited(); } } } } +/// Find the last comma-separated token in a Transfer-Encoding header value. +/// Takes the literal last token after the last comma, even if empty. #[inline] -pub fn is_header_value_chunked_encoding(header_value: Option<&http::header::HeaderValue>) -> bool { - match header_value { - Some(value) => value.as_bytes().eq_ignore_ascii_case(b"chunked"), - None => false, +fn find_last_te_token(bytes: &[u8]) -> &[u8] { + let last_token = bytes + .iter() + .rposition(|&b| b == b',') + .map(|pos| &bytes[pos + 1..]) + .unwrap_or(bytes); + + last_token.trim_ascii() +} + +/// Check if chunked encoding is the final encoding across all transfer-encoding headers +pub(crate) fn is_chunked_encoding_from_headers(headers: &HMap) -> bool { + // Get the last Transfer-Encoding header value + let last_te = headers + .get_all(http::header::TRANSFER_ENCODING) + .into_iter() + .next_back(); + + let Some(last_header_value) = last_te else { + return false; + }; + + let bytes = last_header_value.as_bytes(); + + // Fast path: exact match for "chunked" + if bytes.eq_ignore_ascii_case(b"chunked") { + return true; } + + // Slow path: parse comma-separated values + find_last_te_token(bytes).eq_ignore_ascii_case(b"chunked") } pub fn is_upgrade_req(req: &RequestHeader) -> bool { @@ -173,13 +200,13 @@ pub fn header_value_content_length( header_value: Option<&http::header::HeaderValue>, ) -> Option { match header_value { - Some(value) => buf_to_content_length(Some(value.as_bytes())), + Some(value) => buf_to_content_length(Some(value.as_bytes())).ok().flatten(), None => None, } } #[inline] -pub(super) fn buf_to_content_length(header_value: Option<&[u8]>) -> Option { +pub(super) fn buf_to_content_length(header_value: Option<&[u8]>) -> Result> { match header_value { Some(buf) => { match str::from_utf8(buf) { @@ -187,24 +214,30 @@ pub(super) fn buf_to_content_length(header_value: Option<&[u8]>) -> Option match str_cl_value.parse::() { Ok(cl_length) => { if cl_length >= 0 { - Some(cl_length as usize) + Ok(Some(cl_length as usize)) } else { warn!("negative content-length header value {cl_length}"); - None + Error::e_explain( + InvalidHTTPHeader, + format!("negative Content-Length header value: {cl_length}"), + ) } } Err(_) => { warn!("invalid content-length header value {str_cl_value}"); - None + Error::e_explain( + InvalidHTTPHeader, + format!("invalid Content-Length header value: {str_cl_value}"), + ) } }, Err(_) => { warn!("invalid content-length header encoding"); - None + Error::e_explain(InvalidHTTPHeader, "invalid Content-Length header encoding") } } } - None => None, + None => Ok(None), } } @@ -276,6 +309,7 @@ mod test { header::{CONTENT_LENGTH, TRANSFER_ENCODING}, StatusCode, Version, }; + use rstest::rstest; #[test] fn test_check_dup_content_length() { @@ -312,4 +346,44 @@ mod test { response.set_version(Version::HTTP_11); assert!(!is_upgrade_resp(&response)); } + + #[test] + fn test_is_chunked_encoding_from_headers_empty() { + let empty_headers = HMap::new(); + assert!(!is_chunked_encoding_from_headers(&empty_headers)); + } + + #[rstest] + #[case::single_chunked("chunked", true)] + #[case::comma_separated_final("identity, chunked", true)] + #[case::whitespace_around(" chunked ", true)] + #[case::empty_elements_before(", , , chunked", true)] + #[case::only_identity("identity", false)] + #[case::trailing_comma("chunked, ", false)] + #[case::multiple_trailing_commas("chunked, , ", false)] + #[case::empty_value("", false)] + #[case::whitespace_only(" ", false)] + fn test_is_chunked_encoding_single_header(#[case] value: &str, #[case] expected: bool) { + let mut headers = HMap::new(); + headers.insert(TRANSFER_ENCODING, value.try_into().unwrap()); + assert_eq!(is_chunked_encoding_from_headers(&headers), expected); + } + + #[rstest] + #[case::two_headers_chunked_last(&["identity", "chunked"], true)] + #[case::three_headers_chunked_last(&["gzip", "identity", "chunked"], true)] + #[case::last_has_comma_separated(&["gzip", "identity, chunked"], true)] + #[case::whitespace_in_last(&["gzip", " chunked "], true)] + #[case::two_headers_no_chunked(&["identity", "gzip"], false)] + #[case::chunked_not_last(&["chunked", "identity"], false)] + #[case::last_has_chunked_not_final(&["gzip", "chunked, identity"], false)] + #[case::chunked_overridden(&["chunked", "identity, gzip"], false)] + #[case::trailing_comma_in_last(&["gzip", "chunked, "], false)] + fn test_is_chunked_encoding_multiple_headers(#[case] values: &[&str], #[case] expected: bool) { + let mut headers = HMap::new(); + for value in values { + headers.append(TRANSFER_ENCODING, (*value).try_into().unwrap()); + } + assert_eq!(is_chunked_encoding_from_headers(&headers), expected); + } } diff --git a/pingora-core/src/protocols/http/v1/mod.rs b/pingora-core/src/protocols/http/v1/mod.rs index c819ee08..19602491 100644 --- a/pingora-core/src/protocols/http/v1/mod.rs +++ b/pingora-core/src/protocols/http/v1/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-core/src/protocols/http/v1/server.rs b/pingora-core/src/protocols/http/v1/server.rs index 073fee41..b071e6fd 100644 --- a/pingora-core/src/protocols/http/v1/server.rs +++ b/pingora-core/src/protocols/http/v1/server.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,12 +14,13 @@ //! HTTP/1.x server session +use bstr::ByteSlice; use bytes::Bytes; use bytes::{BufMut, BytesMut}; use http::header::{CONTENT_LENGTH, TRANSFER_ENCODING}; use http::HeaderValue; use http::{header, header::AsHeaderName, Method, Version}; -use log::{debug, warn}; +use log::{debug, trace, warn}; use once_cell::sync::Lazy; use percent_encoding::{percent_encode, AsciiSet, CONTROLS}; use pingora_error::{Error, ErrorType::*, OrErr, Result}; @@ -81,6 +82,10 @@ pub struct HttpSession { ignore_info_resp: bool, /// Disable keepalive if response is sent before downstream body is finished close_on_response_before_downstream_finish: bool, + + /// Number of times the upstream connection associated with this session can be reused + /// after this session ends + keepalive_reuses_remaining: Option, } impl HttpSession { @@ -118,7 +123,9 @@ impl HttpSession { digest, min_send_rate: None, ignore_info_resp: false, - close_on_response_before_downstream_finish: false, + // default on to avoid rejecting requests after body as pipelined + close_on_response_before_downstream_finish: true, + keepalive_reuses_remaining: None, } } @@ -256,7 +263,12 @@ impl HttpSession { // Transfer encoding overrides content length, so when // both are present, we can remove content length. This // is per https://datatracker.ietf.org/doc/html/rfc9112#section-6.3 - if contains_content_length && contains_transfer_encoding { + // + // RFC 9112 Section 6.1 (https://datatracker.ietf.org/doc/html/rfc9112#section-6.1-15) + // also requires us to disable keepalive when both headers are present. + let has_both_te_and_cl = + contains_content_length && contains_transfer_encoding; + if has_both_te_and_cl { request_header.remove_header(&CONTENT_LENGTH); } @@ -266,6 +278,11 @@ impl HttpSession { self.body_reader.reinit(); self.response_written = None; self.respect_keepalive(); + + // Disable keepalive if both Transfer-Encoding and Content-Length were present + if has_both_te_and_cl { + self.set_keepalive(None); + } self.validate_request()?; return Ok(Some(s)); @@ -284,10 +301,7 @@ impl HttpSession { buf.truncate(MAX_ERR_BUF_LEN); return Error::e_because( InvalidHTTPHeader, - format!( - "buf: {}", - String::from_utf8_lossy(&buf).escape_default() - ), + format!("buf: {}", buf.escape_ascii()), e, ); } @@ -297,7 +311,7 @@ impl HttpSession { buf.truncate(MAX_ERR_BUF_LEN); return Error::e_because( InvalidHTTPHeader, - format!("buf: {}", String::from_utf8_lossy(&buf).escape_default()), + format!("buf: {:?}", buf.as_bstr()), e, ); } @@ -317,6 +331,25 @@ impl HttpSession { // ad-hoc checks super::common::check_dup_content_length(&req_header.headers)?; + if req_header.headers.contains_key(TRANSFER_ENCODING) { + // Per [RFC 9112 Section 6.1-16](https://datatracker.ietf.org/doc/html/rfc9112#section-6.1-16), + // HTTP/1.0 requests with Transfer-Encoding MUST be treated as having faulty framing. + // We reject with 400 Bad Request and close the connection. + if req_header.version == http::Version::HTTP_10 { + return Error::e_explain( + InvalidHTTPHeader, + "HTTP/1.0 requests cannot include Transfer-Encoding header", + ); + } + // If chunked is not the final Transfer-Encoding, reject request + // See https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.4.3 + if !self.is_chunked_encoding() { + return Error::e_explain(InvalidHTTPHeader, "non-chunked final Transfer-Encoding"); + } + } + // validate content-length value if present to avoid ambiguous framing + self.get_content_length()?; + Ok(()) } @@ -474,7 +507,10 @@ impl HttpSession { } } - if self.close_on_response_before_downstream_finish && !self.is_body_done() { + // if body unfinished, or request header was not finished reading + if self.close_on_response_before_downstream_finish + && (self.request_header.is_none() || !self.is_body_done()) + { debug!("set connection close before downstream finish"); self.set_keepalive(None); } @@ -514,16 +550,37 @@ impl HttpSession { // a peer discards any further data received. // https://www.rfc-editor.org/rfc/rfc6455#section-1.4 self.upgraded = true; + // Now that the upgrade was successful, we need to change + // how we interpret the rest of the body as pass-through. + if self.body_reader.need_init() { + self.init_body_reader(); + } else { + // already initialized + // immediately start reading the rest of the body as upgraded + // (in practice most upgraded requests shouldn't have any body) + // + // TODO: https://datatracker.ietf.org/doc/html/rfc9110#name-upgrade + // the most spec-compliant behavior is to switch interpretation + // after sending the former body, + // we immediately switch interpretation to match nginx + self.body_reader.convert_to_close_delimited(); + } } else { + // this was a request that requested Upgrade, + // but upstream did not comply debug!("bad upgrade handshake!"); - // reset request body buf and mark as done - // safe to reset an upgrade because it doesn't have body - self.body_reader.init_content_length(0, b""); + // continue to read body as-is, this is now just a regular request } } self.init_body_writer(&header); } + // Defense-in-depth: if response body is close-delimited, mark session + // as un-reusable + if self.body_writer.is_close_delimited() { + self.set_keepalive(None); + } + // Don't have to flush response with content length because it is less // likely to be real time communication. So do flush when // 1.1xx response: client needs to see it before the rest of response @@ -566,6 +623,14 @@ impl HttpSession { } } + /// Was this request successfully turned into an upgraded connection? + /// + /// Both the request had to have been an `Upgrade` request + /// and the response had to have been a `101 Switching Protocols`. + pub fn was_upgraded(&self) -> bool { + self.upgraded + } + fn set_keepalive(&mut self, seconds: Option) { match seconds { Some(sec) => { @@ -589,11 +654,20 @@ impl HttpSession { } } + pub fn set_keepalive_reuses_remaining(&mut self, remaining: Option) { + self.keepalive_reuses_remaining = remaining; + } + + pub fn get_keepalive_reuses_remaining(&self) -> Option { + self.keepalive_reuses_remaining + } + /// Return whether the session will be keepalived for connection reuse. pub fn will_keepalive(&self) -> bool { - // TODO: check self.body_writer. If it is http1.0 type then keepalive - // cannot be used because the connection close is the signal of end body - !matches!(self.keepalive_timeout, KeepaliveStatus::Off) + !matches!( + (&self.keepalive_timeout, self.keepalive_reuses_remaining), + (KeepaliveStatus::Off, _) | (_, Some(0)) + ) } // `Keep-Alive: timeout=5, max=1000` => 5, 1000 @@ -670,7 +744,7 @@ impl HttpSession { } if self.is_upgrade(header) == Some(true) { - self.body_writer.init_http10(); + self.body_writer.init_close_delimited(); } else { init_body_writer_comm(&mut self.body_writer, &header.headers); } @@ -756,6 +830,10 @@ impl HttpSession { .await .or_err(WriteError, "flushing body")?; + trace!( + "finish body (response body writer), upgraded: {}", + self.upgraded + ); self.maybe_force_close_body_reader(); Ok(res) } @@ -771,10 +849,10 @@ impl HttpSession { } fn is_chunked_encoding(&self) -> bool { - is_header_value_chunked_encoding(self.get_header(header::TRANSFER_ENCODING)) + is_chunked_encoding_from_headers(&self.req_header().headers) } - fn get_content_length(&self) -> Option { + fn get_content_length(&self) -> Result> { buf_to_content_length( self.get_header(header::CONTENT_LENGTH) .map(|v| v.as_bytes()), @@ -788,33 +866,30 @@ impl HttpSession { buffer.clear(); } - /* follow https://tools.ietf.org/html/rfc7230#section-3.3.3 */ + // follow https://datatracker.ietf.org/doc/html/rfc9112#section-6.3 let preread_body = self.preread_body.as_ref().unwrap().get(&self.buf[..]); - if self.req_header().version == Version::HTTP_11 && self.is_upgrade_req() { - self.body_reader.init_http10(preread_body); - return; - } - - if self.is_chunked_encoding() { + if self.was_upgraded() { + // if upgraded _post_ 101 (and body was not init yet) + // treat as upgraded body (pass through until closed) + self.body_reader.init_close_delimited(preread_body); + } else if self.is_chunked_encoding() { // if chunked encoding, content-length should be ignored self.body_reader.init_chunked(preread_body); } else { - let cl = self.get_content_length(); + // At this point, validate_request() should have already been called, + // so get_content_length() should not return an error for invalid values + let cl = self.get_content_length().unwrap_or(None); match cl { Some(i) => { self.body_reader.init_content_length(i, preread_body); } None => { - match self.req_header().version { - Version::HTTP_11 => { - // Per RFC assume no body by default in HTTP 1.1 - self.body_reader.init_content_length(0, preread_body); - } - _ => { - self.body_reader.init_http10(preread_body); - } - } + // https://datatracker.ietf.org/doc/html/rfc9112#section-6.3 + // "Request messages are never close-delimited because they are + // always explicitly framed by length or transfer coding, with the absence of + // both implying the request ends immediately after the header section." + self.body_reader.init_content_length(0, preread_body); } } } @@ -864,6 +939,7 @@ impl HttpSession { /// forever, same as [`Self::idle()`]. pub async fn read_body_or_idle(&mut self, no_body_expected: bool) -> Result> { if no_body_expected || self.is_body_done() { + // XXX: account for upgraded body reader change, if the read half split from the write half let read = self.idle().await?; if read == 0 { Error::e_explain( @@ -1015,14 +1091,21 @@ impl HttpSession { /// returned. If there was an error while draining any remaining request body that error will /// be returned. pub async fn reuse(mut self) -> Result> { - match self.keepalive_timeout { - KeepaliveStatus::Off => { - debug!("HTTP shutdown connection"); - self.shutdown().await; + if !self.will_keepalive() { + debug!("HTTP shutdown connection"); + self.shutdown().await; + Ok(None) + } else { + self.drain_request_body().await?; + // XXX: currently pipelined requests are not properly read without + // pipelining support, and pingora 400s if pipelined requests are sent + // in the middle of another request. + // We will mark the connection as un-reusable so it may be closed, + // the pipelined request left unread, and the client can attempt to resend + if self.body_reader.has_bytes_overread() { + debug!("bytes overread on request, disallowing reuse"); Ok(None) - } - _ => { - self.drain_request_body().await?; + } else { Ok(Some(self.underlying_stream)) } } @@ -1040,6 +1123,28 @@ impl HttpSession { Ok(()) } + async fn write_non_empty_body(&mut self, data: Option, upgraded: bool) -> Result<()> { + // Both upstream and downstream should agree on upgrade status. + // Upgrade can only occur if both downstream and upstream sessions are H1.1 + // and see a 101 response, which logically MUST have been received + // prior to this task. + if upgraded != self.upgraded { + if upgraded { + panic!("Unexpected UpgradedBody task received on un-upgraded downstream session"); + } else { + panic!("Unexpected Body task received on upgraded downstream session"); + } + } + let Some(d) = data else { + return Ok(()); + }; + if d.is_empty() { + return Ok(()); + } + self.write_body(&d).await.map_err(|e| e.into_down())?; + Ok(()) + } + async fn response_duplex(&mut self, task: HttpTask) -> Result { let end_stream = match task { HttpTask::Header(header, end_stream) => { @@ -1048,15 +1153,14 @@ impl HttpSession { .map_err(|e| e.into_down())?; end_stream } - HttpTask::Body(data, end_stream) => match data { - Some(d) => { - if !d.is_empty() { - self.write_body(&d).await.map_err(|e| e.into_down())?; - } - end_stream - } - None => end_stream, - }, + HttpTask::Body(data, end_stream) => { + self.write_non_empty_body(data, false).await?; + end_stream + } + HttpTask::UpgradedBody(data, end_stream) => { + self.write_non_empty_body(data, true).await?; + end_stream + } HttpTask::Trailer(_) => true, // h1 trailer is not supported yet HttpTask::Done => true, HttpTask::Failed(e) => return Err(e), @@ -1068,6 +1172,23 @@ impl HttpSession { Ok(end_stream || self.body_writer.finished()) } + fn buffer_body_data(&mut self, data: Option, upgraded: bool) { + if upgraded != self.upgraded { + if upgraded { + panic!("Unexpected Body task received on upgraded downstream session"); + } else { + panic!("Unexpected UpgradedBody task received on un-upgraded downstream session"); + } + } + + let Some(d) = data else { + return; + }; + if !d.is_empty() && !self.body_writer.finished() { + self.body_write_buf.put_slice(&d); + } + } + // TODO: use vectored write to avoid copying pub async fn response_duplex_vec(&mut self, mut tasks: Vec) -> Result { let n_tasks = tasks.len(); @@ -1075,6 +1196,7 @@ impl HttpSession { // fallback to single operation to avoid copy return self.response_duplex(tasks.pop().unwrap()).await; } + let mut end_stream = false; for task in tasks.into_iter() { end_stream = match task { @@ -1084,15 +1206,14 @@ impl HttpSession { .map_err(|e| e.into_down())?; end_stream } - HttpTask::Body(data, end_stream) => match data { - Some(d) => { - if !d.is_empty() && !self.body_writer.finished() { - self.body_write_buf.put_slice(&d); - } - end_stream - } - None => end_stream, - }, + HttpTask::Body(data, end_stream) => { + self.buffer_body_data(data, false); + end_stream + } + HttpTask::UpgradedBody(data, end_stream) => { + self.buffer_body_data(data, true); + end_stream + } HttpTask::Trailer(_) => true, // h1 trailer is not supported yet HttpTask::Done => true, HttpTask::Failed(e) => { @@ -1232,6 +1353,7 @@ mod tests_stream { use super::*; use crate::protocols::http::v1::body::{BodyMode, ParseState}; use http::StatusCode; + use pingora_error::ErrorType; use rstest::rstest; use std::str; use tokio_test::io::Builder; @@ -1340,12 +1462,13 @@ mod tests_stream { } #[tokio::test] + #[should_panic(expected = "There is still data left to read.")] async fn read_with_body_http10() { init_log(); let input1 = b"GET / HTTP/1.0\r\n"; let input2 = b"Host: pingora.org\r\n\r\n"; - let input3 = b"a"; - let input4 = b""; // simulating close + let input3 = b"a"; // This should NOT be read as body + let input4 = b""; // simulating close - should also NOT be reached let mock_io = Builder::new() .read(&input1[..]) .read(&input2[..]) @@ -1354,41 +1477,26 @@ mod tests_stream { .build(); let mut http_stream = HttpSession::new(Box::new(mock_io)); http_stream.read_request().await.unwrap(); - let res = http_stream.read_body_bytes().await.unwrap().unwrap(); - assert_eq!(res, input3.as_slice()); - assert_eq!(http_stream.body_reader.body_state, ParseState::HTTP1_0(1)); - assert_eq!(http_stream.body_bytes_read(), 1); let res = http_stream.read_body_bytes().await.unwrap(); assert!(res.is_none()); - assert_eq!(http_stream.body_reader.body_state, ParseState::Complete(1)); - assert_eq!(http_stream.body_bytes_read(), 1); + assert_eq!(http_stream.body_bytes_read(), 0); + assert_eq!(http_stream.body_reader.body_state, ParseState::Complete(0)); } #[tokio::test] async fn read_with_body_http10_single_read() { init_log(); + // should have 0 body, even when data follows the headers let input1 = b"GET / HTTP/1.0\r\n"; let input2 = b"Host: pingora.org\r\n\r\na"; - let input3 = b"b"; - let input4 = b""; // simulating close - let mock_io = Builder::new() - .read(&input1[..]) - .read(&input2[..]) - .read(&input3[..]) - .read(&input4[..]) - .build(); + let mock_io = Builder::new().read(&input1[..]).read(&input2[..]).build(); let mut http_stream = HttpSession::new(Box::new(mock_io)); http_stream.read_request().await.unwrap(); - let res = http_stream.read_body_bytes().await.unwrap().unwrap(); - assert_eq!(res, b"a".as_slice()); - assert_eq!(http_stream.body_reader.body_state, ParseState::HTTP1_0(1)); - let res = http_stream.read_body_bytes().await.unwrap().unwrap(); - assert_eq!(res, b"b".as_slice()); - assert_eq!(http_stream.body_reader.body_state, ParseState::HTTP1_0(2)); let res = http_stream.read_body_bytes().await.unwrap(); - assert_eq!(http_stream.body_bytes_read(), 2); assert!(res.is_none()); - assert_eq!(http_stream.body_reader.body_state, ParseState::Complete(2)); + assert_eq!(http_stream.body_bytes_read(), 0); + assert_eq!(http_stream.body_reader.body_state, ParseState::Complete(0)); + assert_eq!(http_stream.body_reader.get_body_overread().unwrap(), b"a"); } #[tokio::test] @@ -1406,7 +1514,26 @@ mod tests_stream { } #[tokio::test] - async fn read_with_body_chunked_0() { + async fn read_http10_with_content_length() { + init_log(); + let input1 = b"POST / HTTP/1.0\r\n"; + let input2 = b"Host: pingora.org\r\nContent-Length: 3\r\n\r\n"; + let input3 = b"abc"; + let mock_io = Builder::new() + .read(&input1[..]) + .read(&input2[..]) + .read(&input3[..]) + .build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + http_stream.read_request().await.unwrap(); + let res = http_stream.read_body_bytes().await.unwrap().unwrap(); + assert_eq!(res, input3.as_slice()); + assert_eq!(http_stream.body_reader.body_state, ParseState::Complete(3)); + assert_eq!(http_stream.body_bytes_read(), 3); + } + + #[tokio::test] + async fn read_with_body_chunked_0_incomplete() { init_log(); let input1 = b"GET / HTTP/1.1\r\n"; let input2 = b"Host: pingora.org\r\nTransfer-Encoding: chunked\r\n\r\n"; @@ -1419,10 +1546,36 @@ mod tests_stream { let mut http_stream = HttpSession::new(Box::new(mock_io)); http_stream.read_request().await.unwrap(); assert!(http_stream.is_chunked_encoding()); - let res = http_stream.read_body_bytes().await.unwrap(); - assert!(res.is_none()); - assert_eq!(http_stream.body_bytes_read(), 0); - assert_eq!(http_stream.body_reader.body_state, ParseState::Complete(0)); + let res = http_stream.read_body_bytes().await.unwrap().unwrap(); + assert_eq!(res, b"".as_slice()); + let e = http_stream.read_body_bytes().await.unwrap_err(); + assert_eq!(*e.etype(), ErrorType::ConnectionClosed); + assert_eq!(http_stream.body_reader.body_state, ParseState::Done(0)); + } + + #[tokio::test] + async fn read_with_body_chunked_0_extra() { + init_log(); + let input1 = b"GET / HTTP/1.1\r\n"; + let input2 = b"Host: pingora.org\r\nTransfer-Encoding: chunked\r\n\r\n"; + let input3 = b"0\r\n"; + let input4 = b"abc"; + let mock_io = Builder::new() + .read(&input1[..]) + .read(&input2[..]) + .read(&input3[..]) + .read(&input4[..]) + .build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + http_stream.read_request().await.unwrap(); + assert!(http_stream.is_chunked_encoding()); + let res = http_stream.read_body_bytes().await.unwrap().unwrap(); + assert_eq!(res, b"".as_slice()); + let res = http_stream.read_body_bytes().await.unwrap().unwrap(); + assert_eq!(res, b"".as_slice()); + let e = http_stream.read_body_bytes().await.unwrap_err(); + assert_eq!(*e.etype(), ErrorType::ConnectionClosed); + assert_eq!(http_stream.body_reader.body_state, ParseState::Done(0)); } #[tokio::test] @@ -1451,6 +1604,33 @@ mod tests_stream { assert_eq!(http_stream.body_reader.body_state, ParseState::Complete(1)); } + #[tokio::test] + async fn read_with_body_chunked_single_read_extra() { + init_log(); + let input1 = b"GET / HTTP/1.1\r\n"; + let input2 = b"Host: pingora.org\r\nTransfer-Encoding: chunked\r\n\r\n1\r\na\r\n"; + let input3 = b"0\r\n\r\nabc"; + let mock_io = Builder::new() + .read(&input1[..]) + .read(&input2[..]) + .read(&input3[..]) + .build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + http_stream.read_request().await.unwrap(); + assert!(http_stream.is_chunked_encoding()); + let res = http_stream.read_body_bytes().await.unwrap().unwrap(); + assert_eq!(res, b"a".as_slice()); + assert_eq!( + http_stream.body_reader.body_state, + ParseState::Chunked(1, 0, 0, 0) + ); + let res = http_stream.read_body_bytes().await.unwrap(); + assert!(res.is_none()); + assert_eq!(http_stream.body_bytes_read(), 1); + assert_eq!(http_stream.body_reader.body_state, ParseState::Complete(1)); + assert_eq!(http_stream.body_reader.get_body_overread().unwrap(), b"abc"); + } + #[rstest] #[case(None, None)] #[case(Some("transfer-encoding"), None)] @@ -1500,6 +1680,55 @@ mod tests_stream { } } + #[rstest] + #[case::negative("-1")] + #[case::not_a_number("abc")] + #[case::float("1.5")] + #[case::empty("")] + #[case::spaces(" ")] + #[case::mixed("123abc")] + #[tokio::test] + async fn validate_request_rejects_invalid_content_length(#[case] invalid_value: &str) { + init_log(); + let input = format!( + "POST / HTTP/1.1\r\nHost: pingora.org\r\nContent-Length: {}\r\n\r\n", + invalid_value + ); + let mock_io = Builder::new().read(input.as_bytes()).build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + // read_request calls validate_request internally, so it should fail here + let res = http_stream.read_request().await; + assert!(res.is_err()); + assert_eq!(res.unwrap_err().etype(), &InvalidHTTPHeader); + } + + #[rstest] + #[case::valid_zero("0")] + #[case::valid_small("123")] + #[case::valid_large("999999")] + #[tokio::test] + async fn validate_request_accepts_valid_content_length(#[case] valid_value: &str) { + init_log(); + let input = format!( + "POST / HTTP/1.1\r\nHost: pingora.org\r\nContent-Length: {}\r\n\r\n", + valid_value + ); + let mock_io = Builder::new().read(input.as_bytes()).build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + let res = http_stream.read_request().await; + assert!(res.is_ok()); + } + + #[tokio::test] + async fn validate_request_accepts_no_content_length() { + init_log(); + let input = b"GET / HTTP/1.1\r\nHost: pingora.org\r\n\r\n"; + let mock_io = Builder::new().read(&input[..]).build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + let res = http_stream.read_request().await; + assert!(res.is_ok()); + } + #[tokio::test] #[should_panic(expected = "There is still data left to read.")] async fn read_invalid() { @@ -1511,7 +1740,16 @@ mod tests_stream { assert_eq!(&InvalidHTTPHeader, res.unwrap_err().etype()); } - async fn build_req(upgrade: &str, conn: &str) -> HttpSession { + #[tokio::test] + async fn read_invalid_header_end() { + let input = b"POST / HTTP/1.1\r\nHost: pingora.org\r\nContent-Length: 3\r\r\nConnection: keep-alive\r\n\r\nabc"; + let mock_io = Builder::new().read(&input[..]).build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + let res = http_stream.read_request().await; + assert_eq!(&InvalidHTTPHeader, res.unwrap_err().etype()); + } + + async fn build_upgrade_req(upgrade: &str, conn: &str) -> HttpSession { let input = format!("GET / HTTP/1.1\r\nHost: pingora.org\r\nUpgrade: {upgrade}\r\nConnection: {conn}\r\n\r\n"); let mock_io = Builder::new().read(input.as_bytes()).build(); let mut http_stream = HttpSession::new(Box::new(mock_io)); @@ -1549,10 +1787,210 @@ mod tests_stream { http_stream.read_request().await.unwrap(); assert!(http_stream.is_upgrade_req()); - assert!(build_req("websocket", "Upgrade").await.is_upgrade_req()); + assert!(build_upgrade_req("websocket", "Upgrade") + .await + .is_upgrade_req()); // mixed case - assert!(build_req("WebSocket", "Upgrade").await.is_upgrade_req()); + assert!(build_upgrade_req("WebSocket", "Upgrade") + .await + .is_upgrade_req()); + } + + const POST_CL_UPGRADE_REQ: &[u8] = b"POST / HTTP/1.1\r\nHost: pingora.org\r\nUpgrade: websocket\r\nConnection: upgrade\r\nContent-Length: 10\r\n\r\n"; + const POST_BODY_DATA: &[u8] = b"abcdefghij"; + const POST_CHUNKED_UPGRADE_REQ: &[u8] = b"POST / HTTP/1.1\r\nHost: pingora.org\r\nUpgrade: websocket\r\nConnection: upgrade\r\nTransfer-Encoding: chunked\r\n\r\n"; + const POST_BODY_DATA_CHUNKED: &[u8] = b"3\r\nabc\r\n7\r\ndefghij\r\n0\r\n\r\n"; + + #[rstest] + #[case::content_length(POST_CL_UPGRADE_REQ, POST_BODY_DATA, POST_BODY_DATA)] + #[case::chunked(POST_CHUNKED_UPGRADE_REQ, POST_BODY_DATA, POST_BODY_DATA_CHUNKED)] + #[tokio::test] + async fn read_upgrade_req_with_body( + #[case] header: &[u8], + #[case] body: &[u8], + #[case] body_wire: &[u8], + ) { + let ws_data = b"data"; + let mock_io = Builder::new() + .read(header) + .read(body_wire) + .write(b"HTTP/1.1 101 Switching Protocols\r\n\r\n") + .read(&ws_data[..]) + .build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + http_stream.read_request().await.unwrap(); + assert!(http_stream.is_upgrade_req()); + // request has body + assert!(!http_stream.is_body_done()); + + let mut buf = vec![]; + while let Some(b) = http_stream.read_body_bytes().await.unwrap() { + buf.put_slice(&b); + } + assert_eq!(buf, body); + assert_eq!(http_stream.body_reader.body_state, ParseState::Complete(10)); + assert_eq!(http_stream.body_bytes_read(), 10); + + assert!(http_stream.is_body_done()); + + let mut response = ResponseHeader::build(StatusCode::SWITCHING_PROTOCOLS, None).unwrap(); + response.set_version(http::Version::HTTP_11); + http_stream + .write_response_header(Box::new(response)) + .await + .unwrap(); + // body reader type switches + assert!(!http_stream.is_body_done()); + + // now the ws data + let buf = http_stream.read_body_bytes().await.unwrap().unwrap(); + assert_eq!(buf, ws_data.as_slice()); + assert!(!http_stream.is_body_done()); + + // EOF ends body + assert!(http_stream.read_body_bytes().await.unwrap().is_none()); + assert!(http_stream.is_body_done()); + } + + #[rstest] + #[case::content_length(POST_CL_UPGRADE_REQ, POST_BODY_DATA, POST_BODY_DATA)] + #[case::chunked(POST_CHUNKED_UPGRADE_REQ, POST_BODY_DATA, POST_BODY_DATA_CHUNKED)] + #[tokio::test] + async fn read_upgrade_req_with_body_extra( + #[case] header: &[u8], + #[case] body: &[u8], + #[case] body_wire: &[u8], + ) { + let ws_data = b"data"; + let data_wire = [body_wire, ws_data.as_slice()].concat(); + let mock_io = Builder::new() + .read(header) + .read(&data_wire[..]) + .write(b"HTTP/1.1 101 Switching Protocols\r\n\r\n") + .build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + http_stream.read_request().await.unwrap(); + assert!(http_stream.is_upgrade_req()); + // request has body + assert!(!http_stream.is_body_done()); + + let mut buf = vec![]; + while let Some(b) = http_stream.read_body_bytes().await.unwrap() { + buf.put_slice(&b); + } + assert_eq!(buf, body); + assert_eq!(http_stream.body_reader.body_state, ParseState::Complete(10)); + assert_eq!(http_stream.body_bytes_read(), 10); + + assert!(http_stream.is_body_done()); + + let mut response = ResponseHeader::build(StatusCode::SWITCHING_PROTOCOLS, None).unwrap(); + response.set_version(http::Version::HTTP_11); + http_stream + .write_response_header(Box::new(response)) + .await + .unwrap(); + // body reader type switches + assert!(!http_stream.is_body_done()); + + // now the ws data + let buf = http_stream.read_body_bytes().await.unwrap().unwrap(); + assert_eq!(buf, ws_data.as_slice()); + assert!(!http_stream.is_body_done()); + + // EOF ends body + assert!(http_stream.read_body_bytes().await.unwrap().is_none()); + assert!(http_stream.is_body_done()); + } + + #[rstest] + #[case::content_length(POST_CL_UPGRADE_REQ, POST_BODY_DATA, POST_BODY_DATA)] + #[case::chunked(POST_CHUNKED_UPGRADE_REQ, POST_BODY_DATA, POST_BODY_DATA_CHUNKED)] + #[tokio::test] + async fn read_upgrade_req_with_preread_body( + #[case] header: &[u8], + #[case] body: &[u8], + #[case] body_wire: &[u8], + ) { + let ws_data = b"data"; + let data_wire = [header, body_wire, ws_data.as_slice()].concat(); + let mock_io = Builder::new() + .read(&data_wire[..]) + .write(b"HTTP/1.1 101 Switching Protocols\r\n\r\n") + .build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + http_stream.read_request().await.unwrap(); + assert!(http_stream.is_upgrade_req()); + // request has body + assert!(!http_stream.is_body_done()); + + let mut buf = vec![]; + while let Some(b) = http_stream.read_body_bytes().await.unwrap() { + buf.put_slice(&b); + } + assert_eq!(buf, body); + assert_eq!(http_stream.body_reader.body_state, ParseState::Complete(10)); + assert_eq!(http_stream.body_bytes_read(), 10); + + assert!(http_stream.is_body_done()); + + let mut response = ResponseHeader::build(StatusCode::SWITCHING_PROTOCOLS, None).unwrap(); + response.set_version(http::Version::HTTP_11); + http_stream + .write_response_header(Box::new(response)) + .await + .unwrap(); + // body reader type switches + assert!(!http_stream.is_body_done()); + + // now the ws data + let buf = http_stream.read_body_bytes().await.unwrap().unwrap(); + assert_eq!(buf, ws_data.as_slice()); + assert!(!http_stream.is_body_done()); + + // EOF ends body + assert!(http_stream.read_body_bytes().await.unwrap().is_none()); + assert!(http_stream.is_body_done()); + } + + #[rstest] + #[case::content_length(POST_CL_UPGRADE_REQ, POST_BODY_DATA)] + #[case::chunked(POST_CHUNKED_UPGRADE_REQ, POST_BODY_DATA_CHUNKED)] + #[tokio::test] + async fn read_upgrade_req_with_preread_body_after_101( + #[case] header: &[u8], + #[case] body_wire: &[u8], + ) { + let ws_data = b"data"; + let data_wire = [header, body_wire, ws_data.as_slice()].concat(); + let mock_io = Builder::new() + .read(&data_wire[..]) + .write(b"HTTP/1.1 101 Switching Protocols\r\n\r\n") + .build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + http_stream.read_request().await.unwrap(); + assert!(http_stream.is_upgrade_req()); + // request has body + assert!(!http_stream.is_body_done()); + + let mut response = ResponseHeader::build(StatusCode::SWITCHING_PROTOCOLS, None).unwrap(); + response.set_version(http::Version::HTTP_11); + http_stream + .write_response_header(Box::new(response)) + .await + .unwrap(); + // body reader type switches to http10 + assert!(!http_stream.is_body_done()); + + let mut buf = vec![]; + while let Some(b) = http_stream.read_body_bytes().await.unwrap() { + buf.put_slice(&b); + } + let expected_body = [body_wire, ws_data.as_slice()].concat(); + assert_eq!(buf, expected_body.as_bytes()); + assert_eq!(http_stream.body_bytes_read(), expected_body.len()); + assert!(http_stream.is_body_done()); } #[tokio::test] @@ -1561,6 +1999,7 @@ mod tests_stream { let mock_io = Builder::new() .read(&input[..]) .write(b"HTTP/1.1 100 Continue\r\n\r\n") + .write(b"HTTP/1.1 101 Switching Protocols\r\n\r\n") .build(); let mut http_stream = HttpSession::new(Box::new(mock_io)); http_stream.read_request().await.unwrap(); @@ -1572,7 +2011,60 @@ mod tests_stream { .await .unwrap(); // 100 won't affect body state + // current GET request is done + assert!(http_stream.is_body_done()); + + let mut response = ResponseHeader::build(StatusCode::SWITCHING_PROTOCOLS, None).unwrap(); + response.set_version(http::Version::HTTP_11); + http_stream + .write_response_header(Box::new(response)) + .await + .unwrap(); + // body reader type switches assert!(!http_stream.is_body_done()); + // EOF ends body + assert!(http_stream.read_body_bytes().await.unwrap().is_none()); + assert!(http_stream.is_body_done()); + } + + #[tokio::test] + async fn test_upgrade_without_content_length_with_ws_data() { + let request = b"GET / HTTP/1.1\r\nHost: pingora.org\r\nUpgrade: websocket\r\nConnection: upgrade\r\n\r\n"; + let ws_data = b"websocket data"; + + let mock_io = Builder::new() + .read(request) + .write(b"HTTP/1.1 101 Switching Protocols\r\n\r\n") + .read(ws_data) // websocket data sent after 101 + .build(); + + let mut http_stream = HttpSession::new(Box::new(mock_io)); + http_stream.read_request().await.unwrap(); + assert!(http_stream.is_upgrade_req()); + + // When enabled (default), is_body_done() is called before the upgrade + http_stream.set_close_on_response_before_downstream_finish(false); + + // Send 101 response - this is where the bug occurs + let mut response = ResponseHeader::build(StatusCode::SWITCHING_PROTOCOLS, None).unwrap(); + response.set_version(http::Version::HTTP_11); + http_stream + .write_response_header(Box::new(response)) + .await + .unwrap(); + + assert_eq!( + http_stream.body_reader.body_state, + ParseState::UntilClose(0), + "Body reader should be in UntilClose mode after 101 for upgraded connections" + ); + + // Try to read websocket data + let mut buf = vec![]; + while let Some(b) = http_stream.read_body_bytes().await.unwrap() { + buf.put_slice(&b); + } + assert_eq!(buf, ws_data, "Expected to read websocket data after 101"); } #[tokio::test] @@ -1619,9 +2111,11 @@ mod tests_stream { #[tokio::test] async fn write() { - let wire = b"HTTP/1.1 200 OK\r\nFoo: Bar\r\n\r\n"; - let mock_io = Builder::new().write(wire).build(); + let read_wire = b"GET / HTTP/1.1\r\n\r\n"; + let write_expected = b"HTTP/1.1 200 OK\r\nFoo: Bar\r\n\r\n"; + let mock_io = Builder::new().read(read_wire).write(write_expected).build(); let mut http_stream = HttpSession::new(Box::new(mock_io)); + http_stream.read_request().await.unwrap(); let mut new_response = ResponseHeader::build(StatusCode::OK, None).unwrap(); new_response.append_header("Foo", "Bar").unwrap(); http_stream.update_resp_headers = false; @@ -1633,9 +2127,11 @@ mod tests_stream { #[tokio::test] async fn write_custom_reason() { - let wire = b"HTTP/1.1 200 Just Fine\r\nFoo: Bar\r\n\r\n"; - let mock_io = Builder::new().write(wire).build(); + let read_wire = b"GET / HTTP/1.1\r\n\r\n"; + let write_expected = b"HTTP/1.1 200 Just Fine\r\nFoo: Bar\r\n\r\n"; + let mock_io = Builder::new().read(read_wire).write(write_expected).build(); let mut http_stream = HttpSession::new(Box::new(mock_io)); + http_stream.read_request().await.unwrap(); let mut new_response = ResponseHeader::build(StatusCode::OK, None).unwrap(); new_response.set_reason_phrase(Some("Just Fine")).unwrap(); new_response.append_header("Foo", "Bar").unwrap(); @@ -1648,9 +2144,11 @@ mod tests_stream { #[tokio::test] async fn write_informational() { - let wire = b"HTTP/1.1 100 Continue\r\n\r\nHTTP/1.1 200 OK\r\nFoo: Bar\r\n\r\n"; - let mock_io = Builder::new().write(wire).build(); + let read_wire = b"GET / HTTP/1.1\r\n\r\n"; + let write_expected = b"HTTP/1.1 100 Continue\r\n\r\nHTTP/1.1 200 OK\r\nFoo: Bar\r\n\r\n"; + let mock_io = Builder::new().read(read_wire).write(write_expected).build(); let mut http_stream = HttpSession::new(Box::new(mock_io)); + http_stream.read_request().await.unwrap(); let response_100 = ResponseHeader::build(StatusCode::CONTINUE, None).unwrap(); http_stream .write_response_header_ref(&response_100) @@ -1667,11 +2165,13 @@ mod tests_stream { #[tokio::test] async fn write_informational_ignored() { - let wire = b"HTTP/1.1 200 OK\r\nFoo: Bar\r\n\r\n"; - let mock_io = Builder::new().write(wire).build(); + let read_wire = b"GET / HTTP/1.1\r\n\r\n"; + let write_expected = b"HTTP/1.1 200 OK\r\nFoo: Bar\r\n\r\n"; + let mock_io = Builder::new().read(read_wire).write(write_expected).build(); let mut http_stream = HttpSession::new(Box::new(mock_io)); // ignore the 100 Continue http_stream.ignore_info_resp = true; + http_stream.read_request().await.unwrap(); let response_100 = ResponseHeader::build(StatusCode::CONTINUE, None).unwrap(); http_stream .write_response_header_ref(&response_100) @@ -1736,10 +2236,16 @@ mod tests_stream { #[tokio::test] async fn write_101_switching_protocol() { + let read_wire = b"GET / HTTP/1.1\r\nUpgrade: websocket\r\n\r\n"; let wire = b"HTTP/1.1 101 Switching Protocols\r\nFoo: Bar\r\n\r\n"; let wire_body = b"nPAYLOAD"; - let mock_io = Builder::new().write(wire).write(wire_body).build(); + let mock_io = Builder::new() + .read(read_wire) + .write(wire) + .write(wire_body) + .build(); let mut http_stream = HttpSession::new(Box::new(mock_io)); + http_stream.read_request().await.unwrap(); let mut response_101 = ResponseHeader::build(StatusCode::SWITCHING_PROTOCOLS, None).unwrap(); response_101.append_header("Foo", "Bar").unwrap(); @@ -1747,10 +2253,12 @@ mod tests_stream { .write_response_header_ref(&response_101) .await .unwrap(); + assert_eq!(http_stream.body_writer.body_mode, BodyMode::UntilClose(0)); + let n = http_stream.write_body(wire_body).await.unwrap().unwrap(); assert_eq!(wire_body.len(), n); - // simulate upgrade - http_stream.upgraded = true; + assert_eq!(http_stream.body_writer.body_mode, BodyMode::UntilClose(n)); + // this write should be ignored let response_502 = ResponseHeader::build(StatusCode::BAD_GATEWAY, None).unwrap(); http_stream @@ -1761,10 +2269,16 @@ mod tests_stream { #[tokio::test] async fn write_body_cl() { + let read_wire = b"GET / HTTP/1.1\r\n\r\n"; let wire_header = b"HTTP/1.1 200 OK\r\nContent-Length: 1\r\n\r\n"; let wire_body = b"a"; - let mock_io = Builder::new().write(wire_header).write(wire_body).build(); + let mock_io = Builder::new() + .read(read_wire) + .write(wire_header) + .write(wire_body) + .build(); let mut http_stream = HttpSession::new(Box::new(mock_io)); + http_stream.read_request().await.unwrap(); let mut new_response = ResponseHeader::build(StatusCode::OK, None).unwrap(); new_response.append_header("Content-Length", "1").unwrap(); http_stream.update_resp_headers = false; @@ -1784,17 +2298,23 @@ mod tests_stream { #[tokio::test] async fn write_body_http10() { + let read_wire = b"GET / HTTP/1.1\r\n\r\n"; let wire_header = b"HTTP/1.1 200 OK\r\n\r\n"; let wire_body = b"a"; - let mock_io = Builder::new().write(wire_header).write(wire_body).build(); + let mock_io = Builder::new() + .read(read_wire) + .write(wire_header) + .write(wire_body) + .build(); let mut http_stream = HttpSession::new(Box::new(mock_io)); + http_stream.read_request().await.unwrap(); let new_response = ResponseHeader::build(StatusCode::OK, None).unwrap(); http_stream.update_resp_headers = false; http_stream .write_response_header_ref(&new_response) .await .unwrap(); - assert_eq!(http_stream.body_writer.body_mode, BodyMode::HTTP1_0(0)); + assert_eq!(http_stream.body_writer.body_mode, BodyMode::UntilClose(0)); let n = http_stream.write_body(wire_body).await.unwrap().unwrap(); assert_eq!(wire_body.len(), n); let n = http_stream.finish_body().await.unwrap().unwrap(); @@ -1803,15 +2323,18 @@ mod tests_stream { #[tokio::test] async fn write_body_chunk() { + let read_wire = b"GET / HTTP/1.1\r\n\r\n"; let wire_header = b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n"; let wire_body = b"1\r\na\r\n"; let wire_end = b"0\r\n\r\n"; let mock_io = Builder::new() + .read(read_wire) .write(wire_header) .write(wire_body) .write(wire_end) .build(); let mut http_stream = HttpSession::new(Box::new(mock_io)); + http_stream.read_request().await.unwrap(); let mut new_response = ResponseHeader::build(StatusCode::OK, None).unwrap(); new_response .append_header("Transfer-Encoding", "chunked") @@ -1882,9 +2405,11 @@ mod tests_stream { #[tokio::test] async fn test_write_body_buf() { - let wire = b"HTTP/1.1 200 OK\r\nFoo: Bar\r\n\r\n"; - let mock_io = Builder::new().write(wire).build(); + let read_wire = b"GET / HTTP/1.1\r\n\r\n"; + let write_expected = b"HTTP/1.1 200 OK\r\nFoo: Bar\r\n\r\n"; + let mock_io = Builder::new().read(read_wire).write(write_expected).build(); let mut http_stream = HttpSession::new(Box::new(mock_io)); + http_stream.read_request().await.unwrap(); let mut new_response = ResponseHeader::build(StatusCode::OK, None).unwrap(); new_response.append_header("Foo", "Bar").unwrap(); http_stream.update_resp_headers = false; @@ -1899,14 +2424,17 @@ mod tests_stream { #[tokio::test] #[should_panic(expected = "There is still data left to write.")] async fn test_write_body_buf_write_timeout() { + let read_wire = b"GET / HTTP/1.1\r\n\r\n"; let wire1 = b"HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\n"; let wire2 = b"abc"; let mock_io = Builder::new() + .read(read_wire) .write(wire1) .wait(Duration::from_millis(500)) .write(wire2) .build(); let mut http_stream = HttpSession::new(Box::new(mock_io)); + http_stream.read_request().await.unwrap(); http_stream.write_timeout = Some(Duration::from_millis(100)); let mut new_response = ResponseHeader::build(StatusCode::OK, None).unwrap(); new_response.append_header("Content-Length", "3").unwrap(); @@ -1922,9 +2450,11 @@ mod tests_stream { #[tokio::test] async fn test_write_continue_resp() { - let wire = b"HTTP/1.1 100 Continue\r\n\r\n"; - let mock_io = Builder::new().write(wire).build(); + let read_wire = b"GET / HTTP/1.1\r\n\r\n"; + let write_expected = b"HTTP/1.1 100 Continue\r\n\r\n"; + let mock_io = Builder::new().read(read_wire).write(write_expected).build(); let mut http_stream = HttpSession::new(Box::new(mock_io)); + http_stream.read_request().await.unwrap(); http_stream.write_continue_response().await.unwrap(); } @@ -1973,6 +2503,203 @@ mod tests_stream { http_stream.set_min_send_rate(Some(1)); assert_eq!(Some(expected), http_stream.write_timeout(0)); } + + #[tokio::test] + async fn test_te_and_cl_disables_keepalive() { + // When both Transfer-Encoding and Content-Length are present, + // we must disable keepalive per RFC 9112 Section 6.1 + // https://datatracker.ietf.org/doc/html/rfc9112#section-6.1-15 + let input = b"POST / HTTP/1.1\r\n\ +Host: pingora.org\r\n\ +Transfer-Encoding: chunked\r\n\ +Content-Length: 10\r\n\ +\r\n\ +5\r\n\ +hello\r\n\ +0\r\n\ +\r\n"; + let mock_io = Builder::new().read(&input[..]).build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + http_stream.read_request().await.unwrap(); + + // Keepalive should be disabled + assert_eq!(http_stream.keepalive_timeout, KeepaliveStatus::Off); + + // Content-Length header should have been removed + assert!(!http_stream + .req_header() + .headers + .contains_key(CONTENT_LENGTH)); + + // Transfer-Encoding should still be present + assert!(http_stream + .req_header() + .headers + .contains_key(TRANSFER_ENCODING)); + } + + #[tokio::test] + async fn test_http10_request_with_transfer_encoding_rejected() { + // HTTP/1.0 requests MUST NOT contain Transfer-Encoding + let input = b"POST / HTTP/1.0\r\n\ +Host: pingora.org\r\n\ +Transfer-Encoding: chunked\r\n\ +\r\n\ +5\r\n\ +hello\r\n\ +0\r\n\ +\r\n"; + let mock_io = Builder::new().read(&input[..]).build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + let result = http_stream.read_request().await; + + // Should be rejected with InvalidHTTPHeader error + assert!(result.is_err()); + let err = result.unwrap_err(); + assert_eq!(err.etype(), &InvalidHTTPHeader); + assert!(err.to_string().contains("Transfer-Encoding")); + } + + #[tokio::test] + async fn test_http10_request_without_transfer_encoding_accepted() { + // HTTP/1.0 requests without Transfer-Encoding should be accepted + let input = b"POST / HTTP/1.0\r\n\ +Host: pingora.org\r\n\ +Content-Length: 5\r\n\ +\r\n\ +hello"; + let mock_io = Builder::new().read(&input[..]).build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + let result = http_stream.read_request().await; + + // Should succeed + assert!(result.is_ok()); + assert_eq!(http_stream.req_header().version, http::Version::HTTP_10); + } + + #[tokio::test] + async fn test_http11_request_with_transfer_encoding_accepted() { + // HTTP/1.1 with Transfer-Encoding should be accepted (contrast with HTTP/1.0) + let input = b"POST / HTTP/1.1\r\n\ +Host: pingora.org\r\n\ +Transfer-Encoding: chunked\r\n\ +\r\n\ +5\r\n\ +hello\r\n\ +0\r\n\ +\r\n"; + let mock_io = Builder::new().read(&input[..]).build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + let result = http_stream.read_request().await; + + // Should succeed + assert!(result.is_ok()); + assert_eq!(http_stream.req_header().version, http::Version::HTTP_11); + } + + #[tokio::test] + async fn test_request_multiple_transfer_encoding_headers() { + init_log(); + // Multiple TE headers should be treated as comma-separated + let input = b"POST / HTTP/1.1\r\n\ +Host: pingora.org\r\n\ +Transfer-Encoding: gzip\r\n\ +Transfer-Encoding: chunked\r\n\ +\r\n\ +5\r\n\ +hello\r\n\ +0\r\n\ +\r\n"; + + let mock_io = Builder::new().read(&input[..]).build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + http_stream.read_request().await.unwrap(); + + // Should correctly identify chunked encoding from last header + assert!(http_stream.is_chunked_encoding()); + + // Verify body can be read correctly + let body = http_stream.read_body_bytes().await.unwrap(); + assert_eq!(body.unwrap().as_ref(), b"hello"); + } + + #[tokio::test] + async fn test_request_multiple_te_headers_chunked_not_last() { + init_log(); + // Chunked in first header but not last - should NOT be chunked + // Only the final Transfer-Encoding determines if body is chunked + let input = b"POST / HTTP/1.1\r\n\ +Host: pingora.org\r\n\ +Transfer-Encoding: chunked\r\n\ +Transfer-Encoding: identity\r\n\ +Content-Length: 5\r\n\ +\r\n"; + + let mock_io = Builder::new().read(&input[..]).build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + // should fail validation + http_stream.read_request().await.unwrap_err(); + } + + #[tokio::test] + async fn test_no_more_reuses_explicitly_disables_reuse() { + init_log(); + let wire_req = b"GET /test HTTP/1.1\r\n\r\n"; + let wire_header = b"HTTP/1.1 200 OK\r\n\r\n"; + let mock_io = Builder::new() + .read(&wire_req[..]) + .write(wire_header) + .build(); + let mut http_session = HttpSession::new(Box::new(mock_io)); + + // Setting the number of keepalive reuses here overrides the keepalive + // setting below + http_session.set_keepalive_reuses_remaining(Some(0)); + + http_session.read_request().await.unwrap(); + + let new_response = ResponseHeader::build(StatusCode::OK, None).unwrap(); + http_session.update_resp_headers = false; + http_session + .write_response_header(Box::new(new_response)) + .await + .unwrap(); + + assert_eq!(http_session.body_writer.body_mode, BodyMode::UntilClose(0)); + + http_session.finish_body().await.unwrap().unwrap(); + + http_session.set_keepalive(Some(100)); + let reused = http_session.reuse().await.unwrap(); + assert!(reused.is_none()); + } + + #[tokio::test] + async fn test_close_delimited_response_explicitly_disables_reuse() { + init_log(); + let wire_req = b"GET /test HTTP/1.1\r\n\r\n"; + let wire_header = b"HTTP/1.1 200 OK\r\n\r\n"; + let mock_io = Builder::new() + .read(&wire_req[..]) + .write(wire_header) + .build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + http_stream.read_request().await.unwrap(); + + let new_response = ResponseHeader::build(StatusCode::OK, None).unwrap(); + http_stream.update_resp_headers = false; + http_stream + .write_response_header(Box::new(new_response)) + .await + .unwrap(); + + assert_eq!(http_stream.body_writer.body_mode, BodyMode::UntilClose(0)); + + http_stream.finish_body().await.unwrap().unwrap(); + + let reused = http_stream.reuse().await.unwrap(); + assert!(reused.is_none()); + } } #[cfg(test)] @@ -2090,3 +2817,111 @@ mod test_timeouts { assert_eq!(res.unwrap().unwrap_err().etype(), &ReadTimedout); } } + +#[cfg(test)] +mod test_overread { + use super::*; + use rstest::rstest; + use tokio_test::io::Builder; + + fn init_log() { + let _ = env_logger::builder().is_test(true).try_init(); + } + + /// Test session reuse with preread body (all data in single read). + /// When extra bytes are read beyond the request body, the session should NOT be reused. + /// Test matrix includes whether reading body bytes is polled. + #[rstest] + #[case(0, None, true, true)] // CL:0, no extra, read body -> should reuse + #[case(0, None, false, true)] // CL:0, no extra, no read -> should reuse + #[case(0, Some(&b"extra_data_here"[..]), true, false)] // CL:0, extra, read body -> should NOT reuse + #[case(0, Some(&b"extra_data_here"[..]), false, false)] // CL:0, extra, no read -> should NOT reuse + #[case(5, None, true, true)] // CL:5, no extra, read body -> should reuse + #[case(5, None, false, true)] // CL:5, no extra, no read -> should reuse + #[case(5, Some(&b"extra"[..]), true, false)] // CL:5, extra, read body -> should NOT reuse + #[case(5, Some(&b"extra"[..]), false, false)] // CL:5, extra, no read -> should NOT reuse + #[tokio::test] + async fn test_reuse_with_preread_body_overread( + #[case] content_length: usize, + #[case] extra_bytes: Option<&[u8]>, + #[case] read_body: bool, + #[case] expect_reuse: bool, + ) { + init_log(); + + let body = b"hello"; + + // Build the complete HTTP request in a single buffer + // (all body is preread with header) + let mut request_data = Vec::new(); + request_data.extend_from_slice(b"GET / HTTP/1.1\r\n"); + request_data.extend_from_slice( + format!("Host: pingora.org\r\nContent-Length: {content_length}\r\n\r\n",).as_bytes(), + ); + + if content_length > 0 { + request_data.extend_from_slice(&body[..content_length]); + } + + if let Some(extra) = extra_bytes { + request_data.extend_from_slice(extra); + } + + let mock_io = Builder::new().read(&request_data).build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + http_stream.read_request().await.unwrap(); + + // Conditionally read the body + if read_body { + let result = http_stream.read_body_bytes().await.unwrap(); + + if content_length == 0 { + assert!( + result.is_none(), + "Body should be empty for Content-Length: 0" + ); + } else { + let body_result = result.unwrap(); + assert_eq!(body_result.as_ref(), &body[..content_length]); + } + assert_eq!(http_stream.body_bytes_read(), content_length); + } + + let reused = http_stream.reuse().await.unwrap(); + assert_eq!(reused.is_some(), expect_reuse); + } + + /// Test session reuse with chunked encoding and separate reads. + /// When extra bytes are read beyond the request body, the session should NOT be reused. + /// Test matrix includes whether reading body bytes is polled. + #[rstest] + #[case(true)] + #[case(false)] + #[tokio::test] + async fn test_reuse_with_chunked_body_overread(#[case] read_body: bool) { + init_log(); + + let headers = b"GET / HTTP/1.1\r\nHost: pingora.org\r\nTransfer-Encoding: chunked\r\n\r\n"; + let body_and_extra = b"5\r\nhello\r\n0\r\n\r\nextra"; + + let mock_io = Builder::new().read(headers).read(body_and_extra).build(); + + let mut http_stream = HttpSession::new(Box::new(mock_io)); + http_stream.read_request().await.unwrap(); + assert!(http_stream.is_chunked_encoding()); + + if read_body { + let result = http_stream.read_body_bytes().await.unwrap(); + assert_eq!(result.unwrap().as_ref(), b"hello"); + + // Read terminating chunk (returns None) + let result = http_stream.read_body_bytes().await.unwrap(); + assert!(result.is_none()); + + assert_eq!(http_stream.body_bytes_read(), 5); + } + + let reused = http_stream.reuse().await.unwrap(); + assert!(reused.is_none()); + } +} diff --git a/pingora-core/src/protocols/http/v2/client.rs b/pingora-core/src/protocols/http/v2/client.rs index 51b2ea75..dd3a14d4 100644 --- a/pingora-core/src/protocols/http/v2/client.rs +++ b/pingora-core/src/protocols/http/v2/client.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -27,6 +27,7 @@ use pingora_timeout::timeout; use std::io::ErrorKind; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; +use std::task::{ready, Context, Poll}; use std::time::Duration; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::sync::watch; @@ -51,9 +52,11 @@ pub struct Http2Session { /// The timeout is reset on every write. This is not a timeout on the overall duration of the /// request. pub write_timeout: Option, - pub(crate) conn: ConnectionRef, + pub conn: ConnectionRef, // Indicate that whether a END_STREAM is already sent ended: bool, + // Total DATA payload bytes received from upstream response + body_recv: usize, } impl Drop for Http2Session { @@ -75,6 +78,7 @@ impl Http2Session { write_timeout: None, conn, ended: false, + body_recv: 0, } } @@ -176,7 +180,7 @@ impl Http2Session { } let Some(resp_fut) = self.resp_fut.take() else { - panic!("Try to response header is already read") + panic!("Try to take response header, but it is already taken") }; let res = match self.read_timeout { @@ -193,6 +197,35 @@ impl Http2Session { Ok(()) } + #[doc(hidden)] + pub fn poll_read_response_header( + &mut self, + cx: &mut Context<'_>, + ) -> Poll> { + if self.response_header.is_some() { + panic!("H2 response header is already read") + } + + let Some(mut resp_fut) = self.resp_fut.take() else { + panic!("Try to take response header, but it is already taken") + }; + + let res = match resp_fut.poll_unpin(cx) { + Poll::Ready(Ok(res)) => res, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + Poll::Pending => { + self.resp_fut = Some(resp_fut); + return Poll::Pending; + } + }; + + let (resp, body_reader) = res.into_parts(); + self.response_header = Some(resp.into()); + self.response_body_reader = Some(body_reader); + + Poll::Ready(Ok(())) + } + /// Read the response body /// /// `None` means, no more body to read @@ -226,11 +259,36 @@ impl Http2Session { .flow_control() .release_capacity(data.len()) .or_err(ReadError, "while releasing h2 response body capacity")?; + self.body_recv = self.body_recv.saturating_add(data.len()); } Ok(body) } + #[doc(hidden)] + pub fn poll_read_response_body( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>> { + let Some(body_reader) = self.response_body_reader.as_mut() else { + // req is not sent or response is already read + // TODO: warn + return Poll::Ready(None); + }; + + let data = match ready!(body_reader.poll_data(cx)).transpose() { + Ok(data) => data, + Err(err) => return Poll::Ready(Some(Err(err))), + }; + + if let Some(data) = data { + body_reader.flow_control().release_capacity(data.len())?; + return Poll::Ready(Some(Ok(data))); + } + + Poll::Ready(None) + } + /// Whether the response has ended pub fn response_finished(&self) -> bool { // if response_body_reader doesn't exist, the response is not even read yet @@ -388,6 +446,11 @@ impl Http2Session { self.conn.id() } + /// Upstream response body bytes received (HTTP/2 DATA payload; excludes headers/framing). + pub fn body_bytes_received(&self) -> usize { + self.body_recv + } + /// take the body sender to another task to perform duplex read and write pub fn take_request_body_writer(&mut self) -> Option> { self.send_body.take() @@ -562,3 +625,67 @@ async fn do_ping_pong( } } } + +#[cfg(test)] +mod tests_h2 { + use super::*; + use bytes::Bytes; + use http::{Response, StatusCode}; + use tokio::io::duplex; + + #[tokio::test] + async fn h2_body_bytes_received_multi_frames() { + let (client_io, server_io) = duplex(65536); + + // Server: respond with two DATA frames "a" and "bc" + tokio::spawn(async move { + let mut conn = h2::server::handshake(server_io).await.unwrap(); + if let Some(result) = conn.accept().await { + let (req, mut send_resp) = result.unwrap(); + assert_eq!(req.method(), http::Method::GET); + let resp = Response::builder().status(StatusCode::OK).body(()).unwrap(); + let mut send_stream = send_resp.send_response(resp, false).unwrap(); + send_stream.send_data(Bytes::from("a"), false).unwrap(); + send_stream.send_data(Bytes::from("bc"), true).unwrap(); + // Signal graceful shutdown so the accept loop can exit after the client finishes + conn.graceful_shutdown(); + } + // Drive the server connection until the client closes + while let Some(_res) = conn.accept().await {} + }); + + // Client: build Http2Session and read response + let (send_req, connection) = h2::client::handshake(client_io).await.unwrap(); + let (closed_tx, closed_rx) = tokio::sync::watch::channel(false); + let ping_timeout = Arc::new(AtomicBool::new(false)); + tokio::spawn(async move { + let _ = connection.await; + let _ = closed_tx.send(true); + }); + + let digest = Digest::default(); + let conn_ref = crate::connectors::http::v2::ConnectionRef::new( + send_req.clone(), + closed_rx, + ping_timeout, + 0, + 1, + digest, + ); + let mut h2s = Http2Session::new(send_req, conn_ref); + + // minimal request + let mut req = RequestHeader::build("GET", b"/", None).unwrap(); + req.insert_header(http::header::HOST, "example.com") + .unwrap(); + h2s.write_request_header(Box::new(req), true).unwrap(); + h2s.read_response_header().await.unwrap(); + + let mut total = 0; + while let Some(chunk) = h2s.read_response_body().await.unwrap() { + total += chunk.len(); + } + assert_eq!(total, 3); + assert_eq!(h2s.body_bytes_received(), 3); + } +} diff --git a/pingora-core/src/protocols/http/v2/mod.rs b/pingora-core/src/protocols/http/v2/mod.rs index a588f4bd..01711807 100644 --- a/pingora-core/src/protocols/http/v2/mod.rs +++ b/pingora-core/src/protocols/http/v2/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-core/src/protocols/http/v2/server.rs b/pingora-core/src/protocols/http/v2/server.rs index 883fa22f..363b7357 100644 --- a/pingora-core/src/protocols/http/v2/server.rs +++ b/pingora-core/src/protocols/http/v2/server.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -26,6 +26,7 @@ use log::{debug, warn}; use pingora_http::{RequestHeader, ResponseHeader}; use pingora_timeout::timeout; use std::sync::Arc; +use std::task::ready; use std::time::Duration; use crate::protocols::http::body_buffer::FixedBuffer; @@ -48,6 +49,7 @@ pub use h2::server::Builder as H2Options; pub async fn handshake(io: Stream, options: Option) -> Result> { let options = options.unwrap_or_default(); let res = options.handshake(io).await; + match res { Ok(connection) => { debug!("H2 handshake done."); @@ -188,6 +190,27 @@ impl HttpSession { Ok(data) } + #[doc(hidden)] + pub fn poll_read_body_bytes( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>> { + let data = match ready!(self.request_body_reader.poll_data(cx)).transpose() { + Ok(data) => data, + Err(err) => return Poll::Ready(Some(Err(err))), + }; + + if let Some(data) = data { + self.body_read += data.len(); + self.request_body_reader + .flow_control() + .release_capacity(data.len())?; + return Poll::Ready(Some(Ok(data))); + } + + Poll::Ready(None) + } + async fn do_drain_request_body(&mut self) -> Result<()> { loop { match self.read_body_bytes().await { @@ -397,6 +420,18 @@ impl HttpSession { } None => end, }, + HttpTask::UpgradedBody(..) => { + // Seeing an Upgraded body means that the upstream session + // was H1.1 that upgraded. + // + // While the downstream H2 session may encapsulate the opaque body bytes, + // this represents an undefined discrepancy and change between how + // the upstream and downstream sessions began intepreting the response body. + return Error::e_explain( + ErrorType::InternalError, + "upgraded body on h2 server session", + ); + } HttpTask::Trailer(Some(trailers)) => { self.write_trailers(*trailers)?; true @@ -449,6 +484,11 @@ impl HttpSession { } } + #[doc(hidden)] + pub fn take_response_body_writer(&mut self) -> Option> { + self.send_response_body.take() + } + // This is a hack for pingora-proxy to create subrequests from h2 server session // TODO: be able to convert from h2 to h1 subrequest pub fn pseudo_raw_h1_request_header(&self) -> Bytes { @@ -500,7 +540,7 @@ impl HttpSession { /// This async fn will be pending forever until the client closes the stream/connection /// This function is used for watching client status so that the server is able to cancel /// its internal tasks as the client waiting for the tasks goes away - pub fn idle(&mut self) -> Idle { + pub fn idle(&mut self) -> Idle<'_> { Idle(self) } diff --git a/pingora-core/src/protocols/l4/ext.rs b/pingora-core/src/protocols/l4/ext.rs index 9f7d8ec4..9a632e96 100644 --- a/pingora-core/src/protocols/l4/ext.rs +++ b/pingora-core/src/protocols/l4/ext.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -154,10 +154,7 @@ fn get_opt_sized(sock: c_int, opt: c_int, val: c_int) -> io::Result { get_opt(sock, opt, val, &mut payload, &mut size)?; if size != expected_size { - return Err(std::io::Error::new( - std::io::ErrorKind::Other, - "get_opt size mismatch", - )); + return Err(std::io::Error::other("get_opt size mismatch")); } // Assume getsockopt() will set the value properly let payload = unsafe { payload.assume_init() }; diff --git a/pingora-core/src/protocols/l4/listener.rs b/pingora-core/src/protocols/l4/listener.rs index 88f5fe85..7d00005e 100644 --- a/pingora-core/src/protocols/l4/listener.rs +++ b/pingora-core/src/protocols/l4/listener.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-core/src/protocols/l4/mod.rs b/pingora-core/src/protocols/l4/mod.rs index bda24121..7e24cd88 100644 --- a/pingora-core/src/protocols/l4/mod.rs +++ b/pingora-core/src/protocols/l4/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -18,3 +18,4 @@ pub mod ext; pub mod listener; pub mod socket; pub mod stream; +pub mod virt; diff --git a/pingora-core/src/protocols/l4/socket.rs b/pingora-core/src/protocols/l4/socket.rs index 3a764920..46decd2f 100644 --- a/pingora-core/src/protocols/l4/socket.rs +++ b/pingora-core/src/protocols/l4/socket.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -238,8 +238,7 @@ impl std::net::ToSocketAddrs for SocketAddr { if let Some(inet) = self.as_inet() { Ok(std::iter::once(*inet)) } else { - Err(std::io::Error::new( - std::io::ErrorKind::Other, + Err(std::io::Error::other( "UDS socket cannot be used as inet socket", )) } diff --git a/pingora-core/src/protocols/l4/stream.rs b/pingora-core/src/protocols/l4/stream.rs index fd50d77f..4aa70f70 100644 --- a/pingora-core/src/protocols/l4/stream.rs +++ b/pingora-core/src/protocols/l4/stream.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -37,6 +37,7 @@ use tokio::net::TcpStream; use tokio::net::UnixStream; use crate::protocols::l4::ext::{set_tcp_keepalive, TcpKeepalive}; +use crate::protocols::l4::virt; use crate::protocols::raw_connect::ProxyDigest; use crate::protocols::{ GetProxyDigest, GetSocketDigest, GetTimingDigest, Peek, Shutdown, SocketDigest, Ssl, @@ -49,6 +50,7 @@ enum RawStream { Tcp(TcpStream), #[cfg(unix)] Unix(UnixStream), + Virtual(virt::VirtualSocketStream), } impl AsyncRead for RawStream { @@ -63,6 +65,7 @@ impl AsyncRead for RawStream { RawStream::Tcp(s) => Pin::new_unchecked(s).poll_read(cx, buf), #[cfg(unix)] RawStream::Unix(s) => Pin::new_unchecked(s).poll_read(cx, buf), + RawStream::Virtual(s) => Pin::new_unchecked(s).poll_read(cx, buf), } } } @@ -76,6 +79,7 @@ impl AsyncWrite for RawStream { RawStream::Tcp(s) => Pin::new_unchecked(s).poll_write(cx, buf), #[cfg(unix)] RawStream::Unix(s) => Pin::new_unchecked(s).poll_write(cx, buf), + RawStream::Virtual(s) => Pin::new_unchecked(s).poll_write(cx, buf), } } } @@ -87,6 +91,7 @@ impl AsyncWrite for RawStream { RawStream::Tcp(s) => Pin::new_unchecked(s).poll_flush(cx), #[cfg(unix)] RawStream::Unix(s) => Pin::new_unchecked(s).poll_flush(cx), + RawStream::Virtual(s) => Pin::new_unchecked(s).poll_flush(cx), } } } @@ -98,6 +103,7 @@ impl AsyncWrite for RawStream { RawStream::Tcp(s) => Pin::new_unchecked(s).poll_shutdown(cx), #[cfg(unix)] RawStream::Unix(s) => Pin::new_unchecked(s).poll_shutdown(cx), + RawStream::Virtual(s) => Pin::new_unchecked(s).poll_shutdown(cx), } } } @@ -113,6 +119,7 @@ impl AsyncWrite for RawStream { RawStream::Tcp(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs), #[cfg(unix)] RawStream::Unix(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs), + RawStream::Virtual(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs), } } } @@ -122,6 +129,7 @@ impl AsyncWrite for RawStream { RawStream::Tcp(s) => s.is_write_vectored(), #[cfg(unix)] RawStream::Unix(s) => s.is_write_vectored(), + RawStream::Virtual(s) => s.is_write_vectored(), } } } @@ -132,6 +140,7 @@ impl AsRawFd for RawStream { match self { RawStream::Tcp(s) => s.as_raw_fd(), RawStream::Unix(s) => s.as_raw_fd(), + RawStream::Virtual(_) => -1, // Virtual stream does not have a real fd } } } @@ -141,6 +150,8 @@ impl AsRawSocket for RawStream { fn as_raw_socket(&self) -> std::os::windows::io::RawSocket { match self { RawStream::Tcp(s) => s.as_raw_socket(), + // Virtual stream does not have a real socket, return INVALID_SOCKET (!0) + RawStream::Virtual(_) => !0, } } } @@ -192,6 +203,7 @@ impl AsyncRead for RawStreamWrapper { RawStream::Tcp(s) => Pin::new_unchecked(s).poll_read(cx, buf), #[cfg(unix)] RawStream::Unix(s) => Pin::new_unchecked(s).poll_read(cx, buf), + RawStream::Virtual(s) => Pin::new_unchecked(s).poll_read(cx, buf), } } } @@ -213,6 +225,7 @@ impl AsyncRead for RawStreamWrapper { match &mut rs_wrapper.stream { RawStream::Tcp(s) => return Pin::new_unchecked(s).poll_read(cx, buf), RawStream::Unix(s) => return Pin::new_unchecked(s).poll_read(cx, buf), + RawStream::Virtual(s) => return Pin::new_unchecked(s).poll_read(cx, buf), } } } @@ -264,6 +277,7 @@ impl AsyncRead for RawStreamWrapper { } // Unix RX timestamp only works with datagram for now, so we do not care about it RawStream::Unix(s) => unsafe { Pin::new_unchecked(s).poll_read(cx, buf) }, + RawStream::Virtual(s) => unsafe { Pin::new_unchecked(s).poll_read(cx, buf) }, } } } @@ -276,6 +290,7 @@ impl AsyncWrite for RawStreamWrapper { RawStream::Tcp(s) => Pin::new_unchecked(s).poll_write(cx, buf), #[cfg(unix)] RawStream::Unix(s) => Pin::new_unchecked(s).poll_write(cx, buf), + RawStream::Virtual(s) => Pin::new_unchecked(s).poll_write(cx, buf), } } } @@ -287,6 +302,7 @@ impl AsyncWrite for RawStreamWrapper { RawStream::Tcp(s) => Pin::new_unchecked(s).poll_flush(cx), #[cfg(unix)] RawStream::Unix(s) => Pin::new_unchecked(s).poll_flush(cx), + RawStream::Virtual(s) => Pin::new_unchecked(s).poll_flush(cx), } } } @@ -298,6 +314,7 @@ impl AsyncWrite for RawStreamWrapper { RawStream::Tcp(s) => Pin::new_unchecked(s).poll_shutdown(cx), #[cfg(unix)] RawStream::Unix(s) => Pin::new_unchecked(s).poll_shutdown(cx), + RawStream::Virtual(s) => Pin::new_unchecked(s).poll_shutdown(cx), } } } @@ -313,6 +330,7 @@ impl AsyncWrite for RawStreamWrapper { RawStream::Tcp(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs), #[cfg(unix)] RawStream::Unix(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs), + RawStream::Virtual(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs), } } } @@ -379,18 +397,32 @@ impl Stream { /// set TCP nodelay for this connection if `self` is TCP pub fn set_nodelay(&mut self) -> Result<()> { - if let RawStream::Tcp(s) = &self.stream_mut().get_mut().stream { - s.set_nodelay(true) - .or_err(ConnectError, "failed to set_nodelay")?; + match &self.stream_mut().get_mut().stream { + RawStream::Tcp(s) => { + s.set_nodelay(true) + .or_err(ConnectError, "failed to set_nodelay")?; + } + RawStream::Virtual(s) => { + s.set_socket_option(virt::VirtualSockOpt::NoDelay) + .or_err(ConnectError, "failed to set_nodelay on virtual socket")?; + } + _ => (), } Ok(()) } /// set TCP keepalive settings for this connection if `self` is TCP pub fn set_keepalive(&mut self, ka: &TcpKeepalive) -> Result<()> { - if let RawStream::Tcp(s) = &self.stream_mut().get_mut().stream { - debug!("Setting tcp keepalive"); - set_tcp_keepalive(s, ka)?; + match &self.stream_mut().get_mut().stream { + RawStream::Tcp(s) => { + debug!("Setting tcp keepalive"); + set_tcp_keepalive(s, ka)?; + } + RawStream::Virtual(s) => { + s.set_socket_option(virt::VirtualSockOpt::KeepAlive(ka.clone())) + .or_err(ConnectError, "failed to set_keepalive on virtual socket")?; + } + _ => (), } Ok(()) } @@ -456,6 +488,27 @@ impl From for Stream { } } +impl From for Stream { + fn from(s: virt::VirtualSocketStream) -> Self { + Stream { + stream: Some(BufStream::with_capacity( + 0, + 0, + RawStreamWrapper::new(RawStream::Virtual(s)), + )), + rewind_read_buf: Vec::new(), + buffer_write: true, + established_ts: SystemTime::now(), + proxy_digest: None, + socket_digest: None, + tracer: None, + read_pending_time: AccumulatedDuration::new(), + write_pending_time: AccumulatedDuration::new(), + rx_ts: None, + } + } +} + #[cfg(unix)] impl From for Stream { fn from(s: UnixStream) -> Self { @@ -576,6 +629,10 @@ impl Drop for Stream { RawStream::Tcp(s) => s.nodelay().err(), #[cfg(unix)] RawStream::Unix(s) => s.local_addr().err(), + RawStream::Virtual(_) => { + // TODO: should this do something? + None + } }; if let Some(e) = ret { match e.kind() { diff --git a/pingora-core/src/protocols/l4/virt.rs b/pingora-core/src/protocols/l4/virt.rs new file mode 100644 index 00000000..5148e417 --- /dev/null +++ b/pingora-core/src/protocols/l4/virt.rs @@ -0,0 +1,161 @@ +//! Provides [`VirtualSocketStream`]. + +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +use tokio::io::{AsyncRead, AsyncWrite}; + +use super::ext::TcpKeepalive; + +/// A limited set of socket options that can be set on a [`VirtualSocket`]. +#[non_exhaustive] +#[derive(Debug, Clone)] +pub enum VirtualSockOpt { + NoDelay, + KeepAlive(TcpKeepalive), +} + +/// A "virtual" socket that supports async read and write operations. +pub trait VirtualSocket: AsyncRead + AsyncWrite + Unpin + Send + Sync + std::fmt::Debug { + /// Set a socket option. + fn set_socket_option(&self, opt: VirtualSockOpt) -> std::io::Result<()>; +} + +/// Wrapper around any type implementing [`VirtualSocket`]. +#[derive(Debug)] +pub struct VirtualSocketStream { + pub(crate) socket: Box, +} + +impl VirtualSocketStream { + pub fn new(socket: Box) -> Self { + Self { socket } + } + + #[inline] + pub fn set_socket_option(&self, opt: VirtualSockOpt) -> std::io::Result<()> { + self.socket.set_socket_option(opt) + } +} + +impl AsyncRead for VirtualSocketStream { + #[inline] + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut *self.get_mut().socket).poll_read(cx, buf) + } +} + +impl AsyncWrite for VirtualSocketStream { + #[inline] + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut *self.get_mut().socket).poll_write(cx, buf) + } + + #[inline] + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut *self.get_mut().socket).poll_flush(cx) + } + + #[inline] + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut *self.get_mut().socket).poll_shutdown(cx) + } +} + +#[cfg(test)] +mod tests { + use std::sync::{Arc, Mutex}; + + use tokio::io::{AsyncReadExt, AsyncWriteExt as _}; + + use crate::protocols::l4::stream::Stream; + + use super::*; + + #[derive(Debug)] + struct StaticVirtualSocket { + content: Vec, + read_pos: usize, + write_buf: Arc>>, + } + + impl AsyncRead for StaticVirtualSocket { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + debug_assert!(self.read_pos <= self.content.len()); + + let remaining = self.content.len() - self.read_pos; + if remaining == 0 { + return Poll::Ready(Ok(())); + } + + let to_read = std::cmp::min(remaining, buf.remaining()); + buf.put_slice(&self.content[self.read_pos..self.read_pos + to_read]); + self.read_pos += to_read; + + Poll::Ready(Ok(())) + } + } + + impl AsyncWrite for StaticVirtualSocket { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + // write to internal buffer + let this = self.get_mut(); + this.write_buf.lock().unwrap().extend_from_slice(buf); + Poll::Ready(Ok(buf.len())) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + } + + impl VirtualSocket for StaticVirtualSocket { + fn set_socket_option(&self, _opt: VirtualSockOpt) -> std::io::Result<()> { + Ok(()) + } + } + + /// Basic test that ensures reading and writing works with a virtual socket. + // + /// Mostly just ensures that construction works and the plumbing is correct. + #[tokio::test] + async fn test_stream_virtual() { + let content = b"hello virtual world"; + let write_buf = Arc::new(Mutex::new(Vec::new())); + let mut stream = Stream::from(VirtualSocketStream::new(Box::new(StaticVirtualSocket { + content: content.to_vec(), + read_pos: 0, + write_buf: write_buf.clone(), + }))); + + let mut buf = Vec::new(); + let out = stream.read_to_end(&mut buf).await.unwrap(); + assert_eq!(out, content.len()); + assert_eq!(buf, content); + + stream.write_all(content).await.unwrap(); + assert_eq!(write_buf.lock().unwrap().as_slice(), content); + } +} diff --git a/pingora-core/src/protocols/mod.rs b/pingora-core/src/protocols/mod.rs index 1bad6b28..904ed09b 100644 --- a/pingora-core/src/protocols/mod.rs +++ b/pingora-core/src/protocols/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -232,6 +232,46 @@ mod ext_io_impl { impl Peek for DuplexStream {} } +#[cfg(unix)] +pub mod ext_test { + use std::sync::Arc; + + use async_trait::async_trait; + + use super::{ + raw_connect, GetProxyDigest, GetSocketDigest, GetTimingDigest, Peek, Shutdown, + SocketDigest, Ssl, TimingDigest, UniqueID, UniqueIDType, + }; + + #[async_trait] + impl Shutdown for tokio::net::UnixStream { + async fn shutdown(&mut self) -> () {} + } + impl UniqueID for tokio::net::UnixStream { + fn id(&self) -> UniqueIDType { + 0 + } + } + impl Ssl for tokio::net::UnixStream {} + impl GetTimingDigest for tokio::net::UnixStream { + fn get_timing_digest(&self) -> Vec> { + vec![] + } + } + impl GetProxyDigest for tokio::net::UnixStream { + fn get_proxy_digest(&self) -> Option> { + None + } + } + impl GetSocketDigest for tokio::net::UnixStream { + fn get_socket_digest(&self) -> Option> { + None + } + } + + impl Peek for tokio::net::UnixStream {} +} + #[cfg(unix)] pub(crate) trait ConnFdReusable { fn check_fd_match(&self, fd: V) -> bool; diff --git a/pingora-core/src/protocols/proxy_protocol.rs b/pingora-core/src/protocols/proxy_protocol.rs index 36ec1e91..a6700993 100644 --- a/pingora-core/src/protocols/proxy_protocol.rs +++ b/pingora-core/src/protocols/proxy_protocol.rs @@ -60,14 +60,9 @@ pub fn source_addr_from_header(header: &ProxyHeader) -> Option { _ => None, }, ProxyHeader::Version2 { - command, - addresses, - .. + command, addresses, .. } => { - if matches!( - command, - proxy_protocol::version2::ProxyCommand::Local - ) { + if matches!(command, proxy_protocol::version2::ProxyCommand::Local) { return None; } match addresses { @@ -92,11 +87,10 @@ pub fn header_has_source_addr(header: &ProxyHeader) -> bool { proxy_protocol::version1::ProxyAddresses::Ipv4 { .. } | proxy_protocol::version1::ProxyAddresses::Ipv6 { .. } ), - ProxyHeader::Version2 { command, addresses, .. } => { - if matches!( - command, - proxy_protocol::version2::ProxyCommand::Local - ) { + ProxyHeader::Version2 { + command, addresses, .. + } => { + if matches!(command, proxy_protocol::version2::ProxyCommand::Local) { return false; } matches!( @@ -233,10 +227,15 @@ pub async fn consume_proxy_header(stream: &mut Stream) -> Result { + ProxyDetection::NeedsMore + | ProxyDetection::Invalid + | ProxyDetection::HeaderLength(_) => { // Buffer looks like it could be a PROXY header but connection closed debug!("Stream closed while reading PROXY header (buffer looks like PROXY header)"); - return Error::e_explain(PROXY_PROTOCOL_ERROR, "Incomplete PROXY protocol header"); + return Error::e_explain( + PROXY_PROTOCOL_ERROR, + "Incomplete PROXY protocol header", + ); } } } diff --git a/pingora-core/src/protocols/raw_connect.rs b/pingora-core/src/protocols/raw_connect.rs index 94b3130e..80158edc 100644 --- a/pingora-core/src/protocols/raw_connect.rs +++ b/pingora-core/src/protocols/raw_connect.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -19,6 +19,8 @@ //! so that the protocol encapsulated can use the stream directly. //! This idea only works for CONNECT over HTTP 1.1 and localhost (or where the server is close by). +use std::any::Any; + use super::http::v1::client::HttpSession; use super::http::v1::common::*; use super::Stream; @@ -35,7 +37,14 @@ use tokio::io::AsyncWriteExt; /// `request_header` should include the necessary request headers for the CONNECT protocol. /// /// When successful, a [`Stream`] will be returned which is the established CONNECT proxy connection. -pub async fn connect(stream: Stream, request_header: &ReqHeader) -> Result<(Stream, ProxyDigest)> { +pub async fn connect

( + stream: Stream, + request_header: &ReqHeader, + peer: &P, +) -> Result<(Stream, ProxyDigest)> +where + P: crate::upstreams::peer::Peer, +{ let mut http = HttpSession::new(stream); // We write to stream directly because HttpSession doesn't write req header in auth form @@ -53,7 +62,7 @@ pub async fn connect(stream: Stream, request_header: &ReqHeader) -> Result<(Stre let resp_header = http.read_resp_header_parts().await?; Ok(( http.underlying_stream, - validate_connect_response(resp_header)?, + validate_connect_response(resp_header, peer, request_header)?, )) } @@ -104,11 +113,19 @@ where pub struct ProxyDigest { /// The response header the proxy returns pub response: Box, + /// Optional arbitrary data. + pub user_data: Option>, } impl ProxyDigest { - pub fn new(response: Box) -> Self { - ProxyDigest { response } + pub fn new( + response: Box, + user_data: Option>, + ) -> Self { + ProxyDigest { + response, + user_data, + } } } @@ -182,7 +199,14 @@ fn http_req_header_to_wire_auth_form(req: &ReqHeader) -> BytesMut { } #[inline] -fn validate_connect_response(resp: Box) -> Result { +fn validate_connect_response

( + resp: Box, + peer: &P, + req: &ReqHeader, +) -> Result +where + P: crate::upstreams::peer::Peer, +{ if !resp.status.is_success() { return Error::e_because( ConnectProxyFailure, @@ -201,7 +225,11 @@ fn validate_connect_response(resp: Box) -> Result { ConnectProxyError::boxed_new(resp), ); } - Ok(ProxyDigest::new(resp)) + + let user_data = peer + .proxy_digest_user_data_hook() + .and_then(|hook| hook(req, &resp)); + Ok(ProxyDigest::new(resp, user_data)) } #[cfg(test)] @@ -252,37 +280,80 @@ mod test_sync { #[test] fn test_validate_connect_response() { + use crate::upstreams::peer::BasicPeer; + + struct DummyUserData { + some_num: i32, + some_string: String, + } + + let peer_no_data = BasicPeer::new("127.0.0.1:80"); + let mut peer_with_data = peer_no_data.clone(); + peer_with_data.options.proxy_digest_user_data_hook = Some(std::sync::Arc::new( + |_req: &http::request::Parts, _resp: &pingora_http::ResponseHeader| { + Some(Box::new(DummyUserData { + some_num: 42, + some_string: "test".to_string(), + }) as Box) + }, + )); + + let request = http::Request::builder() + .method("CONNECT") + .uri("https://example.com:443/") + .body(()) + .unwrap(); + let (req_header, _) = request.into_parts(); + let resp = ResponseHeader::build(200, None).unwrap(); - validate_connect_response(Box::new(resp)).unwrap(); + let proxy_digest = + validate_connect_response(Box::new(resp), &peer_with_data, &req_header).unwrap(); + assert!(proxy_digest.user_data.is_some()); + let user_data = proxy_digest + .user_data + .as_ref() + .unwrap() + .downcast_ref::() + .unwrap(); + assert_eq!(user_data.some_num, 42); + assert_eq!(user_data.some_string, "test"); + + let resp = ResponseHeader::build(200, None).unwrap(); + let proxy_digest = + validate_connect_response(Box::new(resp), &peer_no_data, &req_header).unwrap(); + assert!(proxy_digest.user_data.is_none()); let resp = ResponseHeader::build(404, None).unwrap(); - assert!(validate_connect_response(Box::new(resp)).is_err()); + assert!(validate_connect_response(Box::new(resp), &peer_with_data, &req_header).is_err()); let mut resp = ResponseHeader::build(200, None).unwrap(); resp.append_header("content-length", 0).unwrap(); - assert!(validate_connect_response(Box::new(resp)).is_ok()); + assert!(validate_connect_response(Box::new(resp), &peer_no_data, &req_header).is_ok()); let mut resp = ResponseHeader::build(200, None).unwrap(); resp.append_header("transfer-encoding", 0).unwrap(); - assert!(validate_connect_response(Box::new(resp)).is_err()); + assert!(validate_connect_response(Box::new(resp), &peer_no_data, &req_header).is_err()); } #[tokio::test] async fn test_connect_write_request() { + use crate::upstreams::peer::BasicPeer; + let wire = b"CONNECT pingora.org:123 HTTP/1.1\r\nhost: pingora.org:123\r\n\r\n"; let mock_io = Box::new(Builder::new().write(wire).build()); let headers: BTreeMap> = BTreeMap::new(); let req = generate_connect_header("pingora.org", 123, headers.iter()).unwrap(); + let peer = BasicPeer::new("127.0.0.1:123"); // ConnectionClosed - assert!(connect(mock_io, &req).await.is_err()); + assert!(connect(mock_io, &req, &peer).await.is_err()); let to_wire = b"CONNECT pingora.org:123 HTTP/1.1\r\nhost: pingora.org:123\r\n\r\n"; let from_wire = b"HTTP/1.1 200 OK\r\n\r\n"; let mock_io = Box::new(Builder::new().write(to_wire).read(from_wire).build()); let req = generate_connect_header("pingora.org", 123, headers.iter()).unwrap(); - let result = connect(mock_io, &req).await; + let result = connect(mock_io, &req, &peer).await; assert!(result.is_ok()); } } diff --git a/pingora-core/src/protocols/tls/boringssl_openssl/client.rs b/pingora-core/src/protocols/tls/boringssl_openssl/client.rs index 161040e9..4e5bded4 100644 --- a/pingora-core/src/protocols/tls/boringssl_openssl/client.rs +++ b/pingora-core/src/protocols/tls/boringssl_openssl/client.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -19,9 +19,10 @@ use crate::protocols::tls::SslStream; use crate::protocols::{ GetProxyDigest, GetSocketDigest, GetTimingDigest, SocketDigest, TimingDigest, IO, }; -use crate::tls::{ssl, ssl::ConnectConfiguration, ssl_sys::X509_V_ERR_INVALID_CALL}; +use crate::tls::{ssl, ssl::ConnectConfiguration, ssl::SslRef, ssl_sys::X509_V_ERR_INVALID_CALL}; use pingora_error::{Error, ErrorType::*, OrErr, Result}; +use std::any::Any; use std::sync::Arc; use std::time::Duration; @@ -30,6 +31,7 @@ pub async fn handshake( conn_config: ConnectConfiguration, domain: &str, io: S, + complete_hook: Option Option> + Send + Sync>>, ) -> Result> { let ssl = conn_config .into_ssl(domain) @@ -38,7 +40,16 @@ pub async fn handshake( .explain_err(TLSHandshakeFailure, |e| format!("ssl stream error: {e}"))?; let handshake_result = stream.connect().await; match handshake_result { - Ok(()) => Ok(stream), + Ok(()) => { + if let Some(hook) = complete_hook { + if let Some(extension) = hook(stream.ssl()) { + if let Some(digest_mut) = stream.ssl_digest_mut() { + digest_mut.extension.set(extension); + } + } + } + Ok(stream) + } Err(e) => { let context = format!("TLS connect() failed: {e}, SNI: {domain}"); match e.code() { diff --git a/pingora-core/src/protocols/tls/boringssl_openssl/mod.rs b/pingora-core/src/protocols/tls/boringssl_openssl/mod.rs index cb6876c3..7d2c1e2b 100644 --- a/pingora-core/src/protocols/tls/boringssl_openssl/mod.rs +++ b/pingora-core/src/protocols/tls/boringssl_openssl/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-core/src/protocols/tls/boringssl_openssl/server.rs b/pingora-core/src/protocols/tls/boringssl_openssl/server.rs index 5795f775..bd14ea70 100644 --- a/pingora-core/src/protocols/tls/boringssl_openssl/server.rs +++ b/pingora-core/src/protocols/tls/boringssl_openssl/server.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -64,10 +64,16 @@ pub async fn handshake_with_callback( .resume_accept() .await .explain_err(TLSHandshakeFailure, |e| format!("TLS accept() failed: {e}"))?; - Ok(tls_stream) - } else { - Ok(tls_stream) } + { + let ssl = tls_stream.ssl(); + if let Some(extension) = callbacks.handshake_complete_callback(ssl).await { + if let Some(digest_mut) = tls_stream.ssl_digest_mut() { + digest_mut.extension.set(extension); + } + } + } + Ok(tls_stream) } #[async_trait] @@ -130,43 +136,23 @@ impl ResumableAccept for SslStream } } -#[tokio::test] -#[cfg(feature = "any_tls")] -async fn test_async_cert() { - use crate::protocols::tls::TlsRef; - use tokio::io::AsyncReadExt; +#[cfg(test)] +mod tests { + use super::handshake_with_callback; use crate::listeners::{TlsAccept, TlsAcceptCallbacks}; - let acceptor = ssl::SslAcceptor::mozilla_intermediate_v5(ssl::SslMethod::tls()) - .unwrap() - .build(); - - struct Callback; - #[async_trait] - impl TlsAccept for Callback { - async fn certificate_callback(&self, ssl: &mut TlsRef) -> () { - assert_eq!( - ssl.servername(ssl::NameType::HOST_NAME).unwrap(), - "pingora.org" - ); - let cert = format!("{}/tests/keys/server.crt", env!("CARGO_MANIFEST_DIR")); - let key = format!("{}/tests/keys/key.pem", env!("CARGO_MANIFEST_DIR")); - - let cert_bytes = std::fs::read(cert).unwrap(); - let cert = crate::tls::x509::X509::from_pem(&cert_bytes).unwrap(); - - let key_bytes = std::fs::read(key).unwrap(); - let key = crate::tls::pkey::PKey::private_key_from_pem(&key_bytes).unwrap(); - ext::ssl_use_certificate(ssl, &cert).unwrap(); - ext::ssl_use_private_key(ssl, &key).unwrap(); - } - } - - let cb: TlsAcceptCallbacks = Box::new(Callback); + use crate::protocols::tls::SslStream; + use crate::protocols::tls::TlsRef; + use crate::tls::ext; + use crate::tls::ssl; - let (client, server) = tokio::io::duplex(1024); + use async_trait::async_trait; + use std::pin::Pin; + use std::sync::Arc; + use tokio::io::DuplexStream; - tokio::spawn(async move { + async fn client_task(client: DuplexStream) { + use tokio::io::AsyncReadExt; let ssl_context = ssl::SslContext::builder(ssl::SslMethod::tls()) .unwrap() .build(); @@ -177,9 +163,87 @@ async fn test_async_cert() { Pin::new(&mut stream).connect().await.unwrap(); let mut buf = [0; 1]; let _ = stream.read(&mut buf).await; - }); + } - handshake_with_callback(&acceptor, server, &cb) - .await - .unwrap(); + #[tokio::test] + #[cfg(feature = "any_tls")] + async fn test_async_cert() { + let acceptor = ssl::SslAcceptor::mozilla_intermediate_v5(ssl::SslMethod::tls()) + .unwrap() + .build(); + + struct Callback; + #[async_trait] + impl TlsAccept for Callback { + async fn certificate_callback(&self, ssl: &mut TlsRef) -> () { + assert_eq!( + ssl.servername(ssl::NameType::HOST_NAME).unwrap(), + "pingora.org" + ); + let cert = format!("{}/tests/keys/server.crt", env!("CARGO_MANIFEST_DIR")); + let key = format!("{}/tests/keys/key.pem", env!("CARGO_MANIFEST_DIR")); + + let cert_bytes = std::fs::read(cert).unwrap(); + let cert = crate::tls::x509::X509::from_pem(&cert_bytes).unwrap(); + + let key_bytes = std::fs::read(key).unwrap(); + let key = crate::tls::pkey::PKey::private_key_from_pem(&key_bytes).unwrap(); + ext::ssl_use_certificate(ssl, &cert).unwrap(); + ext::ssl_use_private_key(ssl, &key).unwrap(); + } + } + + let cb: TlsAcceptCallbacks = Box::new(Callback); + + let (client, server) = tokio::io::duplex(1024); + + tokio::spawn(client_task(client)); + + handshake_with_callback(&acceptor, server, &cb) + .await + .unwrap(); + } + + #[tokio::test] + #[cfg(feature = "openssl_derived")] + async fn test_handshake_complete_callback() { + use crate::tls::ssl::SslFiletype; + + let cert = format!("{}/tests/keys/server.crt", env!("CARGO_MANIFEST_DIR")); + let key = format!("{}/tests/keys/key.pem", env!("CARGO_MANIFEST_DIR")); + + let acceptor = { + let mut builder = + ssl::SslAcceptor::mozilla_intermediate_v5(ssl::SslMethod::tls()).unwrap(); + builder.set_certificate_chain_file(cert).unwrap(); + builder.set_private_key_file(key, SslFiletype::PEM).unwrap(); + builder.build() + }; + + struct Sni(String); + struct Callback; + #[async_trait] + impl TlsAccept for Callback { + async fn handshake_complete_callback( + &self, + ssl: &TlsRef, + ) -> Option> { + let sni = ssl.servername(ssl::NameType::HOST_NAME)?.to_string(); + Some(Arc::new(Sni(sni))) + } + } + + let cb: TlsAcceptCallbacks = Box::new(Callback); + + let (client, server) = tokio::io::duplex(1024); + + tokio::spawn(client_task(client)); + + let stream = handshake_with_callback(&acceptor, server, &cb) + .await + .unwrap(); + let ssl_digest = stream.ssl_digest().unwrap(); + let sni = ssl_digest.extension.get::().unwrap(); + assert_eq!(sni.0, "pingora.org"); + } } diff --git a/pingora-core/src/protocols/tls/boringssl_openssl/stream.rs b/pingora-core/src/protocols/tls/boringssl_openssl/stream.rs index 25dab254..894244c0 100644 --- a/pingora-core/src/protocols/tls/boringssl_openssl/stream.rs +++ b/pingora-core/src/protocols/tls/boringssl_openssl/stream.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -91,6 +91,12 @@ impl SslStream { pub fn ssl_digest(&self) -> Option> { self.digest.clone() } + + /// Attempts to obtain a mutable reference to the SslDigest. + /// This method returns `None` if the SslDigest is currently held by other references. + pub(crate) fn ssl_digest_mut(&mut self) -> Option<&mut SslDigest> { + Arc::get_mut(self.digest.as_mut()?) + } } use std::ops::{Deref, DerefMut}; diff --git a/pingora-core/src/protocols/tls/client_hello.rs b/pingora-core/src/protocols/tls/client_hello.rs index 157a1438..14c8f7d6 100644 --- a/pingora-core/src/protocols/tls/client_hello.rs +++ b/pingora-core/src/protocols/tls/client_hello.rs @@ -560,7 +560,9 @@ pub fn peek_client_hello(stream: &S) -> io::Result { // Invalid PROXY header - try parsing as ClientHello - debug!("Invalid PROXY protocol header detected, trying ClientHello parse"); + debug!( + "Invalid PROXY protocol header detected, trying ClientHello parse" + ); 0 } } @@ -570,7 +572,11 @@ pub fn peek_client_hello(stream: &S) -> io::Result= data.len() { - debug!("PROXY header offset {} exceeds data length {}, no ClientHello data", proxy_offset, data.len()); + debug!( + "PROXY header offset {} exceeds data length {}, no ClientHello data", + proxy_offset, + data.len() + ); Ok(None) } else { Ok(ClientHello::parse(&data[proxy_offset..])) @@ -621,11 +627,9 @@ mod tests { 0x00, 0x00, 0x41, // Handshake Length (65 bytes) 0x03, 0x03, // Client Version: TLS 1.2 // Random (32 bytes) - 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, - 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, - 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, - 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, - 0x00, // Session ID Length + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, + 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, + 0x1c, 0x1d, 0x1e, 0x1f, 0x00, // Session ID Length 0x00, 0x04, // Cipher Suites Length (4 bytes = 2 cipher suites) 0x00, 0x2f, 0x00, 0x35, // Cipher suites 0x01, // Compression Methods Length @@ -658,11 +662,9 @@ mod tests { 0x00, 0x00, 0x3d, // Handshake Length (61 bytes) 0x03, 0x03, // Client Version: TLS 1.2 // Random (32 bytes) - 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, - 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, - 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, - 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, - 0x00, // Session ID Length + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, + 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, + 0x1c, 0x1d, 0x1e, 0x1f, 0x00, // Session ID Length 0x00, 0x02, // Cipher Suites Length (2 bytes = 1 cipher suite) 0x00, 0x2f, // Cipher suite 0x01, // Compression Methods Length @@ -673,7 +675,8 @@ mod tests { 0x00, 0x0e, // Extension Length (14 bytes) 0x00, 0x0c, // ALPN Extension Length (12 bytes) 0x02, 0x68, 0x32, // Length prefix (2) + "h2" - 0x08, 0x68, 0x74, 0x74, 0x70, 0x2f, 0x31, 0x2e, 0x31, // Length prefix (8) + "http/1.1" + 0x08, 0x68, 0x74, 0x74, 0x70, 0x2f, 0x31, 0x2e, + 0x31, // Length prefix (8) + "http/1.1" ]; let hello = ClientHello::parse(&data).expect("Failed to parse ClientHello"); @@ -690,8 +693,7 @@ mod tests { fn test_parse_non_handshake() { let data = vec![ 0x17, // Content Type: Application Data (not handshake) - 0x03, 0x03, - 0x00, 0x10, + 0x03, 0x03, 0x00, 0x10, ]; assert!(ClientHello::parse(&data).is_none()); } @@ -716,11 +718,9 @@ mod tests { 0x00, 0x00, 0x3b, // Handshake Length (59 bytes = body with padding) 0x03, 0x03, // Client Version: TLS 1.2 // Random (32 bytes) - 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, - 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, - 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, - 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, - 0x00, // Session ID Length + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, + 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, + 0x1c, 0x1d, 0x1e, 0x1f, 0x00, // Session ID Length 0x00, 0x02, // Cipher Suites Length (2 bytes) 0x00, 0x2f, // Cipher suite 0x01, // Compression Methods Length @@ -764,11 +764,9 @@ mod tests { 0x00, 0x00, 0x3b, // Handshake Length (59 bytes = body with padding) 0x03, 0x03, // Client Version: TLS 1.2 // Random (32 bytes) - 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, - 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, - 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, - 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, - 0x00, // Session ID Length + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, + 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, + 0x1c, 0x1d, 0x1e, 0x1f, 0x00, // Session ID Length 0x00, 0x02, // Cipher Suites Length (2 bytes) 0x00, 0x2f, // Cipher suite 0x01, // Compression Methods Length @@ -811,11 +809,9 @@ mod tests { 0x00, 0x00, 0x38, // Handshake Length (56 bytes = body with padding) 0x03, 0x03, // Client Version: TLS 1.2 (legacy) // Random (32 bytes) - 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, - 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, - 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, - 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, - 0x00, // Session ID Length + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, + 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, + 0x1c, 0x1d, 0x1e, 0x1f, 0x00, // Session ID Length 0x00, 0x02, // Cipher Suites Length (2 bytes) 0x13, 0x01, // TLS 1.3 cipher suite 0x01, // Compression Methods Length @@ -860,11 +856,9 @@ mod tests { 0x00, 0x00, 0x48, // Handshake Length (72 bytes = body with padding) 0x03, 0x03, // Client Version: TLS 1.2 (legacy) // Random (32 bytes) - 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, - 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, - 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, - 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, - 0x00, // Session ID Length + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, + 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, + 0x1c, 0x1d, 0x1e, 0x1f, 0x00, // Session ID Length 0x00, 0x02, // Cipher Suites Length (2 bytes) 0x13, 0x01, // TLS 1.3 cipher suite 0x01, // Compression Methods Length @@ -877,9 +871,8 @@ mod tests { 0x00, 0x1d, // Group: x25519 0x00, 0x10, // Key Exchange Length (16 bytes) // Key exchange data (16 bytes) - 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, - 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, - 0x00, 0x00, // Padding (2 bytes) + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, + 0x0e, 0x0f, 0x00, 0x00, // Padding (2 bytes) 0x00, 0x00, // Additional padding (2 bytes) ]; @@ -909,11 +902,9 @@ mod tests { 0x00, 0x00, 0x39, // Handshake Length (57 bytes = body with padding) 0x03, 0x03, // Client Version: TLS 1.2 // Random (32 bytes) - 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, - 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, - 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, - 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, - 0x00, // Session ID Length + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, + 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, + 0x1c, 0x1d, 0x1e, 0x1f, 0x00, // Session ID Length 0x00, 0x02, // Cipher Suites Length (2 bytes) 0x00, 0x2f, // Cipher suite 0x01, // Compression Methods Length @@ -947,11 +938,9 @@ mod tests { 0x00, 0x00, 0x41, // Handshake Length (65 bytes) 0x03, 0x03, // Client Version: TLS 1.2 // Random (32 bytes) - 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, - 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, - 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, - 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, - 0x00, // Session ID Length + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, + 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, + 0x1c, 0x1d, 0x1e, 0x1f, 0x00, // Session ID Length 0x00, 0x04, // Cipher Suites Length (4 bytes = 2 cipher suites) 0x00, 0x2f, 0x00, 0x35, // Cipher suites 0x01, // Compression Methods Length @@ -975,7 +964,8 @@ mod tests { crate::protocols::proxy_protocol::set_proxy_protocol_enabled(true); // Parse should skip PROXY header and parse ClientHello - let hello = ClientHello::parse(&data[proxy_header.len()..]).expect("Failed to parse ClientHello"); + let hello = + ClientHello::parse(&data[proxy_header.len()..]).expect("Failed to parse ClientHello"); assert_eq!(hello.sni, Some("example.com".to_string())); assert_eq!(hello.tls_version, Some(0x0301)); @@ -994,11 +984,9 @@ mod tests { 0x00, 0x00, 0x41, // Handshake Length (65 bytes) 0x03, 0x03, // Client Version: TLS 1.2 // Random (32 bytes) - 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, - 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, - 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, - 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, - 0x00, // Session ID Length + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, + 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, + 0x1c, 0x1d, 0x1e, 0x1f, 0x00, // Session ID Length 0x00, 0x04, // Cipher Suites Length (4 bytes = 2 cipher suites) 0x00, 0x2f, 0x00, 0x35, // Cipher suites 0x01, // Compression Methods Length @@ -1023,4 +1011,3 @@ mod tests { assert_eq!(hello.tls_version, Some(0x0301)); } } - diff --git a/pingora-core/src/protocols/tls/digest.rs b/pingora-core/src/protocols/tls/digest.rs index a6b95e62..58ecf3b6 100644 --- a/pingora-core/src/protocols/tls/digest.rs +++ b/pingora-core/src/protocols/tls/digest.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,7 +14,9 @@ //! TLS information from the TLS connection +use std::any::Any; use std::borrow::Cow; +use std::sync::Arc; /// The TLS connection information #[derive(Clone, Debug)] @@ -29,6 +31,8 @@ pub struct SslDigest { pub serial_number: Option, /// The digest of the peer's certificate pub cert_digest: Vec, + /// The user-defined TLS data + pub extension: SslDigestExtension, } impl SslDigest { @@ -49,6 +53,30 @@ impl SslDigest { organization, serial_number, cert_digest, + extension: SslDigestExtension::default(), } } } + +/// The user-defined TLS data +#[derive(Clone, Debug, Default)] +pub struct SslDigestExtension { + value: Option>, +} + +impl SslDigestExtension { + /// Retrieves a reference to the user-defined TLS data if it matches the specified type. + /// + /// Returns `None` if no data has been set or if the data is not of type `T`. + pub fn get(&self) -> Option<&T> + where + T: Send + Sync + 'static, + { + self.value.as_ref().and_then(|v| v.downcast_ref::()) + } + + #[allow(dead_code)] + pub(crate) fn set(&mut self, value: Arc) { + self.value = Some(value); + } +} diff --git a/pingora-core/src/protocols/tls/mod.rs b/pingora-core/src/protocols/tls/mod.rs index ba697412..51a35137 100644 --- a/pingora-core/src/protocols/tls/mod.rs +++ b/pingora-core/src/protocols/tls/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -42,8 +42,14 @@ pub mod noop_tls; #[cfg(not(feature = "any_tls"))] pub use noop_tls::*; +/// Containing type for a user callback to generate extensions for the `SslDigest` upon handshake +/// completion. +pub type HandshakeCompleteHook = std::sync::Arc< + dyn Fn(&TlsRef) -> Option> + Send + Sync, +>; + /// The protocol for Application-Layer Protocol Negotiation -#[derive(Hash, Clone, Debug)] +#[derive(Hash, Clone, Debug, PartialEq, PartialOrd)] pub enum ALPN { /// Prefer HTTP/1.1 only H1, @@ -51,6 +57,54 @@ pub enum ALPN { H2, /// Prefer HTTP/2 over HTTP/1.1 H2H1, + /// Custom Protocol is stored in wire format (length-prefixed) + /// Wire format is precomputed at creation to avoid dangling references + Custom(CustomALPN), +} + +/// Represents a Custom ALPN Protocol with a precomputed wire format and header offset. +#[derive(Hash, Clone, Debug, PartialEq, PartialOrd)] +pub struct CustomALPN { + wire: Vec, + header: usize, +} + +impl CustomALPN { + /// Create a new CustomALPN from a protocol byte vector + pub fn new(proto: Vec) -> Self { + // Validate before setting + assert!(!proto.is_empty(), "Custom ALPN protocol must not be empty"); + // RFC-7301 + assert!( + proto.len() <= 255, + "ALPN protocol name must be 255 bytes or fewer" + ); + + match proto.as_slice() { + b"http/1.1" | b"h2" => { + panic!("Custom ALPN cannot be a reserved protocol (http/1.1 or h2)") + } + _ => {} + } + let mut wire = Vec::with_capacity(1 + proto.len()); + wire.push(proto.len() as u8); + wire.extend_from_slice(&proto); + + Self { + wire, + header: 1, // Header is always at index 1 since we prefix one length byte + } + } + + /// Get the custom protocol name as a slice + pub fn protocol(&self) -> &[u8] { + &self.wire[self.header..] + } + + /// Get the wire format used for ALPN negotiation + pub fn as_wire(&self) -> &[u8] { + &self.wire + } } impl std::fmt::Display for ALPN { @@ -59,6 +113,13 @@ impl std::fmt::Display for ALPN { ALPN::H1 => write!(f, "H1"), ALPN::H2 => write!(f, "H2"), ALPN::H2H1 => write!(f, "H2H1"), + ALPN::Custom(custom) => { + // extract protocol name, print as UTF-8 if possible, else judt itd raw bytes + match std::str::from_utf8(custom.protocol()) { + Ok(s) => write!(f, "Custom({})", s), + Err(_) => write!(f, "Custom({:?})", custom.protocol()), + } + } } } } @@ -79,15 +140,17 @@ impl ALPN { pub fn get_max_http_version(&self) -> u8 { match self { ALPN::H1 => 1, - _ => 2, + ALPN::H2 | ALPN::H2H1 => 2, + ALPN::Custom(_) => 0, } } /// Return the min http version this [`ALPN`] allows pub fn get_min_http_version(&self) -> u8 { match self { + ALPN::H1 | ALPN::H2H1 => 1, ALPN::H2 => 2, - _ => 1, + ALPN::Custom(_) => 0, } } @@ -99,6 +162,7 @@ impl ALPN { Self::H1 => b"\x08http/1.1", Self::H2 => b"\x02h2", Self::H2H1 => b"\x02h2\x08http/1.1", + Self::Custom(custom) => custom.as_wire(), } } @@ -107,7 +171,7 @@ impl ALPN { match raw { b"http/1.1" => Some(Self::H1), b"h2" => Some(Self::H2), - _ => None, + _ => Some(Self::Custom(CustomALPN::new(raw.to_vec()))), } } @@ -117,6 +181,7 @@ impl ALPN { ALPN::H1 => vec![b"http/1.1".to_vec()], ALPN::H2 => vec![b"h2".to_vec()], ALPN::H2H1 => vec![b"h2".to_vec(), b"http/1.1".to_vec()], + ALPN::Custom(custom) => vec![custom.protocol().to_vec()], } } @@ -126,6 +191,51 @@ impl ALPN { ALPN::H1 => vec![b"http/1.1".to_vec()], ALPN::H2 => vec![b"h2".to_vec()], ALPN::H2H1 => vec![b"h2".to_vec(), b"http/1.1".to_vec()], + ALPN::Custom(custom) => vec![custom.protocol().to_vec()], } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_valid_alpn_construction_and_versions() { + // Standard Protocols + assert_eq!(ALPN::H1.get_min_http_version(), 1); + assert_eq!(ALPN::H1.get_max_http_version(), 1); + + assert_eq!(ALPN::H2.get_min_http_version(), 2); + assert_eq!(ALPN::H2.get_max_http_version(), 2); + + assert_eq!(ALPN::H2H1.get_min_http_version(), 1); + assert_eq!(ALPN::H2H1.get_max_http_version(), 2); + + // Custom Protocol + let custom_protocol = ALPN::Custom(CustomALPN::new("custom/1.0".into())); + assert_eq!(custom_protocol.get_min_http_version(), 0); + assert_eq!(custom_protocol.get_max_http_version(), 0); + } + #[test] + #[should_panic(expected = "Custom ALPN protocol must not be empty")] + fn test_empty_custom_alpn() { + let _ = ALPN::Custom(CustomALPN::new("".into())); + } + #[test] + #[should_panic(expected = "ALPN protocol name must be 255 bytes or fewer")] + fn test_large_custom_alpn() { + let large_alpn = vec![b'a'; 256]; + let _ = ALPN::Custom(CustomALPN::new(large_alpn)); + } + #[test] + #[should_panic(expected = "Custom ALPN cannot be a reserved protocol (http/1.1 or h2)")] + fn test_custom_h1_alpn() { + let _ = ALPN::Custom(CustomALPN::new("http/1.1".into())); + } + #[test] + #[should_panic(expected = "Custom ALPN cannot be a reserved protocol (http/1.1 or h2)")] + fn test_custom_h2_alpn() { + let _ = ALPN::Custom(CustomALPN::new("h2".into())); + } +} diff --git a/pingora-core/src/protocols/tls/noop_tls/mod.rs b/pingora-core/src/protocols/tls/noop_tls/mod.rs index b909a3b2..d7632e13 100644 --- a/pingora-core/src/protocols/tls/noop_tls/mod.rs +++ b/pingora-core/src/protocols/tls/noop_tls/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-core/src/protocols/tls/rustls/client.rs b/pingora-core/src/protocols/tls/rustls/client.rs index 7ff701ab..a8e00c41 100644 --- a/pingora-core/src/protocols/tls/rustls/client.rs +++ b/pingora-core/src/protocols/tls/rustls/client.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-core/src/protocols/tls/rustls/mod.rs b/pingora-core/src/protocols/tls/rustls/mod.rs index f8bce5f6..c7c81fc8 100644 --- a/pingora-core/src/protocols/tls/rustls/mod.rs +++ b/pingora-core/src/protocols/tls/rustls/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-core/src/protocols/tls/rustls/server.rs b/pingora-core/src/protocols/tls/rustls/server.rs index d36fac56..4367f75a 100644 --- a/pingora-core/src/protocols/tls/rustls/server.rs +++ b/pingora-core/src/protocols/tls/rustls/server.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ use crate::listeners::TlsAcceptCallbacks; use crate::protocols::tls::rustls::TlsStream; +use crate::protocols::tls::TlsRef; use crate::protocols::IO; use crate::{listeners::tls::Acceptor, protocols::Shutdown}; use async_trait::async_trait; @@ -68,7 +69,7 @@ pub async fn handshake(acceptor: &Acceptor, io: S) -> Result pub async fn handshake_with_callback( acceptor: &Acceptor, io: S, - _callbacks: &TlsAcceptCallbacks, + callbacks: &TlsAcceptCallbacks, ) -> Result> { let mut tls_stream = prepare_tls_stream(acceptor, io).await?; let done = Pin::new(&mut tls_stream).start_accept().await?; @@ -81,7 +82,14 @@ pub async fn handshake_with_callback( .await .explain_err(TLSHandshakeFailure, |e| format!("TLS accept() failed: {e}"))?; } - + { + let tls_ref = TlsRef; + if let Some(extension) = callbacks.handshake_complete_callback(&tls_ref).await { + if let Some(digest_mut) = tls_stream.ssl_digest_mut() { + digest_mut.extension.set(extension); + } + } + } Ok(tls_stream) } diff --git a/pingora-core/src/protocols/tls/rustls/stream.rs b/pingora-core/src/protocols/tls/rustls/stream.rs index a23f4b35..f2a0ddae 100644 --- a/pingora-core/src/protocols/tls/rustls/stream.rs +++ b/pingora-core/src/protocols/tls/rustls/stream.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -133,6 +133,12 @@ impl TlsStream { pub fn ssl_digest(&self) -> Option> { self.digest.clone() } + + /// Attempts to obtain a mutable reference to the SslDigest. + /// This method returns `None` if the SslDigest is currently held by other references. + pub(crate) fn ssl_digest_mut(&mut self) -> Option<&mut SslDigest> { + Arc::get_mut(self.digest.as_mut()?) + } } impl Deref for TlsStream { diff --git a/pingora-core/src/protocols/tls/s2n/client.rs b/pingora-core/src/protocols/tls/s2n/client.rs index 3b7c2858..544a6790 100644 --- a/pingora-core/src/protocols/tls/s2n/client.rs +++ b/pingora-core/src/protocols/tls/s2n/client.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-core/src/protocols/tls/s2n/mod.rs b/pingora-core/src/protocols/tls/s2n/mod.rs index 0d78cb79..6118100c 100644 --- a/pingora-core/src/protocols/tls/s2n/mod.rs +++ b/pingora-core/src/protocols/tls/s2n/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-core/src/protocols/tls/s2n/server.rs b/pingora-core/src/protocols/tls/s2n/server.rs index bde5c927..a8498f5d 100644 --- a/pingora-core/src/protocols/tls/s2n/server.rs +++ b/pingora-core/src/protocols/tls/s2n/server.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-core/src/protocols/tls/s2n/stream.rs b/pingora-core/src/protocols/tls/s2n/stream.rs index 96790be9..059718ea 100644 --- a/pingora-core/src/protocols/tls/s2n/stream.rs +++ b/pingora-core/src/protocols/tls/s2n/stream.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-core/src/protocols/windows.rs b/pingora-core/src/protocols/windows.rs index 10d6ce70..37c9e6fc 100644 --- a/pingora-core/src/protocols/windows.rs +++ b/pingora-core/src/protocols/windows.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-core/src/server/bootstrap_services.rs b/pingora-core/src/server/bootstrap_services.rs new file mode 100644 index 00000000..10df272f --- /dev/null +++ b/pingora-core/src/server/bootstrap_services.rs @@ -0,0 +1,208 @@ +// Copyright 2026 Cloudflare, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#[cfg(unix)] +pub use super::transfer_fd::Fds; +use async_trait::async_trait; +use log::{debug, error, info}; +use parking_lot::Mutex; +use std::sync::Arc; +use tokio::sync::{broadcast, Mutex as TokioMutex}; + +#[cfg(feature = "sentry")] +use sentry::ClientOptions; + +#[cfg(unix)] +use crate::server::ListenFds; + +use crate::{ + prelude::Opt, + server::{configuration::ServerConf, ExecutionPhase, ShutdownWatch}, + services::{background::BackgroundService, ServiceReadyNotifier}, +}; + +/// Service that allows the bootstrap process to be delayed until after +/// dependencies are ready +pub struct BootstrapService { + inner: Arc>, +} + +/// Sentry is typically started as part of the bootstrap process, but if the +/// bootstrap service is used, we want to initialize Sentry before anything else +/// to make sure errors are captured. +pub struct SentryInitService { + inner: Arc>, +} + +impl BootstrapService { + pub fn new(inner: &Arc>) -> Self { + BootstrapService { + inner: Arc::clone(inner), + } + } +} + +impl SentryInitService { + pub fn new(inner: &Arc>) -> Self { + SentryInitService { + inner: Arc::clone(inner), + } + } +} + +/// Encapsulation of the data needed to bootstrap the server +pub struct Bootstrap { + completed: bool, + + test: bool, + upgrade: bool, + + upgrade_sock: String, + + execution_phase_watch: broadcast::Sender, + + #[cfg(unix)] + listen_fds: Option, + + #[cfg(feature = "sentry")] + #[cfg_attr(docsrs, doc(cfg(feature = "sentry")))] + /// The Sentry ClientOptions. + /// + /// Panics and other events sentry captures will be sent to this DSN **only + /// in release mode** + pub sentry: Option, +} + +impl Bootstrap { + pub fn new( + options: &Option, + conf: &ServerConf, + execution_phase_watch: &broadcast::Sender, + ) -> Self { + let (test, upgrade) = options + .as_ref() + .map(|opt| (opt.test, opt.upgrade)) + .unwrap_or_default(); + + let upgrade_sock = conf.upgrade_sock.clone(); + + Bootstrap { + test, + upgrade, + upgrade_sock, + #[cfg(unix)] + listen_fds: None, + execution_phase_watch: execution_phase_watch.clone(), + completed: false, + #[cfg(feature = "sentry")] + sentry: None, + } + } + + #[cfg(feature = "sentry")] + pub fn set_sentry_config(&mut self, sentry_config: Option) { + self.sentry = sentry_config; + } + + /// Start sentry based on the configured options. To prevent multiple + /// initializations, this function will consume the sentry configuration + /// stored in the bootstrap + fn start_sentry(&mut self) { + // Only init sentry in release builds + #[cfg(all(not(debug_assertions), feature = "sentry"))] + let _guard = self.sentry.take().map(|opts| sentry::init(opts)); + } + + pub fn bootstrap(&mut self) { + // already bootstrapped + if self.completed { + return; + } + + info!("Bootstrap starting"); + + self.execution_phase_watch + .send(ExecutionPhase::Bootstrap) + .ok(); + + self.start_sentry(); + + if self.test { + info!("Server Test passed, exiting"); + std::process::exit(0); + } + + // load fds + #[cfg(unix)] + match self.load_fds(self.upgrade) { + Ok(_) => { + info!("Bootstrap done"); + } + Err(e) => { + // sentry log error on fd load failure + #[cfg(all(not(debug_assertions), feature = "sentry"))] + sentry::capture_error(&e); + + error!("Bootstrap failed on error: {:?}, exiting.", e); + std::process::exit(1); + } + } + + self.completed = true; + + self.execution_phase_watch + .send(ExecutionPhase::BootstrapComplete) + .ok(); + } + + #[cfg(unix)] + fn load_fds(&mut self, upgrade: bool) -> Result<(), nix::Error> { + let mut fds = Fds::new(); + if upgrade { + debug!("Trying to receive socks"); + fds.get_from_sock(self.upgrade_sock.as_str())? + } + self.listen_fds = Some(Arc::new(TokioMutex::new(fds))); + Ok(()) + } + + #[cfg(unix)] + pub fn get_fds(&self) -> Option { + self.listen_fds.clone() + } +} + +#[async_trait] +impl BackgroundService for BootstrapService { + async fn start_with_ready_notifier( + &self, + _shutdown: ShutdownWatch, + notifier: ServiceReadyNotifier, + ) { + self.inner.lock().bootstrap(); + notifier.notify_ready(); + } +} + +#[async_trait] +impl BackgroundService for SentryInitService { + async fn start_with_ready_notifier( + &self, + _shutdown: ShutdownWatch, + notifier: ServiceReadyNotifier, + ) { + self.inner.lock().start_sentry(); + notifier.notify_ready(); + } +} diff --git a/pingora-core/src/server/configuration/mod.rs b/pingora-core/src/server/configuration/mod.rs index 77828b6d..fb1742c5 100644 --- a/pingora-core/src/server/configuration/mod.rs +++ b/pingora-core/src/server/configuration/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -23,6 +23,7 @@ use clap::Parser; use log::{debug, trace}; use pingora_error::{Error, ErrorType::*, OrErr, Result}; use serde::{Deserialize, Serialize}; +use std::ffi::OsString; use std::fs; // default maximum upstream retries for retry-able proxy errors @@ -113,6 +114,12 @@ pub struct ServerConf { /// /// This setting is a fail-safe and defaults to 16. pub max_retries: usize, + /// Maximum number of retries for upgrade socket connect and accept operations. + /// This controls how many times send_fds_to will retry connecting and how many times + /// get_fds_from will retry accepting during graceful upgrades. + /// The retry interval is 1 second between attempts. + /// If not set, defaults to 5 retries. + pub upgrade_sock_connect_accept_max_retries: Option, } impl Default for ServerConf { @@ -141,6 +148,7 @@ impl Default for ServerConf { grace_period_seconds: None, graceful_shutdown_timeout_seconds: None, max_retries: DEFAULT_MAX_RETRIES, + upgrade_sock_connect_accept_max_retries: None, } } } @@ -166,7 +174,7 @@ pub struct Opt { /// Not actually used. This flag is there so that the server is not upset seeing this flag /// passed from `cargo test` sometimes - #[clap(long, hidden = true)] + #[clap(long, hide = true)] pub nocapture: bool, /// Test the configuration and exit @@ -263,6 +271,14 @@ impl Opt { pub fn parse_args() -> Self { Opt::parse() } + + pub fn parse_from_args(args: I) -> Self + where + I: IntoIterator, + T: Into + Clone, + { + Opt::parse_from(args) + } } #[cfg(test)] @@ -300,6 +316,7 @@ mod tests { grace_period_seconds: None, graceful_shutdown_timeout_seconds: None, max_retries: 1, + upgrade_sock_connect_accept_max_retries: None, }; // cargo test -- --nocapture not_a_test_i_cannot_write_yaml_by_hand println!("{}", conf.to_yaml()); diff --git a/pingora-core/src/server/daemon.rs b/pingora-core/src/server/daemon.rs index c45a5eeb..7381fc93 100644 --- a/pingora-core/src/server/daemon.rs +++ b/pingora-core/src/server/daemon.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-core/src/server/mod.rs b/pingora-core/src/server/mod.rs index 43e6b3f1..2021677c 100644 --- a/pingora-core/src/server/mod.rs +++ b/pingora-core/src/server/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,6 +14,7 @@ //! Server process and configuration management +mod bootstrap_services; pub mod configuration; #[cfg(unix)] mod daemon; @@ -23,21 +24,29 @@ pub(crate) mod transfer_fd; use async_trait::async_trait; #[cfg(unix)] use daemon::daemonize; +use daggy::NodeIndex; use log::{debug, error, info, warn}; +use parking_lot::Mutex; use pingora_runtime::Runtime; use pingora_timeout::fast_timeout; #[cfg(feature = "sentry")] use sentry::ClientOptions; use std::sync::Arc; use std::thread; +use std::time::SystemTime; #[cfg(unix)] use tokio::signal::unix; -use tokio::sync::{broadcast, watch, Mutex}; +use tokio::sync::{broadcast, watch, Mutex as TokioMutex}; use tokio::time::{sleep, Duration}; +use crate::prelude::background_service; use crate::protocols::proxy_protocol; -use crate::services::Service; +use crate::server::bootstrap_services::{Bootstrap, BootstrapService, SentryInitService}; +use crate::services::{ + DependencyGraph, ServiceHandle, ServiceReadyNotifier, ServiceReadyWatch, ServiceWithDependents, +}; use configuration::{Opt, ServerConf}; +use std::collections::HashMap; #[cfg(unix)] pub use transfer_fd::Fds; @@ -68,6 +77,13 @@ fn configure_proxy_protocol(conf: &mut ServerConf) { proxy_protocol::set_proxy_protocol_enabled(conf.enable_proxy_protocol); } +/// Internal wrapper for services with dependency metadata. +pub(crate) struct ServiceWrapper { + ready_notifier: Option, + service: Box, + service_handle: ServiceHandle, +} + /// The execution phase the server is currently in. #[derive(Clone, Debug)] #[non_exhaustive] @@ -115,7 +131,7 @@ pub enum ExecutionPhase { /// to shutdown pub type ShutdownWatch = watch::Receiver; #[cfg(unix)] -pub type ListenFds = Arc>; +pub type ListenFds = Arc>; /// The type of shutdown process that has been requested. #[derive(Debug)] @@ -195,9 +211,11 @@ impl Default for RunArgs { /// services (see [crate::services]). The server object handles signals, reading configuration, /// zero downtime upgrade and error reporting. pub struct Server { - services: Vec>, - #[cfg(unix)] - listen_fds: Option, + // This is a way to add services that have to be run before any others + // without requiring dependencies to be set directly + init_services: Vec>, + + services: HashMap, shutdown_watch: watch::Sender, // TODO: we many want to drop this copy to let sender call closed() shutdown_recv: ShutdownWatch, @@ -207,16 +225,16 @@ pub struct Server { /// Users can subscribe to the phase with [`Self::watch_execution_phase()`]. execution_phase_watch: broadcast::Sender, + /// Specification of service level dependencies + dependencies: Arc>, + + /// Service initialization + bootstrap: Arc>, + /// The parsed server configuration pub configuration: Arc, /// The parser command line options pub options: Option, - #[cfg(feature = "sentry")] - #[cfg_attr(docsrs, doc(cfg(feature = "sentry")))] - /// The Sentry ClientOptions. - /// - /// Panics and other events sentry captures will be sent to this DSN **only in release mode** - pub sentry: Option, } // TODO: delete the pid when exit @@ -272,7 +290,7 @@ impl Server { .send(ExecutionPhase::GracefulUpgradeTransferringFds) .ok(); - if let Some(fds) = &self.listen_fds { + if let Some(fds) = self.listen_fds() { let fds = fds.lock().await; info!("Trying to send socks"); // XXX: this is blocking IO @@ -313,43 +331,108 @@ impl Server { } } + #[cfg(windows)] + async fn main_loop(&self, _run_args: RunArgs) -> ShutdownType { + // waiting for exit signal + + self.execution_phase_watch + .send(ExecutionPhase::Running) + .ok(); + + match tokio::signal::ctrl_c().await { + Ok(()) => { + info!("Ctrl+C received, gracefully exiting"); + // graceful shutdown if there are listening sockets + info!("Broadcasting graceful shutdown"); + match self.shutdown_watch.send(true) { + Ok(_) => { + info!("Graceful shutdown started!"); + } + Err(e) => { + error!("Graceful shutdown broadcast failed: {e}"); + } + } + info!("Broadcast graceful shutdown complete"); + + self.execution_phase_watch + .send(ExecutionPhase::GracefulTerminate) + .ok(); + + ShutdownType::Graceful + } + Err(e) => { + error!("Unable to listen for shutdown signal: {}", e); + ShutdownType::Quick + } + } + } + + #[cfg(feature = "sentry")] + #[cfg_attr(docsrs, doc(cfg(feature = "sentry")))] + /// The Sentry ClientOptions. + /// + /// Panics and other events sentry captures will be sent to this DSN **only in release mode** + pub fn set_sentry_config(&mut self, sentry_config: ClientOptions) { + self.bootstrap.lock().set_sentry_config(Some(sentry_config)); + } + + /// Get the configured file descriptors for listening + #[cfg(unix)] + fn listen_fds(&self) -> Option { + self.bootstrap.lock().get_fds() + } + + #[allow(clippy::too_many_arguments)] fn run_service( - mut service: Box, + mut service: Box, #[cfg(unix)] fds: Option, shutdown: ShutdownWatch, threads: usize, work_stealing: bool, listeners_per_fd: usize, + ready_notifier: ServiceReadyNotifier, + dependency_watches: Vec, ) -> Runtime // NOTE: we need to keep the runtime outside async since // otherwise the runtime will be dropped. { let service_runtime = Server::create_runtime(service.name(), threads, work_stealing); + let service_name = service.name().to_string(); service_runtime.get_handle().spawn(async move { + // Wait for all dependencies to be ready + let mut time_waited_opt: Option = None; + for mut watch in dependency_watches { + let start = SystemTime::now(); + + if watch.wait_for(|&ready| ready).await.is_err() { + error!( + "Service '{}' dependency channel closed before ready", + service_name + ); + } + + *time_waited_opt.get_or_insert_default() += start.elapsed().unwrap_or_default() + } + + if let Some(time_waited) = time_waited_opt { + service.on_startup_delay(time_waited); + } + + // Start the actual service, passing the ready notifier service .start_service( #[cfg(unix)] fds, shutdown, listeners_per_fd, + ready_notifier, ) .await; - info!("service exited.") + info!("service '{}' exited.", service_name); }); service_runtime } - #[cfg(unix)] - fn load_fds(&mut self, upgrade: bool) -> Result<(), nix::Error> { - let mut fds = Fds::new(); - if upgrade { - debug!("Trying to receive socks"); - fds.get_from_sock(self.configuration.as_ref().upgrade_sock.as_str())? - } - self.listen_fds = Some(Arc::new(Mutex::new(fds))); - Ok(()) - } - /// Create a new [`Server`], using the [`Opt`] and [`ServerConf`] values provided /// /// This method is intended for pingora frontends that are NOT using the built-in @@ -369,17 +452,23 @@ impl Server { let (tx, rx) = watch::channel(false); + let execution_phase_watch = broadcast::channel(100).0; + let bootstrap = Arc::new(Mutex::new(Bootstrap::new( + &opt, + &conf, + &execution_phase_watch, + ))); + Server { - services: vec![], - #[cfg(unix)] - listen_fds: None, + services: Default::default(), + init_services: Default::default(), shutdown_watch: tx, shutdown_recv: rx, - execution_phase_watch: broadcast::channel(100).0, + execution_phase_watch, configuration: Arc::new(conf), options: opt, - #[cfg(feature = "sentry")] - sentry: None, + dependencies: Arc::new(Mutex::new(DependencyGraph::new())), + bootstrap, } } @@ -394,6 +483,7 @@ impl Server { let opt = opt.into(); let (tx, rx) = watch::channel(false); + let execution_phase_watch = broadcast::channel(100).0; let mut conf = if let Some(opt) = opt.as_ref() { opt.conf.as_ref().map_or_else( || { @@ -413,30 +503,115 @@ impl Server { }?; configure_proxy_protocol(&mut conf); + let bootstrap = Arc::new(Mutex::new(Bootstrap::new( + &opt, + &conf, + &execution_phase_watch, + ))); + Ok(Server { - services: vec![], - #[cfg(unix)] - listen_fds: None, + services: Default::default(), + init_services: Default::default(), shutdown_watch: tx, shutdown_recv: rx, - execution_phase_watch: broadcast::channel(100).0, + execution_phase_watch, configuration: Arc::new(conf), options: opt, - #[cfg(feature = "sentry")] - sentry: None, + dependencies: Arc::new(Mutex::new(DependencyGraph::new())), + bootstrap, }) } + /// Add a service that all other services will wait on before starting. + fn add_init_service(&mut self, service: impl ServiceWithDependents + 'static) { + let boxed_service = Box::new(service); + self.init_services.push(boxed_service); + } + + /// Add the init services as dependencies for all existing services + fn apply_init_service_dependencies(&mut self) { + let services = self + .services + .values() + .map(|service| service.service_handle.clone()) + .collect::>(); + let global_deps = self + .init_services + .drain(..) + .collect::>() + .into_iter() + .map(|dep| self.add_boxed_service(dep)) + .collect::>(); + for service in services { + service.add_dependencies(&global_deps); + } + } + /// Add a service to this server. /// - /// A service is anything that implements [`Service`]. - pub fn add_service(&mut self, service: impl Service + 'static) { - self.services.push(Box::new(service)); + /// Returns a [`ServiceHandle`] that can be used to declare dependencies. + /// + /// # Example + /// + /// ```rust,ignore + /// let db_id = server.add_service(database_service); + /// let api_id = server.add_service(api_service); + /// + /// // Declare that API depends on database + /// api_id.add_dependency(&db_id); + /// ``` + pub fn add_service(&mut self, service: impl ServiceWithDependents + 'static) -> ServiceHandle { + self.add_boxed_service(Box::new(service)) + } + + /// Add a pre-boxed service to this server. + /// + /// Returns a [`ServiceHandle`] that can be used to declare dependencies. + /// + /// # Example + /// + /// ```rust,ignore + /// let db_id = server.add_service(database_service); + /// let api_id = server.add_service(api_service); + /// + /// // Declare that API depends on database + /// api_id.add_dependency(&db_id); + /// ``` + pub fn add_boxed_service( + &mut self, + service_box: Box, + ) -> ServiceHandle { + let name = service_box.name().to_string(); + + // Create a readiness notifier for this service + let (tx, rx) = watch::channel(false); + + let id = self.dependencies.lock().add_node(name.clone(), rx.clone()); + + let service_handle = ServiceHandle::new(id, name, rx, &self.dependencies); + + let wrapper = ServiceWrapper { + ready_notifier: Some(ServiceReadyNotifier::new(tx)), + service: service_box, + service_handle: service_handle.clone(), + }; + + self.services.insert(id, wrapper); + + service_handle } - /// Similar to [`Self::add_service()`], but take a list of services - pub fn add_services(&mut self, services: Vec>) { - self.services.extend(services); + /// Similar to [`Self::add_service()`], but take a list of services. + /// + /// Returns a `Vec` for all added services. + pub fn add_services( + &mut self, + services: Vec>, + ) -> Vec { + services + .into_iter() + .map(|service| self.add_boxed_service(service)) + .collect() } /// Prepare the server to start @@ -444,41 +619,28 @@ impl Server { /// When trying to zero downtime upgrade from an older version of the server which is already /// running, this function will try to get all its listening sockets in order to take them over. pub fn bootstrap(&mut self) { - info!("Bootstrap starting"); - debug!("{:#?}", self.options); - - self.execution_phase_watch - .send(ExecutionPhase::Bootstrap) - .ok(); + self.bootstrap.lock().bootstrap(); + } - /* only init sentry in release builds */ - #[cfg(all(not(debug_assertions), feature = "sentry"))] - let _guard = self.sentry.as_ref().map(|opts| sentry::init(opts.clone())); + /// Create a service that will run to prepare the service to start + /// + /// The created service will handle the zero-downtime upgrade from an older version of the server + /// to this one. It will try to get all its listening sockets in order to take them over. + /// + /// Other bootstrapping functionality like sentry initialization will also be handled, but as a + /// service that will complete before any other service starts. + pub fn bootstrap_as_a_service(&mut self) -> ServiceHandle { + let bootstrap_service = + background_service("Bootstrap Service", BootstrapService::new(&self.bootstrap)); - if self.options.as_ref().is_some_and(|o| o.test) { - info!("Server Test passed, exiting"); - std::process::exit(0); - } + let sentry_service = background_service( + "Sentry Init Service", + SentryInitService::new(&self.bootstrap), + ); - // load fds - #[cfg(unix)] - match self.load_fds(self.options.as_ref().is_some_and(|o| o.upgrade)) { - Ok(_) => { - info!("Bootstrap done"); - } - Err(e) => { - // sentry log error on fd load failure - #[cfg(all(not(debug_assertions), feature = "sentry"))] - sentry::capture_error(&e); + self.add_init_service(sentry_service); - error!("Bootstrap failed on error: {:?}, exiting.", e); - std::process::exit(1); - } - } - - self.execution_phase_watch - .send(ExecutionPhase::BootstrapComplete) - .ok(); + self.add_service(bootstrap_service) } /// Start the server using [Self::run] and default [RunArgs]. @@ -505,6 +667,8 @@ impl Server { /// Instead it will either start the daemon process and exit, or panic /// if daemonization fails. pub fn run(mut self, run_args: RunArgs) { + self.apply_init_service_dependencies(); + info!("Server starting"); let conf = self.configuration.as_ref(); @@ -522,24 +686,79 @@ impl Server { panic!("Daemonizing under windows is not supported"); } - /* only init sentry in release builds */ - #[cfg(all(not(debug_assertions), feature = "sentry"))] - let _guard = self.sentry.as_ref().map(|opts| sentry::init(opts.clone())); - // Holds tuples of runtimes and their service name. let mut runtimes: Vec<(Runtime, String)> = Vec::new(); - while let Some(service) = self.services.pop() { - let threads = service.threads().unwrap_or(conf.threads); - let name = service.name().to_string(); + // Get services in topological order (dependencies first) + let startup_order = match self.dependencies.lock().topological_sort() { + Ok(order) => order, + Err(e) => { + error!("Failed to determine service startup order: {}", e); + std::process::exit(1); + } + }; + + // Log service names in startup order + let service_names: Vec = startup_order + .iter() + .map(|(_, service)| service.name.clone()) + .collect(); + info!("Starting services in dependency order: {:?}", service_names); + + // Start services in dependency order + for (service_id, service) in startup_order { + let mut wrapper = match self.services.remove(&service_id) { + Some(w) => w, + None => { + warn!( + "Service ID {:?}-{} in startup order but not found", + service_id, service.name + ); + continue; + } + }; + + let threads = wrapper.service.threads().unwrap_or(conf.threads); + let name = wrapper.service.name().to_string(); + + // Extract dependency watches from the ServiceHandle + let dependencies = self + .dependencies + .lock() + .get_dependencies(wrapper.service_handle.id); + + // Get the readiness notifier for this service by taking it from the Option. + // Since service_id is the index, we can directly access it. + // We take() the notifier, leaving None in its place. + let ready_notifier = wrapper + .ready_notifier + .take() + .expect("Service notifier should exist"); + + if !dependencies.is_empty() { + info!( + "Service '{name}' will wait for dependencies: {:?}", + dependencies.iter().map(|s| &s.name).collect::>() + ); + } else { + info!("Starting service: {}", name); + } + + let dependency_watches = dependencies + .iter() + .map(|s| s.ready_watch.clone()) + .collect::>(); + let runtime = Server::run_service( - service, + wrapper.service, #[cfg(unix)] - self.listen_fds.clone(), + self.listen_fds(), self.shutdown_recv.clone(), threads, conf.work_stealing, self.configuration.listener_tasks_per_fd, + ready_notifier, + dependency_watches, ); runtimes.push((runtime, name)); } @@ -552,7 +771,9 @@ impl Server { .get_handle() .block_on(self.main_loop(run_args)); #[cfg(windows)] - let shutdown_type = ShutdownType::Graceful; + let shutdown_type = server_runtime + .get_handle() + .block_on(self.main_loop(run_args)); self.execution_phase_watch .send(ExecutionPhase::ShutdownStarted) diff --git a/pingora-core/src/server/transfer_fd/mod.rs b/pingora-core/src/server/transfer_fd/mod.rs index 3fb7259b..3f852aec 100644 --- a/pingora-core/src/server/transfer_fd/mod.rs +++ b/pingora-core/src/server/transfer_fd/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -68,7 +68,7 @@ impl Fds { let (vec_key, vec_fds) = self.serialize(); let mut ser_buf: [u8; 2048] = [0; 2048]; let ser_key_size = serialize_vec_string(&vec_key, &mut ser_buf); - send_fds_to(vec_fds, &ser_buf[..ser_key_size], path) + send_fds_to(vec_fds, &ser_buf[..ser_key_size], path, None) } pub fn get_from_sock

(&mut self, path: &P) -> Result<(), Error> @@ -76,7 +76,7 @@ impl Fds { P: ?Sized + NixPath + std::fmt::Display, { let mut de_buf: [u8; 2048] = [0; 2048]; - let (fds, bytes) = get_fds_from(path, &mut de_buf)?; + let (fds, bytes) = get_fds_from(path, &mut de_buf, None)?; let keys = deserialize_vec_string(&de_buf[..bytes])?; self.deserialize(keys, fds); Ok(()) @@ -97,10 +97,15 @@ fn deserialize_vec_string(buf: &[u8]) -> Result, Error> { } #[cfg(target_os = "linux")] -pub fn get_fds_from

(path: &P, payload: &mut [u8]) -> Result<(Vec, usize), Error> +pub fn get_fds_from

( + path: &P, + payload: &mut [u8], + max_retry: Option, +) -> Result<(Vec, usize), Error> where P: ?Sized + NixPath + std::fmt::Display, { + let max_retry = max_retry.unwrap_or(MAX_RETRY); const MAX_FDS: usize = 32; let listen_fd = socket::socket( @@ -124,18 +129,18 @@ where }; socket::bind(listen_fd, &unix_addr).unwrap(); - /* sock is created before we change user, need to give permission to all */ + /* sock is created before we change user, need to give permission */ stat::fchmodat( None, path, - stat::Mode::all(), + stat::Mode::from_bits_truncate(0o666), stat::FchmodatFlags::FollowSymlink, ) .unwrap(); socket::listen(listen_fd, 8).unwrap(); - let fd = match accept_with_retry(listen_fd) { + let fd = match accept_with_retry_timeout(listen_fd, max_retry) { Ok(fd) => fd, Err(e) => { error!("Giving up reading socket from: {path}, error: {e:?}"); @@ -175,7 +180,11 @@ where } #[cfg(not(target_os = "linux"))] -pub fn get_fds_from

(_path: &P, _payload: &mut [u8]) -> Result<(Vec, usize), Error> +pub fn get_fds_from

( + _path: &P, + _payload: &mut [u8], + _max_retry: Option, +) -> Result<(Vec, usize), Error> where P: ?Sized + NixPath + std::fmt::Display, { @@ -189,13 +198,13 @@ const MAX_RETRY: usize = 5; const RETRY_INTERVAL: time::Duration = time::Duration::from_secs(1); #[cfg(target_os = "linux")] -fn accept_with_retry(listen_fd: i32) -> Result { +fn accept_with_retry_timeout(listen_fd: i32, max_retry: usize) -> Result { let mut retried = 0; loop { match socket::accept(listen_fd) { Ok(fd) => return Ok(fd), Err(e) => { - if retried > MAX_RETRY { + if retried > max_retry { return Err(e); } match e { @@ -217,10 +226,16 @@ fn accept_with_retry(listen_fd: i32) -> Result { } #[cfg(target_os = "linux")] -pub fn send_fds_to

(fds: Vec, payload: &[u8], path: &P) -> Result +pub fn send_fds_to

( + fds: Vec, + payload: &[u8], + path: &P, + max_retry: Option, +) -> Result where P: ?Sized + NixPath + std::fmt::Display, { + let max_retry = max_retry.unwrap_or(MAX_RETRY); const MAX_NONBLOCKING_POLLS: usize = 20; const NONBLOCKING_POLL_INTERVAL: time::Duration = time::Duration::from_millis(500); @@ -245,10 +260,10 @@ where Errno::ENOENT | Errno::ECONNREFUSED | Errno::EACCES => { /*the server is not ready yet*/ retried += 1; - if retried > MAX_RETRY { + if retried > max_retry { error!( "Max retry: {} reached. Giving up sending socket to: {}, error: {:?}", - MAX_RETRY, path, e + max_retry, path, e ); break Err(e); } @@ -317,7 +332,12 @@ where } #[cfg(not(target_os = "linux"))] -pub fn send_fds_to

(_fds: Vec, _payload: &[u8], _path: &P) -> Result +pub fn send_fds_to

( + _fds: Vec, + _payload: &[u8], + _path: &P, + _max_retry: Option, +) -> Result where P: ?Sized + NixPath + std::fmt::Display, { @@ -386,7 +406,8 @@ mod tests { // receiver need to start in another thread since it is blocking let child = thread::spawn(move || { let mut buf: [u8; 32] = [0; 32]; - let (fds, bytes) = get_fds_from("/tmp/pingora_fds_receive.sock", &mut buf).unwrap(); + let (fds, bytes) = + get_fds_from("/tmp/pingora_fds_receive.sock", &mut buf, None).unwrap(); debug!("{:?}", fds); assert_eq!(1, fds.len()); assert_eq!(32, bytes); @@ -396,7 +417,7 @@ mod tests { let fds = vec![dumb_fd]; let buf: [u8; 128] = [1; 128]; - match send_fds_to(fds, &buf, "/tmp/pingora_fds_receive.sock") { + match send_fds_to(fds, &buf, "/tmp/pingora_fds_receive.sock", None) { Ok(sent) => { assert!(sent > 0); } @@ -443,4 +464,67 @@ mod tests { fds.send_to_sock("/tmp/pingora_fds_receive2.sock").unwrap(); child.join().unwrap(); } + + #[test] + fn test_send_fds_to_respects_configurable_timeout() { + init_log(); + use std::time::Instant; + + let dumb_fd = socket::socket( + AddressFamily::Unix, + SockType::Stream, + SockFlag::empty(), + None, + ) + .unwrap(); + + let fds = vec![dumb_fd]; + let buf: [u8; 32] = [1; 32]; + + // Try to send with a custom max_retries of 2 + let start = Instant::now(); + let result = send_fds_to(fds, &buf, "/tmp/pingora_test_config_send.sock", Some(2)); + let elapsed = start.elapsed(); + + // Should fail after 2 retries with RETRY_INTERVAL (1 second) between each + // Total time should be approximately 2 seconds + assert!(result.is_err()); + assert!( + elapsed.as_secs() >= 2, + "Expected at least 2 seconds, got {:?}", + elapsed + ); + assert!( + elapsed.as_secs() < 4, + "Expected less than 4 seconds, got {:?}", + elapsed + ); + } + + #[test] + fn test_get_fds_from_respects_configurable_timeout() { + init_log(); + use std::time::Instant; + + let mut buf: [u8; 32] = [0; 32]; + + // Try to receive with a custom max_retries of 2 + let start = Instant::now(); + let result = get_fds_from("/tmp/pingora_test_config_receive.sock", &mut buf, Some(2)); + let elapsed = start.elapsed(); + + // Should fail after 2 retries with RETRY_INTERVAL (1 second) between each + // Total time should be approximately 2 seconds + assert!(result.is_err()); + assert!( + elapsed.as_secs() >= 2, + "Expected at least 2 seconds, got {:?}", + elapsed + ); + assert!( + elapsed.as_secs() < 4, + "Expected less than 4 seconds, got {:?}", + elapsed + ); + } } diff --git a/pingora-core/src/services/background.rs b/pingora-core/src/services/background.rs index 14582334..7edd3761 100644 --- a/pingora-core/src/services/background.rs +++ b/pingora-core/src/services/background.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -22,18 +22,39 @@ use async_trait::async_trait; use std::sync::Arc; -use super::Service; +use super::{ServiceReadyNotifier, ServiceWithDependents}; #[cfg(unix)] use crate::server::ListenFds; use crate::server::ShutdownWatch; /// The background service interface +/// +/// You can implement a background service with or without the ready notifier, +/// but you shouldn't implement both. Under the hood, the pingora service will +/// call the `start_with_ready_notifier` function. By default this function will +/// call the regular `start` function. #[async_trait] pub trait BackgroundService { + /// This function is called when the pingora server tries to start all the + /// services. The background service should signal readiness by calling + /// `ready_notifier.notify_ready()` once initialization is complete. + /// The service can return at anytime or wait for the `shutdown` signal. + /// + /// By default this method will immediately signal readiness and call + /// through to the regular `start` function + async fn start_with_ready_notifier( + &self, + shutdown: ShutdownWatch, + ready_notifier: ServiceReadyNotifier, + ) { + ready_notifier.notify_ready(); + self.start(shutdown).await; + } + /// This function is called when the pingora server tries to start all the /// services. The background service can return at anytime or wait for the /// `shutdown` signal. - async fn start(&self, mut shutdown: ShutdownWatch); + async fn start(&self, mut _shutdown: ShutdownWatch) {} } /// A generic type of background service @@ -63,17 +84,21 @@ impl GenBackgroundService { } #[async_trait] -impl Service for GenBackgroundService +impl ServiceWithDependents for GenBackgroundService where A: BackgroundService + Send + Sync + 'static, { + // Use default start_service implementation which signals ready immediately + // and then calls start_service + async fn start_service( &mut self, #[cfg(unix)] _fds: Option, shutdown: ShutdownWatch, _listeners_per_fd: usize, + ready: ServiceReadyNotifier, ) { - self.task.start(shutdown).await; + self.task.start_with_ready_notifier(shutdown, ready).await; } fn name(&self) -> &str { @@ -85,7 +110,7 @@ where } } -// Helper function to create a background service with a human readable name +/// Helper function to create a background service with a human readable name pub fn background_service(name: &str, task: SV) -> GenBackgroundService { GenBackgroundService::new(format!("BG {name}"), Arc::new(task)) } diff --git a/pingora-core/src/services/listening.rs b/pingora-core/src/services/listening.rs index d14f1246..4be5c4d9 100644 --- a/pingora-core/src/services/listening.rs +++ b/pingora-core/src/services/listening.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -20,7 +20,11 @@ use crate::apps::ServerApp; use crate::listeners::tls::TlsSettings; -use crate::listeners::{Listeners, ServerAddress, TcpSocketOptions, TransportStack}; +#[cfg(feature = "connection_filter")] +use crate::listeners::AcceptAllFilter; +use crate::listeners::{ + ConnectionFilter, Listeners, ServerAddress, TcpSocketOptions, TransportStack, +}; use crate::protocols::Stream; #[cfg(unix)] use crate::server::ListenFds; @@ -43,6 +47,8 @@ pub struct Service { app_logic: Option, /// The number of preferred threads. `None` to follow global setting. pub threads: Option, + #[cfg(feature = "connection_filter")] + connection_filter: Arc, } impl Service { @@ -53,6 +59,8 @@ impl Service { listeners: Listeners::new(), app_logic: Some(app_logic), threads: None, + #[cfg(feature = "connection_filter")] + connection_filter: Arc::new(AcceptAllFilter), } } @@ -64,9 +72,45 @@ impl Service { listeners, app_logic: Some(app_logic), threads: None, + #[cfg(feature = "connection_filter")] + connection_filter: Arc::new(AcceptAllFilter), } } + /// Set a custom connection filter for this service. + /// + /// The connection filter will be applied to all incoming connections + /// on all endpoints of this service. Connections that don't pass the + /// filter will be dropped immediately at the TCP level, before TLS + /// handshake or any HTTP processing. + /// + /// # Feature Flag + /// + /// This method requires the `connection_filter` feature to be enabled. + /// When the feature is disabled, this method is a no-op. + /// + /// # Example + /// + /// ```rust,no_run + /// # use std::sync::Arc; + /// # use pingora_core::listeners::{ConnectionFilter, AcceptAllFilter}; + /// # struct MyService; + /// # impl MyService { + /// # fn new() -> Self { MyService } + /// # } + /// let mut service = MyService::new(); + /// let filter = Arc::new(AcceptAllFilter); + /// service.set_connection_filter(filter); + /// ``` + #[cfg(feature = "connection_filter")] + pub fn set_connection_filter(&mut self, filter: Arc) { + self.connection_filter = filter.clone(); + self.listeners.set_connection_filter(filter); + } + + #[cfg(not(feature = "connection_filter"))] + pub fn set_connection_filter(&mut self, _filter: Arc) {} + /// Get the [`Listeners`], mostly to add more endpoints. pub fn endpoints(&mut self) -> &mut Listeners { &mut self.listeners diff --git a/pingora-core/src/services/mod.rs b/pingora-core/src/services/mod.rs index 51bda994..7c450428 100644 --- a/pingora-core/src/services/mod.rs +++ b/pingora-core/src/services/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -22,6 +22,15 @@ //! - services that are just running in the background. use async_trait::async_trait; +use daggy::Walker; +use daggy::{petgraph::visit::Topo, Dag, NodeIndex}; +use log::{error, info, warn}; +use parking_lot::Mutex; +use std::borrow::Borrow; +use std::sync::Arc; +use std::sync::Weak; +use std::time::Duration; +use tokio::sync::watch; #[cfg(unix)] use crate::server::ListenFds; @@ -30,22 +39,274 @@ use crate::server::ShutdownWatch; pub mod background; pub mod listening; -/// The service interface +/// A notification channel for signaling when a service has become ready. +/// +/// Services can use this to notify other services that may depend on them +/// that they have successfully started and are ready to serve requests. +/// +/// # Example +/// +/// ```rust,ignore +/// use pingora_core::services::ServiceReadyNotifier; +/// +/// async fn my_service(ready_notifier: ServiceReadyNotifier) { +/// // Perform initialization... +/// +/// // Signal that the service is ready +/// ready_notifier.notify_ready(); +/// +/// // Continue with main service loop... +/// } +/// ``` +pub struct ServiceReadyNotifier { + sender: watch::Sender, +} + +impl Drop for ServiceReadyNotifier { + /// In the event that the notifier is dropped before notifying that the + /// service is ready, we opt to signal ready anyway + fn drop(&mut self) { + // Ignore errors - if there are no receivers, that's fine + let _ = self.sender.send(true); + } +} + +impl ServiceReadyNotifier { + /// Creates a new ServiceReadyNotifier from a watch sender. + /// You will not need to create one of these for normal usage, but being + /// able to is useful for testing. + pub fn new(sender: watch::Sender) -> Self { + Self { sender } + } + + /// Notifies dependent services that this service is ready. + /// + /// Consumes the notifier to ensure ready is only signaled once. + pub fn notify_ready(self) { + // Dropping the notifier will signal that the service is ready + drop(self); + } +} + +/// A receiver for watching when a service becomes ready. +pub type ServiceReadyWatch = watch::Receiver; + +/// A handle to a service in the server. +/// +/// This is returned by [`crate::server::Server::add_service()`] and provides +/// methods to declare that other services depend on this one. +/// +/// # Example +/// +/// ```rust,ignore +/// let db_handle = server.add_service(database_service); +/// let cache_handle = server.add_service(cache_service); +/// +/// let api_handle = server.add_service(api_service); +/// api_handle.add_dependency(&db_handle); +/// api_handle.add_dependency(&cache_handle); +/// ``` +#[derive(Debug, Clone)] +pub struct ServiceHandle { + pub(crate) id: NodeIndex, + name: String, + ready_watch: ServiceReadyWatch, + dependencies: Weak>, +} + +/// Internal representation of a dependency relationship. +#[derive(Debug, Clone)] +pub(crate) struct ServiceDependency { + pub name: String, + pub ready_watch: ServiceReadyWatch, +} + +impl ServiceHandle { + /// Creates a new ServiceHandle with the given ID, name, and readiness watcher. + pub(crate) fn new( + id: NodeIndex, + name: String, + ready_watch: ServiceReadyWatch, + dependencies: &Arc>, + ) -> Self { + Self { + id, + name, + ready_watch, + dependencies: Arc::downgrade(dependencies), + } + } + + #[cfg(test)] + fn get_dependencies(&self) -> Vec { + let Some(deps_lock) = self.dependencies.upgrade() else { + return Vec::new(); + }; + + let deps = deps_lock.lock(); + deps.get_dependencies(self.id) + } + + /// Returns the name of the service. + pub fn name(&self) -> &str { + &self.name + } + + /// Returns a clone of the readiness watcher for this service. + #[allow(dead_code)] + pub(crate) fn ready_watch(&self) -> ServiceReadyWatch { + self.ready_watch.clone() + } + + /// Declares that this service depends on another service. + /// + /// This service will not start until the specified dependency has started + /// and signaled readiness. + /// + /// # Example + /// + /// ```rust,ignore + /// let db_id = server.add_service(database_service); + /// let api_id = server.add_service(api_service); + /// + /// // API service depends on database + /// api_id.add_dependency(&db_id); + /// ``` + pub fn add_dependency(&self, dependency: impl Borrow) { + let Some(deps_lock) = self.dependencies.upgrade() else { + warn!("Attempted to add a dependency after the dependency tree was dropped"); + return; + }; + + let mut deps = deps_lock.lock(); + if let Err(e) = deps.add_dependency(self.id, dependency.borrow().id) { + error!("Error creating dependency edge: {e}"); + } + } + + /// Declares that this service depends on the given other services. + /// + /// This service will not start until the specified dependencies have + /// started and signaled readiness. + /// + /// # Example + /// + /// ```rust,ignore + /// let db_id = server.add_service(database_service); + /// let cache_id = server.add_service(cache_service); + /// let api_id = server.add_service(api_service); + /// + /// // API service depends on database + /// api_id.add_dependencies(&[&db_id, &cache_id]); + /// ``` + pub fn add_dependencies<'a, D>(&self, dependencies: impl IntoIterator) + where + D: Borrow + 'a, + { + for dependency in dependencies { + self.add_dependency(dependency); + } + } +} + +/// Helper for validating service dependency graphs using daggy. +pub(crate) struct DependencyGraph { + /// The directed acyclic graph structure from daggy. + dag: Dag, +} + +impl DependencyGraph { + /// Creates a new dependency graph. + pub(crate) fn new() -> Self { + Self { dag: Dag::new() } + } + + /// Adds a service node to the graph. + /// + /// This should be called for all services first, before adding edges. + pub(crate) fn add_node(&mut self, name: String, ready_watch: ServiceReadyWatch) -> NodeIndex { + self.dag.add_node(ServiceDependency { name, ready_watch }) + } + /// Adds a dependency edge from one service to another. + /// + /// Returns an error if adding this dependency would create a cycle or reference + /// a non-existent service. + pub(crate) fn add_dependency( + &mut self, + dependent_service_node_idx: NodeIndex, + dependency_service_node_idx: NodeIndex, + ) -> Result<(), String> { + // Try to add edge (from dependency to dependent) + // daggy will return an error if this would create a cycle + if let Err(cycle) = + self.dag + .add_edge(dependency_service_node_idx, dependent_service_node_idx, ()) + { + return Err(format!( + "Circular service dependency detected between {} and {} creating cycle: {cycle}", + self.dag[dependency_service_node_idx].name, + self.dag[dependent_service_node_idx].name + )); + } + + Ok(()) + } + + /// Returns services in topological order (dependencies before dependents). + /// + /// This ordering ensures that services are started in the correct order. + /// Returns service IDs in the correct startup order. + pub(crate) fn topological_sort(&self) -> Result, String> { + // Use daggy's built-in topological walker + let mut sorted = Vec::new(); + let mut topo = Topo::new(&self.dag); + + while let Some(service_id) = topo.next(&self.dag) { + sorted.push((service_id, self.dag[service_id].clone())); + } + + Ok(sorted) + } + + pub(crate) fn get_dependencies(&self, service_id: NodeIndex) -> Vec { + self.dag + .parents(service_id) + .iter(&self.dag) + .map(|(_, n)| self.dag[n].clone()) + .collect() + } +} + +impl Default for DependencyGraph { + fn default() -> Self { + Self::new() + } +} + #[async_trait] -pub trait Service: Sync + Send { +pub trait ServiceWithDependents: Send + Sync { /// This function will be called when the server is ready to start the service. /// + /// Override this method if you need to control exactly when the service signals readiness + /// (e.g., after async initialization is complete). + /// + /// # Arguments + /// /// - `fds` (Unix only): a collection of listening file descriptors. During zero downtime restart - /// the `fds` would contain the listening sockets passed from the old service, services should - /// take the sockets they need to use then. If the sockets the service looks for don't appear in - /// the collection, the service should create its own listening sockets and then put them into - /// the collection in order for them to be passed to the next server. + /// the `fds` would contain the listening sockets passed from the old service, services should + /// take the sockets they need to use then. If the sockets the service looks for don't appear in + /// the collection, the service should create its own listening sockets and then put them into + /// the collection in order for them to be passed to the next server. /// - `shutdown`: the shutdown signal this server would receive. + /// - `listeners_per_fd`: number of listener tasks to spawn per file descriptor. + /// - `ready_notifier`: notifier to signal when the service is ready. Services with + /// dependents should call `ready_notifier.notify_ready()` once they are fully initialized. async fn start_service( &mut self, #[cfg(unix)] fds: Option, - mut shutdown: ShutdownWatch, + shutdown: ShutdownWatch, listeners_per_fd: usize, + ready_notifier: ServiceReadyNotifier, ); /// The name of the service, just for logging and naming the threads assigned to this service @@ -59,4 +320,379 @@ pub trait Service: Sync + Send { fn threads(&self) -> Option { None } + + /// This is currently called to inform the service about the delay it + /// experienced from between waiting on its dependencies. Default behavior + /// is to log the time. + /// + /// TODO. It would be nice if this function was called intermittently by + /// the server while the service was waiting to give live updates while the + /// service was waiting and allow the service to decide whether to keep + /// waiting, continue anyway, or exit + fn on_startup_delay(&self, time_waited: Duration) { + info!( + "Service {} spent {}ms waiting on dependencies", + self.name(), + time_waited.as_millis() + ); + } +} + +#[async_trait] +impl ServiceWithDependents for S +where + S: Service, +{ + async fn start_service( + &mut self, + #[cfg(unix)] fds: Option, + shutdown: ShutdownWatch, + listeners_per_fd: usize, + ready_notifier: ServiceReadyNotifier, + ) { + // Signal ready immediately + ready_notifier.notify_ready(); + + S::start_service( + self, + #[cfg(unix)] + fds, + shutdown, + listeners_per_fd, + ) + .await + } + + fn name(&self) -> &str { + S::name(self) + } + + fn threads(&self) -> Option { + S::threads(self) + } + + fn on_startup_delay(&self, time_waited: Duration) { + S::on_startup_delay(self, time_waited) + } +} + +/// The service interface +#[async_trait] +pub trait Service: Sync + Send { + /// Start the service without readiness notification. + /// + /// This is a simpler version of [`Self::start_service()`] for services that don't need + /// to control when they signal readiness. The default implementation does nothing. + /// + /// Most services should override this method instead of [`Self::start_service()`]. + /// + /// # Arguments + /// + /// - `fds` (Unix only): a collection of listening file descriptors. + /// - `shutdown`: the shutdown signal this server would receive. + /// - `listeners_per_fd`: number of listener tasks to spawn per file descriptor. + async fn start_service( + &mut self, + #[cfg(unix)] _fds: Option, + _shutdown: ShutdownWatch, + _listeners_per_fd: usize, + ) { + // Default: do nothing + } + + /// The name of the service, just for logging and naming the threads assigned to this service + /// + /// Note that due to the limit of the underlying system, only the first 16 chars will be used + fn name(&self) -> &str; + + /// The preferred number of threads to run this service + /// + /// If `None`, the global setting will be used + fn threads(&self) -> Option { + None + } + + /// This is currently called to inform the service about the delay it + /// experienced from between waiting on its dependencies. Default behavior + /// is to log the time. + /// + /// TODO. It would be nice if this function was called intermittently by + /// the server while the service was waiting to give live updates while the + /// service was waiting and allow the service to decide whether to keep + /// waiting, continue anyway, or exit + fn on_startup_delay(&self, time_waited: Duration) { + info!( + "Service {} spent {}ms waiting on dependencies", + self.name(), + time_waited.as_millis() + ); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_service_handle_creation() { + let deps: Arc> = Arc::new(Mutex::new(DependencyGraph::new())); + let (tx, rx) = watch::channel(false); + let service_id = ServiceHandle::new(0.into(), "test_service".to_string(), rx, &deps); + + assert_eq!(service_id.id, 0.into()); + assert_eq!(service_id.name(), "test_service"); + + // Should be able to clone the watch + let watch_clone = service_id.ready_watch(); + assert!(!*watch_clone.borrow()); + + // Signaling ready should be observable through cloned watch + tx.send(true).ok(); + assert!(*watch_clone.borrow()); + } + + #[test] + fn test_service_handle_add_dependency() { + let graph: Arc> = Arc::new(Mutex::new(DependencyGraph::new())); + let (tx1, rx1) = watch::channel(false); + let (tx1_clone, rx1_clone) = (tx1.clone(), rx1.clone()); + let (_tx2, rx2) = watch::channel(false); + let (_tx2_clone, rx2_clone) = (_tx2.clone(), rx2.clone()); + + // Add nodes to the graph first + let dep_node = { + let mut g = graph.lock(); + g.add_node("dependency".to_string(), rx1) + }; + let main_node = { + let mut g = graph.lock(); + g.add_node("main".to_string(), rx2) + }; + + let dep_service = ServiceHandle::new(dep_node, "dependency".to_string(), rx1_clone, &graph); + let main_service = ServiceHandle::new(main_node, "main".to_string(), rx2_clone, &graph); + + // Add dependency + main_service.add_dependency(&dep_service); + + // Get dependencies and verify + let deps = main_service.get_dependencies(); + assert_eq!(deps.len(), 1); + assert_eq!(deps[0].name, "dependency"); + + // Verify watch is working + assert!(!*deps[0].ready_watch.borrow()); + tx1_clone.send(true).ok(); + assert!(*deps[0].ready_watch.borrow()); + } + + #[test] + fn test_service_handle_multiple_dependencies() { + let graph: Arc> = Arc::new(Mutex::new(DependencyGraph::new())); + let (_tx1, rx1) = watch::channel(false); + let rx1_clone = rx1.clone(); + let (_tx2, rx2) = watch::channel(false); + let rx2_clone = rx2.clone(); + let (_tx3, rx3) = watch::channel(false); + let rx3_clone = rx3.clone(); + + // Add nodes to the graph first + let dep1_node = { + let mut g = graph.lock(); + g.add_node("dep1".to_string(), rx1) + }; + let dep2_node = { + let mut g = graph.lock(); + g.add_node("dep2".to_string(), rx2) + }; + let main_node = { + let mut g = graph.lock(); + g.add_node("main".to_string(), rx3) + }; + + let dep1 = ServiceHandle::new(dep1_node, "dep1".to_string(), rx1_clone, &graph); + let dep2 = ServiceHandle::new(dep2_node, "dep2".to_string(), rx2_clone, &graph); + let main_service = ServiceHandle::new(main_node, "main".to_string(), rx3_clone, &graph); + + // Add multiple dependencies + main_service.add_dependency(&dep1); + main_service.add_dependency(&dep2); + + // Get dependencies and verify + let deps = main_service.get_dependencies(); + assert_eq!(deps.len(), 2); + + let dep_names: Vec<&str> = deps.iter().map(|d| d.name.as_str()).collect(); + assert!(dep_names.contains(&"dep1")); + assert!(dep_names.contains(&"dep2")); + } + + #[test] + fn test_single_service_no_dependencies() { + let mut graph = DependencyGraph::new(); + let (_tx, rx) = watch::channel(false); + let _node = graph.add_node("service1".to_string(), rx); + + let order = graph.topological_sort().unwrap(); + assert_eq!(order.len(), 1); + assert_eq!(order[0].1.name, "service1"); + } + + #[test] + fn test_simple_dependency_chain() { + let mut graph = DependencyGraph::new(); + let (_tx1, rx1) = watch::channel(false); + let (_tx2, rx2) = watch::channel(false); + let (_tx3, rx3) = watch::channel(false); + + let node1 = graph.add_node("service1".to_string(), rx1); + let node2 = graph.add_node("service2".to_string(), rx2); + let node3 = graph.add_node("service3".to_string(), rx3); + + // service2 depends on service1, service3 depends on service2 + graph.add_dependency(node2, node1).unwrap(); + graph.add_dependency(node3, node2).unwrap(); + + let order = graph.topological_sort().unwrap(); + assert_eq!(order.len(), 3); + // Verify order: service1, service2, service3 + assert_eq!(order[0].1.name, "service1"); + assert_eq!(order[1].1.name, "service2"); + assert_eq!(order[2].1.name, "service3"); + } + + #[test] + fn test_diamond_dependency() { + let mut graph = DependencyGraph::new(); + let (_tx1, rx1) = watch::channel(false); + let (_tx2, rx2) = watch::channel(false); + let (_tx3, rx3) = watch::channel(false); + + let db = graph.add_node("db".to_string(), rx1); + let cache = graph.add_node("cache".to_string(), rx2); + let api = graph.add_node("api".to_string(), rx3); + + // api depends on both db and cache + graph.add_dependency(api, db).unwrap(); + graph.add_dependency(api, cache).unwrap(); + + let order = graph.topological_sort().unwrap(); + // api should come last, but db and cache order doesn't matter + assert_eq!(order.len(), 3); + assert_eq!(order[2].1.name, "api"); + let first_two: Vec<&str> = order[0..2].iter().map(|(_, d)| d.name.as_str()).collect(); + assert!(first_two.contains(&"db")); + assert!(first_two.contains(&"cache")); + } + + #[test] + #[should_panic(expected = "node indices out of bounds")] + fn test_missing_dependency() { + let mut graph = DependencyGraph::new(); + let (_tx1, rx1) = watch::channel(false); + + let node1 = graph.add_node("service1".to_string(), rx1); + let nonexistent = NodeIndex::new(999); + + // Try to add dependency on non-existent node - this should panic + let _ = graph.add_dependency(node1, nonexistent); + } + + #[test] + fn test_circular_dependency_self() { + let mut graph = DependencyGraph::new(); + let (_tx1, rx1) = watch::channel(false); + + let node1 = graph.add_node("service1".to_string(), rx1); + + // Try to make service depend on itself + let result = graph.add_dependency(node1, node1); + + assert!(result.is_err()); + assert!(result.unwrap_err().contains("Circular")); + } + + #[test] + fn test_circular_dependency_two_services() { + let mut graph = DependencyGraph::new(); + let (_tx1, rx1) = watch::channel(false); + let (_tx2, rx2) = watch::channel(false); + + // Add both nodes first + let node1 = graph.add_node("service1".to_string(), rx1); + let node2 = graph.add_node("service2".to_string(), rx2); + + // Try to add circular dependencies + graph.add_dependency(node1, node2).unwrap(); + let result = graph.add_dependency(node2, node1); + + assert!(result.is_err()); + assert!(result.unwrap_err().contains("Circular")); + } + + #[test] + fn test_circular_dependency_three_services() { + let mut graph = DependencyGraph::new(); + let (_tx1, rx1) = watch::channel(false); + let (_tx2, rx2) = watch::channel(false); + let (_tx3, rx3) = watch::channel(false); + + // Add all nodes first + let node1 = graph.add_node("service1".to_string(), rx1); + let node2 = graph.add_node("service2".to_string(), rx2); + let node3 = graph.add_node("service3".to_string(), rx3); + + // Add dependencies that would form a cycle + graph.add_dependency(node1, node2).unwrap(); + graph.add_dependency(node2, node3).unwrap(); + let result = graph.add_dependency(node3, node1); + + assert!(result.is_err()); + assert!(result.unwrap_err().contains("Circular")); + } + + #[test] + fn test_complex_valid_graph() { + let mut graph = DependencyGraph::new(); + let (_tx1, rx1) = watch::channel(false); + let (_tx2, rx2) = watch::channel(false); + let (_tx3, rx3) = watch::channel(false); + let (_tx4, rx4) = watch::channel(false); + let (_tx5, rx5) = watch::channel(false); + + // Build a complex dependency graph: + // db, cache - no deps + // auth -> db + // api -> db, cache, auth + // frontend -> api + let db = graph.add_node("db".to_string(), rx1); + let cache = graph.add_node("cache".to_string(), rx2); + let auth = graph.add_node("auth".to_string(), rx3); + let api = graph.add_node("api".to_string(), rx4); + let frontend = graph.add_node("frontend".to_string(), rx5); + + graph.add_dependency(auth, db).unwrap(); + graph.add_dependency(api, db).unwrap(); + graph.add_dependency(api, cache).unwrap(); + graph.add_dependency(api, auth).unwrap(); + graph.add_dependency(frontend, api).unwrap(); + + let order = graph.topological_sort().unwrap(); + + // Verify ordering constraints using names + let db_pos = order.iter().position(|(_, d)| d.name == "db").unwrap(); + let cache_pos = order.iter().position(|(_, d)| d.name == "cache").unwrap(); + let auth_pos = order.iter().position(|(_, d)| d.name == "auth").unwrap(); + let api_pos = order.iter().position(|(_, d)| d.name == "api").unwrap(); + let frontend_pos = order + .iter() + .position(|(_, d)| d.name == "frontend") + .unwrap(); + + assert!(db_pos < auth_pos); + assert!(auth_pos < api_pos); + assert!(db_pos < api_pos); + assert!(cache_pos < api_pos); + assert!(api_pos < frontend_pos); + } } diff --git a/pingora-core/src/upstreams/mod.rs b/pingora-core/src/upstreams/mod.rs index 2348bc85..b66fc26a 100644 --- a/pingora-core/src/upstreams/mod.rs +++ b/pingora-core/src/upstreams/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-core/src/upstreams/peer.rs b/pingora-core/src/upstreams/peer.rs index 84ed43d5..c9ae0a66 100644 --- a/pingora-core/src/upstreams/peer.rs +++ b/pingora-core/src/upstreams/peer.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -17,6 +17,8 @@ use crate::connectors::{l4::BindTo, L4Connect}; use crate::protocols::l4::socket::SocketAddr; use crate::protocols::tls::CaType; +#[cfg(feature = "openssl_derived")] +use crate::protocols::tls::HandshakeCompleteHook; #[cfg(feature = "s2n")] use crate::protocols::tls::PskType; #[cfg(unix)] @@ -46,6 +48,23 @@ use tokio::net::TcpSocket; pub use crate::protocols::tls::ALPN; +/// A hook function that may generate user data for [`crate::protocols::raw_connect::ProxyDigest`]. +/// +/// Takes the request and response headers from the proxy connection establishment, and may produce +/// arbitrary data to be stored in ProxyDigest's user_data field. +/// +/// This can be useful when, for example, you want to store some parameter(s) from the request or +/// response headers from when the proxy connection was first established. +pub type ProxyDigestUserDataHook = Arc< + dyn Fn( + &http::request::Parts, // request headers + &pingora_http::ResponseHeader, // response headers + ) -> Option> + + Send + + Sync + + 'static, +>; + /// The interface to trace the connection pub trait Tracing: Send + Sync + std::fmt::Debug { /// This method is called when successfully connected to a remote server @@ -261,6 +280,29 @@ pub trait Peer: Display + Clone { .upstream_tcp_sock_tweak_hook .as_ref() } + + /// Returns a [`ProxyDigestUserDataHook`] that may generate user data for + /// [`crate::protocols::raw_connect::ProxyDigest`] when establishing a new proxy connection. + fn proxy_digest_user_data_hook(&self) -> Option<&ProxyDigestUserDataHook> { + self.get_peer_options()? + .proxy_digest_user_data_hook + .as_ref() + } + + /// Returns a hook that should be run on TLS handshake completion. + /// + /// Any value returned from the returned hook (other than `None`) will be stored in the + /// `extension` field of `SslDigest`. This allows you to attach custom application-specific + /// data to the TLS connection, which will be accessible from the HTTP layer via the + /// `SslDigest` attached to the session digest. + /// + /// Currently only enabled for openssl variants with meaningful `TlsRef`s. + #[cfg(feature = "openssl_derived")] + fn upstream_tls_handshake_complete_hook(&self) -> Option<&HandshakeCompleteHook> { + self.get_peer_options()? + .upstream_tls_handshake_complete_hook + .as_ref() + } } /// A simple TCP or TLS peer without many complicated settings. @@ -391,6 +433,13 @@ pub struct PeerOptions { pub max_blinding_delay: Option, // how many concurrent h2 stream are allowed in the same connection pub max_h2_streams: usize, + /// Allow invalid Content-Length in HTTP/1 responses (non-RFC compliant). + /// + /// When enabled, invalid Content-Length responses are treated as close-delimited responses. + /// + /// **Note:** This field is unstable and may be removed or changed in future versions. + /// It exists primarily for compatibility with legacy servers that send malformed headers. + pub allow_h1_response_invalid_content_length: bool, pub extra_proxy_headers: BTreeMap>, // The list of curve the tls connection should advertise // if `None`, the default curves will be used @@ -406,6 +455,15 @@ pub struct PeerOptions { #[derivative(Debug = "ignore")] pub upstream_tcp_sock_tweak_hook: Option Result<()> + Send + Sync + 'static>>, + #[derivative(Debug = "ignore")] + pub proxy_digest_user_data_hook: Option, + /// Hook that allows returning an optional `SslDigestExtension`. + /// Any returned value will be saved into the `SslDigest`. + /// + /// Currently only enabled for openssl variants with meaningful `TlsRef`s. + #[cfg(feature = "openssl_derived")] + #[derivative(Debug = "ignore")] + pub upstream_tls_handshake_complete_hook: Option, } impl PeerOptions { @@ -436,6 +494,7 @@ impl PeerOptions { #[cfg(feature = "s2n")] max_blinding_delay: None, max_h2_streams: 1, + allow_h1_response_invalid_content_length: false, extra_proxy_headers: BTreeMap::new(), curves: None, second_keyshare: true, // default true and noop when not using PQ curves @@ -443,6 +502,9 @@ impl PeerOptions { tracer: None, custom_l4: None, upstream_tcp_sock_tweak_hook: None, + proxy_digest_user_data_hook: None, + #[cfg(feature = "openssl_derived")] + upstream_tls_handshake_complete_hook: None, } } @@ -588,6 +650,17 @@ impl HttpPeer { } } + /// Create a new [`HttpPeer`] with client certificate and key for mutual TLS. + pub fn new_mtls( + address: A, + sni: String, + client_cert_key: Arc, + ) -> Self { + let mut peer = Self::new(address, true, sni); + peer.client_cert_key = Some(client_cert_key); + peer + } + fn peer_hash(&self) -> u64 { let mut hasher = AHasher::default(); self.hash(&mut hasher); @@ -610,6 +683,8 @@ impl Hash for HttpPeer { #[cfg(feature = "s2n")] self.get_psk().hash(state); self.group_key.hash(state); + // max h2 stream settings + self.options.max_h2_streams.hash(state); } } diff --git a/pingora-core/src/utils/mod.rs b/pingora-core/src/utils/mod.rs index 2479c0b7..66ad444e 100644 --- a/pingora-core/src/utils/mod.rs +++ b/pingora-core/src/utils/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-core/src/utils/tls/boringssl_openssl.rs b/pingora-core/src/utils/tls/boringssl_openssl.rs index f78d5aeb..1f18adfb 100644 --- a/pingora-core/src/utils/tls/boringssl_openssl.rs +++ b/pingora-core/src/utils/tls/boringssl_openssl.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-core/src/utils/tls/mod.rs b/pingora-core/src/utils/tls/mod.rs index 887293b3..c345073e 100644 --- a/pingora-core/src/utils/tls/mod.rs +++ b/pingora-core/src/utils/tls/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-core/src/utils/tls/rustls.rs b/pingora-core/src/utils/tls/rustls.rs index d336e1fe..429b3724 100644 --- a/pingora-core/src/utils/tls/rustls.rs +++ b/pingora-core/src/utils/tls/rustls.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-core/src/utils/tls/s2n.rs b/pingora-core/src/utils/tls/s2n.rs index f52d86b1..4dffd32b 100644 --- a/pingora-core/src/utils/tls/s2n.rs +++ b/pingora-core/src/utils/tls/s2n.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-core/tests/client_hello_integration.rs b/pingora-core/tests/client_hello_integration.rs index 1744524c..75df8664 100644 --- a/pingora-core/tests/client_hello_integration.rs +++ b/pingora-core/tests/client_hello_integration.rs @@ -222,7 +222,6 @@ mod tests { #[tokio::test] async fn test_async_extraction() { - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); let notify = Arc::new(Notify::new()); @@ -260,4 +259,3 @@ mod tests { let _ = tokio::join!(server_task, client_task); } } - diff --git a/pingora-core/tests/server_phase_fastshutdown.rs b/pingora-core/tests/server_phase_fastshutdown.rs index 83eb3e9b..def35552 100644 --- a/pingora-core/tests/server_phase_fastshutdown.rs +++ b/pingora-core/tests/server_phase_fastshutdown.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-core/tests/server_phase_gracefulshutdown.rs b/pingora-core/tests/server_phase_gracefulshutdown.rs index 7c231e93..9d123f1e 100644 --- a/pingora-core/tests/server_phase_gracefulshutdown.rs +++ b/pingora-core/tests/server_phase_gracefulshutdown.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-core/tests/test_basic.rs b/pingora-core/tests/test_basic.rs index 60f95026..0c9f87f9 100644 --- a/pingora-core/tests/test_basic.rs +++ b/pingora-core/tests/test_basic.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-core/tests/utils/mod.rs b/pingora-core/tests/utils/mod.rs index 7062b349..a5016c0b 100644 --- a/pingora-core/tests/utils/mod.rs +++ b/pingora-core/tests/utils/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-error/Cargo.toml b/pingora-error/Cargo.toml index aec7939d..6aae2aee 100644 --- a/pingora-error/Cargo.toml +++ b/pingora-error/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pingora-error" -version = "0.6.0" +version = "0.8.0" authors = ["Yuchen Wu "] license = "Apache-2.0" edition = "2021" diff --git a/pingora-error/src/immut_str.rs b/pingora-error/src/immut_str.rs index a03ef353..a9e1b6da 100644 --- a/pingora-error/src/immut_str.rs +++ b/pingora-error/src/immut_str.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-error/src/lib.rs b/pingora-error/src/lib.rs index c2d25ad5..c561bccf 100644 --- a/pingora-error/src/lib.rs +++ b/pingora-error/src/lib.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -38,7 +38,7 @@ pub struct Error { /// if the error is retry-able pub retry: RetryType, /// chain to the cause of this error - pub cause: Option>, + pub cause: Option>, /// an arbitrary string that explains the context when the error happens pub context: Option, } @@ -88,7 +88,7 @@ impl From for RetryType { impl ErrorSource { /// for displaying the error source - pub fn as_str(&self) -> &str { + pub fn as_str(&self) -> &'static str { match self { Self::Upstream => "Upstream", Self::Downstream => "Downstream", @@ -159,7 +159,7 @@ impl ErrorType { } /// for displaying the error type - pub fn as_str(&self) -> &str { + pub fn as_str(&self) -> &'static str { match self { ErrorType::ConnectTimedout => "ConnectTimedout", ErrorType::ConnectRefused => "ConnectRefused", diff --git a/pingora-header-serde/Cargo.toml b/pingora-header-serde/Cargo.toml index 8c25636e..181a60b8 100644 --- a/pingora-header-serde/Cargo.toml +++ b/pingora-header-serde/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pingora-header-serde" -version = "0.6.0" +version = "0.8.0" authors = ["Yuchen Wu "] license = "Apache-2.0" edition = "2021" @@ -27,6 +27,6 @@ zstd-safe = { version = "7.1.0", features = ["std"] } http = { workspace = true } bytes = { workspace = true } httparse = { workspace = true } -pingora-error = { version = "0.6.0", path = "../pingora-error" } -pingora-http = { version = "0.6.0", path = "../pingora-http" } +pingora-error = { version = "0.8.0", path = "../pingora-error" } +pingora-http = { version = "0.8.0", path = "../pingora-http" } thread_local = "1.0" diff --git a/pingora-header-serde/src/dict.rs b/pingora-header-serde/src/dict.rs index 792698c1..3fb788d4 100644 --- a/pingora-header-serde/src/dict.rs +++ b/pingora-header-serde/src/dict.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-header-serde/src/lib.rs b/pingora-header-serde/src/lib.rs index d93330e0..71122bf3 100644 --- a/pingora-header-serde/src/lib.rs +++ b/pingora-header-serde/src/lib.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -42,7 +42,8 @@ pub struct HeaderSerde { buf: ThreadLocal>>, } -const MAX_HEADER_SIZE: usize = 64 * 1024; +const MAX_HEADER_BUF_SIZE: usize = 128 * 1024; // 128KB + const COMPRESS_LEVEL: i32 = 3; impl HeaderSerde { @@ -76,7 +77,7 @@ impl HeaderSerde { // TODO: should convert to h1 if the incoming header is for h2 let mut buf = self .buf - .get_or(|| RefCell::new(Vec::with_capacity(MAX_HEADER_SIZE))) + .get_or(|| RefCell::new(Vec::with_capacity(MAX_HEADER_BUF_SIZE))) .borrow_mut(); buf.clear(); // reset the buf resp_header_to_buf(header, &mut buf); @@ -87,7 +88,7 @@ impl HeaderSerde { pub fn deserialize(&self, data: &[u8]) -> Result { let mut buf = self .buf - .get_or(|| RefCell::new(Vec::with_capacity(MAX_HEADER_SIZE))) + .get_or(|| RefCell::new(Vec::with_capacity(MAX_HEADER_BUF_SIZE))) .borrow_mut(); buf.clear(); // reset the buf self.compression @@ -219,6 +220,7 @@ fn buf_to_http_header(buf: &[u8]) -> Result { #[inline] fn parsed_to_header(parsed: &httparse::Response) -> Result { // code should always be there + // TODO: allow reading the parsed http version? let mut resp = ResponseHeader::build(parsed.code.unwrap(), Some(parsed.headers.len()))?; for header in parsed.headers.iter() { diff --git a/pingora-header-serde/src/thread_zstd.rs b/pingora-header-serde/src/thread_zstd.rs index 99aaf617..4510d2b4 100644 --- a/pingora-header-serde/src/thread_zstd.rs +++ b/pingora-header-serde/src/thread_zstd.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -114,14 +114,14 @@ impl CompressionInner { } #[inline] - fn get_com_context(&self) -> RefMut> { + fn get_com_context(&self) -> RefMut<'_, CCtx<'static>> { self.com_context .get_or(|| RefCell::new(CCtx::create())) .borrow_mut() } #[inline] - fn get_de_context(&self) -> RefMut> { + fn get_de_context(&self) -> RefMut<'_, DCtx<'static>> { self.de_context .get_or(|| RefCell::new(DCtx::create())) .borrow_mut() diff --git a/pingora-header-serde/src/trainer.rs b/pingora-header-serde/src/trainer.rs index 9e0ac5dc..aa016d45 100644 --- a/pingora-header-serde/src/trainer.rs +++ b/pingora-header-serde/src/trainer.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-http/Cargo.toml b/pingora-http/Cargo.toml index 459b206e..82f1b65e 100644 --- a/pingora-http/Cargo.toml +++ b/pingora-http/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pingora-http" -version = "0.6.0" +version = "0.8.0" authors = ["Yuchen Wu "] license = "Apache-2.0" edition = "2021" @@ -19,7 +19,7 @@ path = "src/lib.rs" [dependencies] http = { workspace = true } bytes = { workspace = true } -pingora-error = { version = "0.6.0", path = "../pingora-error" } +pingora-error = { version = "0.8.0", path = "../pingora-error" } [features] default = [] diff --git a/pingora-http/src/case_header_name.rs b/pingora-http/src/case_header_name.rs index 3e2b7acf..28d62c27 100644 --- a/pingora-http/src/case_header_name.rs +++ b/pingora-http/src/case_header_name.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -85,6 +85,7 @@ fn titled_header_name(header_name: &HeaderName) -> Bytes { pub(crate) fn titled_header_name_str(header_name: &HeaderName) -> Option<&'static str> { Some(match *header_name { + header::ACCEPT_RANGES => "Accept-Ranges", header::AGE => "Age", header::CACHE_CONTROL => "Cache-Control", header::CONNECTION => "Connection", diff --git a/pingora-http/src/lib.rs b/pingora-http/src/lib.rs index 103abe9c..954be81b 100644 --- a/pingora-http/src/lib.rs +++ b/pingora-http/src/lib.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -30,7 +30,7 @@ use http::response::Builder as RespBuilder; use http::response::Parts as RespParts; use http::uri::Uri; use pingora_error::{ErrorType::*, OrErr, Result}; -use std::ops::Deref; +use std::ops::{Deref, DerefMut}; pub use http::method::Method; pub use http::status::StatusCode; @@ -43,6 +43,7 @@ pub use case_header_name::IntoCaseHeaderName; pub mod prelude { pub use crate::RequestHeader; + pub use crate::ResponseHeader; } /* an ordered header map to store the original case of each header name @@ -56,6 +57,11 @@ This idea is inspaired by hyper @nox */ type CaseMap = HMap; +pub enum HeaderNameVariant<'a> { + Case(&'a CaseHeaderName), + Titled(&'a str), +} + /// The HTTP request header type. /// /// This type is similar to [http::request::Parts] but preserves header name case. @@ -87,6 +93,12 @@ impl Deref for RequestHeader { } } +impl DerefMut for RequestHeader { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.base + } +} + impl RequestHeader { fn new_no_case(size_hint: Option) -> Self { let mut base = ReqBuilder::new().body(()).unwrap().into_parts().0; @@ -200,6 +212,33 @@ impl RequestHeader { self.header_name_map.is_some() } + pub fn map Result<()>>( + &self, + mut f: F, + ) -> Result<()> { + let key_map = self.header_name_map.as_ref(); + let value_map = &self.base.headers; + + if let Some(key_map) = key_map { + let iter = key_map.iter().zip(value_map.iter()); + for ((header, case_header), (header2, val)) in iter { + if header != header2 { + // in case the header iteration order changes in future versions of HMap + panic!("header iter mismatch {}, {}", header, header2) + } + f(HeaderNameVariant::Case(case_header), val)?; + } + } else { + for (header, value) in value_map { + let titled_header = + case_header_name::titled_header_name_str(header).unwrap_or(header.as_str()); + f(HeaderNameVariant::Titled(titled_header), value)?; + } + } + + Ok(()) + } + /// Set the request method pub fn set_method(&mut self, method: Method) { self.base.method = method; @@ -349,6 +388,12 @@ impl Deref for ResponseHeader { } } +impl DerefMut for ResponseHeader { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.base + } +} + impl Clone for ResponseHeader { fn clone(&self) -> Self { Self { @@ -491,6 +536,33 @@ impl ResponseHeader { self.header_name_map.is_some() } + pub fn map Result<()>>( + &self, + mut f: F, + ) -> Result<()> { + let key_map = self.header_name_map.as_ref(); + let value_map = &self.base.headers; + + if let Some(key_map) = key_map { + let iter = key_map.iter().zip(value_map.iter()); + for ((header, case_header), (header2, val)) in iter { + if header != header2 { + // in case the header iteration order changes in future versions of HMap + panic!("header iter mismatch {}, {}", header, header2) + } + f(HeaderNameVariant::Case(case_header), val)?; + } + } else { + for (header, value) in value_map { + let titled_header = + case_header_name::titled_header_name_str(header).unwrap_or(header.as_str()); + f(HeaderNameVariant::Titled(titled_header), value)?; + } + } + + Ok(()) + } + /// Set the status code pub fn set_status(&mut self, status: impl TryInto) -> Result<()> { self.base.status = status @@ -546,6 +618,7 @@ fn clone_req_parts(me: &ReqParts) -> ReqParts { .into_parts() .0; parts.headers = me.headers.clone(); + parts.extensions = me.extensions.clone(); parts } @@ -558,6 +631,7 @@ fn clone_resp_parts(me: &RespParts) -> RespParts { .into_parts() .0; parts.headers = me.headers.clone(); + parts.extensions = me.extensions.clone(); parts } diff --git a/pingora-ketama/Cargo.toml b/pingora-ketama/Cargo.toml index dbc867d5..be17dbba 100644 --- a/pingora-ketama/Cargo.toml +++ b/pingora-ketama/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pingora-ketama" -version = "0.6.0" +version = "0.8.0" description = "Rust port of the nginx consistent hash function" authors = ["Pingora Team "] license = "Apache-2.0" @@ -11,14 +11,15 @@ keywords = ["hash", "hashing", "consistent", "pingora"] [dependencies] crc32fast = "1.3" +i_key_sort = { version = "0.10.1", optional = true, features = ["allow_multithreading"] } [dev-dependencies] -criterion = "0.4" +criterion = "0.7" csv = "1.2" dhat = "0.3" -env_logger = "0.9" +env_logger = "0.11" log = { workspace = true } -rand = "0.8" +rand = "0.9.2" [[bench]] name = "simple" @@ -30,3 +31,4 @@ harness = false [features] heap-prof = [] +v2 = ["i_key_sort"] diff --git a/pingora-ketama/benches/simple.rs b/pingora-ketama/benches/simple.rs index 253cf330..ac93ee4d 100644 --- a/pingora-ketama/benches/simple.rs +++ b/pingora-ketama/benches/simple.rs @@ -1,8 +1,10 @@ use pingora_ketama::{Bucket, Continuum}; use criterion::{criterion_group, criterion_main, Criterion}; -use rand::distributions::Alphanumeric; -use rand::{thread_rng, Rng}; +use rand::{ + distr::{Alphanumeric, SampleString}, + rng, +}; #[cfg(feature = "heap-prof")] #[global_allocator] @@ -19,11 +21,8 @@ fn buckets() -> Vec { } fn random_string() -> String { - thread_rng() - .sample_iter(&Alphanumeric) - .take(30) - .map(char::from) - .collect() + let mut rand = rng(); + Alphanumeric.sample_string(&mut rand, 30) } pub fn criterion_benchmark(c: &mut Criterion) { diff --git a/pingora-ketama/examples/health_aware_selector.rs b/pingora-ketama/examples/health_aware_selector.rs index f749213d..1e44723b 100644 --- a/pingora-ketama/examples/health_aware_selector.rs +++ b/pingora-ketama/examples/health_aware_selector.rs @@ -32,7 +32,7 @@ struct HealthAwareNodeSelector<'a> { } impl HealthAwareNodeSelector<'_> { - fn new(r: Continuum, tries: usize, nhr: &NodeHealthRepository) -> HealthAwareNodeSelector { + fn new(r: Continuum, tries: usize, nhr: &NodeHealthRepository) -> HealthAwareNodeSelector<'_> { HealthAwareNodeSelector { ring: r, max_tries: tries, diff --git a/pingora-ketama/src/lib.rs b/pingora-ketama/src/lib.rs index d18d4fa1..335501b1 100644 --- a/pingora-ketama/src/lib.rs +++ b/pingora-ketama/src/lib.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -62,11 +62,17 @@ use std::io::Write; use std::net::SocketAddr; use crc32fast::Hasher; +#[cfg(feature = "v2")] +use i_key_sort::sort::one_key_cmp::OneKeyAndCmpSort; + +/// This constant is copied from nginx. It will create 160 points per weight +/// unit. For example, a weight of 2 will create 320 points on the ring. +pub const DEFAULT_POINT_MULTIPLE: u32 = 160; /// A [Bucket] represents a server for consistent hashing /// /// A [Bucket] contains a [SocketAddr] to the server and a weight associated with it. -#[derive(Clone, Debug, Eq, PartialEq, PartialOrd)] +#[derive(Clone, Debug, Eq, PartialEq, PartialOrd, Ord)] pub struct Bucket { // The node name. // TODO: UDS @@ -94,28 +100,197 @@ impl Bucket { // A point on the continuum. #[derive(Clone, Debug, Eq, PartialEq)] -struct Point { +struct PointV1 { // the index to the actual address node: u32, hash: u32, } // We only want to compare the hash when sorting, so we implement these traits by hand. -impl Ord for Point { +impl Ord for PointV1 { fn cmp(&self, other: &Self) -> Ordering { self.hash.cmp(&other.hash) } } -impl PartialOrd for Point { +impl PartialOrd for PointV1 { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } -impl Point { +impl PointV1 { fn new(node: u32, hash: u32) -> Self { - Point { node, hash } + PointV1 { node, hash } + } +} + +/// A point on the continuum. +/// +/// We are trying to save memory here, so this struct is equivalent to a struct +/// this this definition, but doesn't require using the "untrustworthy" compact +/// repr. This does mean we have to do the memory layout manually though, but +/// the benchmarks show there is no performance hit for it. +/// +/// #[repr(Rust, packed)] +/// struct Point { +/// node: u16, +/// hash: u32, +/// } +#[cfg(feature = "v2")] +#[derive(Copy, Clone, Eq, PartialEq)] +#[repr(transparent)] +struct PointV2([u8; 6]); + +#[cfg(feature = "v2")] +impl PointV2 { + fn new(node: u16, hash: u32) -> Self { + let mut this = [0; 6]; + + this[0..4].copy_from_slice(&hash.to_ne_bytes()); + this[4..6].copy_from_slice(&node.to_ne_bytes()); + + Self(this) + } + + /// Return the hash of the point which is stored in the first 4 bytes (big endian). + fn hash(&self) -> u32 { + u32::from_ne_bytes(self.0[0..4].try_into().expect("There are exactly 4 bytes")) + } + + /// Return the node of the point which is stored in the last 2 bytes (big endian). + fn node(&self) -> u16 { + u16::from_ne_bytes(self.0[4..6].try_into().expect("There are exactly 2 bytes")) + } +} + +#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)] +pub enum Version { + #[default] + V1, + #[cfg(feature = "v2")] + V2 { point_multiple: u32 }, +} + +impl Version { + fn point_multiple(&self) -> u32 { + match self { + Version::V1 => DEFAULT_POINT_MULTIPLE, + #[cfg(feature = "v2")] + Version::V2 { point_multiple } => *point_multiple, + } + } +} + +enum RingBuilder { + V1(Vec), + #[cfg(feature = "v2")] + V2(Vec), +} + +impl RingBuilder { + fn new(version: Version, total_weight: u32) -> Self { + match version { + Version::V1 => RingBuilder::V1(Vec::with_capacity( + (total_weight * DEFAULT_POINT_MULTIPLE) as usize, + )), + #[cfg(feature = "v2")] + Version::V2 { point_multiple } => { + RingBuilder::V2(Vec::with_capacity((total_weight * point_multiple) as usize)) + } + } + } + + fn push(&mut self, node: u16, hash: u32) { + match self { + RingBuilder::V1(ring) => { + ring.push(PointV1::new(node as u32, hash)); + } + #[cfg(feature = "v2")] + RingBuilder::V2(ring) => { + ring.push(PointV2::new(node, hash)); + } + } + } + + #[allow(unused)] + fn sort(&mut self, addresses: &[SocketAddr]) { + match self { + RingBuilder::V1(ring) => { + // Sort and remove any duplicates. + ring.sort_unstable(); + ring.dedup_by(|a, b| a.hash == b.hash); + } + #[cfg(feature = "v2")] + RingBuilder::V2(ring) => { + ring.sort_by_one_key_then_by( + true, + |p| p.hash(), + |p1, p2| addresses[p1.node() as usize].cmp(&addresses[p2.node() as usize]), + ); + + //secondary_radix_sort(ring, |p| p.hash(), |p| addresses[p.node() as usize]); + ring.dedup_by(|a, b| a.0[0..4] == b.0[0..4]); + } + } + } +} + +impl From for VersionedRing { + fn from(ring: RingBuilder) -> Self { + match ring { + RingBuilder::V1(ring) => VersionedRing::V1(ring.into_boxed_slice()), + #[cfg(feature = "v2")] + RingBuilder::V2(ring) => VersionedRing::V2(ring.into_boxed_slice()), + } + } +} + +enum VersionedRing { + V1(Box<[PointV1]>), + #[cfg(feature = "v2")] + V2(Box<[PointV2]>), +} + +impl VersionedRing { + /// Find the associated index for the given input. + pub fn node_idx(&self, hash: u32) -> usize { + // The `Result` returned here is either a match or the error variant + // returns where the value would be inserted. + let search_result = match self { + VersionedRing::V1(ring) => ring.binary_search_by(|p| p.hash.cmp(&hash)), + #[cfg(feature = "v2")] + VersionedRing::V2(ring) => ring.binary_search_by(|p| p.hash().cmp(&hash)), + }; + + match search_result { + Ok(i) => i, + Err(i) => { + // We wrap around to the front if this value would be + // inserted at the end. + if i == self.len() { + 0 + } else { + i + } + } + } + } + + pub fn get(&self, index: usize) -> Option { + match self { + VersionedRing::V1(ring) => ring.get(index).map(|p| p.node as usize), + #[cfg(feature = "v2")] + VersionedRing::V2(ring) => ring.get(index).map(|p| p.node() as usize), + } + } + + pub fn len(&self) -> usize { + match self { + VersionedRing::V1(ring) => ring.len(), + #[cfg(feature = "v2")] + VersionedRing::V2(ring) => ring.len(), + } } } @@ -124,27 +299,27 @@ impl Point { /// A [Continuum] represents a ring of buckets where a node is associated with various points on /// the ring. pub struct Continuum { - ring: Box<[Point]>, + ring: VersionedRing, addrs: Box<[SocketAddr]>, } impl Continuum { - /// Create a new [Continuum] with the given list of buckets. pub fn new(buckets: &[Bucket]) -> Self { - // This constant is copied from nginx. It will create 160 points per weight unit. For - // example, a weight of 2 will create 320 points on the ring. - const POINT_MULTIPLE: u32 = 160; + Self::new_with_version(buckets, Version::default()) + } + /// Create a new [Continuum] with the given list of buckets. + pub fn new_with_version(buckets: &[Bucket], version: Version) -> Self { if buckets.is_empty() { return Continuum { - ring: Box::new([]), + ring: VersionedRing::V1(Box::new([])), addrs: Box::new([]), }; } // The total weight is multiplied by the factor of points to create many points per node. let total_weight: u32 = buckets.iter().fold(0, |sum, b| sum + b.weight); - let mut ring = Vec::with_capacity((total_weight * POINT_MULTIPLE) as usize); + let mut ring = RingBuilder::new(version, total_weight); let mut addrs = Vec::with_capacity(buckets.len()); for bucket in buckets { @@ -165,7 +340,7 @@ impl Continuum { hasher.update(hash_bytes.as_ref()); // A higher weight will add more points for this node. - let num_points = bucket.weight * POINT_MULTIPLE; + let num_points = bucket.weight * version.point_multiple(); // This is appended to the crc32 hash for each point. let mut prev_hash: u32 = 0; @@ -176,52 +351,40 @@ impl Continuum { hasher.update(&prev_hash.to_le_bytes()); let hash = hasher.finalize(); - ring.push(Point::new(node as u32, hash)); + ring.push(node as u16, hash); prev_hash = hash; } } + let addrs = addrs.into_boxed_slice(); + // Sort and remove any duplicates. - ring.sort_unstable(); - ring.dedup_by(|a, b| a.hash == b.hash); + ring.sort(&addrs); Continuum { - ring: ring.into_boxed_slice(), - addrs: addrs.into_boxed_slice(), + ring: ring.into(), + addrs, } } /// Find the associated index for the given input. pub fn node_idx(&self, input: &[u8]) -> usize { let hash = crc32fast::hash(input); - - // The `Result` returned here is either a match or the error variant returns where the - // value would be inserted. - match self.ring.binary_search_by(|p| p.hash.cmp(&hash)) { - Ok(i) => i, - Err(i) => { - // We wrap around to the front if this value would be inserted at the end. - if i == self.ring.len() { - 0 - } else { - i - } - } - } + self.ring.node_idx(hash) } /// Hash the given `hash_key` to the server address. pub fn node(&self, hash_key: &[u8]) -> Option { self.ring .get(self.node_idx(hash_key)) // should we unwrap here? - .map(|p| self.addrs[p.node as usize]) + .map(|n| self.addrs[n]) } /// Get an iterator of nodes starting at the original hashed node of the `hash_key`. /// /// This function is useful to find failover servers if the original ones are offline, which is /// cheaper than rebuilding the entire hash ring. - pub fn node_iter(&self, hash_key: &[u8]) -> NodeIterator { + pub fn node_iter(&self, hash_key: &[u8]) -> NodeIterator<'_> { NodeIterator { idx: self.node_idx(hash_key), continuum: self, @@ -234,7 +397,7 @@ impl Continuum { // only update idx for non-empty ring otherwise we will panic on modulo 0 *idx = (*idx + 1) % self.ring.len(); } - point.map(|p| &self.addrs[p.node as usize]) + point.map(|n| &self.addrs[n]) } } diff --git a/pingora-ketama/tests/backwards_compat.rs b/pingora-ketama/tests/backwards_compat.rs new file mode 100644 index 00000000..3224cf42 --- /dev/null +++ b/pingora-ketama/tests/backwards_compat.rs @@ -0,0 +1,101 @@ +// Copyright 2026 Cloudflare, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +use old_version::{Bucket as OldBucket, Continuum as OldContinuum}; +#[allow(unused_imports)] +use pingora_ketama::{Bucket, Continuum, Version, DEFAULT_POINT_MULTIPLE}; +use rand::{random, random_range, rng, seq::IteratorRandom}; +use std::collections::BTreeSet; +use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; + +mod old_version; + +fn random_socket_addr() -> SocketAddr { + if random::() { + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from_bits(random()), random())) + } else { + SocketAddr::V6(SocketAddrV6::new( + Ipv6Addr::from_bits(random()), + random(), + 0, + 0, + )) + } +} + +fn random_string(len: usize) -> String { + const CHARS: &str = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"; + let mut rng = rng(); + (0..len) + .map(|_| CHARS.chars().choose(&mut rng).unwrap()) + .collect() +} + +/// The old version of pingora-ketama should _always_ return the same result as +/// v1 of the new version as long as the original input is sorted by by socket +/// address (and has no duplicates). this test generates a large number of +/// random socket addresses with varying weights and compares the output of +/// both +#[test] +fn test_v1_to_old_version() { + let (old_buckets, new_buckets): (BTreeSet<_>, BTreeSet<_>) = (0..2000) + .map(|_| (random_socket_addr(), random_range(1..10))) + .map(|(addr, weight)| (OldBucket::new(addr, weight), Bucket::new(addr, weight))) + .unzip(); + + let old_continuum = OldContinuum::new(&Vec::from_iter(old_buckets)); + let new_continuum = Continuum::new(&Vec::from_iter(new_buckets)); + + for _ in 0..20_000 { + let key = random_string(20); + let old_node = old_continuum.node(key.as_bytes()).unwrap(); + let new_node = new_continuum.node(key.as_bytes()).unwrap(); + + assert_eq!(old_node, new_node); + } +} + +/// The new version of pingora-ketama (v2) should return _almost_ exactly what +/// the old version does. The difference will be in collision handling +#[test] +#[cfg(feature = "v2")] +fn test_v2_to_old_version() { + let (old_buckets, new_buckets): (BTreeSet<_>, BTreeSet<_>) = (0..2000) + .map(|_| (random_socket_addr(), random_range(1..10))) + .map(|(addr, weight)| (OldBucket::new(addr, weight), Bucket::new(addr, weight))) + .unzip(); + + let old_continuum = OldContinuum::new(&Vec::from_iter(old_buckets)); + + let new_continuum = Continuum::new_with_version( + &Vec::from_iter(new_buckets), + Version::V2 { + point_multiple: DEFAULT_POINT_MULTIPLE, + }, + ); + + let test_count = 20_000; + let mut mismatches = 0; + + for _ in 0..test_count { + let key = random_string(20); + let old_node = old_continuum.node(key.as_bytes()).unwrap(); + let new_node = new_continuum.node(key.as_bytes()).unwrap(); + + if old_node != new_node { + mismatches += 1; + } + } + + assert!((mismatches as f64 / test_count as f64) < 0.001); +} diff --git a/pingora-ketama/tests/old_version/mod.rs b/pingora-ketama/tests/old_version/mod.rs new file mode 100644 index 00000000..b6f8dc7f --- /dev/null +++ b/pingora-ketama/tests/old_version/mod.rs @@ -0,0 +1,178 @@ +// Copyright 2026 Cloudflare, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! This mod is a direct copy of the old version of pingora-ketama. It is here +//! to ensure that the new version's compatible mode is produces identical +//! results as the old version. + +use std::cmp::Ordering; +use std::io::Write; +use std::net::SocketAddr; + +use crc32fast::Hasher; + +/// A [Bucket] represents a server for consistent hashing +/// +/// A [Bucket] contains a [SocketAddr] to the server and a weight associated with it. +#[derive(Clone, Debug, Eq, PartialEq, PartialOrd, Ord)] +pub struct Bucket { + // The node name. + // TODO: UDS + node: SocketAddr, + + // The weight associated with a node. A higher weight indicates that this node should + // receive more requests. + weight: u32, +} + +impl Bucket { + /// Return a new bucket with the given node and weight. + /// + /// The chance that a [Bucket] is selected is proportional to the relative weight of all [Bucket]s. + /// + /// # Panics + /// + /// This will panic if the weight is zero. + pub fn new(node: SocketAddr, weight: u32) -> Self { + assert!(weight != 0, "weight must be at least one"); + + Bucket { node, weight } + } +} + +// A point on the continuum. +#[derive(Clone, Debug, Eq, PartialEq)] +struct Point { + // the index to the actual address + node: u32, + hash: u32, +} + +// We only want to compare the hash when sorting, so we implement these traits by hand. +impl Ord for Point { + fn cmp(&self, other: &Self) -> Ordering { + self.hash.cmp(&other.hash) + } +} + +impl PartialOrd for Point { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Point { + fn new(node: u32, hash: u32) -> Self { + Point { node, hash } + } +} + +/// The consistent hashing ring +/// +/// A [Continuum] represents a ring of buckets where a node is associated with various points on +/// the ring. +pub struct Continuum { + ring: Box<[Point]>, + addrs: Box<[SocketAddr]>, +} + +impl Continuum { + /// Create a new [Continuum] with the given list of buckets. + pub fn new(buckets: &[Bucket]) -> Self { + // This constant is copied from nginx. It will create 160 points per weight unit. For + // example, a weight of 2 will create 320 points on the ring. + const POINT_MULTIPLE: u32 = 160; + + if buckets.is_empty() { + return Continuum { + ring: Box::new([]), + addrs: Box::new([]), + }; + } + + // The total weight is multiplied by the factor of points to create many points per node. + let total_weight: u32 = buckets.iter().fold(0, |sum, b| sum + b.weight); + let mut ring = Vec::with_capacity((total_weight * POINT_MULTIPLE) as usize); + let mut addrs = Vec::with_capacity(buckets.len()); + + for bucket in buckets { + let mut hasher = Hasher::new(); + + // We only do the following for backwards compatibility with nginx/memcache: + // - Convert SocketAddr to string + // - The hash input is as follows "HOST EMPTY PORT PREVIOUS_HASH". Spaces are only added + // for readability. + // TODO: remove this logic and hash the literal SocketAddr once we no longer + // need backwards compatibility + + // with_capacity = max_len(ipv6)(39) + len(null)(1) + max_len(port)(5) + let mut hash_bytes = Vec::with_capacity(39 + 1 + 5); + write!(&mut hash_bytes, "{}", bucket.node.ip()).unwrap(); + write!(&mut hash_bytes, "\0").unwrap(); + write!(&mut hash_bytes, "{}", bucket.node.port()).unwrap(); + hasher.update(hash_bytes.as_ref()); + + // A higher weight will add more points for this node. + let num_points = bucket.weight * POINT_MULTIPLE; + + // This is appended to the crc32 hash for each point. + let mut prev_hash: u32 = 0; + addrs.push(bucket.node); + let node = addrs.len() - 1; + for _ in 0..num_points { + let mut hasher = hasher.clone(); + hasher.update(&prev_hash.to_le_bytes()); + + let hash = hasher.finalize(); + ring.push(Point::new(node as u32, hash)); + prev_hash = hash; + } + } + + // Sort and remove any duplicates. + ring.sort_unstable(); + ring.dedup_by(|a, b| a.hash == b.hash); + + Continuum { + ring: ring.into_boxed_slice(), + addrs: addrs.into_boxed_slice(), + } + } + + /// Find the associated index for the given input. + pub fn node_idx(&self, input: &[u8]) -> usize { + let hash = crc32fast::hash(input); + + // The `Result` returned here is either a match or the error variant returns where the + // value would be inserted. + match self.ring.binary_search_by(|p| p.hash.cmp(&hash)) { + Ok(i) => i, + Err(i) => { + // We wrap around to the front if this value would be inserted at the end. + if i == self.ring.len() { + 0 + } else { + i + } + } + } + } + + /// Hash the given `hash_key` to the server address. + pub fn node(&self, hash_key: &[u8]) -> Option { + self.ring + .get(self.node_idx(hash_key)) // should we unwrap here? + .map(|p| self.addrs[p.node as usize]) + } +} diff --git a/pingora-limits/Cargo.toml b/pingora-limits/Cargo.toml index c019d636..64edfd10 100644 --- a/pingora-limits/Cargo.toml +++ b/pingora-limits/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pingora-limits" -version = "0.6.0" +version = "0.8.0" authors = ["Yuchen Wu "] license = "Apache-2.0" description = "A library for rate limiting and event frequency estimation" diff --git a/pingora-limits/benches/benchmark.rs b/pingora-limits/benches/benchmark.rs index 699df3dc..4eaa881a 100644 --- a/pingora-limits/benches/benchmark.rs +++ b/pingora-limits/benches/benchmark.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-limits/src/estimator.rs b/pingora-limits/src/estimator.rs index 6f6576d4..bbf91022 100644 --- a/pingora-limits/src/estimator.rs +++ b/pingora-limits/src/estimator.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-limits/src/inflight.rs b/pingora-limits/src/inflight.rs index 9371a12f..c6a25a69 100644 --- a/pingora-limits/src/inflight.rs +++ b/pingora-limits/src/inflight.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-limits/src/lib.rs b/pingora-limits/src/lib.rs index 68492045..c020302b 100644 --- a/pingora-limits/src/lib.rs +++ b/pingora-limits/src/lib.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-limits/src/rate.rs b/pingora-limits/src/rate.rs index 1f8604f9..bd1268b3 100644 --- a/pingora-limits/src/rate.rs +++ b/pingora-limits/src/rate.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -38,11 +38,28 @@ pub struct RateComponents { pub current_interval_fraction: f64, } -/// A stable rate estimator that reports the rate of events in the past `interval` time. -/// It returns the average rate between `interval` * 2 and `interval` while collecting the events -/// happening between `interval` and now. +/// A rate calculation function which uses a good estimate of the rate of events over the past +/// `interval` time. /// -/// This estimator ignores events that happen less than once per `interval` time. +/// Specifically, it linearly interpolates between the event counts of the previous and current +/// periods based on how far into the current period we are, as described in this post: +/// +#[allow(dead_code)] +pub static PROPORTIONAL_RATE_ESTIMATE_CALC_FN: fn(RateComponents) -> f64 = + |rate_info: RateComponents| { + let prev = rate_info.prev_samples as f64; + let curr = rate_info.curr_samples as f64; + let interval_secs = rate_info.interval.as_secs_f64(); + let interval_fraction = rate_info.current_interval_fraction; + + let weighted_count = prev * (1. - interval_fraction) + curr; + weighted_count / interval_secs + }; + +/// A stable rate estimator that reports the rate of events per period of `interval` time. +/// +/// It counts events for periods of `interval` and returns the average rate of the latest completed +/// period while counting events for the current (partial) period. pub struct Rate { // 2 slots so that we use one to collect the current events and the other to report rate red_slot: Estimator, @@ -104,6 +121,8 @@ impl Rate { } /// Return the per second rate estimation. + /// + /// This is the average rate of the latest completed period of length `interval`. pub fn rate(&self, key: &T) -> f64 { let past_ms = self.maybe_reset(); if past_ms >= self.reset_interval_ms * 2 { @@ -111,7 +130,7 @@ impl Rate { return 0f64; } - self.previous(self.red_or_blue()).get(key) as f64 / self.reset_interval_ms as f64 * 1000.0 + self.previous(self.red_or_blue()).get(key) as f64 * 1000.0 / self.reset_interval_ms as f64 } /// Report new events and return number of events seen so far in the current interval. @@ -277,50 +296,52 @@ mod tests { assert_eq!(r.rate_with(&key, rate_90_10_fn), 0f64); } - // this is the function described in this post - // https://blog.cloudflare.com/counting-things-a-lot-of-different-things/ #[test] fn test_observe_rate_custom_proportional() { let r = Rate::new(Duration::from_secs(1)); let key = 1; - let rate_prop_fn = |rate_info: RateComponents| { - let prev = rate_info.prev_samples as f64; - let curr = rate_info.curr_samples as f64; - let interval_secs = rate_info.interval.as_secs_f64(); - let interval_fraction = rate_info.current_interval_fraction; - - let weighted_count = prev * (1. - interval_fraction) + curr * interval_fraction; - weighted_count / interval_secs - }; - // second: 0 let observed = r.observe(&key, 3); assert_eq!(observed, 3); let observed = r.observe(&key, 2); assert_eq!(observed, 5); - assert_eq_ish(r.rate_with(&key, rate_prop_fn), 0.); + assert_eq_ish(r.rate_with(&key, PROPORTIONAL_RATE_ESTIMATE_CALC_FN), 5.); // second 0.5 sleep(Duration::from_secs_f64(0.5)); - assert_eq_ish(r.rate_with(&key, rate_prop_fn), 5. * 0.5); + assert_eq_ish(r.rate_with(&key, PROPORTIONAL_RATE_ESTIMATE_CALC_FN), 5.); + // rate() just looks at the previous interval, ignores current interval + assert_eq_ish(r.rate(&key), 0.); // second: 1 sleep(Duration::from_secs_f64(0.5)); let observed = r.observe(&key, 4); assert_eq!(observed, 4); - assert_eq_ish(r.rate_with(&key, rate_prop_fn), 5.); + assert_eq_ish(r.rate_with(&key, PROPORTIONAL_RATE_ESTIMATE_CALC_FN), 9.); // second 1.75 sleep(Duration::from_secs_f64(0.75)); - assert_eq_ish(r.rate_with(&key, rate_prop_fn), 5. * 0.25 + 4. * 0.75); + assert_eq_ish( + r.rate_with(&key, PROPORTIONAL_RATE_ESTIMATE_CALC_FN), + 5. * 0.25 + 4., + ); // second: 2 sleep(Duration::from_secs_f64(0.25)); - assert_eq_ish(r.rate_with(&key, rate_prop_fn), 4.); + assert_eq_ish(r.rate_with(&key, PROPORTIONAL_RATE_ESTIMATE_CALC_FN), 4.); + assert_eq_ish(r.rate(&key), 4.); + + // second: 2.5 + sleep(Duration::from_secs_f64(0.5)); + assert_eq_ish( + r.rate_with(&key, PROPORTIONAL_RATE_ESTIMATE_CALC_FN), + 4. / 2., + ); + assert_eq_ish(r.rate(&key), 4.); // second: 3 sleep(Duration::from_secs(1)); - assert_eq!(r.rate_with(&key, rate_prop_fn), 0f64); + assert_eq!(r.rate_with(&key, PROPORTIONAL_RATE_ESTIMATE_CALC_FN), 0f64); } } diff --git a/pingora-load-balancing/Cargo.toml b/pingora-load-balancing/Cargo.toml index 882e72ea..d6f5d41e 100644 --- a/pingora-load-balancing/Cargo.toml +++ b/pingora-load-balancing/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pingora-load-balancing" -version = "0.6.0" +version = "0.8.0" authors = ["Yuchen Wu "] license = "Apache-2.0" edition = "2021" @@ -18,11 +18,11 @@ path = "src/lib.rs" [dependencies] async-trait = { workspace = true } -pingora-http = { version = "0.6.0", path = "../pingora-http" } -pingora-error = { version = "0.6.0", path = "../pingora-error" } -pingora-core = { version = "0.6.0", path = "../pingora-core", default-features = false } -pingora-ketama = { version = "0.6.0", path = "../pingora-ketama" } -pingora-runtime = { version = "0.6.0", path = "../pingora-runtime" } +pingora-http = { version = "0.8.0", path = "../pingora-http" } +pingora-error = { version = "0.8.0", path = "../pingora-error" } +pingora-core = { version = "0.8.0", path = "../pingora-core", default-features = false } +pingora-ketama = { version = "0.8.0", path = "../pingora-ketama" } +pingora-runtime = { version = "0.8.0", path = "../pingora-runtime" } arc-swap = "1" fnv = "1" rand = "0.8" @@ -42,3 +42,4 @@ rustls = ["pingora-core/rustls", "any_tls"] s2n = ["pingora-core/s2n", "any_tls"] openssl_derived = ["any_tls"] any_tls = [] +v2 = ["pingora-ketama/v2"] diff --git a/pingora-load-balancing/src/background.rs b/pingora-load-balancing/src/background.rs index c99c188e..a34c50af 100644 --- a/pingora-load-balancing/src/background.rs +++ b/pingora-load-balancing/src/background.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -18,20 +18,24 @@ use std::time::{Duration, Instant}; use super::{BackendIter, BackendSelection, LoadBalancer}; use async_trait::async_trait; -use pingora_core::services::background::BackgroundService; +use pingora_core::services::{background::BackgroundService, ServiceReadyNotifier}; -#[async_trait] -impl BackgroundService for LoadBalancer +impl LoadBalancer where S::Iter: BackendIter, { - async fn start(&self, shutdown: pingora_core::server::ShutdownWatch) -> () { + pub async fn run( + &self, + shutdown: pingora_core::server::ShutdownWatch, + mut ready_opt: Option, + ) -> () { // 136 years const NEVER: Duration = Duration::from_secs(u32::MAX as u64); let mut now = Instant::now(); // run update and health check once let mut next_update = now; let mut next_health_check = now; + loop { if *shutdown.borrow() { return; @@ -43,6 +47,12 @@ where next_update = now + self.update_frequency.unwrap_or(NEVER); } + // After the first update, discovery and selection setup will be + // done, so we will notify dependents + if let Some(ready) = ready_opt.take() { + ServiceReadyNotifier::notify_ready(ready) + } + if next_health_check <= now { self.backends .run_health_check(self.parallel_health_check) @@ -59,3 +69,24 @@ where } } } + +/// Implement [BackgroundService] for [LoadBalancer]. For backward-compatibility +/// reasons, we implement both the `start` and `start_with_ready_notifier` +/// methods. +#[async_trait] +impl BackgroundService for LoadBalancer +where + S::Iter: BackendIter, +{ + async fn start_with_ready_notifier( + &self, + shutdown: pingora_core::server::ShutdownWatch, + ready: ServiceReadyNotifier, + ) -> () { + self.run(shutdown, Some(ready)).await + } + + async fn start(&self, shutdown: pingora_core::server::ShutdownWatch) -> () { + self.run(shutdown, None).await + } +} diff --git a/pingora-load-balancing/src/discovery.rs b/pingora-load-balancing/src/discovery.rs index 2896ec36..afeba278 100644 --- a/pingora-load-balancing/src/discovery.rs +++ b/pingora-load-balancing/src/discovery.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-load-balancing/src/health_check.rs b/pingora-load-balancing/src/health_check.rs index 268af5dd..5e97fb36 100644 --- a/pingora-load-balancing/src/health_check.rs +++ b/pingora-load-balancing/src/health_check.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -17,7 +17,10 @@ use crate::Backend; use arc_swap::ArcSwap; use async_trait::async_trait; +use pingora_core::connectors::http::custom; use pingora_core::connectors::{http::Connector as HttpConnector, TransportConnector}; +use pingora_core::custom_session; +use pingora_core::protocols::http::custom::client::Session; use pingora_core::upstreams::peer::{BasicPeer, HttpPeer, Peer}; use pingora_error::{Error, ErrorType::CustomCode, Result}; use pingora_http::{RequestHeader, ResponseHeader}; @@ -148,7 +151,10 @@ type Validator = Box Result<()> + Send + Sync>; /// HTTP health check /// /// This health check checks if it can receive the expected HTTP(s) response from the given backend. -pub struct HttpHealthCheck { +pub struct HttpHealthCheck +where + C: custom::Connector, +{ /// Number of successful checks to flip from unhealthy to healthy. pub consecutive_success: usize, /// Number of failed checks to flip from healthy to unhealthy. @@ -170,7 +176,7 @@ pub struct HttpHealthCheck { pub reuse_connection: bool, /// The request header to send to the backend pub req: RequestHeader, - connector: HttpConnector, + connector: HttpConnector, /// Optional field to define how to validate the response from the server. /// /// If not set, any response with a `200 OK` is considered a successful check. @@ -184,7 +190,7 @@ pub struct HttpHealthCheck { pub backend_summary_callback: Option, } -impl HttpHealthCheck { +impl HttpHealthCheck<()> { /// Create a new [HttpHealthCheck] with the following default settings /// * connect timeout: 1 second /// * read timeout: 1 second @@ -213,9 +219,43 @@ impl HttpHealthCheck { backend_summary_callback: None, } } +} + +impl HttpHealthCheck +where + C: custom::Connector, +{ + /// Create a new [HttpHealthCheck] with the following default settings + /// * connect timeout: 1 second + /// * read timeout: 1 second + /// * req: a GET to the `/` of the given host name + /// * consecutive_success: 1 + /// * consecutive_failure: 1 + /// * reuse_connection: false + /// * validator: `None`, any 200 response is considered successful + pub fn new_custom(host: &str, tls: bool, custom: HttpConnector) -> Self { + let mut req = RequestHeader::build("GET", b"/", None).unwrap(); + req.append_header("Host", host).unwrap(); + let sni = if tls { host.into() } else { String::new() }; + let mut peer_template = HttpPeer::new("0.0.0.0:1", tls, sni); + peer_template.options.connection_timeout = Some(Duration::from_secs(1)); + peer_template.options.read_timeout = Some(Duration::from_secs(1)); + HttpHealthCheck { + consecutive_success: 1, + consecutive_failure: 1, + peer_template, + connector: custom, + reuse_connection: false, + req, + validator: None, + port_override: None, + health_changed_callback: None, + backend_summary_callback: None, + } + } /// Replace the internal http connector with the given [HttpConnector] - pub fn set_connector(&mut self, connector: HttpConnector) { + pub fn set_connector(&mut self, connector: HttpConnector) { self.connector = connector; } @@ -228,7 +268,10 @@ impl HttpHealthCheck { } #[async_trait] -impl HealthCheck for HttpHealthCheck { +impl HealthCheck for HttpHealthCheck +where + C: custom::Connector, +{ fn health_threshold(&self, success: bool) -> usize { if success { self.consecutive_success @@ -250,6 +293,8 @@ impl HealthCheck for HttpHealthCheck { session.write_request_header(req).await?; session.finish_request_body().await?; + custom_session!(session.finish_custom().await?); + if let Some(read_timeout) = peer.options.read_timeout { session.set_read_timeout(Some(read_timeout)); } @@ -271,6 +316,9 @@ impl HealthCheck for HttpHealthCheck { // drain the body if any } + // TODO(slava): do it concurrently wtih body drain? + custom_session!(session.drain_custom_messages().await?); + if self.reuse_connection { let idle_timeout = peer.idle_timeout(); self.connector diff --git a/pingora-load-balancing/src/lib.rs b/pingora-load-balancing/src/lib.rs index 33fadd00..0e1bc6e5 100644 --- a/pingora-load-balancing/src/lib.rs +++ b/pingora-load-balancing/src/lib.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -308,9 +308,15 @@ impl Backends { /// /// In order to run service discovery and health check at the designated frequencies, the [LoadBalancer] /// needs to be run as a [pingora_core::services::background::BackgroundService]. -pub struct LoadBalancer { +pub struct LoadBalancer +where + S: BackendSelection, +{ backends: Backends, selector: ArcSwap, + + config: Option, + /// How frequent the health check logic (if set) should run. /// /// If `None`, the health check logic will only run once at the beginning. @@ -323,7 +329,7 @@ pub struct LoadBalancer { pub parallel_health_check: bool, } -impl LoadBalancer +impl LoadBalancer where S: BackendSelection + 'static, S::Iter: BackendIter, @@ -346,25 +352,46 @@ where Ok(lb) } - /// Build a [LoadBalancer] with the given [Backends]. - pub fn from_backends(backends: Backends) -> Self { - let selector = ArcSwap::new(Arc::new(S::build(&backends.get_backend()))); + /// Build a [LoadBalancer] with the given [Backends] and the config. + pub fn from_backends_with_config(backends: Backends, config_opt: Option) -> Self { + let selector_raw = if let Some(config) = config_opt.as_ref() { + S::build_with_config(&backends.get_backend(), config) + } else { + S::build(&backends.get_backend()) + }; + + let selector = ArcSwap::new(Arc::new(selector_raw)); + LoadBalancer { backends, selector, + config: config_opt, health_check_frequency: None, update_frequency: None, parallel_health_check: false, } } + /// Build a [LoadBalancer] with the given [Backends]. + pub fn from_backends(backends: Backends) -> Self { + Self::from_backends_with_config(backends, None) + } + /// Run the service discovery and update the selection algorithm. /// /// This function will be called every `update_frequency` if this [LoadBalancer] instance /// is running as a background service. pub async fn update(&self) -> Result<()> { self.backends - .update(|backends| self.selector.store(Arc::new(S::build(&backends)))) + .update(|backends| { + let selector = if let Some(config) = &self.config { + S::build_with_config(&backends, config) + } else { + S::build(&backends) + }; + + self.selector.store(Arc::new(selector)) + }) .await } diff --git a/pingora-load-balancing/src/selection/algorithms.rs b/pingora-load-balancing/src/selection/algorithms.rs index 4dba4115..cd296c45 100644 --- a/pingora-load-balancing/src/selection/algorithms.rs +++ b/pingora-load-balancing/src/selection/algorithms.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-load-balancing/src/selection/consistent.rs b/pingora-load-balancing/src/selection/consistent.rs index 8a279632..fe1fe0cb 100644 --- a/pingora-load-balancing/src/selection/consistent.rs +++ b/pingora-load-balancing/src/selection/consistent.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ use super::*; use pingora_core::protocols::l4::socket::SocketAddr; -use pingora_ketama::{Bucket, Continuum}; +use pingora_ketama::{Bucket, Continuum, Version}; use std::collections::HashMap; /// Weighted Ketama consistent hashing @@ -26,10 +26,19 @@ pub struct KetamaHashing { backends: HashMap, } +#[derive(Clone, Debug, Copy, Default)] +pub struct KetamaConfig { + pub point_multiple: Option, +} + impl BackendSelection for KetamaHashing { type Iter = OwnedNodeIterator; - fn build(backends: &BTreeSet) -> Self { + type Config = KetamaConfig; + + fn build_with_config(backends: &BTreeSet, config: &Self::Config) -> Self { + let KetamaConfig { point_multiple } = *config; + let buckets: Vec<_> = backends .iter() .filter_map(|b| { @@ -45,12 +54,29 @@ impl BackendSelection for KetamaHashing { .iter() .map(|b| (b.addr.clone(), b.clone())) .collect(); + + #[allow(unused)] + let version = if let Some(point_multiple) = point_multiple { + match () { + #[cfg(feature = "v2")] + () => Version::V2 { point_multiple }, + #[cfg(not(feature = "v2"))] + () => Version::V1, + } + } else { + Version::V1 + }; + KetamaHashing { - ring: Continuum::new(&buckets), + ring: Continuum::new_with_version(&buckets, version), backends: new_backends, } } + fn build(backends: &BTreeSet) -> Self { + Self::build_with_config(backends, &KetamaConfig::default()) + } + fn iter(self: &Arc, key: &[u8]) -> Self::Iter { OwnedNodeIterator { idx: self.ring.node_idx(key), diff --git a/pingora-load-balancing/src/selection/mod.rs b/pingora-load-balancing/src/selection/mod.rs index d3300fdb..3e9d85ed 100644 --- a/pingora-load-balancing/src/selection/mod.rs +++ b/pingora-load-balancing/src/selection/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -24,9 +24,19 @@ use std::sync::Arc; use weighted::Weighted; /// [BackendSelection] is the interface to implement backend selection mechanisms. -pub trait BackendSelection { +pub trait BackendSelection: Sized { /// The [BackendIter] returned from iter() below. type Iter; + + /// The configuration type constructing [BackendSelection] + type Config: Send + Sync; + + /// Create a [BackendSelection] from a set of backends and the given configuration. The + /// default implementation ignores the configuration and simply calls [Self::build] + fn build_with_config(backends: &BTreeSet, _config: &Self::Config) -> Self { + Self::build(backends) + } + /// The function to create a [BackendSelection] implementation. fn build(backends: &BTreeSet) -> Self; /// Select backends for a given key. diff --git a/pingora-load-balancing/src/selection/weighted.rs b/pingora-load-balancing/src/selection/weighted.rs index 9799c378..d12c51f6 100644 --- a/pingora-load-balancing/src/selection/weighted.rs +++ b/pingora-load-balancing/src/selection/weighted.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -32,6 +32,8 @@ pub struct Weighted { impl BackendSelection for Weighted { type Iter = WeightedIterator; + type Config = (); + fn build(backends: &BTreeSet) -> Self { assert!( backends.len() <= u16::MAX as usize, diff --git a/pingora-lru/Cargo.toml b/pingora-lru/Cargo.toml index ffa88bc3..3eae82b9 100644 --- a/pingora-lru/Cargo.toml +++ b/pingora-lru/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pingora-lru" -version = "0.6.0" +version = "0.8.0" authors = ["Yuchen Wu "] license = "Apache-2.0" edition = "2021" diff --git a/pingora-lru/benches/bench_linked_list.rs b/pingora-lru/benches/bench_linked_list.rs index b8a0413f..5fc0e50a 100644 --- a/pingora-lru/benches/bench_linked_list.rs +++ b/pingora-lru/benches/bench_linked_list.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-lru/benches/bench_lru.rs b/pingora-lru/benches/bench_lru.rs index 53acc2e9..c0bdc776 100644 --- a/pingora-lru/benches/bench_lru.rs +++ b/pingora-lru/benches/bench_lru.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-lru/src/lib.rs b/pingora-lru/src/lib.rs index 74ec1ac4..23728c4f 100644 --- a/pingora-lru/src/lib.rs +++ b/pingora-lru/src/lib.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -97,17 +97,26 @@ impl Lru { shard } - /// Increment the weight associated with a given key. + /// Increment the weight associated with a given key, up to an optional max weight. + /// If a `max_weight` is provided, the weight cannot exceed this max weight. If the current + /// weight is higher than the max, it will be capped to the max. /// /// Return the total new weight. 0 indicates the key did not exist. - pub fn increment_weight(&self, key: u64, delta: usize) -> usize { + pub fn increment_weight(&self, key: u64, delta: usize, max_weight: Option) -> usize { let shard = get_shard(key, N); let unit = &mut self.units[shard].write(); - let new_weight = unit.increment_weight(key, delta); - if new_weight > 0 { - self.weight.fetch_add(delta, Ordering::Relaxed); + if let Some((old_weight, new_weight)) = unit.increment_weight(key, delta, max_weight) { + if new_weight >= old_weight { + self.weight + .fetch_add(new_weight - old_weight, Ordering::Relaxed); + } else { + self.weight + .fetch_sub(old_weight - new_weight, Ordering::Relaxed); + } + new_weight + } else { + 0 } - new_weight } /// Promote the key to the head of the LRU @@ -181,7 +190,7 @@ impl Lru { evicted } - /// Remove the given asset + /// Remove the given asset. pub fn remove(&self, key: u64) -> Option<(T, usize)> { let removed = self.units[get_shard(key, N)].write().remove(key); if let Some((_, weight)) = removed.as_ref() { @@ -191,7 +200,7 @@ impl Lru { removed } - /// Insert the item to the tail of this LRU + /// Insert the item to the tail of this LRU. /// /// Useful to recreate an LRU in most-to-least order pub fn insert_tail(&self, key: u64, data: T, weight: usize) -> bool { @@ -207,12 +216,17 @@ impl Lru { } } - /// Check existence of a key without changing the order in LRU + /// Check existence of a key without changing the order in LRU. pub fn peek(&self, key: u64) -> bool { self.units[get_shard(key, N)].read().peek(key).is_some() } - /// Return the current total weight + /// Check the weight of a key without changing the order in LRU. + pub fn peek_weight(&self, key: u64) -> Option { + self.units[get_shard(key, N)].read().peek_weight(key) + } + + /// Return the current total weight. pub fn weight(&self) -> usize { self.weight.load(Ordering::Relaxed) } @@ -251,6 +265,11 @@ impl Lru { pub fn shard_len(&self, shard: usize) -> usize { self.units[shard].read().len() } + + /// Get the weight (total size) inside a shard + pub fn shard_weight(&self, shard: usize) -> usize { + self.units[shard].read().used_weight + } } #[inline] @@ -279,19 +298,20 @@ impl LruUnit { } } + /// Peek data associated with key, if it exists. pub fn peek(&self, key: u64) -> Option<&T> { self.lookup_table.get(&key).map(|n| &n.data) } - // admin into LRU, return old weight if there was any + /// Peek weight associated with key, if it exists. + pub fn peek_weight(&self, key: u64) -> Option { + self.lookup_table.get(&key).map(|n| n.weight) + } + + /// Admit into LRU, return old weight if there was any. pub fn admit(&mut self, key: u64, data: T, weight: usize) -> usize { if let Some(node) = self.lookup_table.get_mut(&key) { - let old_weight = node.weight; - if weight != old_weight { - self.used_weight += weight; - self.used_weight -= old_weight; - node.weight = weight; - } + let old_weight = Self::adjust_weight(node, &mut self.used_weight, weight); node.data = data; self.order.promote(node.list_index); return old_weight; @@ -307,15 +327,25 @@ impl LruUnit { 0 } - /// Increase the weight of an existing key. Returns the new weight or the key. - pub fn increment_weight(&mut self, key: u64, delta: usize) -> usize { + /// Increase the weight of an existing key. Returns the new weight or 0 if the key did not + /// exist, along with the new weight (or 0). + /// + /// If a `max_weight` is provided, the weight cannot exceed this max weight. If the current + /// weight is higher than the max, it will be capped to the max. + pub fn increment_weight( + &mut self, + key: u64, + delta: usize, + max_weight: Option, + ) -> Option<(usize, usize)> { if let Some(node) = self.lookup_table.get_mut(&key) { - node.weight += delta; - self.used_weight += delta; + let new_weight = + max_weight.map_or(node.weight + delta, |m| (node.weight + delta).min(m)); + let old_weight = Self::adjust_weight(node, &mut self.used_weight, new_weight); self.order.promote(node.list_index); - return node.weight; + return Some((old_weight, new_weight)); } - 0 + None } pub fn access(&mut self, key: u64) -> bool { @@ -386,6 +416,19 @@ impl LruUnit { iter: self.order.iter(), } } + + // Adjusts node weight to the new given weight. + // Returns old weight. + #[inline] + fn adjust_weight(node: &mut LruNode, used_weight: &mut usize, weight: usize) -> usize { + let old_weight = node.weight; + if weight != old_weight { + *used_weight += weight; + *used_weight -= old_weight; + node.weight = weight; + } + old_weight + } } struct LruUnitIter<'a, T> { @@ -552,15 +595,18 @@ mod test_lru { fn test_increment_weight() { let lru = Lru::<_, 2>::with_capacity(6, 10); lru.admit(1, 1, 1); - lru.increment_weight(1, 1); + lru.increment_weight(1, 1, None); assert_eq!(lru.weight(), 1 + 1); - lru.increment_weight(0, 1000); + lru.increment_weight(0, 1000, None); assert_eq!(lru.weight(), 1 + 1); lru.admit(2, 2, 2); - lru.increment_weight(2, 2); + lru.increment_weight(2, 2, None); assert_eq!(lru.weight(), 1 + 1 + 2 + 2); + + lru.increment_weight(2, 2, Some(3)); + assert_eq!(lru.weight(), 1 + 1 + 3); } #[test] @@ -742,15 +788,22 @@ mod test_lru_unit { fn test_increment_weight() { let mut lru = LruUnit::with_capacity(10); lru.admit(1, 1, 1); - lru.increment_weight(1, 1); + lru.increment_weight(1, 1, None); assert_eq!(lru.used_weight(), 1 + 1); - lru.increment_weight(0, 1000); + lru.increment_weight(0, 1000, None); assert_eq!(lru.used_weight(), 1 + 1); lru.admit(2, 2, 2); - lru.increment_weight(2, 2); + lru.increment_weight(2, 2, None); assert_eq!(lru.used_weight(), 1 + 1 + 2 + 2); + + lru.admit(3, 3, 3); + lru.increment_weight(3, 3, Some(5)); + assert_eq!(lru.used_weight(), 1 + 1 + 2 + 2 + 3 + 2); + + lru.increment_weight(3, 3, Some(3)); + assert_eq!(lru.used_weight(), 1 + 1 + 2 + 2 + 3); } #[test] diff --git a/pingora-lru/src/linked_list.rs b/pingora-lru/src/linked_list.rs index 7a9d37cc..ceb9a861 100644 --- a/pingora-lru/src/linked_list.rs +++ b/pingora-lru/src/linked_list.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-memory-cache/Cargo.toml b/pingora-memory-cache/Cargo.toml index bb449610..843194b2 100644 --- a/pingora-memory-cache/Cargo.toml +++ b/pingora-memory-cache/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pingora-memory-cache" -version = "0.6.0" +version = "0.8.0" authors = ["Yuchen Wu "] license = "Apache-2.0" edition = "2021" @@ -17,14 +17,14 @@ name = "pingora_memory_cache" path = "src/lib.rs" [dependencies] -TinyUFO = { version = "0.6.0", path = "../tinyufo" } +TinyUFO = { version = "0.8.0", path = "../tinyufo" } ahash = { workspace = true } tokio = { workspace = true, features = ["sync"] } async-trait = { workspace = true } -pingora-error = { version = "0.6.0", path = "../pingora-error" } +pingora-error = { version = "0.8.0", path = "../pingora-error" } log = { workspace = true } parking_lot = "0" -pingora-timeout = { version = "0.6.0", path = "../pingora-timeout" } +pingora-timeout = { version = "0.8.0", path = "../pingora-timeout" } [dev-dependencies] once_cell = { workspace = true } diff --git a/pingora-memory-cache/src/lib.rs b/pingora-memory-cache/src/lib.rs index b30a2d2b..84389d0d 100644 --- a/pingora-memory-cache/src/lib.rs +++ b/pingora-memory-cache/src/lib.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-memory-cache/src/read_through.rs b/pingora-memory-cache/src/read_through.rs index 140f2362..96e4348e 100644 --- a/pingora-memory-cache/src/read_through.rs +++ b/pingora-memory-cache/src/read_through.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -200,10 +200,9 @@ where } }; - if my_read.is_some() { + if let Some(my_lock) = my_read { /* another task will do the lookup */ - let my_lock = my_read.unwrap(); /* if available_permits > 0, writer is done */ if my_lock.lock.available_permits() == 0 { /* block here to wait for writer to finish lookup */ @@ -268,10 +267,10 @@ where (Err(err), cache_state) } }; - if my_write.is_some() { + if let Some(my_write) = my_write { /* add permit so that reader can start. Any number of permits will do, * since readers will return permits right away. */ - my_write.unwrap().lock.add_permits(10); + my_write.lock.add_permits(10); { // remove the lock from locker diff --git a/pingora-openssl/Cargo.toml b/pingora-openssl/Cargo.toml index 6c472ef7..0cc322d1 100644 --- a/pingora-openssl/Cargo.toml +++ b/pingora-openssl/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pingora-openssl" -version = "0.6.0" +version = "0.8.0" authors = ["Yuchen Wu "] license = "Apache-2.0" edition = "2021" diff --git a/pingora-openssl/src/ext.rs b/pingora-openssl/src/ext.rs index 25234b95..18e0fdfe 100644 --- a/pingora-openssl/src/ext.rs +++ b/pingora-openssl/src/ext.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-openssl/src/lib.rs b/pingora-openssl/src/lib.rs index 455be746..6fd2f912 100644 --- a/pingora-openssl/src/lib.rs +++ b/pingora-openssl/src/lib.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-pool/Cargo.toml b/pingora-pool/Cargo.toml index 95b1344f..5d841a4c 100644 --- a/pingora-pool/Cargo.toml +++ b/pingora-pool/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pingora-pool" -version = "0.6.0" +version = "0.8.0" authors = ["Yuchen Wu "] license = "Apache-2.0" edition = "2021" @@ -23,7 +23,7 @@ lru = { workspace = true } log = { workspace = true } parking_lot = "0.12" crossbeam-queue = "0.3" -pingora-timeout = { version = "0.6.0", path = "../pingora-timeout" } +pingora-timeout = { version = "0.8.0", path = "../pingora-timeout" } [dev-dependencies] tokio-test = "0.4" diff --git a/pingora-pool/src/connection.rs b/pingora-pool/src/connection.rs index 63f23c46..a30c08ee 100644 --- a/pingora-pool/src/connection.rs +++ b/pingora-pool/src/connection.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -316,7 +316,7 @@ impl ConnectionPool { pub async fn idle_timeout( &self, meta: &ConnectionMeta, - timeout: Duration, + timeout: Option, notify_evicted: Arc, mut notify_closed: watch::Receiver, watch_use: oneshot::Receiver, @@ -335,7 +335,8 @@ impl ConnectionPool { debug!("idle connection is being closed"); self.pop_closed(meta); } - _ = sleep(timeout) => { + // async expression is evaluated if timeout is None but it's never polled, set it to MAX + _ = sleep(timeout.unwrap_or(Duration::MAX)), if timeout.is_some() => { debug!("idle connection is being evicted"); self.pop_closed(meta); } diff --git a/pingora-pool/src/lib.rs b/pingora-pool/src/lib.rs index b3e88692..d16d57b8 100644 --- a/pingora-pool/src/lib.rs +++ b/pingora-pool/src/lib.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-pool/src/lru.rs b/pingora-pool/src/lru.rs index a7529029..c6a72d8a 100644 --- a/pingora-pool/src/lru.rs +++ b/pingora-pool/src/lru.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-proxy/Cargo.toml b/pingora-proxy/Cargo.toml index aeaa15eb..c685b8c4 100644 --- a/pingora-proxy/Cargo.toml +++ b/pingora-proxy/Cargo.toml @@ -1,9 +1,10 @@ [package] name = "pingora-proxy" -version = "0.6.0" +version = "0.8.0" authors = ["Yuchen Wu "] license = "Apache-2.0" edition = "2021" +rust-version = "1.84" repository = "https://github.com/cloudflare/pingora" categories = ["asynchronous", "network-programming"] keywords = ["async", "http", "proxy", "pingora"] @@ -18,11 +19,11 @@ name = "pingora_proxy" path = "src/lib.rs" [dependencies] -pingora-error = { version = "0.6.0", path = "../pingora-error" } -pingora-core = { version = "0.6.0", path = "../pingora-core", default-features = false } -pingora-cache = { version = "0.6.0", path = "../pingora-cache", default-features = false } +pingora-error = { version = "0.8.0", path = "../pingora-error" } +pingora-core = { version = "0.8.0", path = "../pingora-core", default-features = false } +pingora-cache = { version = "0.8.0", path = "../pingora-cache", default-features = false } tokio = { workspace = true, features = ["macros", "net"] } -pingora-http = { version = "0.6.0", path = "../pingora-http" } +pingora-http = { version = "0.8.0", path = "../pingora-http" } http = { workspace = true } futures = "0.3" bytes = { workspace = true } @@ -30,7 +31,7 @@ async-trait = { workspace = true } log = { workspace = true } h2 = { workspace = true } once_cell = { workspace = true } -clap = { version = "3.2.25", features = ["derive"] } +clap = { version = "4", features = ["derive"] } regex = "1" rand = "0.8" @@ -39,17 +40,18 @@ reqwest = { version = "0.11", features = [ "gzip", "rustls-tls", ], default-features = false } +httparse = { workspace = true } tokio-test = "0.4" -env_logger = "0.9" +env_logger = "0.11" hyper = "0.14" tokio-tungstenite = "0.20.1" -pingora-limits = { version = "0.6.0", path = "../pingora-limits" } -pingora-load-balancing = { version = "0.6.0", path = "../pingora-load-balancing", default-features=false } +pingora-limits = { version = "0.8.0", path = "../pingora-limits" } +pingora-load-balancing = { version = "0.8.0", path = "../pingora-load-balancing", default-features=false } prometheus = "0" futures-util = "0.3" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" -serde_yaml = "0.8" +serde_yaml = "0.9" [target.'cfg(unix)'.dev-dependencies] hyperlocal = "0.8" @@ -57,12 +59,21 @@ hyperlocal = "0.8" [features] default = [] openssl = ["pingora-core/openssl", "pingora-cache/openssl", "openssl_derived"] -boringssl = ["pingora-core/boringssl", "pingora-cache/boringssl", "openssl_derived"] +boringssl = [ + "pingora-core/boringssl", + "pingora-cache/boringssl", + "openssl_derived", +] rustls = ["pingora-core/rustls", "pingora-cache/rustls", "any_tls"] s2n = ["pingora-core/s2n", "pingora-cache/s2n", "any_tls"] openssl_derived = ["any_tls"] any_tls = [] sentry = ["pingora-core/sentry"] +connection_filter = ["pingora-core/connection_filter"] + +[[example]] +name = "connection_filter" +required-features = ["connection_filter"] # or locally cargo doc --config "build.rustdocflags='--cfg doc_async_trait'" [package.metadata.docs.rs] diff --git a/pingora-proxy/examples/backoff_retry.rs b/pingora-proxy/examples/backoff_retry.rs index 717a41b4..0604b6ec 100644 --- a/pingora-proxy/examples/backoff_retry.rs +++ b/pingora-proxy/examples/backoff_retry.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,7 +15,6 @@ use std::time::Duration; use async_trait::async_trait; -use clap::Parser; use log::info; use pingora_core::server::Server; @@ -79,7 +78,7 @@ fn main() { env_logger::init(); // read command line arguments - let opt = Opt::parse(); + let opt = Opt::parse_args(); let mut my_server = Server::new(Some(opt)).unwrap(); my_server.bootstrap(); diff --git a/pingora-proxy/examples/connection_filter.rs b/pingora-proxy/examples/connection_filter.rs new file mode 100644 index 00000000..1c346c6f --- /dev/null +++ b/pingora-proxy/examples/connection_filter.rs @@ -0,0 +1,96 @@ +// Copyright 2026 Cloudflare, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use clap::Parser; +use log::info; +use pingora_core::listeners::ConnectionFilter; +use pingora_core::prelude::Opt; +use pingora_core::server::Server; +use pingora_core::upstreams::peer::HttpPeer; +use pingora_core::Result; +use pingora_proxy::{ProxyHttp, Session}; +use std::sync::Arc; + +/// This example demonstrates how to implement a connection filter +pub struct MyProxy; + +#[async_trait] +impl ProxyHttp for MyProxy { + type CTX = (); + + fn new_ctx(&self) -> Self::CTX {} + + async fn upstream_peer( + &self, + _session: &mut Session, + _ctx: &mut Self::CTX, + ) -> Result> { + // Forward to httpbin.org for testing + let peer = HttpPeer::new(("httpbin.org", 80), false, "httpbin.org".into()); + Ok(Box::new(peer)) + } +} + +/// Connection filter that blocks ALL connections (for testing) +#[derive(Debug, Clone)] +struct BlockAllFilter; + +#[async_trait] +impl ConnectionFilter for BlockAllFilter { + async fn should_accept(&self, addr: &std::net::SocketAddr) -> bool { + info!("BLOCKING connection from {} (BlockAllFilter active)", addr); + false + } +} + +// RUST_LOG=INFO cargo run --example connection_filter + +fn main() { + env_logger::init(); + + // read command line arguments + let opt = Opt::parse(); + let mut my_server = Server::new(Some(opt)).unwrap(); + my_server.bootstrap(); + + let mut my_proxy = pingora_proxy::http_proxy_service(&my_server.configuration, MyProxy); + + // Create a filter that blocks ALL connections + let filter = Arc::new(BlockAllFilter); + + info!("Setting BlockAllFilter on proxy service"); + my_proxy.set_connection_filter(filter.clone()); + + info!("Adding TCP endpoints AFTER setting filter"); + my_proxy.add_tcp("0.0.0.0:6195"); + my_proxy.add_tcp("0.0.0.0:6196"); + + info!("===================================="); + info!("Server starting with BlockAllFilter"); + info!("This filter blocks ALL connections!"); + info!("===================================="); + info!(""); + info!("Test with:"); + info!(" curl http://localhost:6195/get"); + info!(" curl http://localhost:6196/get"); + info!(""); + info!("ALL requests should be blocked!"); + info!("You should see 'BLOCKING connection' in the logs"); + info!("and curl should fail with 'Connection refused' or hang"); + info!(""); + + my_server.add_service(my_proxy); + my_server.run_forever(); +} diff --git a/pingora-proxy/examples/ctx.rs b/pingora-proxy/examples/ctx.rs index 3927f86b..bb281a55 100644 --- a/pingora-proxy/examples/ctx.rs +++ b/pingora-proxy/examples/ctx.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,7 +13,6 @@ // limitations under the License. use async_trait::async_trait; -use clap::Parser; use log::info; use std::sync::Mutex; @@ -82,7 +81,7 @@ fn main() { env_logger::init(); // read command line arguments - let opt = Opt::parse(); + let opt = Opt::parse_args(); let mut my_server = Server::new(Some(opt)).unwrap(); my_server.bootstrap(); diff --git a/pingora-proxy/examples/gateway.rs b/pingora-proxy/examples/gateway.rs index 5c6723f6..83b7c1ca 100644 --- a/pingora-proxy/examples/gateway.rs +++ b/pingora-proxy/examples/gateway.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,7 +14,6 @@ use async_trait::async_trait; use bytes::Bytes; -use clap::Parser; use log::info; use prometheus::register_int_counter; @@ -117,7 +116,7 @@ fn main() { env_logger::init(); // read command line arguments - let opt = Opt::parse(); + let opt = Opt::parse_args(); let mut my_server = Server::new(Some(opt)).unwrap(); my_server.bootstrap(); diff --git a/pingora-proxy/examples/grpc_web_module.rs b/pingora-proxy/examples/grpc_web_module.rs index 43385ec1..085adb92 100644 --- a/pingora-proxy/examples/grpc_web_module.rs +++ b/pingora-proxy/examples/grpc_web_module.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,7 +13,6 @@ // limitations under the License. use async_trait::async_trait; -use clap::Parser; use pingora_core::server::Server; use pingora_core::upstreams::peer::HttpPeer; @@ -77,7 +76,7 @@ fn main() { env_logger::init(); // read command line arguments - let opt = Opt::parse(); + let opt = Opt::parse_args(); let mut my_server = Server::new(Some(opt)).unwrap(); my_server.bootstrap(); diff --git a/pingora-proxy/examples/load_balancer.rs b/pingora-proxy/examples/load_balancer.rs index 0b04c61f..b1375633 100644 --- a/pingora-proxy/examples/load_balancer.rs +++ b/pingora-proxy/examples/load_balancer.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,7 +13,6 @@ // limitations under the License. use async_trait::async_trait; -use clap::Parser; use log::info; use pingora_core::services::background::background_service; use std::{sync::Arc, time::Duration}; @@ -62,7 +61,7 @@ fn main() { env_logger::init(); // read command line arguments - let opt = Opt::parse(); + let opt = Opt::parse_args(); let mut my_server = Server::new(Some(opt)).unwrap(); my_server.bootstrap(); diff --git a/pingora-proxy/examples/modify_response.rs b/pingora-proxy/examples/modify_response.rs index 4a7b480a..ea10f03f 100644 --- a/pingora-proxy/examples/modify_response.rs +++ b/pingora-proxy/examples/modify_response.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,7 +14,6 @@ use async_trait::async_trait; use bytes::Bytes; -use clap::Parser; use serde::{Deserialize, Serialize}; use std::net::ToSocketAddrs; @@ -117,7 +116,7 @@ impl ProxyHttp for Json2Yaml { fn main() { env_logger::init(); - let opt = Opt::parse(); + let opt = Opt::parse_args(); let mut my_server = Server::new(Some(opt)).unwrap(); my_server.bootstrap(); diff --git a/pingora-proxy/examples/multi_lb.rs b/pingora-proxy/examples/multi_lb.rs index a0b629c8..c8c76753 100644 --- a/pingora-proxy/examples/multi_lb.rs +++ b/pingora-proxy/examples/multi_lb.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-proxy/examples/use_module.rs b/pingora-proxy/examples/use_module.rs index d59e741e..26c10ca6 100644 --- a/pingora-proxy/examples/use_module.rs +++ b/pingora-proxy/examples/use_module.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,7 +13,6 @@ // limitations under the License. use async_trait::async_trait; -use clap::Parser; use pingora_core::modules::http::HttpModules; use pingora_core::server::configuration::Opt; @@ -115,7 +114,7 @@ fn main() { env_logger::init(); // read command line arguments - let opt = Opt::parse(); + let opt = Opt::parse_args(); let mut my_server = Server::new(Some(opt)).unwrap(); my_server.bootstrap(); diff --git a/pingora-proxy/examples/virtual_l4.rs b/pingora-proxy/examples/virtual_l4.rs new file mode 100644 index 00000000..ecef1814 --- /dev/null +++ b/pingora-proxy/examples/virtual_l4.rs @@ -0,0 +1,169 @@ +//! This example demonstrates to how to implement a custom L4 connector +//! together with a virtual socket. + +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::sync::Arc; + +use async_trait::async_trait; +use pingora_core::connectors::L4Connect; +use pingora_core::prelude::HttpPeer; +use pingora_core::protocols::l4::socket::SocketAddr as L4SocketAddr; +use pingora_core::protocols::l4::stream::Stream; +use pingora_core::protocols::l4::virt::{VirtualSocket, VirtualSocketStream}; +use pingora_core::server::RunArgs; +use pingora_core::server::{configuration::ServerConf, Server}; +use pingora_core::services::listening::Service; +use pingora_core::upstreams::peer::PeerOptions; +use pingora_error::Result; +use pingora_proxy::{http_proxy_service_with_name, prelude::*, HttpProxy, ProxyHttp}; +use tokio::io::{AsyncRead, AsyncWrite}; + +/// Static virtual socket that serves a single HTTP request with a static response. +/// +/// In real world use cases you would implement [`VirtualSocket`] for streams +/// that implement `AsyncRead + AsyncWrite`. +#[derive(Debug)] +struct StaticVirtualSocket { + content: Vec, + read_pos: usize, +} + +impl StaticVirtualSocket { + fn new() -> Self { + let response = b"HTTP/1.1 200 OK\r\nContent-Length: 13\r\n\r\nHello, world!"; + Self { + content: response.to_vec(), + read_pos: 0, + } + } +} + +impl AsyncRead for StaticVirtualSocket { + fn poll_read( + mut self: std::pin::Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + debug_assert!(self.read_pos <= self.content.len()); + + let remaining = self.content.len() - self.read_pos; + if remaining == 0 { + return std::task::Poll::Ready(Ok(())); + } + + let to_read = std::cmp::min(remaining, buf.remaining()); + buf.put_slice(&self.content[self.read_pos..self.read_pos + to_read]); + self.read_pos += to_read; + + std::task::Poll::Ready(Ok(())) + } +} + +impl AsyncWrite for StaticVirtualSocket { + fn poll_write( + self: std::pin::Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + // Discard all writes + std::task::Poll::Ready(Ok(buf.len())) + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) + } + + fn poll_shutdown( + self: std::pin::Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) + } +} + +impl VirtualSocket for StaticVirtualSocket { + fn set_socket_option( + &self, + _opt: pingora_core::protocols::l4::virt::VirtualSockOpt, + ) -> std::io::Result<()> { + Ok(()) + } +} + +#[derive(Debug)] +struct VirtualConnector; + +#[async_trait] +impl L4Connect for VirtualConnector { + async fn connect(&self, _addr: &L4SocketAddr) -> pingora_error::Result { + Ok(Stream::from(VirtualSocketStream::new(Box::new( + StaticVirtualSocket::new(), + )))) + } +} + +struct VirtualProxy { + connector: Arc, +} + +impl VirtualProxy { + fn new() -> Self { + Self { + connector: Arc::new(VirtualConnector), + } + } +} + +#[async_trait::async_trait] +impl ProxyHttp for VirtualProxy { + type CTX = (); + + fn new_ctx(&self) -> Self::CTX {} + + // Route everything to example.org unless the Host header is "virtual.test", + // in which case target the special virtual address 203.0.113.1:18080. + async fn upstream_peer( + &self, + _session: &mut Session, + _ctx: &mut Self::CTX, + ) -> Result> { + let mut options = PeerOptions::new(); + options.custom_l4 = Some(self.connector.clone()); + + Ok(Box::new(HttpPeer { + _address: L4SocketAddr::Inet(SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)), + 80, + )), + scheme: pingora_core::upstreams::peer::Scheme::HTTP, + sni: "example.org".to_string(), + proxy: None, + client_cert_key: None, + group_key: 0, + options, + })) + } +} + +fn main() { + // Minimal server config + let conf = Arc::new(ServerConf::default()); + + // Build the service and set the default L4 connector + let mut svc: Service> = + http_proxy_service_with_name(&conf, VirtualProxy::new(), "virtual-proxy"); + + // Listen + let addr = "127.0.0.1:6196"; + svc.add_tcp(addr); + + let mut server = Server::new(None).unwrap(); + server.add_service(svc); + let run = RunArgs::default(); + + eprintln!("Listening on {addr}, try: curl http://{addr}/"); + server.run(run); +} diff --git a/pingora-proxy/src/lib.rs b/pingora-proxy/src/lib.rs index d0f41f71..f89f53d3 100644 --- a/pingora-proxy/src/lib.rs +++ b/pingora-proxy/src/lib.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -37,14 +37,18 @@ use async_trait::async_trait; use bytes::Bytes; +use futures::future::BoxFuture; use futures::future::FutureExt; -use http::{header, version::Version}; +use http::{header, version::Version, Method}; use log::{debug, error, trace, warn}; use once_cell::sync::Lazy; use pingora_http::{RequestHeader, ResponseHeader}; use std::fmt::Debug; use std::str; -use std::sync::Arc; +use std::sync::{ + atomic::{AtomicBool, Ordering}, + Arc, +}; use std::time::Duration; use tokio::sync::{mpsc, Notify}; use tokio::time; @@ -53,10 +57,13 @@ use pingora_cache::NoCacheReason; use pingora_core::apps::{ HttpPersistentSettings, HttpServerApp, HttpServerOptions, ReusedHttpStream, }; +use pingora_core::connectors::http::custom; use pingora_core::connectors::{http::Connector, ConnectorOptions}; use pingora_core::modules::http::compression::ResponseCompressionBuilder; use pingora_core::modules::http::{HttpModuleCtx, HttpModules}; use pingora_core::protocols::http::client::HttpSession as ClientSession; +use pingora_core::protocols::http::custom::CustomMessageWrite; +use pingora_core::protocols::http::subrequest::server::SubrequestHandle; use pingora_core::protocols::http::v1::client::HttpSession as HttpSessionV1; use pingora_core::protocols::http::v2::server::H2Options; use pingora_core::protocols::http::HttpTask; @@ -73,6 +80,7 @@ const TASK_BUFFER_SIZE: usize = 4; mod proxy_cache; mod proxy_common; +mod proxy_custom; mod proxy_h1; mod proxy_h2; mod proxy_purge; @@ -81,41 +89,116 @@ pub mod subrequest; use subrequest::{BodyMode, Ctx as SubrequestCtx}; -pub use proxy_cache::range_filter::{range_header_filter, RangeType}; +pub use proxy_cache::range_filter::{range_header_filter, MultiRangeInfo, RangeType}; pub use proxy_purge::PurgeStatus; pub use proxy_trait::{FailToProxy, ProxyHttp}; pub mod prelude { - pub use crate::{http_proxy_service, ProxyHttp, Session}; + pub use crate::{http_proxy, http_proxy_service, ProxyHttp, Session}; } +pub type ProcessCustomSession = Arc< + dyn Fn(Arc>, Stream, &ShutdownWatch) -> BoxFuture<'static, Option> + + Send + + Sync + + Unpin + + 'static, +>; + /// The concrete type that holds the user defined HTTP proxy. /// /// Users don't need to interact with this object directly. -pub struct HttpProxy { +pub struct HttpProxy +where + C: custom::Connector, // Upstream custom connector +{ inner: SV, // TODO: name it better than inner - client_upstream: Connector, + client_upstream: Connector, shutdown: Notify, + shutdown_flag: Arc, pub server_options: Option, pub h2_options: Option, pub downstream_modules: HttpModules, max_retries: usize, + process_custom_session: Option>, } -impl HttpProxy { - fn new(inner: SV, conf: Arc) -> Self { +impl HttpProxy { + /// Create a new [`HttpProxy`] with the given [`ProxyHttp`] implementation and [`ServerConf`]. + /// + /// After creating an `HttpProxy`, you should call [`HttpProxy::handle_init_modules()`] to + /// initialize the downstream modules before processing requests. + /// + /// For most use cases, prefer using [`http_proxy_service()`] which wraps the `HttpProxy` in a + /// [`Service`]. This constructor is useful when you need to integrate `HttpProxy` into a custom + /// accept loop (e.g., for SNI-based routing decisions before TLS termination). + /// + /// # Example + /// + /// ```ignore + /// use pingora_proxy::HttpProxy; + /// use std::sync::Arc; + /// + /// let mut proxy = HttpProxy::new(my_proxy_app, server_conf); + /// proxy.handle_init_modules(); + /// let proxy = Arc::new(proxy); + /// // Use proxy.process_new_http() in your custom accept loop + /// ``` + pub fn new(inner: SV, conf: Arc) -> Self { HttpProxy { inner, client_upstream: Connector::new(Some(ConnectorOptions::from_server_conf(&conf))), shutdown: Notify::new(), + shutdown_flag: Arc::new(AtomicBool::new(false)), server_options: None, h2_options: None, downstream_modules: HttpModules::new(), max_retries: conf.max_retries, + process_custom_session: None, + } + } +} + +impl HttpProxy +where + C: custom::Connector, +{ + fn new_custom( + inner: SV, + conf: Arc, + connector: C, + on_custom: Option>, + server_options: Option, + ) -> Self + where + SV: ProxyHttp + Send + Sync + 'static, + SV::CTX: Send + Sync, + { + let client_upstream = + Connector::new_custom(Some(ConnectorOptions::from_server_conf(&conf)), connector); + + HttpProxy { + inner, + client_upstream, + shutdown: Notify::new(), + shutdown_flag: Arc::new(AtomicBool::new(false)), + server_options, + downstream_modules: HttpModules::new(), + max_retries: conf.max_retries, + process_custom_session: on_custom, + h2_options: None, } } - fn handle_init_modules(&mut self) + /// Initialize the downstream modules for this proxy. + /// + /// This method must be called after creating an [`HttpProxy`] with [`HttpProxy::new()`] + /// and before processing any requests. It invokes [`ProxyHttp::init_downstream_modules()`] + /// to set up any HTTP modules configured by the user's proxy implementation. + /// + /// Note: When using [`http_proxy_service()`] or [`http_proxy_service_with_name()`], + /// this method is called automatically. + pub fn handle_init_modules(&mut self) where SV: ProxyHttp, { @@ -168,6 +251,27 @@ impl HttpProxy { "Request header: {:?}", downstream_session.req_header().as_ref() ); + // CONNECT method proxying is not default supported by the proxy http logic itself, + // since the tunneling process changes the request-response flow. + // https://datatracker.ietf.org/doc/html/rfc9110#name-connect + // Also because the method impacts message framing in a way is currently unaccounted for + // (https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.2) + // it is safest to disallow use of the method by default. + if !self + .server_options + .as_ref() + .is_some_and(|opts| opts.allow_connect_method_proxying) + && downstream_session.req_header().method == Method::CONNECT + { + downstream_session + .respond_error(405) + .await + .unwrap_or_else(|e| { + error!("failed to send error response to downstream: {e}"); + }); + downstream_session.shutdown().await; + return None; + } Some(downstream_session) } @@ -217,7 +321,7 @@ impl HttpProxy { if matches!(e.etype, H2Downgrade | InvalidH2) { if peer .get_alpn() - .map_or(true, |alpn| alpn.get_min_http_version() == 1) + .is_none_or(|alpn| alpn.get_min_http_version() == 1) { // Add the peer to prefer h1 so that all following requests // will use h1 @@ -231,6 +335,16 @@ impl HttpProxy { (server_reused, error) } + ClientSession::Custom(mut c) => { + let (server_reused, error) = self + .proxy_to_custom_upstream(session, &mut c, client_reused, &peer, ctx) + .await; + let session = ClientSession::Custom(c); + self.client_upstream + .release_http_session(session, &*peer, peer.idle_timeout()) + .await; + (server_reused, error) + } }; ( server_reused, @@ -265,7 +379,7 @@ impl HttpProxy { .await?; None } - HttpTask::Body(data, eos) => self + HttpTask::Body(data, eos) | HttpTask::UpgradedBody(data, eos) => self .inner .upstream_response_body_filter(session, data, *eos, ctx)?, HttpTask::Trailer(Some(trailers)) => { @@ -287,13 +401,19 @@ impl HttpProxy { mut session: Session, ctx: &mut SV::CTX, reuse: bool, - error: Option<&Error>, + error: Option>, ) -> Option where SV: ProxyHttp + Send + Sync, SV::CTX: Send + Sync, { - self.inner.logging(&mut session, error, ctx).await; + self.inner + .logging(&mut session, error.as_deref(), ctx) + .await; + + if let Some(e) = error { + session.downstream_session.on_proxy_failure(e); + } if reuse { // TODO: log error @@ -341,12 +461,20 @@ pub struct Session { pub subrequest_spawner: Option, // Downstream filter modules pub downstream_modules_ctx: HttpModuleCtx, + /// Upstream response body bytes received (payload only). Set by proxy layer. + /// TODO: move this into an upstream session digest for future fields. + upstream_body_bytes_received: usize, + /// Upstream write pending time. Set by proxy layer (HTTP/1.x only). + upstream_write_pending_time: Duration, + /// Flag that is set when the shutdown process has begun. + shutdown_flag: Arc, } impl Session { fn new( downstream_session: impl Into>, downstream_modules: &HttpModules, + shutdown_flag: Arc, ) -> Self { Session { downstream_session: downstream_session.into(), @@ -358,22 +486,35 @@ impl Session { subrequest_ctx: None, subrequest_spawner: None, // optionally set later on downstream_modules_ctx: downstream_modules.build_ctx(), + upstream_body_bytes_received: 0, + upstream_write_pending_time: Duration::ZERO, + shutdown_flag, } } /// Create a new [Session] from the given [Stream] /// - /// This function is mostly used for testing and mocking. + /// This function is mostly used for testing and mocking, given the downstream modules and + /// shutdown flags will never be set. pub fn new_h1(stream: Stream) -> Self { let modules = HttpModules::new(); - Self::new(Box::new(HttpSession::new_http1(stream)), &modules) + Self::new( + Box::new(HttpSession::new_http1(stream)), + &modules, + Arc::new(AtomicBool::new(false)), + ) } /// Create a new [Session] from the given [Stream] with modules /// - /// This function is mostly used for testing and mocking. + /// This function is mostly used for testing and mocking, given the shutdown flag will never be + /// set. pub fn new_h1_with_modules(stream: Stream, downstream_modules: &HttpModules) -> Self { - Self::new(Box::new(HttpSession::new_http1(stream)), downstream_modules) + Self::new( + Box::new(HttpSession::new_http1(stream)), + downstream_modules, + Arc::new(AtomicBool::new(false)), + ) } pub fn as_downstream_mut(&mut self) -> &mut HttpSession { @@ -411,6 +552,16 @@ impl Session { self.downstream_session.write_response_header(resp).await } + /// Similar to `write_response_header()`, this fn will clone the `resp` internally + pub async fn write_response_header_ref( + &mut self, + resp: &ResponseHeader, + end_of_stream: bool, + ) -> Result<(), Box> { + self.write_response_header(Box::new(resp.clone()), end_of_stream) + .await + } + /// Write the given HTTP response body chunk to the downstream /// /// Different from directly calling [HttpSession::write_response_body], this function also @@ -434,6 +585,7 @@ impl Session { } pub async fn write_response_tasks(&mut self, mut tasks: Vec) -> Result { + let mut seen_upgraded = self.was_upgraded(); for task in tasks.iter_mut() { match task { HttpTask::Header(resp, end) => { @@ -445,6 +597,11 @@ impl Session { self.downstream_modules_ctx .response_body_filter(data, *end)?; } + HttpTask::UpgradedBody(data, end) => { + seen_upgraded = true; + self.downstream_modules_ctx + .response_body_filter(data, *end)?; + } HttpTask::Trailer(trailers) => { if let Some(buf) = self .downstream_modules_ctx @@ -455,6 +612,7 @@ impl Session { // // Note, this will not work if end of stream has already // been seen or we've written content-length bytes. + // (Trailers should never come after upgraded body) *task = HttpTask::Body(Some(buf), true); } } @@ -468,7 +626,11 @@ impl Session { // Note, this will not work if end of stream has already // been seen or we've written content-length bytes. if let Some(buf) = self.downstream_modules_ctx.response_done_filter()? { - *task = HttpTask::Body(Some(buf), true); + if seen_upgraded { + *task = HttpTask::UpgradedBody(Some(buf), true); + } else { + *task = HttpTask::Body(Some(buf), true); + } } } _ => { /* Failed */ } @@ -487,6 +649,49 @@ impl Session { pub fn upstream_headers_mutated_for_cache(&self) -> bool { self.upstream_headers_mutated_for_cache } + + /// Get the total upstream response body bytes received (payload only) recorded by the proxy layer. + pub fn upstream_body_bytes_received(&self) -> usize { + self.upstream_body_bytes_received + } + + /// Set the total upstream response body bytes received (payload only). Intended for internal use by proxy layer. + pub(crate) fn set_upstream_body_bytes_received(&mut self, n: usize) { + self.upstream_body_bytes_received = n; + } + + /// Get the upstream write pending time recorded by the proxy layer. Returns [`Duration::ZERO`] for HTTP/2. + pub fn upstream_write_pending_time(&self) -> Duration { + self.upstream_write_pending_time + } + + /// Set the upstream write pending time. Intended for internal use by proxy layer. + pub(crate) fn set_upstream_write_pending_time(&mut self, d: Duration) { + self.upstream_write_pending_time = d; + } + + /// Is the proxy process in the process of shutting down (e.g. due to graceful upgrade)? + pub fn is_process_shutting_down(&self) -> bool { + self.shutdown_flag.load(Ordering::Acquire) + } + + pub fn downstream_custom_message( + &mut self, + ) -> Result< + Option> + Unpin + Send + Sync + 'static>>, + > { + if let Some(custom_session) = self.downstream_session.as_custom_mut() { + custom_session + .take_custom_message_reader() + .map(Some) + .ok_or(Error::explain( + ReadError, + "can't extract custom reader from downstream", + )) + } else { + Ok(None) + } + } } impl AsRef for Session { @@ -529,7 +734,10 @@ static BAD_GATEWAY: Lazy = Lazy::new(|| { resp }); -impl HttpProxy { +impl HttpProxy +where + C: custom::Connector, +{ async fn process_request( self: &Arc, mut session: Session, @@ -597,7 +805,7 @@ impl HttpProxy { if let Some((reuse, err)) = self.proxy_cache(&mut session, &mut ctx).await { // cache hit - return self.finish(session, &mut ctx, reuse, err.as_deref()).await; + return self.finish(session, &mut ctx, reuse, err).await; } // either uncacheable, or cache miss @@ -619,7 +827,7 @@ impl HttpProxy { session.cache.disable(NoCacheReason::DeclinedToUpstream); } if session.response_written().is_none() { - match session.write_response_header_ref(&BAD_GATEWAY).await { + match session.write_response_header_ref(&BAD_GATEWAY, true).await { Ok(()) => {} Err(e) => { return self @@ -690,6 +898,8 @@ impl HttpProxy { // serve stale if error // Check both error and cache before calling the function because await is not cheap + // allow unwrap until if let chains + #[allow(clippy::unnecessary_unwrap)] let serve_stale_result = if proxy_error.is_some() && session.cache.can_serve_stale_error() { self.handle_stale_if_error(&mut session, &mut ctx, proxy_error.as_ref().unwrap()) .await @@ -725,13 +935,13 @@ impl HttpProxy { res.error_code, retries, false, // we never retry here - self.inner.request_summary(&session, &ctx) + self.inner.request_summary(&session, &ctx), ); } } // logging() will be called in finish() - self.finish(session, &mut ctx, server_reuse, final_error.as_deref()) + self.finish(session, &mut ctx, server_reuse, final_error) .await } @@ -758,6 +968,8 @@ impl HttpProxy { self.inner.logging(&mut session, Some(&e), ctx).await; self.cleanup_sub_req(&mut session); + session.downstream_session.on_proxy_failure(e); + if res.can_reuse_downstream { let persistent_settings = HttpPersistentSettings::for_session(&session); session @@ -792,10 +1004,11 @@ pub trait Subrequest { } #[async_trait] -impl Subrequest for HttpProxy +impl Subrequest for HttpProxy where SV: ProxyHttp + Send + Sync + 'static, ::CTX: Send + Sync, + C: custom::Connector, { async fn process_subrequest( self: Arc, @@ -805,7 +1018,11 @@ where debug!("starting subrequest"); let mut session = match self.handle_new_request(session).await { - Some(downstream_session) => Session::new(downstream_session, &self.downstream_modules), + Some(downstream_session) => Session::new( + downstream_session, + &self.downstream_modules, + self.shutdown_flag.clone(), + ), None => return, // bad request }; @@ -826,6 +1043,29 @@ pub struct SubrequestSpawner { app: Arc, } +/// A [`PreparedSubrequest`] that is ready to run. +pub struct PreparedSubrequest { + app: Arc, + session: Box, + sub_req_ctx: Box, +} + +impl PreparedSubrequest { + pub async fn run(self) { + self.app + .process_subrequest(self.session, self.sub_req_ctx) + .await + } + + pub fn session(&self) -> &HttpSession { + self.session.as_ref() + } + + pub fn session_mut(&mut self) -> &mut HttpSession { + self.session.deref_mut() + } +} + impl SubrequestSpawner { /// Create a new [`SubrequestSpawner`]. pub fn new(app: Arc) -> SubrequestSpawner { @@ -855,36 +1095,74 @@ impl SubrequestSpawner { .await; }) } + + /// Create a subrequest that listens to `HttpTask`s sent from the returned `Sender` + /// and sends `HttpTask`s to the returned `Receiver`. + /// + /// To run that subrequest, call `run()`. + // TODO: allow configuring the subrequest session before use + pub fn create_subrequest( + &self, + session: &HttpSession, + ctx: SubrequestCtx, + ) -> (PreparedSubrequest, SubrequestHandle) { + let new_app = self.app.clone(); // Clone the Arc + let (mut session, handle) = subrequest::create_session(session); + if ctx.body_mode() == BodyMode::NoBody { + session + .as_subrequest_mut() + .expect("created subrequest session") + .clear_request_body_headers(); + } + let sub_req_ctx = Box::new(ctx); + ( + PreparedSubrequest { + app: new_app, + session: Box::new(session), + sub_req_ctx, + }, + handle, + ) + } } #[async_trait] -impl HttpServerApp for HttpProxy +impl HttpServerApp for HttpProxy where SV: ProxyHttp + Send + Sync + 'static, ::CTX: Send + Sync, + C: custom::Connector, { async fn process_new_http( self: &Arc, session: HttpSession, - _shutdown: &ShutdownWatch, + shutdown: &ShutdownWatch, ) -> Option { let session = Box::new(session); // TODO: keepalive pool, use stack - let session = match self.handle_new_request(session).await { - Some(downstream_session) => Session::new(downstream_session, &self.downstream_modules), + let mut session = match self.handle_new_request(session).await { + Some(downstream_session) => Session::new( + downstream_session, + &self.downstream_modules, + self.shutdown_flag.clone(), + ), None => return None, // bad request }; + if *shutdown.borrow() { + // stop downstream from reusing if this service is shutting down soon + session.set_keepalive(None); + } + let ctx = self.inner.new_ctx(); self.process_request(session, ctx).await } async fn http_cleanup(&self) { + self.shutdown_flag.store(true, Ordering::Release); // Notify all keepalived requests blocking on read_request() to abort self.shutdown.notify_waiters(); - - // TODO: impl shutting down flag so that we don't need to read stack.is_shutting_down() } fn server_options(&self) -> Option<&HttpServerOptions> { @@ -894,14 +1172,69 @@ where fn h2_options(&self) -> Option { self.h2_options.clone() } + async fn process_custom_session( + self: Arc, + stream: Stream, + shutdown: &ShutdownWatch, + ) -> Option { + let app = self.clone(); + + let Some(process_custom_session) = app.process_custom_session.as_ref() else { + warn!("custom was called on an empty on_custom"); + return None; + }; + + process_custom_session(self.clone(), stream, shutdown).await + } + + // TODO implement h2_options } use pingora_core::services::listening::Service; +/// Create an [`HttpProxy`] without wrapping it in a [`Service`]. +/// +/// This is useful when you need to integrate `HttpProxy` into a custom accept loop, +/// for example when implementing SNI-based routing that decides between TLS passthrough +/// and TLS termination on a single port. +/// +/// The returned `HttpProxy` is fully initialized and ready to process requests via +/// [`HttpServerApp::process_new_http()`]. +/// +/// # Example +/// +/// ```ignore +/// use pingora_proxy::http_proxy; +/// use std::sync::Arc; +/// +/// // Create the proxy +/// let proxy = Arc::new(http_proxy(&server_conf, my_proxy_app)); +/// +/// // In your custom accept loop: +/// loop { +/// let (stream, addr) = listener.accept().await?; +/// +/// // Peek SNI, decide routing... +/// if should_terminate_tls { +/// let tls_stream = my_acceptor.accept(stream).await?; +/// let session = HttpSession::new_http1(Box::new(tls_stream)); +/// proxy.process_new_http(session, &shutdown).await; +/// } +/// } +/// ``` +pub fn http_proxy(conf: &Arc, inner: SV) -> HttpProxy +where + SV: ProxyHttp, +{ + let mut proxy = HttpProxy::new(inner, conf.clone()); + proxy.handle_init_modules(); + proxy +} + /// Create a [Service] from the user implemented [ProxyHttp]. /// /// The returned [Service] can be hosted by a [pingora_core::server::Server] directly. -pub fn http_proxy_service(conf: &Arc, inner: SV) -> Service> +pub fn http_proxy_service(conf: &Arc, inner: SV) -> Service> where SV: ProxyHttp, { @@ -915,7 +1248,7 @@ pub fn http_proxy_service_with_name( conf: &Arc, inner: SV, name: &str, -) -> Service> +) -> Service> where SV: ProxyHttp, { @@ -923,3 +1256,142 @@ where proxy.handle_init_modules(); Service::new(name.to_string(), proxy) } + +/// Create a [Service] from the user implemented [ProxyHttp]. +/// +/// The returned [Service] can be hosted by a [pingora_core::server::Server] directly. +pub fn http_proxy_service_with_name_custom( + conf: &Arc, + inner: SV, + name: &str, + connector: C, + on_custom: ProcessCustomSession, +) -> Service> +where + SV: ProxyHttp + Send + Sync + 'static, + SV::CTX: Send + Sync + 'static, + C: custom::Connector, +{ + let mut proxy = HttpProxy::new_custom(inner, conf.clone(), connector, Some(on_custom), None); + proxy.handle_init_modules(); + + Service::new(name.to_string(), proxy) +} + +/// A builder for a [Service] that can be used to create a [HttpProxy] instance +/// +/// The [ProxyServiceBuilder] can be used to construct a [HttpProxy] service with a custom name, +/// connector, and custom session handler. +/// +pub struct ProxyServiceBuilder +where + SV: ProxyHttp + Send + Sync + 'static, + SV::CTX: Send + Sync + 'static, + C: custom::Connector, +{ + conf: Arc, + inner: SV, + name: String, + connector: C, + custom: Option>, + server_options: Option, +} + +impl ProxyServiceBuilder +where + SV: ProxyHttp + Send + Sync + 'static, + SV::CTX: Send + Sync + 'static, +{ + /// Create a new [ProxyServiceBuilder] with the given [ServerConf] and [ProxyHttp] + /// implementation. + /// + /// The returned builder can be used to construct a [HttpProxy] service with a custom name, + /// connector, and custom session handler. + /// + /// The [ProxyServiceBuilder] will default to using the [ProxyHttp] implementation and no custom + /// session handler. + /// + pub fn new(conf: &Arc, inner: SV) -> Self { + ProxyServiceBuilder { + conf: conf.clone(), + inner, + name: "Pingora HTTP Proxy Service".into(), + connector: (), + custom: None, + server_options: None, + } + } +} + +impl ProxyServiceBuilder +where + SV: ProxyHttp + Send + Sync + 'static, + SV::CTX: Send + Sync + 'static, + C: custom::Connector, +{ + /// Sets the name of the [HttpProxy] service. + pub fn name(mut self, name: impl AsRef) -> Self { + self.name = name.as_ref().to_owned(); + self + } + + /// Set a custom connector and custom session handler for the [ProxyServiceBuilder]. + /// + /// The custom connector is used to establish a connection to the upstream server. + /// + /// The custom session handler is used to handle custom protocol specific logic + /// between the proxy and the upstream server. + /// + /// Returns a new [ProxyServiceBuilder] with the custom connector and session handler. + pub fn custom( + self, + connector: C2, + on_custom: ProcessCustomSession, + ) -> ProxyServiceBuilder { + let Self { + conf, + inner, + name, + server_options, + .. + } = self; + ProxyServiceBuilder { + conf, + inner, + name, + connector, + custom: Some(on_custom), + server_options, + } + } + + /// Set the server options for the [ProxyServiceBuilder]. + /// + /// Returns a new [ProxyServiceBuilder] with the server options set. + pub fn server_options(mut self, options: HttpServerOptions) -> Self { + self.server_options = Some(options); + self + } + + /// Builds a new [Service] from the [ProxyServiceBuilder]. + /// + /// This function takes ownership of the [ProxyServiceBuilder] and returns a new [Service] with + /// a fully initialized [HttpProxy]. + /// + /// The returned [Service] is ready to be used by a [pingora_core::server::Server]. + pub fn build(self) -> Service> { + let Self { + conf, + inner, + name, + connector, + custom, + server_options, + } = self; + + let mut proxy = HttpProxy::new_custom(inner, conf, connector, custom, server_options); + + proxy.handle_init_modules(); + Service::new(name, proxy) + } +} diff --git a/pingora-proxy/src/proxy_cache.rs b/pingora-proxy/src/proxy_cache.rs index e3d1ba52..43b2ace9 100644 --- a/pingora-proxy/src/proxy_cache.rs +++ b/pingora-proxy/src/proxy_cache.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,19 +13,22 @@ // limitations under the License. use super::*; -use http::header::{CONTENT_LENGTH, CONTENT_TYPE}; +use http::header::{CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TYPE, TRANSFER_ENCODING}; use http::{Method, StatusCode}; use pingora_cache::key::CacheHashKey; use pingora_cache::lock::LockStatus; use pingora_cache::max_file_size::ERR_RESPONSE_TOO_LARGE; -use pingora_cache::{ForcedFreshness, HitStatus, RespCacheable::*}; +use pingora_cache::{ForcedFreshness, HitHandler, HitStatus, RespCacheable::*}; use pingora_core::protocols::http::conditional_filter::to_304; use pingora_core::protocols::http::v1::common::header_value_content_length; use pingora_core::ErrorType; use range_filter::RangeBodyFilter; use std::time::SystemTime; -impl HttpProxy { +impl HttpProxy +where + C: custom::Connector, +{ // return bool: server_session can be reused, and error if any pub(crate) async fn proxy_cache( self: &Arc, @@ -153,7 +156,7 @@ impl HttpProxy { session.cache.cache_found(meta, handler, hit_status); } - if hit_status_opt.map_or(true, HitStatus::is_treated_as_miss) { + if hit_status_opt.is_none_or(HitStatus::is_treated_as_miss) { // cache miss if session.cache.is_cache_locked() { // Another request is filling the cache; try waiting til that's done and retry. @@ -300,6 +303,7 @@ impl HttpProxy { // return a 416 with an empty body for simplicity let header_only = header_only || matches!(range_type, RangeType::Invalid); + debug!("header: {header:?}"); // TODO: use ProxyUseCache to replace the logic below match self.inner.response_filter(session, &mut header, ctx).await { @@ -352,18 +356,42 @@ impl HttpProxy { } debug!("finished sending cached header to downstream"); + // If the function returns an Err, there was an issue seeking from the hit handler. + // + // Returning false means that no seeking or state change was done, either because the + // hit handler doesn't support the seek or because multipart doesn't apply. + fn seek_multipart( + hit_handler: &mut HitHandler, + range_filter: &mut RangeBodyFilter, + ) -> Result { + if !range_filter.is_multipart_range() || !hit_handler.can_seek_multipart() { + return Ok(false); + } + let r = range_filter.next_cache_multipart_range(); + hit_handler.seek_multipart(r.start, Some(r.end))?; + // we still need RangeBodyFilter's help to transform the byte + // range into a multipart response. + range_filter.set_current_cursor(r.start); + Ok(true) + } + if !header_only { let mut maybe_range_filter = match &range_type { RangeType::Single(r) => { - if let Err(e) = session.cache.hit_handler().seek(r.start, Some(r.end)) { - return (false, Some(e)); + if session.cache.hit_handler().can_seek() { + if let Err(e) = session.cache.hit_handler().seek(r.start, Some(r.end)) { + return (false, Some(e)); + } + None + } else { + Some(RangeBodyFilter::new_range(range_type.clone())) } - None } RangeType::Multi(_) => { - // TODO: seek hit handler for multipart - let mut range_filter = RangeBodyFilter::new(); - range_filter.set(range_type.clone()); + let mut range_filter = RangeBodyFilter::new_range(range_type.clone()); + if let Err(e) = seek_multipart(session.cache.hit_handler(), &mut range_filter) { + return (false, Some(e)); + } Some(range_filter) } RangeType::Invalid => unreachable!(), @@ -374,6 +402,37 @@ impl HttpProxy { Ok(raw_body) => { let end = raw_body.is_none(); + if end { + if let Some(range_filter) = maybe_range_filter.as_mut() { + if range_filter.should_cache_seek_again() { + let e = match seek_multipart( + session.cache.hit_handler(), + range_filter, + ) { + Ok(true) => { + // called seek(), read again + continue; + } + Ok(false) => { + // body reader can no longer seek multipart, + // but cache wants to continue seeking + // the body will just end in this case if we pass the + // None through + // (TODO: how might hit handlers want to recover from + // this situation)? + Error::explain( + InternalError, + "hit handler cannot seek for multipart again", + ) + // the body will just end in this case. + } + Err(e) => e, + }; + return (false, Some(e)); + } + } + } + let mut body = if let Some(range_filter) = maybe_range_filter.as_mut() { range_filter.filter_body(raw_body) } else { @@ -403,7 +462,7 @@ impl HttpProxy { return (false, Some(e)); } - if !end && body.as_ref().map_or(true, |b| b.is_empty()) { + if !end && body.as_ref().is_none_or(|b| b.is_empty()) { // Don't write empty body which will end session, // still more hit handler bytes to read continue; @@ -602,45 +661,50 @@ impl HttpProxy { } } } - HttpTask::Body(data, end_stream) => match data { - Some(d) => { - if session.cache.enabled() { - // TODO: do this async - // fail if writing the body would exceed the max_file_size_bytes - let body_size_allowed = - session.cache.track_body_bytes_for_max_file_size(d.len()); - if !body_size_allowed { - debug!("chunked response exceeded max cache size, remembering that it is uncacheable"); - session - .cache - .response_became_uncacheable(NoCacheReason::ResponseTooLarge); - - return Error::e_explain( - ERR_RESPONSE_TOO_LARGE, - format!( - "writing data of size {} bytes would exceed max file size of {} bytes", - d.len(), - session.cache.max_file_size_bytes().expect("max file size bytes must be set to exceed size") - ), - ); - } + HttpTask::Body(data, end_stream) | HttpTask::UpgradedBody(data, end_stream) => { + // It is not normally advisable to cache upgraded responses + // e.g. they are essentially close-delimited, so they are easily truncated + // but the framework still allows for it + match data { + Some(d) => { + if session.cache.enabled() { + // TODO: do this async + // fail if writing the body would exceed the max_file_size_bytes + let body_size_allowed = + session.cache.track_body_bytes_for_max_file_size(d.len()); + if !body_size_allowed { + debug!("chunked response exceeded max cache size, remembering that it is uncacheable"); + session + .cache + .response_became_uncacheable(NoCacheReason::ResponseTooLarge); + + return Error::e_explain( + ERR_RESPONSE_TOO_LARGE, + format!( + "writing data of size {} bytes would exceed max file size of {} bytes", + d.len(), + session.cache.max_file_size_bytes().expect("max file size bytes must be set to exceed size") + ), + ); + } - // this will panic if more data is sent after we see end_stream - // but should be impossible in real world - let miss_handler = session.cache.miss_handler().unwrap(); + // this will panic if more data is sent after we see end_stream + // but should be impossible in real world + let miss_handler = session.cache.miss_handler().unwrap(); - miss_handler.write_body(d.clone(), *end_stream).await?; - if *end_stream { - session.cache.finish_miss_handler().await?; + miss_handler.write_body(d.clone(), *end_stream).await?; + if *end_stream { + session.cache.finish_miss_handler().await?; + } } } - } - None => { - if session.cache.enabled() && *end_stream { - session.cache.finish_miss_handler().await?; + None => { + if session.cache.enabled() && *end_stream { + session.cache.finish_miss_handler().await?; + } } } - }, + } HttpTask::Trailer(_) => {} // h1 trailer is not supported yet HttpTask::Done => { if session.cache.enabled() { @@ -866,14 +930,9 @@ impl HttpProxy { ); true } - /* We have 3 options when a lock is held too long - * 1. release the lock and let every request complete for it again - * 2. let every request cache miss - * 3. let every request through while disabling cache - * #1 could repeat the situation but protect the origin from load - * #2 could amplify disk writes and storage for temp file - * #3 is the simplest option for now */ - LockStatus::Timeout => { + // If this reader has spent too long waiting on locks, let the request + // through while disabling cache (to avoid amplifying disk writes). + LockStatus::WaitTimeout => { warn!( "Cache lock timeout, {}", self.inner.request_summary(session, ctx) @@ -882,6 +941,10 @@ impl HttpProxy { // not cacheable, just go to the origin. false } + // When a singular cache lock has been held for too long, + // we should allow requests to recompete for the lock + // to protect upstreams from load. + LockStatus::AgeTimeout => true, // software bug, this status should be impossible to reach LockStatus::Waiting => panic!("impossible LockStatus::Waiting"), } @@ -902,6 +965,12 @@ fn cache_hit_header(cache: &HttpCache) -> Box { let age = cache.cache_meta().age().as_secs(); header.insert_header(http::header::AGE, age).unwrap(); } + log::debug!("cache header: {header:?} {:?}", cache.phase()); + + // currently storage cache is always considered an h1 upstream + // (header-serde serializes as h1.0 or h1.1) + // set this header to be h1.1 + header.set_version(Version::HTTP_11); /* Add chunked header to tell downstream to use chunked encoding * during the absent of content-length in h2 */ @@ -928,7 +997,11 @@ pub mod range_filter { str::from_utf8(input).ok()?.parse().ok() } - fn parse_range_header(range: &[u8], content_length: usize) -> RangeType { + fn parse_range_header( + range: &[u8], + content_length: usize, + max_multipart_ranges: Option, + ) -> RangeType { use regex::Regex; // Match individual range parts, (e.g. "0-100", "-5", "1-") @@ -955,15 +1028,21 @@ pub mod range_filter { return RangeType::None; }; + // "bytes=" with an empty (or whitespace-only) range-set is syntactically a + // range request with zero satisfiable range-specs, so return 416. + if ranges_str.trim().is_empty() { + return RangeType::Invalid; + } + // Get the actual range string (e.g."100-200,300-400") let mut range_count = 0; for _ in ranges_str.split(',') { range_count += 1; - // TODO: make configurable - const MAX_RANGES: usize = 200; - if range_count >= MAX_RANGES { - // If we get more than MAX_RANGES ranges, return None for now to save parsing time - return RangeType::None; + if let Some(max_ranges) = max_multipart_ranges { + if range_count >= max_ranges { + // If we get more than max configured ranges, return None for now to save parsing time + return RangeType::None; + } } } let mut ranges: Vec> = Vec::with_capacity(range_count); @@ -1044,40 +1123,50 @@ pub mod range_filter { #[test] fn test_parse_range() { assert_eq!( - parse_range_header(b"bytes=0-1", 10), + parse_range_header(b"bytes=0-1", 10, None), RangeType::new_single(0, 2) ); assert_eq!( - parse_range_header(b"bYTes=0-9", 10), + parse_range_header(b"bYTes=0-9", 10, None), RangeType::new_single(0, 10) ); assert_eq!( - parse_range_header(b"bytes=0-12", 10), + parse_range_header(b"bytes=0-12", 10, None), RangeType::new_single(0, 10) ); assert_eq!( - parse_range_header(b"bytes=0-", 10), + parse_range_header(b"bytes=0-", 10, None), RangeType::new_single(0, 10) ); - assert_eq!(parse_range_header(b"bytes=2-1", 10), RangeType::Invalid); - assert_eq!(parse_range_header(b"bytes=10-11", 10), RangeType::Invalid); assert_eq!( - parse_range_header(b"bytes=-2", 10), + parse_range_header(b"bytes=2-1", 10, None), + RangeType::Invalid + ); + assert_eq!( + parse_range_header(b"bytes=10-11", 10, None), + RangeType::Invalid + ); + assert_eq!( + parse_range_header(b"bytes=-2", 10, None), RangeType::new_single(8, 10) ); assert_eq!( - parse_range_header(b"bytes=-12", 10), + parse_range_header(b"bytes=-12", 10, None), RangeType::new_single(0, 10) ); - assert_eq!(parse_range_header(b"bytes=-", 10), RangeType::Invalid); - assert_eq!(parse_range_header(b"bytes=", 10), RangeType::None); + assert_eq!(parse_range_header(b"bytes=-", 10, None), RangeType::Invalid); + assert_eq!(parse_range_header(b"bytes=", 10, None), RangeType::Invalid); + assert_eq!( + parse_range_header(b"bytes= ", 10, None), + RangeType::Invalid + ); } // Add some tests for multi-range too #[test] fn test_parse_range_header_multi() { assert_eq!( - parse_range_header(b"bytes=0-1,4-5", 10) + parse_range_header(b"bytes=0-1,4-5", 10, None) .get_multirange_info() .expect("Should have multipart info for Multipart range request") .ranges, @@ -1085,7 +1174,7 @@ pub mod range_filter { ); // Last range is invalid because the content-length is too small assert_eq!( - parse_range_header(b"bytEs=0-99,200-299,400-499", 320) + parse_range_header(b"bytEs=0-99,200-299,400-499", 320, None) .get_multirange_info() .expect("Should have multipart info for Multipart range request") .ranges, @@ -1099,7 +1188,7 @@ pub mod range_filter { ); // Same as above but appropriate content length assert_eq!( - parse_range_header(b"bytEs=0-99,200-299,400-499", 500) + parse_range_header(b"bytEs=0-99,200-299,400-499", 500, None) .get_multirange_info() .expect("Should have multipart info for Multipart range request") .ranges, @@ -1116,29 +1205,35 @@ pub mod range_filter { ] ); // Looks like a range request but it is continuous, we decline to range - assert_eq!(parse_range_header(b"bytes=0-,-2", 10), RangeType::None,); + assert_eq!( + parse_range_header(b"bytes=0-,-2", 10, None), + RangeType::None, + ); // Should not have multirange info set - assert!(parse_range_header(b"bytes=0-,-2", 10) + assert!(parse_range_header(b"bytes=0-,-2", 10, None) .get_multirange_info() .is_none()); // Overlapping ranges, these ranges are currently declined - assert_eq!(parse_range_header(b"bytes=0-3,2-5", 10), RangeType::None,); - assert!(parse_range_header(b"bytes=0-3,2-5", 10) + assert_eq!( + parse_range_header(b"bytes=0-3,2-5", 10, None), + RangeType::None, + ); + assert!(parse_range_header(b"bytes=0-3,2-5", 10, None) .get_multirange_info() .is_none()); // Content length is 2, so only range is 0-2. assert_eq!( - parse_range_header(b"bytes=0-5,10-", 2), + parse_range_header(b"bytes=0-5,10-", 2, None), RangeType::new_single(0, 2) ); - assert!(parse_range_header(b"bytes=0-5,10-", 2) + assert!(parse_range_header(b"bytes=0-5,10-", 2, None) .get_multirange_info() .is_none()); // We should ignore the last incorrect range and return the other acceptable ranges assert_eq!( - parse_range_header(b"bytes=0-5, 10-20, 30-18", 200) + parse_range_header(b"bytes=0-5, 10-20, 30-18", 200, None) .get_multirange_info() .expect("Should have multipart info for Multipart range request") .ranges, @@ -1146,7 +1241,7 @@ pub mod range_filter { ); // All invalid ranges assert_eq!( - parse_range_header(b"bytes=5-0, 20-15, 30-25", 200), + parse_range_header(b"bytes=5-0, 20-15, 30-25", 200, None), RangeType::Invalid ); @@ -1168,7 +1263,10 @@ pub mod range_filter { // Test 200 range limit for parsing. let ranges = generate_range_header(201); - assert_eq!(parse_range_header(&ranges, 1000), RangeType::None) + assert_eq!( + parse_range_header(&ranges, 1000, Some(200)), + RangeType::None + ) } // For Multipart Requests, we need to know the boundary, content length and type across @@ -1207,7 +1305,7 @@ pub mod range_filter { let mut rng: rand::prelude::ThreadRng = rand::thread_rng(); format!("{:016x}", rng.gen::()) } - fn calculate_multipart_length(&self) -> usize { + pub fn calculate_multipart_length(&self) -> usize { let mut total_length = 0; let content_type = self.content_type.as_ref(); for range in self.ranges.clone() { @@ -1270,7 +1368,11 @@ pub mod range_filter { } // Handles both single-range and multipart-range requests - pub fn range_header_filter(req: &RequestHeader, resp: &mut ResponseHeader) -> RangeType { + pub fn range_header_filter( + req: &RequestHeader, + resp: &mut ResponseHeader, + max_multipart_ranges: Option, + ) -> RangeType { // The Range header field is evaluated after evaluating the precondition // header fields defined in [RFC7232], and only if the result in absence // of the Range header field would be a 200 (OK) response @@ -1278,15 +1380,6 @@ pub mod range_filter { return RangeType::None; } - // "A server MUST ignore a Range header field received with a request method other than GET." - if req.method != http::Method::GET && req.method != http::Method::HEAD { - return RangeType::None; - } - - let Some(range_header) = req.headers.get(RANGE) else { - return RangeType::None; - }; - // Content-Length is not required by RFC but it is what nginx does and easier to implement // with this header present. let Some(content_length_bytes) = resp.headers.get(CONTENT_LENGTH) else { @@ -1297,37 +1390,65 @@ pub mod range_filter { return RangeType::None; }; - // if-range wants to understand if the Last-Modified / ETag value matches exactly for use - // with resumable downloads. - // https://datatracker.ietf.org/doc/html/rfc9110#name-if-range - // Note that the RFC wants strong validation, and suggests that - // "A valid entity-tag can be distinguished from a valid HTTP-date - // by examining the first three characters for a DQUOTE," - // but this current etag matching behavior most closely mirrors nginx. - if let Some(if_range) = req.headers.get(IF_RANGE) { - let ir = if_range.as_bytes(); - let matches = if ir.len() >= 2 && ir.last() == Some(&b'"') { - resp.headers.get(ETAG).is_some_and(|etag| etag == if_range) - } else if let Some(last_modified) = resp.headers.get(LAST_MODIFIED) { - last_modified == if_range - } else { - false - }; - if !matches { + // At this point the response is allowed to be served as ranges + // TODO: we can also check Accept-Range header from resp. Nginx gives uses the option + // see proxy_force_ranges + + fn request_range_type( + req: &RequestHeader, + resp: &ResponseHeader, + content_length: usize, + max_multipart_ranges: Option, + ) -> RangeType { + // "A server MUST ignore a Range header field received with a request method other than GET." + if req.method != http::Method::GET && req.method != http::Method::HEAD { return RangeType::None; } - } - // TODO: we can also check Accept-Range header from resp. Nginx gives uses the option - // see proxy_force_ranges + let Some(range_header) = req.headers.get(RANGE) else { + return RangeType::None; + }; - let mut range_type = parse_range_header(range_header.as_bytes(), content_length); + // if-range wants to understand if the Last-Modified / ETag value matches exactly for use + // with resumable downloads. + // https://datatracker.ietf.org/doc/html/rfc9110#name-if-range + // Note that the RFC wants strong validation, and suggests that + // "A valid entity-tag can be distinguished from a valid HTTP-date + // by examining the first three characters for a DQUOTE," + // but this current etag matching behavior most closely mirrors nginx. + if let Some(if_range) = req.headers.get(IF_RANGE) { + let ir = if_range.as_bytes(); + let matches = if ir.len() >= 2 && ir.last() == Some(&b'"') { + resp.headers.get(ETAG).is_some_and(|etag| etag == if_range) + } else if let Some(last_modified) = resp.headers.get(LAST_MODIFIED) { + last_modified == if_range + } else { + false + }; + if !matches { + return RangeType::None; + } + } + + parse_range_header( + range_header.as_bytes(), + content_length, + max_multipart_ranges, + ) + } + + let mut range_type = request_range_type(req, resp, content_length, max_multipart_ranges); match &mut range_type { - RangeType::None => { /* nothing to do*/ } + RangeType::None => { + // At this point, the response is _eligible_ to be served in ranges + // in the future, so add Accept-Ranges, mirroring nginx behavior + resp.insert_header(&ACCEPT_RANGES, "bytes").unwrap(); + } RangeType::Single(r) => { // 206 response resp.set_status(StatusCode::PARTIAL_CONTENT).unwrap(); + resp.remove_header(&ACCEPT_RANGES); resp.insert_header(&CONTENT_LENGTH, r.end - r.start) .unwrap(); resp.insert_header( @@ -1350,6 +1471,7 @@ pub mod range_filter { let total_length = multi_range_info.calculate_multipart_length(); resp.set_status(StatusCode::PARTIAL_CONTENT).unwrap(); + resp.remove_header(&ACCEPT_RANGES); resp.insert_header(CONTENT_LENGTH, total_length).unwrap(); resp.insert_header( CONTENT_TYPE, @@ -1367,8 +1489,10 @@ pub mod range_filter { // empty body for simplicity resp.insert_header(&CONTENT_LENGTH, HeaderValue::from_static("0")) .unwrap(); - // TODO: remove other headers like content-encoding + resp.remove_header(&ACCEPT_RANGES); resp.remove_header(&CONTENT_TYPE); + resp.remove_header(&CONTENT_ENCODING); + resp.remove_header(&TRANSFER_ENCODING); resp.insert_header(&CONTENT_RANGE, format!("bytes */{content_length}")) .unwrap() } @@ -1391,8 +1515,23 @@ pub mod range_filter { // no range let req = gen_req(); let mut resp = gen_resp(); - assert_eq!(RangeType::None, range_header_filter(&req, &mut resp)); + assert_eq!(RangeType::None, range_header_filter(&req, &mut resp, None)); assert_eq!(resp.status.as_u16(), 200); + assert_eq!( + resp.headers.get("accept-ranges").unwrap().as_bytes(), + b"bytes" + ); + + // no range, try HEAD + let mut req = gen_req(); + req.method = Method::HEAD; + let mut resp = gen_resp(); + assert_eq!(RangeType::None, range_header_filter(&req, &mut resp, None)); + assert_eq!(resp.status.as_u16(), 200); + assert_eq!( + resp.headers.get("accept-ranges").unwrap().as_bytes(), + b"bytes" + ); // regular range let mut req = gen_req(); @@ -1400,7 +1539,24 @@ pub mod range_filter { let mut resp = gen_resp(); assert_eq!( RangeType::new_single(0, 2), - range_header_filter(&req, &mut resp) + range_header_filter(&req, &mut resp, None) + ); + assert_eq!(resp.status.as_u16(), 206); + assert_eq!(resp.headers.get("content-length").unwrap().as_bytes(), b"2"); + assert_eq!( + resp.headers.get("content-range").unwrap().as_bytes(), + b"bytes 0-1/10" + ); + assert!(resp.headers.get("accept-ranges").is_none()); + + // regular range, accept-ranges included + let mut req = gen_req(); + req.insert_header("Range", "bytes=0-1").unwrap(); + let mut resp = gen_resp(); + resp.insert_header("Accept-Ranges", "bytes").unwrap(); + assert_eq!( + RangeType::new_single(0, 2), + range_header_filter(&req, &mut resp, None) ); assert_eq!(resp.status.as_u16(), 206); assert_eq!(resp.headers.get("content-length").unwrap().as_bytes(), b"2"); @@ -1408,18 +1564,29 @@ pub mod range_filter { resp.headers.get("content-range").unwrap().as_bytes(), b"bytes 0-1/10" ); + // accept-ranges stripped + assert!(resp.headers.get("accept-ranges").is_none()); // bad range let mut req = gen_req(); req.insert_header("Range", "bytes=1-0").unwrap(); let mut resp = gen_resp(); - assert_eq!(RangeType::Invalid, range_header_filter(&req, &mut resp)); + resp.insert_header("Accept-Ranges", "bytes").unwrap(); + resp.insert_header("Content-Encoding", "gzip").unwrap(); + resp.insert_header("Transfer-Encoding", "chunked").unwrap(); + assert_eq!( + RangeType::Invalid, + range_header_filter(&req, &mut resp, None) + ); assert_eq!(resp.status.as_u16(), 416); assert_eq!(resp.headers.get("content-length").unwrap().as_bytes(), b"0"); assert_eq!( resp.headers.get("content-range").unwrap().as_bytes(), b"bytes */10" ); + assert!(resp.headers.get("accept-ranges").is_none()); + assert!(resp.headers.get("content-encoding").is_none()); + assert!(resp.headers.get("transfer-encoding").is_none()); } // Multipart Tests @@ -1446,7 +1613,7 @@ pub mod range_filter { // valid multipart range let req = gen_req(); let mut resp = gen_resp(); - let result = range_header_filter(&req, &mut resp); + let result = range_header_filter(&req, &mut resp, None); let mut boundary_str = String::new(); assert!(matches!(result, RangeType::Multi(_))); @@ -1468,24 +1635,34 @@ pub mod range_filter { format!("multipart/byteranges; boundary={boundary_str}") ); assert!(resp.headers.get("content_length").is_none()); + assert!(resp.headers.get("accept-ranges").is_none()); // overlapping range, multipart range is declined let req = gen_req_overlap_range(); let mut resp = gen_resp(); - let result = range_header_filter(&req, &mut resp); + let result = range_header_filter(&req, &mut resp, None); assert!(matches!(result, RangeType::None)); assert_eq!(resp.status.as_u16(), 200); assert!(resp.headers.get("content-type").is_none()); + assert_eq!( + resp.headers.get("accept-ranges").unwrap().as_bytes(), + b"bytes" + ); // bad multipart range let mut req = gen_req(); req.insert_header("Range", "bytes=1-0, 12-9, 50-40") .unwrap(); let mut resp = gen_resp(); - let result = range_header_filter(&req, &mut resp); + resp.insert_header("Content-Encoding", "br").unwrap(); + resp.insert_header("Transfer-Encoding", "chunked").unwrap(); + let result = range_header_filter(&req, &mut resp, None); assert!(matches!(result, RangeType::Invalid)); assert_eq!(resp.status.as_u16(), 416); + assert!(resp.headers.get("accept-ranges").is_none()); + assert!(resp.headers.get("content-encoding").is_none()); + assert!(resp.headers.get("transfer-encoding").is_none()); } #[test] @@ -1517,7 +1694,7 @@ pub mod range_filter { let mut resp = gen_resp(); assert_eq!( RangeType::new_single(0, 2), - range_header_filter(&req, &mut resp) + range_header_filter(&req, &mut resp, None) ); // non-matching date @@ -1525,7 +1702,12 @@ pub mod range_filter { req.insert_header("If-Range", "Fri, 07 Jul 2023 22:03:25 GMT") .unwrap(); let mut resp = gen_resp(); - assert_eq!(RangeType::None, range_header_filter(&req, &mut resp)); + assert_eq!(RangeType::None, range_header_filter(&req, &mut resp, None)); + assert_eq!(resp.status.as_u16(), 200); + assert_eq!( + resp.headers.get("accept-ranges").unwrap().as_bytes(), + b"bytes" + ); // match ETag let mut req = gen_req(); @@ -1533,33 +1715,46 @@ pub mod range_filter { let mut resp = gen_resp(); assert_eq!( RangeType::new_single(0, 2), - range_header_filter(&req, &mut resp) + range_header_filter(&req, &mut resp, None) ); + assert_eq!(resp.status.as_u16(), 206); + assert!(resp.headers.get("accept-ranges").is_none()); // non-matching ETags do not result in range let mut req = gen_req(); req.insert_header("If-Range", "\"4567\"").unwrap(); let mut resp = gen_resp(); - assert_eq!(RangeType::None, range_header_filter(&req, &mut resp)); + assert_eq!(RangeType::None, range_header_filter(&req, &mut resp, None)); + assert_eq!(resp.status.as_u16(), 200); + assert_eq!( + resp.headers.get("accept-ranges").unwrap().as_bytes(), + b"bytes" + ); let mut req = gen_req(); req.insert_header("If-Range", "1234").unwrap(); let mut resp = gen_resp(); - assert_eq!(RangeType::None, range_header_filter(&req, &mut resp)); + assert_eq!(RangeType::None, range_header_filter(&req, &mut resp, None)); + assert_eq!(resp.status.as_u16(), 200); + assert_eq!( + resp.headers.get("accept-ranges").unwrap().as_bytes(), + b"bytes" + ); // multipart range with If-Range let mut req = get_multipart_req(); req.insert_header("If-Range", DATE).unwrap(); let mut resp = gen_resp(); - let result = range_header_filter(&req, &mut resp); + let result = range_header_filter(&req, &mut resp, None); assert!(matches!(result, RangeType::Multi(_))); assert_eq!(resp.status.as_u16(), 206); + assert!(resp.headers.get("accept-ranges").is_none()); // multipart with matching ETag let req = get_multipart_req(); let mut resp = gen_resp(); assert!(matches!( - range_header_filter(&req, &mut resp), + range_header_filter(&req, &mut resp, None), RangeType::Multi(_) )); @@ -1567,14 +1762,19 @@ pub mod range_filter { let mut req = get_multipart_req(); req.insert_header("If-Range", "\"wrong\"").unwrap(); let mut resp = gen_resp(); - assert_eq!(RangeType::None, range_header_filter(&req, &mut resp)); + assert_eq!(RangeType::None, range_header_filter(&req, &mut resp, None)); assert_eq!(resp.status.as_u16(), 200); + assert_eq!( + resp.headers.get("accept-ranges").unwrap().as_bytes(), + b"bytes" + ); } pub struct RangeBodyFilter { pub range: RangeType, current: usize, multipart_idx: Option, + cache_multipart_idx: Option, } impl Default for RangeBodyFilter { @@ -1589,16 +1789,62 @@ pub mod range_filter { range: RangeType::None, current: 0, multipart_idx: None, + cache_multipart_idx: None, } } - pub fn set(&mut self, range: RangeType) { - self.range = range.clone(); - if let RangeType::Multi(_) = self.range { - self.multipart_idx = Some(0); + pub fn new_range(range: RangeType) -> Self { + RangeBodyFilter { + multipart_idx: matches!(range, RangeType::Multi(_)).then_some(0), + range, + ..Default::default() + } + } + + pub fn is_multipart_range(&self) -> bool { + matches!(self.range, RangeType::Multi(_)) + } + + /// Whether we should expect the cache body reader to seek again + /// for a different range. + pub fn should_cache_seek_again(&self) -> bool { + match &self.range { + RangeType::Multi(multipart_info) => self + .cache_multipart_idx + .is_some_and(|idx| idx != multipart_info.ranges.len() - 1), + _ => false, } } + /// Returns the next multipart range to seek for the cache body reader. + pub fn next_cache_multipart_range(&mut self) -> Range { + match &self.range { + RangeType::Multi(multipart_info) => { + match self.cache_multipart_idx.as_mut() { + Some(v) => *v += 1, + None => self.cache_multipart_idx = Some(0), + } + let cache_multipart_idx = self.cache_multipart_idx.expect("set above"); + let multipart_idx = self.multipart_idx.expect("must be set on multirange"); + // NOTE: currently this assumes once we start seeking multipart from the hit + // handler, it will continue to return can_seek_multipart true. + assert_eq!(multipart_idx, cache_multipart_idx, + "cache multipart idx should match multipart idx, or there is a hit handler bug"); + multipart_info.ranges[cache_multipart_idx].clone() + } + _ => panic!("tried to advance multipart idx on non-multipart range"), + } + } + + pub fn set_current_cursor(&mut self, current: usize) { + self.current = current; + } + + pub fn set(&mut self, range: RangeType) { + self.multipart_idx = matches!(range, RangeType::Multi(_)).then_some(0); + self.range = range; + } + // Emit final boundary footer for multipart requests pub fn finalize(&self, boundary: &String) -> Option { if let RangeType::Multi(_) = self.range { @@ -1746,26 +1992,22 @@ pub mod range_filter { #[test] fn test_range_body_filter_single() { - let mut body_filter = RangeBodyFilter::new(); + let mut body_filter = RangeBodyFilter::new_range(RangeType::None); assert_eq!(body_filter.filter_body(Some("123".into())).unwrap(), "123"); - let mut body_filter = RangeBodyFilter::new(); - body_filter.set(RangeType::Invalid); + let mut body_filter = RangeBodyFilter::new_range(RangeType::Invalid); assert!(body_filter.filter_body(Some("123".into())).is_none()); - let mut body_filter = RangeBodyFilter::new(); - body_filter.set(RangeType::new_single(0, 1)); + let mut body_filter = RangeBodyFilter::new_range(RangeType::new_single(0, 1)); assert_eq!(body_filter.filter_body(Some("012".into())).unwrap(), "0"); assert!(body_filter.filter_body(Some("345".into())).is_none()); - let mut body_filter = RangeBodyFilter::new(); - body_filter.set(RangeType::new_single(4, 6)); + let mut body_filter = RangeBodyFilter::new_range(RangeType::new_single(4, 6)); assert!(body_filter.filter_body(Some("012".into())).is_none()); assert_eq!(body_filter.filter_body(Some("345".into())).unwrap(), "45"); assert!(body_filter.filter_body(Some("678".into())).is_none()); - let mut body_filter = RangeBodyFilter::new(); - body_filter.set(RangeType::new_single(1, 7)); + let mut body_filter = RangeBodyFilter::new_range(RangeType::new_single(1, 7)); assert_eq!(body_filter.filter_body(Some("012".into())).unwrap(), "12"); assert_eq!(body_filter.filter_body(Some("345".into())).unwrap(), "345"); assert_eq!(body_filter.filter_body(Some("678".into())).unwrap(), "6"); @@ -2059,7 +2301,16 @@ impl ServeFromCache { &mut self, cache: &mut HttpCache, range: &mut RangeBodyFilter, + upgraded: bool, ) -> Result { + fn body_task(data: Bytes, upgraded: bool) -> HttpTask { + if upgraded { + HttpTask::UpgradedBody(Some(data), false) + } else { + HttpTask::Body(Some(data), false) + } + } + if !cache.enabled() { // Cache is disabled due to internal error // TODO: if nothing is sent to eyeball yet, figure out a way to recovery by @@ -2085,26 +2336,42 @@ impl ServeFromCache { Ok(HttpTask::Header(cache_hit_header(cache), true)) } Self::CacheBody(should_seek) => { + log::trace!("cache body should seek: {should_seek}"); if *should_seek { self.maybe_seek_hit_handler(cache, range)?; } - if let Some(b) = cache.hit_handler().read_body().await? { - Ok(HttpTask::Body(Some(b), false)) // false for now - } else { - *self = Self::Done; - Ok(HttpTask::Done) + loop { + if let Some(b) = cache.hit_handler().read_body().await? { + return Ok(body_task(b, upgraded)); + } + // EOF from hit handler for body requested + // if multipart, then seek again + if range.should_cache_seek_again() { + self.maybe_seek_hit_handler(cache, range)?; + } else { + *self = Self::Done; + return Ok(HttpTask::Done); + } } } Self::CacheBodyMiss(should_seek) => { if *should_seek { self.maybe_seek_miss_handler(cache, range)?; } - // safety: called of enable_miss() call it only if the async_body_reader exist - if let Some(b) = cache.miss_body_reader().unwrap().read_body().await? { - Ok(HttpTask::Body(Some(b), false)) // false for now - } else { - *self = Self::DoneMiss; - Ok(HttpTask::Done) + // safety: caller of enable_miss() call it only if the async_body_reader exist + loop { + if let Some(b) = cache.miss_body_reader().unwrap().read_body().await? { + return Ok(body_task(b, upgraded)); + } else { + // EOF from hit handler for body requested + // if multipart, then seek again + if range.should_cache_seek_again() { + self.maybe_seek_miss_handler(cache, range)?; + } else { + *self = Self::DoneMiss; + return Ok(HttpTask::Done); + } + } } } Self::Done => Ok(HttpTask::Done), @@ -2117,20 +2384,38 @@ impl ServeFromCache { cache: &mut HttpCache, range_filter: &mut RangeBodyFilter, ) -> Result<()> { - if let RangeType::Single(range) = &range_filter.range { - // safety: called only if the async_body_reader exists - if cache.miss_body_reader().unwrap().can_seek() { - cache - .miss_body_reader() - // safety: called only if the async_body_reader exists - .unwrap() - .seek(range.start, Some(range.end)) - .or_err(InternalError, "cannot seek miss handler")?; - // Because the miss body reader is seeking, we no longer need the - // RangeBodyFilter's help to return the requested byte range. - range_filter.range = RangeType::None; + match &range_filter.range { + RangeType::Single(range) => { + // safety: called only if the async_body_reader exists + if cache.miss_body_reader().unwrap().can_seek() { + cache + .miss_body_reader() + // safety: called only if the async_body_reader exists + .unwrap() + .seek(range.start, Some(range.end)) + .or_err(InternalError, "cannot seek miss handler")?; + // Because the miss body reader is seeking, we no longer need the + // RangeBodyFilter's help to return the requested byte range. + range_filter.range = RangeType::None; + } + } + RangeType::Multi(_info) => { + // safety: called only if the async_body_reader exists + if cache.miss_body_reader().unwrap().can_seek_multipart() { + let range = range_filter.next_cache_multipart_range(); + cache + .miss_body_reader() + .unwrap() + .seek_multipart(range.start, Some(range.end)) + .or_err(InternalError, "cannot seek hit handler for multirange")?; + // we still need RangeBodyFilter's help to transform the byte + // range into a multipart response. + range_filter.set_current_cursor(range.start); + } } + _ => {} } + *self = Self::CacheBodyMiss(false); Ok(()) } @@ -2152,10 +2437,17 @@ impl ServeFromCache { range_filter.range = RangeType::None; } } - RangeType::Multi(_) => { - // For multipart ranges, we will handle the seeking in - // the body filter per part for now. - // TODO: implement seek for multipart range + RangeType::Multi(_info) => { + if cache.hit_handler().can_seek_multipart() { + let range = range_filter.next_cache_multipart_range(); + cache + .hit_handler() + .seek_multipart(range.start, Some(range.end)) + .or_err(InternalError, "cannot seek hit handler for multirange")?; + // we still need RangeBodyFilter's help to transform the byte + // range into a multipart response. + range_filter.set_current_cursor(range.start); + } } _ => {} } diff --git a/pingora-proxy/src/proxy_common.rs b/pingora-proxy/src/proxy_common.rs index d7d97b34..e1d36f69 100644 --- a/pingora-proxy/src/proxy_common.rs +++ b/pingora-proxy/src/proxy_common.rs @@ -43,6 +43,12 @@ impl DownstreamStateMachine { } } + /// Reset if we should continue reading from the downstream again. + /// Only used with upgraded connections when body mode changes. + pub fn reset(&mut self) { + *self = Self::Reading; + } + pub fn to_errored(&mut self) { *self = Self::Errored } diff --git a/pingora-proxy/src/proxy_custom.rs b/pingora-proxy/src/proxy_custom.rs new file mode 100644 index 00000000..63079111 --- /dev/null +++ b/pingora-proxy/src/proxy_custom.rs @@ -0,0 +1,942 @@ +// Copyright 2026 Cloudflare, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use futures::StreamExt; +use pingora_core::{ + protocols::http::custom::{ + client::Session as CustomSession, is_informational_except_101, BodyWrite, + CustomMessageWrite, CUSTOM_MESSAGE_QUEUE_SIZE, + }, + ImmutStr, +}; +use proxy_cache::{range_filter::RangeBodyFilter, ServeFromCache}; +use proxy_common::{DownstreamStateMachine, ResponseStateMachine}; +use tokio::sync::oneshot; + +use super::*; + +impl HttpProxy +where + C: custom::Connector, +{ + /// Proxy to a custom protocol upstream. + /// Returns (reuse_server, error) + pub(crate) async fn proxy_to_custom_upstream( + &self, + session: &mut Session, + client_session: &mut C::Session, + reused: bool, + peer: &HttpPeer, + ctx: &mut SV::CTX, + ) -> (bool, Option>) + where + SV: ProxyHttp + Send + Sync, + SV::CTX: Send + Sync, + { + #[cfg(windows)] + let raw = client_session.fd() as std::os::windows::io::RawSocket; + #[cfg(unix)] + let raw = client_session.fd(); + + if let Err(e) = self + .inner + .connected_to_upstream(session, reused, peer, raw, client_session.digest(), ctx) + .await + { + return (false, Some(e)); + } + + let (server_session_reuse, error) = self + .custom_proxy_down_to_up(session, client_session, peer, ctx) + .await; + + // Parity with H1/H2: custom upstreams don't report payload bytes; record 0. + session.set_upstream_body_bytes_received(0); + + (server_session_reuse, error) + } + + /// Handle custom protocol proxying from downstream to upstream. + /// Returns (reuse_server, error) + async fn custom_proxy_down_to_up( + &self, + session: &mut Session, + client_session: &mut C::Session, + peer: &HttpPeer, + ctx: &mut SV::CTX, + ) -> (bool, Option>) + where + SV: ProxyHttp + Send + Sync, + SV::CTX: Send + Sync, + { + let mut req = session.req_header().clone(); + + if session.cache.enabled() { + pingora_cache::filters::upstream::request_filter( + &mut req, + session.cache.maybe_cache_meta(), + ); + session.mark_upstream_headers_mutated_for_cache(); + } + + match self + .inner + .upstream_request_filter(session, &mut req, ctx) + .await + { + Ok(_) => { /* continue */ } + Err(e) => { + return (false, Some(e)); + } + } + + session.upstream_compression.request_filter(&req); + let body_empty = session.as_mut().is_body_empty(); + + debug!("Request to custom: {req:?}"); + + let req = Box::new(req); + if let Err(e) = client_session.write_request_header(req, body_empty).await { + return (false, Some(e.into_up())); + } + + client_session.set_read_timeout(peer.options.read_timeout); + client_session.set_write_timeout(peer.options.write_timeout); + + // take the body writer out of the client for easy duplex + let mut client_body = client_session + .take_request_body_writer() + .expect("already send request header"); + + let (tx, rx) = mpsc::channel::(TASK_BUFFER_SIZE); + + session.as_mut().enable_retry_buffering(); + + // Custom message logic + + let Some(mut upstream_custom_message_reader) = client_session.take_custom_message_reader() + else { + return ( + false, + Some(Error::explain( + ReadError, + "can't extract custom reader from upstream", + )), + ); + }; + + let Some(mut upstream_custom_message_writer) = client_session.take_custom_message_writer() + else { + return ( + false, + Some(Error::explain( + WriteError, + "custom upstream must have a custom message writer", + )), + ); + }; + + // A channel to inject custom messages to upstream from server logic. + let (upstream_custom_message_inject_tx, upstream_custom_message_inject_rx) = + mpsc::channel(CUSTOM_MESSAGE_QUEUE_SIZE); + + // Downstream reader + let mut downstream_custom_message_reader = match session.downstream_custom_message() { + Ok(Some(rx)) => rx, + Ok(None) => Box::new(futures::stream::empty::>()), + Err(err) => return (false, Some(err)), + }; + + // Downstream writer + let (mut downstream_custom_message_writer, downstream_custom_final_hop): ( + Box, + bool, // if this hop is final + ) = if let Some(custom_session) = session.downstream_session.as_custom_mut() { + ( + custom_session + .take_custom_message_writer() + .expect("custom downstream must have a custom message writer"), + false, + ) + } else { + (Box::new(()), true) + }; + + // A channel to inject custom messages to downstream from server logic. + let (downstream_custom_message_inject_tx, downstream_custom_message_inject_rx) = + mpsc::channel(CUSTOM_MESSAGE_QUEUE_SIZE); + + // Filters for ProxyHttp trait + let (upstream_custom_message_filter_tx, upstream_custom_message_filter_rx) = + mpsc::channel(CUSTOM_MESSAGE_QUEUE_SIZE); + let (downstream_custom_message_filter_tx, downstream_custom_message_filter_rx) = + mpsc::channel(CUSTOM_MESSAGE_QUEUE_SIZE); + + // Cancellation channels for custom coroutines + // The transmitters act as guards: when dropped, they signal the receivers to cancel. + // `cancel_downstream_reader_tx` is held and later used to explicitly cancel. + // `_cancel_upstream_reader_tx` is unused (prefixed with _) - it will be dropped at the + // end of this scope, which automatically signals cancellation to the upstream reader. + let (cancel_downstream_reader_tx, cancel_downstream_reader_rx) = oneshot::channel(); + let (_cancel_upstream_reader_tx, cancel_upstream_reader_rx) = oneshot::channel(); + + let upstream_custom_message_forwarder = CustomMessageForwarder { + ctx: "down_to_up".into(), + reader: &mut downstream_custom_message_reader, + writer: &mut upstream_custom_message_writer, + filter: upstream_custom_message_filter_tx, + inject: upstream_custom_message_inject_rx, + cancel: cancel_downstream_reader_rx, + }; + + let downstream_custom_message_forwarder = CustomMessageForwarder { + ctx: "up_to_down".into(), + reader: &mut upstream_custom_message_reader, + writer: &mut downstream_custom_message_writer, + filter: downstream_custom_message_filter_tx, + inject: downstream_custom_message_inject_rx, + cancel: cancel_upstream_reader_rx, + }; + + if let Err(e) = self + .inner + .custom_forwarding( + session, + ctx, + Some(upstream_custom_message_inject_tx), + downstream_custom_message_inject_tx, + ) + .await + { + return (false, Some(e)); + } + + /* read downstream body and upstream response at the same time */ + let ret = tokio::try_join!( + self.custom_bidirection_down_to_up( + session, + &mut client_body, + rx, + ctx, + upstream_custom_message_filter_rx, + downstream_custom_message_filter_rx, + downstream_custom_final_hop, + cancel_downstream_reader_tx, + ), + custom_pipe_up_to_down_response(client_session, tx), + upstream_custom_message_forwarder.proxy(), + downstream_custom_message_forwarder.proxy(), + ); + + if let Some(custom_session) = session.downstream_session.as_custom_mut() { + custom_session + .restore_custom_message_writer(downstream_custom_message_writer) + .expect("downstream restore_custom_message_writer should be empty"); + + custom_session + .restore_custom_message_reader(downstream_custom_message_reader) + .expect("downstream restore_custom_message_reader should be empty"); + } + + match ret { + Ok((downstream_can_reuse, _upstream, _custom_up_down, _custom_down_up)) => { + (downstream_can_reuse, None) + } + Err(e) => (false, Some(e)), + } + } + + // returns whether server (downstream) session can be reused + #[allow(clippy::too_many_arguments)] + async fn custom_bidirection_down_to_up( + &self, + session: &mut Session, + client_body: &mut Box, + mut rx: mpsc::Receiver, + ctx: &mut SV::CTX, + mut upstream_custom_message_filter_rx: mpsc::Receiver<( + Bytes, + oneshot::Sender>, + )>, + mut downstream_custom_message_filter_rx: mpsc::Receiver<( + Bytes, + oneshot::Sender>, + )>, + downstream_custom_final_hop: bool, + cancel_downstream_reader_tx: oneshot::Sender<()>, + ) -> Result + where + SV: ProxyHttp + Send + Sync, + SV::CTX: Send + Sync, + { + let mut cancel_downstream_reader_tx = Some(cancel_downstream_reader_tx); + + let mut downstream_state = DownstreamStateMachine::new(session.as_mut().is_body_done()); + + // retry, send buffer if it exists + if let Some(buffer) = session.as_mut().get_retry_buffer() { + self.send_body_to_custom( + session, + Some(buffer), + downstream_state.is_done(), + client_body, + ctx, + ) + .await?; + } + + let mut response_state = ResponseStateMachine::new(); + + // these two below can be wrapped into an internal ctx + // use cache when upstream revalidates (or TODO: error) + let mut serve_from_cache = ServeFromCache::new(); + let mut range_body_filter = proxy_cache::range_filter::RangeBodyFilter::new(); + + let mut upstream_custom = true; + let mut downstream_custom = true; + + /* duplex mode + * see the Same function for h1 for more comments + */ + while !downstream_state.is_done() + || !response_state.is_done() + || upstream_custom + || downstream_custom + { + // partial read support, this check will also be false if cache is disabled. + let support_cache_partial_read = + session.cache.support_streaming_partial_write() == Some(true); + let upgraded = session.was_upgraded(); + + tokio::select! { + body = session.downstream_session.read_body_or_idle(downstream_state.is_done()), if downstream_state.can_poll() => { + let body = match body { + Ok(b) => b, + Err(e) => { + let wait_for_cache_fill = (!serve_from_cache.is_on() && support_cache_partial_read) + || serve_from_cache.is_miss(); + if wait_for_cache_fill { + // ignore downstream error so that upstream can continue to write cache + downstream_state.to_errored(); + warn!( + "Downstream Error ignored during caching: {}, {}", + e, + self.inner.request_summary(session, ctx) + ); + continue; + } else { + return Err(e.into_down()); + } + } + }; + let is_body_done = session.is_body_done(); + + match self.send_body_to_custom(session, body, is_body_done, client_body, ctx).await { + Ok(request_done) => { + downstream_state.maybe_finished(request_done); + }, + Err(e) => { + // mark request done, attempt to drain receive + warn!("body send error: {e}"); + + // upstream is what actually errored but we don't want to continue + // polling the downstream body + downstream_state.to_errored(); + + // downstream still trying to send something, but the upstream is already stooped + // cancel the custom downstream to upstream coroutine, because the proxy will not see EOS. + let _ = cancel_downstream_reader_tx.take().expect("cancel must be set and called once").send(()); + } + }; + }, + + task = rx.recv(), if !response_state.upstream_done() => { + debug!("upstream event"); + + if let Some(t) = task { + debug!("upstream event custom: {:?}", t); + if serve_from_cache.should_discard_upstream() { + // just drain, do we need to do anything else? + continue; + } + // pull as many tasks as we can + let mut tasks = Vec::with_capacity(TASK_BUFFER_SIZE); + tasks.push(t); + while let Ok(task) = rx.try_recv() { + tasks.push(task); + } + + /* run filters before sending to downstream */ + let mut filtered_tasks = Vec::with_capacity(TASK_BUFFER_SIZE); + for mut t in tasks { + if self.revalidate_or_stale(session, &mut t, ctx).await { + serve_from_cache.enable(); + response_state.enable_cached_response(); + // skip downstream filtering entirely as the 304 will not be sent + break; + } + session.upstream_compression.response_filter(&mut t); + // check error and abort + // otherwise the error is surfaced via write_response_tasks() + if !serve_from_cache.should_send_to_downstream() { + if let HttpTask::Failed(e) = t { + return Err(e); + } + } + filtered_tasks.push( + self.custom_response_filter(session, t, ctx, + &mut serve_from_cache, + &mut range_body_filter, false).await?); + if serve_from_cache.is_miss_header() { + response_state.enable_cached_response(); + } + } + + if !serve_from_cache.should_send_to_downstream() { + // TODO: need to derive response_done from filtered_tasks in case downstream failed already + continue; + } + + let upgraded = session.was_upgraded(); + let response_done = session.write_response_tasks(filtered_tasks).await?; + if !upgraded && session.was_upgraded() && downstream_state.can_poll() { + // just upgraded, the downstream state should be reset to continue to + // poll body + trace!("reset downstream state on upgrade"); + downstream_state.reset(); + } + + response_state.maybe_set_upstream_done(response_done); + } else { + debug!("empty upstream event"); + response_state.maybe_set_upstream_done(true); + } + } + + task = serve_from_cache.next_http_task(&mut session.cache, &mut range_body_filter, upgraded), + if !response_state.cached_done() && !downstream_state.is_errored() && serve_from_cache.is_on() => { + let task = self.custom_response_filter(session, task?, ctx, + &mut serve_from_cache, + &mut range_body_filter, true).await?; + match session.write_response_tasks(vec![task]).await { + Ok(b) => response_state.maybe_set_cache_done(b), + Err(e) => if serve_from_cache.is_miss() { + // give up writing to downstream but wait for upstream cache write to finish + downstream_state.to_errored(); + response_state.maybe_set_cache_done(true); + warn!( + "Downstream Error ignored during caching: {}, {}", + e, + self.inner.request_summary(session, ctx) + ); + continue; + } else { + return Err(e); + } + } + if response_state.cached_done() { + if let Err(e) = session.cache.finish_hit_handler().await { + warn!("Error during finish_hit_handler: {}", e); + } + } + } + + ret = upstream_custom_message_filter_rx.recv(), if upstream_custom => { + let Some(msg) = ret else { + debug!("upstream_custom_message_filter_rx: custom downstream to upstream exited on reading"); + upstream_custom = false; + continue; + }; + + let (data, callback) = msg; + + let new_msg = self.inner + .downstream_custom_message_proxy_filter(session, data, ctx, false) // false because the upstream is custom + .await?; + + if callback.send(new_msg).is_err() { + debug!("upstream_custom_message_incoming_rx: custom downstream to upstream exited on callback"); + upstream_custom = false; + continue; + }; + }, + + ret = downstream_custom_message_filter_rx.recv(), if downstream_custom => { + let Some(msg) = ret else { + debug!("downstream_custom_message_filter_rx: custom upstream to downstream exited on reading"); + downstream_custom = false; + continue; + }; + + let (data, callback) = msg; + + let new_msg = self.inner + .upstream_custom_message_proxy_filter(session, data, ctx, downstream_custom_final_hop) + .await?; + + if callback.send(new_msg).is_err() { + debug!("downstream_custom_message_filter_rx: custom upstream to downstream exited on callback"); + downstream_custom = false; + continue + }; + }, + + else => { + break; + } + } + } + + // Re-raise the error then the loop is finished. + if downstream_state.is_errored() { + let err = Error::e_explain(WriteError, "downstream_state is_errored"); + error!("custom_bidirection_down_to_up: downstream_state.is_errored",); + return err; + } + + client_body.cleanup().await?; + + let mut reuse_downstream = !downstream_state.is_errored(); + if reuse_downstream { + match session.as_mut().finish_body().await { + Ok(_) => { + debug!("finished sending body to downstream"); + } + Err(e) => { + error!("Error finish sending body to downstream: {}", e); + reuse_downstream = false; + } + } + } + Ok(reuse_downstream) + } + + async fn custom_response_filter( + &self, + session: &mut Session, + mut task: HttpTask, + ctx: &mut SV::CTX, + serve_from_cache: &mut ServeFromCache, + range_body_filter: &mut RangeBodyFilter, + from_cache: bool, // are the task from cache already + ) -> Result + where + SV: ProxyHttp + Send + Sync, + SV::CTX: Send + Sync, + { + if !from_cache { + self.upstream_filter(session, &mut task, ctx).await?; + + // cache the original response before any downstream transformation + // requests that bypassed cache still need to run filters to see if the response has become cacheable + if session.cache.enabled() || session.cache.bypassing() { + if let Err(e) = self + .cache_http_task(session, &task, ctx, serve_from_cache) + .await + { + session.cache.disable(NoCacheReason::StorageError); + if serve_from_cache.is_miss_body() { + // if the response stream cache body during miss but write fails, it has to + // give up the entire request + return Err(e); + } else { + // otherwise, continue processing the response + warn!( + "Fail to cache response: {}, {}", + e, + self.inner.request_summary(session, ctx) + ); + } + } + } + // skip the downstream filtering if these tasks are just for cache admission + if !serve_from_cache.should_send_to_downstream() { + return Ok(task); + } + } // else: cached/local response, no need to trigger upstream filters and caching + + match task { + HttpTask::Header(mut header, eos) => { + /* Downstream revalidation, only needed when cache is on because otherwise origin + * will handle it */ + // TODO: if cache is disabled during response phase, we should still do the filter + if session.cache.enabled() { + self.downstream_response_conditional_filter( + serve_from_cache, + session, + &mut header, + ctx, + ); + if !session.ignore_downstream_range { + let range_type = self.inner.range_header_filter(session, &mut header, ctx); + range_body_filter.set(range_type); + } + } + + self.inner + .response_filter(session, &mut header, ctx) + .await?; + /* Downgrade the version so that write_response_header won't panic */ + header.set_version(Version::HTTP_11); + + // these status codes / method cannot have body, so no need to add chunked encoding + let no_body = session.req_header().method == "HEAD" + || matches!(header.status.as_u16(), 204 | 304); + + /* Add chunked header to tell downstream to use chunked encoding + * during the absent of content-length */ + if !no_body + && !header.status.is_informational() + && header.headers.get(http::header::CONTENT_LENGTH).is_none() + { + header.insert_header(http::header::TRANSFER_ENCODING, "chunked")?; + } + Ok(HttpTask::Header(header, eos)) + } + HttpTask::Body(data, eos) => { + let mut data = range_body_filter.filter_body(data); + if let Some(duration) = self + .inner + .response_body_filter(session, &mut data, eos, ctx)? + { + trace!("delaying response for {duration:?}"); + time::sleep(duration).await; + } + Ok(HttpTask::Body(data, eos)) + } + HttpTask::UpgradedBody(mut data, eos) => { + // range body filter doesn't apply to upgraded body + if let Some(duration) = self + .inner + .response_body_filter(session, &mut data, eos, ctx)? + { + trace!("delaying upgraded response for {duration:?}"); + time::sleep(duration).await; + } + Ok(HttpTask::UpgradedBody(data, eos)) + } + HttpTask::Trailer(mut trailers) => { + let trailer_buffer = match trailers.as_mut() { + Some(trailers) => { + debug!("Parsing response trailers.."); + match self + .inner + .response_trailer_filter(session, trailers, ctx) + .await + { + Ok(buf) => buf, + Err(e) => { + error!( + "Encountered error while filtering upstream trailers {:?}", + e + ); + None + } + } + } + _ => None, + }; + // if we have a trailer buffer write it to the downstream response body + if let Some(buffer) = trailer_buffer { + // write_body will not write additional bytes after reaching the content-length + // for gRPC H2 -> H1 this is not a problem but may be a problem for non gRPC code + // https://http2.github.io/http2-spec/#malformed + Ok(HttpTask::Body(Some(buffer), true)) + } else { + Ok(HttpTask::Trailer(trailers)) + } + } + HttpTask::Done => Ok(task), + HttpTask::Failed(_) => Ok(task), // Do nothing just pass the error down + } + } + + async fn send_body_to_custom( + &self, + session: &mut Session, + mut data: Option, + end_of_body: bool, + client_body: &mut Box, + ctx: &mut SV::CTX, + ) -> Result + where + SV: ProxyHttp + Send + Sync, + SV::CTX: Send + Sync, + { + session + .downstream_modules_ctx + .request_body_filter(&mut data, end_of_body) + .await?; + + self.inner + .request_body_filter(session, &mut data, end_of_body, ctx) + .await?; + + if session.was_upgraded() { + client_body.upgrade_body_writer(); + } + + /* it is normal to get 0 bytes because of multi-chunk parsing or request_body_filter. + * Although there is no harm writing empty byte to custom, unlike h1, we ignore it + * for consistency */ + if !end_of_body && data.as_ref().is_some_and(|d| d.is_empty()) { + return Ok(false); + } + + if let Some(mut data) = data { + client_body + .write_all_buf(&mut data) + .await + .map_err(|e| e.into_up())?; + if end_of_body { + client_body.finish().await.map_err(|e| e.into_up())?; + } + } else { + debug!("Read downstream body done"); + client_body + .finish() + .await + .map_err(|e| { + Error::because(WriteError, "while shutdown send data stream on no data", e) + }) + .map_err(|e| e.into_up())?; + } + + Ok(end_of_body) + } +} + +/* Read response header, body and trailer from custom upstream and send them to tx */ +async fn custom_pipe_up_to_down_response( + client: &mut S, + tx: mpsc::Sender, +) -> Result<()> { + let mut is_informational = true; + while is_informational { + client + .read_response_header() + .await + .map_err(|e| e.into_up())?; + let resp_header = Box::new(client.response_header().expect("just read").clone()); + // `101 Switching Protocols` is a response to the http1 Upgrade header and it's final response. + // The WebSocket Protocol https://datatracker.ietf.org/doc/html/rfc6455 + is_informational = is_informational_except_101(resp_header.status.as_u16() as u32); + + match client.check_response_end_or_error(true).await { + Ok(eos) => { + tx.send(HttpTask::Header(resp_header, eos)) + .await + .or_err(InternalError, "sending custom headers to pipe")?; + } + Err(e) => { + // If upstream errored, then push error to downstream and then quit + // Don't care if send fails (which means downstream already gone) + // we were still able to retrieve the headers, so try sending + let _ = tx.send(HttpTask::Header(resp_header, false)).await; + let _ = tx.send(HttpTask::Failed(e.into_up())).await; + return Ok(()); + } + } + } + + while let Some(chunk) = client + .read_response_body() + .await + .map_err(|e| e.into_up()) + .transpose() + { + let data = match chunk { + Ok(d) => d, + Err(e) => { + // Push the error to downstream and then quit + let _ = tx.send(HttpTask::Failed(e.into_up())).await; + // Downstream should consume all remaining data and handle the error + return Ok(()); + } + }; + + match client.check_response_end_or_error(false).await { + Ok(eos) => { + let empty = data.is_empty(); + if empty && !eos { + /* it is normal to get 0 bytes because of multi-chunk + * don't write 0 bytes to downstream since it will be + * misread as the terminating chunk */ + continue; + } + let body_task = if client.was_upgraded() { + HttpTask::UpgradedBody(Some(data), eos) + } else { + HttpTask::Body(Some(data), eos) + }; + let sent = tx + .send(body_task) + .await + .or_err(InternalError, "sending custom body to pipe"); + // If the if the response with content-length is sent to an HTTP1 downstream, + // custom_bidirection_down_to_up() could decide that the body has finished and exit without + // waiting for this function to signal the eos. In this case tx being closed is not + // an sign of error. It should happen if the only thing left for the custom to send is + // an empty data frame with eos set. + if sent.is_err() && eos && empty { + return Ok(()); + } + sent?; + } + Err(e) => { + // Similar to above, push the error to downstream and then quit + let _ = tx.send(HttpTask::Failed(e.into_up())).await; + return Ok(()); + } + } + } + + // attempt to get trailers + let trailers = match client.read_trailers().await { + Ok(t) => t, + Err(e) => { + // Similar to above, push the error to downstream and then quit + let _ = tx.send(HttpTask::Failed(e.into_up())).await; + return Ok(()); + } + }; + + let trailers = trailers.map(Box::new); + + if trailers.is_some() { + tx.send(HttpTask::Trailer(trailers)) + .await + .or_err(InternalError, "sending custom trailer to pipe")?; + } + + tx.send(HttpTask::Done) + .await + .unwrap_or_else(|_| debug!("custom channel closed!")); + + Ok(()) +} + +struct CustomMessageForwarder<'a> { + ctx: ImmutStr, + writer: &'a mut Box, + reader: + &'a mut Box>> + Send + Sync + Unpin>, + inject: mpsc::Receiver, + filter: mpsc::Sender<(Bytes, oneshot::Sender>)>, + cancel: oneshot::Receiver<()>, +} + +impl CustomMessageForwarder<'_> { + async fn proxy(mut self) -> Result<()> { + let forwarder = async { + let mut injector_status = true; + let mut reader_status = true; + + debug!("{}: CustomMessageForwarder: start", self.ctx); + + while injector_status || reader_status { + let (data, proxied) = tokio::select! { + ret = self.inject.recv(), if injector_status => { + let Some(data) = ret else { + injector_status = false; + continue + }; + (data, false) + }, + + ret = self.reader.next(), if reader_status => { + let Some(data) = ret else { + reader_status = false; + continue + }; + + let data = match data { + Ok(data) => data, + Err(err) => { + reader_status = false; + warn!("{}: CustomMessageForwarder: reader returned err: {err:?}", self.ctx); + continue; + }, + }; + (data, true) + }, + }; + + let (callback_tx, callback_rx) = oneshot::channel(); + + // If data received from proxy send it to filter + if proxied { + if self.filter.send((data, callback_tx)).await.is_err() { + debug!( + "{}: CustomMessageForwarder: filter receiver dropped", + self.ctx + ); + return Error::e_explain( + WriteError, + "CustomMessageForwarder: main proxy thread exited on filter send", + ); + }; + } else { + callback_tx + .send(Some(data)) + .expect("sending from the same thread"); + } + + match callback_rx.await { + Ok(None) => continue, // message was filtered + Ok(Some(msg)) => { + self.writer.write_custom_message(msg).await?; + } + Err(err) => { + debug!( + "{}: CustomMessageForwarder: callback_rx return error: {err}", + self.ctx + ); + return Error::e_because( + WriteError, + "CustomMessageForwarder: main proxy thread exited on callback_rx await", + err, + ); + } + }; + } + + debug!("{}: CustomMessageForwarder: exit loop", self.ctx); + + let ret = self.writer.finish_custom().await; + if let Err(ref err) = ret { + debug!( + "{}: CustomMessageForwarder: finish_custom return error: {err}", + self.ctx + ); + }; + ret?; + + debug!( + "{}: CustomMessageForwarder: exit loop successfully", + self.ctx + ); + + Ok(()) + }; + + tokio::select! { + ret = &mut self.cancel => { + debug!("{}: CustomMessageForwarder: canceled while waiting for new messages: {ret:?}", self.ctx); + Ok(()) + }, + ret = forwarder => ret + } + } +} diff --git a/pingora-proxy/src/proxy_h1.rs b/pingora-proxy/src/proxy_h1.rs index c0446742..9f04289c 100644 --- a/pingora-proxy/src/proxy_h1.rs +++ b/pingora-proxy/src/proxy_h1.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,12 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. +use futures::future::OptionFuture; +use futures::StreamExt; + use super::*; use crate::proxy_cache::{range_filter::RangeBodyFilter, ServeFromCache}; use crate::proxy_common::*; use pingora_cache::CachePhase; +use pingora_core::protocols::http::custom::CUSTOM_MESSAGE_QUEUE_SIZE; -impl HttpProxy { +impl HttpProxy +where + C: custom::Connector, +{ pub(crate) async fn proxy_1to1( &self, session: &mut Session, @@ -85,6 +92,11 @@ impl HttpProxy { } } + let mut downstream_custom_message_writer = session + .downstream_session + .as_custom_mut() + .and_then(|c| c.take_custom_message_writer()); + let (tx_upstream, rx_upstream) = mpsc::channel::(TASK_BUFFER_SIZE); let (tx_downstream, rx_downstream) = mpsc::channel::(TASK_BUFFER_SIZE); @@ -92,10 +104,28 @@ impl HttpProxy { // start bi-directional streaming let ret = tokio::try_join!( - self.proxy_handle_downstream(session, tx_downstream, rx_upstream, ctx), + self.proxy_handle_downstream( + session, + tx_downstream, + rx_upstream, + ctx, + &mut downstream_custom_message_writer + ), self.proxy_handle_upstream(client_session, tx_upstream, rx_downstream), ); + if let Some(custom_session) = session.downstream_session.as_custom_mut() { + if let Some(downstream_custom_message_writer) = downstream_custom_message_writer { + match custom_session.restore_custom_message_writer(downstream_custom_message_writer) + { + Ok(_) => { /* continue */ } + Err(e) => { + return (false, false, Some(e)); + } + } + } + } + match ret { Ok((downstream_can_reuse, _upstream)) => (downstream_can_reuse, true, None), Err(e) => (false, false, Some(e)), @@ -120,6 +150,8 @@ impl HttpProxy { #[cfg(unix)] let raw = client_session.id(); + let initial_write_pending = client_session.stream().get_write_pending_time(); + if let Err(e) = self .inner .connected_to_upstream( @@ -138,6 +170,15 @@ impl HttpProxy { let (server_session_reuse, client_session_reuse, error) = self.proxy_1to1(session, client_session, peer, ctx).await; + // Record upstream response body bytes received (payload only) for logging consumers. + let upstream_bytes_total = client_session.body_bytes_received(); + session.set_upstream_body_bytes_received(upstream_bytes_total); + + // Record upstream write pending time for this session only (delta from baseline). + let current_write_pending = client_session.stream().get_write_pending_time(); + let upstream_write_pending = current_write_pending.saturating_sub(initial_write_pending); + session.set_upstream_write_pending_time(upstream_write_pending); + (server_session_reuse, client_session_reuse, error) } @@ -154,6 +195,7 @@ impl HttpProxy { let mut request_done = false; let mut response_done = false; let mut send_error = None; + let mut upgraded = false; /* duplex mode, wait for either to complete */ while !request_done || !response_done { @@ -162,6 +204,14 @@ impl HttpProxy { match res { Ok(task) => { response_done = task.is_end(); + if !upgraded && client_session.was_upgraded() { + // upgrade can only happen once + upgraded = true; + if send_error.is_none() { + // continue receiving from downstream after body mode change + request_done = false; + } + } let type_str = task.type_str(); let result = tx.send(task) .await.or_err_with( @@ -174,7 +224,7 @@ impl HttpProxy { // In that case, this function should ignore that the pipe is closed. // So that this function could read the rest events from rx including // the closure, then exit. - if result.is_err() && !client_session.is_upgrade_req() { + if result.is_err() && !client_session.was_upgraded() { return result; } }, @@ -193,13 +243,14 @@ impl HttpProxy { Ok(send_done) => { request_done = send_done; // An upgraded request is terminated when either side is done - if request_done && client_session.is_upgrade_req() { + if request_done && client_session.was_upgraded() { response_done = true; } }, Err(e) => { - debug!("send error, draining read buf: {e}"); + warn!("send error, draining read buf: {e}"); request_done = true; + send_error = Some(e); continue } @@ -224,11 +275,33 @@ impl HttpProxy { tx: mpsc::Sender, mut rx: mpsc::Receiver, ctx: &mut SV::CTX, + downstream_custom_message_writer: &mut Option>, ) -> Result where SV: ProxyHttp + Send + Sync, SV::CTX: Send + Sync, { + // setup custom message forwarding, if downstream supports it + let ( + mut downstream_custom_read, + mut downstream_custom_write, + downstream_custom_message_custom_forwarding, + mut downstream_custom_message_inject_rx, + mut downstream_custom_message_reader, + ) = if downstream_custom_message_writer.is_some() { + let reader = session.downstream_custom_message()?; + let (inject_tx, inject_rx) = mpsc::channel::(CUSTOM_MESSAGE_QUEUE_SIZE); + (true, true, Some(inject_tx), Some(inject_rx), reader) + } else { + (false, false, None, None, None) + }; + + if let Some(custom_forwarding) = downstream_custom_message_custom_forwarding { + self.inner + .custom_forwarding(session, ctx, None, custom_forwarding) + .await?; + } + let mut downstream_state = DownstreamStateMachine::new(session.as_mut().is_body_done()); let buffer = session.as_ref().get_retry_buffer(); @@ -273,13 +346,32 @@ impl HttpProxy { * If both are done, quit the loop * Usually there is no request body to read for cacheable request */ - while !downstream_state.is_done() || !response_state.is_done() { + while !downstream_state.is_done() + || !response_state.is_done() + || downstream_custom_read && !downstream_state.is_errored() + || downstream_custom_write + { // reserve tx capacity ahead to avoid deadlock, see below let send_permit = tx .try_reserve() .or_err(InternalError, "try_reserve() body pipe for upstream"); + // Use optional futures to allow using optional channels in select branches + let custom_inject_rx_recv: OptionFuture<_> = downstream_custom_message_inject_rx + .as_mut() + .map(|rx| rx.recv()) + .into(); + let custom_reader_next: OptionFuture<_> = downstream_custom_message_reader + .as_mut() + .map(|reader| reader.next()) + .into(); + + // partial read support, this check will also be false if cache is disabled. + let support_cache_partial_read = + session.cache.support_streaming_partial_write() == Some(true); + let upgraded = session.was_upgraded(); + tokio::select! { // only try to send to pipe if there is capacity to avoid deadlock // Otherwise deadlock could happen if both upstream and downstream are blocked @@ -291,7 +383,9 @@ impl HttpProxy { let body = match body { Ok(b) => b, Err(e) => { - if serve_from_cache.is_miss() { + let wait_for_cache_fill = (!serve_from_cache.is_on() && support_cache_partial_read) + || serve_from_cache.is_miss(); + if wait_for_cache_fill { // ignore downstream error so that upstream can continue to write cache downstream_state.to_errored(); warn!( @@ -299,6 +393,9 @@ impl HttpProxy { e, self.inner.request_summary(session, ctx) ); + // This will not be treated as a final error, but we should signal to + // downstream session regardless + session.downstream_session.on_proxy_failure(e); continue; } else { return Err(e.into_down()); @@ -307,7 +404,7 @@ impl HttpProxy { }; // If the request is websocket, `None` body means the request is closed. // Set the response to be done as well so that the request completes normally. - if body.is_none() && session.is_upgrade_req() { + if body.is_none() && session.was_upgraded() { response_state.maybe_set_upstream_done(true); } // TODO: consider just drain this if serve_from_cache is set @@ -386,9 +483,17 @@ impl HttpProxy { } // set to downstream + let upgraded = session.was_upgraded(); let response_done = session.write_response_tasks(filtered_tasks).await?; + if !upgraded && session.was_upgraded() && downstream_state.can_poll() { + // just upgraded, the downstream state should be reset to continue to + // poll body + trace!("reset downstream state on upgrade"); + downstream_state.reset(); + } response_state.maybe_set_upstream_done(response_done); - // unsuccessful upgrade response may force the request done + // unsuccessful upgrade response (or end of upstream upgraded conn, + // which forces the body reader to complete) may force the request done downstream_state.maybe_finished(session.is_body_done()); } else { debug!("empty upstream event"); @@ -396,7 +501,7 @@ impl HttpProxy { } }, - task = serve_from_cache.next_http_task(&mut session.cache, &mut range_body_filter), + task = serve_from_cache.next_http_task(&mut session.cache, &mut range_body_filter, upgraded), if !response_state.cached_done() && !downstream_state.is_errored() && serve_from_cache.is_on() => { let task = self.h1_response_filter(session, task?, ctx, @@ -415,6 +520,9 @@ impl HttpProxy { e, self.inner.request_summary(session, ctx) ); + // This will not be treated as a final error, but we should signal to + // downstream session regardless + session.downstream_session.on_proxy_failure(e); continue; } else { return Err(e); @@ -427,12 +535,56 @@ impl HttpProxy { } } + data = custom_reader_next, if downstream_custom_read && !downstream_state.is_errored() => { + let Some(data) = data.flatten() else { + downstream_custom_read = false; + continue; + }; + + let data = match data { + Ok(data) => data, + Err(err) => { + warn!("downstream_custom_message_reader got error: {err}"); + downstream_custom_read = false; + continue; + }, + }; + + self.inner + .downstream_custom_message_proxy_filter(session, data, ctx, true) // true, because it's the last hop for downstream proxying + .await?; + }, + + data = custom_inject_rx_recv, if downstream_custom_write => { + match data.flatten() { + Some(data) => { + if let Some(ref mut custom_writer) = downstream_custom_message_writer { + custom_writer.write_custom_message(data).await? + } + }, + None => { + downstream_custom_write = false; + if let Some(ref mut custom_writer) = downstream_custom_message_writer { + custom_writer.finish_custom().await?; + } + }, + } + }, + else => { break; } } } + if let Some(custom_session) = session.downstream_session.as_custom_mut() { + if let Some(downstream_custom_message_reader) = downstream_custom_message_reader { + custom_session + .restore_custom_message_reader(downstream_custom_message_reader) + .expect("downstream restore_custom_message_reader should be empty"); + } + } + let mut reuse_downstream = !downstream_state.is_errored(); if reuse_downstream { match session.as_mut().finish_body().await { @@ -521,6 +673,10 @@ impl HttpProxy { } } + // TODO: just set version to Version::HTTP_11 unconditionally here, + // (with another todo being an option to faithfully proxy the <1.1 responses) + // as we are already trying to mutate this for HTTP/1.1 downstream reuse + /* Convert HTTP 1.0 style response to chunked encoding so that we don't * have to close the downstream connection */ // these status codes / method cannot have body, so no need to add chunked encoding @@ -564,6 +720,24 @@ impl HttpProxy { Ok(HttpTask::Body(data, end)) } + HttpTask::UpgradedBody(mut data, end) => { + if track_max_cache_size { + session + .cache + .track_body_bytes_for_max_file_size(data.as_ref().map_or(0, |d| d.len())); + } + + // range doesn't apply to upgraded body + if let Some(duration) = self + .inner + .response_body_filter(session, &mut data, end, ctx)? + { + trace!("delaying downstream upgraded response for {:?}", duration); + time::sleep(duration).await; + } + + Ok(HttpTask::UpgradedBody(data, end)) + } HttpTask::Trailer(h) => Ok(HttpTask::Trailer(h)), // TODO: support trailers for h1 HttpTask::Done => Ok(task), HttpTask::Failed(_) => Ok(task), // Do nothing just pass the error down @@ -604,6 +778,8 @@ impl HttpProxy { .request_body_filter(&mut data, end_of_body) .await?; + // TODO: request body filter to have info about upgraded status? + // (can also check session.was_upgraded()) self.inner .request_body_filter(session, &mut data, end_of_body, ctx) .await?; @@ -624,7 +800,12 @@ impl HttpProxy { data.as_ref().map_or(-1, |d| d.len() as isize) ); - tx.send(HttpTask::Body(data, upstream_end_of_body)); + // upgraded body needs to be marked + if session.was_upgraded() { + tx.send(HttpTask::UpgradedBody(data, upstream_end_of_body)); + } else { + tx.send(HttpTask::Body(data, upstream_end_of_body)); + } Ok(end_of_body) } @@ -657,15 +838,49 @@ pub(crate) async fn send_body_to1( } } } + HttpTask::UpgradedBody(data, end) => { + client_session.maybe_upgrade_body_writer(); + + body_done = end; + if let Some(d) = data { + let m = client_session.write_body(&d).await; + match m { + Ok(m) => { + match m { + Some(n) => { + debug!("Write {} bytes upgraded body to upstream", n); + } + None => { + warn!("Upstream upgraded body is already finished. Nothing to write"); + } + } + } + Err(e) => { + return e.into_up().into_err(); + } + } + } + } _ => { // should never happen, sender only sends body warn!("Unexpected task sent to upstream"); body_done = true; + // error here, + // for client sessions that received upgrade but didn't + // receive any UpgradedBody, + // no more data is arriving so we should consider this + // as downstream finalizing its upgrade payload + client_session.maybe_upgrade_body_writer(); } } } else { // sender dropped body_done = true; + // for client sessions that received upgrade but didn't + // receive any UpgradedBody, + // no more data is arriving so we should consider this + // as downstream finalizing its upgrade payload + client_session.maybe_upgrade_body_writer(); } if body_done { diff --git a/pingora-proxy/src/proxy_h2.rs b/pingora-proxy/src/proxy_h2.rs index ae406e64..808da5bc 100644 --- a/pingora-proxy/src/proxy_h2.rs +++ b/pingora-proxy/src/proxy_h2.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,11 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. +use futures::future::OptionFuture; +use futures::StreamExt; + use super::*; use crate::proxy_cache::{range_filter::RangeBodyFilter, ServeFromCache}; use crate::proxy_common::*; use http::{header::CONTENT_LENGTH, Method, StatusCode}; use pingora_cache::CachePhase; +use pingora_core::protocols::http::custom::CUSTOM_MESSAGE_QUEUE_SIZE; use pingora_core::protocols::http::v2::{client::Http2Session, write_body}; // add scheme and authority as required by h2 lib @@ -67,7 +71,10 @@ fn update_h2_scheme_authority( } } -impl HttpProxy { +impl HttpProxy +where + C: custom::Connector, +{ pub(crate) async fn proxy_down_to_up( &self, session: &mut Session, @@ -159,6 +166,11 @@ impl HttpProxy { client_session.read_timeout = peer.options.read_timeout; + let mut downstream_custom_message_writer = session + .downstream_session + .as_custom_mut() + .and_then(|c| c.take_custom_message_writer()); + // take the body writer out of the client for easy duplex let mut client_body = client_session .take_request_body_writer() @@ -175,13 +187,40 @@ impl HttpProxy { /* read downstream body and upstream response at the same time */ let ret = tokio::try_join!( - self.bidirection_down_to_up(session, &mut client_body, rx, ctx, write_timeout), + self.bidirection_down_to_up( + session, + &mut client_body, + rx, + ctx, + write_timeout, + &mut downstream_custom_message_writer + ), pipe_up_to_down_response(client_session, tx) ); + if let Some(custom_session) = session.downstream_session.as_custom_mut() { + if let Some(downstream_custom_message_writer) = downstream_custom_message_writer { + match custom_session.restore_custom_message_writer(downstream_custom_message_writer) + { + Ok(_) => { /* continue */ } + Err(e) => { + return (false, Some(e)); + } + } + } + } + match ret { Ok((downstream_can_reuse, _upstream)) => (downstream_can_reuse, None), - Err(e) => (false, Some(e)), + Err(e) => { + // On application level upstream read timeouts, send RST_STREAM CANCEL, + // we know we have not received END_STREAM at this point since we read timed out + // TODO: implement for write timeouts? + if e.esource == ErrorSource::Upstream && matches!(e.etype, ReadTimedout) { + client_body.send_reset(h2::Reason::CANCEL); + } + (false, Some(e)) + } } } @@ -214,6 +253,12 @@ impl HttpProxy { .proxy_down_to_up(session, client_session, peer, ctx) .await; + // Record upstream response body bytes received (HTTP/2 DATA payload). + let upstream_bytes_total = client_session.body_bytes_received(); + session.set_upstream_body_bytes_received(upstream_bytes_total); + + // Note: upstream_write_pending_time is not tracked for HTTP/2 (multiplexed streams). + (server_session_reuse, error) } @@ -225,11 +270,33 @@ impl HttpProxy { mut rx: mpsc::Receiver, ctx: &mut SV::CTX, write_timeout: Option, + downstream_custom_message_writer: &mut Option>, ) -> Result where SV: ProxyHttp + Send + Sync, SV::CTX: Send + Sync, { + // setup custom message forwarding, if downstream supports it + let ( + mut downstream_custom_read, + mut downstream_custom_write, + downstream_custom_message_custom_forwarding, + mut downstream_custom_message_inject_rx, + mut downstream_custom_message_reader, + ) = if downstream_custom_message_writer.is_some() { + let reader = session.downstream_custom_message()?; + let (inject_tx, inject_rx) = mpsc::channel::(CUSTOM_MESSAGE_QUEUE_SIZE); + (true, true, Some(inject_tx), Some(inject_rx), reader) + } else { + (false, false, None, None, None) + }; + + if let Some(custom_forwarding) = downstream_custom_message_custom_forwarding { + self.inner + .custom_forwarding(session, ctx, None, custom_forwarding) + .await?; + } + let mut downstream_state = DownstreamStateMachine::new(session.as_mut().is_body_done()); // retry, send buffer if it exists @@ -255,7 +322,26 @@ impl HttpProxy { /* duplex mode * see the Same function for h1 for more comments */ - while !downstream_state.is_done() || !response_state.is_done() { + while !downstream_state.is_done() + || !response_state.is_done() + || downstream_custom_read && !downstream_state.is_errored() + || downstream_custom_write + { + // Use optional futures to allow using optional channels in select branches + let custom_inject_rx_recv: OptionFuture<_> = downstream_custom_message_inject_rx + .as_mut() + .map(|rx| rx.recv()) + .into(); + let custom_reader_next: OptionFuture<_> = downstream_custom_message_reader + .as_mut() + .map(|reader| reader.next()) + .into(); + + // partial read support, this check will also be false if cache is disabled. + let support_cache_partial_read = + session.cache.support_streaming_partial_write() == Some(true); + let upgraded = session.was_upgraded(); + // Similar logic in h1 need to reserve capacity first to avoid deadlock // But we don't need to do the same because the h2 client_body pipe is unbounded (never block) tokio::select! { @@ -265,7 +351,9 @@ impl HttpProxy { let body = match body { Ok(b) => b, Err(e) => { - if serve_from_cache.is_miss() { + let wait_for_cache_fill = (!serve_from_cache.is_on() && support_cache_partial_read) + || serve_from_cache.is_miss(); + if wait_for_cache_fill { // ignore downstream error so that upstream can continue to write cache downstream_state.to_errored(); warn!( @@ -273,6 +361,9 @@ impl HttpProxy { e, self.inner.request_summary(session, ctx) ); + // This will not be treated as a final error, but we should signal to + // downstream session regardless + session.downstream_session.on_proxy_failure(e); continue; } else { return Err(e.into_down()); @@ -345,6 +436,11 @@ impl HttpProxy { } let response_done = session.write_response_tasks(filtered_tasks).await?; + if session.was_upgraded() { + // it is very weird if the downstream session decides to upgrade + // since the client h2 session cannot, return an error on this case + return Error::e_explain(H2Error, "upgraded while proxying to h2 session"); + } response_state.maybe_set_upstream_done(response_done); } else { debug!("empty upstream event"); @@ -352,7 +448,7 @@ impl HttpProxy { } } - task = serve_from_cache.next_http_task(&mut session.cache, &mut range_body_filter), + task = serve_from_cache.next_http_task(&mut session.cache, &mut range_body_filter, upgraded), if !response_state.cached_done() && !downstream_state.is_errored() && serve_from_cache.is_on() => { let task = self.h2_response_filter(session, task?, ctx, &mut serve_from_cache, @@ -370,6 +466,9 @@ impl HttpProxy { e, self.inner.request_summary(session, ctx) ); + // This will not be treated as a final error, but we should signal to + // downstream session regardless + session.downstream_session.on_proxy_failure(e); continue; } else { return Err(e); @@ -381,6 +480,42 @@ impl HttpProxy { } } } + data = custom_reader_next, if downstream_custom_read && !downstream_state.is_errored() => { + let Some(data) = data.flatten() else { + + downstream_custom_read = false; + continue; + }; + + let data = match data { + Ok(data) => data, + Err(err) => { + warn!("downstream_custom_message_reader got error: {err}"); + downstream_custom_read = false; + continue; + }, + }; + + self.inner + .downstream_custom_message_proxy_filter(session, data, ctx, true) // true, because it's the last hop for downstream proxying + .await?; + }, + + data = custom_inject_rx_recv, if downstream_custom_write => { + match data.flatten() { + Some(data) => { + if let Some(ref mut custom_writer) = downstream_custom_message_writer { + custom_writer.write_custom_message(data).await? + } + }, + None => { + downstream_custom_write = false; + if let Some(ref mut custom_writer) = downstream_custom_message_writer { + custom_writer.finish_custom().await?; + } + }, + } + }, else => { break; @@ -388,6 +523,14 @@ impl HttpProxy { } } + if let Some(custom_session) = session.downstream_session.as_custom_mut() { + if let Some(downstream_custom_message_reader) = downstream_custom_message_reader { + custom_session + .restore_custom_message_reader(downstream_custom_message_reader) + .expect("downstream restore_custom_message_reader should be empty"); + } + } + let mut reuse_downstream = !downstream_state.is_errored(); if reuse_downstream { match session.as_mut().finish_body().await { @@ -512,6 +655,11 @@ impl HttpProxy { } Ok(HttpTask::Body(data, eos)) } + HttpTask::UpgradedBody(..) => { + // An h2 session should not be able to send an h2 upgraded response body, + // and logically that is impossible unless there is a bug in the client v2 session + panic!("Unexpected UpgradedBody task while proxy h2"); + } HttpTask::Trailer(mut trailers) => { let trailer_buffer = match trailers.as_mut() { Some(trailers) => { diff --git a/pingora-proxy/src/proxy_purge.rs b/pingora-proxy/src/proxy_purge.rs index 1464aa15..cfdb9078 100644 --- a/pingora-proxy/src/proxy_purge.rs +++ b/pingora-proxy/src/proxy_purge.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -59,7 +59,10 @@ static NOT_PURGEABLE: Lazy = Lazy::new(|| gen_purge_response(405 // on cache storage or proxy error static INTERNAL_ERROR: Lazy = Lazy::new(|| error_resp::gen_error_response(500)); -impl HttpProxy { +impl HttpProxy +where + C: custom::Connector, +{ pub(crate) async fn proxy_purge( &self, session: &mut Session, diff --git a/pingora-proxy/src/proxy_trait.rs b/pingora-proxy/src/proxy_trait.rs index 85d61baa..d5a3efde 100644 --- a/pingora-proxy/src/proxy_trait.rs +++ b/pingora-proxy/src/proxy_trait.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -90,7 +90,7 @@ pub trait ProxyHttp { /// Returns whether this session is allowed to spawn subrequests. /// - /// This function is checked after [`early_request_filter`] to allow that filter to configure + /// This function is checked after [Self::early_request_filter] to allow that filter to configure /// this if required. This will also run for subrequests themselves, which may allowed to spawn /// their own subrequests. /// @@ -130,18 +130,31 @@ pub trait ProxyHttp { /// /// By default this filter does nothing which effectively disables caching. // Ideally only session.cache should be modified, TODO: reflect that in this interface - fn request_cache_filter(&self, _session: &mut Session, _ctx: &mut Self::CTX) -> Result<()> { + fn request_cache_filter(&self, _session: &mut Session, _ctx: &mut Self::CTX) -> Result<()> + where + Self::CTX: Send + Sync, + { Ok(()) } - /// This callback generates the cache key + /// This callback generates the cache key. + /// + /// This callback is called only when cache is enabled for this request. + /// + /// There is no sensible default cache key for all proxy applications. The + /// correct key depends on which request properties affect upstream responses + /// (e.g. `Vary` headers, custom request filters that modify the origin host). + /// Getting this wrong leads to cache poisoning. /// - /// This callback is called only when cache is enabled for this request + /// See `pingora-proxy/tests/utils/server_utils.rs` for a minimal (not + /// production-ready) reference implementation. /// - /// By default this callback returns a default cache key generated from the request. - fn cache_key_callback(&self, session: &Session, _ctx: &mut Self::CTX) -> Result { - let req_header = session.req_header(); - Ok(CacheKey::default(req_header)) + /// # Panics + /// + /// The default implementation panics. You **must** override this method when + /// caching is enabled. + fn cache_key_callback(&self, _session: &Session, _ctx: &mut Self::CTX) -> Result { + unimplemented!("cache_key_callback must be implemented when caching is enabled") } /// This callback is invoked when a cacheable response is ready to be admitted to cache. @@ -255,7 +268,12 @@ pub trait ProxyHttp { resp: &mut ResponseHeader, _ctx: &mut Self::CTX, ) -> range_filter::RangeType { - proxy_cache::range_filter::range_header_filter(session.req_header(), resp) + const DEFAULT_MAX_RANGES: Option = Some(200); + proxy_cache::range_filter::range_header_filter( + session.req_header(), + resp, + DEFAULT_MAX_RANGES, + ) } /// Modify the request before it is sent to the upstream @@ -309,6 +327,51 @@ pub trait ProxyHttp { Ok(()) } + // custom_forwarding is called when downstream and upstream connections are successfully established. + #[doc(hidden)] + async fn custom_forwarding( + &self, + _session: &mut Session, + _ctx: &mut Self::CTX, + _custom_message_to_upstream: Option>, + _custom_message_to_downstream: mpsc::Sender, + ) -> Result<()> + where + Self::CTX: Send + Sync, + { + Ok(()) + } + + // received a custom message from the downstream before sending it to the upstream. + #[doc(hidden)] + async fn downstream_custom_message_proxy_filter( + &self, + _session: &mut Session, + custom_message: Bytes, + _ctx: &mut Self::CTX, + _final_hop: bool, + ) -> Result> + where + Self::CTX: Send + Sync, + { + Ok(Some(custom_message)) + } + + // received a custom message from the upstream before sending it to the downstream. + #[doc(hidden)] + async fn upstream_custom_message_proxy_filter( + &self, + _session: &mut Session, + custom_message: Bytes, + _ctx: &mut Self::CTX, + _final_hop: bool, + ) -> Result> + where + Self::CTX: Send + Sync, + { + Ok(Some(custom_message)) + } + /// Similar to [Self::upstream_response_filter()] but for response body /// /// This function will be called every time a piece of response body is received. The `body` is diff --git a/pingora-proxy/src/subrequest/mod.rs b/pingora-proxy/src/subrequest/mod.rs index 8cfd6215..8141f8c4 100644 --- a/pingora-proxy/src/subrequest/mod.rs +++ b/pingora-proxy/src/subrequest/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use bytes::Bytes; use pingora_cache::lock::{CacheKeyLockImpl, LockStatus, WritePermit}; use pingora_cache::CacheKey; use pingora_core::protocols::http::subrequest::server::{ @@ -19,14 +20,25 @@ use pingora_core::protocols::http::subrequest::server::{ }; use std::any::Any; +pub mod pipe; + struct LockCtx { write_permit: WritePermit, cache_lock: &'static CacheKeyLockImpl, key: CacheKey, } +// Thin wrapper to allow iterating over InputBody Vec. +pub(crate) struct InputBodyReader(std::vec::IntoIter); + +impl InputBodyReader { + pub fn read_body(&mut self) -> Option { + self.0.next() + } +} + /// Optional user-defined subrequest context. -pub type UserCtx = Box<(dyn Any + Sync + Send)>; +pub type UserCtx = Box; #[derive(Debug, Copy, Clone, Default, PartialEq, Eq)] pub enum BodyMode { diff --git a/pingora-proxy/src/subrequest/pipe.rs b/pingora-proxy/src/subrequest/pipe.rs new file mode 100644 index 00000000..6dd4a57e --- /dev/null +++ b/pingora-proxy/src/subrequest/pipe.rs @@ -0,0 +1,399 @@ +// Copyright 2026 Cloudflare, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Subrequest piping. +//! +//! Along with subrequests themselves, subrequest piping as a feature is in +//! alpha stages, APIs are highly unstable and subject to change at any point. +//! +//! Unlike proxy_*, it is not a "true" proxy mode; the functions here help +//! establish a pipe between the main downstream session and the subrequest (which +//! in most cases will be used as a downstream session itself). +//! +//! Furthermore, only downstream modules are invoked on the main downstream session, +//! and the ProxyHttp trait filters are not run on the HttpTasks from the main session +//! (the only relevant one being the request body filter). + +use crate::proxy_common::{DownstreamStateMachine, ResponseStateMachine}; +use crate::subrequest::*; +use crate::{PreparedSubrequest, Session}; +use bytes::Bytes; +use futures::FutureExt; +use log::{debug, warn}; +use pingora_core::protocols::http::{subrequest::server::SubrequestHandle, HttpTask}; +use pingora_error::{Error, ErrorType::*, OrErr, Result}; +use tokio::sync::mpsc; + +pub enum InputBodyType { + /// Preset body + Preset(InputBody), + /// Body should be saved (up to limit) + SaveBody(usize), +} + +/// Context struct as a result of subrequest piping. +#[derive(Clone)] +pub struct PipeSubrequestState { + /// The saved (captured) body from the main session. + pub saved_body: Option, +} + +impl PipeSubrequestState { + fn new() -> PipeSubrequestState { + PipeSubrequestState { saved_body: None } + } +} + +pub struct PipeSubrequestError { + pub state: PipeSubrequestState, + /// Whether error originated (and was propagated from) subrequest itself + /// (vs. an error that occurred while sending task) + pub from_subreq: bool, + pub error: Box, +} +impl PipeSubrequestError { + pub fn new( + error: impl Into>, + from_subreq: bool, + state: PipeSubrequestState, + ) -> Self { + PipeSubrequestError { + error: error.into(), + from_subreq, + state, + } + } +} + +fn map_pipe_err>>( + result: Result, + from_subreq: bool, + state: &PipeSubrequestState, +) -> Result { + result.map_err(|e| PipeSubrequestError::new(e, from_subreq, state.clone())) +} + +#[derive(Debug, Clone)] +pub struct SavedBody { + body: Vec, + complete: bool, + truncated: bool, + length: usize, + max_length: usize, +} + +impl SavedBody { + pub fn new(max_length: usize) -> Self { + SavedBody { + body: vec![], + complete: false, + truncated: false, + length: 0, + max_length, + } + } + + pub fn save_body_bytes(&mut self, body_bytes: Bytes) -> bool { + let len = body_bytes.len(); + if self.length + len > self.max_length { + self.truncated = true; + return false; + } + self.length += len; + self.body.push(body_bytes); + true + } + + pub fn is_body_complete(&self) -> bool { + self.complete && !self.truncated + } + + pub fn set_body_complete(&mut self) { + self.complete = true; + } +} + +#[derive(Debug, Clone)] +pub enum InputBody { + NoBody, + Bytes(Vec), + // TODO: stream +} + +impl InputBody { + pub(crate) fn into_reader(self) -> InputBodyReader { + InputBodyReader(match self { + InputBody::NoBody => vec![].into_iter(), + InputBody::Bytes(v) => v.into_iter(), + }) + } + + pub fn is_body_empty(&self) -> bool { + match self { + InputBody::NoBody => true, + InputBody::Bytes(v) => v.is_empty(), + } + } +} + +impl std::convert::From for InputBody { + fn from(body: SavedBody) -> Self { + if body.body.is_empty() { + InputBody::NoBody + } else { + InputBody::Bytes(body.body) + } + } +} + +pub async fn pipe_subrequest( + session: &mut Session, + mut subrequest: PreparedSubrequest, + subrequest_handle: SubrequestHandle, + mut task_filter: F, + input_body: InputBodyType, +) -> std::result::Result +where + F: FnMut(HttpTask) -> Result>, +{ + let (maybe_preset_body, saved_body) = match input_body { + InputBodyType::Preset(body) => (Some(body), None), + InputBodyType::SaveBody(limit) => (None, Some(SavedBody::new(limit))), + }; + let use_preset_body = maybe_preset_body.is_some(); + + let mut response_state = ResponseStateMachine::new(); + let (no_body_input, mut maybe_preset_reader) = if use_preset_body { + let preset_body = maybe_preset_body.expect("checked above"); + (preset_body.is_body_empty(), Some(preset_body.into_reader())) + } else { + (session.as_mut().is_body_done(), None) + }; + let mut downstream_state = DownstreamStateMachine::new(no_body_input); + + let mut state = PipeSubrequestState::new(); + state.saved_body = saved_body; + + // Have the subrequest remove all body-related headers if no body will be sent + // TODO: we could also await the join handle, but subrequest may be running logging phase + // also the full run() may also await cache fill if downstream fails + let _join_handle = tokio::spawn(async move { + if no_body_input { + subrequest + .session_mut() + .as_subrequest_mut() + .expect("PreparedSubrequest must be subrequest") + .clear_request_body_headers(); + } + subrequest.run().await + }); + let tx = subrequest_handle.tx; + let mut rx = subrequest_handle.rx; + + let mut wants_body = false; + let mut wants_body_rx_err = false; + let mut wants_body_rx = subrequest_handle.subreq_wants_body; + + let mut proxy_error_rx_err = false; + let mut proxy_error_rx = subrequest_handle.subreq_proxy_error; + + // Note: "upstream" here refers to subrequest session tasks, + // downstream refers to main session + while !downstream_state.is_done() || !response_state.is_done() { + let send_permit = tx + .try_reserve() + .or_err(InternalError, "try_reserve() body pipe for subrequest"); + + tokio::select! { + task = rx.recv(), if !response_state.upstream_done() => { + debug!("upstream event: {:?}", task); + if let Some(t) = task { + // pull as many tasks as we can + const TASK_BUFFER_SIZE: usize = 4; + let mut tasks = Vec::with_capacity(TASK_BUFFER_SIZE); + let task = map_pipe_err(task_filter(t), false, &state)?; + if let Some(filtered) = task { + tasks.push(filtered); + } + // tokio::task::unconstrained because now_or_never may yield None when the future is ready + while let Some(maybe_task) = tokio::task::unconstrained(rx.recv()).now_or_never() { + if let Some(t) = maybe_task { + let task = map_pipe_err(task_filter(t), false, &state)?; + if let Some(filtered) = task { + tasks.push(filtered); + } + } else { + break + } + } + // FIXME: if one of these tasks is Failed(e), the session will return that + // error; in this case, the error is actually from the subreq + let response_done = map_pipe_err(session.write_response_tasks(tasks).await, false, &state)?; + + // NOTE: technically it is the downstream whose response state has finished here + // we consider the subrequest's work done however + response_state.maybe_set_upstream_done(response_done); + // unsuccessful upgrade response may force the request done + // (can only happen with a real session, TODO to allow with preset body) + downstream_state.maybe_finished(!use_preset_body && session.is_body_done()); + } else { + // quite possible that the subrequest may be finished, though the main session + // is not - we still must exit in this case + debug!("empty upstream event"); + response_state.maybe_set_upstream_done(true); + } + }, + + res = &mut wants_body_rx, if !wants_body && !wants_body_rx_err => { + // subrequest may need time before it needs body, or it may not actually require it + // TODO: tx send permit may not be necessary if no oneshot exists + if res.is_err() { + wants_body_rx_err = true; + } else { + wants_body = true; + } + } + + res = &mut proxy_error_rx, if !proxy_error_rx_err => { + if let Ok(e) = res { + // propagate proxy error to caller + return Err(PipeSubrequestError::new(e, true, state)); + } else { + // subrequest dropped, let select loop finish + proxy_error_rx_err = true; + } + } + + _ = tx.reserve(), if downstream_state.is_reading() && send_permit.is_err() => { + // If tx is closed, the upstream has already finished its job. + downstream_state.maybe_finished(tx.is_closed()); + debug!("waiting for permit {send_permit:?}, upstream closed {}", tx.is_closed()); + /* No permit, wait on more capacity to avoid starving. + * Otherwise this select only blocks on rx, which might send no data + * before the entire body is uploaded. + * once more capacity arrives we just loop back + */ + }, + + body = session.downstream_session.read_body_or_idle(downstream_state.is_done()), + if wants_body && !use_preset_body && downstream_state.can_poll() && send_permit.is_ok() => { + // this is the first subrequest + // send the body + debug!("downstream event: main body for subrequest"); + let body = map_pipe_err(body.map_err(|e| e.into_down()), false, &state)?; + + // If the request is websocket, `None` body means the request is closed. + // Set the response to be done as well so that the request completes normally. + if body.is_none() && session.is_upgrade_req() { + response_state.maybe_set_upstream_done(true); + } + + let is_body_done = session.is_body_done(); + let request_done = map_pipe_err(send_body_to_pipe( + session, + body, + is_body_done, + state.saved_body.as_mut(), + send_permit.expect("checked is_ok()"), + ) + .await, false, &state)?; + + downstream_state.maybe_finished(request_done); + + }, + + // lazily evaluated async block allows us to expect() inside the select! branch + body = async { maybe_preset_reader.as_mut().expect("preset body set").read_body() }, + if wants_body && use_preset_body && !downstream_state.is_done() && downstream_state.can_poll() && send_permit.is_ok() => { + debug!("downstream event: preset body for subrequest"); + + // TODO: WebSocket handling to set upstream done? + + // preset None body indicates we are done + let is_body_done = body.is_none(); + // Don't run downstream modules on preset input body + let request_done = map_pipe_err(do_send_body_to_pipe( + body, + is_body_done, + None, + send_permit.expect("checked is_ok()"), + ), false, &state)?; + downstream_state.maybe_finished(request_done); + + }, + + else => break, + } + } + Ok(state) +} + +// Mostly the same as proxy_common, but does not run proxy request_body_filter +async fn send_body_to_pipe( + session: &mut Session, + mut data: Option, + end_of_body: bool, + saved_body: Option<&mut SavedBody>, + tx: mpsc::Permit<'_, HttpTask>, +) -> Result { + // None: end of body + // this var is to signal if downstream finish sending the body, which shouldn't be + // affected by the request_body_filter + let end_of_body = end_of_body || data.is_none(); + + session + .downstream_modules_ctx + .request_body_filter(&mut data, end_of_body) + .await?; + + do_send_body_to_pipe(data, end_of_body, saved_body, tx) +} + +fn do_send_body_to_pipe( + data: Option, + end_of_body: bool, + mut saved_body: Option<&mut SavedBody>, + tx: mpsc::Permit<'_, HttpTask>, +) -> Result { + // the flag to signal to upstream + let upstream_end_of_body = end_of_body || data.is_none(); + + /* It is normal to get 0 bytes because of multi-chunk or request_body_filter decides not to + * output anything yet. + * Don't write 0 bytes to the network since it will be + * treated as the terminating chunk */ + if !upstream_end_of_body && data.as_ref().is_some_and(|d| d.is_empty()) { + return Ok(false); + } + + debug!( + "Read {} bytes body from downstream", + data.as_ref().map_or(-1, |d| d.len() as isize) + ); + + if let Some(capture) = saved_body.as_mut() { + if capture.is_body_complete() { + warn!("subrequest trying to save body after body is complete"); + } else if let Some(d) = data.as_ref() { + capture.save_body_bytes(d.clone()); + } + if end_of_body { + capture.set_body_complete(); + } + } + + tx.send(HttpTask::Body(data, upstream_end_of_body)); + + Ok(end_of_body) +} diff --git a/pingora-proxy/tests/test_basic.rs b/pingora-proxy/tests/test_basic.rs index 7b093dfe..77303fc3 100644 --- a/pingora-proxy/tests/test_basic.rs +++ b/pingora-proxy/tests/test_basic.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -21,7 +21,8 @@ use hyper::{body::HttpBody, header::HeaderValue, Body, Client}; #[cfg(unix)] use hyperlocal::{UnixClientExt, Uri}; use reqwest::{header, StatusCode}; -use tokio::net::TcpStream; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::{TcpListener, TcpStream}; use utils::server_utils::init; @@ -745,6 +746,77 @@ async fn test_connect_close() { assert_eq!(body, "Hello World!\n"); } +#[tokio::test] +async fn test_connect_proxying_disallowed_h1() { + init(); + + let mut stream = TcpStream::connect("127.0.0.1:6147").await.unwrap(); + let request = b"CONNECT pingora.org:443 HTTP/1.1\r\nHost: pingora.org:443\r\n\r\n"; + stream.write_all(request).await.unwrap(); + + let mut buf = [0u8; 1024]; + let read = stream.read(&mut buf).await.unwrap(); + let resp = std::str::from_utf8(&buf[..read]).unwrap(); + let status_line = resp.lines().next().unwrap_or(""); + assert!(status_line.contains(" 405 ")); +} + +#[tokio::test] +async fn test_connect_proxying_disallowed_h2() { + init(); + + let tcp = TcpStream::connect("127.0.0.1:6146").await.unwrap(); + let (mut h2, connection) = client::handshake(tcp).await.unwrap(); + tokio::spawn(async move { + connection.await.unwrap(); + }); + + let request = Request::builder() + .method("CONNECT") + .uri("http://pingora.org:443/") + .body(()) + .unwrap(); + let (response, _body) = h2.send_request(request, true).unwrap(); + let (head, mut body) = response.await.unwrap().into_parts(); + assert_eq!(head.status.as_u16(), 405); + while let Some(chunk) = body.data().await { + assert!(chunk.unwrap().is_empty()); + } +} + +#[tokio::test] +async fn test_connect_proxying_allowed_h1() { + init(); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let upstream_addr = listener.local_addr().unwrap(); + + // Note per RFC CONNECT 2xx responses are not allowed to have response + // bodies, so this is non-standard behavior. + tokio::spawn(async move { + let (mut socket, _) = listener.accept().await.unwrap(); + let mut buf = [0u8; 1024]; + let _ = socket.read(&mut buf).await.unwrap(); + let response = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok"; + socket.write_all(response).await.unwrap(); + let _ = socket.shutdown().await; + }); + + let mut stream = TcpStream::connect("127.0.0.1:6160").await.unwrap(); + let request = format!( + "CONNECT pingora.org:443 HTTP/1.1\r\nHost: pingora.org:443\r\nX-Port: {}\r\n\r\n", + upstream_addr.port() + ); + stream.write_all(request.as_bytes()).await.unwrap(); + + let mut buf = vec![0u8; 1024]; + let read = stream.read(&mut buf).await.unwrap(); + let resp = std::str::from_utf8(&buf[..read]).unwrap(); + let status_line = resp.lines().next().unwrap_or(""); + assert!(status_line.contains(" 200 ")); + assert!(resp.ends_with("ok")); +} + #[tokio::test] #[cfg(feature = "any_tls")] async fn test_mtls_no_client_cert() { diff --git a/pingora-proxy/tests/test_upstream.rs b/pingora-proxy/tests/test_upstream.rs index 3b1333c1..49e5a6fa 100644 --- a/pingora-proxy/tests/test_upstream.rs +++ b/pingora-proxy/tests/test_upstream.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,12 +15,16 @@ mod utils; use utils::server_utils::init; -use utils::websocket::WS_ECHO; +use utils::websocket::{WS_ECHO, WS_ECHO_RAW}; use futures::{SinkExt, StreamExt}; +use pingora_http::ResponseHeader; use reqwest::header::{HeaderName, HeaderValue}; -use reqwest::StatusCode; +use reqwest::{StatusCode, Version}; use std::time::Duration; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpStream; +use tokio::time::timeout; use tokio_tungstenite::tungstenite::{client::IntoClientRequest, Message}; #[tokio::test] @@ -184,6 +188,178 @@ async fn test_ws_server_ends_conn() { assert!(ws_stream.next().await.is_none()); } +fn parse_response_header(buf: &[u8]) -> ResponseHeader { + let mut headers = vec![httparse::EMPTY_HEADER; 256]; + let mut parsed = httparse::Response::new(&mut headers); + match parsed.parse(buf).unwrap() { + httparse::Status::Complete(_) => { + let mut resp = + ResponseHeader::build(parsed.code.unwrap(), Some(parsed.headers.len())).unwrap(); + for header in parsed.headers.iter() { + resp.append_header(header.name.to_string(), header.value) + .unwrap(); + } + resp + } + _ => panic!("expects a whole response header"), + } +} + +/// Read response header and return it along with any preread body data +async fn read_response_header(stream: &mut tokio::net::TcpStream) -> (ResponseHeader, Vec) { + let mut response = vec![]; + let mut header_end = 0; + let mut buf = [0; 1024]; + loop { + let n = stream.read(&mut buf).await.unwrap(); + response.extend_from_slice(&buf[..n]); + let mut end_of_response = false; + for (i, w) in response.windows(4).enumerate() { + if w == b"\r\n\r\n" { + end_of_response = true; + header_end = i + 4; + break; + } + } + if end_of_response { + break; + } + } + let response_header = parse_response_header(&response[..header_end]); + let preread_body = response[header_end..].to_vec(); + (response_header, preread_body) +} + +/// Read remaining body bytes from stream until expected_body_len is reached +async fn read_response_body( + stream: &mut tokio::net::TcpStream, + mut body: Vec, + expected_body_len: usize, +) -> Vec { + let mut buf = [0; 1024]; + while body.len() < expected_body_len { + let n = stream.read(&mut buf).await.unwrap(); + body.extend_from_slice(&buf[..n]); + } + if body.len() > expected_body_len { + panic!("more body bytes than expected"); + } + body +} + +async fn read_response( + stream: &mut tokio::net::TcpStream, + expected_body_len: usize, +) -> (ResponseHeader, Vec) { + let (response_header, body) = read_response_header(stream).await; + let body = read_response_body(stream, body, expected_body_len).await; + (response_header, body) +} + +#[tokio::test] +async fn test_upgrade_smoke() { + init(); + + let mut stream = TcpStream::connect("127.0.0.1:6147").await.unwrap(); + + let req = concat!( + "GET /upgrade HTTP/1.1\r\n", + "Host: 127.0.0.1\r\n", + "Upgrade: websocket\r\n", + "Connection: Upgrade\r\n", + "\r\n" + ); + stream.write_all(req.as_bytes()).await.unwrap(); + stream.flush().await.unwrap(); + + let expected_payload = b"hello\n"; + let fut = read_response(&mut stream, expected_payload.len()); + let (resp_header, resp_body) = timeout(Duration::from_secs(5), fut).await.unwrap(); + + assert_eq!(resp_header.status, 101); + assert_eq!(resp_header.headers["Upgrade"], "websocket"); + assert_eq!(resp_header.headers["Connection"], "upgrade"); + assert_eq!(resp_body, expected_payload); +} + +#[tokio::test] +async fn test_upgrade_body() { + init(); + + let mut stream = TcpStream::connect("127.0.0.1:6147").await.unwrap(); + + let req = concat!( + "POST /upgrade_echo_body HTTP/1.1\r\n", + "Host: 127.0.0.1\r\n", + "Upgrade: websocket\r\n", + "Connection: Upgrade\r\n", + "Content-Length: 1024\r\n", + "\r\n" + ); + stream.write_all(req.as_bytes()).await.unwrap(); + stream.flush().await.unwrap(); + stream.write_all("b".repeat(1024).as_bytes()).await.unwrap(); + stream.flush().await.unwrap(); + + let fut = read_response(&mut stream, 1024); + let (resp_header, resp_body) = timeout(Duration::from_secs(5), fut).await.unwrap(); + assert_eq!(resp_header.status, 101); + assert_eq!(resp_header.headers["Upgrade"], "websocket"); + assert_eq!(resp_header.headers["Connection"], "upgrade"); + + let body = "b".repeat(1024); + assert_eq!(resp_body, body.as_bytes()); +} + +#[tokio::test] +async fn test_upgrade_body_after_101() { + // test content-length body is passed through after 101, + // and that ws payload is passed through afterwards + // use websocket server that flushes 101 after reading header + init(); + let _ = *WS_ECHO_RAW; + + let mut stream = TcpStream::connect("127.0.0.1:6147").await.unwrap(); + + let req = concat!( + "POST /upgrade_echo_body HTTP/1.1\r\n", + "Host: 127.0.0.1\r\n", + "Upgrade: websocket\r\n", + "Connection: Upgrade\r\n", + "X-Port: 9284\r\n", + "Content-Length: 5120\r\n", + "X-Expected-Body-Len: 5125\r\n", // include ws payload + "\r\n" + ); + stream.write_all(req.as_bytes()).await.unwrap(); + stream.flush().await.unwrap(); + stream + .write_all("b".repeat(5 * 1024).as_bytes()) + .await + .unwrap(); + stream.flush().await.unwrap(); + + // Read response header and any preread body first (before sending ws_payload) + let fut = read_response_header(&mut stream); + let (resp_header, resp_body) = timeout(Duration::from_secs(5), fut).await.unwrap(); + assert_eq!(resp_header.status, 101); + assert_eq!(resp_header.headers["Upgrade"], "websocket"); + assert_eq!(resp_header.headers["Connection"], "upgrade"); + + // Now send the websocket payload after receiving 101 + let ws_payload = "hello"; + stream.write_all(ws_payload.as_bytes()).await.unwrap(); + stream.flush().await.unwrap(); + + // Read the rest of the bytes (body + ws payload), subtracting preread body length + let expected_total_len = 5 * 1024 + ws_payload.len(); + let fut = read_response_body(&mut stream, resp_body, expected_total_len); + let resp_body = timeout(Duration::from_secs(5), fut).await.unwrap(); + + let body = "b".repeat(5 * 1024) + ws_payload; + assert_eq!(resp_body, body.as_bytes()); +} + #[tokio::test] async fn test_download_timeout() { init(); @@ -191,7 +367,7 @@ async fn test_download_timeout() { use tokio::time::sleep; let client = hyper::Client::new(); - let uri: hyper::Uri = "http://127.0.0.1:6147/download/".parse().unwrap(); + let uri: hyper::Uri = "http://127.0.0.1:6147/download_large/".parse().unwrap(); let req = hyper::Request::builder() .uri(uri) .header("x-write-timeout", "1") @@ -363,6 +539,61 @@ mod test_cache { assert_eq!(res.text().await.unwrap(), "no if headers detected\n"); } + #[tokio::test] + async fn test_cache_http10() { + // allow caching http1.0 from origin, but proxy as h1.1 downstream + init(); + let url = "http://127.0.0.1:6148/unique/test_cache_http10/now"; + + let res = reqwest::Client::new() + .get(url) + .header("x-upstream-fake-http10", "1") // fake http1.0 in upstream response filter + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.version(), Version::HTTP_11); + let headers = res.headers(); + let cache_miss_epoch = headers["x-epoch"].to_str().unwrap().parse::().unwrap(); + assert_eq!(headers["transfer-encoding"], "chunked"); + assert_eq!(headers["x-cache-status"], "miss"); + assert_eq!(res.text().await.unwrap(), "hello world"); + + let res = reqwest::Client::new() + .get(url) + .header("x-upstream-fake-http10", "1") + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.version(), Version::HTTP_11); + let headers = res.headers(); + let cache_hit_epoch = headers["x-epoch"].to_str().unwrap().parse::().unwrap(); + assert_eq!(headers["transfer-encoding"], "chunked"); + assert_eq!(headers["x-cache-status"], "hit"); + assert_eq!(res.text().await.unwrap(), "hello world"); + + assert_eq!(cache_miss_epoch, cache_hit_epoch); + + sleep(Duration::from_millis(1100)).await; // ttl is 1 + + let res = reqwest::Client::new() + .get(url) + .header("x-upstream-fake-http10", "1") + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.version(), Version::HTTP_11); + let headers = res.headers(); + let cache_expired_epoch = headers["x-epoch"].to_str().unwrap().parse::().unwrap(); + assert_eq!(headers["transfer-encoding"], "chunked"); + assert_eq!(headers["x-cache-status"], "expired"); + assert_eq!(res.text().await.unwrap(), "hello world"); + + assert!(cache_expired_epoch > cache_hit_epoch); + } + #[tokio::test] async fn test_cache_downstream_compression() { init(); @@ -907,6 +1138,55 @@ mod test_cache { assert_eq!(res.text().await.unwrap(), ""); } + #[tokio::test] + async fn test_cache_websocket_101() { + // Test the unlikely scenario in which users may want to cache WS + init(); + + // First request - should be a miss + let mut stream = TcpStream::connect("127.0.0.1:6148").await.unwrap(); + let req = concat!( + "GET /unique/test_cache_websocket_101/upgrade HTTP/1.1\r\n", + "Host: 127.0.0.1\r\n", + "Upgrade: websocket\r\n", + "Connection: Upgrade\r\n", + "X-Cache-Websocket: 1\r\n", + "\r\n" + ); + stream.write_all(req.as_bytes()).await.unwrap(); + stream.flush().await.unwrap(); + + let expected_payload = b"hello\n"; + let fut = read_response(&mut stream, expected_payload.len()); + let (resp_header, resp_body) = timeout(Duration::from_secs(5), fut).await.unwrap(); + + assert_eq!(resp_header.status, 101); + assert_eq!(resp_header.headers["Upgrade"], "websocket"); + assert_eq!(resp_header.headers["x-cache-status"], "miss"); + assert_eq!(resp_body, expected_payload); + + // Second request - should be a cache hit + let mut stream = TcpStream::connect("127.0.0.1:6148").await.unwrap(); + let req = concat!( + "GET /unique/test_cache_websocket_101/upgrade HTTP/1.1\r\n", + "Host: 127.0.0.1\r\n", + "Upgrade: websocket\r\n", + "Connection: Upgrade\r\n", + "X-Cache-Websocket: 1\r\n", + "\r\n" + ); + stream.write_all(req.as_bytes()).await.unwrap(); + stream.flush().await.unwrap(); + + let fut = read_response(&mut stream, expected_payload.len()); + let (resp_header, resp_body) = timeout(Duration::from_secs(5), fut).await.unwrap(); + + assert_eq!(resp_header.status, 101); + assert_eq!(resp_header.headers["Upgrade"], "websocket"); + assert_eq!(resp_header.headers["x-cache-status"], "hit"); + assert_eq!(resp_body, expected_payload); + } + #[tokio::test] async fn test_1xx_caching() { // 1xx shouldn't interfere with HTTP caching @@ -1581,8 +1861,8 @@ mod test_cache { .unwrap(); assert_eq!(res.status(), StatusCode::OK); let headers = res.headers(); - // cache lock timeout, disable cache - assert_eq!(headers["x-cache-status"], "no-cache"); + // cache lock timeout, try to replace lock + assert_eq!(headers["x-cache-status"], "miss"); assert_eq!(res.text().await.unwrap(), "hello world"); }); @@ -1599,26 +1879,16 @@ mod test_cache { .unwrap(); assert_eq!(res.status(), StatusCode::OK); let headers = res.headers(); - // this is now a miss because we will not timeout on cache lock + // this is now a hit because the second task cached from origin + // successfully // and will fetch from origin successfully - assert_eq!(headers["x-cache-status"], "miss"); + assert_eq!(headers["x-cache-status"], "hit"); assert_eq!(res.text().await.unwrap(), "hello world"); }); task1.await.unwrap(); task2.await.unwrap(); task3.await.unwrap(); - - let res = reqwest::Client::new() - .get(url) - .header("x-lock", "true") - .send() - .await - .unwrap(); - assert_eq!(res.status(), 200); - let headers = res.headers(); - assert_eq!(headers["x-cache-status"], "hit"); // the first request cached it - assert_eq!(res.text().await.unwrap(), "hello world"); } #[tokio::test] @@ -2348,6 +2618,108 @@ mod test_cache { assert_eq!(res.text().await.unwrap(), "hello world!"); } + #[tokio::test] + async fn test_caching_when_downstream_bails_uncacheable() { + init(); + let url = "http://127.0.0.1:6148/slow_body/test_caching_when_downstream_bails_uncacheable/"; + + tokio::spawn(async move { + let res = reqwest::Client::new() + .get(url) + .header("x-lock", "true") + .header("x-no-store", "1") + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + let headers = res.headers(); + assert_eq!(headers["x-cache-status"], "no-cache"); + // exit without res.text().await so that we bail early + }); + // sleep just a little to make sure the req above gets the cache lock + sleep(Duration::from_millis(50)).await; + + let res = reqwest::Client::new() + .get(url) + .header("x-lock", "true") + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + let headers = res.headers(); + // entirely new request made to upstream, since the response was uncacheable + assert_eq!(headers["x-cache-status"], "no-cache"); // due to cache lock give up + assert_eq!(res.text().await.unwrap(), "hello world!"); + } + + #[tokio::test] + async fn test_caching_when_downstream_bails_header() { + init(); + let url = "http://127.0.0.1:6148/unique/test_caching_when_downstream_bails_header/sleep"; + + tokio::spawn(async move { + // this should always time out + reqwest::Client::new() + .get(url) + .header("x-lock", "true") + .header("x-set-sleep", "2") + .timeout(Duration::from_secs(1)) + .send() + .await + .unwrap_err() + }); + // sleep after cache fill + sleep(Duration::from_millis(2500)).await; + + // next request should be a cache hit + let res = reqwest::Client::new() + .get(url) + .header("x-lock", "true") + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + let headers = res.headers(); + assert_eq!(headers["x-cache-status"], "hit"); + assert_eq!(res.text().await.unwrap(), "hello world"); + } + + #[tokio::test] + async fn test_caching_when_downstream_bails_header_uncacheable() { + init(); + let url = "http://127.0.0.1:6148/unique/test_caching_when_downstream_bails_header_uncacheable/sleep"; + + tokio::spawn(async move { + // this should always time out + reqwest::Client::new() + .get(url) + .header("x-lock", "true") + .header("x-set-sleep", "2") + .header("x-no-store", "1") + .timeout(Duration::from_secs(1)) + .send() + .await + .unwrap_err() + // note that while the downstream error is ignored, + // once the response is uncacheable we will still attempt to write + // downstream and find a broken connection that terminates the request + }); + // sleep after cache fill + sleep(Duration::from_millis(2500)).await; + + // next request should be a cache miss, as the previous fill was uncacheable + let res = reqwest::Client::new() + .get(url) + .header("x-lock", "true") + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + let headers = res.headers(); + assert_eq!(headers["x-cache-status"], "miss"); + assert_eq!(res.text().await.unwrap(), "hello world"); + } + async fn send_vary_req_with_headers_with_dups( url: &str, vary_field: &str, diff --git a/pingora-proxy/tests/utils/cert.rs b/pingora-proxy/tests/utils/cert.rs index 7594afa4..5428f71b 100644 --- a/pingora-proxy/tests/utils/cert.rs +++ b/pingora-proxy/tests/utils/cert.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-proxy/tests/utils/conf/origin/conf/nginx.conf b/pingora-proxy/tests/utils/conf/origin/conf/nginx.conf index 6914344b..f19c974c 100644 --- a/pingora-proxy/tests/utils/conf/origin/conf/nginx.conf +++ b/pingora-proxy/tests/utils/conf/origin/conf/nginx.conf @@ -99,7 +99,6 @@ http { # increase max body size for /upload/ test client_max_body_size 128m; - #charset koi8-r; #access_log logs/host.access.log main; @@ -319,6 +318,19 @@ http { } } + location /download_large/ { + content_by_lua_block { + ngx.req.read_body() + local chunk = string.rep("A", 1048576) -- 1MB chunk + local total_size = 128 * 1048576 -- 128MB total + ngx.header["Content-Length"] = total_size + for i = 1, 128 do + ngx.print(chunk) + ngx.flush() + end + } + } + location /tls_verify { keepalive_timeout 0; return 200; @@ -490,6 +502,11 @@ http { ngx.sleep(sleep_sec) ngx.print("!") } + header_filter_by_lua_block { + if ngx.var.http_x_no_store then + ngx.header["Cache-control"] = "no-store" + end + } } location /content_type { @@ -499,6 +516,29 @@ http { return 200 "hello world"; } + location /upgrade { + content_by_lua_block { + ngx.status = 101 + ngx.header['Upgrade'] = 'websocket' + ngx.header['Connection'] = 'Upgrade' + ngx.say('hello') + } + } + + location /upgrade_echo_body { + rewrite_by_lua_block { + ngx.req.read_body() + local data = ngx.req.get_body_data() + ngx.status = 101 + ngx.header['Upgrade'] = 'websocket' + ngx.header['Connection'] = 'Upgrade' + + if data then + ngx.print(data) + end + } + } + #error_page 404 /404.html; # redirect server error pages to the static page /50x.html diff --git a/pingora-proxy/tests/utils/mock_origin.rs b/pingora-proxy/tests/utils/mock_origin.rs index f3564dbe..74840e19 100644 --- a/pingora-proxy/tests/utils/mock_origin.rs +++ b/pingora-proxy/tests/utils/mock_origin.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-proxy/tests/utils/mod.rs b/pingora-proxy/tests/utils/mod.rs index 7a70ae4f..3ec2fa28 100644 --- a/pingora-proxy/tests/utils/mod.rs +++ b/pingora-proxy/tests/utils/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-proxy/tests/utils/server_utils.rs b/pingora-proxy/tests/utils/server_utils.rs index 73629c8d..0df71336 100644 --- a/pingora-proxy/tests/utils/server_utils.rs +++ b/pingora-proxy/tests/utils/server_utils.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ use super::cert; use async_trait::async_trait; use clap::Parser; -use http::header::{ACCEPT_ENCODING, VARY}; +use http::header::{ACCEPT_ENCODING, CONTENT_LENGTH, TRANSFER_ENCODING, VARY}; use http::HeaderValue; use log::error; use once_cell::sync::Lazy; @@ -26,8 +26,8 @@ use pingora_cache::key::HashBinary; use pingora_cache::lock::CacheKeyLockImpl; use pingora_cache::{ eviction::simple_lru::Manager, filters::resp_cacheable, lock::CacheLock, predictor::Predictor, - set_compression_dict_path, CacheMeta, CacheMetaDefaults, CachePhase, MemCache, NoCacheReason, - RespCacheable, + set_compression_dict_path, CacheKey, CacheMeta, CacheMetaDefaults, CachePhase, MemCache, + NoCacheReason, RespCacheable, }; use pingora_cache::{ CacheOptionOverrides, ForcedFreshness, HitHandler, PurgeType, VarianceBuilder, @@ -38,7 +38,7 @@ use pingora_core::protocols::{ http::error_resp::gen_error_response, l4::socket::SocketAddr, Digest, }; use pingora_core::server::configuration::Opt; -use pingora_core::services::Service; +use pingora_core::services::{Service, ServiceWithDependents}; use pingora_core::upstreams::peer::HttpPeer; use pingora_core::utils::tls::CertKey; use pingora_error::{Error, ErrorSource, ErrorType::*, Result}; @@ -47,7 +47,7 @@ use pingora_proxy::{FailToProxy, ProxyHttp, Session}; use std::collections::{HashMap, HashSet}; use std::sync::Arc; use std::thread; -use std::time::Duration; +use std::time::{Duration, SystemTime}; pub struct ExampleProxyHttps {} @@ -489,6 +489,35 @@ impl ProxyHttp for ExampleProxyCache { Ok(()) } + /// Reference `cache_key_callback` implementation for integration tests. + /// + /// Builds the primary key as `{host}{path_and_query}` from the request. + /// This is **not production ready**: it does not account for `Vary`, custom + /// request filters, or scheme differences. See the rustdoc on + /// [`ProxyHttp::cache_key_callback`] for details. + fn cache_key_callback(&self, session: &Session, _ctx: &mut Self::CTX) -> Result { + let req_header = session.req_header(); + + let host = req_header + .headers + .get(http::header::HOST) + .and_then(|v| v.to_str().ok()) + .or_else(|| req_header.uri.authority().map(|a| a.as_str())) + .unwrap_or(""); + + let path_and_query = req_header + .uri + .path_and_query() + .map(|pq| pq.as_str()) + .unwrap_or("/"); + + Ok(CacheKey::new( + String::new(), + format!("{host}{path_and_query}"), + String::new(), + )) + } + async fn cache_hit_filter( &self, session: &mut Session, @@ -574,10 +603,26 @@ impl ProxyHttp for ExampleProxyCache { fn response_cache_filter( &self, - _session: &Session, + session: &Session, resp: &ResponseHeader, _ctx: &mut Self::CTX, ) -> Result { + // Allow testing the unlikely case of caching a 101 response + if resp.status == 101 + && session + .req_header() + .headers + .contains_key("x-cache-websocket") + { + return Ok(RespCacheable::Cacheable(CacheMeta::new( + SystemTime::now() + Duration::from_secs(5), + SystemTime::now(), + 0, + 0, + resp.clone(), + ))); + } + let cc = CacheControl::from_resp_headers(resp); Ok(resp_cacheable( cc.as_ref(), @@ -589,11 +634,21 @@ impl ProxyHttp for ExampleProxyCache { async fn upstream_response_filter( &self, - _session: &mut Session, + session: &mut Session, upstream_response: &mut ResponseHeader, ctx: &mut Self::CTX, ) -> Result<()> { ctx.upstream_status = Some(upstream_response.status.into()); + if session + .req_header() + .headers + .contains_key("x-upstream-fake-http10") + { + // TODO to simulate an actual http1.0 origin + upstream_response.set_version(http::Version::HTTP_10); + upstream_response.remove_header(&CONTENT_LENGTH); + upstream_response.remove_header(&TRANSFER_ENCODING); + } Ok(()) } @@ -696,7 +751,7 @@ impl ProxyHttp for ExampleProxyCache { error: Option<&Error>, // None when it is called during stale while revalidate ) -> bool { // enable serve stale while updating - error.map_or(true, |e| e.esource() == &ErrorSource::Upstream) + error.is_none_or(|e| e.esource() == &ErrorSource::Upstream) } fn is_purge(&self, session: &Session, _ctx: &Self::CTX) -> bool { @@ -712,7 +767,8 @@ fn test_main() { "-c".into(), "tests/pingora_conf.yaml".into(), ]; - let mut my_server = pingora_core::server::Server::new(Some(Opt::parse_from(opts))).unwrap(); + let mut my_server = + pingora_core::server::Server::new(Some(Opt::parse_from_args(opts))).unwrap(); my_server.bootstrap(); let mut proxy_service_http = @@ -721,6 +777,14 @@ fn test_main() { #[cfg(unix)] proxy_service_http.add_uds("/tmp/pingora_proxy.sock", None); + let mut proxy_service_http_connect = + pingora_proxy::http_proxy_service(&my_server.configuration, ExampleProxyHttp {}); + let http_logic = proxy_service_http_connect.app_logic_mut().unwrap(); + let mut http_server_options = HttpServerOptions::default(); + http_server_options.allow_connect_method_proxying = true; + http_logic.server_options = Some(http_server_options); + proxy_service_http_connect.add_tcp("0.0.0.0:6160"); + let mut proxy_service_h2c = pingora_proxy::http_proxy_service(&my_server.configuration, ExampleProxyHttp {}); @@ -730,7 +794,7 @@ fn test_main() { http_logic.server_options = Some(http_server_options); proxy_service_h2c.add_tcp("0.0.0.0:6146"); - let mut proxy_service_https_opt: Option> = None; + let mut proxy_service_https_opt: Option> = None; #[cfg(feature = "any_tls")] { @@ -761,9 +825,10 @@ fn test_main() { proxy_service_cache.add_tls_with_settings("0.0.0.0:6153", None, tls_settings); } - let mut services: Vec> = vec![ + let mut services: Vec> = vec![ Box::new(proxy_service_h2c), Box::new(proxy_service_http), + Box::new(proxy_service_http_connect), Box::new(proxy_service_cache), ]; diff --git a/pingora-proxy/tests/utils/websocket/mod.rs b/pingora-proxy/tests/utils/websocket/mod.rs new file mode 100644 index 00000000..f416b702 --- /dev/null +++ b/pingora-proxy/tests/utils/websocket/mod.rs @@ -0,0 +1,5 @@ +mod ws_echo; +mod ws_echo_raw; + +pub use ws_echo::WS_ECHO; +pub use ws_echo_raw::WS_ECHO_RAW; diff --git a/pingora-proxy/tests/utils/websocket.rs b/pingora-proxy/tests/utils/websocket/ws_echo.rs similarity index 65% rename from pingora-proxy/tests/utils/websocket.rs rename to pingora-proxy/tests/utils/websocket/ws_echo.rs index 92b35e95..5c610320 100644 --- a/pingora-proxy/tests/utils/websocket.rs +++ b/pingora-proxy/tests/utils/websocket/ws_echo.rs @@ -1,14 +1,29 @@ +// Copyright 2025 Cloudflare, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + use std::{io::Error, thread, time::Duration}; use futures_util::{SinkExt, StreamExt}; use log::debug; -use once_cell::sync::Lazy; +use std::sync::LazyLock; use tokio::{ net::{TcpListener, TcpStream}, runtime::Builder, }; -pub static WS_ECHO: Lazy = Lazy::new(init); +pub static WS_ECHO: LazyLock = LazyLock::new(init); +pub const WS_ECHO_ORIGIN_PORT: u16 = 9283; fn init() -> bool { thread::spawn(move || { @@ -18,7 +33,9 @@ fn init() -> bool { .build() .unwrap(); runtime.block_on(async move { - server("127.0.0.1:9283").await.unwrap(); + server(&format!("127.0.0.1:{WS_ECHO_ORIGIN_PORT}")) + .await + .unwrap(); }) }); thread::sleep(Duration::from_millis(200)); diff --git a/pingora-proxy/tests/utils/websocket/ws_echo_raw.rs b/pingora-proxy/tests/utils/websocket/ws_echo_raw.rs new file mode 100644 index 00000000..89b186f2 --- /dev/null +++ b/pingora-proxy/tests/utils/websocket/ws_echo_raw.rs @@ -0,0 +1,176 @@ +// Copyright 2025 Cloudflare, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{thread, time::Duration}; + +use futures_util::{SinkExt, StreamExt}; +use log::debug; +use pingora_error::{Error, ErrorType::*, OrErr, Result}; +use pingora_http::RequestHeader; +use std::sync::LazyLock; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::{ + tcp::{OwnedReadHalf, OwnedWriteHalf}, + TcpListener, TcpStream, + }, + runtime::Builder, +}; + +pub static WS_ECHO_RAW: LazyLock = LazyLock::new(init); +pub const WS_ECHO_RAW_ORIGIN_PORT: u16 = 9284; + +fn init() -> bool { + thread::spawn(move || { + let runtime = Builder::new_current_thread() + .thread_name("websocket raw echo") + .enable_all() + .build() + .unwrap(); + runtime.block_on(async move { + server(&format!("127.0.0.1:{WS_ECHO_RAW_ORIGIN_PORT}")) + .await + .unwrap(); + }) + }); + thread::sleep(Duration::from_millis(200)); + true +} + +async fn server(addr: &str) -> Result<(), Error> { + let listener = TcpListener::bind(&addr).await.unwrap(); + while let Ok((stream, _)) = listener.accept().await { + tokio::spawn(handle_connection(stream)); + } + Ok(()) +} + +async fn read_request_header(stream: &mut TcpStream) -> Result<(RequestHeader, Vec)> { + fn parse_request_header(buf: &[u8]) -> Result { + let mut headers = vec![httparse::EMPTY_HEADER; 256]; + let mut parsed = httparse::Request::new(&mut headers); + match parsed + .parse(buf) + .or_err(ReadError, "request header parse error")? + { + httparse::Status::Complete(_) => { + let mut req = RequestHeader::build( + parsed.method.unwrap_or(""), + parsed.path.unwrap_or("").as_bytes(), + Some(parsed.headers.len()), + )?; + for header in parsed.headers.iter() { + req.append_header(header.name.to_string(), header.value) + .unwrap(); + } + Ok(req) + } + _ => Error::e_explain(ReadError, "should have full request header"), + } + } + + let mut request = vec![]; + let mut header_end = 0; + let mut buf = [0; 1024]; + loop { + let n = stream + .read(&mut buf) + .await + .or_err(ReadError, "while reading request header")?; + request.extend_from_slice(&buf[..n]); + let mut end_of_header = false; + for (i, w) in request.windows(4).enumerate() { + if w == b"\r\n\r\n" { + end_of_header = true; + header_end = i + 4; + break; + } + } + if end_of_header { + break; + } + } + Ok(( + parse_request_header(&request[..header_end])?, + request[header_end..].to_vec(), + )) +} + +async fn read_body_until_close( + stream: &mut OwnedReadHalf, +) -> Result>, std::io::Error> { + let mut buf = [0; 1024]; + let n = stream.read(&mut buf).await?; + if n == 0 { + return Ok(None); + } + Ok(Some(buf[..n].to_vec())) +} + +async fn write_body_until_close( + stream: &mut OwnedWriteHalf, + body: &[u8], +) -> Result, std::io::Error> { + let n = stream.write(body).await?; + Ok((n != 0).then_some(n)) +} + +async fn handle_connection(mut stream: TcpStream) -> Result<()> { + let (header, preread_body) = read_request_header(&mut stream).await?; + + // if x-expected-body-len unset, continue to read until stream is closed + let expected_body_len = header + .headers + .get("x-expected-body-len") + .and_then(|v| std::str::from_utf8(v.as_bytes()).ok()) + .and_then(|s| s.parse().ok()); + + let resp_raw = + b"HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: websocket\r\n\r\n"; + stream + .write_all(resp_raw) + .await + .or_err(WriteError, "while writing 101")?; + + let (mut stream_read, mut stream_write) = stream.into_split(); + let mut request_body = preread_body; + let mut body_read = request_body.len(); + let mut body_read_done = false; + + loop { + tokio::select! { + res = read_body_until_close(&mut stream_read), if !body_read_done => { + let Some(buf) = res.or_err(ReadError, "while reading body")? else { + return Ok(()); + }; + body_read += buf.len(); + body_read_done = expected_body_len.is_some_and(|len| body_read >= len); + request_body.extend_from_slice(&buf[..]); + } + res = write_body_until_close(&mut stream_write, &request_body[..]), if !request_body.is_empty() => { + let Some(n) = res.or_err(WriteError, "while writing body")? else { + return Ok(()); + }; + request_body = request_body[n..].to_vec(); + } + else => break, + } + } + if let Some(expected) = expected_body_len { + if body_read > expected { + return Error::e_explain(ReadError, "read {body_read} bytes, expected {expected}"); + } + } + Ok(()) +} diff --git a/pingora-runtime/Cargo.toml b/pingora-runtime/Cargo.toml index de419400..5de4f26b 100644 --- a/pingora-runtime/Cargo.toml +++ b/pingora-runtime/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pingora-runtime" -version = "0.6.0" +version = "0.8.0" authors = ["Yuchen Wu "] license = "Apache-2.0" edition = "2021" diff --git a/pingora-runtime/benches/hello.rs b/pingora-runtime/benches/hello.rs index 3460efb1..271447e5 100644 --- a/pingora-runtime/benches/hello.rs +++ b/pingora-runtime/benches/hello.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-runtime/src/lib.rs b/pingora-runtime/src/lib.rs index 07883400..a0468f4f 100644 --- a/pingora-runtime/src/lib.rs +++ b/pingora-runtime/src/lib.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-rustls/Cargo.toml b/pingora-rustls/Cargo.toml index f2540349..efa377bf 100644 --- a/pingora-rustls/Cargo.toml +++ b/pingora-rustls/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pingora-rustls" -version = "0.6.0" +version = "0.8.0" license = "Apache-2.0" edition = "2021" repository = "https://github.com/cloudflare/pingora" @@ -16,7 +16,7 @@ path = "src/lib.rs" [dependencies] log = "0.4.21" -pingora-error = { version = "0.6.0", path = "../pingora-error"} +pingora-error = { version = "0.8.0", path = "../pingora-error"} ring = "0.17.12" rustls = "0.23.12" rustls-native-certs = "0.7.1" diff --git a/pingora-rustls/src/lib.rs b/pingora-rustls/src/lib.rs index 51672c42..097a8da5 100644 --- a/pingora-rustls/src/lib.rs +++ b/pingora-rustls/src/lib.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -24,14 +24,23 @@ use std::path::Path; use log::warn; pub use no_debug::{Ellipses, NoDebug, WithTypeInfo}; use pingora_error::{Error, ErrorType, OrErr, Result}; -pub use rustls::{version, ClientConfig, RootCertStore, ServerConfig, Stream}; + +pub use rustls::server::danger::{ClientCertVerified, ClientCertVerifier}; +pub use rustls::server::{ClientCertVerifierBuilder, WebPkiClientVerifier}; +pub use rustls::{ + client::WebPkiServerVerifier, version, CertificateError, ClientConfig, DigitallySignedStruct, + Error as RusTlsError, KeyLogFile, RootCertStore, ServerConfig, SignatureScheme, Stream, +}; pub use rustls_native_certs::load_native_certs; use rustls_pemfile::Item; -pub use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName}; +pub use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime}; pub use tokio_rustls::client::TlsStream as ClientTlsStream; pub use tokio_rustls::server::TlsStream as ServerTlsStream; pub use tokio_rustls::{Accept, Connect, TlsAcceptor, TlsConnector, TlsStream}; +// This allows to skip certificate verification. Be highly cautious. +pub use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}; + /// Load the given file from disk as a buffered reader and use the pingora Error /// type instead of the std::io version fn load_file

(path: P) -> Result> diff --git a/pingora-s2n/Cargo.toml b/pingora-s2n/Cargo.toml index 9ecf1087..0dbd1103 100644 --- a/pingora-s2n/Cargo.toml +++ b/pingora-s2n/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pingora-s2n" -version = "0.6.0" +version = "0.8.0" license = "Apache-2.0" edition = "2021" repository = "https://github.com/cloudflare/pingora" @@ -15,7 +15,7 @@ name = "pingora_s2n" path = "src/lib.rs" [dependencies] -pingora-error = {version = "0.6.0", path = "../pingora-error"} +pingora-error = { version = "0.8.0", path = "../pingora-error"} ring = "0.17.12" s2n-tls = "0.3" s2n-tls-tokio = "0.3" diff --git a/pingora-s2n/src/lib.rs b/pingora-s2n/src/lib.rs index 2a7a476e..aef1cef3 100644 --- a/pingora-s2n/src/lib.rs +++ b/pingora-s2n/src/lib.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-timeout/Cargo.toml b/pingora-timeout/Cargo.toml index ff14283c..c8d615c7 100644 --- a/pingora-timeout/Cargo.toml +++ b/pingora-timeout/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pingora-timeout" -version = "0.6.0" +version = "0.8.0" authors = ["Yuchen Wu "] license = "Apache-2.0" edition = "2021" diff --git a/pingora-timeout/benches/benchmark.rs b/pingora-timeout/benches/benchmark.rs index ae32556c..64fd053d 100644 --- a/pingora-timeout/benches/benchmark.rs +++ b/pingora-timeout/benches/benchmark.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-timeout/src/fast_timeout.rs b/pingora-timeout/src/fast_timeout.rs index 8fd22908..27535e11 100644 --- a/pingora-timeout/src/fast_timeout.rs +++ b/pingora-timeout/src/fast_timeout.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-timeout/src/lib.rs b/pingora-timeout/src/lib.rs index c0498c3e..707f7be8 100644 --- a/pingora-timeout/src/lib.rs +++ b/pingora-timeout/src/lib.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora-timeout/src/timer.rs b/pingora-timeout/src/timer.rs index e0f631a7..c6c587e0 100644 --- a/pingora-timeout/src/timer.rs +++ b/pingora-timeout/src/timer.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora/Cargo.toml b/pingora/Cargo.toml index 864ae306..cb16664e 100644 --- a/pingora/Cargo.toml +++ b/pingora/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pingora" -version = "0.6.0" +version = "0.8.0" authors = ["Yuchen Wu "] license = "Apache-2.0" edition = "2021" @@ -22,21 +22,21 @@ features = ["document-features"] rustdoc-args = ["--cfg", "docsrs"] [dependencies] -pingora-core = { version = "0.6.0", path = "../pingora-core", default-features = false } -pingora-http = { version = "0.6.0", path = "../pingora-http" } -pingora-timeout = { version = "0.6.0", path = "../pingora-timeout" } -pingora-load-balancing = { version = "0.6.0", path = "../pingora-load-balancing", optional = true, default-features = false } -pingora-proxy = { version = "0.6.0", path = "../pingora-proxy", optional = true, default-features = false } -pingora-cache = { version = "0.6.0", path = "../pingora-cache", optional = true, default-features = false } +pingora-core = { version = "0.8.0", path = "../pingora-core", default-features = false } +pingora-http = { version = "0.8.0", path = "../pingora-http" } +pingora-timeout = { version = "0.8.0", path = "../pingora-timeout" } +pingora-load-balancing = { version = "0.8.0", path = "../pingora-load-balancing", optional = true, default-features = false } +pingora-proxy = { version = "0.8.0", path = "../pingora-proxy", optional = true, default-features = false } +pingora-cache = { version = "0.8.0", path = "../pingora-cache", optional = true, default-features = false } # Only used for documenting features, but doesn't work in any other dependency # group :( document-features = { version = "0.2.10", optional = true } [dev-dependencies] -clap = { version = "3.2.25", features = ["derive"] } +clap = { version = "4.5", features = ["derive"] } tokio = { workspace = true, features = ["rt-multi-thread", "signal"] } -env_logger = "0.9" +env_logger = "0.11" reqwest = { version = "0.11", features = ["rustls"], default-features = false } hyper = "0.14" async-trait = { workspace = true } @@ -126,8 +126,23 @@ time = [] ## Enable sentry for error notifications sentry = ["pingora-core/sentry"] +## Enable pre-TLS connection filtering +connection_filter = [ + "pingora-core/connection_filter", + "pingora-proxy?/connection_filter", +] + + # These features are intentionally not documented openssl_derived = ["any_tls"] any_tls = [] patched_http1 = ["pingora-core/patched_http1"] -document-features = ["dep:document-features", "proxy", "lb", "cache", "time", "sentry"] +document-features = [ + "dep:document-features", + "proxy", + "lb", + "cache", + "time", + "sentry", + "connection_filter" +] diff --git a/pingora/examples/app/echo.rs b/pingora/examples/app/echo.rs index 97e449df..fd1daeb4 100644 --- a/pingora/examples/app/echo.rs +++ b/pingora/examples/app/echo.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora/examples/app/mod.rs b/pingora/examples/app/mod.rs index a9fa06e8..1f6c3e61 100644 --- a/pingora/examples/app/mod.rs +++ b/pingora/examples/app/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora/examples/app/proxy.rs b/pingora/examples/app/proxy.rs index 042b5112..4760957a 100644 --- a/pingora/examples/app/proxy.rs +++ b/pingora/examples/app/proxy.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora/examples/client.rs b/pingora/examples/client.rs index 6eb29648..30be7b2f 100644 --- a/pingora/examples/client.rs +++ b/pingora/examples/client.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,7 +13,6 @@ // limitations under the License. use pingora::{connectors::http::Connector, prelude::*}; -use pingora_http::RequestHeader; use regex::Regex; #[tokio::main] diff --git a/pingora/examples/server.rs b/pingora/examples/server.rs index fffcb1cc..0a055acc 100644 --- a/pingora/examples/server.rs +++ b/pingora/examples/server.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -20,7 +20,8 @@ use pingora::protocols::TcpKeepalive; use pingora::server::configuration::Opt; use pingora::server::{Server, ShutdownWatch}; use pingora::services::background::{background_service, BackgroundService}; -use pingora::services::{listening::Service as ListeningService, Service}; +use pingora::services::listening::Service as ListeningService; +use pingora::services::ServiceWithDependents; use async_trait::async_trait; use clap::Parser; @@ -190,7 +191,7 @@ pub fn main() { let background_service = background_service("example", ExampleBackgroundService {}); - let services: Vec> = vec![ + let services: Vec> = vec![ Box::new(echo_service), Box::new(echo_service_http), Box::new(proxy_service), diff --git a/pingora/examples/service/echo.rs b/pingora/examples/service/echo.rs index 83b46ed4..a2e0f32e 100644 --- a/pingora/examples/service/echo.rs +++ b/pingora/examples/service/echo.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora/examples/service/mod.rs b/pingora/examples/service/mod.rs index a9fa06e8..1f6c3e61 100644 --- a/pingora/examples/service/mod.rs +++ b/pingora/examples/service/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora/examples/service/proxy.rs b/pingora/examples/service/proxy.rs index 39de498d..1c6a1df9 100644 --- a/pingora/examples/service/proxy.rs +++ b/pingora/examples/service/proxy.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pingora/src/lib.rs b/pingora/src/lib.rs index a102050e..e72cb28c 100644 --- a/pingora/src/lib.rs +++ b/pingora/src/lib.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/tinyufo/Cargo.toml b/tinyufo/Cargo.toml index 08a4c18b..cbfedbba 100644 --- a/tinyufo/Cargo.toml +++ b/tinyufo/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "TinyUFO" -version = "0.6.0" +version = "0.8.0" authors = ["Yuchen Wu "] edition = "2021" license = "Apache-2.0" diff --git a/tinyufo/benches/bench_hit_ratio.rs b/tinyufo/benches/bench_hit_ratio.rs index dcd666c5..4c162fbe 100644 --- a/tinyufo/benches/bench_hit_ratio.rs +++ b/tinyufo/benches/bench_hit_ratio.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/tinyufo/benches/bench_memory.rs b/tinyufo/benches/bench_memory.rs index cb8f3605..2f770027 100644 --- a/tinyufo/benches/bench_memory.rs +++ b/tinyufo/benches/bench_memory.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/tinyufo/benches/bench_perf.rs b/tinyufo/benches/bench_perf.rs index 5d05b8b9..cb0638d6 100644 --- a/tinyufo/benches/bench_perf.rs +++ b/tinyufo/benches/bench_perf.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/tinyufo/src/buckets.rs b/tinyufo/src/buckets.rs index 182123cb..d74ab6bf 100644 --- a/tinyufo/src/buckets.rs +++ b/tinyufo/src/buckets.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -36,7 +36,7 @@ impl Compact { Self(shard_array.into_boxed_slice()) } - pub fn get(&self, key: &Key) -> Option>> { + pub fn get(&self, key: &Key) -> Option>> { let shard = *key as usize % self.0.len(); self.0[shard].get(key) } diff --git a/tinyufo/src/estimation.rs b/tinyufo/src/estimation.rs index bd6c764a..8e187931 100644 --- a/tinyufo/src/estimation.rs +++ b/tinyufo/src/estimation.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/tinyufo/src/lib.rs b/tinyufo/src/lib.rs index a8509e21..4064a356 100644 --- a/tinyufo/src/lib.rs +++ b/tinyufo/src/lib.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Cloudflare, Inc. +// Copyright 2026 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License.