diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 0000000..f8fff10 --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,31 @@ +// For format details, see https://aka.ms/devcontainer.json. For config options, see the +// README at: https://github.com/devcontainers/templates/tree/main/src/rust +{ + "name": "Rust", + // Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile + "image": "mcr.microsoft.com/devcontainers/rust:1-1-bullseye" + + // Use 'mounts' to make the cargo cache persistent in a Docker Volume. + // "mounts": [ + // { + // "source": "devcontainer-cargo-cache-${devcontainerId}", + // "target": "/usr/local/cargo", + // "type": "volume" + // } + // ] + + // Features to add to the dev container. More info: https://containers.dev/features. + // "features": {}, + + // Use 'forwardPorts' to make a list of ports inside the container available locally. + // "forwardPorts": [], + + // Use 'postCreateCommand' to run commands after the container is created. + // "postCreateCommand": "rustc --version", + + // Configure tool-specific properties. + // "customizations": {}, + + // Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root. + // "remoteUser": "root" +} diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..f33a02c --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,12 @@ +# To get started with Dependabot version updates, you'll need to specify which +# package ecosystems to update and where the package manifests are located. +# Please see the documentation for more information: +# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates +# https://containers.dev/guide/dependabot + +version: 2 +updates: + - package-ecosystem: "devcontainers" + directory: "/" + schedule: + interval: weekly diff --git a/Cargo.toml b/Cargo.toml index 06fdcb9..f808e53 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,15 +18,22 @@ httparse = { version = "1.8.0", default-features = false } rand_core = { version = "0.6.4", default-features = false } base64 = { version = "0.13.1", default-features = false } futures = { version = "0.3.28", default-features = false } +embedded-io-async = { version = "0.6.1", default-features = false, optional = true } [dev-dependencies] rand = "0.8.5" bytes = "1.4.0" tokio = { version = "1.28.2", features = ["macros", "rt-multi-thread"] } tokio-util = { version = "0.7.8", features = ["net", "codec"] } +embedded-io-adapters = { version = "0.6.1", features = ["tokio-1"] } # see readme for no_std support [features] default = ["std"] -# default = [] std = ["httparse/std"] +embedded-io-async = ["dep:embedded-io-async"] + +[[example]] +name = "client_async_embedded_io_async" +path = "examples/client_async_embedded_io_async.rs" +required-features = ["embedded-io-async"] diff --git a/examples/client_async_embedded_io_async.rs b/examples/client_async_embedded_io_async.rs new file mode 100644 index 0000000..aac102c --- /dev/null +++ b/examples/client_async_embedded_io_async.rs @@ -0,0 +1,71 @@ +use embedded_io_adapters::tokio_1::FromTokio; +use std::error::Error; + +use tokio::net::TcpStream; + +use embedded_websocket::{ + framer_async::{Framer, FramerError, ReadResult}, + WebSocketClient, WebSocketCloseStatusCode, WebSocketOptions, WebSocketSendMessageType, +}; + +#[tokio::main] +async fn main() -> Result<(), FramerError> { + // Connect to a peer + let address = "127.0.0.1:1337"; + let mut buffer = [0u8; 4000]; + let tcp_stream = TcpStream::connect(address).await.map_err(FramerError::Io)?; + let mut stream = FromTokio::new(tcp_stream); + let websocket = WebSocketClient::new_client(rand::thread_rng()); + + // initiate a websocket opening handshake + let websocket_options = WebSocketOptions { + path: "/chat", + host: "localhost", + origin: "http://localhost:1337", + sub_protocols: None, + additional_headers: None, + }; + + let mut framer = Framer::new(websocket); + + framer + .connect(&mut stream, &mut buffer, &websocket_options) + .await?; + + println!("ws handshake complete"); + + framer + .write( + &mut stream, + &mut buffer, + WebSocketSendMessageType::Text, + true, + "Hello, world".as_bytes(), + ) + .await?; + + println!("sent message"); + + while let Some(read_result) = framer.read(&mut stream, &mut buffer).await { + let read_result = read_result?; + match read_result { + ReadResult::Text(text) => { + println!("received text: {text}"); + + framer + .close( + &mut stream, + &mut buffer, + WebSocketCloseStatusCode::NormalClosure, + None, + ) + .await? + } + _ => { // ignore other kinds of messages + } + } + } + + println!("closed"); + Ok(()) +} diff --git a/src/framer_async.rs b/src/framer_async.rs index 20a0063..97cd0f6 100644 --- a/src/framer_async.rs +++ b/src/framer_async.rs @@ -1,5 +1,9 @@ -use core::{fmt::Debug, ops::Deref, str::Utf8Error}; - +#[cfg(not(feature = "embedded-io-async"))] +use core::ops::Deref; +use core::{fmt::Debug, str::Utf8Error}; +#[cfg(feature = "embedded-io-async")] +use embedded_io_async::{ErrorType, Read, Write}; +#[cfg(not(feature = "embedded-io-async"))] use futures::{Sink, SinkExt, Stream, StreamExt}; use rand_core::RngCore; @@ -50,6 +54,7 @@ where rx_remainder_len: usize, } +#[cfg(not(feature = "embedded-io-async"))] impl Framer where TRng: RngCore, @@ -111,6 +116,7 @@ where } } +#[cfg(not(feature = "embedded-io-async"))] impl Framer where TRng: RngCore, @@ -306,3 +312,255 @@ where None } } + +#[cfg(feature = "embedded-io-async")] +impl Framer +where + TRng: RngCore, +{ + pub async fn connect<'a, S>( + &mut self, + stream: &mut S, + buffer: &'a mut [u8], + websocket_options: &WebSocketOptions<'_>, + ) -> Result, FramerError<::Error>> + where + S: Read + Write + Unpin, + { + let (tx_len, web_socket_key) = self + .websocket + .client_connect(websocket_options, buffer) + .map_err(FramerError::WebSocket)?; + + let (tx_buf, _rx_buf) = buffer.split_at_mut(tx_len); + stream.write(tx_buf).await.map_err(FramerError::Io)?; + stream.flush().await.map_err(FramerError::Io)?; + + loop { + let read_len = stream.read(buffer).await.map_err(FramerError::Io)?; + + match self.websocket.client_accept(&web_socket_key, buffer) { + Ok((len, sub_protocol)) => { + // "consume" the HTTP header that we have read from the stream + // read_cursor would be 0 if we exactly read the HTTP header from the stream and nothing else + + // copy the remaining bytes to the end of the rx_buf (which is also the end of the buffer) because they are the contents of the next websocket frame(s) + let from = len; + let to = read_len; + let remaining_len = to - from; + + if remaining_len > 0 { + // let rx_start = read_len - remaining_len; + // rx_buf[rx_start..].copy_from_slice(&buf[from..to]); + self.rx_remainder_len = remaining_len; + } + + return Ok(sub_protocol); + } + Err(crate::Error::HttpHeaderIncomplete) => { + // TODO: continue reading HTTP header in loop + panic!("oh no"); + } + Err(e) => { + return Err(FramerError::WebSocket(e)); + } + } + } + } +} + +#[cfg(feature = "embedded-io-async")] +impl Framer +where + TRng: RngCore, + TWebSocketType: WebSocketType, +{ + pub fn new(websocket: WebSocket) -> Self { + Self { + websocket, + frame_cursor: 0, + rx_remainder_len: 0, + } + } + + pub fn encode( + &mut self, + message_type: WebSocketSendMessageType, + end_of_message: bool, + from: &[u8], + to: &mut [u8], + ) -> Result> { + let len = self + .websocket + .write(message_type, end_of_message, from, to) + .map_err(FramerError::WebSocket)?; + + Ok(len) + } + + pub async fn write<'b, T>( + &mut self, + tx: &mut T, + tx_buf: &'b mut [u8], + message_type: WebSocketSendMessageType, + end_of_message: bool, + frame_buf: &[u8], + ) -> Result<(), FramerError<::Error>> + where + T: Write + Unpin, + { + let len = self + .websocket + .write(message_type, end_of_message, frame_buf, tx_buf) + .map_err(FramerError::WebSocket)?; + + tx.write(&tx_buf[..len]) + .await + .map_err(FramerError::Io) + .unwrap(); + tx.flush().await.map_err(FramerError::Io).unwrap(); + Ok(()) + } + + pub async fn close<'b, T>( + &mut self, + tx: &mut T, + tx_buf: &'b mut [u8], + close_status: WebSocketCloseStatusCode, + status_description: Option<&str>, + ) -> Result<(), FramerError<::Error>> + where + T: Write + Unpin, + { + let len = self + .websocket + .close(close_status, status_description, tx_buf) + .map_err(FramerError::WebSocket)?; + + tx.write(&tx_buf[..len]) + .await + .map_err(FramerError::Io) + .unwrap(); + tx.flush().await.map_err(FramerError::Io).unwrap(); + Ok(()) + } + + // NOTE: any unused bytes read from the stream but not decoded are stored at the end + // of the buffer to be used next time this read function is called. This also applies to + // any unused bytes read when the connect handshake was made. Therefore it is important that + // the caller does not clear this buffer between calls or use it for anthing other than reads. + pub async fn read<'a, S>( + &mut self, + stream: &mut S, + buffer: &'a mut [u8], + ) -> Option, FramerError<::Error>>> + where + S: Read + Write + Unpin, + { + if self.rx_remainder_len == 0 { + match stream.read(buffer).await { + Ok(read_len) => { + if buffer.len() < read_len { + return Some(Err(FramerError::RxBufferTooSmall(read_len))); + } + + self.rx_remainder_len = read_len + } + Err(error) => { + return Some(Err(FramerError::Io(error))); + } + } + } + + let (rx_buf, frame_buf) = buffer.split_at_mut(self.rx_remainder_len); + let ws_result = match self.websocket.read(rx_buf, frame_buf) { + Ok(ws_result) => ws_result, + Err(e) => return Some(Err(FramerError::WebSocket(e))), + }; + + self.rx_remainder_len -= ws_result.len_from; + + match ws_result.message_type { + WebSocketReceiveMessageType::Binary => { + self.frame_cursor += ws_result.len_to; + if ws_result.end_of_message { + let range = 0..self.frame_cursor; + self.frame_cursor = 0; + return Some(Ok(ReadResult::Binary(&frame_buf[range]))); + } + } + WebSocketReceiveMessageType::Text => { + self.frame_cursor += ws_result.len_to; + if ws_result.end_of_message { + let range = 0..self.frame_cursor; + self.frame_cursor = 0; + match core::str::from_utf8(&frame_buf[range]) { + Ok(text) => return Some(Ok(ReadResult::Text(text))), + Err(e) => return Some(Err(FramerError::Utf8(e))), + } + } + } + WebSocketReceiveMessageType::CloseMustReply => { + let range = self.frame_cursor..self.frame_cursor + ws_result.len_to; + + // create a tx_buf from the end of the frame_buf + let tx_buf_len = ws_result.len_to + 14; // for extra websocket header + let split_at = frame_buf.len() - tx_buf_len; + let (frame_buf, tx_buf) = frame_buf.split_at_mut(split_at); + + match self.websocket.write( + WebSocketSendMessageType::CloseReply, + true, + &frame_buf[range.start..range.end], + tx_buf, + ) { + Ok(len) => match stream.write(&tx_buf[..len]).await { + Ok(_write_len) => { + self.frame_cursor = 0; + let status_code = ws_result + .close_status + .expect("close message must have code"); + let reason = &frame_buf[range]; + return Some(Ok(ReadResult::Close(CloseMessage { + status_code, + reason, + }))); + } + Err(e) => return Some(Err(FramerError::Io(e))), + }, + Err(e) => return Some(Err(FramerError::WebSocket(e))), + } + } + WebSocketReceiveMessageType::CloseCompleted => return None, + WebSocketReceiveMessageType::Pong => { + let range = self.frame_cursor..self.frame_cursor + ws_result.len_to; + return Some(Ok(ReadResult::Pong(&frame_buf[range]))); + } + WebSocketReceiveMessageType::Ping => { + let range = self.frame_cursor..self.frame_cursor + ws_result.len_to; + + // create a tx_buf from the end of the frame_buf + let tx_buf_len = ws_result.len_to + 14; // for extra websocket header + let split_at = frame_buf.len() - tx_buf_len; + let (frame_buf, tx_buf) = frame_buf.split_at_mut(split_at); + + match self.websocket.write( + WebSocketSendMessageType::Pong, + true, + &frame_buf[range.start..range.end], + tx_buf, + ) { + Ok(len) => match stream.write(&tx_buf[..len]).await { + Ok(_write_len) => { + return Some(Ok(ReadResult::Ping(&frame_buf[range]))); + } + Err(e) => return Some(Err(FramerError::Io(e))), + }, + Err(e) => return Some(Err(FramerError::WebSocket(e))), + } + } + } + + None + } +}