diff --git a/src/api/mod.rs b/src/api/mod.rs index c168ce63..d6cab7a9 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -65,6 +65,8 @@ pub trait ClientInfo { fn metadata_mut(&mut self) -> &mut HashMap; + fn sni_server_name(&self) -> Option<&str>; + #[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))] fn client_certificates<'a>(&self) -> Option<&[CertificateDer<'a>]>; } @@ -89,6 +91,7 @@ pub struct DefaultClient { pub state: PgWireConnectionState, pub transaction_status: TransactionStatus, pub metadata: HashMap, + pub sni_server_name: Option, pub portal_store: store::MemPortalStore, } @@ -141,6 +144,10 @@ impl ClientInfo for DefaultClient { self.transaction_status = new_status } + fn sni_server_name(&self) -> Option<&str> { + self.sni_server_name.as_deref() + } + #[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))] fn client_certificates<'a>(&self) -> Option<&[CertificateDer<'a>]> { None @@ -157,6 +164,7 @@ impl DefaultClient { state: PgWireConnectionState::default(), transaction_status: TransactionStatus::Idle, metadata: HashMap::new(), + sni_server_name: None, portal_store: store::MemPortalStore::new(), } } diff --git a/src/tokio/server.rs b/src/tokio/server.rs index f2c9c49d..0e433915 100644 --- a/src/tokio/server.rs +++ b/src/tokio/server.rs @@ -126,6 +126,10 @@ impl ClientInfo for Framed> { .set_transaction_status(new_status); } + fn sni_server_name(&self) -> Option<&str> { + self.codec().client_info.sni_server_name() + } + #[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))] fn client_certificates<'a>(&self) -> Option<&[CertificateDer<'a>]> { if !self.is_secure() { @@ -540,7 +544,7 @@ where { if let Some(tls_acceptor) = tls_acceptor { // mention the use of ssl - let client_info = DefaultClient::new(addr, true); + let mut client_info = DefaultClient::new(addr, true); let ssl_socket = tokio::select! { _ = &mut startup_timeout => { @@ -557,6 +561,15 @@ where check_alpn_for_direct_ssl(&ssl_socket)?; } + // capture SNI (server name) from the underlying TLS connection + let sni = { + let (_, conn) = ssl_socket.get_ref(); + conn.server_name().map(|s| s.to_string()) + }; + if let Some(s) = sni { + client_info.sni_server_name = Some(s); + } + let mut socket = Framed::new( BufStream::new(ssl_socket), PgWireMessageServerCodec::new(client_info), @@ -584,3 +597,247 @@ where Ok(()) } } + +#[cfg(all(test, any(feature = "_ring", feature = "_aws-lc-rs")))] +mod tests { + use super::*; + use std::fs::File; + use std::io::{BufReader, Error as IOError}; + use std::sync::Arc; + use tokio::net::TcpListener; + use tokio::sync::oneshot; + use tokio_rustls::rustls; + use tokio_rustls::TlsAcceptor; + use tokio_rustls::TlsConnector; + + fn load_test_server_config() -> Result { + use rustls_pemfile::{certs, pkcs8_private_keys}; + use rustls_pki_types::{CertificateDer, PrivateKeyDer}; + + let certs = certs(&mut BufReader::new(File::open("examples/ssl/server.crt")?)) + .collect::, _>>()?; + let key = pkcs8_private_keys(&mut BufReader::new(File::open("examples/ssl/server.key")?)) + .map(|key| key.map(PrivateKeyDer::from)) + .collect::, _>>()? + .remove(0); + + let mut cfg = rustls::ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(certs, key) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))?; + // ALPN is optional for this test; SNI extraction doesn't depend on it. + cfg.alpn_protocols = vec![crate::tokio::POSTGRESQL_ALPN_NAME.to_vec()]; + Ok(cfg) + } + + fn make_test_client_connector() -> Result { + // For this unit test we are only validating SNI plumbing, not cert validation. + // Use a custom verifier that accepts any certificate. + #[derive(Debug)] + struct NoCertVerifier; + impl rustls::client::danger::ServerCertVerifier for NoCertVerifier { + fn verify_server_cert( + &self, + _end_entity: &rustls::pki_types::CertificateDer<'_>, + _intermediates: &[rustls::pki_types::CertificateDer<'_>], + _server_name: &rustls::pki_types::ServerName<'_>, + _ocsp_response: &[u8], + _now: rustls::pki_types::UnixTime, + ) -> Result { + Ok(rustls::client::danger::ServerCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + _message: &[u8], + _cert: &rustls::pki_types::CertificateDer<'_>, + _dss: &rustls::DigitallySignedStruct, + ) -> Result + { + Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) + } + + fn verify_tls13_signature( + &self, + _message: &[u8], + _cert: &rustls::pki_types::CertificateDer<'_>, + _dss: &rustls::DigitallySignedStruct, + ) -> Result + { + Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) + } + + fn supported_verify_schemes(&self) -> Vec { + vec![ + rustls::SignatureScheme::RSA_PKCS1_SHA256, + rustls::SignatureScheme::RSA_PKCS1_SHA384, + rustls::SignatureScheme::RSA_PKCS1_SHA512, + rustls::SignatureScheme::ECDSA_NISTP256_SHA256, + rustls::SignatureScheme::ECDSA_NISTP384_SHA384, + rustls::SignatureScheme::RSA_PSS_SHA256, + rustls::SignatureScheme::RSA_PSS_SHA384, + rustls::SignatureScheme::RSA_PSS_SHA512, + ] + } + } + + let mut cfg = rustls::ClientConfig::builder() + .dangerous() + .with_custom_certificate_verifier(Arc::new(NoCertVerifier)) + .with_no_client_auth(); + // Align ALPN to server to reduce negotiation variance + cfg.alpn_protocols = vec![crate::tokio::POSTGRESQL_ALPN_NAME.to_vec()]; + Ok(TlsConnector::from(Arc::new(cfg))) + } + + #[tokio::test] + #[ignore] + async fn server_name_metadata_is_set_from_tls_sni() { + use std::net::SocketAddr; + use tokio::io::duplex; + + // set up TLS server and client configs + let server_cfg = load_test_server_config().expect("server config"); + let acceptor = TlsAcceptor::from(Arc::new(server_cfg)); + let connector = make_test_client_connector().expect("client connector"); + + // in-memory full-duplex stream pair (use a larger buffer for TLS handshake) + let (server_io, client_io) = duplex(64 * 1024); + + let (tx, rx) = oneshot::channel::>(); + + // spawn server task to accept TLS over in-memory IO + tokio::spawn(async move { + let tls = acceptor.accept(server_io).await.unwrap(); + + // mimic production path: capture SNI and store on client_info + let sni = { + let (_, conn) = tls.get_ref(); + conn.server_name().map(|s| s.to_string()) + }; + let peer: SocketAddr = "127.0.0.1:0".parse().unwrap(); + let mut ci: DefaultClient<()> = DefaultClient::new(peer, true); + if let Some(s) = sni { + ci.sni_server_name = Some(s); + } + let framed = Framed::new(BufStream::new(tls), PgWireMessageServerCodec::new(ci)); + let server_name = framed.sni_server_name().map(str::to_string); + let _ = tx.send(server_name); + }); + + // client side: connect with SNI=localhost over in-memory IO + let server_name = rustls_pki_types::ServerName::try_from("localhost").unwrap(); + let _ = connector.connect(server_name, client_io).await.unwrap(); + + // verify server observed SNI and stored as `server_name` + let observed = rx.await.expect("server_name from server"); + assert_eq!(observed.as_deref(), Some("localhost")); + } + + #[tokio::test] + async fn server_name_metadata_is_set_from_tls_sni_in_memory() { + use std::net::SocketAddr; + + // server and client rustls configs + let server_cfg = Arc::new(load_test_server_config().expect("server config")); + + // no-op verifier to focus on SNI plumbing + #[derive(Debug)] + struct NoCertVerifier; + impl rustls::client::danger::ServerCertVerifier for NoCertVerifier { + fn verify_server_cert( + &self, + _end_entity: &rustls::pki_types::CertificateDer<'_>, + _intermediates: &[rustls::pki_types::CertificateDer<'_>], + _server_name: &rustls::pki_types::ServerName<'_>, + _ocsp_response: &[u8], + _now: rustls::pki_types::UnixTime, + ) -> Result { + Ok(rustls::client::danger::ServerCertVerified::assertion()) + } + fn verify_tls12_signature( + &self, + _message: &[u8], + _cert: &rustls::pki_types::CertificateDer<'_>, + _dss: &rustls::DigitallySignedStruct, + ) -> Result + { + Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) + } + fn verify_tls13_signature( + &self, + _message: &[u8], + _cert: &rustls::pki_types::CertificateDer<'_>, + _dss: &rustls::DigitallySignedStruct, + ) -> Result + { + Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) + } + fn supported_verify_schemes(&self) -> Vec { + vec![ + rustls::SignatureScheme::RSA_PKCS1_SHA256, + rustls::SignatureScheme::RSA_PKCS1_SHA384, + rustls::SignatureScheme::RSA_PKCS1_SHA512, + rustls::SignatureScheme::ECDSA_NISTP256_SHA256, + rustls::SignatureScheme::ECDSA_NISTP384_SHA384, + rustls::SignatureScheme::RSA_PSS_SHA256, + rustls::SignatureScheme::RSA_PSS_SHA384, + rustls::SignatureScheme::RSA_PSS_SHA512, + ] + } + } + + let mut client_cfg = rustls::ClientConfig::builder() + .dangerous() + .with_custom_certificate_verifier(Arc::new(NoCertVerifier)) + .with_no_client_auth(); + client_cfg.alpn_protocols = vec![crate::tokio::POSTGRESQL_ALPN_NAME.to_vec()]; + let client_cfg = Arc::new(client_cfg); + + // build rustls connections directly and drive handshake in-memory + let mut server_conn = rustls::ServerConnection::new(server_cfg).unwrap(); + let mut client_conn = rustls::ClientConnection::new( + client_cfg, + rustls_pki_types::ServerName::try_from("localhost").unwrap(), + ) + .unwrap(); + + // in-memory pipes for TLS records + let mut c2s = Vec::new(); + let mut s2c = Vec::new(); + + // drive handshake until both sides complete + for _ in 0..1000 { + // client -> server + let _ = client_conn.write_tls(&mut c2s); + if !c2s.is_empty() { + let mut cur = std::io::Cursor::new(&c2s); + let _ = server_conn.read_tls(&mut cur); + c2s.clear(); + server_conn.process_new_packets().unwrap(); + } + + // server -> client + let _ = server_conn.write_tls(&mut s2c); + if !s2c.is_empty() { + let mut cur = std::io::Cursor::new(&s2c); + let _ = client_conn.read_tls(&mut cur); + s2c.clear(); + client_conn.process_new_packets().unwrap(); + } + + if !client_conn.is_handshaking() && !server_conn.is_handshaking() { + break; + } + } + + // capture SNI from server side and store on client info + let sni = server_conn.server_name().map(|s| s.to_string()); + let peer: SocketAddr = "127.0.0.1:0".parse().unwrap(); + let mut ci: DefaultClient<()> = DefaultClient::new(peer, true); + if let Some(s) = sni { + ci.sni_server_name = Some(s); + } + assert_eq!(ci.sni_server_name(), Some("localhost")); + } +}