Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/async_/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ impl Client {
|| uri.starts_with("rss://")
|| uri.starts_with("tcp://")
|| uri.starts_with("tcps://")
|| uri.starts_with("unix://")
|| uri.starts_with("unix+rs://")
{
let serializer = self.serializer.serializer();
let joiner = RawSocketJoiner::new(self.serializer, self.authenticator);
Expand Down
75 changes: 56 additions & 19 deletions src/async_/rawsocket.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use crate::async_::peer::Peer;
use crate::common::types::{Error, SerializerSpec, TRANSPORT_RAW_SOCKET, TransportType};
use async_trait::async_trait;
use std::fmt::Debug;
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
use tokio::net::TcpStream;
use tokio::net::UnixStream;
use tokio::sync::Mutex;

use url::Url;
Expand All @@ -12,14 +14,23 @@ use wampproto::transports::rawsocket::{
send_handshake, send_message_header,
};

#[derive(Debug, Clone)]
pub struct RawSocketPeer {
reader: Arc<Mutex<ReadHalf<TcpStream>>>,
writer: Arc<Mutex<WriteHalf<TcpStream>>>,
#[derive(Debug)]
pub struct RawSocketPeer<S: AsyncRead + AsyncWrite + Send + Sync + Debug + 'static> {
reader: Arc<Mutex<ReadHalf<S>>>,
writer: Arc<Mutex<WriteHalf<S>>>,
}

impl<S: AsyncRead + AsyncWrite + Send + Sync + Debug + 'static> Clone for RawSocketPeer<S> {
fn clone(&self) -> Self {
RawSocketPeer {
reader: Arc::clone(&self.reader),
writer: Arc::clone(&self.writer),
}
}
}

#[async_trait]
impl Peer for RawSocketPeer {
impl<S: AsyncRead + AsyncWrite + Send + Sync + Debug + 'static> Peer for RawSocketPeer<S> {
fn kind(&self) -> TransportType {
TRANSPORT_RAW_SOCKET
}
Expand Down Expand Up @@ -65,25 +76,19 @@ impl Peer for RawSocketPeer {
}

#[allow(clippy::new_ret_no_self)]
impl RawSocketPeer {
pub fn new(reader: ReadHalf<TcpStream>, writer: WriteHalf<TcpStream>) -> Box<dyn Peer> {
impl<S: AsyncRead + AsyncWrite + Send + Sync + Debug + 'static> RawSocketPeer<S> {
pub fn new(reader: ReadHalf<S>, writer: WriteHalf<S>) -> Box<dyn Peer> {
Box::new(RawSocketPeer {
reader: Arc::new(Mutex::new(reader)),
writer: Arc::new(Mutex::new(writer)),
})
}
}

pub async fn connect_rawsocket(uri: &str, serializer: Box<dyn SerializerSpec>) -> Result<Box<dyn Peer>, Error> {
let parsed = Url::parse(uri).map_err(|e| Error::new(format!("invalid uri: {e}")))?;
let host = parsed.host_str().unwrap();
let port = parsed.port_or_known_default().unwrap();

let addr = format!("{host}:{port}");
let mut stream = TcpStream::connect(addr)
.await
.map_err(|e| Error::new(format!("connect error: {e}")))?;

async fn perform_handshake<S: AsyncRead + AsyncWrite + Unpin>(
stream: &mut S,
serializer: &dyn SerializerSpec,
) -> Result<(), Error> {
let handshake = Handshake::new(serializer.serializer_id(), DEFAULT_MAX_MSG_SIZE);

let handshake_raw =
Expand All @@ -100,7 +105,39 @@ pub async fn connect_rawsocket(uri: &str, serializer: Box<dyn SerializerSpec>) -
.await
.map_err(|e| Error::new(format!("failed to read handshake response: {e}")))?;

_ = receive_handshake(&buf).map_err(|e| Error::new(format!("failed to parse handshake response: {e}")))?;
receive_handshake(&buf).map_err(|e| Error::new(format!("failed to parse handshake response: {e}")))?;

Ok(())
}

pub async fn connect_rawsocket(uri: &str, serializer: Box<dyn SerializerSpec>) -> Result<Box<dyn Peer>, Error> {
let parsed = Url::parse(uri).map_err(|e| Error::new(format!("invalid uri: {e}")))?;

if parsed.scheme() == "unix" || parsed.scheme() == "unix+rs" {
let path = parsed.path();
let mut stream = UnixStream::connect(path)
.await
.map_err(|e| Error::new(format!("connect error: {e}")))?;

perform_handshake(&mut stream, serializer.as_ref()).await?;

let (reader, writer) = tokio::io::split(stream);
return Ok(RawSocketPeer::new(reader, writer));
}

let host = parsed
.host_str()
.ok_or_else(|| Error::new("missing host in uri".to_string()))?;
let port = parsed
.port_or_known_default()
.ok_or_else(|| Error::new("missing port in uri".to_string()))?;

let addr = format!("{host}:{port}");
let mut stream = TcpStream::connect(addr)
.await
.map_err(|e| Error::new(format!("connect error: {e}")))?;

perform_handshake(&mut stream, serializer.as_ref()).await?;

let (reader, writer) = tokio::io::split(stream);
Ok(RawSocketPeer::new(reader, writer))
Expand Down
Loading