From 9f186a033d86daccacf93cbb15146e35bdf14599 Mon Sep 17 00:00:00 2001 From: JKearnsl Date: Fri, 2 May 2025 16:11:27 +0300 Subject: [PATCH 1/3] refactor: implement UDP transport layer as DI --- crates/client/Cargo.toml | 1 + crates/client/src/bin/command/connect.rs | 2 +- crates/client/src/runtime/mod.rs | 1 + crates/client/src/runtime/transport.rs | 16 ++ crates/client/src/runtime/transport/udp.rs | 50 +++++ crates/client/src/runtime/worker/handshake.rs | 12 +- crates/client/src/runtime/worker/mod.rs | 191 +++--------------- crates/client/src/runtime/worker/transport.rs | 78 +++++++ crates/client/src/runtime/worker/tun.rs | 60 ++++++ crates/server/src/runtime/mod.rs | 28 +-- crates/server/src/runtime/transport.rs | 17 ++ crates/server/src/runtime/transport/udp.rs | 63 ++++++ crates/server/src/runtime/worker/mod.rs | 10 +- crates/server/src/runtime/worker/udp.rs | 21 +- 14 files changed, 355 insertions(+), 195 deletions(-) create mode 100644 crates/client/src/runtime/transport.rs create mode 100644 crates/client/src/runtime/transport/udp.rs create mode 100644 crates/client/src/runtime/worker/transport.rs create mode 100644 crates/client/src/runtime/worker/tun.rs create mode 100644 crates/server/src/runtime/transport.rs create mode 100644 crates/server/src/runtime/transport/udp.rs diff --git a/crates/client/Cargo.toml b/crates/client/Cargo.toml index eb5d711..f259493 100644 --- a/crates/client/Cargo.toml +++ b/crates/client/Cargo.toml @@ -39,3 +39,4 @@ bincode = { workspace = true } tracing-subscriber = { workspace = true } tracing = { workspace = true } tracing-appender = { workspace = true } +async-trait = "0.1.88" diff --git a/crates/client/src/bin/command/connect.rs b/crates/client/src/bin/command/connect.rs index ba69d99..b5db588 100644 --- a/crates/client/src/bin/command/connect.rs +++ b/crates/client/src/bin/command/connect.rs @@ -88,7 +88,7 @@ impl ConnectCmd { debug!("stop signal not sent from Ctrl-C handler: {}", err); } } - thread::sleep(Duration::from_secs(1)); + thread::sleep(Duration::from_secs(2)); process::exit(0); }).expect("error setting Ctrl-C handler"); diff --git a/crates/client/src/runtime/mod.rs b/crates/client/src/runtime/mod.rs index d41de3d..eeb7f70 100644 --- a/crates/client/src/runtime/mod.rs +++ b/crates/client/src/runtime/mod.rs @@ -1,5 +1,6 @@ pub mod error; mod worker; +mod transport; use std::net::{IpAddr, SocketAddr}; use std::time::Duration; diff --git a/crates/client/src/runtime/transport.rs b/crates/client/src/runtime/transport.rs new file mode 100644 index 0000000..d217b09 --- /dev/null +++ b/crates/client/src/runtime/transport.rs @@ -0,0 +1,16 @@ +pub mod udp; + +use async_trait::async_trait; +use std::io; + +#[async_trait] +pub trait TransportSender: Send + Sync { + async fn send(&self, data: &[u8]) -> io::Result; +} + +#[async_trait] +pub trait TransportReceiver: Send + Sync { + async fn recv(&self, buffer: &mut [u8]) -> io::Result; +} + +pub trait Transport: TransportSender + TransportReceiver{} diff --git a/crates/client/src/runtime/transport/udp.rs b/crates/client/src/runtime/transport/udp.rs new file mode 100644 index 0000000..a63ddf6 --- /dev/null +++ b/crates/client/src/runtime/transport/udp.rs @@ -0,0 +1,50 @@ +use crate::runtime::error::RuntimeError; +use crate::runtime::transport::{Transport, TransportReceiver, TransportSender}; +use async_trait::async_trait; +use socket2::{Domain, Protocol, Socket, Type}; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use tokio::net::UdpSocket; + +pub struct UdpTransport { + socket: UdpSocket +} + +impl UdpTransport { + pub fn new( + addr: SocketAddr, + so_rcvbuf: usize, + so_sndbuf: usize, + ) -> Result { + let socket = Socket::new( + Domain::for_address(addr), + Type::DGRAM, + Some(Protocol::UDP) + )?; + socket.set_nonblocking(true)?; + socket.set_recv_buffer_size(so_rcvbuf)?; + socket.set_send_buffer_size(so_sndbuf)?; + socket.bind(&SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0,0,0,0)), 0).into())?; + socket.connect(&addr.into())?; + + Ok(Self { socket: UdpSocket::from_std(socket.into())? }) + } +} + +#[async_trait] +impl TransportReceiver for UdpTransport { + + #[inline(always)] + async fn recv(&self, buffer: &mut [u8]) -> std::io::Result { + self.socket.recv(buffer).await + } +} + +#[async_trait] +impl TransportSender for UdpTransport { + #[inline(always)] + async fn send(&self, data: &[u8]) -> std::io::Result { + self.socket.send(data).await + } +} + +impl Transport for UdpTransport{} diff --git a/crates/client/src/runtime/worker/handshake.rs b/crates/client/src/runtime/worker/handshake.rs index 1b6b288..e741512 100644 --- a/crates/client/src/runtime/worker/handshake.rs +++ b/crates/client/src/runtime/worker/handshake.rs @@ -1,7 +1,6 @@ use std::sync::Arc; use std::time::Duration; use snow::{Builder, HandshakeState, StatelessTransportState}; -use tokio::net::UdpSocket; use tracing::warn; use shared::connection_config::CredentialsConfig; use shared::handshake::{ @@ -16,6 +15,7 @@ use shared::protocol::{ Packet }; use shared::session::Alg; +use crate::runtime::transport::Transport; use super::super::{ error::RuntimeError }; @@ -54,7 +54,7 @@ fn complete( } pub(super) async fn handshake_step( - socket: Arc, + transport: Arc, cred: CredentialsConfig, alg: Alg, timeout: Duration @@ -65,7 +65,7 @@ pub(super) async fn handshake_step( &cred )?; - socket.send(&Packet::HandshakeInitial(handshake).to_bytes()).await?; + transport.send(&Packet::HandshakeInitial(handshake).to_bytes()).await?; // [step 2] Server complete let mut buffer = [0u8; 65536]; @@ -74,11 +74,13 @@ pub(super) async fn handshake_step( format!("server timeout ({:?})", timeout) )), handshake = async { loop { - let size = socket.recv(&mut buffer).await?; + let size = transport.recv(&mut buffer).await.map_err( + |err| RuntimeError::IO(format!("receive handshake: {}", err)) + )?; match Packet::try_from(&buffer[..size]) { Ok(Packet::HandshakeResponder(handshake)) => break Ok(handshake), Err(err) => { - warn!("failed to parse handshake packet: {}", err); + warn!("parse handshake packet: {}", err); continue; }, _ => { diff --git a/crates/client/src/runtime/worker/mod.rs b/crates/client/src/runtime/worker/mod.rs index c3aee01..2b85b8f 100644 --- a/crates/client/src/runtime/worker/mod.rs +++ b/crates/client/src/runtime/worker/mod.rs @@ -1,26 +1,33 @@ mod handshake; mod data; - +mod tun; +mod transport; + +use crate::{ + network::DefaultGateway, + runtime::{ + error::RuntimeError, + transport::udp::UdpTransport, + worker::{ + data::{data_tun_executor, data_udp_executor, keepalive_sender}, + handshake::handshake_step, + transport::{transport_listener, transport_sender}, + tun::{tun_listener, tun_sender}, + } + }, +}; +use shared::connection_config::{CredentialsConfig, InterfaceConfig, RuntimeConfig}; +use shared::protocol::{EncryptedData, Packet}; +use shared::session::Alg; +use shared::tun::setup_tun; +use std::time::Duration; use std::{ net::SocketAddr, sync::Arc }; -use std::net::{IpAddr, Ipv4Addr}; -use std::time::Duration; -use socket2::{Domain, Protocol, Socket, Type}; -use tokio::net::UdpSocket; +use tokio::sync::broadcast::{Sender}; use tokio::sync::mpsc; -use tokio::sync::broadcast::{Receiver, Sender}; -use tracing::{error, info, warn}; -use tun_rs::AsyncDevice; -use shared::protocol::{EncryptedData, Packet}; -use shared::connection_config::{CredentialsConfig, InterfaceConfig, RuntimeConfig}; -use shared::session::Alg; -use shared::tun::setup_tun; -use crate::network::DefaultGateway; -use crate::runtime::worker::data::{data_tun_executor, data_udp_executor, keepalive_sender}; -use crate::runtime::worker::handshake::handshake_step; -use super::error::RuntimeError; +use tracing::{info}; pub(crate) async fn create( @@ -31,19 +38,12 @@ pub(crate) async fn create( runtime_config: RuntimeConfig, iface_config: InterfaceConfig, ) -> Result<(), RuntimeError> { - let socket = Socket::new( - Domain::for_address(addr), - Type::DGRAM, - Some(Protocol::UDP) - )?; - socket.set_nonblocking(true)?; - // socket.set_reuse_port(true)?; - socket.set_recv_buffer_size(runtime_config.so_rcvbuf)?; - socket.set_send_buffer_size(runtime_config.so_sndbuf)?; - socket.bind(&SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0,0,0,0)), 0).into())?; - socket.connect(&addr.into())?; - let socket = Arc::new(UdpSocket::from_std(socket.into())?); + let transport = Arc::new(UdpTransport::new( + addr, + runtime_config.so_rcvbuf, + runtime_config.so_sndbuf, + )?); let (udp_sender_tx, udp_sender_rx) = mpsc::channel::(runtime_config.out_udp_buf); let (tun_sender_tx, tun_sender_rx) = mpsc::channel::>(runtime_config.out_tun_buf); let (data_udp_tx, data_udp_rx) = mpsc::channel::(runtime_config.data_udp_buf); @@ -51,7 +51,7 @@ pub(crate) async fn create( // Handshake step let (handshake_payload, state) = match tokio::spawn(handshake_step( - socket.clone(), + transport.clone(), cred, alg, Duration::from_millis(runtime_config.handshake_timeout) @@ -64,10 +64,10 @@ pub(crate) async fn create( }; // Handle incoming UDP packets - tokio::spawn(udp_listener(stop_tx.clone(), stop_tx.subscribe(), socket.clone(), data_udp_tx)); + tokio::spawn(transport_listener(stop_tx.clone(), stop_tx.subscribe(), transport.clone(), data_udp_tx)); // Handle outgoing UDP packets - tokio::spawn(udp_sender(stop_tx.clone(), stop_tx.subscribe(), socket.clone(), udp_sender_rx)); + tokio::spawn(transport_sender(stop_tx.clone(), stop_tx.subscribe(), transport.clone(), udp_sender_rx)); // Executors @@ -121,7 +121,7 @@ pub(crate) async fn create( match runtime_config.keepalive { Some(duration) => { - info!("starting keepalive sender with interval {:?}", duration); + info!("starting keepalive transport with interval {:?}", duration); tokio::spawn(keepalive_sender( stop_tx.clone(), stop_tx.subscribe(), @@ -131,7 +131,7 @@ pub(crate) async fn create( handshake_payload.sid, )); }, - None => info!("keepalive sender is disabled") + None => info!("keepalive transport is disabled") } let mut stop_rx = stop_tx.subscribe(); @@ -144,128 +144,3 @@ pub(crate) async fn create( Ok(()) } - -async fn tun_sender( - stop_sender: Sender, - mut stop: Receiver, - tun: Arc, - mut queue: mpsc::Receiver> -) { - loop { - tokio::select! { - _ = stop.recv() => break, - result = queue.recv() => match result { - Some(packet) => { - if let Err(err) = tun.send(&packet).await { - stop_sender.send(RuntimeError::IO(format!("failed to send tun: {}", err))).unwrap(); - } - }, - None => break - } - } - } -} - -async fn tun_listener( - stop_sender: Sender, - mut stop: Receiver, - tun: Arc, - queue: mpsc::Sender> -) { - let mut buffer = [0u8; 65536]; - loop { - tokio::select! { - _ = stop.recv() => break, - result = tun.recv(&mut buffer) => match result { - Ok(n) => { - if n == 0 { - warn!("received tun packet with 0 bytes, dropping it"); - continue; - } - if n > 65536 { - warn!("received tun packet larger than 65536 bytes, dropping it (check ur mtu)"); - continue; - } - if let Err(err) = queue.send(buffer[..n].to_vec()).await { - error!("failed to send data to data_receiver: {}", err); - } - } - Err(err) => { - stop_sender.send(RuntimeError::IO(format!("failed to receive tun: {}",err))).unwrap(); - } - } - } - } -} - - -async fn udp_sender( - stop_sender: Sender, - mut stop: Receiver, - socket: Arc, - mut queue: mpsc::Receiver -) { - loop { - tokio::select! { - _ = stop.recv() => break, - result = queue.recv() => match result { - Some(packet) => { - if let Err(err) = socket.send(&packet.to_bytes()).await { - stop_sender.send(RuntimeError::IO(format!("failed to send udp: {}", err))).unwrap(); - } - }, - None => break - } - } - } -} - -async fn udp_listener( - stop_sender: Sender, - mut stop: Receiver, - socket: Arc, - data_receiver: mpsc::Sender -) { - let mut udp_buffer = [0u8; 65536]; - loop { - tokio::select! { - _ = stop.recv() => break, - result = socket.recv(&mut udp_buffer) => match result { - Ok(n) => { - if n == 0 { - warn!("received UDP packet with 0 bytes, dropping it"); - continue; - } - if n > 65536 { - warn!("received UDP packet larger than 65536 bytes, dropping it"); - continue; - } - match Packet::try_from(&udp_buffer[..n]) { - Ok(packet) => match packet { - Packet::DataServer(data) => { - if let Err(err) = data_receiver.send(data).await { - error!("failed to send data to data_receiver: {}", err); - } - }, - Packet::HandshakeResponder(_) => { - warn!("received handshake packet, but expected data packet"); - continue; - }, - _ => { - warn!("received unexpected packet type"); - continue; - } - }, - Err(err) => { - warn!("failed to parse UDP packet: {}", err); - continue; - } - } - } - Err(err) => { - stop_sender.send(RuntimeError::IO(format!("failed to receive udp: {}", err))).unwrap(); - } - } - } - } -} diff --git a/crates/client/src/runtime/worker/transport.rs b/crates/client/src/runtime/worker/transport.rs new file mode 100644 index 0000000..d368708 --- /dev/null +++ b/crates/client/src/runtime/worker/transport.rs @@ -0,0 +1,78 @@ +use std::sync::Arc; +use tokio::sync::broadcast::{Receiver, Sender}; +use tokio::sync::mpsc; +use tracing::{error, warn}; +use shared::protocol::{EncryptedData, Packet}; +use crate::runtime::error::RuntimeError; +use crate::runtime::transport::{TransportReceiver, TransportSender}; + +pub async fn transport_sender( + stop_sender: Sender, + mut stop: Receiver, + transport: Arc, + mut queue: mpsc::Receiver +) { + loop { + tokio::select! { + _ = stop.recv() => break, + result = queue.recv() => match result { + Some(packet) => { + if let Err(err) = transport.send(&packet.to_bytes()).await { + stop_sender.send(RuntimeError::IO(format!("failed to send udp: {}", err))).unwrap(); + } + }, + None => break + } + } + } +} + +pub async fn transport_listener( + stop_sender: Sender, + mut stop: Receiver, + transport: Arc, + data_receiver: mpsc::Sender +) { + let mut udp_buffer = [0u8; 65536]; + loop { + tokio::select! { + _ = stop.recv() => break, + result = transport.recv(&mut udp_buffer) => match result { + Ok(n) => { + if n == 0 { + warn!("received UDP packet with 0 bytes, dropping it"); + continue; + } + if n > 65536 { + warn!("received UDP packet larger than 65536 bytes, dropping it"); + continue; + } + match Packet::try_from(&udp_buffer[..n]) { + Ok(packet) => match packet { + Packet::DataServer(data) => { + if let Err(err) = data_receiver.send(data).await { + error!("failed to send data to data_receiver: {}", err); + } + }, + Packet::HandshakeResponder(_) => { + warn!("received handshake packet, but expected data packet"); + continue; + }, + _ => { + warn!("received unexpected packet type"); + continue; + } + }, + Err(err) => { + warn!("failed to parse UDP packet: {}", err); + continue; + } + } + } + Err(err) => { + stop_sender.send(RuntimeError::IO(format!("failed to receive udp: {}", err))).unwrap(); + } + } + } + } +} \ No newline at end of file diff --git a/crates/client/src/runtime/worker/tun.rs b/crates/client/src/runtime/worker/tun.rs new file mode 100644 index 0000000..16d3e79 --- /dev/null +++ b/crates/client/src/runtime/worker/tun.rs @@ -0,0 +1,60 @@ +use std::sync::Arc; +use tokio::sync::broadcast::{Receiver, Sender}; +use tokio::sync::mpsc; +use tracing::{error, warn}; +use tun_rs::AsyncDevice; +use crate::runtime::error::RuntimeError; + +pub async fn tun_sender( + stop_sender: Sender, + mut stop: Receiver, + tun: Arc, + mut queue: mpsc::Receiver> +) { + loop { + tokio::select! { + _ = stop.recv() => break, + result = queue.recv() => match result { + Some(packet) => { + if let Err(err) = tun.send(&packet).await { + stop_sender.send(RuntimeError::IO(format!("failed to send tun: {}", err))).unwrap(); + } + }, + None => break + } + } + } +} + +pub async fn tun_listener( + stop_sender: Sender, + mut stop: Receiver, + tun: Arc, + queue: mpsc::Sender> +) { + let mut buffer = [0u8; 65536]; + loop { + tokio::select! { + _ = stop.recv() => break, + result = tun.recv(&mut buffer) => match result { + Ok(n) => { + if n == 0 { + warn!("received tun packet with 0 bytes, dropping it"); + continue; + } + if n > 65536 { + warn!("received tun packet larger than 65536 bytes, dropping it (check ur mtu)"); + continue; + } + if let Err(err) = queue.send(buffer[..n].to_vec()).await { + error!("failed to send data to data_receiver: {}", err); + } + } + Err(err) => { + stop_sender.send(RuntimeError::IO(format!("failed to receive tun: {}",err))).unwrap(); + } + } + } + } +} + diff --git a/crates/server/src/runtime/mod.rs b/crates/server/src/runtime/mod.rs index a345631..e9d64ff 100644 --- a/crates/server/src/runtime/mod.rs +++ b/crates/server/src/runtime/mod.rs @@ -2,6 +2,7 @@ pub mod session; mod worker; pub mod error; mod network; +mod transport; use std::{ net::{IpAddr, SocketAddr}, @@ -10,10 +11,9 @@ use std::{ use std::sync::Arc; use std::time::Duration; use dashmap::DashMap; -use socket2::{Domain, Protocol, Socket, Type}; use self::{ error::RuntimeError, - session::{ Sessions} + session::Sessions }; use tokio::runtime::Builder; @@ -24,6 +24,7 @@ use shared::{ network::set_ipv4_forwarding }; use shared::tun::setup_tun; +use crate::runtime::transport::udp::UdpTransport; pub struct Runtime { sock: SocketAddr, @@ -90,20 +91,13 @@ impl Runtime { true ).await.map_err(|err| vec![RuntimeError::from(err)])?; - let socket = Socket::new( - Domain::for_address(self.sock), - Type::DGRAM, - Some(Protocol::UDP) + let mut transports = UdpTransport::new_pool( + self.sock, + self.config.so_rcvbuf, + self.config.so_sndbuf, + workers ).map_err(|err| vec![RuntimeError::from(err)])?; - socket.set_nonblocking(true).map_err(|err| vec![RuntimeError::from(err)])?; - socket.set_reuse_port(true).map_err(|err| vec![RuntimeError::from(err)])?; - socket.set_reuse_address(true).map_err(|err| vec![RuntimeError::from(err)])?; - socket.set_recv_buffer_size(self.config.so_rcvbuf).map_err(|err| vec![RuntimeError::from(err)])?; - socket.set_send_buffer_size(self.config.so_sndbuf).map_err(|err| vec![RuntimeError::from(err)])?; - socket.set_tos(0b101110 << 2).map_err(|err| vec![RuntimeError::from(err)])?; - socket.bind(&self.sock.into()).map_err(|err| vec![RuntimeError::from(err)])?; - let rt = Builder::new_multi_thread() .worker_threads(workers) .enable_all() @@ -120,16 +114,14 @@ impl Runtime { let tun = tun.try_clone().map_err(|err| vec![RuntimeError::Tun( format!("failed to clone tun device: {}", err) )])?; - let socket = socket.try_clone().map_err(|err| vec![RuntimeError::IO( - format!("failed to clone socket: {}", err) - )])?; + let transport = Arc::new(transports.pop().unwrap()); // unwrap is safe here let config = self.config.clone(); let handle = rt.spawn(async move { tracing::debug!("worker {} started", worker_id); if let Err(err) = worker::create( - socket, + transport, stop_tx, sessions, known_clients, diff --git a/crates/server/src/runtime/transport.rs b/crates/server/src/runtime/transport.rs new file mode 100644 index 0000000..6baeec4 --- /dev/null +++ b/crates/server/src/runtime/transport.rs @@ -0,0 +1,17 @@ +pub mod udp; + +use std::io; +use std::net::SocketAddr; +use async_trait::async_trait; + +#[async_trait] +pub trait TransportSender: Send + Sync { + async fn send_to(&self, data: &[u8], addr: &SocketAddr) -> io::Result; +} + +#[async_trait] +pub trait TransportReceiver: Send + Sync { + async fn recv_from(&self, buffer: &mut [u8]) -> io::Result<(usize, SocketAddr)>; +} + +pub trait Transport: TransportSender + TransportReceiver{} \ No newline at end of file diff --git a/crates/server/src/runtime/transport/udp.rs b/crates/server/src/runtime/transport/udp.rs new file mode 100644 index 0000000..7f36d1f --- /dev/null +++ b/crates/server/src/runtime/transport/udp.rs @@ -0,0 +1,63 @@ +use crate::runtime::error::RuntimeError; +use crate::runtime::transport::{Transport, TransportReceiver, TransportSender}; +use async_trait::async_trait; +use socket2::{Domain, Protocol, Socket, Type}; +use std::net::SocketAddr; +use tokio::net::UdpSocket; + +pub struct UdpTransport { + socket: UdpSocket +} + +impl UdpTransport { + pub fn new_pool( + addr: SocketAddr, + so_rcvbuf: usize, + so_sndbuf: usize, + count: usize + ) -> Result, RuntimeError> { + let socket = Socket::new( + Domain::for_address(addr), + Type::DGRAM, + Some(Protocol::UDP) + )?; + + socket.set_nonblocking(true)?; + socket.set_reuse_port(true)?; + socket.set_reuse_address(true)?; + socket.set_recv_buffer_size(so_rcvbuf)?; + socket.set_send_buffer_size(so_sndbuf)?; + socket.set_tos(0b101110 << 2)?; + socket.bind(&addr.into())?; + + let mut sockets = Vec::with_capacity(count); + for i in 0..count { + let cloned_raw_socket = socket.try_clone().map_err(|err| { + RuntimeError::IO(format!("clone socket #{}: {}", i + 1, err)) + })?.into(); + + sockets.push(Self { socket: UdpSocket::from_std(cloned_raw_socket)? }); + } + + Ok(sockets) + } +} + +#[async_trait] +impl TransportReceiver for UdpTransport { + + #[inline(always)] + async fn recv_from(&self, buffer: &mut [u8]) -> std::io::Result<(usize, SocketAddr)> { + self.socket.recv_from(buffer).await + } +} + +#[async_trait] +impl TransportSender for UdpTransport { + #[inline(always)] + async fn send_to(&self, data: &[u8], addr: &SocketAddr) -> std::io::Result { + self.socket.send_to(data, addr).await + } +} + +impl Transport for UdpTransport{} diff --git a/crates/server/src/runtime/worker/mod.rs b/crates/server/src/runtime/worker/mod.rs index e8a03d0..803ea89 100644 --- a/crates/server/src/runtime/worker/mod.rs +++ b/crates/server/src/runtime/worker/mod.rs @@ -19,19 +19,18 @@ use dashmap::DashMap; use shared::keys::handshake::{PublicKey, SecretKey}; use shared::protocol::{EncryptedData, EncryptedHandshake, Packet}; use shared::session::SessionId; -use socket2::Socket; use std::{ net::SocketAddr, sync::Arc }; -use tokio::net::UdpSocket; use tokio::sync::{broadcast, mpsc}; use tracing::info; use tun_rs::AsyncDevice; +use crate::runtime::transport::Transport; #[allow(clippy::too_many_arguments)] pub(crate) async fn create( - socket: Socket, + transport: Arc, stop_tx: broadcast::Sender, sessions: Sessions, known_clients: Arc>, @@ -41,7 +40,6 @@ pub(crate) async fn create( config: RuntimeConfig ) -> Result<(), RuntimeError> { - let socket = Arc::new(UdpSocket::from_std(socket.into())?); let tun = Arc::new(tun); let (out_udp_tx, out_udp_rx) = mpsc::channel::<(Packet, SocketAddr)>(config.out_udp_buf); @@ -52,10 +50,10 @@ pub(crate) async fn create( // Handle incoming UDP packets - tokio::spawn(udp_listener(stop_tx.subscribe(), socket.clone(), handshake_tx, data_udp_tx)); + tokio::spawn(udp_listener(stop_tx.subscribe(), transport.clone(), handshake_tx, data_udp_tx)); // Handle outgoing UDP packets - tokio::spawn(udp_sender(stop_tx.subscribe(), socket.clone(), out_udp_rx)); + tokio::spawn(udp_sender(stop_tx.subscribe(), transport.clone(), out_udp_rx)); // Handle incoming TUN packets tokio::spawn(tun_listener(stop_tx.subscribe(), tun.clone(), data_tun_tx)); diff --git a/crates/server/src/runtime/worker/udp.rs b/crates/server/src/runtime/worker/udp.rs index 5dc73c1..27e93a2 100644 --- a/crates/server/src/runtime/worker/udp.rs +++ b/crates/server/src/runtime/worker/udp.rs @@ -1,16 +1,16 @@ use std::net::SocketAddr; use std::sync::Arc; -use tokio::net::UdpSocket; use tokio::sync::broadcast::Receiver; use tokio::sync::mpsc; -use tracing::{error, warn}; +use tracing::{debug, error, warn}; use shared::protocol::{EncryptedData, EncryptedHandshake, Packet}; use shared::session::SessionId; use crate::runtime::error::RuntimeError; +use crate::runtime::transport::{TransportReceiver, TransportSender}; pub async fn udp_sender( mut stop: Receiver, - socket: Arc, + transport: Arc, mut out_udp_rx: mpsc::Receiver<(Packet, SocketAddr)> ) { loop { @@ -18,8 +18,15 @@ pub async fn udp_sender( _ = stop.recv() => break, result = out_udp_rx.recv() => match result { Some((data, client_addr)) => { - if let Err(e) = socket.send_to(&data.to_bytes(), &client_addr).await { - warn!("failed to send data to {}: {}", client_addr, e); + match transport.send_to(&data.to_bytes(), &client_addr).await { + Ok(len) => { + debug!("sent packet to {}: len: {}", client_addr, len); + + }, + Err(e) => { + error!("failed to send data to {}: {}", client_addr, e); + continue; + } } }, None => break @@ -30,7 +37,7 @@ pub async fn udp_sender( pub async fn udp_listener( mut stop: Receiver, - socket: Arc, + transport: Arc, handshake_tx: mpsc::Sender<(EncryptedHandshake, SocketAddr)>, data_tx: mpsc::Sender<(SessionId, EncryptedData, SocketAddr)> ) { @@ -38,7 +45,7 @@ pub async fn udp_listener( loop { tokio::select! { _ = stop.recv() => break, - result = socket.recv_from(&mut udp_buffer) => { + result = transport.recv_from(&mut udp_buffer) => { match result { Ok((n, client_addr)) => { if n == 0 { From 0c2f6c9c82fa8376740f30fb77f92ed35bcc4415 Mon Sep 17 00:00:00 2001 From: JKearnsl Date: Sun, 4 May 2025 16:28:49 +0300 Subject: [PATCH 2/3] add dependencies for tokio-tungstenite and futures --- Cargo.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Cargo.toml b/Cargo.toml index 7aacf41..bd5926a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,8 @@ snow = "0.9" tokio = { version = "1.44", features = ["rt", "rt-multi-thread", "macros", "sync", "time", "socket2"] } tun-rs = { version = "2.0.9", features = ["async_tokio"] } socket2 = "0.5.9" +tokio-tungstenite = "0.26.2" +futures = "0.3.31" ctrlc = "3.4" # data From fb953304d99793e5745c8a40017da8ffd7a20a02 Mon Sep 17 00:00:00 2001 From: JKearnsl Date: Sun, 4 May 2025 16:36:09 +0300 Subject: [PATCH 3/3] feat: add WebSocket transport implementation; fix gateway -> route state --- crates/client/Cargo.toml | 10 +- crates/client/src/lib.rs | 10 + crates/client/src/network.rs | 237 +++++++++++------- crates/client/src/runtime/mod.rs | 2 - crates/client/src/runtime/transport.rs | 3 + crates/client/src/runtime/transport/udp.rs | 1 + crates/client/src/runtime/transport/ws.rs | 71 ++++++ crates/client/src/runtime/worker/mod.rs | 50 ++-- crates/client/src/runtime/worker/transport.rs | 22 +- crates/server/Cargo.toml | 11 +- crates/server/src/bin/command/start.rs | 7 +- crates/server/src/runtime/mod.rs | 51 ++-- crates/server/src/runtime/transport.rs | 9 +- crates/server/src/runtime/transport/udp.rs | 18 +- crates/server/src/runtime/transport/ws.rs | 151 +++++++++++ crates/server/src/runtime/worker/data.rs | 16 +- crates/server/src/runtime/worker/mod.rs | 48 ++-- .../runtime/worker/{udp.rs => transport.rs} | 23 +- 18 files changed, 560 insertions(+), 180 deletions(-) create mode 100644 crates/client/src/runtime/transport/ws.rs create mode 100644 crates/server/src/runtime/transport/ws.rs rename crates/server/src/runtime/worker/{udp.rs => transport.rs} (77%) diff --git a/crates/client/Cargo.toml b/crates/client/Cargo.toml index f259493..cea6d71 100644 --- a/crates/client/Cargo.toml +++ b/crates/client/Cargo.toml @@ -6,6 +6,8 @@ autobins = false [features] default = ["cli"] +udp = ["socket2"] +ws = ["tokio-tungstenite", "futures"] cli = ["clap", "ctrlc", "anstyle"] [[bin]] @@ -20,12 +22,13 @@ shared = { workspace = true } # IO tokio = { workspace = true } tun-rs = { workspace = true } -socket2 = { workspace = true} -pnet = "0.35.*" +socket2 = { workspace = true, optional = true} ctrlc = { workspace = true, optional = true} clap = { version = "4.5.23", features = ["derive"], optional = true} anstyle = { version = "1.0", optional = true } - +tokio-tungstenite = { workspace = true, optional = true } +futures = { workspace = true, optional = true } +ipnetwork = "0.21.1" # Crypto snow = "0.9" @@ -40,3 +43,4 @@ tracing-subscriber = { workspace = true } tracing = { workspace = true } tracing-appender = { workspace = true } async-trait = "0.1.88" + diff --git a/crates/client/src/lib.rs b/crates/client/src/lib.rs index a189645..d4e89be 100644 --- a/crates/client/src/lib.rs +++ b/crates/client/src/lib.rs @@ -6,5 +6,15 @@ //! //! +#[cfg(not(any( + feature = "udp", + feature = "ws", +)))] +compile_error!( + "please enable one of the following transport backends with cargo's --features argument: \ + udp, ws (e.g. --features=udp)" +); + + pub mod network; pub mod runtime; \ No newline at end of file diff --git a/crates/client/src/network.rs b/crates/client/src/network.rs index bf587a9..31ca48f 100644 --- a/crates/client/src/network.rs +++ b/crates/client/src/network.rs @@ -1,60 +1,73 @@ use std::net::{IpAddr}; use tracing::{info, warn}; use std::process::Command; -use crate::runtime::error::RuntimeError; +use anyhow::format_err; +use std::fmt::Write; +use std::str::FromStr; -pub enum RouteType { - Net, - Host, +pub struct RouteState { + dev: String, + default_gateway: Option, + exclude: Vec, } +use ipnetwork::{IpNetwork, NetworkSize}; -pub struct DefaultGateway { - origin: IpAddr, - remote: IpAddr, - default: bool, -} +impl RouteState { + pub fn new(remote: IpAddr, dev: String) -> Self { + Self { + dev, + default_gateway: None, + exclude: vec![IpNetwork::from(remote)] + } + } + + pub fn exclude(mut self, addr: IpNetwork) { + self.exclude.push(addr); + } -impl DefaultGateway { - pub fn create(new_gateway: &IpAddr, remote: &IpAddr, default: bool) -> Result { - let origin = default_gateway() - .map_err(|e| RuntimeError::Network(format!("getting default gateway: {}", e)))?; - info!("original default gateway: {}.", origin); - add_route(RouteType::Host, &remote.to_string(), &origin.to_string()) - .map_err(|error| RuntimeError::Network(format!( - "failed to add route: {} -> {}: {}", - remote, - origin, - error - )))?; + pub fn build(mut self) -> anyhow::Result { + let (default_gateway, default_dev_name) = default_device().map_err(|e| + format_err!("failed to get default device: {}", e) + )?; + self.default_gateway = Some(default_gateway); + info!("default gateway: {} from dev {}", default_gateway, default_dev_name); + add_route( + &IpNetwork::from_str("0.0.0.0/1")?, + None, + &self.dev, + Some(1), + )?; + add_route( + &IpNetwork::from_str("128.0.0.0/1")?, + None, + &self.dev, + Some(1), + )?; - if default { - delete_route(RouteType::Net, "default"); - add_route(RouteType::Net, "default", &new_gateway.to_string()) - .map_err(|error| RuntimeError::Network(format!( - "failed to add new default route: {} (new) -> {} (old): {}", - new_gateway, - origin, - error - )))?; + for addr in self.exclude.iter() { + add_route( + addr, + Some(default_gateway), + &default_dev_name, + None + )?; } - Ok(DefaultGateway { - origin, - remote: *remote, - default, - }) + Ok(self) } pub fn restore(&mut self) { - if self.default { - delete_route(RouteType::Net, "default"); - if let Err(e) = add_route(RouteType::Net, "default", &self.origin.to_string()) { - warn!("failed to restore default route: {}", e); - } else { - info!("restored default route: {}.", self.origin); + for addr in self.exclude.iter() { + match delete_route( + addr, + &self.default_gateway.expect( + "default gateway not set, cannot restore route (are you sure you called build?)" + ), + ) { + Ok(_) => {}, + Err(e) => warn!("failed to restore route: {} via {}: {}", addr, self.default_gateway.unwrap(), e), } } - delete_route(RouteType::Host, &self.remote.to_string()); } } @@ -64,46 +77,64 @@ impl DefaultGateway { // } // } -pub fn delete_route(route_type: RouteType, route: &str) { - let mode = match route_type { - RouteType::Net => "-net", - RouteType::Host => "-host", +pub fn delete_route(route: &IpNetwork, via: &IpAddr,) -> anyhow::Result<()> { + info!("deleting route: {} via {}", route, via); + + let (formated_route, _) = match route.size() { + NetworkSize::V4(32) | NetworkSize::V6(128) => (route.ip().to_string(), false), + _ => (route.to_string(), true), }; - info!("deleting route: {} {}.", mode, route); + let status = if cfg!(target_os = "linux") { + let check = Command::new("ip") + .arg("route") + .arg("show") + .arg(formated_route.clone()) + .output()?; + + if check.stdout.is_empty() { + warn!("route already deleted"); + return Ok(()); + } + Command::new("ip") .arg("route") .arg("del") - .arg(route) - .status() - .unwrap() - } else if cfg!(target_os = "macos") { - Command::new("route") - .arg("-n") - .arg("delete") - .arg(mode) - .arg(route) - .status() - .unwrap() + .arg(formated_route) + .arg("via") + .arg(via.to_string()) + .status()? } else { unimplemented!("Unsupported OS"); }; if !status.success() { - warn!("failed to delete route: {}", status); + warn!("cant delete route: {}", status); } + Ok(()) } -pub fn add_route(route_type: RouteType, route: &str, gateway: &str) -> anyhow::Result<()> { - let mode = match route_type { - RouteType::Net => "-net", - RouteType::Host => "-host", + +fn add_route(route: &IpNetwork, via: Option, dev: &str, metric: Option) -> anyhow::Result<()> { + let mut buffer = format!("adding route: {} ", route); + if let Some(via) = via { + write!(buffer, "via {} ", via)?; + } + write!(buffer, "dev {} ", dev)?; + if let Some(metric) = metric { + write!(buffer, "metric {}", metric)?; + } + info!("{}", buffer); + + let (formated_route, _) = match route.size() { + NetworkSize::V4(32) | NetworkSize::V6(128) => (route.ip().to_string(), false), + _ => (route.to_string(), true), }; - info!("adding route: {} {} gateway {}.", mode, route, gateway); + let status = if cfg!(target_os = "linux") { let check = Command::new("ip") .arg("route") .arg("show") - .arg(route) + .arg(formated_route.clone()) .output()?; if !check.stdout.is_empty() { @@ -111,21 +142,30 @@ pub fn add_route(route_type: RouteType, route: &str, gateway: &str) -> anyhow::R return Ok(()); } - Command::new("ip") - .arg("route") - .arg("add") - .arg(route) - .arg("via") - .arg(gateway) - .status()? - } else if cfg!(target_os = "macos") { - Command::new("route") - .arg("-n") - .arg("add") - .arg(mode) - .arg(route) - .arg(gateway) - .status()? + let mut cmd = Command::new("ip"); + + cmd.arg("route").arg("add").arg(formated_route); + + if let Some(via) = via { + cmd.arg("via").arg(via.to_string()); + }; + + cmd.arg("dev").arg(dev); + + if let Some(metric) = metric { + cmd.arg("metric").arg(metric.to_string()); + }; + + cmd.status()? + // } else if cfg!(target_os = "macos") { + // + // Command::new("route") + // .arg("-n") + // .arg("add") + // .arg(if is_net { "-net" } else { "-host" }) + // .arg(formated_route) + // .arg(gateway) + // .status()? } else { unimplemented!("Unsupported OS"); }; @@ -136,17 +176,46 @@ pub fn add_route(route_type: RouteType, route: &str, gateway: &str) -> anyhow::R } } -pub fn default_gateway() -> anyhow::Result { +pub fn default_device() -> anyhow::Result<(IpAddr, String)> { let cmd = if cfg!(target_os = "linux") { - "ip -4 route list 0/0 | awk '{print $3}'" + "ip -4 route list 0/0" } else if cfg!(target_os = "macos") { - "route -n get default | grep gateway | awk '{print $2}'" + "route -n get default" } else { unimplemented!("Unsupported OS"); }; + let output = Command::new("bash").arg("-c").arg(cmd).output()?; + if output.status.success() { - Ok(String::from_utf8(output.stdout)?.trim_end().parse()?) + let output_str = String::from_utf8(output.stdout)?; + + if cfg!(target_os = "linux") { + for line in output_str.lines() { + if line.contains("default") { + let parts: Vec<&str> = line.split_whitespace().collect(); + if parts.len() >= 4 { + let ip: IpAddr = parts[2].parse()?; + let interface = parts[4].to_string(); + return Ok((ip, interface)); + } + } + } + } + + // if cfg!(target_os = "macos") { + // for line in output_str.lines() { + // if line.contains("gateway") { + // let parts: Vec<&str> = line.split_whitespace().collect(); + // if parts.len() >= 2 { + // let ip: IpAddr = parts[1].parse()?; + // return Ok((ip, String::from("unknown"))); + // } + // } + // } + // } + + Err(anyhow::anyhow!("Failed to parse output")) } else { Err(anyhow::anyhow!(String::from_utf8(output.stderr)?)) } diff --git a/crates/client/src/runtime/mod.rs b/crates/client/src/runtime/mod.rs index eeb7f70..b54a556 100644 --- a/crates/client/src/runtime/mod.rs +++ b/crates/client/src/runtime/mod.rs @@ -43,8 +43,6 @@ impl Runtime { } pub async fn run(&self) -> Result<(), RuntimeError> { - tracing::info!("Connecting to udp://{}", self.sock); - let worker = worker::create( self.sock, self.stop_tx.clone(), diff --git a/crates/client/src/runtime/transport.rs b/crates/client/src/runtime/transport.rs index d217b09..350eb9e 100644 --- a/crates/client/src/runtime/transport.rs +++ b/crates/client/src/runtime/transport.rs @@ -1,4 +1,7 @@ +#[cfg(feature = "udp")] pub mod udp; +#[cfg(feature = "ws")] +pub mod ws; use async_trait::async_trait; use std::io; diff --git a/crates/client/src/runtime/transport/udp.rs b/crates/client/src/runtime/transport/udp.rs index a63ddf6..aedf940 100644 --- a/crates/client/src/runtime/transport/udp.rs +++ b/crates/client/src/runtime/transport/udp.rs @@ -15,6 +15,7 @@ impl UdpTransport { so_rcvbuf: usize, so_sndbuf: usize, ) -> Result { + tracing::info!("Connecting to udp://{}", addr); let socket = Socket::new( Domain::for_address(addr), Type::DGRAM, diff --git a/crates/client/src/runtime/transport/ws.rs b/crates/client/src/runtime/transport/ws.rs new file mode 100644 index 0000000..721b135 --- /dev/null +++ b/crates/client/src/runtime/transport/ws.rs @@ -0,0 +1,71 @@ +use crate::runtime::error::RuntimeError; +use crate::runtime::transport::{Transport, TransportReceiver, TransportSender}; +use async_trait::async_trait; +use futures::stream::{SplitSink, SplitStream}; +use futures::{SinkExt, StreamExt}; +use std::net::SocketAddr; +use std::sync::Arc; +use tokio::net::TcpStream; +use tokio::sync::Mutex; +use tokio_tungstenite::tungstenite::client::IntoClientRequest; +use tokio_tungstenite::tungstenite::Message; +use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream}; + + +pub struct WsTransport { + write: Arc>, Message>>>, + read: Arc>>>>, +} + +impl WsTransport { + pub async fn connect(addr: SocketAddr) -> Result { + tracing::info!("connecting to ws://{}", addr); + let request = format!("ws://{addr}").into_client_request().unwrap(); + let (ws_stream, _) = connect_async(request) + .await + .map_err(|e| RuntimeError::IO(format!( + "Failed to connect to WebSocket server: {}", e + )))?; + + let (write, read) = ws_stream.split(); + + Ok(Self { write: Arc::new(Mutex::new(write)) , read: Arc::new(Mutex::new(read)) }) + } +} + +#[async_trait] +impl TransportReceiver for WsTransport { + + #[inline(always)] + async fn recv(&self, buffer: &mut [u8]) -> std::io::Result { + let mut read = self.read.lock().await; + while let Some(Ok(msg)) = read.next().await { + if let Message::Binary(data) = msg { + let len = data.len().min(buffer.len()); + buffer[..len].copy_from_slice(&data[..len]); + return Ok(len); + } + } + Err(std::io::Error::new( + std::io::ErrorKind::ConnectionAborted, + "WebSocket connection closed" + )) + } +} + +#[async_trait] +impl TransportSender for WsTransport { + #[inline(always)] + async fn send(&self, data: &[u8]) -> std::io::Result { + self.write.lock().await + .send(Message::Binary(data.to_vec().into())) + .await + .map(|_| data.len()) + .map_err(|e| std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + e.to_string() + )) + } +} + +impl Transport for WsTransport{} diff --git a/crates/client/src/runtime/worker/mod.rs b/crates/client/src/runtime/worker/mod.rs index 2b85b8f..14ac64e 100644 --- a/crates/client/src/runtime/worker/mod.rs +++ b/crates/client/src/runtime/worker/mod.rs @@ -3,11 +3,16 @@ mod data; mod tun; mod transport; +#[cfg(feature = "udp")] +pub use crate::runtime::transport::udp::UdpTransport; + +#[cfg(feature = "ws")] +pub use crate::runtime::transport::ws::WsTransport; + use crate::{ - network::DefaultGateway, + network::RouteState, runtime::{ error::RuntimeError, - transport::udp::UdpTransport, worker::{ data::{data_tun_executor, data_udp_executor, keepalive_sender}, handshake::handshake_step, @@ -28,7 +33,7 @@ use std::{ use tokio::sync::broadcast::{Sender}; use tokio::sync::mpsc; use tracing::{info}; - +use crate::runtime::transport::Transport; pub(crate) async fn create( addr: SocketAddr, @@ -39,11 +44,24 @@ pub(crate) async fn create( iface_config: InterfaceConfig, ) -> Result<(), RuntimeError> { - let transport = Arc::new(UdpTransport::new( - addr, - runtime_config.so_rcvbuf, - runtime_config.so_sndbuf, - )?); + let transport: Arc = match () { + #[cfg(feature = "udp")] + _ if cfg!(feature = "udp") => { + info!("using UDP transport"); + Arc::new(UdpTransport::new( + addr, + runtime_config.so_rcvbuf, + runtime_config.so_sndbuf, + )?) + } + #[cfg(feature = "ws")] + _ if cfg!(feature = "ws") => { + info!("using WebSocket transport"); + Arc::new(WsTransport::connect(addr).await?) + } + _ => unreachable!("transport is not enabled, please enable transport features") + }; + let (udp_sender_tx, udp_sender_rx) = mpsc::channel::(runtime_config.out_udp_buf); let (tun_sender_tx, tun_sender_rx) = mpsc::channel::>(runtime_config.out_tun_buf); let (data_udp_tx, data_udp_rx) = mpsc::channel::(runtime_config.data_udp_buf); @@ -89,19 +107,17 @@ pub(crate) async fn create( )); let tun = Arc::new(setup_tun( - iface_config.name, + iface_config.name.clone(), iface_config.mtu, handshake_payload.ipaddr, 32, false ).await?); - - let mut gw = DefaultGateway::create( - &handshake_payload.ipaddr, - &addr.ip(), - true - )?; - + + // move from runtime + let mut routes = RouteState::new(addr.ip(), iface_config.name) + .build()?; + // Handle incoming TUN packets tokio::spawn(tun_listener( stop_tx.clone(), @@ -137,7 +153,7 @@ pub(crate) async fn create( let mut stop_rx = stop_tx.subscribe(); tokio::select! { _ = stop_rx.recv() => { - gw.restore(); + routes.restore(); info!("listener stopped") } } diff --git a/crates/client/src/runtime/worker/transport.rs b/crates/client/src/runtime/worker/transport.rs index d368708..cb63fc8 100644 --- a/crates/client/src/runtime/worker/transport.rs +++ b/crates/client/src/runtime/worker/transport.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use tokio::sync::broadcast::{Receiver, Sender}; use tokio::sync::mpsc; -use tracing::{error, warn}; +use tracing::{debug, error, warn}; use shared::protocol::{EncryptedData, Packet}; use crate::runtime::error::RuntimeError; use crate::runtime::transport::{TransportReceiver, TransportSender}; @@ -16,8 +16,9 @@ pub async fn transport_sender( tokio::select! { _ = stop.recv() => break, result = queue.recv() => match result { - Some(packet) => { - if let Err(err) = transport.send(&packet.to_bytes()).await { + Some(packet) => match transport.send(&packet.to_bytes()).await { + Ok(n) => debug!("sent transport packet with {} bytes", n), + Err(err) => { stop_sender.send(RuntimeError::IO(format!("failed to send udp: {}", err))).unwrap(); } }, @@ -33,21 +34,22 @@ pub async fn transport_listener( transport: Arc, data_receiver: mpsc::Sender ) { - let mut udp_buffer = [0u8; 65536]; + let mut transport_buffer = [0u8; 65536]; loop { tokio::select! { _ = stop.recv() => break, - result = transport.recv(&mut udp_buffer) => match result { + result = transport.recv(&mut transport_buffer) => match result { Ok(n) => { + debug!("received transport packet with {} bytes", n); if n == 0 { - warn!("received UDP packet with 0 bytes, dropping it"); + warn!("received transport packet with 0 bytes, dropping it"); continue; } if n > 65536 { - warn!("received UDP packet larger than 65536 bytes, dropping it"); + warn!("received transport packet larger than 65536 bytes, dropping it"); continue; } - match Packet::try_from(&udp_buffer[..n]) { + match Packet::try_from(&transport_buffer[..n]) { Ok(packet) => match packet { Packet::DataServer(data) => { if let Err(err) = data_receiver.send(data).await { @@ -64,13 +66,13 @@ pub async fn transport_listener( } }, Err(err) => { - warn!("failed to parse UDP packet: {}", err); + warn!("failed to parse transport packet: {}", err); continue; } } } Err(err) => { - stop_sender.send(RuntimeError::IO(format!("failed to receive udp: {}", err))).unwrap(); + stop_sender.send(RuntimeError::IO(format!("failed to receive transport: {}", err))).unwrap(); } } } diff --git a/crates/server/Cargo.toml b/crates/server/Cargo.toml index 8afb0c6..e3d13d2 100644 --- a/crates/server/Cargo.toml +++ b/crates/server/Cargo.toml @@ -6,7 +6,9 @@ autobins = false [features] default = ["cli"] -cli = ["clap", "inquire", "anstyle", "ctrlc"] +udp = ["socket2"] +ws = ["tokio-tungstenite", "futures", "socket2"] +cli = ["clap", "inquire", "anstyle", "ctrlc", "rocksdb"] [[bin]] name = "server" @@ -40,10 +42,13 @@ tracing-appender = { workspace = true } snow = { workspace = true } # IO -rocksdb = { version = "0.23", features = ["multi-threaded-cf"] } +rocksdb = { version = "0.23", features = ["multi-threaded-cf"], optional = true } tun-rs = { workspace = true } tokio = { workspace = true } -socket2 = "0.5.8" +socket2 = { workspace = true, optional = true } async-trait = "0.1" hex = "0.4.3" rand = "0.9.1" +tokio-tungstenite = { workspace = true, optional = true } +futures = { workspace = true, optional = true } + diff --git a/crates/server/src/bin/command/start.rs b/crates/server/src/bin/command/start.rs index fc71f77..a8c0850 100644 --- a/crates/server/src/bin/command/start.rs +++ b/crates/server/src/bin/command/start.rs @@ -1,5 +1,4 @@ -use std::{process, thread}; -use std::time::Duration; +use std::process; use clap::Parser; use tracing::{debug, error, info, warn}; use server::config::Config; @@ -70,8 +69,8 @@ impl StartCmd { debug!("stop signal not sent from Ctrl-C handler: {}", err); } } - thread::sleep(Duration::from_secs(1)); - process::exit(0); + // thread::sleep(Duration::from_secs(1)); + // process::exit(0); }).expect("error setting Ctrl-C handler"); if let Err(errors) = runtime.run().await { diff --git a/crates/server/src/runtime/mod.rs b/crates/server/src/runtime/mod.rs index e9d64ff..0ae82ed 100644 --- a/crates/server/src/runtime/mod.rs +++ b/crates/server/src/runtime/mod.rs @@ -18,14 +18,20 @@ use self::{ use tokio::runtime::Builder; use tokio::sync::{broadcast}; +use tracing::info; use crate::config::{Config, RuntimeConfig}; use shared::{ keys::handshake::{PublicKey, SecretKey}, network::set_ipv4_forwarding }; use shared::tun::setup_tun; +use crate::runtime::transport::Transport; +#[cfg(feature = "udp")] use crate::runtime::transport::udp::UdpTransport; +#[cfg(feature = "ws")] +use crate::runtime::transport::ws::WsTransport; + pub struct Runtime { sock: SocketAddr, sk: SecretKey, @@ -75,12 +81,6 @@ impl Runtime { false => self.config.workers }; - tracing::info!( - "Runtime running on udp://{} with {} workers", - self.sock, - workers - ); - set_ipv4_forwarding(true).map_err(|err| vec![RuntimeError::from(err)])?; let tun = setup_tun( @@ -91,12 +91,33 @@ impl Runtime { true ).await.map_err(|err| vec![RuntimeError::from(err)])?; - let mut transports = UdpTransport::new_pool( - self.sock, - self.config.so_rcvbuf, - self.config.so_sndbuf, - workers - ).map_err(|err| vec![RuntimeError::from(err)])?; + let mut transports: Vec> = match () { + #[cfg(feature = "udp")] + _ if cfg!(feature = "udp") => { + UdpTransport::new_pool( + self.sock, + self.config.so_rcvbuf, + self.config.so_sndbuf, + workers + ).map_err(|err| vec![err])? + .into_iter() + .map(|t| Arc::new(t) as Arc) + .collect() + } + #[cfg(feature = "ws")] + _ if cfg!(feature = "ws") => { + WsTransport::new_pool( + self.sock, + self.config.so_rcvbuf, + self.config.so_sndbuf, + workers + ).map_err(|err| vec![err])? + .into_iter() + .map(|t| Arc::new(t) as Arc) + .collect() + } + _ => unreachable!("transport is not selected, please enable one of transport features") + }; let rt = Builder::new_multi_thread() .worker_threads(workers) @@ -114,7 +135,7 @@ impl Runtime { let tun = tun.try_clone().map_err(|err| vec![RuntimeError::Tun( format!("failed to clone tun device: {}", err) )])?; - let transport = Arc::new(transports.pop().unwrap()); // unwrap is safe here + let transport = transports.pop().unwrap(); // unwrap is safe here let config = self.config.clone(); let handle = rt.spawn(async move { @@ -144,7 +165,7 @@ impl Runtime { // session cleanup let session = self.config.session.clone().unwrap_or_default(); if session.timeout != 0 { - tracing::info!("session cleanup worker started"); + info!("session cleanup worker started"); tokio::spawn(session::worker::run( self.stop_tx.clone(), self.sessions.clone(), @@ -152,7 +173,7 @@ impl Runtime { Duration::from_secs(session.cleanup_interval as u64), )); } else { - tracing::info!("session cleanup worker disabled"); + info!("session cleanup worker disabled"); } let mut errors = Vec::new(); diff --git a/crates/server/src/runtime/transport.rs b/crates/server/src/runtime/transport.rs index 6baeec4..f32aead 100644 --- a/crates/server/src/runtime/transport.rs +++ b/crates/server/src/runtime/transport.rs @@ -1,5 +1,10 @@ +#[cfg(feature = "udp")] pub mod udp; +#[cfg(feature = "ws")] +pub mod ws; + +use std::any::Any; use std::io; use std::net::SocketAddr; use async_trait::async_trait; @@ -14,4 +19,6 @@ pub trait TransportReceiver: Send + Sync { async fn recv_from(&self, buffer: &mut [u8]) -> io::Result<(usize, SocketAddr)>; } -pub trait Transport: TransportSender + TransportReceiver{} \ No newline at end of file +pub trait Transport: TransportSender + TransportReceiver{ + fn as_any(&self) -> &dyn Any; +} diff --git a/crates/server/src/runtime/transport/udp.rs b/crates/server/src/runtime/transport/udp.rs index 7f36d1f..55a3531 100644 --- a/crates/server/src/runtime/transport/udp.rs +++ b/crates/server/src/runtime/transport/udp.rs @@ -1,9 +1,11 @@ +use std::any::Any; use crate::runtime::error::RuntimeError; use crate::runtime::transport::{Transport, TransportReceiver, TransportSender}; use async_trait::async_trait; use socket2::{Domain, Protocol, Socket, Type}; use std::net::SocketAddr; use tokio::net::UdpSocket; +use tracing::info; pub struct UdpTransport { socket: UdpSocket @@ -29,15 +31,23 @@ impl UdpTransport { socket.set_send_buffer_size(so_sndbuf)?; socket.set_tos(0b101110 << 2)?; socket.bind(&addr.into())?; + + info!( + "Runtime running on udp://{} with {} workers", + addr, + count + ); let mut sockets = Vec::with_capacity(count); - for i in 0..count { + for i in 0..count - 1 { let cloned_raw_socket = socket.try_clone().map_err(|err| { RuntimeError::IO(format!("clone socket #{}: {}", i + 1, err)) })?.into(); sockets.push(Self { socket: UdpSocket::from_std(cloned_raw_socket)? }); } + + sockets.push(Self { socket: UdpSocket::from_std(socket.into())? }); Ok(sockets) } @@ -60,4 +70,8 @@ impl TransportSender for UdpTransport { } } -impl Transport for UdpTransport{} +impl Transport for UdpTransport{ + fn as_any(&self) -> &dyn Any { + self + } +} diff --git a/crates/server/src/runtime/transport/ws.rs b/crates/server/src/runtime/transport/ws.rs new file mode 100644 index 0000000..a573e37 --- /dev/null +++ b/crates/server/src/runtime/transport/ws.rs @@ -0,0 +1,151 @@ +use std::any::Any; +use crate::runtime::error::RuntimeError; +use crate::runtime::transport::{Transport, TransportReceiver, TransportSender}; +use async_trait::async_trait; +use dashmap::DashMap; +use futures::stream::SplitSink; +use futures::{SinkExt, StreamExt}; +use socket2::{Domain, Protocol, Socket, Type}; +use std::net::SocketAddr; +use std::sync::Arc; +use tokio::net::{TcpListener, TcpStream}; +use tokio::sync::{mpsc, Mutex}; +use tokio_tungstenite::tungstenite::{Bytes, Message}; +use tokio_tungstenite::{accept_async, WebSocketStream}; +use tracing::info; + +pub struct WsTransport { + listener: TcpListener, + active_connections: Arc, Message>>>, + message_queue: Arc>>, + message_sender: mpsc::UnboundedSender<(Bytes, SocketAddr)> +} + +impl WsTransport { + pub fn new_pool( + addr: SocketAddr, + so_rcvbuf: usize, + so_sndbuf: usize, + count: usize, + ) -> Result, RuntimeError> { + let socket = Socket::new( + Domain::for_address(addr), + Type::STREAM, + Some(Protocol::TCP), + )?; + + socket.set_nonblocking(true)?; + socket.set_reuse_port(true)?; + socket.set_reuse_address(true)?; + socket.set_recv_buffer_size(so_rcvbuf)?; + socket.set_send_buffer_size(so_sndbuf)?; + socket.set_tos(0b101110 << 2)?; + socket.bind(&addr.into())?; + socket.listen(1024)?; + + info!( + "Runtime running on ws://{} with {} workers", + addr, + count + ); + + let mut listeners = Vec::with_capacity(count); + for _ in 0..count - 1 { + let cloned = socket.try_clone()?; + let listener = TcpListener::from_std(cloned.into())?; + let (sender, receiver) = mpsc::unbounded_channel(); + listeners.push(Self { + listener, + active_connections: Arc::new(DashMap::new()), + message_queue: Arc::new(Mutex::new(receiver)), + message_sender: sender, + }); + } + + let (sender, receiver) = mpsc::unbounded_channel(); + let listener = TcpListener::from_std(socket.into())?; + listeners.push(Self { + listener, + active_connections: Arc::new(DashMap::new()), + message_queue: Arc::new(Mutex::new(receiver)), + message_sender: sender, + }); + Ok(listeners) + } + + pub async fn start(&self) -> Result<(), RuntimeError> { + loop { + let (tcp_stream, addr) = self.listener.accept().await?; + let message_sender = self.message_sender.clone(); + let connections = self.active_connections.clone(); + tokio::spawn(async move { + let ws_stream = match accept_async(tcp_stream).await { + Ok(ws) => ws, + Err(e) => { + eprintln!("WebSocket handshake error: {}", e); + return; + } + }; + let (write, read) = ws_stream.split(); + connections.insert(addr, write); + // Обработка входящих сообщений + tokio::spawn(async move { + let mut read = read; + while let Some(Ok(msg)) = read.next().await { + if let Message::Binary(data) = msg { + let _ = message_sender.send((data, addr)); + } + } + connections.remove(&addr); + }); + }); + } + } +} + + +#[async_trait] +impl TransportReceiver for WsTransport { + #[inline(always)] + async fn recv_from(&self, buffer: &mut [u8]) -> std::io::Result<(usize, SocketAddr)> { + match self.message_queue.lock().await.recv().await { // todo mutex + Some((data, addr)) => { + let len = data.len().min(buffer.len()); + buffer[..len].copy_from_slice(&data[..len]); + Ok((len, addr)) + } + None => Err(std::io::Error::new(std::io::ErrorKind::BrokenPipe, "Channel closed")), + } + } +} + +#[async_trait] +impl TransportSender for WsTransport { + #[inline(always)] + async fn send_to(&self, data: &[u8], addr: &SocketAddr) -> std::io::Result { + if let Some(mut writer) = self.active_connections.get_mut(addr) { + writer.value_mut().send(Message::Binary(data.to_vec().into())).await + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; + Ok(data.len()) + } else { + Err(std::io::Error::new(std::io::ErrorKind::AddrNotAvailable, "Address not found")) + } + } +} + +impl Transport for WsTransport { + fn as_any(&self) -> &dyn Any { + self + } +} + +impl dyn Transport { + pub fn downcast(self: Arc) -> Result, Arc> { + if self.as_any().is::() { + let ptr = Arc::into_raw(self); + Ok(unsafe { Arc::from_raw(ptr as *const T) }) + } else { + Err(self) + } + } +} \ No newline at end of file diff --git a/crates/server/src/runtime/worker/data.rs b/crates/server/src/runtime/worker/data.rs index 41ef65f..b9fafcf 100644 --- a/crates/server/src/runtime/worker/data.rs +++ b/crates/server/src/runtime/worker/data.rs @@ -43,10 +43,10 @@ fn encode_body(body: &DataServerBody, state: &StatelessTransportState) -> anyhow Ok(encrypted_buffer[..encrypted_len].to_vec().into()) } -pub(super) async fn data_udp_executor( +pub(super) async fn data_transport_executor( mut stop: Receiver, mut queue: mpsc::Receiver<(SessionId, EncryptedData, SocketAddr)>, - udp_tx: mpsc::Sender<(Packet, SocketAddr)>, + transport_tx: mpsc::Sender<(Packet, SocketAddr)>, tun_tx: mpsc::Sender>, sessions: Sessions, inf_sessions_timeout: bool, @@ -69,8 +69,8 @@ pub(super) async fn data_udp_executor( if !inf_sessions_timeout { sessions.touch(sid) } - if let Err(e) = udp_tx.send((Packet::DataServer(value), addr)).await { - error!("failed to send server data packet to udp queue: {}", e); + if let Err(e) = transport_tx.send((Packet::DataServer(value), addr)).await { + error!("failed to send server data packet to transport queue: {}", e); } }, Err(e) => { @@ -97,7 +97,7 @@ pub(super) async fn data_udp_executor( None => warn!("[{}] received data packet for unknown session {}", addr, sid) }, None => { - error!("data_udp_executor channel is closed"); + error!("data_transport_executor channel is closed"); break } } @@ -108,7 +108,7 @@ pub(super) async fn data_udp_executor( pub(super) async fn data_tun_executor( mut stop: Receiver, mut queue: mpsc::Receiver<(Vec, HolyIp)>, - udp_tx: mpsc::Sender<(Packet, SocketAddr)>, + transport_tx: mpsc::Sender<(Packet, SocketAddr)>, sessions: Sessions, ) { loop { @@ -118,8 +118,8 @@ pub(super) async fn data_tun_executor( Some((packet, holy_ip)) => match sessions.get(&holy_ip) { Some(session) => match encode_body(&DataServerBody::Payload(packet.into()), &session.state) { Ok(body) => { - if let Err(e) = udp_tx.send((Packet::DataServer(body), session.sock_addr())).await { - error!("failed to send server data packet to udp queue: {}", e); + if let Err(e) = transport_tx.send((Packet::DataServer(body), session.sock_addr())).await { + error!("failed to send server data packet to transport queue: {}", e); } }, Err(err) => warn!("[{}] failed to encode tun data packet (sid: {}): {}", session.sock_addr(), session.id, err) diff --git a/crates/server/src/runtime/worker/mod.rs b/crates/server/src/runtime/worker/mod.rs index 803ea89..10c2556 100644 --- a/crates/server/src/runtime/worker/mod.rs +++ b/crates/server/src/runtime/worker/mod.rs @@ -1,7 +1,7 @@ mod handshake; mod data; mod tun; -mod udp; +mod transport; use super::session::HolyIp; use super::{ @@ -10,10 +10,10 @@ use super::{ }; use crate::config::RuntimeConfig; use crate::runtime::worker::{ - data::{data_tun_executor, data_udp_executor}, + data::{data_tun_executor, data_transport_executor}, handshake::handshake_executor, tun::{tun_listener, tun_sender}, - udp::{udp_listener, udp_sender} + transport::{transport_listener, transport_sender} }; use dashmap::DashMap; use shared::keys::handshake::{PublicKey, SecretKey}; @@ -42,18 +42,26 @@ pub(crate) async fn create( let tun = Arc::new(tun); - let (out_udp_tx, out_udp_rx) = mpsc::channel::<(Packet, SocketAddr)>(config.out_udp_buf); + let (out_transport_tx, out_transport_rx) = mpsc::channel::<(Packet, SocketAddr)>(config.out_udp_buf); let (out_tun_tx, out_tun_rx) = mpsc::channel::>(config.out_tun_buf); let (handshake_tx, handshake_rx) = mpsc::channel::<(EncryptedHandshake, SocketAddr)>(config.handshake_buf); - let (data_udp_tx, data_udp_rx) = mpsc::channel::<(SessionId, EncryptedData, SocketAddr)>(config.data_udp_buf); + let (data_transport_tx, data_transport_rx) = mpsc::channel::<(SessionId, EncryptedData, SocketAddr)>(config.data_udp_buf); let (data_tun_tx, data_tun_rx) = mpsc::channel::<(Vec, HolyIp)>(config.data_tun_buf); + #[cfg(feature = "ws")] + { + use crate::runtime::transport::ws::WsTransport; + if let Ok(ws_transport) = transport.clone().downcast::() { + tokio::spawn(async move { + ws_transport.start().await + }); + } + } + // Handle incoming transport packets + tokio::spawn(transport_listener(stop_tx.subscribe(), transport.clone(), handshake_tx, data_transport_tx)); - // Handle incoming UDP packets - tokio::spawn(udp_listener(stop_tx.subscribe(), transport.clone(), handshake_tx, data_udp_tx)); - - // Handle outgoing UDP packets - tokio::spawn(udp_sender(stop_tx.subscribe(), transport.clone(), out_udp_rx)); + // Handle outgoing transport packets + tokio::spawn(transport_sender(stop_tx.subscribe(), transport.clone(), out_transport_rx)); // Handle incoming TUN packets tokio::spawn(tun_listener(stop_tx.subscribe(), tun.clone(), data_tun_tx)); @@ -65,15 +73,15 @@ pub(crate) async fn create( tokio::spawn(handshake_executor( stop_tx.subscribe(), handshake_rx, - out_udp_tx.clone(), + out_transport_tx.clone(), known_clients.clone(), sessions.clone(), sk )); - tokio::spawn(data_udp_executor( + tokio::spawn(data_transport_executor( stop_tx.subscribe(), - data_udp_rx, - out_udp_tx.clone(), + data_transport_rx, + out_transport_tx.clone(), out_tun_tx.clone(), sessions.clone(), config.session.unwrap_or_default().timeout == 0 @@ -82,7 +90,7 @@ pub(crate) async fn create( tokio::spawn(data_tun_executor( stop_tx.subscribe(), data_tun_rx, - out_udp_tx.clone(), + out_transport_tx.clone(), sessions.clone(), )); @@ -133,7 +141,7 @@ pub(crate) async fn create( // tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed"); // // let (stop_tx, _) = broadcast::channel::(1); -// let (out_udp_tx, mut out_udp_rx) = mpsc::channel::<(Packet, SocketAddr)>(1000); +// let (out_transport_tx, mut out_transport_rx) = mpsc::channel::<(Packet, SocketAddr)>(1000); // let (handshake_tx, handshake_rx) = mpsc::channel::<(storage::packet::Handshake, SocketAddr)>(1000); // let (data_tx, data_rx) = mpsc::channel::<(storage::packet::DataPacket, SocketAddr)>(1000); // @@ -147,7 +155,7 @@ pub(crate) async fn create( // tokio::spawn(handshake_executor( // stop_tx.subscribe(), // handshake_rx, -// out_udp_tx.clone(), +// out_transport_tx.clone(), // known_clients.clone(), // sessions.clone(), // server_sk @@ -155,7 +163,7 @@ pub(crate) async fn create( // tokio::spawn(data_executor( // stop_tx.subscribe(), // data_rx, -// out_udp_tx.clone(), +// out_transport_tx.clone(), // sessions.clone(), // Some( // std::sync::Arc::new(|req| { @@ -192,7 +200,7 @@ pub(crate) async fn create( // handshake_tx.send((handshake, client_sock)).await?; // // // [step 2] Server Complete -// let (packet, s) = out_udp_rx.recv().await.unwrap(); +// let (packet, s) = out_transport_rx.recv().await.unwrap(); // // // // [step 3] Client Complete @@ -214,7 +222,7 @@ pub(crate) async fn create( // data_tx.send((packet, client_sock)).await?; // // // Server -// let (packet, _) = out_udp_rx.recv().await.unwrap(); +// let (packet, _) = out_transport_rx.recv().await.unwrap(); // // // Client // let body = match packet { diff --git a/crates/server/src/runtime/worker/udp.rs b/crates/server/src/runtime/worker/transport.rs similarity index 77% rename from crates/server/src/runtime/worker/udp.rs rename to crates/server/src/runtime/worker/transport.rs index 27e93a2..3927b0a 100644 --- a/crates/server/src/runtime/worker/udp.rs +++ b/crates/server/src/runtime/worker/transport.rs @@ -8,15 +8,15 @@ use shared::session::SessionId; use crate::runtime::error::RuntimeError; use crate::runtime::transport::{TransportReceiver, TransportSender}; -pub async fn udp_sender( +pub async fn transport_sender( mut stop: Receiver, transport: Arc, - mut out_udp_rx: mpsc::Receiver<(Packet, SocketAddr)> + mut out_transport_rx: mpsc::Receiver<(Packet, SocketAddr)> ) { loop { tokio::select! { _ = stop.recv() => break, - result = out_udp_rx.recv() => match result { + result = out_transport_rx.recv() => match result { Some((data, client_addr)) => { match transport.send_to(&data.to_bytes(), &client_addr).await { Ok(len) => { @@ -35,28 +35,29 @@ pub async fn udp_sender( } } -pub async fn udp_listener( +pub async fn transport_listener( mut stop: Receiver, transport: Arc, handshake_tx: mpsc::Sender<(EncryptedHandshake, SocketAddr)>, data_tx: mpsc::Sender<(SessionId, EncryptedData, SocketAddr)> ) { - let mut udp_buffer = [0u8; 65536]; + let mut buffer = [0u8; 65536]; loop { tokio::select! { _ = stop.recv() => break, - result = transport.recv_from(&mut udp_buffer) => { + result = transport.recv_from(&mut buffer) => { match result { Ok((n, client_addr)) => { + debug!("received transport packet from {}: len: {}", client_addr, n); if n == 0 { - warn!("received UDP packet from {} with 0 bytes, dropping it", client_addr); + warn!("received transport packet from {} with 0 bytes, dropping it", client_addr); continue; } if n > 65536 { - warn!("received UDP packet from {} larger than 65536 bytes, dropping it", client_addr); + warn!("received transport packet from {} larger than 65536 bytes, dropping it", client_addr); continue; } - match Packet::try_from(&udp_buffer[..n]) { + match Packet::try_from(&buffer[..n]) { Ok(packet) => match packet { Packet::HandshakeInitial(handshake) => { if let Err(e) = handshake_tx.send((handshake, client_addr)).await { @@ -74,12 +75,12 @@ pub async fn udp_listener( } }, Err(e) => { - warn!("failed to parse UDP packet from {}: {}", client_addr, e); + warn!("failed to parse transport packet from {}: {}", client_addr, e); continue; } } } - Err(e) => warn!("failed to receive udp: {}", e) + Err(e) => warn!("failed to receive transport: {}", e) } } }