diff --git a/src/client.rs b/src/client.rs index 688721d1..521847af 100644 --- a/src/client.rs +++ b/src/client.rs @@ -68,6 +68,9 @@ impl Client { /// options required to connect to the database using an established /// tcp connection /// + /// Note: `tcp_stream` is a connected stream, so some parts of the `Config` + /// must be handled outside of this constructor. + /// /// [`Config`]: struct.Config.html pub async fn connect(config: Config, tcp_stream: S) -> crate::Result> { Ok(Client { diff --git a/src/client/config.rs b/src/client/config.rs index fff68bc1..dd48e478 100644 --- a/src/client/config.rs +++ b/src/client/config.rs @@ -32,6 +32,7 @@ pub struct Config { pub(crate) trust: TrustConfig, pub(crate) auth: AuthMethod, pub(crate) readonly: bool, + pub(crate) multi_subnet_failover: bool, } #[derive(Clone, Debug)] @@ -65,6 +66,7 @@ impl Default for Config { trust: TrustConfig::Default, auth: AuthMethod::None, readonly: false, + multi_subnet_failover: false, } } } @@ -171,6 +173,17 @@ impl Config { self.readonly = readnoly; } + /// Sets multiSubnetFailover flag. + /// + /// - Defaults to `false`. + pub fn multi_subnet_failover(&mut self, multi_subnet_failover: bool) { + self.multi_subnet_failover = multi_subnet_failover; + } + /// Gets multiSubnetFailover flag. + pub fn get_multi_subnet_failover(&self) -> bool { + self.multi_subnet_failover + } + pub(crate) fn get_host(&self) -> &str { self.host .as_deref() @@ -269,6 +282,8 @@ impl Config { builder.readonly(s.readonly()); + builder.multi_subnet_failover(s.multi_subnet_failover()?); + Ok(builder) } } @@ -388,4 +403,11 @@ pub(crate) trait ConfigString { .filter(|val| *val == "ReadOnly") .is_some() } + + fn multi_subnet_failover(&self) -> crate::Result { + self.dict() + .get("multisubnetfailover") + .map(Self::parse_bool) + .unwrap_or(Ok(false)) + } } diff --git a/src/client/connection.rs b/src/client/connection.rs index c6ce1d66..b8ef8667 100644 --- a/src/client/connection.rs +++ b/src/client/connection.rs @@ -72,6 +72,9 @@ impl Debug for Connection { impl Connection { /// Creates a new connection + /// + /// Note: `tcp_stream` is a connected stream, so some parts of the + /// [`Config`] needs to be handled outside of this method. pub(crate) async fn connect(config: Config, tcp_stream: S) -> crate::Result> { let context = { let mut context = Context::new(); diff --git a/src/sql_browser/async_std.rs b/src/sql_browser/async_std.rs index 14f55de5..74e10d12 100644 --- a/src/sql_browser/async_std.rs +++ b/src/sql_browser/async_std.rs @@ -1,72 +1,98 @@ use super::SqlBrowser; +use std::net::SocketAddr; use async_std::{ io, net::{self, ToSocketAddrs}, }; use async_trait::async_trait; use futures_util::future::TryFutureExt; +use futures_util::stream::FuturesUnordered; +use futures_util::StreamExt; use std::time; use tracing::Level; -#[async_trait] -impl SqlBrowser for net::TcpStream { - /// This method can be used to connect to SQL Server named instances - /// when on a Windows platform with the `sql-browser-async-std` feature - /// enabled. Please see the crate examples for more detailed examples. - async fn connect_named(builder: &crate::client::Config) -> crate::Result { - let addrs = builder.get_addr().to_socket_addrs().await?; +async fn connect_addr(builder: &crate::client::Config, mut addr: SocketAddr) -> crate::Result { + if let Some(ref instance_name) = builder.instance_name { + // First resolve the instance to a port via the + // SSRP protocol/MS-SQLR protocol [1] + // [1] https://msdn.microsoft.com/en-us/library/cc219703.aspx - for mut addr in addrs { - if let Some(ref instance_name) = builder.instance_name { - // First resolve the instance to a port via the - // SSRP protocol/MS-SQLR protocol [1] - // [1] https://msdn.microsoft.com/en-us/library/cc219703.aspx + let local_bind: std::net::SocketAddr = if addr.is_ipv4() { + "0.0.0.0:0".parse().unwrap() + } else { + "[::]:0".parse().unwrap() + }; - let local_bind: std::net::SocketAddr = if addr.is_ipv4() { - "0.0.0.0:0".parse().unwrap() - } else { - "[::]:0".parse().unwrap() - }; + tracing::event!( + Level::TRACE, + "Connecting to instance `{}` using SQL Browser in port `{}`", + instance_name, + builder.get_port() + ); - tracing::event!( - Level::TRACE, - "Connecting to instance `{}` using SQL Browser in port `{}`", - instance_name, - builder.get_port() - ); + let msg = [&[4u8], instance_name.as_bytes()].concat(); + let mut buf = vec![0u8; 4096]; - let msg = [&[4u8], instance_name.as_bytes()].concat(); - let mut buf = vec![0u8; 4096]; + let socket = net::UdpSocket::bind(&local_bind).await?; + socket.send_to(&msg, &addr).await?; - let socket = net::UdpSocket::bind(&local_bind).await?; - socket.send_to(&msg, &addr).await?; + let timeout = time::Duration::from_millis(1000); - let timeout = time::Duration::from_millis(1000); + let len = io::timeout(timeout, socket.recv(&mut buf)) + .map_err(|_| { + crate::error::Error::Conversion( + format!( + "SQL browser timeout during resolving instance {}. Please check if browser is running in port {} and does the instance exist.", + instance_name, + builder.get_port(), + ) + .into(), + ) + }) + .await?; - let len = io::timeout(timeout, socket.recv(&mut buf)) - .map_err(|_| { - crate::error::Error::Conversion( - format!( - "SQL browser timeout during resolving instance {}. Please check if browser is running in port {} and does the instance exist.", - instance_name, - builder.get_port(), - ) - .into(), - ) - }) - .await?; + let port = super::get_port_from_sql_browser_reply(buf, len, instance_name)?; + tracing::event!(Level::TRACE, "Found port `{}` from SQL Browser", port); + addr.set_port(port); + }; - let port = super::get_port_from_sql_browser_reply(buf, len, instance_name)?; - tracing::event!(Level::TRACE, "Found port `{}` from SQL Browser", port); - addr.set_port(port); - }; + if let Ok(stream) = net::TcpStream::connect(addr).await { + stream.set_nodelay(true)?; + return Ok(stream); + } else { + Err(io::Error::new(io::ErrorKind::NotFound, "Could not resolve server host").into()) + } +} - if let Ok(stream) = net::TcpStream::connect(addr).await { - stream.set_nodelay(true)?; - return Ok(stream); +#[async_trait] +impl SqlBrowser for net::TcpStream { + /// This method can be used to connect to SQL Server named instances + /// when on a Windows platform with the `sql-browser-async-std` feature + /// enabled. Please see the crate examples for more detailed examples. + async fn connect_named(builder: &crate::client::Config) -> crate::Result { + let addrs = builder.get_addr().to_socket_addrs().await?; + let mut first_error = None; + + if builder.multi_subnet_failover { + let mut futures = addrs + .map(|addr| connect_addr(builder, addr)) + .collect::>(); + while let Some(connection) = futures.next().await { + match connection { + Ok(connection) => return Ok(connection), + Err(error) => first_error.get_or_insert(error), + }; + } + } else { + for addr in addrs { + match connect_addr(builder, addr).await { + Ok(connection) => return Ok(connection), + Err(error) => first_error.get_or_insert(error), + }; } } - Err(io::Error::new(io::ErrorKind::NotFound, "Could not resolve server host").into()) + // If we end up here, there was no successfull connection. + Err(first_error.unwrap_or_else(|| io::Error::new(io::ErrorKind::NotFound, "Could not resolve server host").into())) } } diff --git a/src/sql_browser/smol.rs b/src/sql_browser/smol.rs index 252b834e..c19403cc 100644 --- a/src/sql_browser/smol.rs +++ b/src/sql_browser/smol.rs @@ -1,10 +1,12 @@ use super::SqlBrowser; use crate::client::Config; use async_io::Timer; -use async_net::{resolve, TcpStream, UdpSocket}; +use async_net::{resolve, TcpStream, UdpSocket, SocketAddr}; use async_trait::async_trait; use futures_lite::FutureExt; use futures_util::future::TryFutureExt; +use futures_util::stream::FuturesUnordered; +use futures_util::StreamExt; use std::io; use std::time::Duration; use tracing::Level; @@ -16,64 +18,85 @@ impl SqlBrowser for TcpStream { /// enabled. Please see the crate examples for more detailed examples. async fn connect_named(builder: &Config) -> crate::Result { let addrs = resolve(builder.get_addr()).await?; + let mut first_error = None; - for mut addr in addrs { - if let Some(ref instance_name) = builder.instance_name { - // First resolve the instance to a port via the - // SSRP protocol/MS-SQLR protocol [1] - // [1] https://msdn.microsoft.com/en-us/library/cc219703.aspx - - let local_bind: std::net::SocketAddr = if addr.is_ipv4() { - "0.0.0.0:0".parse().unwrap() - } else { - "[::]:0".parse().unwrap() + if builder.multi_subnet_failover { + let mut futures = addrs + .into_iter() + .map(|addr| connect_addr(builder, addr)) + .collect::>(); + while let Some(connection) = futures.next().await { + match connection { + Ok(connection) => return Ok(connection), + Err(error) => first_error.get_or_insert(error), + }; + } + } else { + for addr in addrs { + match connect_addr(builder, addr).await { + Ok(connection) => return Ok(connection), + Err(error) => first_error.get_or_insert(error), }; + } + } + + // If we end up here, there was no successfull connection. + Err(first_error.unwrap_or_else(|| io::Error::new(io::ErrorKind::NotFound, "Could not resolve server host").into())) + } +} - tracing::event!( - Level::TRACE, - "Connecting to instance `{}` using SQL Browser in port `{}`", - instance_name, - builder.get_port() - ); +async fn connect_addr(builder: &Config, mut addr: SocketAddr) -> crate::Result { + if let Some(ref instance_name) = builder.instance_name { + // First resolve the instance to a port via the + // SSRP protocol/MS-SQLR protocol [1] + // [1] https://msdn.microsoft.com/en-us/library/cc219703.aspx - let msg = [&[4u8], instance_name.as_bytes()].concat(); - let mut buf = vec![0u8; 4096]; + let local_bind: std::net::SocketAddr = if addr.is_ipv4() { + "0.0.0.0:0".parse().unwrap() + } else { + "[::]:0".parse().unwrap() + }; - let socket = UdpSocket::bind(&local_bind).await?; - socket.send_to(&msg, &addr).await?; + tracing::event!( + Level::TRACE, + "Connecting to instance `{}` using SQL Browser in port `{}`", + instance_name, + builder.get_port() + ); - let timeout = Duration::from_millis(1000); + let msg = [&[4u8], instance_name.as_bytes()].concat(); + let mut buf = vec![0u8; 4096]; - let len = socket.recv(&mut buf).or(async { - Timer::after(timeout).await; - Err(std::io::ErrorKind::TimedOut.into()) - }) - .map_err(|e| { - if e.kind() == std::io::ErrorKind::TimedOut { - crate::error::Error::Conversion( - format!( - "SQL browser timeout during resolving instance {}. Please check if browser is running in port {} and does the instance exist.", - instance_name, - builder.get_port(), - ) - .into(), - ) - } else { - e.into() - } - }).await?; + let socket = UdpSocket::bind(&local_bind).await?; + socket.send_to(&msg, &addr).await?; - let port = super::get_port_from_sql_browser_reply(buf, len, instance_name)?; - tracing::event!(Level::TRACE, "Found port `{}` from SQL Browser", port); - addr.set_port(port); - }; + let timeout = Duration::from_millis(1000); - if let Ok(stream) = TcpStream::connect(addr).await { - stream.set_nodelay(true)?; - return Ok(stream); - } - } + let len = socket.recv(&mut buf).or(async { + Timer::after(timeout).await; + Err(std::io::ErrorKind::TimedOut.into()) + }) + .map_err(|e| { + if e.kind() == std::io::ErrorKind::TimedOut { + crate::error::Error::Conversion( + format!( + "SQL browser timeout during resolving instance {}. Please check if browser is running in port {} and does the instance exist.", + instance_name, + builder.get_port(), + ) + .into(), + ) + } else { + e.into() + } + }).await?; - Err(io::Error::new(io::ErrorKind::NotFound, "Could not resolve server host").into()) - } + let port = super::get_port_from_sql_browser_reply(buf, len, instance_name)?; + tracing::event!(Level::TRACE, "Found port `{}` from SQL Browser", port); + addr.set_port(port); + }; + + let stream = TcpStream::connect(addr).await?; + stream.set_nodelay(true)?; + Ok(stream) } diff --git a/src/sql_browser/tokio.rs b/src/sql_browser/tokio.rs index 1fbf6e0e..9bc61152 100644 --- a/src/sql_browser/tokio.rs +++ b/src/sql_browser/tokio.rs @@ -2,8 +2,10 @@ use super::SqlBrowser; use crate::client::Config; use async_trait::async_trait; use futures_util::future::TryFutureExt; +use futures_util::stream::FuturesUnordered; +use futures_util::StreamExt; use net::{TcpStream, UdpSocket}; -use std::io; +use std::{io, net::SocketAddr}; use tokio::{ net, time::{self, error::Elapsed, Duration}, @@ -17,58 +19,78 @@ impl SqlBrowser for TcpStream { /// enabled. Please see the crate examples for more detailed examples. async fn connect_named(builder: &Config) -> crate::Result { let addrs = net::lookup_host(builder.get_addr()).await?; + let mut first_error = None; - for mut addr in addrs { - if let Some(ref instance_name) = builder.instance_name { - // First resolve the instance to a port via the - // SSRP protocol/MS-SQLR protocol [1] - // [1] https://msdn.microsoft.com/en-us/library/cc219703.aspx - - let local_bind: std::net::SocketAddr = if addr.is_ipv4() { - "0.0.0.0:0".parse().unwrap() - } else { - "[::]:0".parse().unwrap() + if builder.multi_subnet_failover { + let mut futures = addrs + .map(|addr| connect_addr(builder, addr)) + .collect::>(); + while let Some(connection) = futures.next().await { + match connection { + Ok(connection) => return Ok(connection), + Err(error) => first_error.get_or_insert(error), + }; + } + } else { + for addr in addrs { + match connect_addr(builder, addr).await { + Ok(connection) => return Ok(connection), + Err(error) => first_error.get_or_insert(error), }; + } + } + + // If we end up here, there was no successfull connection. + Err(first_error.unwrap_or_else(|| io::Error::new(io::ErrorKind::NotFound, "Could not resolve server host").into())) + } +} - tracing::event!( - Level::TRACE, - "Connecting to instance `{}` using SQL Browser in port `{}`", - instance_name, - builder.get_port() - ); +async fn connect_addr(builder: &Config, mut addr: SocketAddr) -> crate::Result { + if let Some(ref instance_name) = builder.instance_name { + // First resolve the instance to a port via the + // SSRP protocol/MS-SQLR protocol [1] + // [1] https://msdn.microsoft.com/en-us/library/cc219703.aspx - let msg = [&[4u8], instance_name.as_bytes()].concat(); - let mut buf = vec![0u8; 4096]; + let local_bind: std::net::SocketAddr = if addr.is_ipv4() { + "0.0.0.0:0".parse().unwrap() + } else { + "[::]:0".parse().unwrap() + }; - let socket = UdpSocket::bind(&local_bind).await?; - socket.send_to(&msg, &addr).await?; + tracing::event!( + Level::TRACE, + "Connecting to instance `{}` using SQL Browser in port `{}`", + instance_name, + builder.get_port() + ); - let timeout = Duration::from_millis(1000); + let msg = [&[4u8], instance_name.as_bytes()].concat(); + let mut buf = vec![0u8; 4096]; - let len = time::timeout(timeout, socket.recv(&mut buf)) - .map_err(|_: Elapsed| { - crate::error::Error::Conversion( - format!( - "SQL browser timeout during resolving instance {}. Please check if browser is running in port {} and does the instance exist.", - instance_name, - builder.get_port(), - ) - .into(), - ) - }) - .await??; + let socket = UdpSocket::bind(&local_bind).await?; + socket.send_to(&msg, &addr).await?; - let port = super::get_port_from_sql_browser_reply(buf, len, instance_name)?; - tracing::event!(Level::TRACE, "Found port `{}` from SQL Browser", port); - addr.set_port(port); - }; + let timeout = Duration::from_millis(1000); - if let Ok(stream) = TcpStream::connect(addr).await { - stream.set_nodelay(true)?; - return Ok(stream); - } - } + let len = time::timeout(timeout, socket.recv(&mut buf)) + .map_err(|_: Elapsed| { + crate::error::Error::Conversion( + format!( + "SQL browser timeout during resolving instance {}. Please check if browser is running in port {} and does the instance exist.", + instance_name, + builder.get_port(), + ) + .into(), + ) + }) + .await??; - Err(io::Error::new(io::ErrorKind::NotFound, "Could not resolve server host").into()) - } + let port = super::get_port_from_sql_browser_reply(buf, len, instance_name)?; + tracing::event!(Level::TRACE, "Found port `{}` from SQL Browser", port); + addr.set_port(port); + }; + + let stream = TcpStream::connect(addr).await?; + stream.set_nodelay(true)?; + Ok(stream) }