From 4854fb8c7bb8ea938c8847477433a8e0afac8c6a Mon Sep 17 00:00:00 2001 From: Evgeny Safronov Date: Tue, 7 Apr 2026 21:51:02 +0300 Subject: [PATCH] fix: supernet_for uses mask intersection instead of common prefix MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The old algorithm only compared base addresses to find a common prefix, ignoring input masks entirely. This produced incorrect results when inputs had different prefix lengths (e.g. /24 and /16 with the same base address returned /24 instead of /16). New algorithm — intersects masks and clears bits where addresses disagree, yielding the tightest valid supernet. Unfortunately, this fix resulted in many tests broken. Several existing tests expected the old (wider) supernet values. These were not catching the bug — they simply matched the previous lossy behavior and have been updated to reflect the correct tighter results. Also fixes binary_split to fold supernet_for over all window elements instead of just the endpoints, since the tighter supernet no longer guarantees contiguous address coverage. --- src/net.rs | 171 ++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 130 insertions(+), 41 deletions(-) diff --git a/src/net.rs b/src/net.rs index dafb319..e7fc0ce 100644 --- a/src/net.rs +++ b/src/net.rs @@ -960,24 +960,13 @@ impl Ipv4Network { /// ); /// ``` pub fn supernet_for(&self, nets: &[Ipv4Network]) -> Ipv4Network { - let (addr, mask) = self.to_bits(); - - let mut min = addr; - let mut max = addr; + let (addr, mut mask) = self.to_bits(); for net in nets { - let (addr, ..) = net.to_bits(); - if addr < min { - min = addr; - } else if addr > max { - max = addr; - } + let (a, m) = net.to_bits(); + mask &= m & !(addr ^ a); } - let common_addr = min ^ max; - let common_addr_len = common_addr.leading_zeros(); - let mask = mask & !(u32::MAX.checked_shr(common_addr_len).unwrap_or_default()); - Self::new(addr.into(), mask.into()) } @@ -2034,16 +2023,16 @@ impl Ipv6Network { /// /// use netip::Ipv6Network; /// - /// // 2013:db8:1::1/64 + /// // 2013:db8:1::/48 /// let net0 = Ipv6Network::new( - /// Ipv6Addr::new(0x2013, 0xdb8, 0x1, 0, 0, 0, 0, 0x1), - /// Ipv6Addr::new(0xffff, 0xffff, 0xffff, 0xffff, 0, 0, 0, 0), + /// Ipv6Addr::new(0x2013, 0xdb8, 0x1, 0, 0, 0, 0, 0), + /// Ipv6Addr::new(0xffff, 0xffff, 0xffff, 0, 0, 0, 0, 0), /// ); /// - /// // 2013:db8:2::1/64 + /// // 2013:db8:2::/48 /// let net1 = Ipv6Network::new( - /// Ipv6Addr::new(0x2013, 0xdb8, 0x2, 0, 0, 0, 0, 0x1), - /// Ipv6Addr::new(0xffff, 0xffff, 0xffff, 0xffff, 0, 0, 0, 0), + /// Ipv6Addr::new(0x2013, 0xdb8, 0x2, 0, 0, 0, 0, 0), + /// Ipv6Addr::new(0xffff, 0xffff, 0xffff, 0, 0, 0, 0, 0), /// ); /// /// // 2013:db8::/46 @@ -2054,24 +2043,13 @@ impl Ipv6Network { /// assert_eq!(expected, net0.supernet_for(&[net1])); /// ``` pub fn supernet_for(&self, nets: &[Ipv6Network]) -> Ipv6Network { - let (addr, mask) = self.to_bits(); - - let mut min = addr; - let mut max = addr; + let (addr, mut mask) = self.to_bits(); for net in nets { - let (addr, ..) = net.to_bits(); - if addr < min { - min = addr; - } else if addr > max { - max = addr; - } + let (a, m) = net.to_bits(); + mask &= m & !(addr ^ a); } - let common_addr = min ^ max; - let common_addr_len = common_addr.leading_zeros(); - let mask = mask & !(u128::MAX.checked_shr(common_addr_len).unwrap_or_default()); - Self::new(addr.into(), mask.into()) } @@ -2752,12 +2730,23 @@ where // Window size. let size = nets.len().div_ceil(2); + // Fold supernet_for over a window: equivalent to calling supernet_for + // with all elements at once (mask intersection is associative under fold + // because cleared mask positions stay cleared regardless of addr bits). + let window_supernet = |window: &[T]| -> U { + let mut supernet = *window[0].as_ref(); + for elem in &window[1..] { + supernet = supernet.supernet_for(&[*elem.as_ref()]); + } + supernet + }; + // We start with supernet that covers all given networks. - let mut range = 0..nets.len() - 1; - let mut candidate = nets[range.start].as_ref().supernet_for(&[*nets[range.end].as_ref()]); + let mut range = 0..nets.len(); + let mut candidate = window_supernet(nets); for (idx, window) in nets.windows(size).enumerate() { - let supernet = window[0].as_ref().supernet_for(&[*window[size - 1].as_ref()]); + let supernet = window_supernet(window); if supernet.mask() > candidate.mask() { range = idx..(idx + size); @@ -3443,6 +3432,24 @@ mod test { assert_eq!(expected, nets[0].supernet_for(&nets[1..]).to_contiguous()); } + #[test] + fn ipv4_supernet_for_different_prefix_lengths() { + let narrow = Ipv4Network::parse("10.0.0.0/24").unwrap(); + let wide = Ipv4Network::parse("10.0.0.0/16").unwrap(); + + assert_eq!(wide, narrow.supernet_for(&[wide])); + assert_eq!(wide, wide.supernet_for(&[narrow])); + } + + #[test] + fn ipv6_supernet_for_different_prefix_lengths() { + let narrow = Ipv6Network::parse("2001:db8::/48").unwrap(); + let wide = Ipv6Network::parse("2001:db8::/32").unwrap(); + + assert_eq!(wide, narrow.supernet_for(&[wide])); + assert_eq!(wide, wide.supernet_for(&[narrow])); + } + #[test] fn ipv4_last_addr() { // Contiguous networks (broadcast addresses). @@ -3596,18 +3603,36 @@ mod test { Ipv6Network::parse("2013:db8:2::1/64").unwrap(), ]; - let expected = Ipv6Network::parse("2013:db8::/46").unwrap(); + let expected = Ipv6Network::new( + Ipv6Addr::new(0x2013, 0xdb8, 0, 0, 0, 0, 0, 0), + Ipv6Addr::new(0xffff, 0xffff, 0xfffc, 0xffff, 0, 0, 0, 0), + ); assert_eq!(expected, nets[0].supernet_for(&nets[1..])); } #[test] fn ipv6_supernet_for_unspecified() { + // Alternating bit patterns: every bit disagrees, so mask becomes ::/0. + let nets = &[ + Ipv6Network::parse("aaaa:aaaa:aaaa:aaaa:aaaa:aaaa:aaaa:aaaa/128").unwrap(), + Ipv6Network::parse("5555:5555:5555:5555:5555:5555:5555:5555/128").unwrap(), + ]; + + let expected = Ipv6Network::parse("::/0").unwrap(); + assert_eq!(expected, nets[0].supernet_for(&nets[1..])); + } + + #[test] + fn ipv6_supernet_for_partial_agreement() { let nets = &[ Ipv6Network::parse("8001:db8:1::/34").unwrap(), Ipv6Network::parse("2013:db8:2::/32").unwrap(), ]; - let expected = Ipv6Network::parse("::/0").unwrap(); + let expected = Ipv6Network::new( + Ipv6Addr::new(0x0001, 0x0db8, 0, 0, 0, 0, 0, 0), + Ipv6Addr::new(0x5fed, 0xffff, 0, 0, 0, 0, 0, 0), + ); assert_eq!(expected, nets[0].supernet_for(&nets[1..])); } @@ -3618,7 +3643,7 @@ mod test { Ipv6Network::parse("2a02:6b8:c00::4707:0:0/ffff:ffff:ff00::ffff:ffff:0:0").unwrap(), ]; - let expected = Ipv6Network::parse("2a02:6b8:c00::4000:0:0/ffff:ffff:ff00::ffff:f000:0:0").unwrap(); + let expected = Ipv6Network::parse("2a02:6b8:c00::4002:0:0/ffff:ffff:ff00::ffff:f052:0:0").unwrap(); assert_eq!(expected, nets[0].supernet_for(&nets[1..])); } @@ -3629,7 +3654,7 @@ mod test { Ipv6Network::parse("2a02:6b8:fc00::4707:0:0/ffff:ffff:ff00::ffff:ffff:0:0").unwrap(), ]; - let expected = Ipv6Network::parse("2a02:6b8::/32").unwrap(); + let expected = Ipv6Network::parse("2a02:6b8:c00::4002:0:0/ffff:ffff:f00::ffff:f052:0:0").unwrap(); assert_eq!(expected, nets[0].supernet_for(&nets[1..])); } @@ -5059,5 +5084,69 @@ mod test { "a={}, b={}, inter={:?}", a, b, a.intersection(&b) ); } + + #[test] + fn prop_ipv4_supernet_for_contains_all( + a in arb_ipv4_network(), + b in arb_ipv4_network(), + ) { + let supernet = a.supernet_for(&[b]); + + prop_assert!(supernet.contains(&a), "supernet {supernet} must contain {a}"); + prop_assert!(supernet.contains(&b), "supernet {supernet} must contain {b}"); + } + + #[test] + fn prop_ipv6_supernet_for_contains_all( + a in arb_ipv6_network(), + b in arb_ipv6_network(), + ) { + let supernet = a.supernet_for(&[b]); + + prop_assert!(supernet.contains(&a), "supernet {supernet} must contain {a}"); + prop_assert!(supernet.contains(&b), "supernet {supernet} must contain {b}"); + } + + #[test] + fn prop_ipv4_supernet_for_diff_intersection( + a in arb_ipv4_network(), + b in arb_ipv4_network(), + ) { + let diff: Vec<_> = a.difference(&b).collect(); + let intersected = a.intersection(&b); + + if diff.is_empty() && intersected.is_none() { + return Ok(()); + } + + let mut decomposed = diff; + if let Some(c) = intersected { + decomposed.push(c); + } + + let supernet = decomposed[0].supernet_for(&decomposed[1..]); + prop_assert_eq!(supernet, a, "supernet_for(A\\B ∪ A∩B) must equal A"); + } + + #[test] + fn prop_ipv6_supernet_for_diff_intersection( + a in arb_ipv6_network(), + b in arb_ipv6_network(), + ) { + let diff: Vec<_> = a.difference(&b).collect(); + let intersected = a.intersection(&b); + + if diff.is_empty() && intersected.is_none() { + return Ok(()); + } + + let mut decomposed = diff; + if let Some(c) = intersected { + decomposed.push(c); + } + + let supernet = decomposed[0].supernet_for(&decomposed[1..]); + prop_assert_eq!(supernet, a, "supernet_for(A\\B ∪ A∩B) must equal A"); + } } }