diff --git a/Cargo.toml b/Cargo.toml index 7ea340f..f97dc32 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,9 @@ [workspace] members = [ "hcpl_algebra", + "hcpl_complex", "hcpl_divide_and_conquer_dp", + "hcpl_fft", "hcpl_fwht", "hcpl_integer", "hcpl_io", diff --git a/hcpl_complex/Cargo.toml b/hcpl_complex/Cargo.toml new file mode 100644 index 0000000..c1d6d44 --- /dev/null +++ b/hcpl_complex/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "hcpl_complex" +version = "0.1.0" +edition = "2021" +repository = "https://github.com/THE-nio/hcpl" +license = "MIT" + +[dependencies] +hcpl_algebra = { version = "0.1.0", path = "../hcpl_algebra" } +hcpl_number_theory = { path = "../hcpl_number_theory" } diff --git a/hcpl_complex/src/lib.rs b/hcpl_complex/src/lib.rs new file mode 100644 index 0000000..ff395ec --- /dev/null +++ b/hcpl_complex/src/lib.rs @@ -0,0 +1,256 @@ +#[derive(Clone, Copy, Debug, Default)] +pub struct Complex { + pub re: f64, + pub im: f64, +} + +impl Complex { + pub const ZERO: Complex = Complex { re: 0., im: 0. }; + pub const ONE: Complex = Complex { re: 1., im: 0. }; + pub const I: Complex = Complex { re: 0., im: 1. }; + + pub fn powi(mut self, mut rhs: u32) -> Self { + let mut result = Self::ONE; + + while rhs > 0 { + if (rhs & 1) != 0 { + result *= self; + } + + self *= self; + rhs >>= 1; + } + + result + } + + pub fn cis(angle: f64) -> Self { + Self { + re: angle.cos(), + im: angle.sin(), + } + } + + pub fn conj(self) -> Self { + Self { + re: self.re, + im: -self.im, + } + } + + pub fn squared_norm(self) -> f64 { + self.re * self.re + self.im * self.im + } + + pub fn inv(self) -> Self { + Self { + re: self.re / self.squared_norm(), + im: -self.im / self.squared_norm(), + } + } +} + +impl std::fmt::Display for Complex { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_fmt(format_args!("{} + i {}", self.re, self.im)) + } +} + +impl From for Complex { + fn from(value: f64) -> Self { + Self { re: value, im: 0. } + } +} + +impl From for Complex { + fn from(value: usize) -> Self { + Self { + re: value as f64, + im: 0., + } + } +} + +impl std::ops::Add for Complex { + type Output = Complex; + + fn add(self, rhs: Complex) -> Self::Output { + Self { + re: self.re + rhs.re, + im: self.im + rhs.im, + } + } +} + +impl std::ops::Add for Complex { + type Output = Complex; + + fn add(self, rhs: f64) -> Self::Output { + self + Complex::from(rhs) + } +} + +impl std::ops::Neg for Complex { + type Output = Complex; + + fn neg(self) -> Self::Output { + Self { + re: -self.re, + im: -self.im, + } + } +} + +impl std::ops::Sub for Complex { + type Output = Complex; + + fn sub(self, rhs: Complex) -> Self::Output { + self + -rhs + } +} + +impl std::ops::Sub for Complex { + type Output = Complex; + + fn sub(self, rhs: f64) -> Self::Output { + self + -rhs + } +} + +impl std::ops::Mul for Complex { + type Output = Complex; + + fn mul(self, rhs: Complex) -> Self::Output { + Self { + re: self.re * rhs.re - self.im * rhs.im, + im: self.re * rhs.im + self.im * rhs.re, + } + } +} + +impl std::ops::Mul for Complex { + type Output = Complex; + + fn mul(self, rhs: f64) -> Self::Output { + Self { + re: self.re * rhs, + im: self.im * rhs, + } + } +} + +impl std::ops::Div for Complex { + type Output = Complex; + + fn div(self, rhs: Complex) -> Self::Output { + self * rhs.inv() + } +} + +impl std::ops::Div for Complex { + type Output = Complex; + + fn div(self, rhs: f64) -> Self::Output { + Self { + re: self.re / rhs, + im: self.im / rhs, + } + } +} + +impl std::ops::AddAssign for Complex { + fn add_assign(&mut self, rhs: Complex) { + *self = *self + rhs; + } +} + +impl std::ops::AddAssign for Complex { + fn add_assign(&mut self, rhs: f64) { + *self = *self + rhs; + } +} + +impl std::ops::SubAssign for Complex { + fn sub_assign(&mut self, rhs: Complex) { + *self = *self - rhs; + } +} + +impl std::ops::SubAssign for Complex { + fn sub_assign(&mut self, rhs: f64) { + *self = *self - rhs; + } +} + +impl std::ops::MulAssign for Complex { + fn mul_assign(&mut self, rhs: Complex) { + *self = *self * rhs; + } +} + +impl std::ops::MulAssign for Complex { + fn mul_assign(&mut self, rhs: f64) { + *self = *self * rhs; + } +} + +impl std::ops::DivAssign for Complex { + fn div_assign(&mut self, rhs: Complex) { + *self = *self / rhs; + } +} + +impl std::ops::DivAssign for Complex { + fn div_assign(&mut self, rhs: f64) { + *self = *self / rhs; + } +} + +impl std::iter::Sum for Complex { + fn sum>(iter: I) -> Self { + let mut result = Complex::ZERO; + + for item in iter { + result += item; + } + + result + } +} + +impl std::iter::Product for Complex { + fn product>(iter: I) -> Self { + let mut result = Complex::ONE; + + for item in iter { + result *= item; + } + + result + } +} + +impl hcpl_algebra::monoid::AdditiveIdentity for Complex { + const VALUE: Self = Complex::ZERO; +} +impl hcpl_algebra::monoid::MultiplicativeIdentity for Complex { + const VALUE: Self = Complex::ONE; +} + +impl hcpl_number_theory::roots::TryNthRootOfUnity for Complex { + type Error = std::convert::Infallible; + + fn try_nth_root_of_unity(n: usize) -> Result + where + Self: Sized, + { + Ok(Self::cis(2. * std::f64::consts::PI / n as f64)) + } + + fn try_nth_root_of_unity_inv(n: usize) -> Result + where + Self: Sized, + { + Ok(Self::cis(-2. * std::f64::consts::PI / n as f64)) + } +} diff --git a/hcpl_fft/Cargo.toml b/hcpl_fft/Cargo.toml new file mode 100644 index 0000000..d9230f9 --- /dev/null +++ b/hcpl_fft/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "hcpl_fft" +version = "0.1.0" +edition = "2021" +repository = "https://github.com/THE-nio/hcpl" +license = "MIT" + +[dependencies] +hcpl_algebra = { path = "../hcpl_algebra" } +hcpl_number_theory = { path = "../hcpl_number_theory" } +hcpl_complex = { path = "../hcpl_complex", optional = true } + +[features] +real = [ "dep:hcpl_complex" ] diff --git a/hcpl_fft/src/lib.rs b/hcpl_fft/src/lib.rs new file mode 100644 index 0000000..2443687 --- /dev/null +++ b/hcpl_fft/src/lib.rs @@ -0,0 +1,98 @@ +#[cfg(feature = "real")] +pub mod real; + +mod stack_stack; + +use hcpl_algebra::monoid::MultiplicativeIdentity; +use hcpl_number_theory::roots::TryNthRootOfUnity; +use std::fmt::Debug; + +/// Performs the bit-reversal permutation on `vec` +pub fn bit_reversal(vec: &mut [T]) { + let n = vec.len(); + + let mut j = 0; + for i in 1..n { + let mut bit = n >> 1; + while (j & bit) != 0 { + j ^= bit; + bit >>= 1; + } + j ^= bit; + + if i < j { + vec.swap(i, j); + } + } +} + +fn get_roots(mut n: usize, inv: bool) -> stack_stack::StackStack +where + T: TryNthRootOfUnity, + ::Error: Debug, +{ + let mut roots = stack_stack::StackStack::new(); + + let mut last: T = if inv { + TryNthRootOfUnity::try_nth_root_of_unity_inv(n) + } else { + TryNthRootOfUnity::try_nth_root_of_unity(n) + } + .unwrap(); + while n >= 2 { + roots.push(last); + n /= 2; + last = last * last; + } + + roots +} + +/// Performs the in-place Fast Fourier Transform on the slice `vec`, whose lenght must be a power of two. DOES NOT +/// PERFORM NORMALISATION +pub fn fft(vec: &mut [T], inv: bool) +where + T: TryNthRootOfUnity, + ::Error: Debug, +{ + bit_reversal(vec); + + let n = vec.len(); + + let mut roots = get_roots(n, inv); + let mut width = 2; + while width <= n { + let w_d = roots.pop().unwrap(); + + for i in (0..n).step_by(width) { + let mut w = ::VALUE; + for j in 0..width / 2 { + let l = i + j; + let r = i + j + width / 2; + + (vec[l], vec[r]) = (vec[l] + w * vec[r], vec[l] - w * vec[r]); + w = w * w_d; + } + } + + width *= 2; + } +} + +/// Performs the in-place discrete convolution of the two vectors `a` and `b`. DOES NOT PERFORM +/// NORMALISATION +/// +/// `a` will contain the convolution of `a` and `b`, and `b` will contain the discrete +/// Fourier transform of `b`. +pub fn convolution(a: &mut [T], b: &mut [T]) +where + T: TryNthRootOfUnity, + ::Error: Debug, +{ + fft(a, false); + fft(b, false); + for (x, y) in a.iter_mut().zip(b.iter().copied()) { + *x = *x * y; + } + fft(a, true); +} diff --git a/hcpl_fft/src/real.rs b/hcpl_fft/src/real.rs new file mode 100644 index 0000000..6bda867 --- /dev/null +++ b/hcpl_fft/src/real.rs @@ -0,0 +1,113 @@ +//! FFTs of real-valued data +//! +//! Functions in this module should act the same as functions in the crate root but may assume that the input or output is real-valued. +//! +//! PLEASE DON'T RELY ON THIS AT THE MOMENT AS IT LACKS TESTING + +use std::f64::consts::PI; + +use hcpl_complex::Complex; + +use crate::fft; + +/// Performs both recursive calls in the Cooley-Tukey algorithm with only one call to `fft`. `vec` must be real-valued. +pub fn double_real_ffft(vec: &mut [Complex]) { + let n = vec.len() / 2; + + for i in 0..n { + vec[i] = Complex { + re: vec[2 * i].re, + im: vec[2 * i + 1].re, + }; + } + + fft(&mut vec[..n], false); + + vec.copy_within(..n, n); + vec[n + 1..].reverse(); + + // To see why this works, consider the case where all imaginary components are zero. The FFT famously exhibits conjugate symmetry + // on real data, meaning that vec[-i] = vec[i]*. Therefore `z_c` (= vec[-i]* = vec[i]** = vec[i]) is equal to `z`, and + // `(z + z_c) / 2 == z` and `-i * (z - z_c) / 2 == 0`. These `z` values are simply the DHT of the input array, as expected. + // + // The same argument works if all real components are zero, except the output will have been multiplied by `i`. Using linearity one + // can show that vec[..n] will contain the DHT of the even-indexed values in the input and vec[n..] will contain the DHT of the + // odd-indexed values in the input. + for i in 0..n { + let z = vec[i]; + let z_c = vec[i + n].conj(); + + vec[i] = (z + z_c) / 2.; + vec[i + n] = -Complex::I * (z - z_c) / 2.; + } +} + +/// Calculates the FFT of a sequence of real numbers +pub fn real_ffft(vec: &mut [Complex]) { + let n = vec.len(); + + double_real_ffft(vec); + + let mut w = Complex::ONE; + let w_d = Complex::cis(2. * PI / n as f64); + for i in 0..n / 2 { + (vec[i], vec[i + n / 2]) = (vec[i] + w * vec[i + n / 2], vec[i] - w * vec[i + n / 2]); + w *= w_d; + } +} + +// These functions were written by manually inverting the functions above. + +/// Inverse of double_real_fft, up to normalisation +pub fn double_real_ifft(vec: &mut [Complex]) { + // a, b = (x + y)/2, (x - y) / 2i + // a + bi = x + // a - bi = y + + let n = vec.len() / 2; + + for i in 0..n { + vec[i] = vec[i] + Complex::I * vec[i + n]; + } + + fft(&mut vec[..n], true); + + for i in (0..n).rev() { + (vec[2 * i], vec[2 * i + 1]) = (vec[i].re.into(), vec[i].im.into()); + } +} + +/// Calculates the IFFT of a sequence which is the DFT of a sequence of real numbers. Does NOT perform normalisation. +pub fn real_ifft(vec: &mut [Complex]) { + // a, b = x + w y, x - w y + // a + b = x + w y + x - w y = 2 x + // a - b = x + w y - x + w y = 2 w y + // (a + b) / 2 = x + // (a - b) / 2w = y + + let n = vec.len(); + + let mut w = Complex::ONE; + let w_d = Complex::cis(-2. * PI / n as f64); + + for i in 0..n/2 { + (vec[i], vec[i + n / 2]) = ((vec[i] + vec[i + n / 2]) / 2., w * (vec[i] - vec[i + n / 2]) / 2.); + w *= w_d; + } + + double_real_ifft(vec); + + // apply missing denormalisation + for i in 0..n { + vec[i] *= 2.; + } +} + +/// Like real_ffft and real_ifft but using an extra parameter to choose direction +pub fn real_fft(vec: &mut [Complex], inverse: bool) { + if inverse { + real_ifft(vec); + } else { + real_ffft(vec); + } +} diff --git a/hcpl_fft/src/stack_stack.rs b/hcpl_fft/src/stack_stack.rs new file mode 100644 index 0000000..a9ae7e3 --- /dev/null +++ b/hcpl_fft/src/stack_stack.rs @@ -0,0 +1,42 @@ +use std::mem::MaybeUninit; + +const BUF_SIZE: usize = 50; + +/// A stack-allocated stack +pub(crate) struct StackStack { + buffer: [MaybeUninit; BUF_SIZE], + len: usize, +} + +impl StackStack { + pub(crate) fn new() -> Self { + Self { + buffer: unsafe { + MaybeUninit::<[MaybeUninit; BUF_SIZE]>::uninit().assume_init() + }, + len: 0, + } + } + + pub(crate) fn push(&mut self, val: T) -> Option<()> { + if self.len == BUF_SIZE { + return None; + } + + self.buffer[self.len].write(val); + self.len += 1; + + Some(()) + } + + pub(crate) fn pop(&mut self) -> Option { + if self.len == 0 { + return None; + } + + self.len -= 1; + + // SAFETY: this element will have been written to exactly once, and won't be read again + Some(unsafe { self.buffer[self.len].assume_init_read() }) + } +} diff --git a/hcpl_modnum/Cargo.toml b/hcpl_modnum/Cargo.toml index 60e2682..ab197cf 100644 --- a/hcpl_modnum/Cargo.toml +++ b/hcpl_modnum/Cargo.toml @@ -7,3 +7,4 @@ license = "MIT" [dependencies] hcpl_algebra = { path = "../hcpl_algebra" } +hcpl_number_theory = { path = "../hcpl_number_theory" } diff --git a/hcpl_modnum/src/lib.rs b/hcpl_modnum/src/lib.rs index f885922..4ca797f 100644 --- a/hcpl_modnum/src/lib.rs +++ b/hcpl_modnum/src/lib.rs @@ -135,3 +135,46 @@ impl hcpl_algebra::monoid::AdditiveIdentity for Modnum { impl hcpl_algebra::monoid::MultiplicativeIdentity for Modnum { const VALUE: Self = Self::new(1); } + +impl hcpl_number_theory::roots::TryNthRootOfUnity for Modnum { + type Error = &'static str; + + fn try_nth_root_of_unity(n: usize) -> Result { + if n == 0 { + Err("n must be positive") + } else if n == 1 { + Ok(Self::new(1)) + } else if (N - 1) as usize % n != 0 { + Err("n is not a divisor of MOD - 1") + } else { + for base in 2.. { + #[cfg(debug_assertions)] + eprintln!("trying base {}", base); + let attempt = Self::new(base).pow((N - 1) as usize / n); + let mut d = 0; + if loop { + d += 1; + if d * d > n { + break true; + } + if n % d != 0 { + continue; + } + if attempt.pow(d as usize) == Self::new(1) { + break false; + } + if d != 1 && d * d != n && attempt.pow((n / d) as usize) == Self::new(1) { + break false; + } + } { + return Ok(attempt); + } + } + unreachable!() + } + } + + fn try_nth_root_of_unity_inv(n: usize) -> Result { + Ok(Self::try_nth_root_of_unity(n)?.inv()) + } +} diff --git a/hcpl_number_theory/Cargo.toml b/hcpl_number_theory/Cargo.toml index 86677b1..e38e715 100644 --- a/hcpl_number_theory/Cargo.toml +++ b/hcpl_number_theory/Cargo.toml @@ -6,4 +6,5 @@ repository = "https://github.com/THE-nio/hcpl" license = "MIT" [dependencies] +hcpl_algebra = { path = "../hcpl_algebra" } hcpl_integer = { path = "../hcpl_integer" } diff --git a/hcpl_number_theory/src/lib.rs b/hcpl_number_theory/src/lib.rs index 7af74c6..9cad09a 100644 --- a/hcpl_number_theory/src/lib.rs +++ b/hcpl_number_theory/src/lib.rs @@ -1,4 +1,5 @@ mod gcd; pub mod prime; +pub mod roots; pub use gcd::gcd; diff --git a/hcpl_number_theory/src/roots.rs b/hcpl_number_theory/src/roots.rs new file mode 100644 index 0000000..62eb189 --- /dev/null +++ b/hcpl_number_theory/src/roots.rs @@ -0,0 +1,27 @@ +use hcpl_algebra::{monoid::MultiplicativeIdentity, Ring}; + +pub trait TryNthRootOfUnity: Sized + Copy + Ring { + type Error; + + /// Returns a principal nth root of unity. Must be deterministic. + fn try_nth_root_of_unity(n: usize) -> Result; + + /// Returns a value s.t. try_nth_root_of_unity(n) * try_nth_root_of_unity_inv(n) = 1 + fn try_nth_root_of_unity_inv(n: usize) -> Result { + let mut root = Self::try_nth_root_of_unity(n)?; + + let mut result = ::VALUE; + let mut exp = n - 1; + + while exp > 0 { + if exp % 2 == 1 { + result = result * root; + } + + root = root * root; + exp >>= 1; + } + + Ok(root) + } +}