From 31c9651c5792ef1a3a2c41f1f5cdb0cff1f691ec Mon Sep 17 00:00:00 2001 From: Uli Schlachter Date: Sat, 16 May 2020 18:17:33 +0200 Subject: [PATCH 1/2] Add a safe binding to WSAPoll() Signed-off-by: Uli Schlachter --- Cargo.toml | 1 + src/lib.rs | 2 ++ src/socket.rs | 37 +++++++++++++++++++++++++++++++++++++ tests/poll.rs | 43 +++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 83 insertions(+) create mode 100644 src/socket.rs create mode 100644 tests/poll.rs diff --git a/Cargo.toml b/Cargo.toml index e1fb776..5eb91bd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,6 +25,7 @@ features = [ "wincon", "winerror", "winnt", + "winsock2", ] [package.metadata.docs.rs] diff --git a/src/lib.rs b/src/lib.rs index 0bb259d..f1a8535 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -29,4 +29,6 @@ pub mod console; #[cfg(windows)] pub mod file; #[cfg(windows)] +pub mod socket; +#[cfg(windows)] mod win; diff --git a/src/socket.rs b/src/socket.rs new file mode 100644 index 0000000..78d216a --- /dev/null +++ b/src/socket.rs @@ -0,0 +1,37 @@ +use std::convert::TryInto; +use std::io; + +use winapi::shared::minwindef::INT; +use winapi::um::{winnt::SHORT, winsock2}; + +pub const POLLRDNORM: SHORT = winsock2::POLLRDNORM; +pub const POLLRDBAND: SHORT = winsock2::POLLRDBAND; +pub const POLLIN: SHORT = winsock2::POLLIN; +pub const POLLPRI: SHORT = winsock2::POLLPRI; +pub const POLLWRNORM: SHORT = winsock2::POLLWRNORM; +pub const POLLOUT: SHORT = winsock2::POLLOUT; +pub const POLLWRBAND: SHORT = winsock2::POLLWRBAND; +pub const POLLERR: SHORT = winsock2::POLLERR; +pub const POLLHUP: SHORT = winsock2::POLLHUP; +pub const POLLNVAL: SHORT = winsock2::POLLNVAL; + +pub use winsock2::WSAPOLLFD; + +/// `wsa_poll` waits for one of a set of file descriptors to become ready to perform I/O. +/// +/// This corresponds to calling [`WSAPoll`]. +/// +/// [`WSAPoll`]: https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-wsapoll +pub fn wsa_poll( + fd_array: &mut [WSAPOLLFD], + timeout: INT, +) -> io::Result { + unsafe { + let length = fd_array.len().try_into().unwrap(); + let rc = winsock2::WSAPoll(fd_array.as_mut_ptr(), length, timeout); + if rc < 0 { + return Err(io::Error::last_os_error()); + }; + Ok(rc.try_into().unwrap()) + } +} diff --git a/tests/poll.rs b/tests/poll.rs new file mode 100644 index 0000000..bf67360 --- /dev/null +++ b/tests/poll.rs @@ -0,0 +1,43 @@ +#[cfg(windows)] +mod windows { + use std::io::{Result, Write}; + use std::net::{TcpListener, TcpStream}; + use std::os::windows::io::{AsRawSocket, RawSocket}; + + use winapi::um::winnt::SHORT; + use winapi_util::socket::*; + + /// Get a pair of connected TcpStreams + fn get_connection_pair() -> Result<(TcpStream, TcpStream)> { + let listener = TcpListener::bind("127.0.0.1:0")?; + let stream1 = TcpStream::connect(listener.local_addr()?)?; + let stream2 = listener.accept()?.0; + + Ok((stream1, stream2)) + } + + fn poll(socket: RawSocket, events: SHORT, revents: SHORT) -> Result<()> { + let mut sockets = [WSAPOLLFD { fd: socket as _, events, revents: 0 }]; + let count = wsa_poll(&mut sockets, -1)?; + assert_eq!(count, 1); + assert_eq!(sockets[0].revents, revents); + + Ok(()) + } + + #[test] + fn test_poll() -> Result<()> { + let (mut stream1, stream2) = get_connection_pair()?; + + // Check that stream1 is writable + poll(stream1.as_raw_socket(), POLLOUT, POLLOUT)?; + + // Write something to the stream + stream1.write_all(b"1")?; + + // stream2 should now be readable and writable + poll(stream2.as_raw_socket(), POLLIN | POLLOUT, POLLOUT | POLLRDNORM)?; + + Ok(()) + } +} From fbdf09ac1c3d8793eafda005afdfc37046920143 Mon Sep 17 00:00:00 2001 From: Uli Schlachter Date: Sun, 17 May 2020 11:21:41 +0200 Subject: [PATCH 2/2] Address review comments Signed-off-by: Uli Schlachter --- src/socket.rs | 68 ++++++++++++++++++++++++++++++++++++++------------- tests/poll.rs | 43 -------------------------------- 2 files changed, 51 insertions(+), 60 deletions(-) delete mode 100644 tests/poll.rs diff --git a/src/socket.rs b/src/socket.rs index 78d216a..9fb34ae 100644 --- a/src/socket.rs +++ b/src/socket.rs @@ -2,22 +2,10 @@ use std::convert::TryInto; use std::io; use winapi::shared::minwindef::INT; -use winapi::um::{winnt::SHORT, winsock2}; - -pub const POLLRDNORM: SHORT = winsock2::POLLRDNORM; -pub const POLLRDBAND: SHORT = winsock2::POLLRDBAND; -pub const POLLIN: SHORT = winsock2::POLLIN; -pub const POLLPRI: SHORT = winsock2::POLLPRI; -pub const POLLWRNORM: SHORT = winsock2::POLLWRNORM; -pub const POLLOUT: SHORT = winsock2::POLLOUT; -pub const POLLWRBAND: SHORT = winsock2::POLLWRBAND; -pub const POLLERR: SHORT = winsock2::POLLERR; -pub const POLLHUP: SHORT = winsock2::POLLHUP; -pub const POLLNVAL: SHORT = winsock2::POLLNVAL; - -pub use winsock2::WSAPOLLFD; - -/// `wsa_poll` waits for one of a set of file descriptors to become ready to perform I/O. +use winapi::um::winsock2::{WSAPoll, WSAPOLLFD}; + +/// `wsa_poll` waits for one of a set of file descriptors to become ready to +/// perform I/O. /// /// This corresponds to calling [`WSAPoll`]. /// @@ -28,10 +16,56 @@ pub fn wsa_poll( ) -> io::Result { unsafe { let length = fd_array.len().try_into().unwrap(); - let rc = winsock2::WSAPoll(fd_array.as_mut_ptr(), length, timeout); + let rc = WSAPoll(fd_array.as_mut_ptr(), length, timeout); if rc < 0 { return Err(io::Error::last_os_error()); }; Ok(rc.try_into().unwrap()) } } + +#[cfg(test)] +mod test { + use std::io::{Result, Write}; + use std::net::{TcpListener, TcpStream}; + use std::os::windows::io::{AsRawSocket, RawSocket}; + + use winapi::um::winnt::SHORT; + use winapi::um::winsock2::{POLLIN, POLLOUT, POLLRDNORM, WSAPOLLFD}; + + use super::wsa_poll; + + /// Get a pair of connected TcpStreams + fn get_connection_pair() -> Result<(TcpStream, TcpStream)> { + let listener = TcpListener::bind("127.0.0.1:0")?; + let stream1 = TcpStream::connect(listener.local_addr()?)?; + let stream2 = listener.accept()?.0; + + Ok((stream1, stream2)) + } + + fn poll(socket: RawSocket, events: SHORT, revents: SHORT) -> Result<()> { + let mut sockets = [WSAPOLLFD { fd: socket as _, events, revents: 0 }]; + let count = wsa_poll(&mut sockets, -1)?; + assert_eq!(count, 1); + assert_eq!(sockets[0].revents, revents); + + Ok(()) + } + + #[test] + fn test_poll() -> Result<()> { + let (mut stream1, stream2) = get_connection_pair()?; + + // Check that stream1 is writable + poll(stream1.as_raw_socket(), POLLOUT, POLLOUT)?; + + // Write something to the stream + stream1.write_all(b"1")?; + + // stream2 should now be readable and writable + poll(stream2.as_raw_socket(), POLLIN | POLLOUT, POLLOUT | POLLRDNORM)?; + + Ok(()) + } +} diff --git a/tests/poll.rs b/tests/poll.rs deleted file mode 100644 index bf67360..0000000 --- a/tests/poll.rs +++ /dev/null @@ -1,43 +0,0 @@ -#[cfg(windows)] -mod windows { - use std::io::{Result, Write}; - use std::net::{TcpListener, TcpStream}; - use std::os::windows::io::{AsRawSocket, RawSocket}; - - use winapi::um::winnt::SHORT; - use winapi_util::socket::*; - - /// Get a pair of connected TcpStreams - fn get_connection_pair() -> Result<(TcpStream, TcpStream)> { - let listener = TcpListener::bind("127.0.0.1:0")?; - let stream1 = TcpStream::connect(listener.local_addr()?)?; - let stream2 = listener.accept()?.0; - - Ok((stream1, stream2)) - } - - fn poll(socket: RawSocket, events: SHORT, revents: SHORT) -> Result<()> { - let mut sockets = [WSAPOLLFD { fd: socket as _, events, revents: 0 }]; - let count = wsa_poll(&mut sockets, -1)?; - assert_eq!(count, 1); - assert_eq!(sockets[0].revents, revents); - - Ok(()) - } - - #[test] - fn test_poll() -> Result<()> { - let (mut stream1, stream2) = get_connection_pair()?; - - // Check that stream1 is writable - poll(stream1.as_raw_socket(), POLLOUT, POLLOUT)?; - - // Write something to the stream - stream1.write_all(b"1")?; - - // stream2 should now be readable and writable - poll(stream2.as_raw_socket(), POLLIN | POLLOUT, POLLOUT | POLLRDNORM)?; - - Ok(()) - } -}