Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ pub trait ClientInfo {

fn metadata_mut(&mut self) -> &mut HashMap<String, String>;

fn sni_server_name(&self) -> Option<&str>;

#[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))]
fn client_certificates<'a>(&self) -> Option<&[CertificateDer<'a>]>;
}
Expand All @@ -89,6 +91,7 @@ pub struct DefaultClient<S> {
pub state: PgWireConnectionState,
pub transaction_status: TransactionStatus,
pub metadata: HashMap<String, String>,
pub sni_server_name: Option<String>,
pub portal_store: store::MemPortalStore<S>,
}

Expand Down Expand Up @@ -141,6 +144,10 @@ impl<S> ClientInfo for DefaultClient<S> {
self.transaction_status = new_status
}

fn sni_server_name(&self) -> Option<&str> {
self.sni_server_name.as_deref()
}

#[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))]
fn client_certificates<'a>(&self) -> Option<&[CertificateDer<'a>]> {
None
Expand All @@ -157,6 +164,7 @@ impl<S> DefaultClient<S> {
state: PgWireConnectionState::default(),
transaction_status: TransactionStatus::Idle,
metadata: HashMap::new(),
sni_server_name: None,
portal_store: store::MemPortalStore::new(),
}
}
Expand Down
259 changes: 258 additions & 1 deletion src/tokio/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ impl<T: 'static, S> ClientInfo for Framed<T, PgWireMessageServerCodec<S>> {
.set_transaction_status(new_status);
}

fn sni_server_name(&self) -> Option<&str> {
self.codec().client_info.sni_server_name()
}

