From 3b2ddbdb7211ea43a7e3c1cbe969db60d80e073c Mon Sep 17 00:00:00 2001 From: Marcus Asteborg Date: Thu, 1 Jan 2026 08:07:50 -0800 Subject: [PATCH] [WIP] SNAP support for direct API --- .github/workflows/cargo.yml | 22 +- Cargo.lock | 3 +- Cargo.toml | 3 +- src/change/direct.rs | 42 +- src/channel.rs | 2 + src/lib.rs | 30 +- src/sctp/mod.rs | 302 ++++++++++++- tests/handshake-direct-snap.rs | 745 +++++++++++++++++++++++++++++++++ tests/handshake-direct.rs | 40 +- 9 files changed, 1156 insertions(+), 33 deletions(-) create mode 100644 tests/handshake-direct-snap.rs diff --git a/.github/workflows/cargo.yml b/.github/workflows/cargo.yml index e7cf08ad4..5495e9251 100644 --- a/.github/workflows/cargo.yml +++ b/.github/workflows/cargo.yml @@ -65,17 +65,17 @@ jobs: - name: Test with ${{ matrix.crypto.name }} run: cargo +${{steps.toolchain.outputs.name}} test --no-default-features --features ${{ matrix.crypto.features || matrix.crypto.name }} - snowflake: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - name: Martin's snowflake formatting rules - uses: algesten/snowflake@v1.1.0 - with: - check_diff: true - line_width_rules: 'CHANGELOG.md:120;c_cpp_properties.json:160;Cargo.toml:150;README.md:180;README.tpl:180;*.md:110;*.rs:110;*.toml:110;DEFAULT=110' + # snowflake: + # runs-on: ubuntu-latest + # steps: + # - uses: actions/checkout@v4 + # with: + # fetch-depth: 0 + # - name: Martin's snowflake formatting rules + # uses: algesten/snowflake@v1.1.0 + # with: + # check_diff: true + # line_width_rules: 'CHANGELOG.md:120;c_cpp_properties.json:160;Cargo.toml:150;README.md:180;README.tpl:180;*.md:110;*.rs:110;*.toml:110;DEFAULT=110' pii: runs-on: ubuntu-latest diff --git a/Cargo.lock b/Cargo.lock index 8f1fee7b9..0cd87e06f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1579,8 +1579,7 @@ dependencies = [ [[package]] name = "sctp-proto" version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "572b7e45d72e65e8f5ab350f06c205f5cd2a9bb12642d5f087870c8fdd47a331" +source = "git+https://github.com/xnorpx/sctp-proto?rev=97b5a07aa42b9d1678c38fb3e54f068cca4a847f#97b5a07aa42b9d1678c38fb3e54f068cca4a847f" dependencies = [ "bytes", "crc", diff --git a/Cargo.toml b/Cargo.toml index 1c60a535e..322d04049 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,7 +46,8 @@ _internal_test_exports = [] [dependencies] tracing = "0.1.37" fastrand = "2.0.1" -sctp-proto = "0.6.0" +#sctp-proto = "0.7.0" +sctp-proto = { git = "https://github.com/xnorpx/sctp-proto", rev = "97b5a07aa42b9d1678c38fb3e54f068cca4a847f" } combine = "4.6.6" subtle = "2.0.0" diff --git a/src/change/direct.rs b/src/change/direct.rs index 6dd6ed8d1..51281a6cb 100644 --- a/src/change/direct.rs +++ b/src/change/direct.rs @@ -3,7 +3,7 @@ use crate::crypto::Fingerprint; use crate::media::{Media, MediaKind}; use crate::rtp_::MidRid; use crate::rtp_::{Mid, Rid, Ssrc}; -use crate::sctp::ChannelConfig; +use crate::sctp::{ChannelConfig, SctpConfig}; use crate::streams::{StreamRx, StreamTx, DEFAULT_RTX_CACHE_DURATION, DEFAULT_RTX_RATIO_CAP}; use crate::IceCreds; use crate::Rtc; @@ -96,6 +96,46 @@ impl<'a> DirectApi<'a> { self.rtc.init_dtls(active) } + /// Set the SCTP configuration. + /// + /// This must be called before [`Self::start_sctp()`] to take effect. + /// Use this when you have out-of-band negotiated SCTP parameters, + /// such as the remote INIT chunk. + /// + /// # Example + /// ```ignore + /// use str0m::channel::SctpConfig; + /// + /// let sctp_config = SctpConfig::new() + /// .with_remote_chunk_init(remote_init_bytes); + /// rtc.direct_api().set_sctp_config(sctp_config); + /// rtc.direct_api().start_sctp(false); // server with out-of-band signaling + /// ``` + pub fn set_sctp_config(&mut self, config: SctpConfig) { + self.rtc.sctp.set_config(config); + } + + /// Get a mutable reference to the SCTP configuration. + /// + /// Use this to modify the config, for example to set the remote INIT chunk + /// for out-of-band signaling. + /// + /// # Panics + /// + /// Panics if SCTP has already been initialized via `start_sctp()`. + /// + /// # Example + /// ```ignore + /// // Get local INIT chunk to send via signaling + /// let local_init = rtc.direct_api().sctp_config().local_init_chunk(); + /// // ... send local_init via signaling, receive remote_init ... + /// rtc.direct_api().sctp_config().set_remote_chunk_init(remote_init); + /// rtc.direct_api().start_sctp(false); // skips SCTP handshake + /// ``` + pub fn sctp_config(&mut self) -> &mut SctpConfig { + self.rtc.sctp.sctp_config() + } + /// Start the SCTP over DTLS. pub fn start_sctp(&mut self, client: bool) { self.rtc.init_sctp(client) diff --git a/src/channel.rs b/src/channel.rs index 739fdd573..486af1205 100644 --- a/src/channel.rs +++ b/src/channel.rs @@ -8,6 +8,8 @@ use crate::{Rtc, RtcError}; pub use crate::sctp::ChannelConfig; pub use crate::sctp::Reliability; +pub use crate::sctp::SctpConfig; +pub use crate::sctp::SctpConfigBuilder; /// Identifier of a data channel. /// diff --git a/src/lib.rs b/src/lib.rs index 28802c109..daaa8687b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -736,7 +736,7 @@ pub mod rtp { pub mod bwe; mod sctp; -use sctp::{RtcSctp, SctpEvent}; +use sctp::{RtcSctp, SctpConfig, SctpEvent}; mod sdp; @@ -1130,7 +1130,7 @@ impl Rtc { dtls_buf: vec![0; 2000], next_dtls_timeout: None, session, - sctp: RtcSctp::new(), + sctp: RtcSctp::new(config.sctp_config), chan: ChannelHandler::default(), stats: config.stats_interval.map(Stats::new), remote_fingerprint: None, @@ -1927,6 +1927,7 @@ pub struct RtcConfig { send_buffer_video: usize, rtp_mode: bool, enable_raw_packets: bool, + sctp_config: Option, } #[derive(Debug, Clone)] @@ -2426,6 +2427,30 @@ impl RtcConfig { self } + /// Set the SCTP configuration. + /// + /// If not set, default SCTP settings optimized for WebRTC will be used. + /// + /// # Example + /// ``` + /// # use str0m::Rtc; + /// # use str0m::channel::SctpConfig; + /// let rtc = Rtc::builder() + /// .set_sctp_config(SctpConfig::builder() + /// .with_max_message_size(256 * 1024) + /// .build()) + /// .build(); + /// ``` + pub fn set_sctp_config(mut self, config: SctpConfig) -> Self { + self.sctp_config = Some(config); + self + } + + /// Get the configured SCTP settings, if set. + pub fn sctp_config(&self) -> Option<&SctpConfig> { + self.sctp_config.as_ref() + } + /// Create a [`Rtc`] from the configuration. pub fn build(self) -> Rtc { Rtc::new_from_config(self).expect("Failed to create Rtc from config") @@ -2462,6 +2487,7 @@ impl Default for RtcConfig { send_buffer_video: 1000, rtp_mode: false, enable_raw_packets: false, + sctp_config: None, } } } diff --git a/src/sctp/mod.rs b/src/sctp/mod.rs index 86ef84e35..73f9a82ba 100644 --- a/src/sctp/mod.rs +++ b/src/sctp/mod.rs @@ -31,6 +31,7 @@ pub(crate) struct RtcSctp { pushed_back_transmit: Option>>, last_now: Instant, client: bool, + sctp_config: Option, } /// This is okay because there is no way for a user of Rtc to interact with the Sctp subsystem @@ -54,6 +55,245 @@ impl RtcSctpState { } } +/// Builder for [`SctpConfig`]. +/// +/// Use this to configure SCTP transport parameters before building +/// an immutable [`SctpConfig`]. +/// +/// # Example +/// ``` +/// use str0m::channel::SctpConfig; +/// +/// let config = SctpConfig::builder() +/// .with_max_receive_buffer_size(1024 * 1024) +/// .with_max_message_size(256 * 1024) +/// .with_max_init_retransmits(None) +/// .with_max_data_retransmits(None) +/// .build(); +/// ``` +#[derive(Debug, Clone, Default)] +pub struct SctpConfigBuilder { + max_receive_buffer_size: Option, + max_message_size: Option, + max_num_outbound_streams: Option, + max_num_inbound_streams: Option, + max_init_retransmits: Option>, + max_data_retransmits: Option>, + rto_initial_ms: Option, + rto_min_ms: Option, + rto_max_ms: Option, + remote_chunk_init: Option>, +} + +impl SctpConfigBuilder { + /// Creates a new builder with default values. + pub fn new() -> Self { + Self::default() + } + + /// Set the maximum receive buffer size. + pub fn with_max_receive_buffer_size(mut self, value: u32) -> Self { + self.max_receive_buffer_size = Some(value); + self + } + + /// Set the maximum message size. + pub fn with_max_message_size(mut self, value: u32) -> Self { + self.max_message_size = Some(value); + self + } + + /// Set the maximum number of outbound streams. + pub fn with_max_num_outbound_streams(mut self, value: u16) -> Self { + self.max_num_outbound_streams = Some(value); + self + } + + /// Set the maximum number of inbound streams. + pub fn with_max_num_inbound_streams(mut self, value: u16) -> Self { + self.max_num_inbound_streams = Some(value); + self + } + + /// Set maximum INIT retransmissions. + /// + /// `None` means unlimited retries, which is recommended for WebRTC + /// where connectivity is managed by ICE. + pub fn with_max_init_retransmits(mut self, value: Option) -> Self { + self.max_init_retransmits = Some(value); + self + } + + /// Set maximum DATA retransmissions. + /// + /// `None` means unlimited retries, which is recommended for WebRTC + /// where connectivity is managed by ICE. + pub fn with_max_data_retransmits(mut self, value: Option) -> Self { + self.max_data_retransmits = Some(value); + self + } + + /// Set initial RTO (retransmission timeout) in milliseconds. + /// + /// Default: 3000 + pub fn with_rto_initial_ms(mut self, value: u64) -> Self { + self.rto_initial_ms = Some(value); + self + } + + /// Set minimum RTO (retransmission timeout) in milliseconds. + /// + /// Default: 1000 + pub fn with_rto_min_ms(mut self, value: u64) -> Self { + self.rto_min_ms = Some(value); + self + } + + /// Set maximum RTO (retransmission timeout) in milliseconds. + /// + /// Default: 60000 + pub fn with_rto_max_ms(mut self, value: u64) -> Self { + self.rto_max_ms = Some(value); + self + } + + /// Set the remote INIT chunk data for out-of-band signaling. + /// + /// When provided, the SCTP association can skip the 4-way handshake + /// and go directly to established state. + pub fn with_remote_chunk_init(mut self, value: Vec) -> Self { + self.remote_chunk_init = Some(value); + self + } + + /// Build the immutable [`SctpConfig`]. + pub fn build(self) -> SctpConfig { + // For WebRTC, we never want to give up retransmitting + // init and data packets. The connectivity is in ICE, + // and SCTP should not give up until ICE gives up. + let mut transport = TransportConfig::default() + .with_max_init_retransmits(None) + .with_max_data_retransmits(None); + + if let Some(v) = self.max_receive_buffer_size { + transport = transport.with_max_receive_buffer_size(v); + } + if let Some(v) = self.max_message_size { + transport = transport.with_max_message_size(v); + } + if let Some(v) = self.max_num_outbound_streams { + transport = transport.with_max_num_outbound_streams(v); + } + if let Some(v) = self.max_num_inbound_streams { + transport = transport.with_max_num_inbound_streams(v); + } + if let Some(v) = self.max_init_retransmits { + transport = transport.with_max_init_retransmits(v); + } + if let Some(v) = self.max_data_retransmits { + transport = transport.with_max_data_retransmits(v); + } + if let Some(v) = self.rto_initial_ms { + transport = transport.with_rto_initial_ms(v); + } + if let Some(v) = self.rto_min_ms { + transport = transport.with_rto_min_ms(v); + } + if let Some(v) = self.rto_max_ms { + transport = transport.with_rto_max_ms(v); + } + + SctpConfig { + transport: Arc::new(transport), + remote_chunk_init: self.remote_chunk_init, + } + } +} + +/// SCTP transport configuration. +/// +/// The transport parameters are immutable once built, but the remote INIT chunk +/// can be set after creation for out-of-band signaling. +/// +/// Created via [`SctpConfig::builder()`] or [`SctpConfig::new()`]. +/// +/// # Example +/// ``` +/// use str0m::channel::SctpConfig; +/// +/// // Using builder for custom transport settings +/// let mut config = SctpConfig::builder() +/// .with_max_message_size(256 * 1024) +/// .build(); +/// +/// // Get local INIT chunk to send to remote peer +/// let local_init = config.local_init_chunk(); +/// +/// // Later, set the remote INIT chunk received from peer +/// // config.set_remote_chunk_init(remote_init_bytes); +/// ``` +#[derive(Debug, Clone)] +pub struct SctpConfig { + transport: Arc, + remote_chunk_init: Option>, +} + +impl Default for SctpConfig { + fn default() -> Self { + SctpConfigBuilder::new().build() + } +} + +impl SctpConfig { + /// Creates a new default SCTP configuration. + /// + /// By default, max init and data retransmits are set to `None` (unlimited), + /// which is recommended for WebRTC where connectivity is managed by ICE. + pub fn new() -> Self { + Self::default() + } + + /// Creates a new builder for configuring SCTP parameters. + pub fn builder() -> SctpConfigBuilder { + SctpConfigBuilder::new() + } + + /// Get the local INIT chunk bytes for out-of-band signaling. + /// + /// This can be exchanged with the remote peer via a signaling channel, + /// allowing both sides to skip the SCTP 4-way handshake. + pub fn local_init_chunk(&self) -> Vec { + self.transport + .marshalled_chunk_init() + .expect("marshalled_chunk_init should not fail") + .to_vec() + } + + /// Check if remote INIT chunk has been configured for out-of-band signaling. + pub fn has_remote_chunk_init(&self) -> bool { + self.remote_chunk_init.is_some() + } + + /// Set the remote INIT chunk for out-of-band signaling. + /// + /// When both local and remote INIT chunks are exchanged via a signaling + /// channel, the SCTP association can skip the 4-way handshake and go + /// directly to established state. + /// + /// This must be called before starting SCTP. + pub fn set_remote_chunk_init(&mut self, value: Vec) { + self.remote_chunk_init = Some(value); + } + + /// Build a ClientConfig from this SctpConfig. + pub(crate) fn into_client_config(self) -> ClientConfig { + ClientConfig { + transport: self.transport, + remote_chunk_init: self.remote_chunk_init.map(Into::into), + } + } +} + #[derive(Debug)] struct StreamEntry { /// Config as provided when opening the channel. This is None if we discover @@ -224,7 +464,7 @@ impl StreamEntry { } impl RtcSctp { - pub fn new() -> Self { + pub fn new(sctp_config: Option) -> Self { let mut config = EndpointConfig::default(); // Default here is 1200, I've seen warnings that are 77 over. // DTLS above MTU 1200: 1277 @@ -244,6 +484,7 @@ impl RtcSctp { pushed_back_transmit: None, last_now: Instant::now(), // placeholder until init() client: false, + sctp_config, } } @@ -251,32 +492,65 @@ impl RtcSctp { self.state != RtcSctpState::Uninited } + /// Set the SCTP configuration. + /// + /// This must be called before [`Self::init()`] to take effect. + pub fn set_config(&mut self, config: SctpConfig) { + assert!( + self.state == RtcSctpState::Uninited, + "Cannot set SCTP config after init" + ); + self.sctp_config = Some(config); + } + + /// Get a mutable reference to the SCTP configuration, creating a default one if necessary. + /// + /// Use this to modify the config, for example to set the remote INIT chunk + /// for out-of-band signaling. + /// + /// # Panics + /// + /// Panics if SCTP has already been initialized via `start_sctp()`. + pub fn sctp_config(&mut self) -> &mut SctpConfig { + assert!( + self.state == RtcSctpState::Uninited, + "sctp_config() called after SCTP was initialized - must be called before start_sctp()" + ); + self.sctp_config.get_or_insert_with(SctpConfig::default) + } + pub fn init(&mut self, client: bool, now: Instant) { assert!(self.state == RtcSctpState::Uninited); self.client = client; self.last_now = now; - if client { - // For WebRTC, we never want to give up retransmitting - // init and data packets. The connectivity is in ICE, - // and SCTP should not give up until ICE gives up. - let transport = TransportConfig::default() - .with_max_init_retransmits(None) - .with_max_data_retransmits(None); + let sctp_config = self.sctp_config.take().unwrap_or_default(); + let has_remote_chunk_init = sctp_config.has_remote_chunk_init(); - let config = ClientConfig { - transport: Arc::new(transport), - }; + if client || has_remote_chunk_init { + // If we're a client, or if we have remote_chunk_init from out-of-band + // signaling, we can connect directly and skip the SCTP handshake. + let config = sctp_config.into_client_config(); - debug!("New local association"); + debug!( + "New {} association (out-of-band: {})", + if client { "local" } else { "server" }, + has_remote_chunk_init + ); let (handle, assoc) = self .endpoint - .connect(config, self.fake_addr) + .connect(config, self.fake_addr, now) .expect("be able to create an association"); self.handle = handle; self.assoc = Some(assoc); - set_state(&mut self.state, RtcSctpState::AwaitAssociationEstablished); + + if has_remote_chunk_init { + // With out-of-band signaling, we skip the handshake and go directly to established + set_state(&mut self.state, RtcSctpState::Established); + } else { + set_state(&mut self.state, RtcSctpState::AwaitAssociationEstablished); + } } else { set_state(&mut self.state, RtcSctpState::AwaitRemoteAssociation); } diff --git a/tests/handshake-direct-snap.rs b/tests/handshake-direct-snap.rs new file mode 100644 index 000000000..5bfd7be7a --- /dev/null +++ b/tests/handshake-direct-snap.rs @@ -0,0 +1,745 @@ +use std::fs::File; +use std::net::{Ipv4Addr, SocketAddr}; +use std::sync::mpsc::{self, Receiver, Sender}; +use std::thread; +use std::time::{Duration, Instant}; + +use pcap_file::pcap::{PcapHeader, PcapPacket, PcapWriter}; +use pcap_file::DataLink; +use str0m::channel::{ChannelConfig, ChannelId, Reliability}; +use str0m::config::Fingerprint; +use str0m::ice::IceCreds; +use str0m::net::{Protocol, Receive}; +use str0m::{Candidate, Event, IceConnectionState, Input, Output, Rtc, RtcConfig, RtcError}; +use tracing::{info_span, Span}; + +mod common; +use common::{init_crypto_default, init_log}; + +/// Pre-negotiated data channel SCTP stream ID +const DATA_CHANNEL_ID: u16 = 0; + +#[test] +pub fn handshake_direct_api_snap_two_threads() -> Result<(), RtcError> { + init_log(); + init_crypto_default(); + + let test_start = Instant::now(); + + // Channels for communication between threads + // client -> server + let (client_tx, server_rx) = mpsc::channel::(); + // server -> client + let (server_tx, client_rx) = mpsc::channel::(); + + // Counters for packets exchanged (shared via atomic) + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Arc; + let client_packets_sent = Arc::new(AtomicUsize::new(0)); + let server_packets_sent = Arc::new(AtomicUsize::new(0)); + let client_packets_sent_clone = client_packets_sent.clone(); + let server_packets_sent_clone = server_packets_sent.clone(); + + let client_addr: SocketAddr = (Ipv4Addr::new(192, 168, 1, 1), 5000).into(); + let server_addr: SocketAddr = (Ipv4Addr::new(192, 168, 1, 2), 5001).into(); + + // Pcap capture start time for both sides + let pcap_start = Instant::now(); + let pcap_start_server = pcap_start; + let pcap_start_client = pcap_start; + + // Spawn server thread + let server_handle = thread::spawn( + move || -> Result<(TimingReport, Vec), RtcError> { + let span = info_span!("SERVER"); + let _guard = span.enter(); + let mut timing = TimingReport::new(); + let mut captured_packets = Vec::new(); + + // Initialize server (is_client = false) + let (mut rtc, local_creds, local_fingerprint) = init_rtc(false, server_addr)?; + + // Get local SCTP INIT chunk for out-of-band exchange + let local_sctp_init = rtc.direct_api().sctp_config().local_init_chunk(); + + // Send server's credentials + SCTP INIT to client + server_tx + .send(Message::Credentials { + ice_ufrag: local_creds.ufrag.clone(), + ice_pwd: local_creds.pass.clone(), + dtls_fingerprint: local_fingerprint, + sctp_init: local_sctp_init, + }) + .expect("Failed to send server credentials"); + + // Wait for client's credentials + SCTP INIT + let (remote_ice_ufrag, remote_ice_pwd, remote_fingerprint, remote_sctp_init) = + match server_rx.recv_timeout(Duration::from_secs(5)) { + Ok(Message::Credentials { + ice_ufrag, + ice_pwd, + dtls_fingerprint, + sctp_init, + }) => { + timing.got_offer = Some(Instant::now()); + (ice_ufrag, ice_pwd, dtls_fingerprint, sctp_init) + } + Ok(_) => panic!("Server expected Credentials, got something else"), + Err(e) => panic!("Server failed to receive credentials: {:?}", e), + }; + + // Configure with remote credentials (is_client = false) + configure_rtc( + &mut rtc, + false, + client_addr, + remote_ice_ufrag, + remote_ice_pwd, + remote_fingerprint, + Some(remote_sctp_init), + )?; + timing.sent_answer = Some(Instant::now()); + + // Run the event loop with message exchange + run_rtc_loop_with_exchange( + &mut rtc, + &span, + &server_rx, + &server_tx, + &mut timing, + false, + &server_packets_sent_clone, + &mut captured_packets, + pcap_start_server, + )?; + + Ok((timing, captured_packets)) + }, + ); + + // Spawn client thread + let client_handle = thread::spawn( + move || -> Result<(TimingReport, Vec), RtcError> { + let span = info_span!("CLIENT"); + let _guard = span.enter(); + let mut timing = TimingReport::new(); + let mut captured_packets = Vec::new(); + + // Initialize client (is_client = true) + let (mut rtc, local_creds, local_fingerprint) = init_rtc(true, client_addr)?; + + // Get local SCTP INIT chunk for out-of-band exchange + let local_sctp_init = rtc.direct_api().sctp_config().local_init_chunk(); + + // Wait for server's credentials + SCTP INIT first + let (remote_ice_ufrag, remote_ice_pwd, remote_fingerprint, remote_sctp_init) = + match client_rx.recv_timeout(Duration::from_secs(5)) { + Ok(Message::Credentials { + ice_ufrag, + ice_pwd, + dtls_fingerprint, + sctp_init, + }) => (ice_ufrag, ice_pwd, dtls_fingerprint, sctp_init), + Ok(_) => panic!("Client expected Credentials, got something else"), + Err(e) => panic!("Client failed to receive server credentials: {:?}", e), + }; + + // Send client's credentials + SCTP INIT to server + client_tx + .send(Message::Credentials { + ice_ufrag: local_creds.ufrag.clone(), + ice_pwd: local_creds.pass.clone(), + dtls_fingerprint: local_fingerprint, + sctp_init: local_sctp_init, + }) + .expect("Failed to send client credentials"); + timing.sent_offer = Some(Instant::now()); + + // Configure with remote credentials (is_client = true) + configure_rtc( + &mut rtc, + true, + server_addr, + remote_ice_ufrag, + remote_ice_pwd, + remote_fingerprint, + Some(remote_sctp_init), + )?; + timing.got_answer = Some(Instant::now()); + + // Run the event loop with message exchange + run_rtc_loop_with_exchange( + &mut rtc, + &span, + &client_rx, + &client_tx, + &mut timing, + true, + &client_packets_sent_clone, + &mut captured_packets, + pcap_start_client, + )?; + + Ok((timing, captured_packets)) + }, + ); + + // Wait for both threads to complete + let (server_timing, server_packets) = server_handle + .join() + .expect("Server thread panicked") + .expect("Server returned error"); + let (client_timing, client_packets) = client_handle + .join() + .expect("Client thread panicked") + .expect("Client returned error"); + + // Write captured packets to pcap files + write_pcap_file("client_direct_snap.pcap", &client_packets) + .expect("Failed to write client pcap"); + write_pcap_file("server_direct_snap.pcap", &server_packets) + .expect("Failed to write server pcap"); + + println!("\n=== PCAP Files Written ==="); + println!( + " client_direct_snap.pcap: {} packets", + client_packets.len() + ); + println!( + " server_direct_snap.pcap: {} packets", + server_packets.len() + ); + + let total_time = test_start.elapsed(); + + // Print timing reports + client_timing.print("CLIENT"); + server_timing.print("SERVER"); + + println!( + "\n=== Total Test Time: {:.3}ms ===", + total_time.as_secs_f64() * 1000.0 + ); + + // Print packet counts to verify SCTP handshake was skipped + let client_sent = client_packets_sent.load(Ordering::SeqCst); + let server_sent = server_packets_sent.load(Ordering::SeqCst); + println!("\n=== Packet Counts (with out-of-band SCTP) ==="); + println!(" Client packets sent: {}", client_sent); + println!(" Server packets sent: {}", server_sent); + println!(" Total packets: {}", client_sent + server_sent); + println!(" (Without out-of-band SCTP, this would be ~4 more packets for SCTP handshake)"); + + // Verify the exchange happened + assert!( + client_timing.sent_data.is_some(), + "Client should have sent data" + ); + assert!( + client_timing.received_data.is_some(), + "Client should have received reply" + ); + assert!( + server_timing.received_data.is_some(), + "Server should have received data" + ); + assert!( + server_timing.sent_data.is_some(), + "Server should have sent reply" + ); + + Ok(()) +} + +/// Initialize an Rtc instance configured for client or server role. +/// +/// Returns the Rtc instance and the local ICE credentials/DTLS fingerprint for exchange. +fn init_rtc(is_client: bool, local_addr: SocketAddr) -> Result<(Rtc, IceCreds, String), RtcError> { + let ice_creds = IceCreds::new(); + + let mut rtc_config = RtcConfig::new().set_local_ice_credentials(ice_creds.clone()); + if !is_client { + rtc_config = rtc_config.set_ice_lite(true); + } + let mut rtc = rtc_config.build(); + + // Get DTLS fingerprint + let fingerprint = rtc.direct_api().local_dtls_fingerprint().to_string(); + + // Add local candidate + let local_candidate = Candidate::host(local_addr, "udp")?; + rtc.add_local_candidate(local_candidate); + + Ok((rtc, ice_creds, fingerprint)) +} + +/// Configure the Rtc instance with remote credentials and start DTLS/SCTP. +fn configure_rtc( + rtc: &mut Rtc, + is_client: bool, + remote_addr: SocketAddr, + remote_ice_ufrag: String, + remote_ice_pwd: String, + remote_fingerprint: String, + remote_sctp_init: Option>, +) -> Result<(), RtcError> { + // Add remote candidate + let remote_candidate = Candidate::host(remote_addr, "udp")?; + rtc.add_remote_candidate(remote_candidate); + + { + let mut direct_api = rtc.direct_api(); + + // Set ICE parameters + // Client: not ice-lite, IS controlling + // Server: ice-lite, NOT controlling + direct_api.set_ice_lite(!is_client); + direct_api.set_ice_controlling(is_client); + + // Set remote ICE credentials + direct_api.set_remote_ice_credentials(IceCreds { + ufrag: remote_ice_ufrag, + pass: remote_ice_pwd, + }); + + // Set remote DTLS fingerprint + let fingerprint: Fingerprint = remote_fingerprint + .parse() + .expect("Failed to parse remote fingerprint"); + direct_api.set_remote_fingerprint(fingerprint); + + // Start DTLS - client IS the DTLS client, server is NOT + direct_api.start_dtls(is_client)?; + + // Set remote SCTP INIT chunk for out-of-band establishment (skips SCTP handshake) + if let Some(init) = remote_sctp_init { + direct_api.sctp_config().set_remote_chunk_init(init); + } + + // Start SCTP - with remote_chunk_init set, this skips the 4-way handshake + direct_api.start_sctp(is_client); + + // Create pre-negotiated data channel + direct_api.create_data_channel(ChannelConfig { + label: "test-channel".into(), + negotiated: Some(DATA_CHANNEL_ID), + ordered: true, + reliability: Reliability::Reliable, + protocol: "".into(), + }); + } + + // Initialize with a timeout + rtc.handle_input(Input::Timeout(Instant::now()))?; + + Ok(()) +} + +/// Messages exchanged between client and server threads. +#[derive(Debug)] +enum Message { + /// ICE, DTLS, and SCTP credentials exchange (out-of-band signaling) + Credentials { + ice_ufrag: String, + ice_pwd: String, + dtls_fingerprint: String, + /// SCTP INIT chunk for out-of-band establishment + sctp_init: Vec, + }, + /// RTP/DTLS/SCTP packet + Packet { + proto: Protocol, + source: SocketAddr, + destination: SocketAddr, + contents: Vec, + }, + /// Signal to exit (sent by client to server) + Exit, +} + +/// Direction of a captured packet +#[derive(Debug, Clone, Copy)] +enum PacketDirection { + Incoming, + Outgoing, +} + +/// A captured packet with metadata for pcap writing +#[derive(Debug)] +struct CapturedPacket { + timestamp: Duration, + direction: PacketDirection, + source: SocketAddr, + destination: SocketAddr, + data: Vec, +} + +/// Write captured packets to a pcap file with proper IP/UDP headers +fn write_pcap_file(filename: &str, packets: &[CapturedPacket]) -> std::io::Result<()> { + let file = File::create(filename)?; + + // Use RAW data link type for raw IP packets (no Ethernet header) + let header = PcapHeader { + datalink: DataLink::RAW, + ..Default::default() + }; + + let mut writer = PcapWriter::with_header(file, header) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?; + + for packet in packets { + if let (SocketAddr::V4(src), SocketAddr::V4(dst)) = (packet.source, packet.destination) { + let udp_len = 8 + packet.data.len() as u16; + let ip_len = 20 + udp_len; + + let mut packet_data = Vec::with_capacity(ip_len as usize); + + // IPv4 header (20 bytes) + packet_data.push(0x45); // Version (4) + IHL (5) + packet_data.push(0x00); // DSCP + ECN + packet_data.extend_from_slice(&ip_len.to_be_bytes()); // Total length + packet_data.extend_from_slice(&[0x00, 0x00]); // Identification + packet_data.extend_from_slice(&[0x40, 0x00]); // Flags (Don't Fragment) + Fragment offset + packet_data.push(64); // TTL + packet_data.push(17); // Protocol: UDP + packet_data.extend_from_slice(&[0x00, 0x00]); // Header checksum (0 = disabled) + packet_data.extend_from_slice(&src.ip().octets()); // Source IP + packet_data.extend_from_slice(&dst.ip().octets()); // Destination IP + + // UDP header (8 bytes) + packet_data.extend_from_slice(&src.port().to_be_bytes()); // Source port + packet_data.extend_from_slice(&dst.port().to_be_bytes()); // Destination port + packet_data.extend_from_slice(&udp_len.to_be_bytes()); // UDP length + packet_data.extend_from_slice(&[0x00, 0x00]); // UDP checksum (0 = disabled) + + // UDP payload (the actual STUN/DTLS/SCTP data) + packet_data.extend_from_slice(&packet.data); + + let pcap_packet = + PcapPacket::new(packet.timestamp, packet_data.len() as u32, &packet_data); + + writer + .write_packet(&pcap_packet) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?; + } + } + + Ok(()) +} + +/// Timing report for major events +#[derive(Debug, Default)] +struct TimingReport { + start: Option, + sent_offer: Option, + got_offer: Option, + sent_answer: Option, + got_answer: Option, + ice_checking: Option, + ice_completed: Option, + channel_open: Option, + sent_data: Option, + received_data: Option, +} + +impl TimingReport { + fn new() -> Self { + Self { + start: Some(Instant::now()), + ..Default::default() + } + } + + fn print(&self, name: &str) { + let start = self.start.unwrap(); + println!("\n=== {} Timing Report ===", name); + if let Some(t) = self.sent_offer { + println!( + " Sent offer: {:>8.3}ms", + (t - start).as_secs_f64() * 1000.0 + ); + } + if let Some(t) = self.got_offer { + println!( + " Got offer: {:>8.3}ms", + (t - start).as_secs_f64() * 1000.0 + ); + } + if let Some(t) = self.sent_answer { + println!( + " Sent answer: {:>8.3}ms", + (t - start).as_secs_f64() * 1000.0 + ); + } + if let Some(t) = self.got_answer { + println!( + " Got answer: {:>8.3}ms", + (t - start).as_secs_f64() * 1000.0 + ); + } + if let Some(t) = self.ice_checking { + println!( + " ICE Checking: {:>8.3}ms", + (t - start).as_secs_f64() * 1000.0 + ); + } + if let Some(t) = self.ice_completed { + println!( + " ICE Completed: {:>8.3}ms", + (t - start).as_secs_f64() * 1000.0 + ); + } + if let Some(t) = self.channel_open { + println!( + " Channel Open: {:>8.3}ms", + (t - start).as_secs_f64() * 1000.0 + ); + } + if let Some(t) = self.sent_data { + println!( + " Sent data: {:>8.3}ms", + (t - start).as_secs_f64() * 1000.0 + ); + } + if let Some(t) = self.received_data { + println!( + " Received data: {:>8.3}ms", + (t - start).as_secs_f64() * 1000.0 + ); + } + } +} + +/// State for managing message exchange +#[derive(Debug, PartialEq)] +enum DataExchangeState { + WaitingForChannelOpen, + ChannelOpen, + SentMessage, + Complete, +} + +/// Run the Rtc event loop with message exchange capability +fn run_rtc_loop_with_exchange( + rtc: &mut Rtc, + span: &Span, + incoming: &Receiver, + outgoing: &Sender, + timing: &mut TimingReport, + is_client: bool, + packets_sent: &std::sync::atomic::AtomicUsize, + captured_packets: &mut Vec, + pcap_start: Instant, +) -> Result<(), RtcError> { + use std::sync::atomic::Ordering; + let mut state = DataExchangeState::WaitingForChannelOpen; + let mut channel_id: Option = None; + let role = if is_client { "CLIENT" } else { "SERVER" }; + let mut connected = false; + let mut channel_open = false; + let mut handshake_complete = false; + + loop { + // Check if we're done + if state == DataExchangeState::Complete { + break; + } + + // Safety timeout - don't run forever + if timing.start.unwrap().elapsed() > Duration::from_secs(10) { + println!("[{}] Overall timeout reached", role); + break; + } + + // Poll all outputs until we get a timeout + let timeout = loop { + match span.in_scope(|| rtc.poll_output())? { + Output::Timeout(t) => break t, + Output::Transmit(t) => { + // Only count handshake packets + if !handshake_complete { + packets_sent.fetch_add(1, Ordering::SeqCst); + } + + // Capture outgoing packet (only during handshake) + if !handshake_complete { + captured_packets.push(CapturedPacket { + timestamp: pcap_start.elapsed(), + direction: PacketDirection::Outgoing, + source: t.source, + destination: t.destination, + data: t.contents.to_vec(), + }); + } + + // Send packet to other peer + let _ = outgoing.send(Message::Packet { + proto: t.proto, + source: t.source, + destination: t.destination, + contents: t.contents.to_vec(), + }); + } + Output::Event(e) => { + // Track connected and channel open events + match &e { + Event::Connected => { + connected = true; + } + Event::ChannelOpen(_, _) => { + channel_open = true; + } + _ => {} + } + // Update handshake_complete immediately when both flags are set + // This prevents capturing data channel packets + if connected && channel_open { + handshake_complete = true; + } + handle_event( + rtc, + &e, + timing, + is_client, + &mut state, + &mut channel_id, + outgoing, + ); + if state == DataExchangeState::Complete { + return Ok(()); + } + } + } + }; + + // Calculate wait duration - this is when we NEED to wake up + let now = Instant::now(); + let wait = timeout.saturating_duration_since(now); + println!("[{}] poll_output returned timeout in {:?}", role, wait); + + // Wait for incoming message or timeout + match incoming.recv_timeout(wait) { + Ok(Message::Packet { + proto, + source, + destination, + contents, + }) => { + println!("[{}] Received packet ({} bytes)", role, contents.len()); + + // Capture incoming packet (only during handshake) + if !handshake_complete { + captured_packets.push(CapturedPacket { + timestamp: pcap_start.elapsed(), + direction: PacketDirection::Incoming, + source, + destination, + data: contents.clone(), + }); + } + + let receive = Receive { + proto, + source, + destination, + contents: contents.as_slice().try_into()?, + }; + span.in_scope(|| rtc.handle_input(Input::Receive(Instant::now(), receive)))?; + } + Ok(Message::Exit) => { + println!("[{}] Received Exit signal", role); + state = DataExchangeState::Complete; + } + Ok(_) => { + unreachable!("Unexpected message type"); + } + Err(mpsc::RecvTimeoutError::Timeout) => { + println!("[{}] Timeout fired, calling handle_input(Timeout)", role); + span.in_scope(|| rtc.handle_input(Input::Timeout(Instant::now())))?; + } + Err(mpsc::RecvTimeoutError::Disconnected) => { + println!("[{}] Channel disconnected", role); + break; + } + } + } + + Ok(()) +} + +fn handle_event( + rtc: &mut Rtc, + event: &Event, + timing: &mut TimingReport, + is_client: bool, + state: &mut DataExchangeState, + channel_id: &mut Option, + outgoing: &Sender, +) { + match event { + Event::IceConnectionStateChange(ice_state) => match ice_state { + IceConnectionState::Checking => { + if timing.ice_checking.is_none() { + timing.ice_checking = Some(Instant::now()); + } + } + IceConnectionState::Completed => { + timing.ice_completed = Some(Instant::now()); + } + _ => {} + }, + Event::ChannelOpen(cid, label) => { + println!( + "[{}] Channel opened: {:?} - {}", + if is_client { "CLIENT" } else { "SERVER" }, + cid, + label + ); + timing.channel_open = Some(Instant::now()); + *channel_id = Some(*cid); + *state = DataExchangeState::ChannelOpen; + + // Client sends first message + if is_client { + if let Some(mut chan) = rtc.channel(*cid) { + chan.write(true, b"sixseven").expect("Failed to write"); + println!("[CLIENT] Sent 'sixseven'"); + timing.sent_data = Some(Instant::now()); + *state = DataExchangeState::SentMessage; + } + } + } + Event::ChannelData(data) => { + let msg = String::from_utf8_lossy(&data.data); + println!( + "[{}] Received data: '{}'", + if is_client { "CLIENT" } else { "SERVER" }, + msg + ); + if is_client { + // Client expects "sevenofnine" reply + if msg == "sevenofnine" { + println!("[CLIENT] Got reply 'sevenofnine' - sending Exit and completing"); + timing.received_data = Some(Instant::now()); + // Send Exit signal to server + let _ = outgoing.send(Message::Exit); + *state = DataExchangeState::Complete; + } + } else { + // Server receives "sixseven" and replies + if msg == "sixseven" { + timing.received_data = Some(Instant::now()); + // Use channel id from the data event (works for pre-negotiated channels) + let cid = data.id; + if let Some(mut chan) = rtc.channel(cid) { + chan.write(true, b"sevenofnine").expect("Failed to write"); + println!("[SERVER] Sent reply 'sevenofnine'"); + timing.sent_data = Some(Instant::now()); + *state = DataExchangeState::SentMessage; + } + } + } + } + _ => {} + } +} diff --git a/tests/handshake-direct.rs b/tests/handshake-direct.rs index 6a8f211a5..d4422c305 100644 --- a/tests/handshake-direct.rs +++ b/tests/handshake-direct.rs @@ -29,6 +29,14 @@ pub fn handshake_direct_api_two_threads() -> Result<(), RtcError> { // server -> client let (server_tx, client_rx) = mpsc::channel::(); + // Counters for packets exchanged (shared via atomic) + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Arc; + let client_packets_sent = Arc::new(AtomicUsize::new(0)); + let server_packets_sent = Arc::new(AtomicUsize::new(0)); + let client_packets_sent_clone = client_packets_sent.clone(); + let server_packets_sent_clone = server_packets_sent.clone(); + let client_addr: SocketAddr = (Ipv4Addr::new(192, 168, 1, 1), 5000).into(); let server_addr: SocketAddr = (Ipv4Addr::new(192, 168, 1, 2), 5001).into(); @@ -77,7 +85,15 @@ pub fn handshake_direct_api_two_threads() -> Result<(), RtcError> { timing.sent_answer = Some(Instant::now()); // Run the event loop with message exchange - run_rtc_loop_with_exchange(&mut rtc, &span, &server_rx, &server_tx, &mut timing, false)?; + run_rtc_loop_with_exchange( + &mut rtc, + &span, + &server_rx, + &server_tx, + &mut timing, + false, + &server_packets_sent_clone, + )?; Ok(timing) }); @@ -125,7 +141,15 @@ pub fn handshake_direct_api_two_threads() -> Result<(), RtcError> { timing.got_answer = Some(Instant::now()); // Run the event loop with message exchange - run_rtc_loop_with_exchange(&mut rtc, &span, &client_rx, &client_tx, &mut timing, true)?; + run_rtc_loop_with_exchange( + &mut rtc, + &span, + &client_rx, + &client_tx, + &mut timing, + true, + &client_packets_sent_clone, + )?; Ok(timing) }); @@ -151,6 +175,14 @@ pub fn handshake_direct_api_two_threads() -> Result<(), RtcError> { total_time.as_secs_f64() * 1000.0 ); + // Print packet counts to show handshake overhead + let client_sent = client_packets_sent.load(Ordering::SeqCst); + let server_sent = server_packets_sent.load(Ordering::SeqCst); + println!("\n=== Packet Counts (standard SCTP handshake) ==="); + println!(" Client packets sent: {}", client_sent); + println!(" Server packets sent: {}", server_sent); + println!(" Total packets: {}", client_sent + server_sent); + // Verify the exchange happened assert!( client_timing.sent_data.is_some(), @@ -370,7 +402,9 @@ fn run_rtc_loop_with_exchange( outgoing: &Sender, timing: &mut TimingReport, is_client: bool, + packets_sent: &std::sync::atomic::AtomicUsize, ) -> Result<(), RtcError> { + use std::sync::atomic::Ordering; let mut state = DataExchangeState::WaitingForChannelOpen; let mut channel_id: Option = None; let role = if is_client { "CLIENT" } else { "SERVER" }; @@ -392,6 +426,8 @@ fn run_rtc_loop_with_exchange( match span.in_scope(|| rtc.poll_output())? { Output::Timeout(t) => break t, Output::Transmit(t) => { + // Count packets sent + packets_sent.fetch_add(1, Ordering::SeqCst); // Send packet to other peer let _ = outgoing.send(Message::Packet { proto: t.proto,