diff --git a/src/net.rs b/src/net.rs index dafb319..e7fc0ce 100644 --- a/src/net.rs +++ b/src/net.rs @@ -960,24 +960,13 @@ impl Ipv4Network { /// ); /// ``` pub fn supernet_for(&self, nets: &[Ipv4Network]) -> Ipv4Network { - let (addr, mask) = self.to_bits(); - - let mut min = addr; - let mut max = addr; + let (addr, mut mask) = self.to_bits(); for net in nets { - let (addr, ..) = net.to_bits(); - if addr < min { - min = addr; - } else if addr > max { - max = addr; - } + let (a, m) = net.to_bits(); + mask &= m & !(addr ^ a); } - let common_addr = min ^ max; - let common_addr_len = common_addr.leading_zeros(); - let mask = mask & !(u32::MAX.checked_shr(common_addr_len).unwrap_or_default()); - Self::new(addr.into(), mask.into()) } @@ -2034,16 +2023,16 @@ impl Ipv6Network { /// /// use netip::Ipv6Network; /// - /// // 2013:db8:1::1/64 + /// // 2013:db8:1::/48 /// let net0 = Ipv6Network::new( - /// Ipv6Addr::new(0x2013, 0xdb8, 0x1, 0, 0, 0, 0, 0x1), - /// Ipv6Addr::new(0xffff, 0xffff, 0xffff, 0xffff, 0, 0, 0, 0), + /// Ipv6Addr::new(0x2013, 0xdb8, 0x1, 0, 0, 0, 0, 0), + /// Ipv6Addr::new(0xffff, 0xffff, 0xffff, 0, 0, 0, 0, 0), /// ); /// - /// // 2013:db8:2::1/64 + /// // 2013:db8:2::/48 /// let net1 = Ipv6Network::new( - /// Ipv6Addr::new(0x2013, 0xdb8, 0x2, 0, 0, 0, 0, 0x1), - /// Ipv6Addr::new(0xffff, 0xffff, 0xffff, 0xffff, 0, 0, 0, 0), + /// Ipv6Addr::new(0x2013, 0xdb8, 0x2, 0, 0, 0, 0, 0), + /// Ipv6Addr::new(0xffff, 0xffff, 0xffff, 0, 0, 0, 0, 0), /// ); /// /// // 2013:db8::/46 @@ -2054,24 +2043,13 @@ impl Ipv6Network { /// assert_eq!(expected, net0.supernet_for(&[net1])); /// ``` pub fn supernet_for(&self, nets: &[Ipv6Network]) -> Ipv6Network { - let (addr, mask) = self.to_bits(); - - let mut min = addr; - let mut max = addr; + let (addr, mut mask) = self.to_bits(); for net in nets { - let (addr, ..) = net.to_bits(); - if addr < min { - min = addr; - } else if addr > max { - max = addr; - } + let (a, m) = net.to_bits(); + mask &= m & !(addr ^ a); } - let common_addr = min ^ max; - let common_addr_len = common_addr.leading_zeros(); - let mask = mask & !(u128::MAX.checked_shr(common_addr_len).unwrap_or_default()); - Self::new(addr.into(), mask.into()) } @@ -2752,12 +2730,23 @@ where // Window size. let size = nets.len().div_ceil(2); + // Fold supernet_for over a window: equivalent to calling supernet_for + // with all elements at once (mask intersection is associative under fold + // because cleared mask positions stay cleared regardless of addr bits). + let window_supernet = |window: &[T]| -> U { + let mut supernet = *window[0].as_ref(); + for elem in &window[1..] { + supernet = supernet.supernet_for(&[*elem.as_ref()]); + } + supernet + }; + // We start with supernet that covers all given networks. - let mut range = 0..nets.len() - 1; - let mut candidate = nets[range.start].as_ref().supernet_for(&[*nets[range.end].as_ref()]); + let mut range = 0..nets.len(); + let mut candidate = window_supernet(nets); for (idx, window) in nets.windows(size).enumerate() { - let supernet = window[0].as_ref().supernet_for(&[*window[size - 1].as_ref()]); + let supernet = window_supernet(window); if supernet.mask() > candidate.mask() { range = idx..(idx + size); @@ -3443,6 +3432,24 @@ mod test { assert_eq!(expected, nets[0].supernet_for(&nets[1..]).to_contiguous()); } + #[test] + fn ipv4_supernet_for_different_prefix_lengths() { + let narrow = Ipv4Network::parse("10.0.0.0/24").unwrap(); + let wide = Ipv4Network::parse("10.0.0.0/16").unwrap(); + + assert_eq!(wide, narrow.supernet_for(&[wide])); + assert_eq!(wide, wide.supernet_for(&[narrow])); + } + + #[test] + fn ipv6_supernet_for_different_prefix_lengths() { + let narrow = Ipv6Network::parse("2001:db8::/48").unwrap(); + let wide = Ipv6Network::parse("2001:db8::/32").unwrap(); + + assert_eq!(wide, narrow.supernet_for(&[wide])); + assert_eq!(wide, wide.supernet_for(&[narrow])); + } + #[test] fn ipv4_last_addr() { // Contiguous networks (broadcast addresses). @@ -3596,18 +3603,36 @@ mod test { Ipv6Network::parse("2013:db8:2::1/64").unwrap(), ]; - let expected = Ipv6Network::parse("2013:db8::/46").unwrap(); + let expected = Ipv6Network::new( + Ipv6Addr::new(0x2013, 0xdb8, 0, 0, 0, 0, 0, 0), + Ipv6Addr::new(0xffff, 0xffff, 0xfffc, 0xffff, 0, 0, 0, 0), + ); assert_eq!(expected, nets[0].supernet_for(&nets[1..])); } #[test] fn ipv6_supernet_for_unspecified() { + // Alternating bit patterns: every bit disagrees, so mask becomes ::/0. + let nets = &[ + Ipv6Network::parse("aaaa:aaaa:aaaa:aaaa:aaaa:aaaa:aaaa:aaaa/128").unwrap(), + Ipv6Network::parse("5555:5555:5555:5555:5555:5555:5555:5555/128").unwrap(), + ]; + + let expected = Ipv6Network::parse("::/0").unwrap(); + assert_eq!(expected, nets[0].supernet_for(&nets[1..])); + } + + #[test] + fn ipv6_supernet_for_partial_agreement() { let nets = &[ Ipv6Network::parse("8001:db8:1::/34").unwrap(), Ipv6Network::parse("2013:db8:2::/32").unwrap(), ]; - let expected = Ipv6Network::parse("::/0").unwrap(); + let expected = Ipv6Network::new( + Ipv6Addr::new(0x0001, 0x0db8, 0, 0, 0, 0, 0, 0), + Ipv6Addr::new(0x5fed, 0xffff, 0, 0, 0, 0, 0, 0), + ); assert_eq!(expected, nets[0].supernet_for(&nets[1..])); } @@ -3618,7 +3643,7 @@ mod test { Ipv6Network::parse("2a02:6b8:c00::4707:0:0/ffff:ffff:ff00::ffff:ffff:0:0").unwrap(), ]; - let expected = Ipv6Network::parse("2a02:6b8:c00::4000:0:0/ffff:ffff:ff00::ffff:f000:0:0").unwrap(); + let expected = Ipv6Network::parse("2a02:6b8:c00::4002:0:0/ffff:ffff:ff00::ffff:f052:0:0").unwrap(); assert_eq!(expected, nets[0].supernet_for(&nets[1..])); } @@ -3629,7 +3654,7 @@ mod test { Ipv6Network::parse("2a02:6b8:fc00::4707:0:0/ffff:ffff:ff00::ffff:ffff:0:0").unwrap(), ]; - let expected = Ipv6Network::parse("2a02:6b8::/32").unwrap(); + let expected = Ipv6Network::parse("2a02:6b8:c00::4002:0:0/ffff:ffff:f00::ffff:f052:0:0").unwrap(); assert_eq!(expected, nets[0].supernet_for(&nets[1..])); } @@ -5059,5 +5084,69 @@ mod test { "a={}, b={}, inter={:?}", a, b, a.intersection(&b) ); } + + #[test] + fn prop_ipv4_supernet_for_contains_all( + a in arb_ipv4_network(), + b in arb_ipv4_network(), + ) { + let supernet = a.supernet_for(&[b]); + + prop_assert!(supernet.contains(&a), "supernet {supernet} must contain {a}"); + prop_assert!(supernet.contains(&b), "supernet {supernet} must contain {b}"); + } + + #[test] + fn prop_ipv6_supernet_for_contains_all( + a in arb_ipv6_network(), + b in arb_ipv6_network(), + ) { + let supernet = a.supernet_for(&[b]); + + prop_assert!(supernet.contains(&a), "supernet {supernet} must contain {a}"); + prop_assert!(supernet.contains(&b), "supernet {supernet} must contain {b}"); + } + + #[test] + fn prop_ipv4_supernet_for_diff_intersection( + a in arb_ipv4_network(), + b in arb_ipv4_network(), + ) { + let diff: Vec<_> = a.difference(&b).collect(); + let intersected = a.intersection(&b); + + if diff.is_empty() && intersected.is_none() { + return Ok(()); + } + + let mut decomposed = diff; + if let Some(c) = intersected { + decomposed.push(c); + } + + let supernet = decomposed[0].supernet_for(&decomposed[1..]); + prop_assert_eq!(supernet, a, "supernet_for(A\\B ∪ A∩B) must equal A"); + } + + #[test] + fn prop_ipv6_supernet_for_diff_intersection( + a in arb_ipv6_network(), + b in arb_ipv6_network(), + ) { + let diff: Vec<_> = a.difference(&b).collect(); + let intersected = a.intersection(&b); + + if diff.is_empty() && intersected.is_none() { + return Ok(()); + } + + let mut decomposed = diff; + if let Some(c) = intersected { + decomposed.push(c); + } + + let supernet = decomposed[0].supernet_for(&decomposed[1..]); + prop_assert_eq!(supernet, a, "supernet_for(A\\B ∪ A∩B) must equal A"); + } } }