#[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))]
fn client_certificates<'a>(&self) -> Option<&[CertificateDer<'a>]> {
if !self.is_secure() {
Expand Down Expand Up @@ -540,7 +544,7 @@ where
{
if let Some(tls_acceptor) = tls_acceptor {
// mention the use of ssl
let client_info = DefaultClient::new(addr, true);
let mut client_info = DefaultClient::new(addr, true);

let ssl_socket = tokio::select! {
_ = &mut startup_timeout => {
Expand All @@ -557,6 +561,15 @@ where
check_alpn_for_direct_ssl(&ssl_socket)?;
}

// capture SNI (server name) from the underlying TLS connection
let sni = {
let (_, conn) = ssl_socket.get_ref();
conn.server_name().map(|s| s.to_string())
};
if let Some(s) = sni {
client_info.sni_server_name = Some(s);
}
Comment thread
sunng87 marked this conversation as resolved.

let mut socket = Framed::new(
BufStream::new(ssl_socket),
PgWireMessageServerCodec::new(client_info),
Expand Down Expand Up @@ -584,3 +597,247 @@ where
Ok(())
}
}

#[cfg(all(test, any(feature = "_ring", feature = "_aws-lc-rs")))]
mod tests {
use super::*;
use std::fs::File;
use std::io::{BufReader, Error as IOError};
use std::sync::Arc;
use tokio::net::TcpListener;
use tokio::sync::oneshot;
use tokio_rustls::rustls;
use tokio_rustls::TlsAcceptor;
use tokio_rustls::TlsConnector;

fn load_test_server_config() -> Result<rustls::ServerConfig, IOError> {
use rustls_pemfile::{certs, pkcs8_private_keys};
use rustls_pki_types::{CertificateDer, PrivateKeyDer};

let certs = certs(&mut BufReader::new(File::open("examples/ssl/server.crt")?))
.collect::<Result<Vec<CertificateDer>, _>>()?;
let key = pkcs8_private_keys(&mut BufReader::new(File::open("examples/ssl/server.key")?))
.map(|key| key.map(PrivateKeyDer::from))
.collect::<Result<Vec<PrivateKeyDer>, _>>()?
.remove(0);

let mut cfg = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))?;
// ALPN is optional for this test; SNI extraction doesn't depend on it.
cfg.alpn_protocols = vec![crate::tokio::POSTGRESQL_ALPN_NAME.to_vec()];
Ok(cfg)
}

fn make_test_client_connector() -> Result<TlsConnector, IOError> {
// For this unit test we are only validating SNI plumbing, not cert validation.
// Use a custom verifier that accepts any certificate.
#[derive(Debug)]
struct NoCertVerifier;
impl rustls::client::danger::ServerCertVerifier for NoCertVerifier {
fn verify_server_cert(
&self,
_end_entity: &rustls::pki_types::CertificateDer<'_>,
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}

fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error>
{
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}

fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error>
{
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}

fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
vec![
rustls::SignatureScheme::RSA_PKCS1_SHA256,
rustls::SignatureScheme::RSA_PKCS1_SHA384,
rustls::SignatureScheme::RSA_PKCS1_SHA512,
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
rustls::SignatureScheme::RSA_PSS_SHA256,
rustls::SignatureScheme::RSA_PSS_SHA384,
rustls::SignatureScheme::RSA_PSS_SHA512,
]
}
}

let mut cfg = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(NoCertVerifier))
.with_no_client_auth();
// Align ALPN to server to reduce negotiation variance
cfg.alpn_protocols = vec![crate::tokio::POSTGRESQL_ALPN_NAME.to_vec()];
Ok(TlsConnector::from(Arc::new(cfg)))
}

#[tokio::test]
#[ignore]
async fn server_name_metadata_is_set_from_tls_sni() {
use std::net::SocketAddr;
use tokio::io::duplex;

// set up TLS server and client configs
let server_cfg = load_test_server_config().expect("server config");
let acceptor = TlsAcceptor::from(Arc::new(server_cfg));
let connector = make_test_client_connector().expect("client connector");

// in-memory full-duplex stream pair (use a larger buffer for TLS handshake)
let (server_io, client_io) = duplex(64 * 1024);

let (tx, rx) = oneshot::channel::<Option<String>>();

// spawn server task to accept TLS over in-memory IO
tokio::spawn(async move {
let tls = acceptor.accept(server_io).await.unwrap();

// mimic production path: capture SNI and store on client_info
let sni = {
let (_, conn) = tls.get_ref();
conn.server_name().map(|s| s.to_string())
};
let peer: SocketAddr = "127.0.0.1:0".parse().unwrap();
let mut ci: DefaultClient<()> = DefaultClient::new(peer, true);
if let Some(s) = sni {
ci.sni_server_name = Some(s);
}
let framed = Framed::new(BufStream::new(tls), PgWireMessageServerCodec::new(ci));
let server_name = framed.sni_server_name().map(str::to_string);
let _ = tx.send(server_name);
});

// client side: connect with SNI=localhost over in-memory IO
let server_name = rustls_pki_types::ServerName::try_from("localhost").unwrap();
let _ = connector.connect(server_name, client_io).await.unwrap();

// verify server observed SNI and stored as `server_name`
let observed = rx.await.expect("server_name from server");
assert_eq!(observed.as_deref(), Some("localhost"));
}

#[tokio::test]
async fn server_name_metadata_is_set_from_tls_sni_in_memory() {
use std::net::SocketAddr;

// server and client rustls configs
let server_cfg = Arc::new(load_test_server_config().expect("server config"));

// no-op verifier to focus on SNI plumbing
#[derive(Debug)]
struct NoCertVerifier;
impl rustls::client::danger::ServerCertVerifier for NoCertVerifier {
fn verify_server_cert(
&self,
_end_entity: &rustls::pki_types::CertificateDer<'_>,
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error>
{
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error>
{
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
vec![
rustls::SignatureScheme::RSA_PKCS1_SHA256,
rustls::SignatureScheme::RSA_PKCS1_SHA384,
rustls::SignatureScheme::RSA_PKCS1_SHA512,
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
rustls::SignatureScheme::RSA_PSS_SHA256,
rustls::SignatureScheme::RSA_PSS_SHA384,
rustls::SignatureScheme::RSA_PSS_SHA512,
]
}
}

let mut client_cfg = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(NoCertVerifier))
.with_no_client_auth();
client_cfg.alpn_protocols = vec![crate::tokio::POSTGRESQL_ALPN_NAME.to_vec()];
let client_cfg = Arc::new(client_cfg);

// build rustls connections directly and drive handshake in-memory
let mut server_conn = rustls::ServerConnection::new(server_cfg).unwrap();
let mut client_conn = rustls::ClientConnection::new(
client_cfg,
rustls_pki_types::ServerName::try_from("localhost").unwrap(),
)
.unwrap();

// in-memory pipes for TLS records
let mut c2s = Vec::new();
let mut s2c = Vec::new();

// drive handshake until both sides complete
for _ in 0..1000 {
// client -> server
let _ = client_conn.write_tls(&mut c2s);
if !c2s.is_empty() {
let mut cur = std::io::Cursor::new(&c2s);
let _ = server_conn.read_tls(&mut cur);
c2s.clear();
server_conn.process_new_packets().unwrap();
}

// server -> client
let _ = server_conn.write_tls(&mut s2c);
if !s2c.is_empty() {
let mut cur = std::io::Cursor::new(&s2c);
let _ = client_conn.read_tls(&mut cur);
s2c.clear();
client_conn.process_new_packets().unwrap();
}

if !client_conn.is_handshaking() && !server_conn.is_handshaking() {
break;
}
}

// capture SNI from server side and store on client info
let sni = server_conn.server_name().map(|s| s.to_string());
let peer: SocketAddr = "127.0.0.1:0".parse().unwrap();
let mut ci: DefaultClient<()> = DefaultClient::new(peer, true);
if let Some(s) = sni {
ci.sni_server_name = Some(s);
}
assert_eq!(ci.sni_server_name(), Some("localhost"));
}
}
Loading