From 0d767a5cf5bfac909fc6afbdc360ce24c56effe8 Mon Sep 17 00:00:00 2001 From: JKearnsl Date: Fri, 23 May 2025 16:55:25 +0300 Subject: [PATCH 1/3] add buff pool --- crates/client/Cargo.toml | 2 +- crates/client/src/runtime/buffer.rs | 191 ++++++++++++++++++++++++++++ crates/client/src/runtime/mod.rs | 1 + 3 files changed, 193 insertions(+), 1 deletion(-) create mode 100644 crates/client/src/runtime/buffer.rs diff --git a/crates/client/Cargo.toml b/crates/client/Cargo.toml index 9b714a7..0876982 100644 --- a/crates/client/Cargo.toml +++ b/crates/client/Cargo.toml @@ -29,6 +29,7 @@ anstyle = { version = "1.0", optional = true } tokio-tungstenite = { workspace = true, optional = true } futures = { workspace = true, optional = true } ipnetwork = "0.21.1" +crossbeam = "0.8.4" # console-subscriber = "0.4.1" # Crypto @@ -44,4 +45,3 @@ tracing-subscriber = { workspace = true } tracing = { workspace = true } tracing-appender = { workspace = true } async-trait = "0.1.88" - diff --git a/crates/client/src/runtime/buffer.rs b/crates/client/src/runtime/buffer.rs new file mode 100644 index 0000000..9cc34e9 --- /dev/null +++ b/crates/client/src/runtime/buffer.rs @@ -0,0 +1,191 @@ +use std::io::{Read, Write}; +use std::sync::atomic::{AtomicUsize, Ordering}; +use crossbeam::queue::SegQueue; + +pub struct Buffer { + data: Vec +} + +impl Buffer { + + pub fn new() -> Self { + Self { + data: Vec::new() + } + } + + fn reset(&mut self) { + unsafe { self.data.set_len(0) }; + } +} + +impl Read for Buffer { + #[inline] + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + if self.data.len() > buf.len() { + return Ok(0); + } + + unsafe { + std::ptr::copy_nonoverlapping( + self.data.as_ptr(), + buf.as_mut_ptr(), + self.data.len() + ); + } + Ok(self.data.len()) + } +} + +impl Write for Buffer { + #[inline] + fn write(&mut self, buf: &[u8]) -> std::io::Result { + if self.data.capacity() < buf.len() { + self.data.reserve(buf.len()); + } + + unsafe { + std::ptr::copy_nonoverlapping( + buf.as_ptr(), + self.data.as_mut_ptr(), + buf.len() + ); + self.data.set_len(buf.len()); + } + Ok(buf.len()) + } + + fn flush(&mut self) -> std::io::Result<()> { Ok(()) } +} + +impl Default for Buffer { + fn default() -> Self { + Self::new() + } +} + +pub struct BufferPool { + buffers: SegQueue, + allocated: AtomicUsize, + limit: Option, +} + +impl BufferPool { + pub fn new(limit: Option) -> Self { + Self { + buffers: SegQueue::new(), + allocated: AtomicUsize::new(0), + limit, + } + } + + pub fn alloc(&self) -> Option { + if let Some(limit) = self.limit { + if self.allocated.load(Ordering::Relaxed) >= limit { + return None; + } + self.allocated.fetch_add(1, Ordering::Relaxed); + } + Some(self.buffers.pop().unwrap_or_default()) + } + + pub fn release(&self, mut buffer: Buffer) { + if self.limit.is_some() { + self.allocated.fetch_sub(1, Ordering::Relaxed); + } + buffer.reset(); + self.buffers.push(buffer); + } +} + +impl Default for BufferPool { + fn default() -> Self { + Self::new(None) + } +} + +#[cfg(test)] +#[allow(clippy::unused_io_amount)] +mod tests { + use std::sync::Arc; + use super::*; + use std::thread; + + #[test] + fn basic_read_write() { + let mut buf = Buffer::new(); + buf.write(&[1, 2, 3]).unwrap(); + let mut output = [0u8; 3]; + assert_eq!(buf.read(&mut output).unwrap(), 3); + assert_eq!(output, [1, 2, 3]); + } + + #[test] + fn buffer_reuse() { + let pool = BufferPool::new(Some(1)); + + // first alloc + let mut buf1 = pool.alloc().unwrap(); + buf1.write(&[1; 512]).unwrap(); + let buf1_cap = buf1.data.capacity(); + pool.release(buf1); + + // reuse + let buf2 = pool.alloc().unwrap(); + assert_eq!(buf2.data.capacity(), buf1_cap); + assert!(buf2.data.is_empty()); + } + + #[test] + fn pool_limit() { + let pool = BufferPool::new(Some(2)); + + let b1 = pool.alloc().unwrap(); + let b2 = pool.alloc().unwrap(); + assert!(pool.alloc().is_none()); + + pool.release(b1); + pool.release(b2); + + assert!(pool.alloc().is_some()); + assert!(pool.alloc().is_some()); + assert!(pool.alloc().is_none()); + } + + #[test] + fn concurrent_access() { + let pool = BufferPool::new(Some(10)); + let pool = Arc::new(pool); + + let handles: Vec<_> = (0..10).map(|_| { + let pool = Arc::clone(&pool); + thread::spawn(move || { + let mut buf = pool.alloc().unwrap(); + buf.write(&[1; 100]).unwrap(); + pool.release(buf); + }) + }).collect(); + + handles.into_iter().for_each(|h| h.join().unwrap()); + } + + #[test] + fn zero_alloc_after_warmup() { + let pool = BufferPool::new(Some(10)); + const REQ_CAP: usize = 1024; + + // warmup + for _ in 0..10 { + let mut buf = pool.alloc().unwrap(); + buf.write(&[0; REQ_CAP]).unwrap(); + pool.release(buf); + } + + // check if we can allocate without exceeding the limit + for _ in 0..10 { + let buf = pool.alloc().unwrap(); + assert_eq!(buf.data.capacity(), REQ_CAP); + pool.release(buf); + } + } +} \ No newline at end of file diff --git a/crates/client/src/runtime/mod.rs b/crates/client/src/runtime/mod.rs index ab359dd..af82447 100644 --- a/crates/client/src/runtime/mod.rs +++ b/crates/client/src/runtime/mod.rs @@ -3,6 +3,7 @@ mod worker; mod transport; pub mod state; mod handshake; +mod buffer; use std::net::SocketAddr; use std::ops::Deref; From 6ef3e1deff675e28278f41b8ee8c755cd316a5e8 Mon Sep 17 00:00:00 2001 From: JKearnsl Date: Sat, 5 Jul 2025 21:07:27 +0300 Subject: [PATCH 2/3] improve buffer allocation performance and reuse testing --- crates/client/src/runtime/buffer.rs | 42 ++++++++++++++++++++--------- 1 file changed, 30 insertions(+), 12 deletions(-) diff --git a/crates/client/src/runtime/buffer.rs b/crates/client/src/runtime/buffer.rs index 9cc34e9..7c6dd85 100644 --- a/crates/client/src/runtime/buffer.rs +++ b/crates/client/src/runtime/buffer.rs @@ -122,18 +122,36 @@ mod tests { #[test] fn buffer_reuse() { - let pool = BufferPool::new(Some(1)); - - // first alloc - let mut buf1 = pool.alloc().unwrap(); - buf1.write(&[1; 512]).unwrap(); - let buf1_cap = buf1.data.capacity(); - pool.release(buf1); - - // reuse - let buf2 = pool.alloc().unwrap(); - assert_eq!(buf2.data.capacity(), buf1_cap); - assert!(buf2.data.is_empty()); + let pool = BufferPool::new(Some(214748)); + let mut results = vec![(0u128, 0u128); 214748]; + let mut buffers= Vec::new(); + for i in 0..214748 { + let mut buf = pool.alloc().unwrap(); + buf.write(&[i as u8; 9000]).unwrap(); + buffers.push(buf); + } + for i in buffers.into_iter() { + pool.release(i); + } + + let mut start = std::time::Instant::now(); + for i in 0..214748 { + start = std::time::Instant::now(); + let mut buf = pool.alloc().unwrap(); + results[i] = (results[i].0, start.elapsed().as_nanos()); + pool.release(buf); + } + + for i in 0..214748 { + start = std::time::Instant::now(); + let buf = Vec::::with_capacity(9000); + results[i] = (start.elapsed().as_nanos(), results[i].1); + } + for (alloc_time, reuse_time) in results { + if alloc_time > 1000 || reuse_time > 1000 { + panic!("Allocation or reuse took too long: alloc={}ns, reuse={}ns", alloc_time, reuse_time); + } + } } #[test] From 72309da05fe04bfc541394ad7a57399c98f12fe0 Mon Sep 17 00:00:00 2001 From: JKearnsl Date: Sun, 7 Dec 2025 15:46:55 +0300 Subject: [PATCH 3/3] upd --- Cargo.toml | 2 +- crates/client/src/runtime/transport.rs | 22 -- crates/client/src/runtime/transport/udp.rs | 62 ---- crates/client/src/runtime/transport/ws.rs | 103 ------- crates/client/src/runtime/worker/mod.rs | 124 -------- .../src/runtime => holynet-sdk/src}/error.rs | 21 +- crates/holynet-sdk/src/gateway/mod.rs | 2 + crates/holynet-sdk/src/gateway/network.rs | 18 ++ .../src/gateway}/transport.rs | 5 + .../holynet-sdk/src/gateway/transport/mock.rs | 227 ++++++++++++++ .../src/gateway}/transport/udp.rs | 56 +++- .../src/gateway}/transport/ws.rs | 81 ++++- .../mod.rs => holynet-sdk/src/protocol.rs} | 13 +- .../src/protocol/data.rs | 5 +- .../src/protocol/handshake.rs | 2 +- .../src/protocol/primitives.rs} | 0 .../src/protocol}/session.rs | 8 + crates/holynet-sdk/src/runtime/client.rs | 287 ++++++++++++++++++ .../src/runtime/client}/connector.rs | 28 +- .../src/runtime/client}/data.rs | 2 + .../src/runtime/client/network.rs} | 53 ++-- .../src/runtime/client}/transport.rs | 74 +++-- crates/holynet-sdk/src/runtime/cred.rs | 9 + crates/holynet-sdk/src/runtime/mod.rs | 3 + .../src/runtime/state.rs | 7 +- crates/server/src/runtime/error.rs | 20 -- crates/shared/src/keys/handshake.rs | 42 --- crates/shared/src/keys/mod.rs | 105 ------- crates/shared/src/lib.rs | 4 - 29 files changed, 807 insertions(+), 578 deletions(-) delete mode 100644 crates/client/src/runtime/transport.rs delete mode 100644 crates/client/src/runtime/transport/udp.rs delete mode 100644 crates/client/src/runtime/transport/ws.rs delete mode 100644 crates/client/src/runtime/worker/mod.rs rename crates/{client/src/runtime => holynet-sdk/src}/error.rs (53%) create mode 100644 crates/holynet-sdk/src/gateway/mod.rs create mode 100644 crates/holynet-sdk/src/gateway/network.rs rename crates/{server/src/runtime => holynet-sdk/src/gateway}/transport.rs (73%) create mode 100644 crates/holynet-sdk/src/gateway/transport/mock.rs rename crates/{server/src/runtime => holynet-sdk/src/gateway}/transport/udp.rs (52%) rename crates/{server/src/runtime => holynet-sdk/src/gateway}/transport/ws.rs (66%) rename crates/{shared/src/protocol/mod.rs => holynet-sdk/src/protocol.rs} (78%) rename crates/{shared => holynet-sdk}/src/protocol/data.rs (65%) rename crates/{shared => holynet-sdk}/src/protocol/handshake.rs (96%) rename crates/{shared/src/types.rs => holynet-sdk/src/protocol/primitives.rs} (100%) rename crates/{shared/src => holynet-sdk/src/protocol}/session.rs (60%) create mode 100644 crates/holynet-sdk/src/runtime/client.rs rename crates/{client/src/runtime/worker => holynet-sdk/src/runtime/client}/connector.rs (81%) rename crates/{client/src/runtime/worker => holynet-sdk/src/runtime/client}/data.rs (98%) rename crates/{client/src/runtime/worker/tun.rs => holynet-sdk/src/runtime/client/network.rs} (68%) rename crates/{client/src/runtime/worker => holynet-sdk/src/runtime/client}/transport.rs (65%) create mode 100644 crates/holynet-sdk/src/runtime/cred.rs create mode 100644 crates/holynet-sdk/src/runtime/mod.rs rename crates/{client => holynet-sdk}/src/runtime/state.rs (61%) delete mode 100644 crates/server/src/runtime/error.rs delete mode 100644 crates/shared/src/keys/handshake.rs delete mode 100644 crates/shared/src/keys/mod.rs diff --git a/Cargo.toml b/Cargo.toml index b65718a..67312bb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [workspace] resolver = "2" members = [ - "crates/client", + "crates/client", "crates/holynet-sdk", "crates/server", "crates/shared" ] diff --git a/crates/client/src/runtime/transport.rs b/crates/client/src/runtime/transport.rs deleted file mode 100644 index b17febf..0000000 --- a/crates/client/src/runtime/transport.rs +++ /dev/null @@ -1,22 +0,0 @@ -#[cfg(feature = "udp")] -pub mod udp; -#[cfg(feature = "ws")] -pub mod ws; - -use async_trait::async_trait; -use std::io; - -#[async_trait] -pub trait TransportSender: Send + Sync { - async fn send(&self, data: &[u8]) -> io::Result; -} - -#[async_trait] -pub trait TransportReceiver: Send + Sync { - async fn recv(&self, buffer: &mut [u8]) -> io::Result; -} - -#[async_trait] -pub trait Transport: TransportSender + TransportReceiver + Send + Sync { - async fn connect(&self) -> io::Result<()>; -} diff --git a/crates/client/src/runtime/transport/udp.rs b/crates/client/src/runtime/transport/udp.rs deleted file mode 100644 index 76f7392..0000000 --- a/crates/client/src/runtime/transport/udp.rs +++ /dev/null @@ -1,62 +0,0 @@ -use crate::runtime::error::RuntimeError; -use crate::runtime::transport::{Transport, TransportReceiver, TransportSender}; -use async_trait::async_trait; -use socket2::{Domain, Protocol, Socket, Type}; -use std::net::{IpAddr, Ipv4Addr, SocketAddr}; -use std::time::Duration; -use tokio::net::UdpSocket; -use tracing::info; - - -pub struct UdpTransport { - socket: UdpSocket -} - -impl UdpTransport { - pub fn new( - addr: SocketAddr, - so_rcvbuf: usize, - so_sndbuf: usize, - ) -> Result { - let socket = Socket::new( - Domain::for_address(addr), - Type::DGRAM, - Some(Protocol::UDP) - )?; - socket.set_nonblocking(true)?; - socket.set_recv_buffer_size(so_rcvbuf)?; - socket.set_send_buffer_size(so_sndbuf)?; - socket.bind(&SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0,0,0,0)), 0).into())?; - socket.connect(&addr.into())?; - - Ok(Self { socket: UdpSocket::from_std(socket.into())? }) - } -} - -#[async_trait] -impl TransportReceiver for UdpTransport { - - #[inline(always)] - async fn recv(&self, buffer: &mut [u8]) -> std::io::Result { - self.socket.recv(buffer).await - } -} - -#[async_trait] -impl TransportSender for UdpTransport { - #[inline(always)] - async fn send(&self, data: &[u8]) -> std::io::Result { - self.socket.send(data).await - } -} - -#[async_trait] -impl Transport for UdpTransport { - async fn connect(&self) -> std::io::Result<()> { - info!("connecting to udp://{}", self.socket.peer_addr()?); - tokio::select! { - _ = self.socket.connect(self.socket.peer_addr()?) => Ok(()), - _ = tokio::time::sleep(Duration::from_secs(5)) => Err(std::io::Error::other("connection timeout")) - } - } -} diff --git a/crates/client/src/runtime/transport/ws.rs b/crates/client/src/runtime/transport/ws.rs deleted file mode 100644 index 02dda56..0000000 --- a/crates/client/src/runtime/transport/ws.rs +++ /dev/null @@ -1,103 +0,0 @@ -use std::io; -use crate::runtime::transport::{Transport, TransportReceiver, TransportSender}; -use async_trait::async_trait; -use futures::stream::{SplitSink, SplitStream}; -use futures::{SinkExt, StreamExt}; -use std::net::SocketAddr; -use std::sync::Arc; -use anyhow::anyhow; -use tokio::net::TcpStream; -use tokio::sync::Mutex; -use tokio_tungstenite::tungstenite::client::IntoClientRequest; -use tokio_tungstenite::tungstenite::Message; -use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream}; -use tracing::info; - - -pub struct WsTransport { - addr: SocketAddr, - write: Arc>, Message>>>>, - read: Arc>>>>>, -} - -impl WsTransport { - pub fn new(addr: SocketAddr) -> Self { - Self {addr, write: Arc::new(Mutex::new(None)) , read: Arc::new(Mutex::new(None)) } - } -} - -#[async_trait] -impl TransportReceiver for WsTransport { - - #[inline(always)] - async fn recv(&self, buffer: &mut [u8]) -> io::Result { - match self.read.lock().await.as_mut() { - Some(read) => { - while let Some(Ok(msg)) = read.next().await { - if let Message::Binary(data) = msg { - let len = data.len().min(buffer.len()); - buffer[..len].copy_from_slice(&data[..len]); - return Ok(len); - } - } - - Err(io::Error::new( - io::ErrorKind::ConnectionAborted, - "WebSocket connection closed" - )) - }, - None => Err(io::Error::new( - io::ErrorKind::NotConnected, - "WebSocket connection not established" - )) - } - } -} - -#[async_trait] -impl TransportSender for WsTransport { - - #[inline(always)] - async fn send(&self, data: &[u8]) -> io::Result { - match self.write.lock().await.as_mut() { - Some(write) => write - .send(Message::Binary(data.to_vec().into())) - .await - .map(|_| data.len()) - .map_err(|e| io::Error::new( - io::ErrorKind::BrokenPipe, - e.to_string() - )), - None => Err(io::Error::new( - io::ErrorKind::NotConnected, - "WebSocket connection not established" - )) - } - } -} - -#[async_trait] -impl Transport for WsTransport{ - async fn connect(&self) -> io::Result<()> { - info!("connecting to ws://{}", self.addr); - let request = format!("ws://{}", self.addr).into_client_request().map_err(|e| { - io::Error::new( - io::ErrorKind::Other, - anyhow!("failed to create WebSocket request: {}", e) - ) - })?; - - let (ws_stream, _) = connect_async(request) - .await - .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?; - - let (write, read) = ws_stream.split(); - - let mut write_lock = self.write.lock().await; - *write_lock = Some(write); - let mut read_lock = self.read.lock().await; - *read_lock = Some(read); - - Ok(()) - } -} diff --git a/crates/client/src/runtime/worker/mod.rs b/crates/client/src/runtime/worker/mod.rs deleted file mode 100644 index 689e1a1..0000000 --- a/crates/client/src/runtime/worker/mod.rs +++ /dev/null @@ -1,124 +0,0 @@ -mod connector; -mod data; -mod tun; -mod transport; - -#[cfg(feature = "udp")] -pub use crate::runtime::transport::udp::UdpTransport; - -#[cfg(feature = "ws")] -pub use crate::runtime::transport::ws::WsTransport; - -use crate::{ - runtime::{ - error::RuntimeError, - worker::{ - data::{data_tun_executor, data_udp_executor, keepalive_sender}, - transport::{transport_listener, transport_sender}, - tun::{tun_listener, tun_sender}, - } - }, -}; -use shared::connection_config::{CredentialsConfig, RuntimeConfig}; -use shared::protocol::{EncryptedData, Packet}; -use shared::session::Alg; -use std::time::Duration; -use std::{ - net::SocketAddr, - sync::Arc -}; -use tokio::sync::{mpsc, watch}; -use tracing::{debug, warn}; -use tun_rs::AsyncDevice; -use crate::runtime::state::RuntimeState; -use crate::runtime::transport::Transport; - -pub(crate) async fn create( - addr: SocketAddr, - tun: Arc, - state_tx: watch::Sender, - cred: CredentialsConfig, - alg: Alg, - runtime_config: RuntimeConfig, -) -> Result<(), RuntimeError> { - - let transport: Arc = match () { - #[cfg(feature = "udp")] - _ if cfg!(feature = "udp") => { - Arc::new(UdpTransport::new( - addr, - runtime_config.so_rcvbuf, - runtime_config.so_sndbuf, - )?) - } - #[cfg(feature = "ws")] - _ if cfg!(feature = "ws") => { - Arc::new(WsTransport::new(addr)) - } - _ => unreachable!("transport is not enabled, please enable transport features") - }; - - let (udp_sender_tx, udp_sender_rx) = mpsc::channel::(runtime_config.out_udp_buf); - let (tun_sender_tx, tun_sender_rx) = mpsc::channel::>(runtime_config.out_tun_buf); - let (data_udp_tx, data_udp_rx) = mpsc::channel::(runtime_config.data_udp_buf); - let (data_tun_tx, data_tun_rx) = mpsc::channel::>(runtime_config.data_tun_buf); - - // Handle incoming UDP packets - tokio::spawn(transport_listener(state_tx.clone(), transport.clone(), data_udp_tx)); - - // Handle outgoing UDP packets - tokio::spawn(transport_sender(state_tx.clone(), transport.clone(), udp_sender_rx)); - - - // Executors - tokio::spawn(data_tun_executor( - state_tx.clone(), - data_tun_rx, - udp_sender_tx.clone(), - )); - - tokio::spawn(data_udp_executor( - state_tx.clone(), - data_udp_rx, - tun_sender_tx - )); - - - // Handle incoming TUN packets - tokio::spawn(tun_listener( - state_tx.clone(), - tun.clone(), - data_tun_tx - )); - - // Handle outgoing TUN packets - tokio::spawn(tun_sender( - state_tx.clone(), - tun.clone(), - tun_sender_rx - )); - - - match runtime_config.keepalive { - Some(duration) => { - debug!("starting keepalive with interval {:?}", duration); - tokio::spawn(keepalive_sender( - state_tx.clone(), - udp_sender_tx, - Duration::from_secs(duration), - )); - }, - None => warn!("keepalive is disabled") - } - - // handshake_executor - connector::executor( - transport.clone(), - state_tx.clone(), - cred, - alg, - Duration::from_millis(runtime_config.handshake_timeout) - ).await; - - Ok(()) -} diff --git a/crates/client/src/runtime/error.rs b/crates/holynet-sdk/src/error.rs similarity index 53% rename from crates/client/src/runtime/error.rs rename to crates/holynet-sdk/src/error.rs index 80b0977..a814b5c 100644 --- a/crates/client/src/runtime/error.rs +++ b/crates/holynet-sdk/src/error.rs @@ -1,15 +1,14 @@ -use thiserror::Error; -#[derive(Error, Debug, Clone)] +pub enum BuildError { + MissingRequiredField(&'static str), +} + +#[derive(Debug, Clone)] pub enum RuntimeError { - #[error("IO: {0}")] IO(String), - #[error("Handshake: {0}")] Handshake(String), - #[error("Unexpected: {0}")] Unexpected(String), - #[error("StopSignal")] StopSignal } @@ -21,18 +20,12 @@ impl From for RuntimeError { impl From for RuntimeError { fn from(err: snow::Error) -> Self { - RuntimeError::Handshake(format!("snow error: {}", err)) - } -} - -impl From for RuntimeError { - fn from(err: anyhow::Error) -> Self { - RuntimeError::Unexpected(err.to_string()) + RuntimeError::Handshake(format!("snow error: {err}")) } } impl From> for RuntimeError { fn from(err: tokio::sync::broadcast::error::SendError) -> Self { - RuntimeError::IO(format!("broadcast send: {}", err)) + RuntimeError::IO(format!("broadcast send: {err}")) } } diff --git a/crates/holynet-sdk/src/gateway/mod.rs b/crates/holynet-sdk/src/gateway/mod.rs new file mode 100644 index 0000000..03fe74f --- /dev/null +++ b/crates/holynet-sdk/src/gateway/mod.rs @@ -0,0 +1,2 @@ +pub mod network; +pub mod transport; diff --git a/crates/holynet-sdk/src/gateway/network.rs b/crates/holynet-sdk/src/gateway/network.rs new file mode 100644 index 0000000..8cd4289 --- /dev/null +++ b/crates/holynet-sdk/src/gateway/network.rs @@ -0,0 +1,18 @@ +use std::io; +use std::net::SocketAddr; +use async_trait::async_trait; + +#[async_trait] +pub trait NetworkSender: Send + Sync { + async fn send_to(&self, data: &[u8], addr: &SocketAddr) -> io::Result; + async fn send(&self, data: &[u8]) -> io::Result; +} + +#[async_trait] +pub trait NetworkReceiver: Send + Sync { + async fn recv_from(&self, buffer: &mut [u8]) -> io::Result<(usize, SocketAddr)>; + async fn recv(&self, buffer: &mut [u8]) -> io::Result; +} + +#[async_trait] +pub trait Network: NetworkSender + NetworkReceiver{} diff --git a/crates/server/src/runtime/transport.rs b/crates/holynet-sdk/src/gateway/transport.rs similarity index 73% rename from crates/server/src/runtime/transport.rs rename to crates/holynet-sdk/src/gateway/transport.rs index f32aead..223ec88 100644 --- a/crates/server/src/runtime/transport.rs +++ b/crates/holynet-sdk/src/gateway/transport.rs @@ -3,6 +3,7 @@ pub mod udp; #[cfg(feature = "ws")] pub mod ws; +mod mock; use std::any::Any; use std::io; @@ -12,13 +13,17 @@ use async_trait::async_trait; #[async_trait] pub trait TransportSender: Send + Sync { async fn send_to(&self, data: &[u8], addr: &SocketAddr) -> io::Result; + async fn send(&self, data: &[u8]) -> io::Result; } #[async_trait] pub trait TransportReceiver: Send + Sync { async fn recv_from(&self, buffer: &mut [u8]) -> io::Result<(usize, SocketAddr)>; + async fn recv(&self, buffer: &mut [u8]) -> io::Result; } +#[async_trait] pub trait Transport: TransportSender + TransportReceiver{ fn as_any(&self) -> &dyn Any; + async fn connect(&self) -> io::Result<()>; } diff --git a/crates/holynet-sdk/src/gateway/transport/mock.rs b/crates/holynet-sdk/src/gateway/transport/mock.rs new file mode 100644 index 0000000..1ea5322 --- /dev/null +++ b/crates/holynet-sdk/src/gateway/transport/mock.rs @@ -0,0 +1,227 @@ +use crate::error::RuntimeError; +use crate::gateway::transport::{Transport, TransportReceiver, TransportSender}; +use async_trait::async_trait; +use std::any::Any; +use std::net::SocketAddr; +use std::sync::Arc; +use tokio::sync::mpsc; +use tracing::info; + +struct MockTransportInner { + tx: mpsc::Sender>, + rx: mpsc::Receiver>, + peer_addr: SocketAddr, +} + +pub struct MockTransport { + inner: Arc>, + local_addr: SocketAddr, +} + +impl MockTransport { + pub fn new() -> Self { + Self::with_capacity(100) + } + + pub fn with_capacity(buffer_size: usize) -> Self { + let (tx, rx) = mpsc::channel(buffer_size); + let local_addr = "127.0.0.1:0".parse().unwrap(); + let peer_addr = "127.0.0.1:0".parse().unwrap(); + + MockTransport { + inner: Arc::new(tokio::sync::Mutex::new(MockTransportInner { + tx, + rx, + peer_addr, + })), + local_addr, + } + } + + pub fn create_pair() -> (Self, Self) { + let (tx1, rx1) = mpsc::channel(100); + let (tx2, rx2) = mpsc::channel(100); + + let addr1 = "127.0.0.1:10001".parse().unwrap(); + let addr2 = "127.0.0.1:10002".parse().unwrap(); + + let transport1 = MockTransport { + inner: Arc::new(tokio::sync::Mutex::new(MockTransportInner { + tx: tx1, + rx: rx2, + peer_addr: addr2, + })), + local_addr: addr1, + }; + + let transport2 = MockTransport { + inner: Arc::new(tokio::sync::Mutex::new(MockTransportInner { + tx: tx2, + rx: rx1, + peer_addr: addr1, + })), + local_addr: addr2, + }; + + (transport1, transport2) + } + + pub fn set_peer(&self, peer: &MockTransport) { + let peer_inner = Arc::clone(&peer.inner); + let mut inner_guard = futures::executor::block_on(self.inner.lock()); + + let peer_guard = futures::executor::block_on(peer_inner.lock()); + inner_guard.peer_addr = peer.local_addr; + + // Note: В реальности нужно аккуратно обменяться каналами + // Для простоты используем создание новой пары + } + + pub fn create_sender(&self) -> MockTransportSender { + let inner = Arc::clone(&self.inner); + MockTransportSender { inner } + } + + pub fn local_addr(&self) -> SocketAddr { + self.local_addr + } + + pub fn peer_addr(&self) -> SocketAddr { + futures::executor::block_on(async { + let inner = self.inner.lock().await; + inner.peer_addr + }) + } +} + +pub struct MockTransportSender { + inner: Arc>, +} + +impl MockTransportSender { + pub async fn send(&self, data: Vec) -> Result<(), RuntimeError> { + let inner = self.inner.lock().await; + inner.tx.send(data).await.map_err(|e| { + RuntimeError::IO(format!("Failed to send data: {}", e)) + })?; + Ok(()) + } +} + +#[async_trait] +impl TransportSender for MockTransport { + async fn send_to(&self, data: &[u8], addr: &SocketAddr) -> std::io::Result { + let inner = self.inner.lock().await; + inner.tx.send(data.to_vec()).await.map_err(|e| { + std::io::Error::new(std::io::ErrorKind::Other, format!("Send error: {}", e)) + })?; + Ok(data.len()) + } + + async fn send(&self, data: &[u8]) -> std::io::Result { + let inner = self.inner.lock().await; + inner.tx.send(data.to_vec()).await.map_err(|e| { + std::io::Error::new(std::io::ErrorKind::Other, format!("Send error: {}", e)) + })?; + Ok(data.len()) + } +} + +#[async_trait] +impl TransportReceiver for MockTransport { + async fn recv_from(&self, buffer: &mut [u8]) -> std::io::Result<(usize, SocketAddr)> { + let mut inner = self.inner.lock().await; + + match inner.rx.recv().await { + Some(data) => { + let len = data.len().min(buffer.len()); + buffer[..len].copy_from_slice(&data[..len]); + Ok((len, inner.peer_addr)) + } + None => Err(std::io::Error::new( + std::io::ErrorKind::ConnectionAborted, + "Channel closed" + )), + } + } + + async fn recv(&self, buffer: &mut [u8]) -> std::io::Result { + let mut inner = self.inner.lock().await; + + match inner.rx.recv().await { + Some(data) => { + let len = data.len().min(buffer.len()); + buffer[..len].copy_from_slice(&data[..len]); + Ok(len) + } + None => Err(std::io::Error::new( + std::io::ErrorKind::ConnectionAborted, + "Channel closed" + )), + } + } +} + +#[async_trait] +impl Transport for MockTransport { + async fn connect(&self) -> std::io::Result<()> { + info!("MockTransport::connect called - ready for communication"); + Ok(()) + } + + fn as_any(&self) -> &dyn Any { + self + } +} + +impl Clone for MockTransport { + fn clone(&self) -> Self { + MockTransport { + inner: Arc::clone(&self.inner), + local_addr: self.local_addr, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_mock_transport_pair() { + let (mut transport1, mut transport2) = MockTransport::create_pair(); + + // Тест отправки от transport1 к transport2 + let test_data = b"Hello from transport1"; + transport1.send_to(test_data, &transport2.local_addr()).await.unwrap(); + + let mut buffer = [0u8; 1024]; + let (size, addr) = transport2.recv_from(&mut buffer).await.unwrap(); + + assert_eq!(&buffer[..size], test_data); + assert_eq!(addr, transport1.local_addr()); + + // Тест отправки от transport2 к transport1 + let test_data2 = b"Hello from transport2"; + transport2.send_to(test_data2, &transport1.local_addr()).await.unwrap(); + + let (size2, addr2) = transport1.recv_from(&mut buffer).await.unwrap(); + + assert_eq!(&buffer[..size2], test_data2); + assert_eq!(addr2, transport2.local_addr()); + } + + #[tokio::test] + async fn test_mock_transport_sender() { + let transport = MockTransport::new(); + let sender = transport.create_sender(); + + let test_data = b"Test message"; + sender.send(test_data.to_vec()).await.unwrap(); + + let mut buffer = [0u8; 1024]; + let size = transport.recv(&mut buffer).await.unwrap(); + + assert_eq!(&buffer[..size], test_data); + } +} \ No newline at end of file diff --git a/crates/server/src/runtime/transport/udp.rs b/crates/holynet-sdk/src/gateway/transport/udp.rs similarity index 52% rename from crates/server/src/runtime/transport/udp.rs rename to crates/holynet-sdk/src/gateway/transport/udp.rs index 55a3531..0a68076 100644 --- a/crates/server/src/runtime/transport/udp.rs +++ b/crates/holynet-sdk/src/gateway/transport/udp.rs @@ -1,9 +1,10 @@ -use std::any::Any; -use crate::runtime::error::RuntimeError; -use crate::runtime::transport::{Transport, TransportReceiver, TransportSender}; +use crate::error::RuntimeError; +use crate::gateway::transport::{Transport, TransportReceiver, TransportSender}; use async_trait::async_trait; use socket2::{Domain, Protocol, Socket, Type}; -use std::net::SocketAddr; +use std::any::Any; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::time::Duration; use tokio::net::UdpSocket; use tracing::info; @@ -12,6 +13,12 @@ pub struct UdpTransport { } impl UdpTransport { + + /// Create a new UDP transport with a pool of sockets (requires `udp-reuse-port` feature) + /// Workers will share the same port using `SO_REUSEPORT` option + /// + /// Only available on Linux and some BSD systems + #[cfg(feature = "udp-reuse-port")] pub fn new_pool( addr: SocketAddr, so_rcvbuf: usize, @@ -51,6 +58,28 @@ impl UdpTransport { Ok(sockets) } + + /// Create a new UDP transport with single socket + /// + /// Available on all platforms + pub fn new( + addr: SocketAddr, + so_rcvbuf: usize, + so_sndbuf: usize, + ) -> Result { + let socket = Socket::new( + Domain::for_address(addr), + Type::DGRAM, + Some(Protocol::UDP) + )?; + socket.set_nonblocking(true)?; + socket.set_recv_buffer_size(so_rcvbuf)?; + socket.set_send_buffer_size(so_sndbuf)?; + socket.bind(&SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0,0,0,0)), 0).into())?; + socket.connect(&addr.into())?; + + Ok(Self { socket: UdpSocket::from_std(socket.into())? }) + } } #[async_trait] @@ -60,6 +89,11 @@ impl TransportReceiver for UdpTransport { async fn recv_from(&self, buffer: &mut [u8]) -> std::io::Result<(usize, SocketAddr)> { self.socket.recv_from(buffer).await } + + #[inline(always)] + async fn recv(&self, buffer: &mut [u8]) -> std::io::Result { + self.socket.recv(buffer).await + } } #[async_trait] @@ -68,10 +102,24 @@ impl TransportSender for UdpTransport { async fn send_to(&self, data: &[u8], addr: &SocketAddr) -> std::io::Result { self.socket.send_to(data, addr).await } + + #[inline(always)] + async fn send(&self, data: &[u8]) -> std::io::Result { + self.socket.send(data).await + } } +#[async_trait] impl Transport for UdpTransport{ fn as_any(&self) -> &dyn Any { self } + + async fn connect(&self) -> std::io::Result<()> { + info!("connecting to udp://{}", self.socket.peer_addr()?); + tokio::select! { + _ = self.socket.connect(self.socket.peer_addr()?) => Ok(()), + _ = tokio::time::sleep(Duration::from_secs(5)) => Err(std::io::Error::other("connection timeout")) + } + } } diff --git a/crates/server/src/runtime/transport/ws.rs b/crates/holynet-sdk/src/gateway/transport/ws.rs similarity index 66% rename from crates/server/src/runtime/transport/ws.rs rename to crates/holynet-sdk/src/gateway/transport/ws.rs index 3bfe11b..3251d0b 100644 --- a/crates/server/src/runtime/transport/ws.rs +++ b/crates/holynet-sdk/src/gateway/transport/ws.rs @@ -1,6 +1,7 @@ use std::any::Any; -use crate::runtime::error::RuntimeError; -use crate::runtime::transport::{Transport, TransportReceiver, TransportSender}; +use std::io; +use crate::error::RuntimeError; +use crate::transport::{Transport, TransportReceiver, TransportSender}; use async_trait::async_trait; use dashmap::DashMap; use futures::stream::SplitSink; @@ -11,7 +12,7 @@ use std::sync::Arc; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::{mpsc, Mutex}; use tokio_tungstenite::tungstenite::{Bytes, Message}; -use tokio_tungstenite::{accept_async, WebSocketStream}; +use tokio_tungstenite::{accept_async, connect_async, WebSocketStream}; use tracing::{debug, info}; pub struct WsTransport { @@ -22,6 +23,8 @@ pub struct WsTransport { } impl WsTransport { + + #[cfg(feature = "ws-reuse-port")] pub fn new_pool( addr: SocketAddr, so_rcvbuf: usize, @@ -77,6 +80,10 @@ impl WsTransport { Ok(listeners) } + pub fn new(addr: SocketAddr) -> Self { + Self {addr, write: Arc::new(Mutex::new(None)) , read: Arc::new(Mutex::new(None)) } + } + pub async fn start(&self) -> Result<(), RuntimeError> { loop { let (tcp_stream, addr) = self.listener.accept().await?; @@ -120,6 +127,30 @@ impl TransportReceiver for WsTransport { None => Err(std::io::Error::new(std::io::ErrorKind::BrokenPipe, "Channel closed")), } } + + #[inline(always)] + async fn recv(&self, buffer: &mut [u8]) -> io::Result { + match self.read.lock().await.as_mut() { + Some(read) => { + while let Some(Ok(msg)) = read.next().await { + if let Message::Binary(data) = msg { + let len = data.len().min(buffer.len()); + buffer[..len].copy_from_slice(&data[..len]); + return Ok(len); + } + } + + Err(io::Error::new( + io::ErrorKind::ConnectionAborted, + "WebSocket connection closed" + )) + }, + None => Err(io::Error::new( + io::ErrorKind::NotConnected, + "WebSocket connection not established" + )) + } + } } #[async_trait] @@ -134,12 +165,54 @@ impl TransportSender for WsTransport { Err(std::io::Error::new(std::io::ErrorKind::AddrNotAvailable, "Address not found")) } } + + #[inline(always)] + async fn send(&self, data: &[u8]) -> io::Result { + match self.write.lock().await.as_mut() { + Some(write) => write + .send(Message::Binary(data.to_vec().into())) + .await + .map(|_| data.len()) + .map_err(|e| io::Error::new( + io::ErrorKind::BrokenPipe, + e.to_string() + )), + None => Err(io::Error::new( + io::ErrorKind::NotConnected, + "WebSocket connection not established" + )) + } + } } +#[async_trait] impl Transport for WsTransport { fn as_any(&self) -> &dyn Any { self } + + async fn connect(&self) -> io::Result<()> { + info!("connecting to ws://{}", self.addr); + let request = format!("ws://{}", self.addr).into_client_request().map_err(|e| { + io::Error::new( + io::ErrorKind::Other, + anyhow!("failed to create WebSocket request: {}", e) + ) + })?; + + let (ws_stream, _) = connect_async(request) + .await + .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?; + + let (write, read) = ws_stream.split(); + + let mut write_lock = self.write.lock().await; + *write_lock = Some(write); + let mut read_lock = self.read.lock().await; + *read_lock = Some(read); + + Ok(()) + } } impl dyn Transport { @@ -151,4 +224,4 @@ impl dyn Transport { Err(self) } } -} \ No newline at end of file +} diff --git a/crates/shared/src/protocol/mod.rs b/crates/holynet-sdk/src/protocol.rs similarity index 78% rename from crates/shared/src/protocol/mod.rs rename to crates/holynet-sdk/src/protocol.rs index f53e515..79b7ba8 100644 --- a/crates/shared/src/protocol/mod.rs +++ b/crates/holynet-sdk/src/protocol.rs @@ -1,5 +1,7 @@ mod handshake; mod data; +mod primitives; +mod session; use bincode::{Decode, Encode}; pub use data::{ @@ -11,8 +13,8 @@ pub use handshake::{ HandshakeResponderPayload, HandshakeError }; -use crate::session::SessionId; -use crate::types::VecU16; +use primitives::VecU16; +pub use session::{SessionId, Alg}; pub type EncryptedHandshake = VecU16; pub type EncryptedData = VecU16; @@ -30,20 +32,21 @@ pub enum Packet { impl TryFrom<&[u8]> for Packet { - type Error = anyhow::Error; + type Error = &'static str; fn try_from(data: &[u8]) -> Result { match bincode::decode_from_slice( data, - bincode::config::standard() + bincode::config::standard() // todo: can we reuse config? ) { Ok((obj, _)) => Ok(obj), - Err(err) => Err(anyhow::anyhow!(err)) + Err(_) => Err("error decoding packet") } } } impl Packet { + /// todo: handle error pub fn to_bytes(&self) -> Vec { bincode::encode_to_vec( self, diff --git a/crates/shared/src/protocol/data.rs b/crates/holynet-sdk/src/protocol/data.rs similarity index 65% rename from crates/shared/src/protocol/data.rs rename to crates/holynet-sdk/src/protocol/data.rs index 3025321..65c271c 100644 --- a/crates/shared/src/protocol/data.rs +++ b/crates/holynet-sdk/src/protocol/data.rs @@ -1,15 +1,18 @@ use serde::{Deserialize, Serialize}; -use crate::types::VecU16; +use super::primitives::VecU16; #[derive(Serialize, Deserialize)] pub enum DataServerBody { Packet(VecU16), + /// Contains the client's timestamp KeepAlive(u128), + /// Contains the shutdown initiation code Disconnect(u8) } #[derive(Serialize, Deserialize)] pub enum DataClientBody { Packet(VecU16), + /// Contains timestamp KeepAlive(u128) } diff --git a/crates/shared/src/protocol/handshake.rs b/crates/holynet-sdk/src/protocol/handshake.rs similarity index 96% rename from crates/shared/src/protocol/handshake.rs rename to crates/holynet-sdk/src/protocol/handshake.rs index e79d7cd..2103ab9 100644 --- a/crates/shared/src/protocol/handshake.rs +++ b/crates/holynet-sdk/src/protocol/handshake.rs @@ -1,6 +1,6 @@ -use crate::session::SessionId; use serde::{Deserialize, Serialize}; use std::net::IpAddr; +use super::session::SessionId; #[derive(Serialize, Deserialize)] pub enum HandshakeResponderBody { diff --git a/crates/shared/src/types.rs b/crates/holynet-sdk/src/protocol/primitives.rs similarity index 100% rename from crates/shared/src/types.rs rename to crates/holynet-sdk/src/protocol/primitives.rs diff --git a/crates/shared/src/session.rs b/crates/holynet-sdk/src/protocol/session.rs similarity index 60% rename from crates/shared/src/session.rs rename to crates/holynet-sdk/src/protocol/session.rs index dc4956f..244d14b 100644 --- a/crates/shared/src/session.rs +++ b/crates/holynet-sdk/src/protocol/session.rs @@ -7,3 +7,11 @@ pub enum Alg { Aes256, ChaCha20Poly1305 } + + +impl Default for Alg { + fn default() -> Self { + todo!("need default for system support!!!") + } + +} \ No newline at end of file diff --git a/crates/holynet-sdk/src/runtime/client.rs b/crates/holynet-sdk/src/runtime/client.rs new file mode 100644 index 0000000..d6cfc31 --- /dev/null +++ b/crates/holynet-sdk/src/runtime/client.rs @@ -0,0 +1,287 @@ +mod connector; +mod data; +mod network; +mod transport; + + +use crate::{ + gateway::{ + network::Network, + transport::Transport + }, + runtime::{ + state::RuntimeState, + error::RuntimeError, + worker::{ + data::{data_tun_executor, data_udp_executor, keepalive_sender}, + transport::{transport_listener, transport_sender}, + tun::{tun_listener, tun_sender}, + } + }, +}; +use crate::protocol::{Alg, EncryptedData, Packet}; +use std::time::Duration; +use std::{ + net::SocketAddr, + sync::Arc +}; +use tokio::sync::{mpsc, watch}; +use tokio::task::JoinSet; +use tracing::{debug, info, warn}; +use crate::error::{BuildError, RuntimeError}; +use crate::keys::handshake::{PublicKey, SecretKey}; +use crate::runtime::client::data::{data_tun_executor, data_udp_executor, keepalive_sender}; +use crate::runtime::client::transport::{transport_listener, transport_sender}; +use crate::runtime::client::network::{tun_listener, tun_sender}; +use crate::runtime::cred::Cred; + + +const AWAIT_STATE_DELAY: Duration = Duration::from_secs(1); +const MAX_PACKET_SIZE: usize = 65536; + +pub struct ClientBuilder { + transport: Option>, + network: Option>, + addr: Option, + alg: Option, + keepalive: Option, + handshake_timeout: Duration, + reconnect_delay: Duration, + cred: Option, + // Buffer sizes + out_transport_buf: usize, + out_network_buf: usize, + data_transport_buf: usize, + data_network_buf: usize, +} + +impl ClientBuilder { + pub fn new() -> Self { + Self { + transport: None, + network: None, + addr: None, + alg: None, + keepalive: Some(Duration::from_secs(15)), + handshake_timeout: Duration::from_secs(5), + reconnect_delay: Duration::from_secs(3), + cred: None, + out_transport_buf: usize::MAX, + out_network_buf: usize::MAX, + data_transport_buf: usize::MAX, + data_network_buf: usize::MAX, + } + } + + pub fn transport(mut self, value: T) -> Self { + self.transport = Some(Box::new(value)); + self + } + + pub fn network + 'static>(mut self, value: N) -> Self { + self.network = Some(Box::new(value.into())); + self + } + + /// Set server address + pub fn addr(mut self, value: SocketAddr) -> Self { + self.addr = Some(value); + self + } + + /// Set encryption algorithm + /// If not set, the default algorithm will be used + /// + /// # Arguments + /// * `value` - Encryption algorithm + /// + /// # Default + /// The default algorithm is calculated based on the processor's capabilities + /// + pub fn alg(mut self, value: Alg) -> Self { + self.alg = Some(value); + self + } + + /// Set keepalive interval + /// If not set, keepalive is disabled + /// + /// This is useful if the client is behind nat! + /// + /// # Note + /// If the interval is too short, it may cause unnecessary network traffic. + /// If the interval is too long, it may cause the connection to be dropped by nat + /// + /// # Default + /// The default value is 15 seconds + /// + /// # Arguments + /// * `value` - Keepalive interval + pub fn keepalive(mut self, value: Duration) -> Self { + self.keepalive = Some(value); + self + } + + /// Set handshake timeout + /// + /// # Arguments + /// * `value` - Handshake timeout duration + /// # Default + /// The default value is 5 seconds + pub fn handshake_timeout(mut self, value: Duration) -> Self { + self.handshake_timeout = value; + self + } + + /// Set reconnect delay + /// If not set, the default value is 3 seconds + /// + /// # Arguments + /// * `value` - Reconnect delay duration + /// # Default + /// The default value is 3 seconds + pub fn reconnect_delay(mut self, value: Duration) -> Self { + self.reconnect_delay = value; + self + } + + /// Set credentials + /// + /// # Arguments + /// * `sk` - Client's secret key + /// * `psk` - Pre-shared key + /// * `spk` - Server's public key + pub fn cred(mut self, sk: [u8; 32], psk: [u8; 32], spk: [u8; 32]) -> Self { + self.cred = Some(Cred { sk, psk, spk }); + self + } + + pub fn out_transport_buf(mut self, value: usize) -> Self { + self.out_transport_buf = value; + self + } + pub fn out_network_buf(mut self, value: usize) -> Self { + self.out_network_buf = value; + self + } + pub fn data_transport_buf(mut self, value: usize) -> Self { + self.data_transport_buf = value; + self + } + pub fn data_network_buf(mut self, value: usize) -> Self { + self.data_network_buf = value; + self + } + + pub fn build(self) -> Result { + let (state, _) = watch::channel(RuntimeState::Connecting); + + Ok(Client { + transport: Arc::from(self.transport.ok_or(BuildError::MissingRequiredField("missing transport"))?), + network: Arc::from(self.network.ok_or(BuildError::MissingRequiredField("missing network"))?), + addr: self.addr.ok_or(BuildError::MissingRequiredField)?, + alg: self.alg.unwrap_or_default(), + keepalive: self.keepalive, + handshake_timeout: self.handshake_timeout, + cred: self.cred.ok_or(BuildError::MissingRequiredField("missing credentials"))?, + out_transport_buf: self.out_transport_buf, + out_network_buf: self.out_network_buf, + data_transport_buf: self.data_transport_buf, + data_network_buf: self.data_network_buf, + state + }) + } +} + + +pub struct Client { + transport: Arc, + network: Arc, // TODO : support multiple network types + addr: SocketAddr, + alg: Alg, + keepalive: Option, + handshake_timeout: Duration, + cred: Cred, + // Buffer sizes + out_transport_buf: usize, + out_network_buf: usize, + data_transport_buf: usize, + data_network_buf: usize, + // Internal state + state: watch::Sender +} + +impl Client { + pub async fn run(&mut self) -> Result { + let (transport_sender_tx, transport_sender_rx) = mpsc::channel::(self.out_transport_buf); + let (network_sender_tx, network_sender_rx) = mpsc::channel::>(self.out_network_buf); + let (data_transport_tx, data_transport_rx) = mpsc::channel::(self.data_transport_buf); + let (data_network_tx, data_network_rx) = mpsc::channel::>(self.data_network_buf); + + + let mut set = JoinSet::new(); + + // Handle incoming transport packets + set.spawn(transport_listener(self.state.clone(), self.transport.clone(), data_transport_tx)); + // Handle outgoing transport packets + set.spawn(transport_sender(self.state.clone(), self.transport.clone(), transport_sender_rx)); + // Handle incoming net packets + set.spawn(tun_listener( + self.state.clone(), + self.network.clone(), + data_network_tx.clone() + )); + // Handle outgoing net packets + set.spawn(tun_sender( + self.state.clone(), + self.network.clone(), + network_sender_rx + )); + + // Executors + set.spawn(data_tun_executor( + self.state.clone(), + data_network_rx, + transport_sender_tx.clone(), + )); + + set.spawn(data_udp_executor( + self.state.clone(), + data_transport_rx, + network_sender_tx + )); + + + match self.keepalive { + Some(duration) => { + debug!("starting keepalive with interval {:?}", duration); + set.spawn(keepalive_sender( + self.state.clone(), + transport_sender_tx, + duration, + )); + }, + None => debug!("keepalive is disabled") + } + + // connector + set.spawn(connector::executor( + self.transport, + self.state.clone(), + self.cred, + self.alg, + self.handshake_timeout + )); + + while let Some(res) = set.join_all().await { + match res { + Ok(val) => debug!("task exited: {val:?}"), + Err(e) => debug!("task exited with error: {e}"), + } + } + Err(match self.state { + RuntimeState::Error(err) => err, + _ => RuntimeError::Unexpected("all tasks exited unexpectedly".into()) + }) + } +} diff --git a/crates/client/src/runtime/worker/connector.rs b/crates/holynet-sdk/src/runtime/client/connector.rs similarity index 81% rename from crates/client/src/runtime/worker/connector.rs rename to crates/holynet-sdk/src/runtime/client/connector.rs index 53c1c09..85e4bf7 100644 --- a/crates/client/src/runtime/worker/connector.rs +++ b/crates/holynet-sdk/src/runtime/client/connector.rs @@ -1,9 +1,14 @@ use std::sync::Arc; use std::time::Duration; +use futures::SinkExt; use tokio::sync::watch; use tracing::{debug, error}; use shared::connection_config::CredentialsConfig; use shared::session::Alg; +use crate::gateway::transport::Transport; +use crate::keys::handshake::{PublicKey, SecretKey}; +use crate::protocol::Alg; +use crate::runtime::cred::Cred; use crate::runtime::handshake::handshake_step; use crate::runtime::state::RuntimeState; use crate::runtime::transport::Transport; @@ -12,25 +17,22 @@ use super::super::{ }; -const RECONNECT_DELAY: Duration = Duration::from_secs(3); - pub(crate) async fn executor( + state: watch::Sender, transport: Arc, - state_tx: watch::Sender, - // for handshake step: - cred: CredentialsConfig, + cred: Cred, alg: Alg, + reconnect_delay: Duration, timeout: Duration -) { - let mut state_rx = state_tx.subscribe(); +) -> ! { + let mut state_rx = state.subscribe(); state_rx.mark_changed(); - let mut ticker = tokio::time::interval(RECONNECT_DELAY); + let mut ticker = tokio::time::interval(reconnect_delay); let mut is_reconnect = false; - loop { match state_rx.changed().await { Ok(_) => { - let state = state_rx.borrow().clone(); + let mut state = state_rx.borrow().clone(); match state { RuntimeState::Connecting => match transport.connect().await { Ok(_) => match handshake_step( @@ -41,14 +43,14 @@ pub(crate) async fn executor( ).await { Ok((payload, transport_state)) => { is_reconnect = true; - state_tx.send(RuntimeState::Connected((payload, Arc::new(transport_state)))) + state.send(RuntimeState::Connected((payload, Arc::new(transport_state)))) .expect("broken runtime state pipe"); continue }, // if conn is ok, but handshake no :( Err(err) => match is_reconnect { false => { - state_tx.send(RuntimeState::Error(err)) + state.send(RuntimeState::Error(err)) .expect("broken runtime state pipe"); return; }, @@ -62,7 +64,7 @@ pub(crate) async fn executor( // if connecting err Err(err) => match is_reconnect { false => { - state_tx.send(RuntimeState::Error( + state.send(RuntimeState::Error( RuntimeError::IO(format!("connecting error: {}", err)) )).expect( "broken runtime state pipe" diff --git a/crates/client/src/runtime/worker/data.rs b/crates/holynet-sdk/src/runtime/client/data.rs similarity index 98% rename from crates/client/src/runtime/worker/data.rs rename to crates/holynet-sdk/src/runtime/client/data.rs index f0bc25d..d9fdf9d 100644 --- a/crates/client/src/runtime/worker/data.rs +++ b/crates/holynet-sdk/src/runtime/client/data.rs @@ -8,6 +8,8 @@ use std::time::Duration; use tokio::sync::watch::Sender; use tokio::sync::mpsc; use tracing::{info, warn}; +use crate::error::RuntimeError; +use crate::protocol::{DataClientBody, SessionId}; use crate::runtime::state::RuntimeState; fn decrypt_body( diff --git a/crates/client/src/runtime/worker/tun.rs b/crates/holynet-sdk/src/runtime/client/network.rs similarity index 68% rename from crates/client/src/runtime/worker/tun.rs rename to crates/holynet-sdk/src/runtime/client/network.rs index 7a022eb..2e3b780 100644 --- a/crates/client/src/runtime/worker/tun.rs +++ b/crates/holynet-sdk/src/runtime/client/network.rs @@ -1,16 +1,24 @@ -use std::ops::Deref; -use std::sync::Arc; -use std::time::Duration; -use tokio::sync::watch::Sender; -use tokio::sync::mpsc; +use std::{ + ops::Deref, + sync::Arc, +}; +use tokio::sync::{ + watch::Sender, + mpsc +}; use tracing::{error, warn}; -use tun_rs::AsyncDevice; -use crate::runtime::error::RuntimeError; -use crate::runtime::state::RuntimeState; +use crate::{ + gateway::network::Network, + error::RuntimeError, + runtime::{ + state::RuntimeState, + client::{AWAIT_STATE_DELAY, MAX_PACKET_SIZE} + } +}; -pub async fn tun_sender( +pub async fn network_sender( state_tx: Sender, - tun: Arc, + network: Arc, mut queue: mpsc::Receiver> ) { let mut state_rx = state_tx.subscribe(); @@ -24,9 +32,9 @@ pub async fn tun_sender( }, result = queue.recv() => match result { Some(packet) => { - if let Err(err) = tun.send(&packet).await { + if let Err(err) = network.send(&packet).await { state_tx.send(RuntimeState::Error( - RuntimeError::IO(format!("failed to send tun: {}", err)) + RuntimeError::IO(format!("failed to send network: {}", err)) )).unwrap(); } }, @@ -36,16 +44,17 @@ pub async fn tun_sender( } } -pub async fn tun_listener( +pub async fn network_receiver( state_tx: Sender, - tun: Arc, + network: Arc, queue: mpsc::Sender> ) { - let mut state_wait_timer = tokio::time::interval(Duration::from_secs(1)); + let mut state_wait_timer = tokio::time::interval(AWAIT_STATE_DELAY); let mut state_rx = state_tx.subscribe(); let mut is_connected = false; - let mut buffer = [0u8; 65536]; + + let mut buffer = [0u8; MAX_PACKET_SIZE]; loop { if !is_connected && !state_rx.has_changed().unwrap() { state_wait_timer.tick().await; @@ -62,16 +71,17 @@ pub async fn tun_listener( RuntimeState::Connected(_) => { is_connected = true; } + _ => {} } }, - result = tun.recv(&mut buffer) => match result { + result = network.recv(&mut buffer) => match result { Ok(n) => { if n == 0 { - warn!("received tun packet with 0 bytes, dropping it"); + warn!("received network packet with 0 bytes, dropping it"); continue; } - if n > 65536 { - warn!("received tun packet larger than 65536 bytes, dropping it (check ur mtu)"); + if n > MAX_PACKET_SIZE { + warn!("received network packet larger than 65536 bytes, dropping it (check your mtu)"); continue; } if let Err(err) = queue.send(buffer[..n].to_vec()).await { @@ -80,11 +90,10 @@ pub async fn tun_listener( } Err(err) => { state_tx.send(RuntimeState::Error( - RuntimeError::IO(format!("failed to receive tun: {}",err)) + RuntimeError::IO(format!("failed to receive network: {}",err)) )).unwrap(); } } } } } - diff --git a/crates/client/src/runtime/worker/transport.rs b/crates/holynet-sdk/src/runtime/client/transport.rs similarity index 65% rename from crates/client/src/runtime/worker/transport.rs rename to crates/holynet-sdk/src/runtime/client/transport.rs index cfa11e4..1eed7ec 100644 --- a/crates/client/src/runtime/worker/transport.rs +++ b/crates/holynet-sdk/src/runtime/client/transport.rs @@ -1,44 +1,59 @@ -use std::ops::Deref; -use std::sync::Arc; -use std::time::Duration; -use tokio::sync::watch::Sender; -use tokio::sync::mpsc; +use std::{ + ops::Deref, + sync::Arc, +}; + +use tokio::{ + time::interval, + sync::{ + watch::Sender, + mpsc + } +}; use tracing::{debug, error, warn}; -use shared::protocol::{EncryptedData, Packet}; -use crate::runtime::state::RuntimeState; -use crate::runtime::transport::{TransportReceiver, TransportSender}; +use crate::{ + gateway::transport::{TransportReceiver, TransportSender}, + protocol::{Packet, EncryptedData}, + runtime::{ + state::RuntimeState, + client::{AWAIT_STATE_DELAY, MAX_PACKET_SIZE} + } +}; pub async fn transport_sender( - state_tx: Sender, + state: Sender, transport: Arc, mut queue: mpsc::Receiver ) { - let mut state_wait_timer = tokio::time::interval(Duration::from_secs(1)); + let mut state_wait_timer = interval(AWAIT_STATE_DELAY); - let mut state_rx = state_tx.subscribe(); + let mut state_rx = state.subscribe(); let mut is_connected = false; loop { + // If the application's state has changed (the connection has been lost, etc.), + // it makes sense to stop and wait for everything to recover, rather than waste + // CPU on executing unnecessary tasks. if !is_connected && !state_rx.has_changed().unwrap() { state_wait_timer.tick().await; continue; } - + tokio::select! { _ = state_rx.changed() => match state_rx.borrow().deref() { RuntimeState::Error(_) => break, + RuntimeState::Listening | RuntimeState::Connected(_) => { + is_connected = true; + } RuntimeState::Connecting => { is_connected = false; - }, - RuntimeState::Connected(_) => { - is_connected = true; } }, result = queue.recv() => match result { Some(packet) => match transport.send(&packet.to_bytes()).await { Ok(n) => debug!("sent transport packet with {} bytes", n), - Err(_) => { - state_tx.send(RuntimeState::Connecting).unwrap(); // todo log + Err(_) => { // todo provide error and resolve it in higher level + state.send(RuntimeState::Connecting).unwrap(); } }, None => break @@ -47,17 +62,20 @@ pub async fn transport_sender( } } -pub async fn transport_listener( - state_tx: Sender, +pub async fn transport_receiver( + state: Sender, transport: Arc, data_receiver: mpsc::Sender ) { - let mut state_wait_timer = tokio::time::interval(Duration::from_secs(1)); + let mut state_wait_timer = interval(AWAIT_STATE_DELAY); - let mut state_rx = state_tx.subscribe(); + let mut state_rx = state.subscribe(); let mut is_connected = false; - let mut transport_buffer = [0u8; 65536]; + let mut transport_buffer = [0u8; MAX_PACKET_SIZE]; loop { + // If the application's state has changed (the connection has been lost, etc.), + // it makes sense to stop and wait for everything to recover, rather than waste + // CPU on executing unnecessary tasks. if !is_connected && !state_rx.has_changed().unwrap() { state_wait_timer.tick().await; continue; @@ -66,11 +84,11 @@ pub async fn transport_listener( tokio::select! { _ = state_rx.changed() => match state_rx.borrow().deref() { RuntimeState::Error(_) => break, + RuntimeState::Listening | RuntimeState::Connected(_) => { + is_connected = true; + }, RuntimeState::Connecting => { is_connected = false; - }, - RuntimeState::Connected(_) => { - is_connected = true; } }, result = transport.recv(&mut transport_buffer) => match result { @@ -80,7 +98,7 @@ pub async fn transport_listener( warn!("received transport packet with 0 bytes, dropping it"); continue; } - if n > 65536 { + if n > MAX_PACKET_SIZE { warn!("received transport packet larger than 65536 bytes, dropping it"); continue; } @@ -106,8 +124,8 @@ pub async fn transport_listener( } } } - Err(_) => state_tx.send(RuntimeState::Connecting).unwrap() // todo log + Err(_) => state.send(RuntimeState::Connecting).unwrap() // todo provide error and resolve it in higher level } } } -} \ No newline at end of file +} diff --git a/crates/holynet-sdk/src/runtime/cred.rs b/crates/holynet-sdk/src/runtime/cred.rs new file mode 100644 index 0000000..09eacee --- /dev/null +++ b/crates/holynet-sdk/src/runtime/cred.rs @@ -0,0 +1,9 @@ + + + + +pub(crate) struct Cred { + pub sk: [u8; 32], + pub psk: [u8; 32], + pub spk: [u8; 32], +} \ No newline at end of file diff --git a/crates/holynet-sdk/src/runtime/mod.rs b/crates/holynet-sdk/src/runtime/mod.rs new file mode 100644 index 0000000..e78e3ba --- /dev/null +++ b/crates/holynet-sdk/src/runtime/mod.rs @@ -0,0 +1,3 @@ +mod state; +pub mod client; +mod cred; diff --git a/crates/client/src/runtime/state.rs b/crates/holynet-sdk/src/runtime/state.rs similarity index 61% rename from crates/client/src/runtime/state.rs rename to crates/holynet-sdk/src/runtime/state.rs index 3c3bae7..f4ace89 100644 --- a/crates/client/src/runtime/state.rs +++ b/crates/holynet-sdk/src/runtime/state.rs @@ -1,11 +1,12 @@ use std::sync::Arc; use snow::StatelessTransportState; -use shared::protocol::HandshakeResponderPayload; -use crate::runtime::error::RuntimeError; +use crate::protocol::HandshakeResponderPayload; +use crate::error::RuntimeError; #[derive(Debug, Clone)] pub enum RuntimeState { Connecting, Connected((HandshakeResponderPayload, Arc)), - Error(RuntimeError) + Error(RuntimeError), + Listening, } \ No newline at end of file diff --git a/crates/server/src/runtime/error.rs b/crates/server/src/runtime/error.rs deleted file mode 100644 index 8751f51..0000000 --- a/crates/server/src/runtime/error.rs +++ /dev/null @@ -1,20 +0,0 @@ -use thiserror::Error; - - -#[derive(Error, Debug, Clone)] -pub enum RuntimeError { - #[error("Tun: {0}")] - Tun(String), - #[error("IO: {0}")] - IO(String), - #[error("Unexpected: {0}")] - Unexpected(String), - #[error("StopSignal")] - StopSignal -} - -impl From for RuntimeError { - fn from(err: std::io::Error) -> Self { - RuntimeError::IO(err.to_string()) - } -} diff --git a/crates/shared/src/keys/handshake.rs b/crates/shared/src/keys/handshake.rs deleted file mode 100644 index 541b6d5..0000000 --- a/crates/shared/src/keys/handshake.rs +++ /dev/null @@ -1,42 +0,0 @@ -use rand_core::OsRng; -use super::Key; - - -pub type SecretKey = Key<32>; -pub type PublicKey = Key<32>; - -impl PublicKey { - pub fn derive_from(secret: SecretKey) -> Self { - Self::from(x25519_dalek::PublicKey::from( - &x25519_dalek::StaticSecret::from(Into::<[u8; 32]>::into(secret)) - ).to_bytes()) - } -} - -impl SecretKey { - pub fn generate_x25519() -> Self { x25519_dalek::StaticSecret::random_from_rng(OsRng).into() } -} - -impl Into for PublicKey { - fn into(self) -> x25519_dalek::PublicKey { - x25519_dalek::PublicKey::from(self.0) - } -} - -impl Into for SecretKey { - fn into(self) -> x25519_dalek::StaticSecret { - x25519_dalek::StaticSecret::from(self.0) - } -} - -impl From for PublicKey { - fn from(key: x25519_dalek::PublicKey) -> Self { - Self::from(key.to_bytes()) - } -} - -impl From for SecretKey { - fn from(key: x25519_dalek::StaticSecret) -> Self { - Self::from(key.to_bytes()) - } -} diff --git a/crates/shared/src/keys/mod.rs b/crates/shared/src/keys/mod.rs deleted file mode 100644 index 6490ef5..0000000 --- a/crates/shared/src/keys/mod.rs +++ /dev/null @@ -1,105 +0,0 @@ -use base64::engine::general_purpose::STANDARD_NO_PAD; -use base64::Engine; -use rand_core::{OsRng, RngCore}; -use serde::{de, Deserialize, Deserializer, Serialize, Serializer}; -use std::fmt; -use std::fmt::{Display, Formatter}; -use std::ops::Deref; - -pub mod handshake; - -#[derive(Clone, Debug, Eq, PartialEq, Hash)] -pub struct Key(pub [u8; SIZE]); - -impl Key { - pub const SIZE: usize = SIZE; - - pub fn generate() -> Self { - let mut key = [0u8; SIZE]; - OsRng.fill_bytes(&mut key); - Self(key) - } -} - -impl Deref for Key { - type Target = [u8; SIZE]; - - fn deref(&self) -> &Self::Target { &self.0 } -} - -impl TryFrom<&[u8]> for Key { - type Error = anyhow::Error; - - fn try_from(slice: &[u8]) -> Result { - if slice.len() != SIZE { - return Err(anyhow::anyhow!("invalid key size, expected {}", SIZE)); - } - let mut key = [0u8; SIZE]; - key.copy_from_slice(slice); - Ok(Self(key)) - } -} - -impl TryFrom<&str> for Key { - type Error = anyhow::Error; - - fn try_from(value: &str) -> Result { - let bytes = STANDARD_NO_PAD.decode(value)?; - Self::try_from(bytes.as_slice()) - } -} - -impl Display for Key { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - write!(f, "{}", STANDARD_NO_PAD.encode(self.0)) - } -} - - -impl Into<[u8; SIZE]> for Key { - fn into(self) -> [u8; SIZE] { - self.0 - } -} - -impl From<[u8; SIZE]> for Key { - fn from(key: [u8; SIZE]) -> Self { - Self(key) - } -} - -impl Serialize for Key { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - if serializer.is_human_readable() { - let s = STANDARD_NO_PAD.encode(self.0); - serializer.serialize_str(&s) - } else { - serializer.serialize_bytes(&self.0) - } - } -} - -impl<'de, const SIZE: usize> Deserialize<'de> for Key { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - let bytes = if deserializer.is_human_readable() { - let s = String::deserialize(deserializer)?; - STANDARD_NO_PAD.decode(&s).map_err(de::Error::custom)? - } else { - Vec::::deserialize(deserializer)? - }; - - if bytes.len() != SIZE { - return Err(de::Error::custom(format!("key must be {} bytes", SIZE))); - } - - let mut key = [0u8; SIZE]; - key.copy_from_slice(&bytes); // todo: use array, not vec - Ok(Self(key)) - } -} diff --git a/crates/shared/src/lib.rs b/crates/shared/src/lib.rs index d0328ae..24f1d1c 100644 --- a/crates/shared/src/lib.rs +++ b/crates/shared/src/lib.rs @@ -1,11 +1,7 @@ pub mod connection_config; -pub mod keys; -pub mod session; pub mod handshake; pub mod credential; pub mod network; pub mod time; -pub mod protocol; pub mod tun; -pub mod types; pub mod style;