Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ features = [
"wincon",
"winerror",
"winnt",
"winsock2",
]

[package.metadata.docs.rs]
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,6 @@ pub mod console;
#[cfg(windows)]
pub mod file;
#[cfg(windows)]
pub mod socket;
#[cfg(windows)]
mod win;
71 changes: 71 additions & 0 deletions src/socket.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
use std::convert::TryInto;
use std::io;

use winapi::shared::minwindef::INT;
use winapi::um::winsock2::{WSAPoll, WSAPOLLFD};

/// `wsa_poll` waits for one of a set of file descriptors to become ready to
/// perform I/O.
///
/// This corresponds to calling [`WSAPoll`].
///
/// [`WSAPoll`]: https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-wsapoll
pub fn wsa_poll(
fd_array: &mut [WSAPOLLFD],
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm... I see why making winapi a public dependency is convenient here. Wrapping WSAPOLLFD would be pretty annoying. I'm not quite sure what to do, as I don't really grok these Windows APIs. Do you see any way of buttoning this up in a more convenient fashion?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your call. Is the following better or worse? (I haven't looked at your other comments yet. That will have to wait until tomorrow. But thanks a lot for this quick feedback.)

diff --git a/src/socket.rs b/src/socket.rs
index 78d216a..a43b7f6 100644
--- a/src/socket.rs
+++ b/src/socket.rs
@@ -1,21 +1,55 @@
 use std::convert::TryInto;
 use std::io;
+use std::os::raw::c_short;
+use std::os::windows::io::RawSocket;
 
 use winapi::shared::minwindef::INT;
-use winapi::um::{winnt::SHORT, winsock2};
+use winapi::um::winsock2::{self, WSAPOLLFD};
 
-pub const POLLRDNORM: SHORT = winsock2::POLLRDNORM;
-pub const POLLRDBAND: SHORT = winsock2::POLLRDBAND;
-pub const POLLIN: SHORT = winsock2::POLLIN;
-pub const POLLPRI: SHORT = winsock2::POLLPRI;
-pub const POLLWRNORM: SHORT = winsock2::POLLWRNORM;
-pub const POLLOUT: SHORT = winsock2::POLLOUT;
-pub const POLLWRBAND: SHORT = winsock2::POLLWRBAND;
-pub const POLLERR: SHORT = winsock2::POLLERR;
-pub const POLLHUP: SHORT = winsock2::POLLHUP;
-pub const POLLNVAL: SHORT = winsock2::POLLNVAL;
+pub const POLLRDNORM: c_short = winsock2::POLLRDNORM;
+pub const POLLRDBAND: c_short = winsock2::POLLRDBAND;
+pub const POLLIN: c_short = winsock2::POLLIN;
+pub const POLLPRI: c_short = winsock2::POLLPRI;
+pub const POLLWRNORM: c_short = winsock2::POLLWRNORM;
+pub const POLLOUT: c_short = winsock2::POLLOUT;
+pub const POLLWRBAND: c_short = winsock2::POLLWRBAND;
+pub const POLLERR: c_short = winsock2::POLLERR;
+pub const POLLHUP: c_short = winsock2::POLLHUP;
+pub const POLLNVAL: c_short = winsock2::POLLNVAL;
 
-pub use winsock2::WSAPOLLFD;
+/// Information necessary for the [`wsa_poll`] function.
+#[derive(Debug, Clone)]
+pub struct WSAPollFD {
+    fd: RawSocket,
+    events: c_short,
+    revents: c_short,
+}
+
+impl WSAPollFD {
+    /// Create a new `WSAPollFD` apecifying the events of interest for a given socket.
+    pub fn new(fd: RawSocket, events: c_short) -> Self {
+        Self { fd, events, revents: 0 }
+    }
+
+    /// Returns the events that occurred in the last call to [`wsa_poll`].
+    pub fn revents(&self) -> c_short {
+        self.revents
+    }
+
+    fn to_winapi(&self) -> WSAPOLLFD {
+        WSAPOLLFD {
+            fd: self.fd.try_into().unwrap(),
+            events: self.events,
+            revents: self.revents,
+        }
+    }
+
+    fn from_winapi(&mut self, poll: &WSAPOLLFD) {
+        self.fd = poll.fd.try_into().unwrap();
+        self.events = poll.events;
+        self.revents = poll.revents;
+    }
+}
 
 /// `wsa_poll` waits for one of a set of file descriptors to become ready to perform I/O.
 ///
@@ -23,15 +57,18 @@ pub use winsock2::WSAPOLLFD;
 ///
 /// [`WSAPoll`]: https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-wsapoll
 pub fn wsa_poll(
-    fd_array: &mut [WSAPOLLFD],
+    fd_array: &mut [WSAPollFD],
     timeout: INT,
 ) -> io::Result<usize> {
-    unsafe {
-        let length = fd_array.len().try_into().unwrap();
-        let rc = winsock2::WSAPoll(fd_array.as_mut_ptr(), length, timeout);
-        if rc < 0 {
-            return Err(io::Error::last_os_error());
-        };
-        Ok(rc.try_into().unwrap())
-    }
+    let mut array2 =
+        fd_array.iter().map(WSAPollFD::to_winapi).collect::<Vec<_>>();
+    let length = array2.len().try_into().unwrap();
+    let rc =
+        unsafe { winsock2::WSAPoll(array2.as_mut_ptr(), length, timeout) };
+    if rc < 0 {
+        return Err(io::Error::last_os_error());
+    };
+    // Copy the results to the output array
+    fd_array.iter_mut().zip(array2.iter()).for_each(|(a, b)| a.from_winapi(b));
+    Ok(rc.try_into().unwrap())
 }
diff --git a/tests/poll.rs b/tests/poll.rs
index bf67360..8882080 100644
--- a/tests/poll.rs
+++ b/tests/poll.rs
@@ -2,9 +2,9 @@
 mod windows {
     use std::io::{Result, Write};
     use std::net::{TcpListener, TcpStream};
+    use std::os::raw::c_short;
     use std::os::windows::io::{AsRawSocket, RawSocket};
 
-    use winapi::um::winnt::SHORT;
     use winapi_util::socket::*;
 
     /// Get a pair of connected TcpStreams
@@ -16,11 +16,15 @@ mod windows {
         Ok((stream1, stream2))
     }
 
-    fn poll(socket: RawSocket, events: SHORT, revents: SHORT) -> Result<()> {
-        let mut sockets = [WSAPOLLFD { fd: socket as _, events, revents: 0 }];
+    fn poll(
+        socket: RawSocket,
+        events: c_short,
+        revents: c_short,
+    ) -> Result<()> {
+        let mut sockets = [WSAPollFD::new(socket, events)];
         let count = wsa_poll(&mut sockets, -1)?;
         assert_eq!(count, 1);
-        assert_eq!(sockets[0].revents, revents);
+        assert_eq!(sockets[0].revents(), revents);
 
         Ok(())
     }

timeout: INT,
) -> io::Result<usize> {
unsafe {
let length = fd_array.len().try_into().unwrap();
let rc = WSAPoll(fd_array.as_mut_ptr(), length, timeout);
if rc < 0 {
return Err(io::Error::last_os_error());
};
Ok(rc.try_into().unwrap())
}
}

#[cfg(test)]
mod test {
use std::io::{Result, Write};
use std::net::{TcpListener, TcpStream};
use std::os::windows::io::{AsRawSocket, RawSocket};

use winapi::um::winnt::SHORT;
use winapi::um::winsock2::{POLLIN, POLLOUT, POLLRDNORM, WSAPOLLFD};

use super::wsa_poll;

/// Get a pair of connected TcpStreams
fn get_connection_pair() -> Result<(TcpStream, TcpStream)> {
let listener = TcpListener::bind("127.0.0.1:0")?;
let stream1 = TcpStream::connect(listener.local_addr()?)?;
let stream2 = listener.accept()?.0;

Ok((stream1, stream2))
}

fn poll(socket: RawSocket, events: SHORT, revents: SHORT) -> Result<()> {
let mut sockets = [WSAPOLLFD { fd: socket as _, events, revents: 0 }];
let count = wsa_poll(&mut sockets, -1)?;
assert_eq!(count, 1);
assert_eq!(sockets[0].revents, revents);

Ok(())
}

#[test]
fn test_poll() -> Result<()> {
let (mut stream1, stream2) = get_connection_pair()?;

// Check that stream1 is writable
poll(stream1.as_raw_socket(), POLLOUT, POLLOUT)?;

// Write something to the stream
stream1.write_all(b"1")?;

// stream2 should now be readable and writable
poll(stream2.as_raw_socket(), POLLIN | POLLOUT, POLLOUT | POLLRDNORM)?;

Ok(())
}
}