diff --git a/Cargo.toml b/Cargo.toml index bda9b33..febc8dd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = ["connectrpc", "connectrpc-codegen", "connectrpc-build", "conformance", "examples/eliza", "examples/middleware", "examples/multiservice", "examples/streaming-tour", "examples/wasm-client", "tests/streaming", "benches/rpc", "benches/rpc-tonic"] +members = ["connectrpc", "connectrpc-codegen", "connectrpc-build", "conformance", "examples/eliza", "examples/middleware", "examples/mtls-identity", "examples/multiservice", "examples/streaming-tour", "examples/wasm-client", "tests/streaming", "benches/rpc", "benches/rpc-tonic"] resolver = "2" [workspace.package] diff --git a/connectrpc/Cargo.toml b/connectrpc/Cargo.toml index 35bff22..c084756 100644 --- a/connectrpc/Cargo.toml +++ b/connectrpc/Cargo.toml @@ -102,6 +102,10 @@ server = [ "dep:libc", "dep:tower-http", "tokio/net", + # The accept loop uses `tokio::select!`. Surfaced only when downstream + # consumers enable `server` without also enabling `tokio/macros` + # themselves; `--all-targets` (CI) hides it via dev-dep unification. + "tokio/macros", ] # HTTP client with connection pooling diff --git a/connectrpc/src/axum.rs b/connectrpc/src/axum.rs new file mode 100644 index 0000000..ea040d5 --- /dev/null +++ b/connectrpc/src/axum.rs @@ -0,0 +1,599 @@ +//! TLS-aware `axum::serve` counterpart that exposes peer identity to handlers. +//! +//! [`Router::into_axum_service`](crate::Router::into_axum_service) and +//! [`Router::into_axum_router`](crate::Router::into_axum_router) cover the +//! plaintext path: mount your ConnectRPC routes on an `axum::Router` and +//! hand the result to `axum::serve`. This module fills the TLS gap. +//! +//! `axum::serve` accepts a plain [`TcpListener`] and has no hook for +//! terminating TLS. The standalone [`Server`](crate::Server), by contrast, +//! owns the rustls accept loop and so can capture [`PeerAddr`]/[`PeerCerts`] +//! once per connection and inject them into every request's extensions for +//! handlers to read via `ctx.extensions.get::()`. Without help, an axum + +//! mTLS deployment has to reimplement that accept loop and per-connection +//! plumbing by hand for handlers to get the same view. +//! +//! [`serve_tls`] is that help: it serves an `axum::Router`, terminates TLS, +//! captures peer identity, and stamps it into request extensions. Handler +//! code that reads `ctx.extensions.get::()` is then portable +//! between the standalone `Server` and an axum app — the hosting choice no +//! longer leaks into your authorization logic. +//! +//! ```rust,ignore +//! // Plaintext: axum's built-in serve. +//! axum::serve(listener, app).await?; +//! +//! // TLS with PeerAddr/PeerCerts passthrough. +//! connectrpc::axum::serve_tls(listener, app, tls_config).await?; +//! ``` +//! +//! # Differences from `axum::serve` +//! +//! `serve_tls` is the TLS counterpart to `axum::serve(listener, router)` for +//! the common `axum::Router` case. It is intentionally less generic: +//! +//! - **Service type.** `serve_tls` accepts a concrete `axum::Router`, not +//! the make-service forms `axum::serve` is generic over. There is no +//! `into_make_service_with_connect_info::()` equivalent because +//! `serve_tls` already injects [`PeerAddr`] (the same socket address) into +//! request extensions; read that instead of `ConnectInfo`. +//! A `Router` with state must have `.with_state(...)` applied first. +//! - **`PeerCerts` is conditional.** It is only inserted when the +//! [`rustls::ServerConfig`] requests client authentication *and* the peer +//! presents a chain rustls verifies. With `with_no_client_auth()` (or a +//! permissive verifier and a client that sends no cert), only [`PeerAddr`] +//! is present. Handlers must treat `ctx.extensions.get::()` as +//! optional. +//! - **ALPN.** The TLS terminator speaks the protocol ALPN selects. To allow +//! HTTP/2 (required for gRPC; preferred for Connect streaming), set +//! `server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]` +//! before passing it in. Without ALPN, hyper falls back to HTTP/1.1. +//! - **No automatic panic catching.** Unlike the standalone +//! [`Server`](crate::Server), `serve_tls` does not wrap your `axum::Router` +//! in `tower_http::catch_panic::CatchPanicLayer` (`axum::serve` doesn't +//! either). If you want a panicking handler to surface as a Connect error +//! rather than a dropped connection, add the layer yourself. +//! +//! Available only with both the `axum` and `server-tls` features enabled. + +use std::future::{Future, IntoFuture}; +use std::pin::Pin; +use std::sync::Arc; +use std::time::Duration; + +use hyper_util::rt::{TokioExecutor, TokioIo}; +use hyper_util::server::conn::auto::Builder as AutoBuilder; +use hyper_util::server::graceful::GracefulShutdown; +use tokio::net::TcpListener; +use tower::ServiceExt; + +use crate::server::{ + DEFAULT_TLS_HANDSHAKE_TIMEOUT, PeerAddr, PeerCerts, is_transient_accept_error, +}; + +/// Serve an `axum::Router` over TLS, exposing peer identity to handlers. +/// +/// The TLS counterpart to `axum::serve(listener, router)` for when handlers +/// need [`PeerAddr`] and [`PeerCerts`] in request extensions — the same +/// convention the standalone [`Server::with_tls`](crate::Server::with_tls) +/// uses. The accept loop terminates TLS with `tokio-rustls`, captures the +/// remote address and any verified client certificate chain, then injects +/// both into every request before handing off to the axum service. +/// [`PeerCerts`] is only present when `tls_config` requests client +/// authentication and the peer presented a chain rustls verified. +/// +/// Like the standalone server, the TLS handshake is bounded by a +/// [`DEFAULT_TLS_HANDSHAKE_TIMEOUT`] to prevent slowloris-style connection +/// exhaustion; tune it with [`ServeTls::tls_handshake_timeout`]. +/// +/// The returned [`ServeTls`] resolves once the listener stops accepting and +/// in-flight connections drain (after [`ServeTls::with_graceful_shutdown`]'s +/// signal fires) or when a non-transient accept error occurs. +/// +/// See the [module docs](self) for the differences from `axum::serve`, +/// including ALPN setup and panic-handling expectations. +/// +/// ```rust,no_run +/// # use std::sync::Arc; +/// # async fn demo(connect_router: connectrpc::Router, tls_config: Arc, +/// # shutdown_signal: tokio::sync::oneshot::Receiver<()>) +/// # -> Result<(), Box> { +/// let app = axum::Router::new() +/// .route("/health", axum::routing::get(|| async { "OK" })) +/// .fallback_service(connect_router.into_axum_service()); +/// +/// let listener = tokio::net::TcpListener::bind("0.0.0.0:8443").await?; +/// connectrpc::axum::serve_tls(listener, app, tls_config) +/// .with_graceful_shutdown(async { shutdown_signal.await.ok(); }) +/// .await?; +/// # Ok(()) +/// # } +/// ``` +/// +/// # Errors +/// +/// The future resolves to `Err` only for non-transient I/O errors from the +/// underlying `accept(2)` (for example, file-descriptor exhaustion that +/// persists past `EMFILE`/`ENFILE` retries, or a closed listener). Per-peer +/// failures — TLS handshake errors, handshake timeouts, and HTTP-layer errors +/// on a single connection — are logged at `debug`/`warn`/`trace` and never +/// abort the accept loop. +pub fn serve_tls( + listener: TcpListener, + router: axum::Router, + tls_config: Arc, +) -> ServeTls { + ServeTls { + listener, + router, + acceptor: tokio_rustls::TlsAcceptor::from(tls_config), + tls_handshake_timeout: DEFAULT_TLS_HANDSHAKE_TIMEOUT, + shutdown: None, + } +} + +/// Configurable future returned by [`serve_tls`]. +/// +/// Mirrors the shape of `axum::serve::Serve`: tweak it with builder +/// methods, then `.await` it (or pass it anywhere an `IntoFuture` is +/// accepted). +#[must_use = "ServeTls does nothing unless `.await`ed"] +pub struct ServeTls { + listener: TcpListener, + router: axum::Router, + acceptor: tokio_rustls::TlsAcceptor, + tls_handshake_timeout: Duration, + shutdown: Option + Send>>>, +} + +impl ServeTls { + /// Override the TLS handshake timeout (default + /// [`DEFAULT_TLS_HANDSHAKE_TIMEOUT`]). Set generously; clients on + /// high-latency links need a few round trips to complete the handshake. + #[must_use = "ServeTls does nothing unless `.await`ed"] + pub fn tls_handshake_timeout(mut self, timeout: Duration) -> Self { + self.tls_handshake_timeout = timeout; + self + } + + /// Stop accepting new connections when `signal` resolves and drain + /// in-flight connections before the future returned by [`serve_tls`] + /// resolves. Mirrors `axum::serve::Serve::with_graceful_shutdown`. + #[must_use = "ServeTls does nothing unless `.await`ed"] + pub fn with_graceful_shutdown(mut self, signal: F) -> Self + where + F: Future + Send + 'static, + { + self.shutdown = Some(Box::pin(signal)); + self + } +} + +impl std::fmt::Debug for ServeTls { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ServeTls") + .field("listener", &self.listener) + .field("tls_handshake_timeout", &self.tls_handshake_timeout) + .field("shutdown", &self.shutdown.is_some()) + .finish_non_exhaustive() + } +} + +impl IntoFuture for ServeTls { + type Output = std::io::Result<()>; + type IntoFuture = Pin + Send>>; + + fn into_future(self) -> Self::IntoFuture { + Box::pin(self.run()) + } +} + +impl ServeTls { + async fn run(self) -> std::io::Result<()> { + let ServeTls { + listener, + router, + acceptor, + tls_handshake_timeout, + shutdown, + } = self; + + // `select!` needs a polled-in-place future for both arms; default to + // a never-resolving signal when no graceful shutdown is configured. + let mut shutdown = shutdown.unwrap_or_else(|| Box::pin(std::future::pending())); + let graceful = GracefulShutdown::new(); + + loop { + let (stream, remote_addr) = tokio::select! { + biased; // honor shutdown before another accept + + _ = &mut shutdown => { + tracing::info!("Shutdown signal received; draining connections"); + break; + } + accepted = listener.accept() => match accepted { + Ok(conn) => conn, + Err(err) if is_transient_accept_error(&err) => { + tracing::warn!("Transient accept error (continuing): {err}"); + continue; + } + Err(err) => return Err(err), + }, + }; + + // Same TCP_NODELAY rationale as the standalone Server: avoid + // Nagle/delayed-ACK interaction on small HTTP/2 control frames. + if let Err(e) = stream.set_nodelay(true) { + tracing::warn!("failed to set TCP_NODELAY: {e}"); + } + + let acceptor = acceptor.clone(); + let router = router.clone(); + let watcher = graceful.watcher(); + + tokio::spawn(async move { + let tls_stream = match tokio::time::timeout( + tls_handshake_timeout, + acceptor.accept(stream), + ) + .await + { + Ok(Ok(s)) => s, + Ok(Err(err)) => { + tracing::debug!(remote_addr = %remote_addr, error = ?err, "TLS handshake failed"); + return; + } + Err(_) => { + tracing::warn!( + remote_addr = %remote_addr, + "TLS handshake timed out after {tls_handshake_timeout:?}", + ); + return; + } + }; + + // Capture peer info now — once hyper owns the stream we can't + // borrow it again. `into_owned()` detaches the cert bytes from + // the session lifetime so the Arc can outlive the TlsStream. + let (_, conn) = tls_stream.get_ref(); + let peer_addr = PeerAddr(remote_addr); + let peer_certs = conn + .peer_certificates() + .map(|chain| PeerCerts(chain.iter().map(|c| c.clone().into_owned()).collect())); + + // Per-request: stamp peer info into extensions and forward to + // the axum service. `Router::clone()` is an Arc bump. + let svc = hyper::service::service_fn( + move |mut req: hyper::Request| { + req.extensions_mut().insert(peer_addr.clone()); + if let Some(c) = &peer_certs { + req.extensions_mut().insert(c.clone()); + } + router.clone().oneshot(req.map(axum::body::Body::new)) + }, + ); + + // `serve_connection_with_upgrades` (vs `serve_connection` on the + // standalone `Server`) so axum routes that need HTTP `Upgrade:` + // (WebSockets) work out of the box. ConnectRPC routes don't + // upgrade, so this is a no-op for them. Keep this divergence — + // it matches what `axum::serve` does internally. + let conn = AutoBuilder::new(TokioExecutor::new()) + .serve_connection_with_upgrades(TokioIo::new(tls_stream), svc) + .into_owned(); + if let Err(err) = watcher.watch(conn).await { + tracing::trace!(remote_addr = %remote_addr, error = %err, "Connection ended with error"); + } + }); + } + + // Stop accepting before signalling the drain so no new connection + // sneaks in between the watcher snapshot and the listener close. + drop(listener); + graceful.shutdown().await; + tracing::info!("All connections drained; shutdown complete"); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Mutex; + + use crate::{Response as ConnectResponse, Router as ConnectRouter, handler_fn}; + use rcgen::{CertificateParams, CertifiedIssuer, IsCa, KeyPair, SanType}; + use rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer}; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + type Pki = ( + Arc, + Arc, + CertificateDer<'static>, + ); + + /// Minimal in-memory mTLS PKI: one CA, one server leaf, one client leaf. + /// Returns `(server_config, client_config, client_leaf_der)`. + fn pki() -> Pki { + // Idempotent; err == already installed (tests share process state). + let _ = rustls::crypto::aws_lc_rs::default_provider().install_default(); + + let ca_key = KeyPair::generate().unwrap(); + let mut ca_params = CertificateParams::default(); + ca_params.is_ca = IsCa::Ca(rcgen::BasicConstraints::Unconstrained); + let ca = CertifiedIssuer::self_signed(ca_params, ca_key).unwrap(); + + let issue = |sans: &[SanType]| { + let k = KeyPair::generate().unwrap(); + let mut p = CertificateParams::default(); + p.subject_alt_names = sans.to_vec(); + let c = p.signed_by(&k, &ca).unwrap(); + ( + CertificateDer::from(c.der().to_vec()), + PrivatePkcs8KeyDer::from(k.serialized_der().to_vec()).into(), + ) + }; + let (srv_cert, srv_key) = issue(&[SanType::DnsName("localhost".try_into().unwrap())]); + let (cli_cert, cli_key) = issue(&[]); + + let mut roots = rustls::RootCertStore::empty(); + roots.add(CertificateDer::from(ca.der().to_vec())).unwrap(); + let roots = Arc::new(roots); + + let cv = rustls::server::WebPkiClientVerifier::builder(Arc::clone(&roots)) + .build() + .unwrap(); + let server = rustls::ServerConfig::builder() + .with_client_cert_verifier(cv) + .with_single_cert(vec![srv_cert], srv_key) + .unwrap(); + let client = rustls::ClientConfig::builder() + .with_root_certificates(roots) + .with_client_auth_cert(vec![cli_cert.clone()], cli_key) + .unwrap(); + (Arc::new(server), Arc::new(client), cli_cert) + } + + /// HTTP/1.1 Connect unary request matching `server.rs`'s fixture. + const ECHO_REQ: &[u8] = b"POST /svc/Echo HTTP/1.1\r\n\ + Host: localhost\r\n\ + Content-Type: application/proto\r\n\ + Content-Length: 0\r\n\ + Connection: close\r\n\ + \r\n"; + + #[tokio::test] + async fn serve_tls_injects_peer_identity() { + let (server_cfg, client_cfg, expected_client_der) = pki(); + + // The handler stashes whatever PeerAddr/PeerCerts it sees. + type Captured = Arc)>>>; + let captured: Captured = Arc::new(Mutex::new(None)); + let handler_captured = Arc::clone(&captured); + let connect = ConnectRouter::new().route( + "svc", + "Echo", + handler_fn( + move |ctx: crate::RequestContext, _req: buffa_types::Empty| { + let cap = Arc::clone(&handler_captured); + async move { + *cap.lock().unwrap() = Some(( + ctx.extensions.get::().cloned().unwrap(), + ctx.extensions.get::().cloned(), + )); + ConnectResponse::ok(buffa_types::Empty::default()) + } + }, + ), + ); + let app = axum::Router::new().fallback_service(connect.into_axum_service()); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (tx, rx) = tokio::sync::oneshot::channel(); + let serve = tokio::spawn( + serve_tls(listener, app, server_cfg) + .with_graceful_shutdown(async { + rx.await.ok(); + }) + .into_future(), + ); + + let resp = echo_over_tls(addr, client_cfg).await; + assert!( + resp.starts_with(b"HTTP/1.1 2"), + "expected 2xx, got: {}", + String::from_utf8_lossy(&resp[..resp.len().min(120)]) + ); + + // Graceful shutdown should drain and resolve the serve task. + tx.send(()).unwrap(); + tokio::time::timeout(Duration::from_secs(5), serve) + .await + .expect("serve should shut down within timeout") + .unwrap() + .unwrap(); + + let (peer_addr, peer_certs) = captured.lock().unwrap().take().expect("handler ran"); + assert_eq!(peer_addr.0.ip(), addr.ip()); + let certs = peer_certs.expect("mTLS client should present a cert chain"); + assert_eq!(certs.0.len(), 1); + assert_eq!(certs.0[0].as_ref(), expected_client_der.as_ref()); + } + + /// Open a TLS+HTTP/1.1 connection, send `ECHO_REQ`, and return the raw + /// HTTP response bytes. + async fn echo_over_tls( + addr: std::net::SocketAddr, + client_cfg: Arc, + ) -> Vec { + let tcp = tokio::net::TcpStream::connect(addr).await.unwrap(); + let connector = tokio_rustls::TlsConnector::from(client_cfg); + let sni = rustls::pki_types::ServerName::try_from("localhost").unwrap(); + let mut tls = connector.connect(sni, tcp).await.unwrap(); + tls.write_all(ECHO_REQ).await.unwrap(); + let mut resp = Vec::new(); + tls.read_to_end(&mut resp).await.unwrap(); + resp + } + + #[tokio::test] + async fn handshake_timeout_drops_stalled_connection() { + let (server_cfg, _, _) = pki(); + let app = axum::Router::new(); + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (tx, rx) = tokio::sync::oneshot::channel(); + let serve = tokio::spawn( + serve_tls(listener, app, server_cfg) + .tls_handshake_timeout(Duration::from_millis(100)) + .with_graceful_shutdown(async { + rx.await.ok(); + }) + .into_future(), + ); + + // Open TCP but never speak TLS, and keep it open through shutdown. + // If the handshake timeout doesn't release this connection's watcher, + // the graceful drain blocks until the outer timeout fails the test. + let _stalled = tokio::net::TcpStream::connect(addr).await.unwrap(); + // Generous margin so the accept loop spawns the per-connection task + // (and its watcher) before we signal shutdown — otherwise the test + // passes vacuously without exercising the timeout path. + tokio::time::sleep(Duration::from_millis(250)).await; + + tx.send(()).unwrap(); + tokio::time::timeout(Duration::from_secs(5), serve) + .await + .expect("handshake timeout must release the watcher so drain completes") + .unwrap() + .unwrap(); + } + + #[tokio::test] + async fn handshake_error_does_not_kill_accept_loop() { + let (server_cfg, client_cfg, _) = pki(); + let calls = Arc::new(Mutex::new(0u32)); + let handler_calls = Arc::clone(&calls); + let connect = ConnectRouter::new().route( + "svc", + "Echo", + handler_fn( + move |_ctx: crate::RequestContext, _req: buffa_types::Empty| { + let calls = Arc::clone(&handler_calls); + async move { + *calls.lock().unwrap() += 1; + ConnectResponse::ok(buffa_types::Empty::default()) + } + }, + ), + ); + let app = axum::Router::new().fallback_service(connect.into_axum_service()); + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (tx, rx) = tokio::sync::oneshot::channel(); + let serve = tokio::spawn( + serve_tls(listener, app, server_cfg) + .with_graceful_shutdown(async { + rx.await.ok(); + }) + .into_future(), + ); + + // Speak garbage instead of a ClientHello: the rustls handshake fails + // immediately. The accept loop must log-and-continue, not propagate. + let mut bad = tokio::net::TcpStream::connect(addr).await.unwrap(); + bad.write_all(b"GET / HTTP/1.1\r\n\r\n").await.unwrap(); + let mut buf = [0u8; 64]; + let _ = bad.read(&mut buf).await; // server closes / sends a TLS alert + drop(bad); + + // A valid client must still get through. + let resp = echo_over_tls(addr, client_cfg).await; + assert!( + resp.starts_with(b"HTTP/1.1 2"), + "valid client must succeed after a handshake error: {}", + String::from_utf8_lossy(&resp[..resp.len().min(120)]) + ); + + tx.send(()).unwrap(); + tokio::time::timeout(Duration::from_secs(5), serve) + .await + .unwrap() + .unwrap() + .unwrap(); + assert_eq!( + *calls.lock().unwrap(), + 1, + "only the valid client reaches the handler" + ); + } + + #[tokio::test] + async fn graceful_shutdown_drains_in_flight_request() { + let (server_cfg, client_cfg, _) = pki(); + + // The handler blocks until the test releases it; this lets us pin a + // request as "in-flight" across the shutdown signal. + let (in_flight_tx, in_flight_rx) = tokio::sync::oneshot::channel::<()>(); + let (release_tx, release_rx) = tokio::sync::oneshot::channel::<()>(); + let in_flight_tx = Arc::new(Mutex::new(Some(in_flight_tx))); + let release_rx = Arc::new(Mutex::new(Some(release_rx))); + let connect = ConnectRouter::new().route( + "svc", + "Echo", + handler_fn( + move |_ctx: crate::RequestContext, _req: buffa_types::Empty| { + let in_flight = in_flight_tx.lock().unwrap().take(); + let release = release_rx.lock().unwrap().take(); + async move { + if let Some(tx) = in_flight { + tx.send(()).ok(); + } + if let Some(rx) = release { + rx.await.ok(); + } + ConnectResponse::ok(buffa_types::Empty::default()) + } + }, + ), + ); + let app = axum::Router::new().fallback_service(connect.into_axum_service()); + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel(); + let serve = tokio::spawn( + serve_tls(listener, app, server_cfg) + .with_graceful_shutdown(async { + shutdown_rx.await.ok(); + }) + .into_future(), + ); + + let client = tokio::spawn(echo_over_tls(addr, client_cfg)); + + // Once the request is in-flight, signal shutdown. The watcher held by + // the per-connection task must anchor it until the handler returns. + in_flight_rx.await.unwrap(); + shutdown_tx.send(()).unwrap(); + + // Release the handler: the in-flight request must complete cleanly + // (proving the connection wasn't torn down by the shutdown), and only + // then should the serve future drain. + release_tx.send(()).unwrap(); + let resp = tokio::time::timeout(Duration::from_secs(5), client) + .await + .expect("in-flight request should complete during drain") + .unwrap(); + assert!( + resp.starts_with(b"HTTP/1.1 2"), + "in-flight request must complete: {}", + String::from_utf8_lossy(&resp[..resp.len().min(120)]) + ); + tokio::time::timeout(Duration::from_secs(5), serve) + .await + .expect("serve should drain after the in-flight request completes") + .unwrap() + .unwrap(); + } +} diff --git a/connectrpc/src/lib.rs b/connectrpc/src/lib.rs index a667395..3207fed 100644 --- a/connectrpc/src/lib.rs +++ b/connectrpc/src/lib.rs @@ -171,6 +171,15 @@ pub mod client; #[cfg(feature = "server")] pub mod server; +// Optional: TLS-aware `axum::serve` counterpart with peer-identity passthrough. +// +// Note: this module shadows the extern-prelude `axum` crate within the crate +// root scope only. Don't add `use axum::...` here in `lib.rs`; use +// `::axum::...` if a root-level reference to the external crate is ever needed. +#[cfg(all(feature = "axum", feature = "server-tls"))] +#[cfg_attr(docsrs, doc(cfg(all(feature = "axum", feature = "server-tls"))))] +pub mod axum; + // ============================================================================ // Primary exports - Tower-first API // ============================================================================ diff --git a/connectrpc/src/server.rs b/connectrpc/src/server.rs index 9198346..6d0b7e0 100644 --- a/connectrpc/src/server.rs +++ b/connectrpc/src/server.rs @@ -62,8 +62,9 @@ use crate::service::ConnectRpcService; /// Remote socket address of the connected peer. /// -/// Inserted into every request's extensions by the built-in server's accept -/// loop. Handlers read it via `ctx.extensions.get::()`. +/// Inserted into every request's extensions by the built-in [`Server`]'s +/// accept loop and by `connectrpc::axum::serve_tls`. Handlers read it via +/// `ctx.extensions.get::()`. /// /// Callers using a different HTTP stack (axum, raw hyper) in front of /// [`ConnectRpcService`] can insert this same type @@ -73,11 +74,11 @@ pub struct PeerAddr(pub SocketAddr); /// TLS client certificate chain presented by the peer (leaf first). /// -/// Inserted by the built-in server's TLS accept loop when the -/// [`rustls::ServerConfig`] requests client authentication and the peer -/// presents a valid chain. Absent on plaintext connections or when the -/// client presents no certificate. Handlers read it via -/// `ctx.extensions.get::()`. +/// Inserted by the built-in [`Server`]'s TLS accept loop and by +/// `connectrpc::axum::serve_tls` when the [`rustls::ServerConfig`] requests +/// client authentication and the peer presents a valid chain. Absent on +/// plaintext connections or when the client presents no certificate. +/// Handlers read it via `ctx.extensions.get::()`. /// /// The `Arc` makes per-request insertion cheap: all requests on a /// connection share one chain, so this is a refcount bump, not a copy. @@ -681,7 +682,7 @@ fn panic_handler(err: Box) -> Response> { /// - `EMFILE` / `ENFILE`: Too many open files (file descriptor exhaustion) /// - `ECONNABORTED`: Connection was aborted before accept completed /// - `EINTR`: Interrupted system call -fn is_transient_accept_error(err: &std::io::Error) -> bool { +pub(crate) fn is_transient_accept_error(err: &std::io::Error) -> bool { use std::io::ErrorKind; matches!( diff --git a/docs/guide.md b/docs/guide.md index 3c0fda5..4ca4bee 100644 --- a/docs/guide.md +++ b/docs/guide.md @@ -676,14 +676,33 @@ Server::new(connect_router) .await?; ``` -For the axum path, wrap the listener with `tokio_rustls::TlsAcceptor` -yourself (this is what the eliza example does). +For the axum path, `connectrpc::axum::serve_tls` (requires both the +`axum` and `server-tls` features) is a drop-in replacement for +`axum::serve` that owns the rustls accept loop and stamps `PeerAddr` / +`PeerCerts` into request extensions exactly as the standalone `Server` +does, so handler code that reads `ctx.extensions.get::()` +is portable across both hosting paths: + +```rust +let app = axum::Router::new() + .route("/health", axum::routing::get(|| async { "OK" })) + .fallback_service(connect_router.into_axum_service()); + +let listener = tokio::net::TcpListener::bind("0.0.0.0:8443").await?; +connectrpc::axum::serve_tls(listener, app, server_config) + .with_graceful_shutdown(shutdown_signal) + .await?; +``` The eliza example ([`examples/eliza/README.md`](../examples/eliza/README.md)) walks through generating self-signed certificates with openssl, configuring mTLS via `--client-ca`, and the rustls strict-PKI requirement that -your CA cert must be distinct from the server leaf cert. +your CA cert must be distinct from the server leaf cert. The +mtls-identity example +([`examples/mtls-identity/README.md`](../examples/mtls-identity/README.md)) +demonstrates `serve_tls` end-to-end with cert-SAN identity extraction +and an ACL keyed on it. ## Clients @@ -905,6 +924,7 @@ let service = ConnectRpcService::new(router).with_compression(registry); |---|---| | [`streaming-tour/`](../examples/streaming-tour) | All four RPC types (unary, server stream, client stream, bidi) on a trivial NumberService. Smallest demo of handler signatures and client invocation patterns. | | [`middleware/`](../examples/middleware) | Server-side tower middleware composition: an `axum::middleware::from_fn` bearer-token auth, identity passthrough via `RequestContext::extensions`, response trailers via `Response::with_trailer`. Client demos `ClientConfig::default_header` and `CallOptions::with_timeout`. | +| [`mtls-identity/`](../examples/mtls-identity) | mTLS twin of `middleware/`: axum hosted behind `connectrpc::axum::serve_tls`, identity from the client cert's DNS SAN via `PeerCerts` instead of a bearer token, ACL keyed on the cert-derived identity. In-memory `rcgen` PKI; no PEM files. | | [`eliza/`](../examples/eliza) | Production-shaped streaming app: a port of the `connectrpc/examples-go` ELIZA demo. Server-streaming Introduce + bidi-streaming Converse, TLS, mTLS, CORS, IPv6, both server and client binaries, interoperates with the hosted Go reference at `demo.connectrpc.com`. | | [`multiservice/`](../examples/multiservice) | Multiple proto packages compiled together with `buf generate`, multiple services on one server, well-known type usage. | | [`wasm-client/`](../examples/wasm-client) | Browser fetch transport: same generated client used from `wasm32-unknown-unknown` with a custom `ClientTransport` backed by `web-sys::fetch`. | diff --git a/examples/mtls-identity/Cargo.toml b/examples/mtls-identity/Cargo.toml new file mode 100644 index 0000000..07c6f6d --- /dev/null +++ b/examples/mtls-identity/Cargo.toml @@ -0,0 +1,50 @@ +[package] +name = "mtls-identity-example" +version = "0.1.0" +edition.workspace = true +rust-version.workspace = true +license.workspace = true +publish = false + +[lib] +name = "mtls_identity_example" +path = "src/lib.rs" + +[[bin]] +name = "mtls-identity" +path = "src/main.rs" + +[dependencies] +connectrpc = { path = "../../connectrpc", features = ["axum", "client", "client-tls", "server-tls"] } + +# Protobuf +buffa = { workspace = true } +buffa-types = { workspace = true } + +# Serialization (for generated message types) +serde = { workspace = true } +serde_json = { workspace = true } + +# HTTP types (for generated client bounds) +http-body = { workspace = true } +http = { workspace = true } + +# Async +tokio = { workspace = true, features = ["full"] } +futures = { workspace = true } + +# Web framework +axum = { workspace = true } + +# TLS (test PKI generated at runtime; no PEM files in the repo) +rcgen = { workspace = true } +rustls-pki-types = { workspace = true } +# Parse the leaf cert's SubjectAlternativeName in the handler. Already in +# the lockfile transitively via rcgen, so this adds nothing to the build. +x509-parser = "0.18" + +[build-dependencies] +connectrpc-build = { path = "../../connectrpc-build" } + +[lints] +workspace = true diff --git a/examples/mtls-identity/README.md b/examples/mtls-identity/README.md new file mode 100644 index 0000000..5b2e7d0 --- /dev/null +++ b/examples/mtls-identity/README.md @@ -0,0 +1,109 @@ +# mTLS identity example + +Demonstrates cert-SAN-based identity for an axum-hosted ConnectRPC +service. The server is hosted on axum behind +`connectrpc::axum::serve_tls`, which terminates TLS, captures the +verified client certificate chain and remote address, and stamps them +into request extensions as `PeerCerts` / `PeerAddr` — the same +convention the standalone `connectrpc::Server::with_tls` uses. The +handler parses the leaf cert's DNS SAN to derive a workload identity +and enforces an ACL against it. + +This is the mTLS twin of [`examples/middleware/`](../middleware): same +secret-store-with-ACL shape, but the credential is a client certificate +instead of a `Bearer` token. The handler-side code that reads +`ctx.extensions.get::()` is unchanged; only what the layer/accept +loop puts into extensions differs. + +## Run it + +The demo is a single self-contained binary: it generates an in-memory +PKI (CA, server cert, two workload client certs) with `rcgen`, starts +the server, makes a few calls with each identity, and shuts down. No +PEM files touch disk. + +```bash +cargo run -p mtls-identity-example +``` + +Expected output (port and source ports vary): + +``` +IdentityService listening on https://127.0.0.1:PORT (mTLS required) + +[alice] WhoAmI -> identity="alice" san="alice.workloads.example.com" from="127.0.0.1:..." +[bob] WhoAmI -> identity="bob" san="bob.workloads.example.com" from="127.0.0.1:..." + +[alice] GetSecret( shared) -> "the value of teamwork" (x-served-by: alice) +[alice] GetSecret(alice-only) -> "alice's diary entry" (x-served-by: alice) +[bob] GetSecret( shared) -> "the value of teamwork" (x-served-by: bob) +[bob] GetSecret(alice-only) -> permission_denied: workload "bob" (bob.workloads.example.com) cannot read "alice-only" +``` + +## What to look at + +### `serve_tls` instead of `axum::serve` (`src/lib.rs::serve`) + +`axum::serve` accepts a plain `TcpListener` with no hook for +terminating TLS, so an axum + mTLS deployment normally has to write a +rustls accept loop by hand. `connectrpc::axum::serve_tls` is a drop-in +replacement that owns that loop and stamps `PeerAddr` / `PeerCerts` +into request extensions: + +```rust +let app = axum::Router::new().fallback_service(connect_router.into_axum_service()); +connectrpc::axum::serve_tls(listener, app, server_config) + .with_graceful_shutdown(shutdown) + .await?; +``` + +Handler code that reads `ctx.extensions.get::()` is then +portable between the standalone `Server::with_tls` and an axum app. + +### Cert-SAN identity (`src/lib.rs::extract_identity`) + +The handler reads the leaf cert from `PeerCerts`, parses its DNS SAN +with `x509-parser`, and derives a short workload name from a SAN under +`workloads.example.com`. A real deployment would typically match a +SPIFFE ID (a URI SAN) instead, or hand the whole step to an +authorization framework — the shape is the same: read `PeerCerts`, +parse the leaf, derive an identity. + +Two failure modes both surface as `Unauthenticated`: + +- No client cert presented: only reachable if the server's + `ClientCertVerifier` made client auth optional. This example uses + `WebPkiClientVerifier`, which *requires* a verified chain, so this + path is dead in practice — kept as defense in depth. +- A cert is presented but no SAN matches the workload domain. + +### In-memory PKI (`src/lib.rs::pki`) + +`pki::generate(&["alice", "bob"])` builds a CA, a server leaf +(`SAN = localhost`), and one client leaf per workload +(`SAN = .workloads.example.com`), all in memory via `rcgen`. A +deployment would load these from a secret store; the rustls types are +identical. + +The server config requires *and verifies* client certs against the +demo CA, so the chain that reaches the handler is always verified — +the SAN parsing only has to decide *which* trusted client this is, not +whether to trust it. + +## Integration test + +`tests/e2e.rs` exercises four paths: identity reflection (`WhoAmI`), +authorized read with response trailer, permission denied (bob reading +alice's secret), and a TLS client without a cert being rejected at the +handshake before the request reaches HTTP. + +```bash +cargo test -p mtls-identity-example +``` + +## Where to go next + +- See [`examples/middleware`](../middleware) for the bearer-token + equivalent of this example. +- See [`examples/eliza`](../eliza) for loading certs from PEM files + with `--cert`/`--key`/`--client-ca` CLI flags. diff --git a/examples/mtls-identity/build.rs b/examples/mtls-identity/build.rs new file mode 100644 index 0000000..cd58f25 --- /dev/null +++ b/examples/mtls-identity/build.rs @@ -0,0 +1,8 @@ +fn main() { + connectrpc_build::Config::new() + .files(&["proto/anthropic/connectrpc/mtls_identity/v1/identity.proto"]) + .includes(&["proto/"]) + .include_file("_connectrpc.rs") + .compile() + .unwrap(); +} diff --git a/examples/mtls-identity/proto/anthropic/connectrpc/mtls_identity/v1/identity.proto b/examples/mtls-identity/proto/anthropic/connectrpc/mtls_identity/v1/identity.proto new file mode 100644 index 0000000..b7f70c2 --- /dev/null +++ b/examples/mtls-identity/proto/anthropic/connectrpc/mtls_identity/v1/identity.proto @@ -0,0 +1,35 @@ +edition = "2023"; + +package anthropic.connectrpc.mtls_identity.v1; + +// IdentityService demonstrates mTLS cert-SAN-based identity. The server +// is hosted on axum behind connectrpc::axum::serve_tls, which terminates +// TLS and stamps PeerAddr/PeerCerts into request extensions. The handler +// parses the leaf cert's DNS SAN to derive a workload identity rather +// than reading a bearer token, then enforces an ACL against it. +service IdentityService { + // Echo back what the server knows about the caller from the TLS layer. + rpc WhoAmI(WhoAmIRequest) returns (WhoAmIResponse); + + // Read a secret, gated on the caller's cert-derived identity. + rpc GetSecret(GetSecretRequest) returns (GetSecretResponse); +} + +message WhoAmIRequest {} + +message WhoAmIResponse { + // Workload identity parsed from the leaf cert's DNS SAN, e.g. "alice". + string identity = 1; + // Full DNS SAN as presented, e.g. "alice.workloads.example.com". + string san = 2; + // Caller's remote socket address as observed by the accept loop. + string remote_addr = 3; +} + +message GetSecretRequest { + string name = 1; +} + +message GetSecretResponse { + string value = 1; +} diff --git a/examples/mtls-identity/src/lib.rs b/examples/mtls-identity/src/lib.rs new file mode 100644 index 0000000..7cbebd1 --- /dev/null +++ b/examples/mtls-identity/src/lib.rs @@ -0,0 +1,371 @@ +//! mTLS cert-SAN identity for axum-hosted ConnectRPC services. +//! +//! Mirrors `examples/middleware/`, swapping bearer-token auth for mTLS: +//! instead of an `axum::middleware::from_fn` reading an `Authorization` +//! header, identity comes from the verified client certificate that +//! `connectrpc::axum::serve_tls` captures during the TLS handshake and +//! stamps into request extensions as [`PeerCerts`]. The handler parses +//! the leaf cert's DNS SAN to derive a workload identity, then enforces +//! an ACL against it. +//! +//! The same handler code works unchanged on the standalone +//! [`connectrpc::Server::with_tls`], which populates [`PeerCerts`] the +//! same way — the hosting choice doesn't leak into authorization logic. + +use std::collections::HashMap; +use std::sync::Arc; + +use buffa::view::OwnedView; +use connectrpc::{ + ConnectError, ErrorCode, PeerAddr, PeerCerts, RequestContext, Router, ServiceResult, +}; + +pub mod proto { + connectrpc::include_generated!(); +} + +pub use proto::anthropic::connectrpc::mtls_identity::v1::*; + +pub type BoxError = Box; + +// ============================================================================ +// Identity: derive a workload name from the leaf cert's DNS SAN. +// ============================================================================ + +/// All clients in this demo carry a SAN under this domain. Anything else +/// is rejected as `Unauthenticated`. +pub const WORKLOAD_DOMAIN: &str = "workloads.example.com"; + +/// Caller identity, parsed from the leaf certificate's DNS SAN. +#[derive(Debug, Clone)] +pub struct Identity { + /// Short workload name, e.g. `"alice"`. + pub name: String, + /// Full DNS SAN as presented, e.g. `"alice.workloads.example.com"`. + pub san: String, +} + +/// Parse a workload identity out of the leaf certificate's DNS SANs. +/// +/// Returns the first SAN under [`WORKLOAD_DOMAIN`]. Returns +/// `Unauthenticated` when no client cert was presented (a non-mTLS +/// connection) or no SAN matches the expected domain. +/// +/// In a real deployment you'd typically match a SPIFFE ID +/// (`spiffe://trust-domain/path`, a URI SAN) instead of a DNS SAN, or +/// delegate this whole step to an authorization framework. The shape is +/// the same: read [`PeerCerts`], parse the leaf, derive an identity. +pub fn extract_identity(certs: Option<&PeerCerts>) -> Result { + use x509_parser::extensions::GeneralName; + use x509_parser::prelude::{FromDer, X509Certificate}; + + let leaf = certs.and_then(|c| c.0.first()).ok_or_else(|| { + ConnectError::new(ErrorCode::Unauthenticated, "client certificate required") + })?; + + let (_, parsed) = X509Certificate::from_der(leaf.as_ref()).map_err(|e| { + ConnectError::new(ErrorCode::Unauthenticated, format!("bad client cert: {e}")) + })?; + + let suffix = format!(".{WORKLOAD_DOMAIN}"); + parsed + .subject_alternative_name() + .ok() + .flatten() + .into_iter() + .flat_map(|ext| ext.value.general_names.iter()) + .find_map(|gn| { + // Only `.workloads.example.com` is a workload SAN. + // Reject `.workloads.example.com` and `a.b.workloads.example.com`: + // we intend to accept only direct subdomains of the workload + // domain. + let GeneralName::DNSName(dns) = gn else { + return None; + }; + let prefix = dns.strip_suffix(&suffix)?; + if prefix.is_empty() || prefix.contains('.') { + return None; + } + Some(Identity { + name: prefix.to_owned(), + san: (*dns).to_owned(), + }) + }) + .ok_or_else(|| { + ConnectError::new( + ErrorCode::Unauthenticated, + format!("client cert has no workload SAN under {WORKLOAD_DOMAIN}"), + ) + }) +} + +// ============================================================================ +// IdentityService handler +// ============================================================================ + +/// Static secret store. Each secret declares which workloads may read it. +pub fn secret_store() -> HashMap)> { + HashMap::from([ + ( + "shared".into(), + ("the value of teamwork".into(), vec!["alice", "bob"]), + ), + ( + "alice-only".into(), + ("alice's diary entry".into(), vec!["alice"]), + ), + ]) +} + +pub struct IdentityServiceImpl { + pub store: HashMap)>, +} + +impl IdentityService for IdentityServiceImpl { + async fn who_am_i( + &self, + ctx: RequestContext, + _request: OwnedView>, + ) -> ServiceResult { + // Both PeerCerts and PeerAddr are stamped per connection by + // serve_tls; the dispatcher copies request extensions verbatim + // into ctx.extensions. + let id = extract_identity(ctx.extensions.get::())?; + let remote = ctx + .extensions + .get::() + .map(|a| a.0.to_string()) + .unwrap_or_default(); + connectrpc::Response::ok(WhoAmIResponse { + identity: Some(id.name), + san: Some(id.san), + remote_addr: Some(remote), + ..Default::default() + }) + } + + async fn get_secret( + &self, + ctx: RequestContext, + request: OwnedView>, + ) -> ServiceResult { + let id = extract_identity(ctx.extensions.get::())?; + let name = request.name.unwrap_or("").to_owned(); + let (value, allowed) = self.store.get(&name).ok_or_else(|| { + ConnectError::new(ErrorCode::NotFound, format!("no secret named {name:?}")) + })?; + if !allowed.iter().any(|w| *w == id.name) { + return Err(ConnectError::new( + ErrorCode::PermissionDenied, + format!("workload {:?} ({}) cannot read {name:?}", id.name, id.san), + )); + } + Ok(connectrpc::Response::new(GetSecretResponse { + value: Some(value.clone()), + ..Default::default() + }) + .with_trailer("x-served-by", id.name)) + } +} + +// ============================================================================ +// Server hosting: axum app behind connectrpc::axum::serve_tls. +// ============================================================================ + +/// Build the axum app and serve it over TLS until `shutdown` resolves. +/// +/// This is the only line that differs from a plaintext axum app: +/// `connectrpc::axum::serve_tls` instead of `axum::serve`. +pub async fn serve( + listener: tokio::net::TcpListener, + server_config: Arc, + shutdown: impl std::future::Future + Send + 'static, +) -> std::io::Result<()> { + let svc = Arc::new(IdentityServiceImpl { + store: secret_store(), + }); + let connect_router = svc.register(Router::new()); + let app = axum::Router::new().fallback_service(connect_router.into_axum_service()); + + connectrpc::axum::serve_tls(listener, app, server_config) + .with_graceful_shutdown(shutdown) + .await +} + +// ============================================================================ +// In-memory PKI: one CA, one server leaf, N client leafs (DNS-SAN identities). +// ============================================================================ + +pub mod pki { + use std::sync::Arc; + + use connectrpc::rustls; + use rcgen::{BasicConstraints, CertificateParams, CertifiedIssuer, IsCa, KeyPair, SanType}; + use rustls_pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer}; + + /// One client credential: a leaf cert (with a SAN under the workload + /// domain) and its private key. + pub struct ClientCredential { + pub cert: CertificateDer<'static>, + pub key: PrivateKeyDer<'static>, + } + + /// In-memory PKI for the demo. + pub struct Pki { + /// Server [`rustls::ServerConfig`] requiring client certs signed + /// by the demo CA. + pub server_config: Arc, + /// Trust roots containing only the demo CA, for building client + /// configs. + pub roots: Arc, + /// Per-workload client credentials, keyed by short name. + pub clients: std::collections::HashMap, + } + + impl Pki { + /// Build a [`rustls::ClientConfig`] that trusts the demo CA and + /// presents the named workload's client cert during the handshake. + /// + /// # Panics + /// + /// Panics if `workload` isn't one of the names passed to [`generate`]. + pub fn client_config(&self, workload: &str) -> Arc { + let cred = self + .clients + .get(workload) + .unwrap_or_else(|| panic!("no credential for workload {workload:?}")); + Arc::new( + rustls::ClientConfig::builder() + .with_root_certificates(Arc::clone(&self.roots)) + .with_client_auth_cert(vec![cred.cert.clone()], cred.key.clone_key()) + .expect("valid client cert"), + ) + } + } + + /// Generate a fresh CA, server leaf (`SAN = localhost`), and one client + /// leaf per name in `workloads` (`SAN = .workloads.example.com`). + /// + /// No PEM files touch disk: this is the same shape a deployment would + /// load from a secret store, but generated in-process for the demo. + pub fn generate(workloads: &[&str]) -> Pki { + // Idempotent; err == already installed (tests share process state). + let _ = rustls::crypto::aws_lc_rs::default_provider().install_default(); + + let ca_key = KeyPair::generate().expect("generate CA key"); + let mut ca_params = CertificateParams::default(); + ca_params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained); + let ca = CertifiedIssuer::self_signed(ca_params, ca_key).expect("self-sign CA"); + + // Issue a leaf with the given DNS SANs, signed by the demo CA. + let issue = |sans: &[&str]| -> (CertificateDer<'static>, PrivateKeyDer<'static>) { + let key = KeyPair::generate().expect("generate leaf key"); + let mut params = CertificateParams::default(); + params.subject_alt_names = sans + .iter() + .map(|s| SanType::DnsName((*s).try_into().expect("valid DNS SAN"))) + .collect(); + let cert = params.signed_by(&key, &ca).expect("sign leaf"); + ( + CertificateDer::from(cert.der().to_vec()), + PrivatePkcs8KeyDer::from(key.serialized_der().to_vec()).into(), + ) + }; + + let (server_cert, server_key) = issue(&["localhost"]); + let clients = workloads + .iter() + .map(|name| { + let san = format!("{name}.{}", super::WORKLOAD_DOMAIN); + let (cert, key) = issue(&[&san]); + (name.to_string(), ClientCredential { cert, key }) + }) + .collect(); + + let mut roots = rustls::RootCertStore::empty(); + roots + .add(CertificateDer::from(ca.der().to_vec())) + .expect("add CA to roots"); + let roots = Arc::new(roots); + + // Require *and verify* client certs. WebPkiClientVerifier rejects + // anything not chained to the demo CA before the request reaches + // the handler, so PeerCerts is always a verified chain. + let verifier = rustls::server::WebPkiClientVerifier::builder(Arc::clone(&roots)) + .build() + .expect("build client verifier"); + let server_config = Arc::new( + rustls::ServerConfig::builder() + .with_client_cert_verifier(verifier) + .with_single_cert(vec![server_cert], server_key) + .expect("valid server cert"), + ); + + Pki { + server_config, + roots, + clients, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rcgen::{CertificateParams, KeyPair, SanType}; + use rustls_pki_types::CertificateDer; + + /// Self-signed leaf with arbitrary DNS SANs, wrapped as PeerCerts. + fn peer_certs_with_dns_sans(sans: &[&str]) -> PeerCerts { + let key = KeyPair::generate().unwrap(); + let mut params = CertificateParams::default(); + params.subject_alt_names = sans + .iter() + .map(|s| SanType::DnsName((*s).try_into().unwrap())) + .collect(); + let cert = params.self_signed(&key).unwrap(); + PeerCerts(Arc::from(vec![CertificateDer::from(cert.der().to_vec())])) + } + + #[test] + fn extract_identity_rejects_no_cert() { + assert_eq!( + extract_identity(None).unwrap_err().code, + ErrorCode::Unauthenticated + ); + } + + #[test] + fn extract_identity_parses_single_label_workload_san() { + let certs = + peer_certs_with_dns_sans(&["ignored.example.org", "alice.workloads.example.com"]); + let id = extract_identity(Some(&certs)).unwrap(); + assert_eq!(id.name, "alice"); + assert_eq!(id.san, "alice.workloads.example.com"); + } + + #[test] + fn extract_identity_rejects_empty_or_multi_label_prefix() { + // Empty prefix: ".workloads.example.com" must not yield name = "". + let empty = peer_certs_with_dns_sans(&[".workloads.example.com"]); + assert_eq!( + extract_identity(Some(&empty)).unwrap_err().code, + ErrorCode::Unauthenticated + ); + // Multi-label prefix: not a direct subdomain; reject it. + let multi = peer_certs_with_dns_sans(&["a.b.workloads.example.com"]); + assert_eq!( + extract_identity(Some(&multi)).unwrap_err().code, + ErrorCode::Unauthenticated + ); + } + + #[test] + fn extract_identity_rejects_unrelated_domain() { + let certs = peer_certs_with_dns_sans(&["service.elsewhere.example.org"]); + assert_eq!( + extract_identity(Some(&certs)).unwrap_err().code, + ErrorCode::Unauthenticated + ); + } +} diff --git a/examples/mtls-identity/src/main.rs b/examples/mtls-identity/src/main.rs new file mode 100644 index 0000000..a7feaec --- /dev/null +++ b/examples/mtls-identity/src/main.rs @@ -0,0 +1,99 @@ +//! Self-contained mTLS identity demo. +//! +//! Generates an in-memory PKI, serves IdentityService on axum behind +//! `connectrpc::axum::serve_tls`, then calls it with two client +//! certificates to show the cert-SAN identity flow: +//! +//! - `alice` reads `shared` and `alice-only` successfully +//! - `bob` reads `shared` but is denied `alice-only` +//! +//! Run with: +//! +//! ```sh +//! cargo run -p mtls-identity-example +//! ``` + +use std::sync::Arc; + +use connectrpc::ErrorCode; +use connectrpc::client::{ClientConfig, HttpClient}; +use mtls_identity_example::{ + BoxError, GetSecretRequest, IdentityServiceClient, WhoAmIRequest, pki, serve, +}; + +#[tokio::main] +async fn main() -> Result<(), BoxError> { + // 1. PKI: a fresh CA, server cert, and two workload client certs. + // No PEM files on disk — see `pki::generate`. + let pki = Arc::new(pki::generate(&["alice", "bob"])); + + // 2. Server: axum app behind connectrpc::axum::serve_tls. + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?; + let addr = listener.local_addr()?; + let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>(); + let server = tokio::spawn(serve(listener, Arc::clone(&pki.server_config), async { + shutdown_rx.await.ok(); + })); + println!("IdentityService listening on https://{addr} (mTLS required)\n"); + + // 3. Clients: one HttpClient per workload, each presenting its own + // client cert. The handler derives the identity from the cert SAN — + // the request bodies carry no credentials at all. + let client_for = |workload: &str| { + let http = HttpClient::with_tls(pki.client_config(workload)); + // The server cert's SAN is "localhost"; rustls verifies hostname. + let cfg = ClientConfig::new( + format!("https://localhost:{}", addr.port()) + .parse() + .unwrap(), + ); + IdentityServiceClient::new(http, cfg) + }; + let alice = client_for("alice"); + let bob = client_for("bob"); + + // WhoAmI: identity comes purely from the TLS layer. + for (label, client) in [("alice", &alice), ("bob", &bob)] { + let resp = client.who_am_i(WhoAmIRequest::default()).await?; + let v = resp.view(); + println!( + "[{label}] WhoAmI -> identity={:?} san={:?} from={:?}", + v.identity.unwrap_or(""), + v.san.unwrap_or(""), + v.remote_addr.unwrap_or(""), + ); + } + println!(); + + // GetSecret: the ACL is keyed on the cert-derived identity. + for (label, client) in [("alice", &alice), ("bob", &bob)] { + for name in ["shared", "alice-only"] { + let req = GetSecretRequest { + name: Some(name.into()), + ..Default::default() + }; + match client.get_secret(req).await { + Ok(resp) => { + let served_by = resp + .trailers() + .get("x-served-by") + .and_then(|v| v.to_str().ok()) + .unwrap_or("?"); + println!( + "[{label}] GetSecret({name:>10}) -> {:?} (x-served-by: {served_by})", + resp.view().value.unwrap_or("") + ); + } + Err(err) => { + debug_assert_eq!(err.code, ErrorCode::PermissionDenied); + println!("[{label}] GetSecret({name:>10}) -> {err}"); + } + } + } + } + + // 4. Graceful shutdown: stop accepting and drain in-flight connections. + shutdown_tx.send(()).ok(); + server.await??; + Ok(()) +} diff --git a/examples/mtls-identity/tests/e2e.rs b/examples/mtls-identity/tests/e2e.rs new file mode 100644 index 0000000..f905bf4 --- /dev/null +++ b/examples/mtls-identity/tests/e2e.rs @@ -0,0 +1,138 @@ +//! End-to-end test: spin up IdentityService behind +//! `connectrpc::axum::serve_tls`, make calls with two client identities, +//! and check that the cert-SAN identity flows through `PeerCerts` to +//! the handler and gates the ACL correctly. + +use std::sync::Arc; + +use connectrpc::ErrorCode; +use connectrpc::client::{ClientConfig, HttpClient}; +use mtls_identity_example::{GetSecretRequest, IdentityServiceClient, WhoAmIRequest, pki, serve}; + +struct Harness { + pki: Arc, + addr: std::net::SocketAddr, + shutdown: tokio::sync::oneshot::Sender<()>, + serve_task: tokio::task::JoinHandle>, +} + +impl Harness { + async fn start(workloads: &[&str]) -> Self { + let pki = Arc::new(pki::generate(workloads)); + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (tx, rx) = tokio::sync::oneshot::channel(); + let serve_task = tokio::spawn(serve(listener, Arc::clone(&pki.server_config), async { + rx.await.ok(); + })); + Harness { + pki, + addr, + shutdown: tx, + serve_task, + } + } + + fn client(&self, workload: &str) -> IdentityServiceClient { + let http = HttpClient::with_tls(self.pki.client_config(workload)); + let cfg = ClientConfig::new( + format!("https://localhost:{}", self.addr.port()) + .parse() + .unwrap(), + ); + IdentityServiceClient::new(http, cfg) + } + + async fn shutdown(self) { + self.shutdown.send(()).ok(); + tokio::time::timeout(std::time::Duration::from_secs(5), self.serve_task) + .await + .expect("server should shut down within timeout") + .unwrap() + .unwrap(); + } +} + +#[tokio::test] +async fn whoami_reflects_cert_san_and_remote_addr() { + let h = Harness::start(&["alice"]).await; + let resp = h + .client("alice") + .who_am_i(WhoAmIRequest::default()) + .await + .unwrap(); + let v = resp.view(); + assert_eq!(v.identity, Some("alice")); + assert_eq!(v.san, Some("alice.workloads.example.com")); + // PeerAddr should be the actual TCP source address. + let remote = v.remote_addr.unwrap(); + assert!(remote.starts_with("127.0.0.1:"), "remote_addr={remote}"); + h.shutdown().await; +} + +#[tokio::test] +async fn authorized_call_returns_value_and_trailer() { + let h = Harness::start(&["alice", "bob"]).await; + let resp = h + .client("alice") + .get_secret(GetSecretRequest { + name: Some("shared".into()), + ..Default::default() + }) + .await + .unwrap(); + assert_eq!(resp.view().value, Some("the value of teamwork")); + assert_eq!( + resp.trailers() + .get("x-served-by") + .unwrap() + .to_str() + .unwrap(), + "alice" + ); + h.shutdown().await; +} + +#[tokio::test] +async fn permission_denied_for_other_workloads_secret() { + let h = Harness::start(&["alice", "bob"]).await; + let err = h + .client("bob") + .get_secret(GetSecretRequest { + name: Some("alice-only".into()), + ..Default::default() + }) + .await + .expect_err("bob cannot read alice's secret"); + assert_eq!(err.code, ErrorCode::PermissionDenied); + h.shutdown().await; +} + +#[tokio::test] +async fn client_without_cert_is_rejected_at_handshake() { + // WebPkiClientVerifier requires a client cert; a TLS client that + // presents none never reaches the handler. Distinct from the + // `extract_identity(None)` branch, which only fires for hosting + // setups that make client auth optional. + let h = Harness::start(&["alice"]).await; + let no_cert_cfg = Arc::new( + connectrpc::rustls::ClientConfig::builder() + .with_root_certificates(Arc::clone(&h.pki.roots)) + .with_no_client_auth(), + ); + let http = HttpClient::with_tls(no_cert_cfg); + let cfg = ClientConfig::new( + format!("https://localhost:{}", h.addr.port()) + .parse() + .unwrap(), + ); + let client = IdentityServiceClient::new(http, cfg); + let err = client + .who_am_i(WhoAmIRequest::default()) + .await + .expect_err("server must reject a client with no cert"); + // The handshake failure surfaces as a connection-level Unavailable, + // not a Connect-protocol error: the request never reaches HTTP. + assert_eq!(err.code, ErrorCode::Unavailable, "got: {err}"); + h.shutdown().await; +}