From 0f8259432f4f06851a361e18cbd5ecbd229f2f19 Mon Sep 17 00:00:00 2001 From: Iain McGinniss <309153+iainmcgin@users.noreply.github.com> Date: Tue, 5 May 2026 20:55:49 +0000 Subject: [PATCH 1/3] axum: add serve_tls helper for peer-identity passthrough axum::serve takes a plain TcpListener with no hook for terminating TLS, so axum + mTLS deployments must hand-roll the rustls accept loop and per-connection extension plumbing to give handlers the same PeerAddr/PeerCerts view the standalone Server provides automatically. connectrpc::axum::serve_tls is a drop-in axum::serve replacement that owns the accept loop, terminates TLS with a slowloris-bounded handshake timeout, captures the remote address and verified client cert chain once per connection, and stamps both into request extensions. Handler code that reads ctx.extensions.get::() is now portable between the standalone Server and an axum app. The returned ServeTls mirrors axum::serve::Serve: a configurable IntoFuture builder with .tls_handshake_timeout() and .with_graceful_shutdown() knobs. Also adds tokio/macros to the server feature set: the accept loop in both server.rs and axum.rs uses tokio::select!, but the macro feature was previously only reachable via dev-dep unification, so a downstream crate enabling server without tokio/macros would fail to compile. Refs #49. --- connectrpc/Cargo.toml | 4 + connectrpc/src/axum.rs | 599 +++++++++++++++++++++++++++++++++++++++ connectrpc/src/lib.rs | 9 + connectrpc/src/server.rs | 17 +- docs/guide.md | 19 +- 5 files changed, 638 insertions(+), 10 deletions(-) create mode 100644 connectrpc/src/axum.rs diff --git a/connectrpc/Cargo.toml b/connectrpc/Cargo.toml index 35bff22..c084756 100644 --- a/connectrpc/Cargo.toml +++ b/connectrpc/Cargo.toml @@ -102,6 +102,10 @@ server = [ "dep:libc", "dep:tower-http", "tokio/net", + # The accept loop uses `tokio::select!`. Surfaced only when downstream + # consumers enable `server` without also enabling `tokio/macros` + # themselves; `--all-targets` (CI) hides it via dev-dep unification. + "tokio/macros", ] # HTTP client with connection pooling diff --git a/connectrpc/src/axum.rs b/connectrpc/src/axum.rs new file mode 100644 index 0000000..ea040d5 --- /dev/null +++ b/connectrpc/src/axum.rs @@ -0,0 +1,599 @@ +//! TLS-aware `axum::serve` counterpart that exposes peer identity to handlers. +//! +//! [`Router::into_axum_service`](crate::Router::into_axum_service) and +//! [`Router::into_axum_router`](crate::Router::into_axum_router) cover the +//! plaintext path: mount your ConnectRPC routes on an `axum::Router` and +//! hand the result to `axum::serve`. This module fills the TLS gap. +//! +//! `axum::serve` accepts a plain [`TcpListener`] and has no hook for +//! terminating TLS. The standalone [`Server`](crate::Server), by contrast, +//! owns the rustls accept loop and so can capture [`PeerAddr`]/[`PeerCerts`] +//! once per connection and inject them into every request's extensions for +//! handlers to read via `ctx.extensions.get::()`. Without help, an axum + +//! mTLS deployment has to reimplement that accept loop and per-connection +//! plumbing by hand for handlers to get the same view. +//! +//! [`serve_tls`] is that help: it serves an `axum::Router`, terminates TLS, +//! captures peer identity, and stamps it into request extensions. Handler +//! code that reads `ctx.extensions.get::()` is then portable +//! between the standalone `Server` and an axum app — the hosting choice no +//! longer leaks into your authorization logic. +//! +//! ```rust,ignore +//! // Plaintext: axum's built-in serve. +//! axum::serve(listener, app).await?; +//! +//! // TLS with PeerAddr/PeerCerts passthrough. +//! connectrpc::axum::serve_tls(listener, app, tls_config).await?; +//! ``` +//! +//! # Differences from `axum::serve` +//! +//! `serve_tls` is the TLS counterpart to `axum::serve(listener, router)` for +//! the common `axum::Router` case. It is intentionally less generic: +//! +//! - **Service type.** `serve_tls` accepts a concrete `axum::Router`, not +//! the make-service forms `axum::serve` is generic over. There is no +//! `into_make_service_with_connect_info::()` equivalent because +//! `serve_tls` already injects [`PeerAddr`] (the same socket address) into +//! request extensions; read that instead of `ConnectInfo`. +//! A `Router` with state must have `.with_state(...)` applied first. +//! - **`PeerCerts` is conditional.** It is only inserted when the +//! [`rustls::ServerConfig`] requests client authentication *and* the peer +//! presents a chain rustls verifies. With `with_no_client_auth()` (or a +//! permissive verifier and a client that sends no cert), only [`PeerAddr`] +//! is present. Handlers must treat `ctx.extensions.get::()` as +//! optional. +//! - **ALPN.** The TLS terminator speaks the protocol ALPN selects. To allow +//! HTTP/2 (required for gRPC; preferred for Connect streaming), set +//! `server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]` +//! before passing it in. Without ALPN, hyper falls back to HTTP/1.1. +//! - **No automatic panic catching.** Unlike the standalone +//! [`Server`](crate::Server), `serve_tls` does not wrap your `axum::Router` +//! in `tower_http::catch_panic::CatchPanicLayer` (`axum::serve` doesn't +//! either). If you want a panicking handler to surface as a Connect error +//! rather than a dropped connection, add the layer yourself. +//! +//! Available only with both the `axum` and `server-tls` features enabled. + +use std::future::{Future, IntoFuture}; +use std::pin::Pin; +use std::sync::Arc; +use std::time::Duration; + +use hyper_util::rt::{TokioExecutor, TokioIo}; +use hyper_util::server::conn::auto::Builder as AutoBuilder; +use hyper_util::server::graceful::GracefulShutdown; +use tokio::net::TcpListener; +use tower::ServiceExt; + +use crate::server::{ + DEFAULT_TLS_HANDSHAKE_TIMEOUT, PeerAddr, PeerCerts, is_transient_accept_error, +}; + +/// Serve an `axum::Router` over TLS, exposing peer identity to handlers. +/// +/// The TLS counterpart to `axum::serve(listener, router)` for when handlers +/// need [`PeerAddr`] and [`PeerCerts`] in request extensions — the same +/// convention the standalone [`Server::with_tls`](crate::Server::with_tls) +/// uses. The accept loop terminates TLS with `tokio-rustls`, captures the +/// remote address and any verified client certificate chain, then injects +/// both into every request before handing off to the axum service. +/// [`PeerCerts`] is only present when `tls_config` requests client +/// authentication and the peer presented a chain rustls verified. +/// +/// Like the standalone server, the TLS handshake is bounded by a +/// [`DEFAULT_TLS_HANDSHAKE_TIMEOUT`] to prevent slowloris-style connection +/// exhaustion; tune it with [`ServeTls::tls_handshake_timeout`]. +/// +/// The returned [`ServeTls`] resolves once the listener stops accepting and +/// in-flight connections drain (after [`ServeTls::with_graceful_shutdown`]'s +/// signal fires) or when a non-transient accept error occurs. +/// +/// See the [module docs](self) for the differences from `axum::serve`, +/// including ALPN setup and panic-handling expectations. +/// +/// ```rust,no_run +/// # use std::sync::Arc; +/// # async fn demo(connect_router: connectrpc::Router, tls_config: Arc, +/// # shutdown_signal: tokio::sync::oneshot::Receiver<()>) +/// # -> Result<(), Box> { +/// let app = axum::Router::new() +/// .route("/health", axum::routing::get(|| async { "OK" })) +/// .fallback_service(connect_router.into_axum_service()); +/// +/// let listener = tokio::net::TcpListener::bind("0.0.0.0:8443").await?; +/// connectrpc::axum::serve_tls(listener, app, tls_config) +/// .with_graceful_shutdown(async { shutdown_signal.await.ok(); }) +/// .await?; +/// # Ok(()) +/// # } +/// ``` +/// +/// # Errors +/// +/// The future resolves to `Err` only for non-transient I/O errors from the +/// underlying `accept(2)` (for example, file-descriptor exhaustion that +/// persists past `EMFILE`/`ENFILE` retries, or a closed listener). Per-peer +/// failures — TLS handshake errors, handshake timeouts, and HTTP-layer errors +/// on a single connection — are logged at `debug`/`warn`/`trace` and never +/// abort the accept loop. +pub fn serve_tls( + listener: TcpListener, + router: axum::Router, + tls_config: Arc, +) -> ServeTls { + ServeTls { + listener, + router, + acceptor: tokio_rustls::TlsAcceptor::from(tls_config), + tls_handshake_timeout: DEFAULT_TLS_HANDSHAKE_TIMEOUT, + shutdown: None, + } +} + +/// Configurable future returned by [`serve_tls`]. +/// +/// Mirrors the shape of `axum::serve::Serve`: tweak it with builder +/// methods, then `.await` it (or pass it anywhere an `IntoFuture` is +/// accepted). +#[must_use = "ServeTls does nothing unless `.await`ed"] +pub struct ServeTls { + listener: TcpListener, + router: axum::Router, + acceptor: tokio_rustls::TlsAcceptor, + tls_handshake_timeout: Duration, + shutdown: Option + Send>>>, +} + +impl ServeTls { + /// Override the TLS handshake timeout (default + /// [`DEFAULT_TLS_HANDSHAKE_TIMEOUT`]). Set generously; clients on + /// high-latency links need a few round trips to complete the handshake. + #[must_use = "ServeTls does nothing unless `.await`ed"] + pub fn tls_handshake_timeout(mut self, timeout: Duration) -> Self { + self.tls_handshake_timeout = timeout; + self + } + + /// Stop accepting new connections when `signal` resolves and drain + /// in-flight connections before the future returned by [`serve_tls`] + /// resolves. Mirrors `axum::serve::Serve::with_graceful_shutdown`. + #[must_use = "ServeTls does nothing unless `.await`ed"] + pub fn with_graceful_shutdown(mut self, signal: F) -> Self + where + F: Future + Send + 'static, + { + self.shutdown = Some(Box::pin(signal)); + self + } +} + +impl std::fmt::Debug for ServeTls { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ServeTls") + .field("listener", &self.listener) + .field("tls_handshake_timeout", &self.tls_handshake_timeout) + .field("shutdown", &self.shutdown.is_some()) + .finish_non_exhaustive() + } +} + +impl IntoFuture for ServeTls { + type Output = std::io::Result<()>; + type IntoFuture = Pin + Send>>; + + fn into_future(self) -> Self::IntoFuture { + Box::pin(self.run()) + } +} + +impl ServeTls { + async fn run(self) -> std::io::Result<()> { + let ServeTls { + listener, + router, + acceptor, + tls_handshake_timeout, + shutdown, + } = self; + + // `select!` needs a polled-in-place future for both arms; default to + // a never-resolving signal when no graceful shutdown is configured. + let mut shutdown = shutdown.unwrap_or_else(|| Box::pin(std::future::pending())); + let graceful = GracefulShutdown::new(); + + loop { + let (stream, remote_addr) = tokio::select! { + biased; // honor shutdown before another accept + + _ = &mut shutdown => { + tracing::info!("Shutdown signal received; draining connections"); + break; + } + accepted = listener.accept() => match accepted { + Ok(conn) => conn, + Err(err) if is_transient_accept_error(&err) => { + tracing::warn!("Transient accept error (continuing): {err}"); + continue; + } + Err(err) => return Err(err), + }, + }; + + // Same TCP_NODELAY rationale as the standalone Server: avoid + // Nagle/delayed-ACK interaction on small HTTP/2 control frames. + if let Err(e) = stream.set_nodelay(true) { + tracing::warn!("failed to set TCP_NODELAY: {e}"); + } + + let acceptor = acceptor.clone(); + let router = router.clone(); + let watcher = graceful.watcher(); + + tokio::spawn(async move { + let tls_stream = match tokio::time::timeout( + tls_handshake_timeout, + acceptor.accept(stream), + ) + .await + { + Ok(Ok(s)) => s, + Ok(Err(err)) => { + tracing::debug!(remote_addr = %remote_addr, error = ?err, "TLS handshake failed"); + return; + } + Err(_) => { + tracing::warn!( + remote_addr = %remote_addr, + "TLS handshake timed out after {tls_handshake_timeout:?}", + ); + return; + } + }; + + // Capture peer info now — once hyper owns the stream we can't + // borrow it again. `into_owned()` detaches the cert bytes from + // the session lifetime so the Arc can outlive the TlsStream. + let (_, conn) = tls_stream.get_ref(); + let peer_addr = PeerAddr(remote_addr); + let peer_certs = conn + .peer_certificates() + .map(|chain| PeerCerts(chain.iter().map(|c| c.clone().into_owned()).collect())); + + // Per-request: stamp peer info into extensions and forward to + // the axum service. `Router::clone()` is an Arc bump. + let svc = hyper::service::service_fn( + move |mut req: hyper::Request| { + req.extensions_mut().insert(peer_addr.clone()); + if let Some(c) = &peer_certs { + req.extensions_mut().insert(c.clone()); + } + router.clone().oneshot(req.map(axum::body::Body::new)) + }, + ); + + // `serve_connection_with_upgrades` (vs `serve_connection` on the + // standalone `Server`) so axum routes that need HTTP `Upgrade:` + // (WebSockets) work out of the box. ConnectRPC routes don't + // upgrade, so this is a no-op for them. Keep this divergence — + // it matches what `axum::serve` does internally. + let conn = AutoBuilder::new(TokioExecutor::new()) + .serve_connection_with_upgrades(TokioIo::new(tls_stream), svc) + .into_owned(); + if let Err(err) = watcher.watch(conn).await { + tracing::trace!(remote_addr = %remote_addr, error = %err, "Connection ended with error"); + } + }); + } + + // Stop accepting before signalling the drain so no new connection + // sneaks in between the watcher snapshot and the listener close. + drop(listener); + graceful.shutdown().await; + tracing::info!("All connections drained; shutdown complete"); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Mutex; + + use crate::{Response as ConnectResponse, Router as ConnectRouter, handler_fn}; + use rcgen::{CertificateParams, CertifiedIssuer, IsCa, KeyPair, SanType}; + use rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer}; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + type Pki = ( + Arc, + Arc, + CertificateDer<'static>, + ); + + /// Minimal in-memory mTLS PKI: one CA, one server leaf, one client leaf. + /// Returns `(server_config, client_config, client_leaf_der)`. + fn pki() -> Pki { + // Idempotent; err == already installed (tests share process state). + let _ = rustls::crypto::aws_lc_rs::default_provider().install_default(); + + let ca_key = KeyPair::generate().unwrap(); + let mut ca_params = CertificateParams::default(); + ca_params.is_ca = IsCa::Ca(rcgen::BasicConstraints::Unconstrained); + let ca = CertifiedIssuer::self_signed(ca_params, ca_key).unwrap(); + + let issue = |sans: &[SanType]| { + let k = KeyPair::generate().unwrap(); + let mut p = CertificateParams::default(); + p.subject_alt_names = sans.to_vec(); + let c = p.signed_by(&k, &ca).unwrap(); + ( + CertificateDer::from(c.der().to_vec()), + PrivatePkcs8KeyDer::from(k.serialized_der().to_vec()).into(), + ) + }; + let (srv_cert, srv_key) = issue(&[SanType::DnsName("localhost".try_into().unwrap())]); + let (cli_cert, cli_key) = issue(&[]); + + let mut roots = rustls::RootCertStore::empty(); + roots.add(CertificateDer::from(ca.der().to_vec())).unwrap(); + let roots = Arc::new(roots); + + let cv = rustls::server::WebPkiClientVerifier::builder(Arc::clone(&roots)) + .build() + .unwrap(); + let server = rustls::ServerConfig::builder() + .with_client_cert_verifier(cv) + .with_single_cert(vec![srv_cert], srv_key) + .unwrap(); + let client = rustls::ClientConfig::builder() + .with_root_certificates(roots) + .with_client_auth_cert(vec![cli_cert.clone()], cli_key) + .unwrap(); + (Arc::new(server), Arc::new(client), cli_cert) + } + + /// HTTP/1.1 Connect unary request matching `server.rs`'s fixture. + const ECHO_REQ: &[u8] = b"POST /svc/Echo HTTP/1.1\r\n\ + Host: localhost\r\n\ + Content-Type: application/proto\r\n\ + Content-Length: 0\r\n\ + Connection: close\r\n\ + \r\n"; + + #[tokio::test] + async fn serve_tls_injects_peer_identity() { + let (server_cfg, client_cfg, expected_client_der) = pki(); + + // The handler stashes whatever PeerAddr/PeerCerts it sees. + type Captured = Arc)>>>; + let captured: Captured = Arc::new(Mutex::new(None)); + let handler_captured = Arc::clone(&captured); + let connect = ConnectRouter::new().route( + "svc", + "Echo", + handler_fn( + move |ctx: crate::RequestContext, _req: buffa_types::Empty| { + let cap = Arc::clone(&handler_captured); + async move { + *cap.lock().unwrap() = Some(( + ctx.extensions.get::().cloned().unwrap(), + ctx.extensions.get::().cloned(), + )); + ConnectResponse::ok(buffa_types::Empty::default()) + } + }, + ), + ); + let app = axum::Router::new().fallback_service(connect.into_axum_service()); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (tx, rx) = tokio::sync::oneshot::channel(); + let serve = tokio::spawn( + serve_tls(listener, app, server_cfg) + .with_graceful_shutdown(async { + rx.await.ok(); + }) + .into_future(), + ); + + let resp = echo_over_tls(addr, client_cfg).await; + assert!( + resp.starts_with(b"HTTP/1.1 2"), + "expected 2xx, got: {}", + String::from_utf8_lossy(&resp[..resp.len().min(120)]) + ); + + // Graceful shutdown should drain and resolve the serve task. + tx.send(()).unwrap(); + tokio::time::timeout(Duration::from_secs(5), serve) + .await + .expect("serve should shut down within timeout") + .unwrap() + .unwrap(); + + let (peer_addr, peer_certs) = captured.lock().unwrap().take().expect("handler ran"); + assert_eq!(peer_addr.0.ip(), addr.ip()); + let certs = peer_certs.expect("mTLS client should present a cert chain"); + assert_eq!(certs.0.len(), 1); + assert_eq!(certs.0[0].as_ref(), expected_client_der.as_ref()); + } + + /// Open a TLS+HTTP/1.1 connection, send `ECHO_REQ`, and return the raw + /// HTTP response bytes. + async fn echo_over_tls( + addr: std::net::SocketAddr, + client_cfg: Arc, + ) -> Vec { + let tcp = tokio::net::TcpStream::connect(addr).await.unwrap(); + let connector = tokio_rustls::TlsConnector::from(client_cfg); + let sni = rustls::pki_types::ServerName::try_from("localhost").unwrap(); + let mut tls = connector.connect(sni, tcp).await.unwrap(); + tls.write_all(ECHO_REQ).await.unwrap(); + let mut resp = Vec::new(); + tls.read_to_end(&mut resp).await.unwrap(); + resp + } + + #[tokio::test] + async fn handshake_timeout_drops_stalled_connection() { + let (server_cfg, _, _) = pki(); + let app = axum::Router::new(); + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (tx, rx) = tokio::sync::oneshot::channel(); + let serve = tokio::spawn( + serve_tls(listener, app, server_cfg) + .tls_handshake_timeout(Duration::from_millis(100)) + .with_graceful_shutdown(async { + rx.await.ok(); + }) + .into_future(), + ); + + // Open TCP but never speak TLS, and keep it open through shutdown. + // If the handshake timeout doesn't release this connection's watcher, + // the graceful drain blocks until the outer timeout fails the test. + let _stalled = tokio::net::TcpStream::connect(addr).await.unwrap(); + // Generous margin so the accept loop spawns the per-connection task + // (and its watcher) before we signal shutdown — otherwise the test + // passes vacuously without exercising the timeout path. + tokio::time::sleep(Duration::from_millis(250)).await; + + tx.send(()).unwrap(); + tokio::time::timeout(Duration::from_secs(5), serve) + .await + .expect("handshake timeout must release the watcher so drain completes") + .unwrap() + .unwrap(); + } + + #[tokio::test] + async fn handshake_error_does_not_kill_accept_loop() { + let (server_cfg, client_cfg, _) = pki(); + let calls = Arc::new(Mutex::new(0u32)); + let handler_calls = Arc::clone(&calls); + let connect = ConnectRouter::new().route( + "svc", + "Echo", + handler_fn( + move |_ctx: crate::RequestContext, _req: buffa_types::Empty| { + let calls = Arc::clone(&handler_calls); + async move { + *calls.lock().unwrap() += 1; + ConnectResponse::ok(buffa_types::Empty::default()) + } + }, + ), + ); + let app = axum::Router::new().fallback_service(connect.into_axum_service()); + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (tx, rx) = tokio::sync::oneshot::channel(); + let serve = tokio::spawn( + serve_tls(listener, app, server_cfg) + .with_graceful_shutdown(async { + rx.await.ok(); + }) + .into_future(), + ); + + // Speak garbage instead of a ClientHello: the rustls handshake fails + // immediately. The accept loop must log-and-continue, not propagate. + let mut bad = tokio::net::TcpStream::connect(addr).await.unwrap(); + bad.write_all(b"GET / HTTP/1.1\r\n\r\n").await.unwrap(); + let mut buf = [0u8; 64]; + let _ = bad.read(&mut buf).await; // server closes / sends a TLS alert + drop(bad); + + // A valid client must still get through. + let resp = echo_over_tls(addr, client_cfg).await; + assert!( + resp.starts_with(b"HTTP/1.1 2"), + "valid client must succeed after a handshake error: {}", + String::from_utf8_lossy(&resp[..resp.len().min(120)]) + ); + + tx.send(()).unwrap(); + tokio::time::timeout(Duration::from_secs(5), serve) + .await + .unwrap() + .unwrap() + .unwrap(); + assert_eq!( + *calls.lock().unwrap(), + 1, + "only the valid client reaches the handler" + ); + } + + #[tokio::test] + async fn graceful_shutdown_drains_in_flight_request() { + let (server_cfg, client_cfg, _) = pki(); + + // The handler blocks until the test releases it; this lets us pin a + // request as "in-flight" across the shutdown signal. + let (in_flight_tx, in_flight_rx) = tokio::sync::oneshot::channel::<()>(); + let (release_tx, release_rx) = tokio::sync::oneshot::channel::<()>(); + let in_flight_tx = Arc::new(Mutex::new(Some(in_flight_tx))); + let release_rx = Arc::new(Mutex::new(Some(release_rx))); + let connect = ConnectRouter::new().route( + "svc", + "Echo", + handler_fn( + move |_ctx: crate::RequestContext, _req: buffa_types::Empty| { + let in_flight = in_flight_tx.lock().unwrap().take(); + let release = release_rx.lock().unwrap().take(); + async move { + if let Some(tx) = in_flight { + tx.send(()).ok(); + } + if let Some(rx) = release { + rx.await.ok(); + } + ConnectResponse::ok(buffa_types::Empty::default()) + } + }, + ), + ); + let app = axum::Router::new().fallback_service(connect.into_axum_service()); + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel(); + let serve = tokio::spawn( + serve_tls(listener, app, server_cfg) + .with_graceful_shutdown(async { + shutdown_rx.await.ok(); + }) + .into_future(), + ); + + let client = tokio::spawn(echo_over_tls(addr, client_cfg)); + + // Once the request is in-flight, signal shutdown. The watcher held by + // the per-connection task must anchor it until the handler returns. + in_flight_rx.await.unwrap(); + shutdown_tx.send(()).unwrap(); + + // Release the handler: the in-flight request must complete cleanly + // (proving the connection wasn't torn down by the shutdown), and only + // then should the serve future drain. + release_tx.send(()).unwrap(); + let resp = tokio::time::timeout(Duration::from_secs(5), client) + .await + .expect("in-flight request should complete during drain") + .unwrap(); + assert!( + resp.starts_with(b"HTTP/1.1 2"), + "in-flight request must complete: {}", + String::from_utf8_lossy(&resp[..resp.len().min(120)]) + ); + tokio::time::timeout(Duration::from_secs(5), serve) + .await + .expect("serve should drain after the in-flight request completes") + .unwrap() + .unwrap(); + } +} diff --git a/connectrpc/src/lib.rs b/connectrpc/src/lib.rs index a667395..3207fed 100644 --- a/connectrpc/src/lib.rs +++ b/connectrpc/src/lib.rs @@ -171,6 +171,15 @@ pub mod client; #[cfg(feature = "server")] pub mod server; +// Optional: TLS-aware `axum::serve` counterpart with peer-identity passthrough. +// +// Note: this module shadows the extern-prelude `axum` crate within the crate +// root scope only. Don't add `use axum::...` here in `lib.rs`; use +// `::axum::...` if a root-level reference to the external crate is ever needed. +#[cfg(all(feature = "axum", feature = "server-tls"))] +#[cfg_attr(docsrs, doc(cfg(all(feature = "axum", feature = "server-tls"))))] +pub mod axum; + // ============================================================================ // Primary exports - Tower-first API // ============================================================================ diff --git a/connectrpc/src/server.rs b/connectrpc/src/server.rs index 9198346..6d0b7e0 100644 --- a/connectrpc/src/server.rs +++ b/connectrpc/src/server.rs @@ -62,8 +62,9 @@ use crate::service::ConnectRpcService; /// Remote socket address of the connected peer. /// -/// Inserted into every request's extensions by the built-in server's accept -/// loop. Handlers read it via `ctx.extensions.get::()`. +/// Inserted into every request's extensions by the built-in [`Server`]'s +/// accept loop and by `connectrpc::axum::serve_tls`. Handlers read it via +/// `ctx.extensions.get::()`. /// /// Callers using a different HTTP stack (axum, raw hyper) in front of /// [`ConnectRpcService`] can insert this same type @@ -73,11 +74,11 @@ pub struct PeerAddr(pub SocketAddr); /// TLS client certificate chain presented by the peer (leaf first). /// -/// Inserted by the built-in server's TLS accept loop when the -/// [`rustls::ServerConfig`] requests client authentication and the peer -/// presents a valid chain. Absent on plaintext connections or when the -/// client presents no certificate. Handlers read it via -/// `ctx.extensions.get::()`. +/// Inserted by the built-in [`Server`]'s TLS accept loop and by +/// `connectrpc::axum::serve_tls` when the [`rustls::ServerConfig`] requests +/// client authentication and the peer presents a valid chain. Absent on +/// plaintext connections or when the client presents no certificate. +/// Handlers read it via `ctx.extensions.get::()`. /// /// The `Arc` makes per-request insertion cheap: all requests on a /// connection share one chain, so this is a refcount bump, not a copy. @@ -681,7 +682,7 @@ fn panic_handler(err: Box) -> Response> { /// - `EMFILE` / `ENFILE`: Too many open files (file descriptor exhaustion) /// - `ECONNABORTED`: Connection was aborted before accept completed /// - `EINTR`: Interrupted system call -fn is_transient_accept_error(err: &std::io::Error) -> bool { +pub(crate) fn is_transient_accept_error(err: &std::io::Error) -> bool { use std::io::ErrorKind; matches!( diff --git a/docs/guide.md b/docs/guide.md index 3c0fda5..4dc94c4 100644 --- a/docs/guide.md +++ b/docs/guide.md @@ -676,8 +676,23 @@ Server::new(connect_router) .await?; ``` -For the axum path, wrap the listener with `tokio_rustls::TlsAcceptor` -yourself (this is what the eliza example does). +For the axum path, `connectrpc::axum::serve_tls` (requires both the +`axum` and `server-tls` features) is a drop-in replacement for +`axum::serve` that owns the rustls accept loop and stamps `PeerAddr` / +`PeerCerts` into request extensions exactly as the standalone `Server` +does, so handler code that reads `ctx.extensions.get::()` +is portable across both hosting paths: + +```rust +let app = axum::Router::new() + .route("/health", axum::routing::get(|| async { "OK" })) + .fallback_service(connect_router.into_axum_service()); + +let listener = tokio::net::TcpListener::bind("0.0.0.0:8443").await?; +connectrpc::axum::serve_tls(listener, app, server_config) + .with_graceful_shutdown(shutdown_signal) + .await?; +``` The eliza example ([`examples/eliza/README.md`](../examples/eliza/README.md)) walks From 1934affa8d1047b888e79780b91cc0d9d9741580 Mon Sep 17 00:00:00 2001 From: Iain McGinniss <309153+iainmcgin@users.noreply.github.com> Date: Tue, 5 May 2026 20:55:54 +0000 Subject: [PATCH 2/3] examples: add mtls-identity (cert-SAN identity behind serve_tls) The mTLS twin of examples/middleware: same secret-store-with-ACL shape, but identity comes from the verified client certificate instead of a Bearer token. Hosted on axum behind connectrpc::axum::serve_tls; the handler reads PeerCerts from request extensions, parses the leaf cert's DNS SAN with x509-parser, and enforces an ACL keyed on the cert-derived workload name. Single self-contained binary: in-memory rcgen PKI (CA, server leaf, two workload client leafs), serve, call with both identities, graceful shutdown. No PEM files touch disk. The shared lib.rs exposes the proto, handler, identity extraction, and PKI helpers so tests/e2e.rs reuses them rather than duplicating cert generation. x509-parser is already a transitive dep via rcgen, so no new crates land in the lockfile. Closes #49. --- Cargo.toml | 2 +- docs/guide.md | 7 +- examples/mtls-identity/Cargo.toml | 50 +++ examples/mtls-identity/README.md | 109 +++++ examples/mtls-identity/build.rs | 8 + .../mtls_identity/v1/identity.proto | 35 ++ examples/mtls-identity/src/lib.rs | 372 ++++++++++++++++++ examples/mtls-identity/src/main.rs | 99 +++++ examples/mtls-identity/tests/e2e.rs | 138 +++++++ 9 files changed, 818 insertions(+), 2 deletions(-) create mode 100644 examples/mtls-identity/Cargo.toml create mode 100644 examples/mtls-identity/README.md create mode 100644 examples/mtls-identity/build.rs create mode 100644 examples/mtls-identity/proto/anthropic/connectrpc/mtls_identity/v1/identity.proto create mode 100644 examples/mtls-identity/src/lib.rs create mode 100644 examples/mtls-identity/src/main.rs create mode 100644 examples/mtls-identity/tests/e2e.rs diff --git a/Cargo.toml b/Cargo.toml index bda9b33..febc8dd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = ["connectrpc", "connectrpc-codegen", "connectrpc-build", "conformance", "examples/eliza", "examples/middleware", "examples/multiservice", "examples/streaming-tour", "examples/wasm-client", "tests/streaming", "benches/rpc", "benches/rpc-tonic"] +members = ["connectrpc", "connectrpc-codegen", "connectrpc-build", "conformance", "examples/eliza", "examples/middleware", "examples/mtls-identity", "examples/multiservice", "examples/streaming-tour", "examples/wasm-client", "tests/streaming", "benches/rpc", "benches/rpc-tonic"] resolver = "2" [workspace.package] diff --git a/docs/guide.md b/docs/guide.md index 4dc94c4..4ca4bee 100644 --- a/docs/guide.md +++ b/docs/guide.md @@ -698,7 +698,11 @@ The eliza example ([`examples/eliza/README.md`](../examples/eliza/README.md)) walks through generating self-signed certificates with openssl, configuring mTLS via `--client-ca`, and the rustls strict-PKI requirement that -your CA cert must be distinct from the server leaf cert. +your CA cert must be distinct from the server leaf cert. The +mtls-identity example +([`examples/mtls-identity/README.md`](../examples/mtls-identity/README.md)) +demonstrates `serve_tls` end-to-end with cert-SAN identity extraction +and an ACL keyed on it. ## Clients @@ -920,6 +924,7 @@ let service = ConnectRpcService::new(router).with_compression(registry); |---|---| | [`streaming-tour/`](../examples/streaming-tour) | All four RPC types (unary, server stream, client stream, bidi) on a trivial NumberService. Smallest demo of handler signatures and client invocation patterns. | | [`middleware/`](../examples/middleware) | Server-side tower middleware composition: an `axum::middleware::from_fn` bearer-token auth, identity passthrough via `RequestContext::extensions`, response trailers via `Response::with_trailer`. Client demos `ClientConfig::default_header` and `CallOptions::with_timeout`. | +| [`mtls-identity/`](../examples/mtls-identity) | mTLS twin of `middleware/`: axum hosted behind `connectrpc::axum::serve_tls`, identity from the client cert's DNS SAN via `PeerCerts` instead of a bearer token, ACL keyed on the cert-derived identity. In-memory `rcgen` PKI; no PEM files. | | [`eliza/`](../examples/eliza) | Production-shaped streaming app: a port of the `connectrpc/examples-go` ELIZA demo. Server-streaming Introduce + bidi-streaming Converse, TLS, mTLS, CORS, IPv6, both server and client binaries, interoperates with the hosted Go reference at `demo.connectrpc.com`. | | [`multiservice/`](../examples/multiservice) | Multiple proto packages compiled together with `buf generate`, multiple services on one server, well-known type usage. | | [`wasm-client/`](../examples/wasm-client) | Browser fetch transport: same generated client used from `wasm32-unknown-unknown` with a custom `ClientTransport` backed by `web-sys::fetch`. | diff --git a/examples/mtls-identity/Cargo.toml b/examples/mtls-identity/Cargo.toml new file mode 100644 index 0000000..07c6f6d --- /dev/null +++ b/examples/mtls-identity/Cargo.toml @@ -0,0 +1,50 @@ +[package] +name = "mtls-identity-example" +version = "0.1.0" +edition.workspace = true +rust-version.workspace = true +license.workspace = true +publish = false + +[lib] +name = "mtls_identity_example" +path = "src/lib.rs" + +[[bin]] +name = "mtls-identity" +path = "src/main.rs" + +[dependencies] +connectrpc = { path = "../../connectrpc", features = ["axum", "client", "client-tls", "server-tls"] } + +# Protobuf +buffa = { workspace = true } +buffa-types = { workspace = true } + +# Serialization (for generated message types) +serde = { workspace = true } +serde_json = { workspace = true } + +# HTTP types (for generated client bounds) +http-body = { workspace = true } +http = { workspace = true } + +# Async +tokio = { workspace = true, features = ["full"] } +futures = { workspace = true } + +# Web framework +axum = { workspace = true } + +# TLS (test PKI generated at runtime; no PEM files in the repo) +rcgen = { workspace = true } +rustls-pki-types = { workspace = true } +# Parse the leaf cert's SubjectAlternativeName in the handler. Already in +# the lockfile transitively via rcgen, so this adds nothing to the build. +x509-parser = "0.18" + +[build-dependencies] +connectrpc-build = { path = "../../connectrpc-build" } + +[lints] +workspace = true diff --git a/examples/mtls-identity/README.md b/examples/mtls-identity/README.md new file mode 100644 index 0000000..5b2e7d0 --- /dev/null +++ b/examples/mtls-identity/README.md @@ -0,0 +1,109 @@ +# mTLS identity example + +Demonstrates cert-SAN-based identity for an axum-hosted ConnectRPC +service. The server is hosted on axum behind +`connectrpc::axum::serve_tls`, which terminates TLS, captures the +verified client certificate chain and remote address, and stamps them +into request extensions as `PeerCerts` / `PeerAddr` — the same +convention the standalone `connectrpc::Server::with_tls` uses. The +handler parses the leaf cert's DNS SAN to derive a workload identity +and enforces an ACL against it. + +This is the mTLS twin of [`examples/middleware/`](../middleware): same +secret-store-with-ACL shape, but the credential is a client certificate +instead of a `Bearer` token. The handler-side code that reads +`ctx.extensions.get::()` is unchanged; only what the layer/accept +loop puts into extensions differs. + +## Run it + +The demo is a single self-contained binary: it generates an in-memory +PKI (CA, server cert, two workload client certs) with `rcgen`, starts +the server, makes a few calls with each identity, and shuts down. No +PEM files touch disk. + +```bash +cargo run -p mtls-identity-example +``` + +Expected output (port and source ports vary): + +``` +IdentityService listening on https://127.0.0.1:PORT (mTLS required) + +[alice] WhoAmI -> identity="alice" san="alice.workloads.example.com" from="127.0.0.1:..." +[bob] WhoAmI -> identity="bob" san="bob.workloads.example.com" from="127.0.0.1:..." + +[alice] GetSecret( shared) -> "the value of teamwork" (x-served-by: alice) +[alice] GetSecret(alice-only) -> "alice's diary entry" (x-served-by: alice) +[bob] GetSecret( shared) -> "the value of teamwork" (x-served-by: bob) +[bob] GetSecret(alice-only) -> permission_denied: workload "bob" (bob.workloads.example.com) cannot read "alice-only" +``` + +## What to look at + +### `serve_tls` instead of `axum::serve` (`src/lib.rs::serve`) + +`axum::serve` accepts a plain `TcpListener` with no hook for +terminating TLS, so an axum + mTLS deployment normally has to write a +rustls accept loop by hand. `connectrpc::axum::serve_tls` is a drop-in +replacement that owns that loop and stamps `PeerAddr` / `PeerCerts` +into request extensions: + +```rust +let app = axum::Router::new().fallback_service(connect_router.into_axum_service()); +connectrpc::axum::serve_tls(listener, app, server_config) + .with_graceful_shutdown(shutdown) + .await?; +``` + +Handler code that reads `ctx.extensions.get::()` is then +portable between the standalone `Server::with_tls` and an axum app. + +### Cert-SAN identity (`src/lib.rs::extract_identity`) + +The handler reads the leaf cert from `PeerCerts`, parses its DNS SAN +with `x509-parser`, and derives a short workload name from a SAN under +`workloads.example.com`. A real deployment would typically match a +SPIFFE ID (a URI SAN) instead, or hand the whole step to an +authorization framework — the shape is the same: read `PeerCerts`, +parse the leaf, derive an identity. + +Two failure modes both surface as `Unauthenticated`: + +- No client cert presented: only reachable if the server's + `ClientCertVerifier` made client auth optional. This example uses + `WebPkiClientVerifier`, which *requires* a verified chain, so this + path is dead in practice — kept as defense in depth. +- A cert is presented but no SAN matches the workload domain. + +### In-memory PKI (`src/lib.rs::pki`) + +`pki::generate(&["alice", "bob"])` builds a CA, a server leaf +(`SAN = localhost`), and one client leaf per workload +(`SAN = .workloads.example.com`), all in memory via `rcgen`. A +deployment would load these from a secret store; the rustls types are +identical. + +The server config requires *and verifies* client certs against the +demo CA, so the chain that reaches the handler is always verified — +the SAN parsing only has to decide *which* trusted client this is, not +whether to trust it. + +## Integration test + +`tests/e2e.rs` exercises four paths: identity reflection (`WhoAmI`), +authorized read with response trailer, permission denied (bob reading +alice's secret), and a TLS client without a cert being rejected at the +handshake before the request reaches HTTP. + +```bash +cargo test -p mtls-identity-example +``` + +## Where to go next + +- See [`examples/middleware`](../middleware) for the bearer-token + equivalent of this example. +- See [`examples/eliza`](../eliza) for loading certs from PEM files + with `--cert`/`--key`/`--client-ca` CLI flags. diff --git a/examples/mtls-identity/build.rs b/examples/mtls-identity/build.rs new file mode 100644 index 0000000..cd58f25 --- /dev/null +++ b/examples/mtls-identity/build.rs @@ -0,0 +1,8 @@ +fn main() { + connectrpc_build::Config::new() + .files(&["proto/anthropic/connectrpc/mtls_identity/v1/identity.proto"]) + .includes(&["proto/"]) + .include_file("_connectrpc.rs") + .compile() + .unwrap(); +} diff --git a/examples/mtls-identity/proto/anthropic/connectrpc/mtls_identity/v1/identity.proto b/examples/mtls-identity/proto/anthropic/connectrpc/mtls_identity/v1/identity.proto new file mode 100644 index 0000000..b7f70c2 --- /dev/null +++ b/examples/mtls-identity/proto/anthropic/connectrpc/mtls_identity/v1/identity.proto @@ -0,0 +1,35 @@ +edition = "2023"; + +package anthropic.connectrpc.mtls_identity.v1; + +// IdentityService demonstrates mTLS cert-SAN-based identity. The server +// is hosted on axum behind connectrpc::axum::serve_tls, which terminates +// TLS and stamps PeerAddr/PeerCerts into request extensions. The handler +// parses the leaf cert's DNS SAN to derive a workload identity rather +// than reading a bearer token, then enforces an ACL against it. +service IdentityService { + // Echo back what the server knows about the caller from the TLS layer. + rpc WhoAmI(WhoAmIRequest) returns (WhoAmIResponse); + + // Read a secret, gated on the caller's cert-derived identity. + rpc GetSecret(GetSecretRequest) returns (GetSecretResponse); +} + +message WhoAmIRequest {} + +message WhoAmIResponse { + // Workload identity parsed from the leaf cert's DNS SAN, e.g. "alice". + string identity = 1; + // Full DNS SAN as presented, e.g. "alice.workloads.example.com". + string san = 2; + // Caller's remote socket address as observed by the accept loop. + string remote_addr = 3; +} + +message GetSecretRequest { + string name = 1; +} + +message GetSecretResponse { + string value = 1; +} diff --git a/examples/mtls-identity/src/lib.rs b/examples/mtls-identity/src/lib.rs new file mode 100644 index 0000000..c333cd3 --- /dev/null +++ b/examples/mtls-identity/src/lib.rs @@ -0,0 +1,372 @@ +//! mTLS cert-SAN identity for axum-hosted ConnectRPC services. +//! +//! Mirrors `examples/middleware/`, swapping bearer-token auth for mTLS: +//! instead of an `axum::middleware::from_fn` reading an `Authorization` +//! header, identity comes from the verified client certificate that +//! `connectrpc::axum::serve_tls` captures during the TLS handshake and +//! stamps into request extensions as [`PeerCerts`]. The handler parses +//! the leaf cert's DNS SAN to derive a workload identity, then enforces +//! an ACL against it. +//! +//! The same handler code works unchanged on the standalone +//! [`connectrpc::Server::with_tls`], which populates [`PeerCerts`] the +//! same way — the hosting choice doesn't leak into authorization logic. + +use std::collections::HashMap; +use std::sync::Arc; + +use buffa::view::OwnedView; +use connectrpc::{ + ConnectError, ErrorCode, PeerAddr, PeerCerts, RequestContext, Router, ServiceResult, +}; + +pub mod proto { + connectrpc::include_generated!(); +} + +pub use proto::anthropic::connectrpc::mtls_identity::v1::*; + +pub type BoxError = Box; + +// ============================================================================ +// Identity: derive a workload name from the leaf cert's DNS SAN. +// ============================================================================ + +/// All clients in this demo carry a SAN under this domain. Anything else +/// is rejected as `Unauthenticated`. +pub const WORKLOAD_DOMAIN: &str = "workloads.example.com"; + +/// Caller identity, parsed from the leaf certificate's DNS SAN. +#[derive(Debug, Clone)] +pub struct Identity { + /// Short workload name, e.g. `"alice"`. + pub name: String, + /// Full DNS SAN as presented, e.g. `"alice.workloads.example.com"`. + pub san: String, +} + +/// Parse a workload identity out of the leaf certificate's DNS SANs. +/// +/// Returns the first SAN under [`WORKLOAD_DOMAIN`]. Returns +/// `Unauthenticated` when no client cert was presented (a non-mTLS +/// connection) or no SAN matches the expected domain. +/// +/// In a real deployment you'd typically match a SPIFFE ID +/// (`spiffe://trust-domain/path`, a URI SAN) instead of a DNS SAN, or +/// delegate this whole step to an authorization framework. The shape is +/// the same: read [`PeerCerts`], parse the leaf, derive an identity. +pub fn extract_identity(certs: Option<&PeerCerts>) -> Result { + use x509_parser::extensions::GeneralName; + use x509_parser::prelude::{FromDer, X509Certificate}; + + let leaf = certs.and_then(|c| c.0.first()).ok_or_else(|| { + ConnectError::new(ErrorCode::Unauthenticated, "client certificate required") + })?; + + let (_, parsed) = X509Certificate::from_der(leaf.as_ref()).map_err(|e| { + ConnectError::new(ErrorCode::Unauthenticated, format!("bad client cert: {e}")) + })?; + + let suffix = format!(".{WORKLOAD_DOMAIN}"); + parsed + .subject_alternative_name() + .ok() + .flatten() + .into_iter() + .flat_map(|ext| ext.value.general_names.iter()) + .find_map(|gn| { + // Only `.workloads.example.com` is a workload SAN. + // Reject `.workloads.example.com` (empty label) and + // `a.b.workloads.example.com` (would alias as `a.b`), which an + // attacker-controlled CA could otherwise mint to spoof an ACL + // entry. + let GeneralName::DNSName(dns) = gn else { + return None; + }; + let prefix = dns.strip_suffix(&suffix)?; + if prefix.is_empty() || prefix.contains('.') { + return None; + } + Some(Identity { + name: prefix.to_owned(), + san: (*dns).to_owned(), + }) + }) + .ok_or_else(|| { + ConnectError::new( + ErrorCode::Unauthenticated, + format!("client cert has no workload SAN under {WORKLOAD_DOMAIN}"), + ) + }) +} + +// ============================================================================ +// IdentityService handler +// ============================================================================ + +/// Static secret store. Each secret declares which workloads may read it. +pub fn secret_store() -> HashMap)> { + HashMap::from([ + ( + "shared".into(), + ("the value of teamwork".into(), vec!["alice", "bob"]), + ), + ( + "alice-only".into(), + ("alice's diary entry".into(), vec!["alice"]), + ), + ]) +} + +pub struct IdentityServiceImpl { + pub store: HashMap)>, +} + +impl IdentityService for IdentityServiceImpl { + async fn who_am_i( + &self, + ctx: RequestContext, + _request: OwnedView>, + ) -> ServiceResult { + // Both PeerCerts and PeerAddr are stamped per connection by + // serve_tls; the dispatcher copies request extensions verbatim + // into ctx.extensions. + let id = extract_identity(ctx.extensions.get::())?; + let remote = ctx + .extensions + .get::() + .map(|a| a.0.to_string()) + .unwrap_or_default(); + connectrpc::Response::ok(WhoAmIResponse { + identity: Some(id.name), + san: Some(id.san), + remote_addr: Some(remote), + ..Default::default() + }) + } + + async fn get_secret( + &self, + ctx: RequestContext, + request: OwnedView>, + ) -> ServiceResult { + let id = extract_identity(ctx.extensions.get::())?; + let name = request.name.unwrap_or("").to_owned(); + let (value, allowed) = self.store.get(&name).ok_or_else(|| { + ConnectError::new(ErrorCode::NotFound, format!("no secret named {name:?}")) + })?; + if !allowed.iter().any(|w| *w == id.name) { + return Err(ConnectError::new( + ErrorCode::PermissionDenied, + format!("workload {:?} ({}) cannot read {name:?}", id.name, id.san), + )); + } + Ok(connectrpc::Response::new(GetSecretResponse { + value: Some(value.clone()), + ..Default::default() + }) + .with_trailer("x-served-by", id.name)) + } +} + +// ============================================================================ +// Server hosting: axum app behind connectrpc::axum::serve_tls. +// ============================================================================ + +/// Build the axum app and serve it over TLS until `shutdown` resolves. +/// +/// This is the only line that differs from a plaintext axum app: +/// `connectrpc::axum::serve_tls` instead of `axum::serve`. +pub async fn serve( + listener: tokio::net::TcpListener, + server_config: Arc, + shutdown: impl std::future::Future + Send + 'static, +) -> std::io::Result<()> { + let svc = Arc::new(IdentityServiceImpl { + store: secret_store(), + }); + let connect_router = svc.register(Router::new()); + let app = axum::Router::new().fallback_service(connect_router.into_axum_service()); + + connectrpc::axum::serve_tls(listener, app, server_config) + .with_graceful_shutdown(shutdown) + .await +} + +// ============================================================================ +// In-memory PKI: one CA, one server leaf, N client leafs (DNS-SAN identities). +// ============================================================================ + +pub mod pki { + use std::sync::Arc; + + use connectrpc::rustls; + use rcgen::{BasicConstraints, CertificateParams, CertifiedIssuer, IsCa, KeyPair, SanType}; + use rustls_pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer}; + + /// One client credential: a leaf cert (with a SAN under the workload + /// domain) and its private key. + pub struct ClientCredential { + pub cert: CertificateDer<'static>, + pub key: PrivateKeyDer<'static>, + } + + /// In-memory PKI for the demo. + pub struct Pki { + /// Server [`rustls::ServerConfig`] requiring client certs signed + /// by the demo CA. + pub server_config: Arc, + /// Trust roots containing only the demo CA, for building client + /// configs. + pub roots: Arc, + /// Per-workload client credentials, keyed by short name. + pub clients: std::collections::HashMap, + } + + impl Pki { + /// Build a [`rustls::ClientConfig`] that trusts the demo CA and + /// presents the named workload's client cert during the handshake. + /// + /// # Panics + /// + /// Panics if `workload` isn't one of the names passed to [`generate`]. + pub fn client_config(&self, workload: &str) -> Arc { + let cred = self + .clients + .get(workload) + .unwrap_or_else(|| panic!("no credential for workload {workload:?}")); + Arc::new( + rustls::ClientConfig::builder() + .with_root_certificates(Arc::clone(&self.roots)) + .with_client_auth_cert(vec![cred.cert.clone()], cred.key.clone_key()) + .expect("valid client cert"), + ) + } + } + + /// Generate a fresh CA, server leaf (`SAN = localhost`), and one client + /// leaf per name in `workloads` (`SAN = .workloads.example.com`). + /// + /// No PEM files touch disk: this is the same shape a deployment would + /// load from a secret store, but generated in-process for the demo. + pub fn generate(workloads: &[&str]) -> Pki { + // Idempotent; err == already installed (tests share process state). + let _ = rustls::crypto::aws_lc_rs::default_provider().install_default(); + + let ca_key = KeyPair::generate().expect("generate CA key"); + let mut ca_params = CertificateParams::default(); + ca_params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained); + let ca = CertifiedIssuer::self_signed(ca_params, ca_key).expect("self-sign CA"); + + // Issue a leaf with the given DNS SANs, signed by the demo CA. + let issue = |sans: &[&str]| -> (CertificateDer<'static>, PrivateKeyDer<'static>) { + let key = KeyPair::generate().expect("generate leaf key"); + let mut params = CertificateParams::default(); + params.subject_alt_names = sans + .iter() + .map(|s| SanType::DnsName((*s).try_into().expect("valid DNS SAN"))) + .collect(); + let cert = params.signed_by(&key, &ca).expect("sign leaf"); + ( + CertificateDer::from(cert.der().to_vec()), + PrivatePkcs8KeyDer::from(key.serialized_der().to_vec()).into(), + ) + }; + + let (server_cert, server_key) = issue(&["localhost"]); + let clients = workloads + .iter() + .map(|name| { + let san = format!("{name}.{}", super::WORKLOAD_DOMAIN); + let (cert, key) = issue(&[&san]); + (name.to_string(), ClientCredential { cert, key }) + }) + .collect(); + + let mut roots = rustls::RootCertStore::empty(); + roots + .add(CertificateDer::from(ca.der().to_vec())) + .expect("add CA to roots"); + let roots = Arc::new(roots); + + // Require *and verify* client certs. WebPkiClientVerifier rejects + // anything not chained to the demo CA before the request reaches + // the handler, so PeerCerts is always a verified chain. + let verifier = rustls::server::WebPkiClientVerifier::builder(Arc::clone(&roots)) + .build() + .expect("build client verifier"); + let server_config = Arc::new( + rustls::ServerConfig::builder() + .with_client_cert_verifier(verifier) + .with_single_cert(vec![server_cert], server_key) + .expect("valid server cert"), + ); + + Pki { + server_config, + roots, + clients, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rcgen::{CertificateParams, KeyPair, SanType}; + use rustls_pki_types::CertificateDer; + + /// Self-signed leaf with arbitrary DNS SANs, wrapped as PeerCerts. + fn peer_certs_with_dns_sans(sans: &[&str]) -> PeerCerts { + let key = KeyPair::generate().unwrap(); + let mut params = CertificateParams::default(); + params.subject_alt_names = sans + .iter() + .map(|s| SanType::DnsName((*s).try_into().unwrap())) + .collect(); + let cert = params.self_signed(&key).unwrap(); + PeerCerts(Arc::from(vec![CertificateDer::from(cert.der().to_vec())])) + } + + #[test] + fn extract_identity_rejects_no_cert() { + assert_eq!( + extract_identity(None).unwrap_err().code, + ErrorCode::Unauthenticated + ); + } + + #[test] + fn extract_identity_parses_single_label_workload_san() { + let certs = + peer_certs_with_dns_sans(&["ignored.example.org", "alice.workloads.example.com"]); + let id = extract_identity(Some(&certs)).unwrap(); + assert_eq!(id.name, "alice"); + assert_eq!(id.san, "alice.workloads.example.com"); + } + + #[test] + fn extract_identity_rejects_empty_or_multi_label_prefix() { + // Empty prefix: ".workloads.example.com" must not yield name = "". + let empty = peer_certs_with_dns_sans(&[".workloads.example.com"]); + assert_eq!( + extract_identity(Some(&empty)).unwrap_err().code, + ErrorCode::Unauthenticated + ); + // Multi-label prefix: "a.b" would alias an ACL entry; reject it. + let multi = peer_certs_with_dns_sans(&["a.b.workloads.example.com"]); + assert_eq!( + extract_identity(Some(&multi)).unwrap_err().code, + ErrorCode::Unauthenticated + ); + } + + #[test] + fn extract_identity_rejects_unrelated_domain() { + let certs = peer_certs_with_dns_sans(&["service.elsewhere.example.org"]); + assert_eq!( + extract_identity(Some(&certs)).unwrap_err().code, + ErrorCode::Unauthenticated + ); + } +} diff --git a/examples/mtls-identity/src/main.rs b/examples/mtls-identity/src/main.rs new file mode 100644 index 0000000..a7feaec --- /dev/null +++ b/examples/mtls-identity/src/main.rs @@ -0,0 +1,99 @@ +//! Self-contained mTLS identity demo. +//! +//! Generates an in-memory PKI, serves IdentityService on axum behind +//! `connectrpc::axum::serve_tls`, then calls it with two client +//! certificates to show the cert-SAN identity flow: +//! +//! - `alice` reads `shared` and `alice-only` successfully +//! - `bob` reads `shared` but is denied `alice-only` +//! +//! Run with: +//! +//! ```sh +//! cargo run -p mtls-identity-example +//! ``` + +use std::sync::Arc; + +use connectrpc::ErrorCode; +use connectrpc::client::{ClientConfig, HttpClient}; +use mtls_identity_example::{ + BoxError, GetSecretRequest, IdentityServiceClient, WhoAmIRequest, pki, serve, +}; + +#[tokio::main] +async fn main() -> Result<(), BoxError> { + // 1. PKI: a fresh CA, server cert, and two workload client certs. + // No PEM files on disk — see `pki::generate`. + let pki = Arc::new(pki::generate(&["alice", "bob"])); + + // 2. Server: axum app behind connectrpc::axum::serve_tls. + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?; + let addr = listener.local_addr()?; + let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>(); + let server = tokio::spawn(serve(listener, Arc::clone(&pki.server_config), async { + shutdown_rx.await.ok(); + })); + println!("IdentityService listening on https://{addr} (mTLS required)\n"); + + // 3. Clients: one HttpClient per workload, each presenting its own + // client cert. The handler derives the identity from the cert SAN — + // the request bodies carry no credentials at all. + let client_for = |workload: &str| { + let http = HttpClient::with_tls(pki.client_config(workload)); + // The server cert's SAN is "localhost"; rustls verifies hostname. + let cfg = ClientConfig::new( + format!("https://localhost:{}", addr.port()) + .parse() + .unwrap(), + ); + IdentityServiceClient::new(http, cfg) + }; + let alice = client_for("alice"); + let bob = client_for("bob"); + + // WhoAmI: identity comes purely from the TLS layer. + for (label, client) in [("alice", &alice), ("bob", &bob)] { + let resp = client.who_am_i(WhoAmIRequest::default()).await?; + let v = resp.view(); + println!( + "[{label}] WhoAmI -> identity={:?} san={:?} from={:?}", + v.identity.unwrap_or(""), + v.san.unwrap_or(""), + v.remote_addr.unwrap_or(""), + ); + } + println!(); + + // GetSecret: the ACL is keyed on the cert-derived identity. + for (label, client) in [("alice", &alice), ("bob", &bob)] { + for name in ["shared", "alice-only"] { + let req = GetSecretRequest { + name: Some(name.into()), + ..Default::default() + }; + match client.get_secret(req).await { + Ok(resp) => { + let served_by = resp + .trailers() + .get("x-served-by") + .and_then(|v| v.to_str().ok()) + .unwrap_or("?"); + println!( + "[{label}] GetSecret({name:>10}) -> {:?} (x-served-by: {served_by})", + resp.view().value.unwrap_or("") + ); + } + Err(err) => { + debug_assert_eq!(err.code, ErrorCode::PermissionDenied); + println!("[{label}] GetSecret({name:>10}) -> {err}"); + } + } + } + } + + // 4. Graceful shutdown: stop accepting and drain in-flight connections. + shutdown_tx.send(()).ok(); + server.await??; + Ok(()) +} diff --git a/examples/mtls-identity/tests/e2e.rs b/examples/mtls-identity/tests/e2e.rs new file mode 100644 index 0000000..f905bf4 --- /dev/null +++ b/examples/mtls-identity/tests/e2e.rs @@ -0,0 +1,138 @@ +//! End-to-end test: spin up IdentityService behind +//! `connectrpc::axum::serve_tls`, make calls with two client identities, +//! and check that the cert-SAN identity flows through `PeerCerts` to +//! the handler and gates the ACL correctly. + +use std::sync::Arc; + +use connectrpc::ErrorCode; +use connectrpc::client::{ClientConfig, HttpClient}; +use mtls_identity_example::{GetSecretRequest, IdentityServiceClient, WhoAmIRequest, pki, serve}; + +struct Harness { + pki: Arc, + addr: std::net::SocketAddr, + shutdown: tokio::sync::oneshot::Sender<()>, + serve_task: tokio::task::JoinHandle>, +} + +impl Harness { + async fn start(workloads: &[&str]) -> Self { + let pki = Arc::new(pki::generate(workloads)); + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (tx, rx) = tokio::sync::oneshot::channel(); + let serve_task = tokio::spawn(serve(listener, Arc::clone(&pki.server_config), async { + rx.await.ok(); + })); + Harness { + pki, + addr, + shutdown: tx, + serve_task, + } + } + + fn client(&self, workload: &str) -> IdentityServiceClient { + let http = HttpClient::with_tls(self.pki.client_config(workload)); + let cfg = ClientConfig::new( + format!("https://localhost:{}", self.addr.port()) + .parse() + .unwrap(), + ); + IdentityServiceClient::new(http, cfg) + } + + async fn shutdown(self) { + self.shutdown.send(()).ok(); + tokio::time::timeout(std::time::Duration::from_secs(5), self.serve_task) + .await + .expect("server should shut down within timeout") + .unwrap() + .unwrap(); + } +} + +#[tokio::test] +async fn whoami_reflects_cert_san_and_remote_addr() { + let h = Harness::start(&["alice"]).await; + let resp = h + .client("alice") + .who_am_i(WhoAmIRequest::default()) + .await + .unwrap(); + let v = resp.view(); + assert_eq!(v.identity, Some("alice")); + assert_eq!(v.san, Some("alice.workloads.example.com")); + // PeerAddr should be the actual TCP source address. + let remote = v.remote_addr.unwrap(); + assert!(remote.starts_with("127.0.0.1:"), "remote_addr={remote}"); + h.shutdown().await; +} + +#[tokio::test] +async fn authorized_call_returns_value_and_trailer() { + let h = Harness::start(&["alice", "bob"]).await; + let resp = h + .client("alice") + .get_secret(GetSecretRequest { + name: Some("shared".into()), + ..Default::default() + }) + .await + .unwrap(); + assert_eq!(resp.view().value, Some("the value of teamwork")); + assert_eq!( + resp.trailers() + .get("x-served-by") + .unwrap() + .to_str() + .unwrap(), + "alice" + ); + h.shutdown().await; +} + +#[tokio::test] +async fn permission_denied_for_other_workloads_secret() { + let h = Harness::start(&["alice", "bob"]).await; + let err = h + .client("bob") + .get_secret(GetSecretRequest { + name: Some("alice-only".into()), + ..Default::default() + }) + .await + .expect_err("bob cannot read alice's secret"); + assert_eq!(err.code, ErrorCode::PermissionDenied); + h.shutdown().await; +} + +#[tokio::test] +async fn client_without_cert_is_rejected_at_handshake() { + // WebPkiClientVerifier requires a client cert; a TLS client that + // presents none never reaches the handler. Distinct from the + // `extract_identity(None)` branch, which only fires for hosting + // setups that make client auth optional. + let h = Harness::start(&["alice"]).await; + let no_cert_cfg = Arc::new( + connectrpc::rustls::ClientConfig::builder() + .with_root_certificates(Arc::clone(&h.pki.roots)) + .with_no_client_auth(), + ); + let http = HttpClient::with_tls(no_cert_cfg); + let cfg = ClientConfig::new( + format!("https://localhost:{}", h.addr.port()) + .parse() + .unwrap(), + ); + let client = IdentityServiceClient::new(http, cfg); + let err = client + .who_am_i(WhoAmIRequest::default()) + .await + .expect_err("server must reject a client with no cert"); + // The handshake failure surfaces as a connection-level Unavailable, + // not a Connect-protocol error: the request never reaches HTTP. + assert_eq!(err.code, ErrorCode::Unavailable, "got: {err}"); + h.shutdown().await; +} From 516882a3f37010d64b592bf9767cd3117eef3198 Mon Sep 17 00:00:00 2001 From: Iain McGinniss <309153+iainmcgin@users.noreply.github.com> Date: Thu, 7 May 2026 16:39:19 +0000 Subject: [PATCH 3/3] examples(mtls-identity): tighten extract_identity comment State the intent (only direct subdomains of the workload domain are accepted) rather than describing one specific spoofing scenario, which understates what an attacker-controlled CA could do. --- examples/mtls-identity/src/lib.rs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/examples/mtls-identity/src/lib.rs b/examples/mtls-identity/src/lib.rs index c333cd3..7cbebd1 100644 --- a/examples/mtls-identity/src/lib.rs +++ b/examples/mtls-identity/src/lib.rs @@ -76,10 +76,9 @@ pub fn extract_identity(certs: Option<&PeerCerts>) -> Result.workloads.example.com` is a workload SAN. - // Reject `.workloads.example.com` (empty label) and - // `a.b.workloads.example.com` (would alias as `a.b`), which an - // attacker-controlled CA could otherwise mint to spoof an ACL - // entry. + // Reject `.workloads.example.com` and `a.b.workloads.example.com`: + // we intend to accept only direct subdomains of the workload + // domain. let GeneralName::DNSName(dns) = gn else { return None; }; @@ -353,7 +352,7 @@ mod tests { extract_identity(Some(&empty)).unwrap_err().code, ErrorCode::Unauthenticated ); - // Multi-label prefix: "a.b" would alias an ACL entry; reject it. + // Multi-label prefix: not a direct subdomain; reject it. let multi = peer_certs_with_dns_sans(&["a.b.workloads.example.com"]); assert_eq!( extract_identity(Some(&multi)).unwrap_err().code,