diff --git a/iroh-relay/src/dns.rs b/iroh-relay/src/dns.rs index d70559385d..3228aed1fd 100644 --- a/iroh-relay/src/dns.rs +++ b/iroh-relay/src/dns.rs @@ -28,6 +28,9 @@ pub const N0_DNS_NODE_ORIGIN_PROD: &str = "dns.iroh.link"; /// The n0 testing DNS node origin, for testing. pub const N0_DNS_NODE_ORIGIN_STAGING: &str = "staging-dns.iroh.link"; +/// Percent of total delay to jitter. 20 means +/- 20% of delay. +const MAX_JITTER_PERCENT: u64 = 20; + /// Potential errors related to dns. #[common_fields({ backtrace: Option, @@ -503,7 +506,7 @@ async fn stagger_call< // NOTE: we add the 0 delay here to have a uniform set of futures. This is more performant than // using alternatives that allow futures of different types. for delay in std::iter::once(&0u64).chain(delays_ms) { - let delay = Duration::from_millis(*delay); + let delay = add_jitter(delay); let fut = f(); let staggered_fut = async move { time::sleep(delay).await; @@ -523,6 +526,19 @@ async fn stagger_call< Err(StaggeredError::new(errors)) } +fn add_jitter(delay: &u64) -> Duration { + // If delay is 0, return 0 immediately. + if *delay == 0 { + return Duration::ZERO; + } + + // Calculate jitter as a random value in the range of +/- MAX_JITTER_PERCENT of the delay. + let max_jitter = delay.saturating_mul(MAX_JITTER_PERCENT * 2) / 100; + let jitter = rand::random::() % max_jitter; + + Duration::from_millis(delay.saturating_sub(max_jitter / 2).saturating_add(jitter)) +} + #[cfg(test)] pub(crate) mod tests { use std::sync::atomic::AtomicUsize; @@ -548,4 +564,30 @@ pub(crate) mod tests { let result = stagger_call(f, &delays).await.unwrap(); assert_eq!(result, 5) } + + #[test] + #[traced_test] + fn jitter_test_zero() { + let jittered_delay = add_jitter(&0); + assert_eq!(jittered_delay, Duration::from_secs(0)); + } + + //Sanity checks that I did the math right + #[test] + #[traced_test] + fn jitter_test_nonzero_lower_bound() { + let delay: u64 = 300; + for _ in 0..100 { + assert!(add_jitter(&delay) >= Duration::from_millis(delay * 8 / 10)); + } + } + + #[test] + #[traced_test] + fn jitter_test_nonzero_upper_bound() { + let delay: u64 = 300; + for _ in 0..100 { + assert!(add_jitter(&delay) < Duration::from_millis(delay * 12 / 10)); + } + } }