From db081597aa8c3299dc36b6ba82bae6d73d6fcf1a Mon Sep 17 00:00:00 2001 From: Tyler Clendenin Date: Mon, 30 Sep 2024 11:26:17 -0400 Subject: [PATCH 1/4] Add support for MultiSubnetFailover connection property When the MultiSubnetFailover=Yes property is added to the connection string, the TCP connection should be attempted for each resolved IP address in parallel rather than in sequence. This creates a race where the first connection to be established wins and becomes the target server. https://learn.microsoft.com/en-us/sql/relational-databases/native-client/features/sql-server-native-client-support-for-high-availability-disaster-recovery?view=sql-server-ver15#connecting-with-multisubnetfailover --- Cargo.toml | 4 ++ src/client/config.rs | 18 ++++++ src/sql_browser/async_std.rs | 107 ++++++++++++++++++++--------------- 3 files changed, 82 insertions(+), 47 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 9d8ccf95..98c95874 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -87,6 +87,10 @@ default-features = false version = "0.3" optional = true +[dependencies.futures] +version = "0.3" +default-features = false + [dependencies.futures-util] version = "0.3" default-features = false diff --git a/src/client/config.rs b/src/client/config.rs index fff68bc1..595898a0 100644 --- a/src/client/config.rs +++ b/src/client/config.rs @@ -32,6 +32,7 @@ pub struct Config { pub(crate) trust: TrustConfig, pub(crate) auth: AuthMethod, pub(crate) readonly: bool, + pub(crate) multi_subnet_failover: bool, } #[derive(Clone, Debug)] @@ -65,6 +66,7 @@ impl Default for Config { trust: TrustConfig::Default, auth: AuthMethod::None, readonly: false, + multi_subnet_failover: false, } } } @@ -171,6 +173,13 @@ impl Config { self.readonly = readnoly; } + /// Sets multiSubnetFailover flag. + /// + /// - Defaults to `false`. + pub fn multi_subnet_failover(&mut self, multi_subnet_failover: bool) { + self.multi_subnet_failover = multi_subnet_failover; + } + pub(crate) fn get_host(&self) -> &str { self.host .as_deref() @@ -269,6 +278,8 @@ impl Config { builder.readonly(s.readonly()); + builder.multi_subnet_failover(s.multi_subnet_failover()?); + Ok(builder) } } @@ -388,4 +399,11 @@ pub(crate) trait ConfigString { .filter(|val| *val == "ReadOnly") .is_some() } + + fn multi_subnet_failover(&self) -> crate::Result { + self.dict() + .get("multisubnetfailover") + .map(Self::parse_bool) + .unwrap_or(Ok(false)) + } } diff --git a/src/sql_browser/async_std.rs b/src/sql_browser/async_std.rs index 14f55de5..5236c924 100644 --- a/src/sql_browser/async_std.rs +++ b/src/sql_browser/async_std.rs @@ -1,69 +1,82 @@ use super::SqlBrowser; +use std::net::SocketAddr; use async_std::{ io, net::{self, ToSocketAddrs}, }; use async_trait::async_trait; +use futures::{future::select_all, FutureExt}; use futures_util::future::TryFutureExt; use std::time; use tracing::Level; -#[async_trait] -impl SqlBrowser for net::TcpStream { - /// This method can be used to connect to SQL Server named instances - /// when on a Windows platform with the `sql-browser-async-std` feature - /// enabled. Please see the crate examples for more detailed examples. - async fn connect_named(builder: &crate::client::Config) -> crate::Result { - let addrs = builder.get_addr().to_socket_addrs().await?; +async fn connect_addr(builder: &crate::client::Config, mut addr: SocketAddr) -> crate::Result { + if let Some(ref instance_name) = builder.instance_name { + // First resolve the instance to a port via the + // SSRP protocol/MS-SQLR protocol [1] + // [1] https://msdn.microsoft.com/en-us/library/cc219703.aspx - for mut addr in addrs { - if let Some(ref instance_name) = builder.instance_name { - // First resolve the instance to a port via the - // SSRP protocol/MS-SQLR protocol [1] - // [1] https://msdn.microsoft.com/en-us/library/cc219703.aspx + let local_bind: std::net::SocketAddr = if addr.is_ipv4() { + "0.0.0.0:0".parse().unwrap() + } else { + "[::]:0".parse().unwrap() + }; - let local_bind: std::net::SocketAddr = if addr.is_ipv4() { - "0.0.0.0:0".parse().unwrap() - } else { - "[::]:0".parse().unwrap() - }; + tracing::event!( + Level::TRACE, + "Connecting to instance `{}` using SQL Browser in port `{}`", + instance_name, + builder.get_port() + ); - tracing::event!( - Level::TRACE, - "Connecting to instance `{}` using SQL Browser in port `{}`", - instance_name, - builder.get_port() - ); + let msg = [&[4u8], instance_name.as_bytes()].concat(); + let mut buf = vec![0u8; 4096]; - let msg = [&[4u8], instance_name.as_bytes()].concat(); - let mut buf = vec![0u8; 4096]; + let socket = net::UdpSocket::bind(&local_bind).await?; + socket.send_to(&msg, &addr).await?; - let socket = net::UdpSocket::bind(&local_bind).await?; - socket.send_to(&msg, &addr).await?; + let timeout = time::Duration::from_millis(1000); - let timeout = time::Duration::from_millis(1000); + let len = io::timeout(timeout, socket.recv(&mut buf)) + .map_err(|_| { + crate::error::Error::Conversion( + format!( + "SQL browser timeout during resolving instance {}. Please check if browser is running in port {} and does the instance exist.", + instance_name, + builder.get_port(), + ) + .into(), + ) + }) + .await?; - let len = io::timeout(timeout, socket.recv(&mut buf)) - .map_err(|_| { - crate::error::Error::Conversion( - format!( - "SQL browser timeout during resolving instance {}. Please check if browser is running in port {} and does the instance exist.", - instance_name, - builder.get_port(), - ) - .into(), - ) - }) - .await?; + let port = super::get_port_from_sql_browser_reply(buf, len, instance_name)?; + tracing::event!(Level::TRACE, "Found port `{}` from SQL Browser", port); + addr.set_port(port); + }; + + if let Ok(stream) = net::TcpStream::connect(addr).await { + stream.set_nodelay(true)?; + return Ok(stream); + } else { + Err(io::Error::new(io::ErrorKind::NotFound, "Could not resolve server host").into()) + } +} - let port = super::get_port_from_sql_browser_reply(buf, len, instance_name)?; - tracing::event!(Level::TRACE, "Found port `{}` from SQL Browser", port); - addr.set_port(port); - }; +#[async_trait] +impl SqlBrowser for net::TcpStream { + /// This method can be used to connect to SQL Server named instances + /// when on a Windows platform with the `sql-browser-async-std` feature + /// enabled. Please see the crate examples for more detailed examples. + async fn connect_named(builder: &crate::client::Config) -> crate::Result { + let addrs = builder.get_addr().to_socket_addrs().await?; - if let Ok(stream) = net::TcpStream::connect(addr).await { - stream.set_nodelay(true)?; - return Ok(stream); + if builder.multi_subnet_failover { + let futures = addrs.map(|addr| connect_addr(builder, addr).boxed()); + select_all(futures).await; + } else { + for mut addr in addrs { + connect_addr(builder, addr).await?; } } From bf276876d3c04cef5ba00f3adfdf84a6d52375a9 Mon Sep 17 00:00:00 2001 From: Rasmus Kaj Date: Mon, 24 Mar 2025 11:21:50 +0100 Subject: [PATCH 2/4] Implement multi_subnet_failover for tokio. --- src/sql_browser/tokio.rs | 119 ++++++++++++++++++++++++--------------- 1 file changed, 75 insertions(+), 44 deletions(-) diff --git a/src/sql_browser/tokio.rs b/src/sql_browser/tokio.rs index 1fbf6e0e..c7c52738 100644 --- a/src/sql_browser/tokio.rs +++ b/src/sql_browser/tokio.rs @@ -2,8 +2,10 @@ use super::SqlBrowser; use crate::client::Config; use async_trait::async_trait; use futures_util::future::TryFutureExt; +use futures_util::stream::FuturesUnordered; +use futures_util::StreamExt; use net::{TcpStream, UdpSocket}; -use std::io; +use std::{io, net::SocketAddr}; use tokio::{ net, time::{self, error::Elapsed, Duration}, @@ -18,57 +20,86 @@ impl SqlBrowser for TcpStream { async fn connect_named(builder: &Config) -> crate::Result { let addrs = net::lookup_host(builder.get_addr()).await?; - for mut addr in addrs { - if let Some(ref instance_name) = builder.instance_name { - // First resolve the instance to a port via the - // SSRP protocol/MS-SQLR protocol [1] - // [1] https://msdn.microsoft.com/en-us/library/cc219703.aspx + if builder.multi_subnet_failover { + let mut futures = addrs + .map(|addr| connect_addr(builder, addr)) + .collect::>(); + let mut first_error = None; + for f in futures.next().await { + match f { + Ok(Some(result)) => return Ok(result), + Ok(None) => break, + Err(error) => { + if first_error.is_none() { + first_error = Some(error); + } + } + } + } + if let Some(error) = first_error { + return Err(error); + } + } else { + for addr in addrs { + if let Some(stream) = connect_addr(builder, addr).await? { + return Ok(stream); + } + } + } - let local_bind: std::net::SocketAddr = if addr.is_ipv4() { - "0.0.0.0:0".parse().unwrap() - } else { - "[::]:0".parse().unwrap() - }; + Err(io::Error::new(io::ErrorKind::NotFound, "Could not resolve server host").into()) + } +} - tracing::event!( - Level::TRACE, - "Connecting to instance `{}` using SQL Browser in port `{}`", - instance_name, - builder.get_port() - ); +async fn connect_addr(builder: &Config, mut addr: SocketAddr) -> crate::Result> { + if let Some(ref instance_name) = builder.instance_name { + // First resolve the instance to a port via the + // SSRP protocol/MS-SQLR protocol [1] + // [1] https://msdn.microsoft.com/en-us/library/cc219703.aspx - let msg = [&[4u8], instance_name.as_bytes()].concat(); - let mut buf = vec![0u8; 4096]; + let local_bind: std::net::SocketAddr = if addr.is_ipv4() { + "0.0.0.0:0".parse().unwrap() + } else { + "[::]:0".parse().unwrap() + }; - let socket = UdpSocket::bind(&local_bind).await?; - socket.send_to(&msg, &addr).await?; + tracing::event!( + Level::TRACE, + "Connecting to instance `{}` using SQL Browser in port `{}`", + instance_name, + builder.get_port() + ); - let timeout = Duration::from_millis(1000); + let msg = [&[4u8], instance_name.as_bytes()].concat(); + let mut buf = vec![0u8; 4096]; - let len = time::timeout(timeout, socket.recv(&mut buf)) - .map_err(|_: Elapsed| { - crate::error::Error::Conversion( - format!( - "SQL browser timeout during resolving instance {}. Please check if browser is running in port {} and does the instance exist.", - instance_name, - builder.get_port(), - ) - .into(), - ) - }) - .await??; + let socket = UdpSocket::bind(&local_bind).await?; + socket.send_to(&msg, &addr).await?; - let port = super::get_port_from_sql_browser_reply(buf, len, instance_name)?; - tracing::event!(Level::TRACE, "Found port `{}` from SQL Browser", port); - addr.set_port(port); - }; + let timeout = Duration::from_millis(1000); - if let Ok(stream) = TcpStream::connect(addr).await { - stream.set_nodelay(true)?; - return Ok(stream); - } - } + let len = time::timeout(timeout, socket.recv(&mut buf)) + .map_err(|_: Elapsed| { + crate::error::Error::Conversion( + format!( + "SQL browser timeout during resolving instance {}. Please check if browser is running in port {} and does the instance exist.", + instance_name, + builder.get_port(), + ) + .into(), + ) + }) + .await??; - Err(io::Error::new(io::ErrorKind::NotFound, "Could not resolve server host").into()) + let port = super::get_port_from_sql_browser_reply(buf, len, instance_name)?; + tracing::event!(Level::TRACE, "Found port `{}` from SQL Browser", port); + addr.set_port(port); + }; + + if let Ok(stream) = TcpStream::connect(addr).await { + stream.set_nodelay(true)?; + return Ok(Some(stream)); + } else { + return Ok(None); } } From f803c238e61b4b81675fcc5a32f2039f8cdfa4cf Mon Sep 17 00:00:00 2001 From: Rasmus Kaj Date: Tue, 25 Mar 2025 12:23:19 +0100 Subject: [PATCH 3/4] Add a config getter for the multi-subnet-failover flag --- src/client/config.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/client/config.rs b/src/client/config.rs index 595898a0..dd48e478 100644 --- a/src/client/config.rs +++ b/src/client/config.rs @@ -179,6 +179,10 @@ impl Config { pub fn multi_subnet_failover(&mut self, multi_subnet_failover: bool) { self.multi_subnet_failover = multi_subnet_failover; } + /// Gets multiSubnetFailover flag. + pub fn get_multi_subnet_failover(&self) -> bool { + self.multi_subnet_failover + } pub(crate) fn get_host(&self) -> &str { self.host From 0d1660febff01033042a8c1d04bee57473e6daa4 Mon Sep 17 00:00:00 2001 From: Rasmus Kaj Date: Fri, 11 Apr 2025 14:35:41 +0200 Subject: [PATCH 4/4] More multi_subnet_failover implementation. Now better implemented in sql_brower with tokio, async_std and smol. --- Cargo.toml | 4 -- src/client.rs | 3 + src/client/connection.rs | 3 + src/sql_browser/async_std.rs | 25 +++++-- src/sql_browser/smol.rs | 125 +++++++++++++++++++++-------------- src/sql_browser/tokio.rs | 41 +++++------- 6 files changed, 115 insertions(+), 86 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 98c95874..9d8ccf95 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -87,10 +87,6 @@ default-features = false version = "0.3" optional = true -[dependencies.futures] -version = "0.3" -default-features = false - [dependencies.futures-util] version = "0.3" default-features = false diff --git a/src/client.rs b/src/client.rs index 688721d1..521847af 100644 --- a/src/client.rs +++ b/src/client.rs @@ -68,6 +68,9 @@ impl Client { /// options required to connect to the database using an established /// tcp connection /// + /// Note: `tcp_stream` is a connected stream, so some parts of the `Config` + /// must be handled outside of this constructor. + /// /// [`Config`]: struct.Config.html pub async fn connect(config: Config, tcp_stream: S) -> crate::Result> { Ok(Client { diff --git a/src/client/connection.rs b/src/client/connection.rs index c6ce1d66..b8ef8667 100644 --- a/src/client/connection.rs +++ b/src/client/connection.rs @@ -72,6 +72,9 @@ impl Debug for Connection { impl Connection { /// Creates a new connection + /// + /// Note: `tcp_stream` is a connected stream, so some parts of the + /// [`Config`] needs to be handled outside of this method. pub(crate) async fn connect(config: Config, tcp_stream: S) -> crate::Result> { let context = { let mut context = Context::new(); diff --git a/src/sql_browser/async_std.rs b/src/sql_browser/async_std.rs index 5236c924..74e10d12 100644 --- a/src/sql_browser/async_std.rs +++ b/src/sql_browser/async_std.rs @@ -5,8 +5,9 @@ use async_std::{ net::{self, ToSocketAddrs}, }; use async_trait::async_trait; -use futures::{future::select_all, FutureExt}; use futures_util::future::TryFutureExt; +use futures_util::stream::FuturesUnordered; +use futures_util::StreamExt; use std::time; use tracing::Level; @@ -70,16 +71,28 @@ impl SqlBrowser for net::TcpStream { /// enabled. Please see the crate examples for more detailed examples. async fn connect_named(builder: &crate::client::Config) -> crate::Result { let addrs = builder.get_addr().to_socket_addrs().await?; + let mut first_error = None; if builder.multi_subnet_failover { - let futures = addrs.map(|addr| connect_addr(builder, addr).boxed()); - select_all(futures).await; + let mut futures = addrs + .map(|addr| connect_addr(builder, addr)) + .collect::>(); + while let Some(connection) = futures.next().await { + match connection { + Ok(connection) => return Ok(connection), + Err(error) => first_error.get_or_insert(error), + }; + } } else { - for mut addr in addrs { - connect_addr(builder, addr).await?; + for addr in addrs { + match connect_addr(builder, addr).await { + Ok(connection) => return Ok(connection), + Err(error) => first_error.get_or_insert(error), + }; } } - Err(io::Error::new(io::ErrorKind::NotFound, "Could not resolve server host").into()) + // If we end up here, there was no successfull connection. + Err(first_error.unwrap_or_else(|| io::Error::new(io::ErrorKind::NotFound, "Could not resolve server host").into())) } } diff --git a/src/sql_browser/smol.rs b/src/sql_browser/smol.rs index 252b834e..c19403cc 100644 --- a/src/sql_browser/smol.rs +++ b/src/sql_browser/smol.rs @@ -1,10 +1,12 @@ use super::SqlBrowser; use crate::client::Config; use async_io::Timer; -use async_net::{resolve, TcpStream, UdpSocket}; +use async_net::{resolve, TcpStream, UdpSocket, SocketAddr}; use async_trait::async_trait; use futures_lite::FutureExt; use futures_util::future::TryFutureExt; +use futures_util::stream::FuturesUnordered; +use futures_util::StreamExt; use std::io; use std::time::Duration; use tracing::Level; @@ -16,64 +18,85 @@ impl SqlBrowser for TcpStream { /// enabled. Please see the crate examples for more detailed examples. async fn connect_named(builder: &Config) -> crate::Result { let addrs = resolve(builder.get_addr()).await?; + let mut first_error = None; - for mut addr in addrs { - if let Some(ref instance_name) = builder.instance_name { - // First resolve the instance to a port via the - // SSRP protocol/MS-SQLR protocol [1] - // [1] https://msdn.microsoft.com/en-us/library/cc219703.aspx - - let local_bind: std::net::SocketAddr = if addr.is_ipv4() { - "0.0.0.0:0".parse().unwrap() - } else { - "[::]:0".parse().unwrap() + if builder.multi_subnet_failover { + let mut futures = addrs + .into_iter() + .map(|addr| connect_addr(builder, addr)) + .collect::>(); + while let Some(connection) = futures.next().await { + match connection { + Ok(connection) => return Ok(connection), + Err(error) => first_error.get_or_insert(error), + }; + } + } else { + for addr in addrs { + match connect_addr(builder, addr).await { + Ok(connection) => return Ok(connection), + Err(error) => first_error.get_or_insert(error), }; + } + } + + // If we end up here, there was no successfull connection. + Err(first_error.unwrap_or_else(|| io::Error::new(io::ErrorKind::NotFound, "Could not resolve server host").into())) + } +} - tracing::event!( - Level::TRACE, - "Connecting to instance `{}` using SQL Browser in port `{}`", - instance_name, - builder.get_port() - ); +async fn connect_addr(builder: &Config, mut addr: SocketAddr) -> crate::Result { + if let Some(ref instance_name) = builder.instance_name { + // First resolve the instance to a port via the + // SSRP protocol/MS-SQLR protocol [1] + // [1] https://msdn.microsoft.com/en-us/library/cc219703.aspx - let msg = [&[4u8], instance_name.as_bytes()].concat(); - let mut buf = vec![0u8; 4096]; + let local_bind: std::net::SocketAddr = if addr.is_ipv4() { + "0.0.0.0:0".parse().unwrap() + } else { + "[::]:0".parse().unwrap() + }; - let socket = UdpSocket::bind(&local_bind).await?; - socket.send_to(&msg, &addr).await?; + tracing::event!( + Level::TRACE, + "Connecting to instance `{}` using SQL Browser in port `{}`", + instance_name, + builder.get_port() + ); - let timeout = Duration::from_millis(1000); + let msg = [&[4u8], instance_name.as_bytes()].concat(); + let mut buf = vec![0u8; 4096]; - let len = socket.recv(&mut buf).or(async { - Timer::after(timeout).await; - Err(std::io::ErrorKind::TimedOut.into()) - }) - .map_err(|e| { - if e.kind() == std::io::ErrorKind::TimedOut { - crate::error::Error::Conversion( - format!( - "SQL browser timeout during resolving instance {}. Please check if browser is running in port {} and does the instance exist.", - instance_name, - builder.get_port(), - ) - .into(), - ) - } else { - e.into() - } - }).await?; + let socket = UdpSocket::bind(&local_bind).await?; + socket.send_to(&msg, &addr).await?; - let port = super::get_port_from_sql_browser_reply(buf, len, instance_name)?; - tracing::event!(Level::TRACE, "Found port `{}` from SQL Browser", port); - addr.set_port(port); - }; + let timeout = Duration::from_millis(1000); - if let Ok(stream) = TcpStream::connect(addr).await { - stream.set_nodelay(true)?; - return Ok(stream); - } - } + let len = socket.recv(&mut buf).or(async { + Timer::after(timeout).await; + Err(std::io::ErrorKind::TimedOut.into()) + }) + .map_err(|e| { + if e.kind() == std::io::ErrorKind::TimedOut { + crate::error::Error::Conversion( + format!( + "SQL browser timeout during resolving instance {}. Please check if browser is running in port {} and does the instance exist.", + instance_name, + builder.get_port(), + ) + .into(), + ) + } else { + e.into() + } + }).await?; - Err(io::Error::new(io::ErrorKind::NotFound, "Could not resolve server host").into()) - } + let port = super::get_port_from_sql_browser_reply(buf, len, instance_name)?; + tracing::event!(Level::TRACE, "Found port `{}` from SQL Browser", port); + addr.set_port(port); + }; + + let stream = TcpStream::connect(addr).await?; + stream.set_nodelay(true)?; + Ok(stream) } diff --git a/src/sql_browser/tokio.rs b/src/sql_browser/tokio.rs index c7c52738..9bc61152 100644 --- a/src/sql_browser/tokio.rs +++ b/src/sql_browser/tokio.rs @@ -19,39 +19,33 @@ impl SqlBrowser for TcpStream { /// enabled. Please see the crate examples for more detailed examples. async fn connect_named(builder: &Config) -> crate::Result { let addrs = net::lookup_host(builder.get_addr()).await?; + let mut first_error = None; if builder.multi_subnet_failover { let mut futures = addrs .map(|addr| connect_addr(builder, addr)) .collect::>(); - let mut first_error = None; - for f in futures.next().await { - match f { - Ok(Some(result)) => return Ok(result), - Ok(None) => break, - Err(error) => { - if first_error.is_none() { - first_error = Some(error); - } - } - } - } - if let Some(error) = first_error { - return Err(error); + while let Some(connection) = futures.next().await { + match connection { + Ok(connection) => return Ok(connection), + Err(error) => first_error.get_or_insert(error), + }; } } else { for addr in addrs { - if let Some(stream) = connect_addr(builder, addr).await? { - return Ok(stream); - } + match connect_addr(builder, addr).await { + Ok(connection) => return Ok(connection), + Err(error) => first_error.get_or_insert(error), + }; } } - Err(io::Error::new(io::ErrorKind::NotFound, "Could not resolve server host").into()) + // If we end up here, there was no successfull connection. + Err(first_error.unwrap_or_else(|| io::Error::new(io::ErrorKind::NotFound, "Could not resolve server host").into())) } } -async fn connect_addr(builder: &Config, mut addr: SocketAddr) -> crate::Result> { +async fn connect_addr(builder: &Config, mut addr: SocketAddr) -> crate::Result { if let Some(ref instance_name) = builder.instance_name { // First resolve the instance to a port via the // SSRP protocol/MS-SQLR protocol [1] @@ -96,10 +90,7 @@ async fn connect_addr(builder: &Config, mut addr: SocketAddr) -> crate::Result