|  | 
|  | 1 | +/* SPDX-License-Identifier: MIT OR Apache-2.0 */ | 
|  | 2 | +use crate::support::{CastInto, DInt, HInt, Int, MinInt, u256}; | 
|  | 3 | + | 
|  | 4 | +/// Trait for unsigned division of a double-wide integer | 
|  | 5 | +/// when the quotient doesn't overflow. | 
|  | 6 | +/// | 
|  | 7 | +/// This is the inverse of widening multiplication: | 
|  | 8 | +///  - for any `x` and nonzero `y`: `x.widen_mul(y).checked_narrowing_div_rem(y) == Some((x, 0))`, | 
|  | 9 | +///  - and for any `r in 0..y`: `x.carrying_mul(y, r).checked_narrowing_div_rem(y) == Some((x, r))`, | 
|  | 10 | +#[allow(dead_code)] | 
|  | 11 | +pub trait NarrowingDiv: DInt + MinInt<Unsigned = Self> { | 
|  | 12 | +    /// Computes `(self / n, self % n))` | 
|  | 13 | +    /// | 
|  | 14 | +    /// # Safety | 
|  | 15 | +    /// The caller must ensure that `self.hi() < n`, or equivalently, | 
|  | 16 | +    /// that the quotient does not overflow. | 
|  | 17 | +    unsafe fn unchecked_narrowing_div_rem(self, n: Self::H) -> (Self::H, Self::H); | 
|  | 18 | + | 
|  | 19 | +    /// Returns `Some((self / n, self % n))` when `self.hi() < n`. | 
|  | 20 | +    fn checked_narrowing_div_rem(self, n: Self::H) -> Option<(Self::H, Self::H)> { | 
|  | 21 | +        if self.hi() < n { | 
|  | 22 | +            Some(unsafe { self.unchecked_narrowing_div_rem(n) }) | 
|  | 23 | +        } else { | 
|  | 24 | +            None | 
|  | 25 | +        } | 
|  | 26 | +    } | 
|  | 27 | +} | 
|  | 28 | + | 
|  | 29 | +// For primitive types we can just use the standard | 
|  | 30 | +// division operators in the double-wide type. | 
|  | 31 | +macro_rules! impl_narrowing_div_primitive { | 
|  | 32 | +    ($D:ident) => { | 
|  | 33 | +        impl NarrowingDiv for $D { | 
|  | 34 | +            unsafe fn unchecked_narrowing_div_rem(self, n: Self::H) -> (Self::H, Self::H) { | 
|  | 35 | +                if self.hi() >= n { | 
|  | 36 | +                    unsafe { core::hint::unreachable_unchecked() } | 
|  | 37 | +                } | 
|  | 38 | +                ((self / n.widen()).cast(), (self % n.widen()).cast()) | 
|  | 39 | +            } | 
|  | 40 | +        } | 
|  | 41 | +    }; | 
|  | 42 | +} | 
|  | 43 | + | 
|  | 44 | +// Extend division from `u2N / uN` to `u4N / u2N` | 
|  | 45 | +// This is not the most efficient algorithm, but it is | 
|  | 46 | +// relatively simple. | 
|  | 47 | +macro_rules! impl_narrowing_div_recurse { | 
|  | 48 | +    ($D:ident) => { | 
|  | 49 | +        impl NarrowingDiv for $D { | 
|  | 50 | +            unsafe fn unchecked_narrowing_div_rem(self, n: Self::H) -> (Self::H, Self::H) { | 
|  | 51 | +                if self.hi() >= n { | 
|  | 52 | +                    unsafe { core::hint::unreachable_unchecked() } | 
|  | 53 | +                } | 
|  | 54 | + | 
|  | 55 | +                // Normalize the divisor by shifting the most significant one | 
|  | 56 | +                // to the leading position. `n != 0` is implied by `self.hi() < n` | 
|  | 57 | +                let lz = n.leading_zeros(); | 
|  | 58 | +                let a = self << lz; | 
|  | 59 | +                let b = n << lz; | 
|  | 60 | + | 
|  | 61 | +                let ah = a.hi(); | 
|  | 62 | +                let (a0, a1) = a.lo().lo_hi(); | 
|  | 63 | +                // SAFETY: For both calls, `b.leading_zeros() == 0` by the above shift. | 
|  | 64 | +                // SAFETY: `ah < b` follows from `self.hi() < n` | 
|  | 65 | +                let (q1, r) = unsafe { div_three_digits_by_two(a1, ah, b) }; | 
|  | 66 | +                // SAFETY: `r < b` is given as the postcondition of the previous call | 
|  | 67 | +                let (q0, r) = unsafe { div_three_digits_by_two(a0, r, b) }; | 
|  | 68 | + | 
|  | 69 | +                // Undo the earlier normalization for the remainder | 
|  | 70 | +                (Self::H::from_lo_hi(q0, q1), r >> lz) | 
|  | 71 | +            } | 
|  | 72 | +        } | 
|  | 73 | +    }; | 
|  | 74 | +} | 
|  | 75 | + | 
|  | 76 | +impl_narrowing_div_primitive!(u16); | 
|  | 77 | +impl_narrowing_div_primitive!(u32); | 
|  | 78 | +impl_narrowing_div_primitive!(u64); | 
|  | 79 | +impl_narrowing_div_primitive!(u128); | 
|  | 80 | +impl_narrowing_div_recurse!(u256); | 
|  | 81 | + | 
|  | 82 | +/// Implement `u3N / u2N`-division on top of `u2N / uN`-division. | 
|  | 83 | +/// | 
|  | 84 | +/// Returns the quotient and remainder of `(a * R + a0) / n`, | 
|  | 85 | +/// where `R = (1 << U::BITS)` is the digit size. | 
|  | 86 | +/// | 
|  | 87 | +/// # Safety | 
|  | 88 | +/// Requires that `n.leading_zeros() == 0` and `a < n`. | 
|  | 89 | +unsafe fn div_three_digits_by_two<U>(a0: U, a: U::D, n: U::D) -> (U, U::D) | 
|  | 90 | +where | 
|  | 91 | +    U: HInt, | 
|  | 92 | +    U::D: Int + NarrowingDiv, | 
|  | 93 | +{ | 
|  | 94 | +    if n.leading_zeros() > 0 || a >= n { | 
|  | 95 | +        unsafe { core::hint::unreachable_unchecked() } | 
|  | 96 | +    } | 
|  | 97 | + | 
|  | 98 | +    // n = n1R + n0 | 
|  | 99 | +    let (n0, n1) = n.lo_hi(); | 
|  | 100 | +    // a = a2R + a1 | 
|  | 101 | +    let (a1, a2) = a.lo_hi(); | 
|  | 102 | + | 
|  | 103 | +    let mut q; | 
|  | 104 | +    let mut r; | 
|  | 105 | +    let mut wrap; | 
|  | 106 | +    // `a < n` is guaranteed by the caller, but `a2 == n1 && a1 < n0` is possible | 
|  | 107 | +    if let Some((q0, r1)) = a.checked_narrowing_div_rem(n1) { | 
|  | 108 | +        q = q0; | 
|  | 109 | +        // a = qn1 + r1, where 0 <= r1 < n1 | 
|  | 110 | + | 
|  | 111 | +        // Include the remainder with the low bits: | 
|  | 112 | +        // r = a0 + r1R | 
|  | 113 | +        r = U::D::from_lo_hi(a0, r1); | 
|  | 114 | + | 
|  | 115 | +        // Subtract the contribution of the divisor low bits with the estimated quotient | 
|  | 116 | +        let d = q.widen_mul(n0); | 
|  | 117 | +        (r, wrap) = r.overflowing_sub(d); | 
|  | 118 | + | 
|  | 119 | +        // Since `q` is the quotient of dividing with a slightly smaller divisor, | 
|  | 120 | +        // it may be an overapproximation, but is never too small, and similarly, | 
|  | 121 | +        // `r` is now either the correct remainder ... | 
|  | 122 | +        if !wrap { | 
|  | 123 | +            return (q, r); | 
|  | 124 | +        } | 
|  | 125 | +        // ... or the remainder went "negative" (by as much as `d = qn0 < RR`) | 
|  | 126 | +        // and we have to adjust. | 
|  | 127 | +        q -= U::ONE; | 
|  | 128 | +    } else { | 
|  | 129 | +        debug_assert!(a2 == n1 && a1 < n0); | 
|  | 130 | +        // Otherwise, `a2 == n1`, and the estimated quotient would be | 
|  | 131 | +        // `R + (a1 % n1)`, but the correct quotient can't overflow. | 
|  | 132 | +        // We'll start from `q = R = (1 << U::BITS)`, | 
|  | 133 | +        // so `r = aR + a0 - qn = (a - n)R + a0` | 
|  | 134 | +        r = U::D::from_lo_hi(a0, a1.wrapping_sub(n0)); | 
|  | 135 | +        // Since `a < n`, the first decrement is always needed: | 
|  | 136 | +        q = U::MAX; /* R - 1 */ | 
|  | 137 | +    } | 
|  | 138 | + | 
|  | 139 | +    (r, wrap) = r.overflowing_add(n); | 
|  | 140 | +    if wrap { | 
|  | 141 | +        return (q, r); | 
|  | 142 | +    } | 
|  | 143 | + | 
|  | 144 | +    // If the remainder still didn't wrap, we need another step. | 
|  | 145 | +    q -= U::ONE; | 
|  | 146 | +    (r, wrap) = r.overflowing_add(n); | 
|  | 147 | +    // Since `n >= RR/2`, at least one of the two `r += n` must have wrapped. | 
|  | 148 | +    debug_assert!(wrap, "estimated quotient should be off by at most two"); | 
|  | 149 | +    (q, r) | 
|  | 150 | +} | 
|  | 151 | + | 
|  | 152 | +#[cfg(test)] | 
|  | 153 | +mod test { | 
|  | 154 | +    use super::{HInt, NarrowingDiv}; | 
|  | 155 | + | 
|  | 156 | +    #[test] | 
|  | 157 | +    fn inverse_mul() { | 
|  | 158 | +        for x in 0..=u8::MAX { | 
|  | 159 | +            for y in 1..=u8::MAX { | 
|  | 160 | +                let xy = x.widen_mul(y); | 
|  | 161 | +                assert_eq!(xy.checked_narrowing_div_rem(y), Some((x, 0))); | 
|  | 162 | +                assert_eq!( | 
|  | 163 | +                    (xy + (y - 1) as u16).checked_narrowing_div_rem(y), | 
|  | 164 | +                    Some((x, y - 1)) | 
|  | 165 | +                ); | 
|  | 166 | +                if y > 1 { | 
|  | 167 | +                    assert_eq!((xy + 1).checked_narrowing_div_rem(y), Some((x, 1))); | 
|  | 168 | +                    assert_eq!( | 
|  | 169 | +                        (xy + (y - 2) as u16).checked_narrowing_div_rem(y), | 
|  | 170 | +                        Some((x, y - 2)) | 
|  | 171 | +                    ); | 
|  | 172 | +                } | 
|  | 173 | +            } | 
|  | 174 | +        } | 
|  | 175 | +    } | 
|  | 176 | +} | 
0 commit comments