Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> Client<S> {
/// options required to connect to the database using an established
/// tcp connection
///
/// Note: `tcp_stream` is a connected stream, so some parts of the `Config`
/// must be handled outside of this constructor.
///
/// [`Config`]: struct.Config.html
pub async fn connect(config: Config, tcp_stream: S) -> crate::Result<Client<S>> {
Ok(Client {
Expand Down
22 changes: 22 additions & 0 deletions src/client/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ pub struct Config {
pub(crate) trust: TrustConfig,
pub(crate) auth: AuthMethod,
pub(crate) readonly: bool,
pub(crate) multi_subnet_failover: bool,
}

#[derive(Clone, Debug)]
Expand Down Expand Up @@ -65,6 +66,7 @@ impl Default for Config {
trust: TrustConfig::Default,
auth: AuthMethod::None,
readonly: false,
multi_subnet_failover: false,
}
}
}
Expand Down Expand Up @@ -171,6 +173,17 @@ impl Config {
self.readonly = readnoly;
}

/// Sets multiSubnetFailover flag.
///
/// - Defaults to `false`.
pub fn multi_subnet_failover(&mut self, multi_subnet_failover: bool) {
self.multi_subnet_failover = multi_subnet_failover;
}
/// Gets multiSubnetFailover flag.
pub fn get_multi_subnet_failover(&self) -> bool {
self.multi_subnet_failover
}

pub(crate) fn get_host(&self) -> &str {
self.host
.as_deref()
Expand Down Expand Up @@ -269,6 +282,8 @@ impl Config {

builder.readonly(s.readonly());

builder.multi_subnet_failover(s.multi_subnet_failover()?);

Ok(builder)
}
}
Expand Down Expand Up @@ -388,4 +403,11 @@ pub(crate) trait ConfigString {
.filter(|val| *val == "ReadOnly")
.is_some()
}

fn multi_subnet_failover(&self) -> crate::Result<bool> {
self.dict()
.get("multisubnetfailover")
.map(Self::parse_bool)
.unwrap_or(Ok(false))
}
}
3 changes: 3 additions & 0 deletions src/client/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> Debug for Connection<S> {

impl<S: AsyncRead + AsyncWrite + Unpin + Send> Connection<S> {
/// Creates a new connection
///
/// Note: `tcp_stream` is a connected stream, so some parts of the
/// [`Config`] needs to be handled outside of this method.
pub(crate) async fn connect(config: Config, tcp_stream: S) -> crate::Result<Connection<S>> {
let context = {
let mut context = Context::new();
Expand Down
122 changes: 74 additions & 48 deletions src/sql_browser/async_std.rs
Original file line number Diff line number Diff line change
@@ -1,72 +1,98 @@
use super::SqlBrowser;
use std::net::SocketAddr;
use async_std::{
io,
net::{self, ToSocketAddrs},
};
use async_trait::async_trait;
use futures_util::future::TryFutureExt;
use futures_util::stream::FuturesUnordered;
use futures_util::StreamExt;
use std::time;
use tracing::Level;

#[async_trait]
impl SqlBrowser for net::TcpStream {
/// This method can be used to connect to SQL Server named instances
/// when on a Windows platform with the `sql-browser-async-std` feature
/// enabled. Please see the crate examples for more detailed examples.
async fn connect_named(builder: &crate::client::Config) -> crate::Result<Self> {
let addrs = builder.get_addr().to_socket_addrs().await?;
async fn connect_addr(builder: &crate::client::Config, mut addr: SocketAddr) -> crate::Result<net::TcpStream> {
if let Some(ref instance_name) = builder.instance_name {
// First resolve the instance to a port via the
// SSRP protocol/MS-SQLR protocol [1]
// [1] https://msdn.microsoft.com/en-us/library/cc219703.aspx

for mut addr in addrs {
if let Some(ref instance_name) = builder.instance_name {
// First resolve the instance to a port via the
// SSRP protocol/MS-SQLR protocol [1]
// [1] https://msdn.microsoft.com/en-us/library/cc219703.aspx
let local_bind: std::net::SocketAddr = if addr.is_ipv4() {
"0.0.0.0:0".parse().unwrap()
} else {
"[::]:0".parse().unwrap()
};

let local_bind: std::net::SocketAddr = if addr.is_ipv4() {
"0.0.0.0:0".parse().unwrap()
} else {
"[::]:0".parse().unwrap()
};
tracing::event!(
Level::TRACE,
"Connecting to instance `{}` using SQL Browser in port `{}`",
instance_name,
builder.get_port()
);

tracing::event!(
Level::TRACE,
"Connecting to instance `{}` using SQL Browser in port `{}`",
instance_name,
builder.get_port()
);
let msg = [&[4u8], instance_name.as_bytes()].concat();
let mut buf = vec![0u8; 4096];

let msg = [&[4u8], instance_name.as_bytes()].concat();
let mut buf = vec![0u8; 4096];
let socket = net::UdpSocket::bind(&local_bind).await?;
socket.send_to(&msg, &addr).await?;

let socket = net::UdpSocket::bind(&local_bind).await?;
socket.send_to(&msg, &addr).await?;
let timeout = time::Duration::from_millis(1000);

let timeout = time::Duration::from_millis(1000);
let len = io::timeout(timeout, socket.recv(&mut buf))
.map_err(|_| {
crate::error::Error::Conversion(
format!(
"SQL browser timeout during resolving instance {}. Please check if browser is running in port {} and does the instance exist.",
instance_name,
builder.get_port(),
)
.into(),
)
})
.await?;

let len = io::timeout(timeout, socket.recv(&mut buf))
.map_err(|_| {
crate::error::Error::Conversion(
format!(
"SQL browser timeout during resolving instance {}. Please check if browser is running in port {} and does the instance exist.",
instance_name,
builder.get_port(),
)
.into(),
)
})
.await?;
let port = super::get_port_from_sql_browser_reply(buf, len, instance_name)?;
tracing::event!(Level::TRACE, "Found port `{}` from SQL Browser", port);
addr.set_port(port);
};

let port = super::get_port_from_sql_browser_reply(buf, len, instance_name)?;
tracing::event!(Level::TRACE, "Found port `{}` from SQL Browser", port);
addr.set_port(port);
};
if let Ok(stream) = net::TcpStream::connect(addr).await {
stream.set_nodelay(true)?;
return Ok(stream);
} else {
Err(io::Error::new(io::ErrorKind::NotFound, "Could not resolve server host").into())
}
}

if let Ok(stream) = net::TcpStream::connect(addr).await {
stream.set_nodelay(true)?;
return Ok(stream);
#[async_trait]
impl SqlBrowser for net::TcpStream {
/// This method can be used to connect to SQL Server named instances
/// when on a Windows platform with the `sql-browser-async-std` feature
/// enabled. Please see the crate examples for more detailed examples.
async fn connect_named(builder: &crate::client::Config) -> crate::Result<Self> {
let addrs = builder.get_addr().to_socket_addrs().await?;
let mut first_error = None;

if builder.multi_subnet_failover {
let mut futures = addrs
.map(|addr| connect_addr(builder, addr))
.collect::<FuturesUnordered<_>>();
while let Some(connection) = futures.next().await {
match connection {
Ok(connection) => return Ok(connection),
Err(error) => first_error.get_or_insert(error),
};
}
} else {
for addr in addrs {
match connect_addr(builder, addr).await {
Ok(connection) => return Ok(connection),
Err(error) => first_error.get_or_insert(error),
};
}
}

Err(io::Error::new(io::ErrorKind::NotFound, "Could not resolve server host").into())
// If we end up here, there was no successfull connection.
Err(first_error.unwrap_or_else(|| io::Error::new(io::ErrorKind::NotFound, "Could not resolve server host").into()))
}
}
125 changes: 74 additions & 51 deletions src/sql_browser/smol.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use super::SqlBrowser;
use crate::client::Config;
use async_io::Timer;
use async_net::{resolve, TcpStream, UdpSocket};
use async_net::{resolve, TcpStream, UdpSocket, SocketAddr};
use async_trait::async_trait;
use futures_lite::FutureExt;
use futures_util::future::TryFutureExt;
use futures_util::stream::FuturesUnordered;
use futures_util::StreamExt;
use std::io;
use std::time::Duration;
use tracing::Level;
Expand All @@ -16,64 +18,85 @@ impl SqlBrowser for TcpStream {
/// enabled. Please see the crate examples for more detailed examples.
async fn connect_named(builder: &Config) -> crate::Result<Self> {
let addrs = resolve(builder.get_addr()).await?;
let mut first_error = None;

for mut addr in addrs {
if let Some(ref instance_name) = builder.instance_name {
// First resolve the instance to a port via the
// SSRP protocol/MS-SQLR protocol [1]
// [1] https://msdn.microsoft.com/en-us/library/cc219703.aspx

let local_bind: std::net::SocketAddr = if addr.is_ipv4() {
"0.0.0.0:0".parse().unwrap()
} else {
"[::]:0".parse().unwrap()
if builder.multi_subnet_failover {
let mut futures = addrs
.into_iter()
.map(|addr| connect_addr(builder, addr))
.collect::<FuturesUnordered<_>>();
while let Some(connection) = futures.next().await {
match connection {
Ok(connection) => return Ok(connection),
Err(error) => first_error.get_or_insert(error),
};
}
} else {
for addr in addrs {
match connect_addr(builder, addr).await {
Ok(connection) => return Ok(connection),
Err(error) => first_error.get_or_insert(error),
};
}
}

// If we end up here, there was no successfull connection.
Err(first_error.unwrap_or_else(|| io::Error::new(io::ErrorKind::NotFound, "Could not resolve server host").into()))
}
}

tracing::event!(
Level::TRACE,
"Connecting to instance `{}` using SQL Browser in port `{}`",
instance_name,
builder.get_port()
);
async fn connect_addr(builder: &Config, mut addr: SocketAddr) -> crate::Result<TcpStream> {
if let Some(ref instance_name) = builder.instance_name {
// First resolve the instance to a port via the
// SSRP protocol/MS-SQLR protocol [1]
// [1] https://msdn.microsoft.com/en-us/library/cc219703.aspx

let msg = [&[4u8], instance_name.as_bytes()].concat();
let mut buf = vec![0u8; 4096];
let local_bind: std::net::SocketAddr = if addr.is_ipv4() {
"0.0.0.0:0".parse().unwrap()
} else {
"[::]:0".parse().unwrap()
};

let socket = UdpSocket::bind(&local_bind).await?;
socket.send_to(&msg, &addr).await?;
tracing::event!(
Level::TRACE,
"Connecting to instance `{}` using SQL Browser in port `{}`",
instance_name,
builder.get_port()
);

let timeout = Duration::from_millis(1000);
let msg = [&[4u8], instance_name.as_bytes()].concat();
let mut buf = vec![0u8; 4096];

let len = socket.recv(&mut buf).or(async {
Timer::after(timeout).await;
Err(std::io::ErrorKind::TimedOut.into())
})
.map_err(|e| {
if e.kind() == std::io::ErrorKind::TimedOut {
crate::error::Error::Conversion(
format!(
"SQL browser timeout during resolving instance {}. Please check if browser is running in port {} and does the instance exist.",
instance_name,
builder.get_port(),
)
.into(),
)
} else {
e.into()
}
}).await?;
let socket = UdpSocket::bind(&local_bind).await?;
socket.send_to(&msg, &addr).await?;

let port = super::get_port_from_sql_browser_reply(buf, len, instance_name)?;
tracing::event!(Level::TRACE, "Found port `{}` from SQL Browser", port);
addr.set_port(port);
};
let timeout = Duration::from_millis(1000);

if let Ok(stream) = TcpStream::connect(addr).await {
stream.set_nodelay(true)?;
return Ok(stream);
}
}
let len = socket.recv(&mut buf).or(async {
Timer::after(timeout).await;
Err(std::io::ErrorKind::TimedOut.into())
})
.map_err(|e| {
if e.kind() == std::io::ErrorKind::TimedOut {
crate::error::Error::Conversion(
format!(
"SQL browser timeout during resolving instance {}. Please check if browser is running in port {} and does the instance exist.",
instance_name,
builder.get_port(),
)
.into(),
)
} else {
e.into()
}
}).await?;

Err(io::Error::new(io::ErrorKind::NotFound, "Could not resolve server host").into())
}
let port = super::get_port_from_sql_browser_reply(buf, len, instance_name)?;
tracing::event!(Level::TRACE, "Found port `{}` from SQL Browser", port);
addr.set_port(port);
};

let stream = TcpStream::connect(addr).await?;
stream.set_nodelay(true)?;
Ok(stream)
}
Loading