From bf22fcdbd0657b78c0e7f24cc755c8d231d74349 Mon Sep 17 00:00:00 2001 From: Eric Sheppard Date: Tue, 23 Aug 2022 21:20:45 +1000 Subject: [PATCH] allow multiple tls features fix imports fix vendored-openssl fmt --- src/client/config.rs | 48 ++++++++ src/client/connection.rs | 23 +++- src/client/tls_stream.rs | 121 ++++++++++++++++++--- src/client/tls_stream/rustls_tls_stream.rs | 2 +- 4 files changed, 176 insertions(+), 18 deletions(-) diff --git a/src/client/config.rs b/src/client/config.rs index 581676f4..a2d0feaa 100644 --- a/src/client/config.rs +++ b/src/client/config.rs @@ -31,6 +31,48 @@ pub struct Config { pub(crate) encryption: EncryptionLevel, pub(crate) trust: TrustConfig, pub(crate) auth: AuthMethod, + pub(crate) tls_choice: TlsChoice, +} + +#[derive(Clone, Copy, PartialEq, Eq, Debug)] +pub enum TlsChoice { + #[cfg(not(any( + feature = "rustls", + feature = "native-tls", + feature = "vendored-openssl" + )))] + None, + #[cfg(feature = "rustls")] + Rustls, + #[cfg(feature = "native-tls")] + NativeTls, + #[cfg(feature = "vendored-openssl")] + Openssl, +} + +impl Default for TlsChoice { + #[allow(unreachable_code, clippy::needless_return)] + fn default() -> TlsChoice { + #[cfg(feature = "rustls")] + { + return TlsChoice::Rustls; + } + #[cfg(feature = "native-tls")] + { + return TlsChoice::NativeTls; + } + #[cfg(feature = "vendored-openssl")] + { + return TlsChoice::Openssl; + } + + #[cfg(not(any( + feature = "rustls", + feature = "native-tls", + feature = "vendored-openssl" + )))] + TlsChoice::None + } } #[derive(Clone, Debug)] @@ -62,6 +104,7 @@ impl Default for Config { encryption: EncryptionLevel::NotSupported, trust: TrustConfig::Default, auth: AuthMethod::None, + tls_choice: TlsChoice::default(), } } } @@ -120,6 +163,11 @@ impl Config { self.encryption = encryption; } + /// Set the choice of Tls + pub fn tls_choice(&mut self, tls_choice: TlsChoice) { + self.tls_choice = tls_choice; + } + /// If set, the server certificate will not be validated and it is accepted /// as-is. /// diff --git a/src/client/connection.rs b/src/client/connection.rs index ac4e14b3..1b3061a0 100644 --- a/src/client/connection.rs +++ b/src/client/connection.rs @@ -3,7 +3,7 @@ feature = "native-tls", feature = "vendored-openssl" ))] -use crate::client::{tls::TlsPreloginWrapper, tls_stream::create_tls_stream}; +use crate::client::{config::TlsChoice, tls::TlsPreloginWrapper, tls_stream}; use crate::{ client::{tls::MaybeTlsStream, AuthMethod, Config}, tds::{ @@ -442,10 +442,25 @@ impl Connection { let Self { transport, context, .. } = self; - let mut stream = match transport.into_inner() { - MaybeTlsStream::Raw(tcp) => { - create_tls_stream(config, TlsPreloginWrapper::new(tcp)).await? + + let mut stream = match (transport.into_inner(), config.tls_choice) { + #[cfg(feature = "rustls")] + (MaybeTlsStream::Raw(tcp), TlsChoice::Rustls) => { + tls_stream::create_tls_stream_rustls(config, TlsPreloginWrapper::new(tcp)) + .await? + } + #[cfg(feature = "vendored-openssl")] + (MaybeTlsStream::Raw(tcp), TlsChoice::Openssl) => { + tls_stream::create_tls_stream_openssl(config, TlsPreloginWrapper::new(tcp)) + .await? + } + #[cfg(feature = "native-tls")] + (MaybeTlsStream::Raw(tcp), TlsChoice::NativeTls) => { + tls_stream::create_tls_stream_native_tls(config, TlsPreloginWrapper::new(tcp)) + .await? } + // this should still be fine as the relevant TlsChoices are only + // enabled when the equivalent tls crate is enabled _ => unreachable!(), }; diff --git a/src/client/tls_stream.rs b/src/client/tls_stream.rs index 9e363ed5..ffc96413 100644 --- a/src/client/tls_stream.rs +++ b/src/client/tls_stream.rs @@ -1,6 +1,10 @@ use crate::Config; use futures::{AsyncRead, AsyncWrite}; - +use std::{ + io, + pin::Pin, + task::{Context, Poll}, +}; #[cfg(feature = "native-tls")] mod native_tls_stream; @@ -10,35 +14,126 @@ mod rustls_tls_stream; #[cfg(feature = "vendored-openssl")] mod opentls_tls_stream; -#[cfg(feature = "native-tls")] -pub(crate) use native_tls_stream::TlsStream; +// #[cfg(feature = "native-tls")] +// pub(crate) use native_tls_stream::TlsStream as NativeTlsStream; -#[cfg(feature = "rustls")] -pub(crate) use rustls_tls_stream::TlsStream; +// #[cfg(feature = "rustls")] +// pub(crate) use rustls_tls_stream::TlsStream as RustlsTlsStream; -#[cfg(feature = "vendored-openssl")] -pub(crate) use opentls_tls_stream::TlsStream; +// #[cfg(feature = "vendored-openssl")] +// pub(crate) use opentls_tls_stream::TlsStream as OptenSslTlsStream; + +pub(crate) enum TlsStream { + #[cfg(feature = "vendored-openssl")] + Openssl(opentls_tls_stream::TlsStream), + #[cfg(feature = "rustls")] + Rustls(rustls_tls_stream::TlsStream), + #[cfg(feature = "native-tls")] + NativeTls(native_tls_stream::TlsStream), +} + +impl TlsStream +where + S: AsyncRead + AsyncWrite + Unpin + Send, +{ + pub(crate) fn get_mut(&mut self) -> &mut S { + match self { + #[cfg(feature = "vendored-openssl")] + TlsStream::Openssl(s) => s.get_mut(), + #[cfg(feature = "rustls")] + TlsStream::Rustls(s) => s.get_mut(), + #[cfg(feature = "native-tls")] + TlsStream::NativeTls(s) => s.get_mut(), + } + } +} + +impl AsyncRead for TlsStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + let inner = Pin::get_mut(self); + match inner { + #[cfg(feature = "vendored-openssl")] + TlsStream::Openssl(s) => Pin::new(s).poll_read(cx, buf), + #[cfg(feature = "rustls")] + TlsStream::Rustls(s) => Pin::new(&mut s.0).poll_read(cx, buf), + #[cfg(feature = "native-tls")] + TlsStream::NativeTls(s) => Pin::new(s).poll_read(cx, buf), + } + } +} + +impl AsyncWrite for TlsStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let inner = Pin::get_mut(self); + match inner { + #[cfg(feature = "vendored-openssl")] + TlsStream::Openssl(s) => Pin::new(s).poll_write(cx, buf), + #[cfg(feature = "rustls")] + TlsStream::Rustls(s) => Pin::new(&mut s.0).poll_write(cx, buf), + #[cfg(feature = "native-tls")] + TlsStream::NativeTls(s) => Pin::new(s).poll_write(cx, buf), + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let inner = Pin::get_mut(self); + match inner { + #[cfg(feature = "vendored-openssl")] + TlsStream::Openssl(s) => Pin::new(s).poll_flush(cx), + #[cfg(feature = "rustls")] + TlsStream::Rustls(s) => Pin::new(&mut s.0).poll_flush(cx), + #[cfg(feature = "native-tls")] + TlsStream::NativeTls(s) => Pin::new(s).poll_flush(cx), + } + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let inner = Pin::get_mut(self); + match inner { + #[cfg(feature = "vendored-openssl")] + TlsStream::Openssl(s) => Pin::new(s).poll_close(cx), + #[cfg(feature = "rustls")] + TlsStream::Rustls(s) => Pin::new(&mut s.0).poll_close(cx), + #[cfg(feature = "native-tls")] + TlsStream::NativeTls(s) => Pin::new(s).poll_close(cx), + } + } +} #[cfg(feature = "rustls")] -pub(crate) async fn create_tls_stream( +pub(crate) async fn create_tls_stream_rustls( config: &Config, stream: S, ) -> crate::Result> { - TlsStream::new(config, stream).await + rustls_tls_stream::TlsStream::new(config, stream) + .await + .map(TlsStream::Rustls) } #[cfg(feature = "native-tls")] -pub(crate) async fn create_tls_stream( +pub(crate) async fn create_tls_stream_native_tls( config: &Config, stream: S, ) -> crate::Result> { - native_tls_stream::create_tls_stream(config, stream).await + native_tls_stream::create_tls_stream(config, stream) + .await + .map(TlsStream::NativeTls) } #[cfg(feature = "vendored-openssl")] -pub(crate) async fn create_tls_stream( +pub(crate) async fn create_tls_stream_openssl( config: &Config, stream: S, ) -> crate::Result> { - opentls_tls_stream::create_tls_stream(config, stream).await + opentls_tls_stream::create_tls_stream(config, stream) + .await + .map(TlsStream::Openssl) } diff --git a/src/client/tls_stream/rustls_tls_stream.rs b/src/client/tls_stream/rustls_tls_stream.rs index b537a2fa..3322cee1 100644 --- a/src/client/tls_stream/rustls_tls_stream.rs +++ b/src/client/tls_stream/rustls_tls_stream.rs @@ -33,7 +33,7 @@ impl From for Error { } pub(crate) struct TlsStream( - Compat>>, + pub(super) Compat>>, ); struct NoCertVerifier;