From b3d265b0e75fabb1ca678466b6db4b91283a0b66 Mon Sep 17 00:00:00 2001 From: jakobrs Date: Fri, 19 May 2023 13:44:05 +0200 Subject: [PATCH 1/4] Complex numbers and FFTs --- Cargo.toml | 2 + hcpl_complex/Cargo.toml | 6 ++ hcpl_complex/src/lib.rs | 190 ++++++++++++++++++++++++++++++++++++++++ hcpl_fft/Cargo.toml | 11 +++ hcpl_fft/src/lib.rs | 66 ++++++++++++++ 5 files changed, 275 insertions(+) create mode 100644 hcpl_complex/Cargo.toml create mode 100644 hcpl_complex/src/lib.rs create mode 100644 hcpl_fft/Cargo.toml create mode 100644 hcpl_fft/src/lib.rs diff --git a/Cargo.toml b/Cargo.toml index 7ea340f..f97dc32 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,9 @@ [workspace] members = [ "hcpl_algebra", + "hcpl_complex", "hcpl_divide_and_conquer_dp", + "hcpl_fft", "hcpl_fwht", "hcpl_integer", "hcpl_io", diff --git a/hcpl_complex/Cargo.toml b/hcpl_complex/Cargo.toml new file mode 100644 index 0000000..99c2290 --- /dev/null +++ b/hcpl_complex/Cargo.toml @@ -0,0 +1,6 @@ +[package] +name = "hcpl_complex" +version = "0.1.0" +edition = "2021" +repository = "https://github.com/THE-nio/hcpl" +license = "MIT" diff --git a/hcpl_complex/src/lib.rs b/hcpl_complex/src/lib.rs new file mode 100644 index 0000000..b9af6ec --- /dev/null +++ b/hcpl_complex/src/lib.rs @@ -0,0 +1,190 @@ +#[derive(Clone, Copy, Debug, Default)] +pub struct Complex { + pub re: f64, + pub im: f64, +} + +impl Complex { + pub const ZERO: Complex = Complex { re: 0., im: 0. }; + pub const ONE: Complex = Complex { re: 1., im: 0. }; + pub const I: Complex = Complex { re: 0., im: 1. }; + + pub fn powi(mut self, mut rhs: u32) -> Self { + let mut result = Self::ONE; + + while rhs > 0 { + if (rhs & 1) != 0 { + result *= self; + } + + self *= self; + rhs >>= 1; + } + + result + } + + pub fn cis(angle: f64) -> Self { + Self { + re: angle.cos(), + im: angle.sin(), + } + } +} + +impl std::fmt::Display for Complex { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_fmt(format_args!("{} + i {}", self.re, self.im)) + } +} + +impl From for Complex { + fn from(value: f64) -> Self { + Self { re: value, im: 0. } + } +} + +impl std::ops::Add for Complex { + type Output = Complex; + + fn add(self, rhs: Complex) -> Self::Output { + Self { + re: self.re + rhs.re, + im: self.im + rhs.im, + } + } +} + +impl std::ops::Add for Complex { + type Output = Complex; + + fn add(self, rhs: f64) -> Self::Output { + self + Complex::from(rhs) + } +} + +impl std::ops::Neg for Complex { + type Output = Complex; + + fn neg(self) -> Self::Output { + Self { + re: -self.re, + im: -self.im, + } + } +} + +impl std::ops::Sub for Complex { + type Output = Complex; + + fn sub(self, rhs: Complex) -> Self::Output { + self + -rhs + } +} + +impl std::ops::Sub for Complex { + type Output = Complex; + + fn sub(self, rhs: f64) -> Self::Output { + self + -rhs + } +} + +impl std::ops::Mul for Complex { + type Output = Complex; + + fn mul(self, rhs: Complex) -> Self::Output { + Self { + re: self.re * rhs.re - self.im * rhs.im, + im: self.re * rhs.im + self.im * rhs.re, + } + } +} + +impl std::ops::Mul for Complex { + type Output = Complex; + + fn mul(self, rhs: f64) -> Self::Output { + Self { + re: self.re * rhs, + im: self.im * rhs, + } + } +} + +impl std::ops::Div for Complex { + type Output = Complex; + + fn div(self, rhs: f64) -> Self::Output { + Self { + re: self.re / rhs, + im: self.im / rhs, + } + } +} + +impl std::ops::AddAssign for Complex { + fn add_assign(&mut self, rhs: Complex) { + *self = *self + rhs; + } +} + +impl std::ops::AddAssign for Complex { + fn add_assign(&mut self, rhs: f64) { + *self = *self + rhs; + } +} + +impl std::ops::SubAssign for Complex { + fn sub_assign(&mut self, rhs: Complex) { + *self = *self - rhs; + } +} + +impl std::ops::SubAssign for Complex { + fn sub_assign(&mut self, rhs: f64) { + *self = *self - rhs; + } +} + +impl std::ops::MulAssign for Complex { + fn mul_assign(&mut self, rhs: Complex) { + *self = *self * rhs; + } +} + +impl std::ops::MulAssign for Complex { + fn mul_assign(&mut self, rhs: f64) { + *self = *self * rhs; + } +} + +impl std::ops::DivAssign for Complex { + fn div_assign(&mut self, rhs: f64) { + *self = *self / rhs; + } +} + +impl std::iter::Sum for Complex { + fn sum>(iter: I) -> Self { + let mut result = Complex::ZERO; + + for item in iter { + result += item; + } + + result + } +} + +impl std::iter::Product for Complex { + fn product>(iter: I) -> Self { + let mut result = Complex::ONE; + + for item in iter { + result *= item; + } + + result + } +} \ No newline at end of file diff --git a/hcpl_fft/Cargo.toml b/hcpl_fft/Cargo.toml new file mode 100644 index 0000000..84965a0 --- /dev/null +++ b/hcpl_fft/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "hcpl_fft" +version = "0.1.0" +edition = "2021" +repository = "https://github.com/THE-nio/hcpl" +license = "MIT" + +[dependencies] +hcpl_algebra = { path = "../hcpl_algebra" } +# Maybe this can be generalised to cover both FFTs and NTTs +hcpl_complex = { path = "../hcpl_complex" } \ No newline at end of file diff --git a/hcpl_fft/src/lib.rs b/hcpl_fft/src/lib.rs new file mode 100644 index 0000000..43fd2d0 --- /dev/null +++ b/hcpl_fft/src/lib.rs @@ -0,0 +1,66 @@ +use hcpl_complex::Complex; +use std::f64::consts::PI; + +/// Performs the bit-reversal permutation on `a` +pub fn bit_reversal(vec: &mut [T]) { + let n = vec.len(); + + let mut j = 0; + for i in 1..n { + let mut bit = n >> 1; + while (j & bit) != 0 { + j ^= bit; + bit >>= 1; + } + j ^= bit; + + if i < j { + vec.swap(i, j); + } + } +} + +/// Performs the in-place Fast Fourier Transform on the slice `a`, whose lenght must be a power of two +pub fn fft(vec: &mut [Complex], inv: bool) { + bit_reversal(vec); + + let n = vec.len(); + let sign = if inv { -1. } else { 1. }; + + let mut width = 2; + while width <= n { + let w_d = Complex::cis(sign * 2. * PI / width as f64); + + for i in (0..n).step_by(width) { + let mut w = Complex::ONE; + for j in 0..width / 2 { + let l = i + j; + let r = i + j + width / 2; + + (vec[l], vec[r]) = (vec[l] + w * vec[r], vec[l] - w * vec[r]); + w *= w_d; + } + } + + width *= 2; + } + + if inv { + for i in 0..n { + vec[i] /= n as f64; + } + } +} + +/// Performs the in-place discrete convolution of the two vectors `a` and `b`. +/// +/// `a` will contain the convolution of `a` and `b`, and `b` will contain the discrete +/// Fourier transform of `b`. +pub fn convolution(a: &mut [Complex], b: &mut [Complex]) { + fft(a, false); + fft(b, false); + for (x, y) in a.iter_mut().zip(b.iter().copied()) { + *x = *x * y; + } + fft(a, true); +} From bbc0cefac9a5270217c0f69520d5b8498ed943e3 Mon Sep 17 00:00:00 2001 From: jakobrs Date: Fri, 19 May 2023 15:49:24 +0200 Subject: [PATCH 2/4] NTTs --- hcpl_complex/Cargo.toml | 4 ++ hcpl_complex/src/lib.rs | 68 ++++++++++++++++++++++- hcpl_fft/Cargo.toml | 3 +- hcpl_fft/src/lib.rs | 96 +++++++++++++++++++++++++++------ hcpl_modnum/Cargo.toml | 1 + hcpl_modnum/src/lib.rs | 43 +++++++++++++++ hcpl_number_theory/Cargo.toml | 1 + hcpl_number_theory/src/lib.rs | 1 + hcpl_number_theory/src/roots.rs | 27 ++++++++++ 9 files changed, 225 insertions(+), 19 deletions(-) create mode 100644 hcpl_number_theory/src/roots.rs diff --git a/hcpl_complex/Cargo.toml b/hcpl_complex/Cargo.toml index 99c2290..c1d6d44 100644 --- a/hcpl_complex/Cargo.toml +++ b/hcpl_complex/Cargo.toml @@ -4,3 +4,7 @@ version = "0.1.0" edition = "2021" repository = "https://github.com/THE-nio/hcpl" license = "MIT" + +[dependencies] +hcpl_algebra = { version = "0.1.0", path = "../hcpl_algebra" } +hcpl_number_theory = { path = "../hcpl_number_theory" } diff --git a/hcpl_complex/src/lib.rs b/hcpl_complex/src/lib.rs index b9af6ec..ff395ec 100644 --- a/hcpl_complex/src/lib.rs +++ b/hcpl_complex/src/lib.rs @@ -30,6 +30,24 @@ impl Complex { im: angle.sin(), } } + + pub fn conj(self) -> Self { + Self { + re: self.re, + im: -self.im, + } + } + + pub fn squared_norm(self) -> f64 { + self.re * self.re + self.im * self.im + } + + pub fn inv(self) -> Self { + Self { + re: self.re / self.squared_norm(), + im: -self.im / self.squared_norm(), + } + } } impl std::fmt::Display for Complex { @@ -44,6 +62,15 @@ impl From for Complex { } } +impl From for Complex { + fn from(value: usize) -> Self { + Self { + re: value as f64, + im: 0., + } + } +} + impl std::ops::Add for Complex { type Output = Complex; @@ -112,6 +139,14 @@ impl std::ops::Mul for Complex { } } +impl std::ops::Div for Complex { + type Output = Complex; + + fn div(self, rhs: Complex) -> Self::Output { + self * rhs.inv() + } +} + impl std::ops::Div for Complex { type Output = Complex; @@ -159,6 +194,12 @@ impl std::ops::MulAssign for Complex { } } +impl std::ops::DivAssign for Complex { + fn div_assign(&mut self, rhs: Complex) { + *self = *self / rhs; + } +} + impl std::ops::DivAssign for Complex { fn div_assign(&mut self, rhs: f64) { *self = *self / rhs; @@ -187,4 +228,29 @@ impl std::iter::Product for Complex { result } -} \ No newline at end of file +} + +impl hcpl_algebra::monoid::AdditiveIdentity for Complex { + const VALUE: Self = Complex::ZERO; +} +impl hcpl_algebra::monoid::MultiplicativeIdentity for Complex { + const VALUE: Self = Complex::ONE; +} + +impl hcpl_number_theory::roots::TryNthRootOfUnity for Complex { + type Error = std::convert::Infallible; + + fn try_nth_root_of_unity(n: usize) -> Result + where + Self: Sized, + { + Ok(Self::cis(2. * std::f64::consts::PI / n as f64)) + } + + fn try_nth_root_of_unity_inv(n: usize) -> Result + where + Self: Sized, + { + Ok(Self::cis(-2. * std::f64::consts::PI / n as f64)) + } +} diff --git a/hcpl_fft/Cargo.toml b/hcpl_fft/Cargo.toml index 84965a0..d0a749c 100644 --- a/hcpl_fft/Cargo.toml +++ b/hcpl_fft/Cargo.toml @@ -7,5 +7,4 @@ license = "MIT" [dependencies] hcpl_algebra = { path = "../hcpl_algebra" } -# Maybe this can be generalised to cover both FFTs and NTTs -hcpl_complex = { path = "../hcpl_complex" } \ No newline at end of file +hcpl_number_theory = { path = "../hcpl_number_theory" } diff --git a/hcpl_fft/src/lib.rs b/hcpl_fft/src/lib.rs index 43fd2d0..05cff3e 100644 --- a/hcpl_fft/src/lib.rs +++ b/hcpl_fft/src/lib.rs @@ -1,5 +1,6 @@ -use hcpl_complex::Complex; -use std::f64::consts::PI; +use hcpl_algebra::monoid::MultiplicativeIdentity; +use hcpl_number_theory::roots::TryNthRootOfUnity; +use std::fmt::Debug; /// Performs the bit-reversal permutation on `a` pub fn bit_reversal(vec: &mut [T]) { @@ -20,43 +21,106 @@ pub fn bit_reversal(vec: &mut [T]) { } } -/// Performs the in-place Fast Fourier Transform on the slice `a`, whose lenght must be a power of two -pub fn fft(vec: &mut [Complex], inv: bool) { +mod stack_stack { + use std::mem::MaybeUninit; + + const BUF_SIZE: usize = 100; + + /// A stack-allocated stack + pub(crate) struct StackStack { + buffer: [MaybeUninit; BUF_SIZE], + len: usize, + } + + impl StackStack { + pub(crate) fn new() -> Self { + Self { + buffer: unsafe { + MaybeUninit::<[MaybeUninit; BUF_SIZE]>::uninit().assume_init() + }, + len: 0, + } + } + + pub(crate) fn push(&mut self, val: T) -> Option<()> { + if self.len == BUF_SIZE { + return None; + } + + self.buffer[self.len].write(val); + self.len += 1; + + Some(()) + } + + pub(crate) fn pop(&mut self) -> Option { + if self.len == 0 { + return None; + } + + self.len -= 1; + + // SAFETY: this element will have been written to exactly once, and won't be read again + Some(unsafe { self.buffer[self.len].assume_init_read() }) + } + } +} + +/// Performs the in-place Fast Fourier Transform on the slice `a`, whose lenght must be a power of two. DOES NOT +/// PERFORM NORMALISATION +pub fn fft(vec: &mut [T], inv: bool) +where + T: TryNthRootOfUnity, + ::Error: Debug, +{ bit_reversal(vec); let n = vec.len(); - let sign = if inv { -1. } else { 1. }; + + let mut roots = stack_stack::StackStack::new(); + + let mut q = n; + let mut last: T = if inv { + TryNthRootOfUnity::try_nth_root_of_unity_inv(n) + } else { + TryNthRootOfUnity::try_nth_root_of_unity(n) + } + .unwrap(); + while q >= 2 { + roots.push(last); + q /= 2; + last = last * last; + } let mut width = 2; while width <= n { - let w_d = Complex::cis(sign * 2. * PI / width as f64); + let w_d = roots.pop().unwrap(); for i in (0..n).step_by(width) { - let mut w = Complex::ONE; + let mut w = ::VALUE; for j in 0..width / 2 { let l = i + j; let r = i + j + width / 2; (vec[l], vec[r]) = (vec[l] + w * vec[r], vec[l] - w * vec[r]); - w *= w_d; + w = w * w_d; } } width *= 2; } - - if inv { - for i in 0..n { - vec[i] /= n as f64; - } - } } -/// Performs the in-place discrete convolution of the two vectors `a` and `b`. +/// Performs the in-place discrete convolution of the two vectors `a` and `b`. DOES NOT PERFORM +/// NORMALISATION /// /// `a` will contain the convolution of `a` and `b`, and `b` will contain the discrete /// Fourier transform of `b`. -pub fn convolution(a: &mut [Complex], b: &mut [Complex]) { +pub fn convolution(a: &mut [T], b: &mut [T]) +where + T: TryNthRootOfUnity, + ::Error: Debug, +{ fft(a, false); fft(b, false); for (x, y) in a.iter_mut().zip(b.iter().copied()) { diff --git a/hcpl_modnum/Cargo.toml b/hcpl_modnum/Cargo.toml index 60e2682..ab197cf 100644 --- a/hcpl_modnum/Cargo.toml +++ b/hcpl_modnum/Cargo.toml @@ -7,3 +7,4 @@ license = "MIT" [dependencies] hcpl_algebra = { path = "../hcpl_algebra" } +hcpl_number_theory = { path = "../hcpl_number_theory" } diff --git a/hcpl_modnum/src/lib.rs b/hcpl_modnum/src/lib.rs index f885922..4ca797f 100644 --- a/hcpl_modnum/src/lib.rs +++ b/hcpl_modnum/src/lib.rs @@ -135,3 +135,46 @@ impl hcpl_algebra::monoid::AdditiveIdentity for Modnum { impl hcpl_algebra::monoid::MultiplicativeIdentity for Modnum { const VALUE: Self = Self::new(1); } + +impl hcpl_number_theory::roots::TryNthRootOfUnity for Modnum { + type Error = &'static str; + + fn try_nth_root_of_unity(n: usize) -> Result { + if n == 0 { + Err("n must be positive") + } else if n == 1 { + Ok(Self::new(1)) + } else if (N - 1) as usize % n != 0 { + Err("n is not a divisor of MOD - 1") + } else { + for base in 2.. { + #[cfg(debug_assertions)] + eprintln!("trying base {}", base); + let attempt = Self::new(base).pow((N - 1) as usize / n); + let mut d = 0; + if loop { + d += 1; + if d * d > n { + break true; + } + if n % d != 0 { + continue; + } + if attempt.pow(d as usize) == Self::new(1) { + break false; + } + if d != 1 && d * d != n && attempt.pow((n / d) as usize) == Self::new(1) { + break false; + } + } { + return Ok(attempt); + } + } + unreachable!() + } + } + + fn try_nth_root_of_unity_inv(n: usize) -> Result { + Ok(Self::try_nth_root_of_unity(n)?.inv()) + } +} diff --git a/hcpl_number_theory/Cargo.toml b/hcpl_number_theory/Cargo.toml index 86677b1..e38e715 100644 --- a/hcpl_number_theory/Cargo.toml +++ b/hcpl_number_theory/Cargo.toml @@ -6,4 +6,5 @@ repository = "https://github.com/THE-nio/hcpl" license = "MIT" [dependencies] +hcpl_algebra = { path = "../hcpl_algebra" } hcpl_integer = { path = "../hcpl_integer" } diff --git a/hcpl_number_theory/src/lib.rs b/hcpl_number_theory/src/lib.rs index 7af74c6..9cad09a 100644 --- a/hcpl_number_theory/src/lib.rs +++ b/hcpl_number_theory/src/lib.rs @@ -1,4 +1,5 @@ mod gcd; pub mod prime; +pub mod roots; pub use gcd::gcd; diff --git a/hcpl_number_theory/src/roots.rs b/hcpl_number_theory/src/roots.rs new file mode 100644 index 0000000..62eb189 --- /dev/null +++ b/hcpl_number_theory/src/roots.rs @@ -0,0 +1,27 @@ +use hcpl_algebra::{monoid::MultiplicativeIdentity, Ring}; + +pub trait TryNthRootOfUnity: Sized + Copy + Ring { + type Error; + + /// Returns a principal nth root of unity. Must be deterministic. + fn try_nth_root_of_unity(n: usize) -> Result; + + /// Returns a value s.t. try_nth_root_of_unity(n) * try_nth_root_of_unity_inv(n) = 1 + fn try_nth_root_of_unity_inv(n: usize) -> Result { + let mut root = Self::try_nth_root_of_unity(n)?; + + let mut result = ::VALUE; + let mut exp = n - 1; + + while exp > 0 { + if exp % 2 == 1 { + result = result * root; + } + + root = root * root; + exp >>= 1; + } + + Ok(root) + } +} From c9c9adaa0596568e9a9ce77e52f22a60b5262ead Mon Sep 17 00:00:00 2001 From: jakobrs Date: Sat, 20 May 2023 13:43:43 +0200 Subject: [PATCH 3/4] Real FFTs --- hcpl_fft/Cargo.toml | 4 ++ hcpl_fft/src/lib.rs | 3 ++ hcpl_fft/src/real.rs | 113 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 120 insertions(+) create mode 100644 hcpl_fft/src/real.rs diff --git a/hcpl_fft/Cargo.toml b/hcpl_fft/Cargo.toml index d0a749c..d9230f9 100644 --- a/hcpl_fft/Cargo.toml +++ b/hcpl_fft/Cargo.toml @@ -8,3 +8,7 @@ license = "MIT" [dependencies] hcpl_algebra = { path = "../hcpl_algebra" } hcpl_number_theory = { path = "../hcpl_number_theory" } +hcpl_complex = { path = "../hcpl_complex", optional = true } + +[features] +real = [ "dep:hcpl_complex" ] diff --git a/hcpl_fft/src/lib.rs b/hcpl_fft/src/lib.rs index 05cff3e..2f5a558 100644 --- a/hcpl_fft/src/lib.rs +++ b/hcpl_fft/src/lib.rs @@ -1,3 +1,6 @@ +#[cfg(feature = "real")] +pub mod real; + use hcpl_algebra::monoid::MultiplicativeIdentity; use hcpl_number_theory::roots::TryNthRootOfUnity; use std::fmt::Debug; diff --git a/hcpl_fft/src/real.rs b/hcpl_fft/src/real.rs new file mode 100644 index 0000000..6bda867 --- /dev/null +++ b/hcpl_fft/src/real.rs @@ -0,0 +1,113 @@ +//! FFTs of real-valued data +//! +//! Functions in this module should act the same as functions in the crate root but may assume that the input or output is real-valued. +//! +//! PLEASE DON'T RELY ON THIS AT THE MOMENT AS IT LACKS TESTING + +use std::f64::consts::PI; + +use hcpl_complex::Complex; + +use crate::fft; + +/// Performs both recursive calls in the Cooley-Tukey algorithm with only one call to `fft`. `vec` must be real-valued. +pub fn double_real_ffft(vec: &mut [Complex]) { + let n = vec.len() / 2; + + for i in 0..n { + vec[i] = Complex { + re: vec[2 * i].re, + im: vec[2 * i + 1].re, + }; + } + + fft(&mut vec[..n], false); + + vec.copy_within(..n, n); + vec[n + 1..].reverse(); + + // To see why this works, consider the case where all imaginary components are zero. The FFT famously exhibits conjugate symmetry + // on real data, meaning that vec[-i] = vec[i]*. Therefore `z_c` (= vec[-i]* = vec[i]** = vec[i]) is equal to `z`, and + // `(z + z_c) / 2 == z` and `-i * (z - z_c) / 2 == 0`. These `z` values are simply the DHT of the input array, as expected. + // + // The same argument works if all real components are zero, except the output will have been multiplied by `i`. Using linearity one + // can show that vec[..n] will contain the DHT of the even-indexed values in the input and vec[n..] will contain the DHT of the + // odd-indexed values in the input. + for i in 0..n { + let z = vec[i]; + let z_c = vec[i + n].conj(); + + vec[i] = (z + z_c) / 2.; + vec[i + n] = -Complex::I * (z - z_c) / 2.; + } +} + +/// Calculates the FFT of a sequence of real numbers +pub fn real_ffft(vec: &mut [Complex]) { + let n = vec.len(); + + double_real_ffft(vec); + + let mut w = Complex::ONE; + let w_d = Complex::cis(2. * PI / n as f64); + for i in 0..n / 2 { + (vec[i], vec[i + n / 2]) = (vec[i] + w * vec[i + n / 2], vec[i] - w * vec[i + n / 2]); + w *= w_d; + } +} + +// These functions were written by manually inverting the functions above. + +/// Inverse of double_real_fft, up to normalisation +pub fn double_real_ifft(vec: &mut [Complex]) { + // a, b = (x + y)/2, (x - y) / 2i + // a + bi = x + // a - bi = y + + let n = vec.len() / 2; + + for i in 0..n { + vec[i] = vec[i] + Complex::I * vec[i + n]; + } + + fft(&mut vec[..n], true); + + for i in (0..n).rev() { + (vec[2 * i], vec[2 * i + 1]) = (vec[i].re.into(), vec[i].im.into()); + } +} + +/// Calculates the IFFT of a sequence which is the DFT of a sequence of real numbers. Does NOT perform normalisation. +pub fn real_ifft(vec: &mut [Complex]) { + // a, b = x + w y, x - w y + // a + b = x + w y + x - w y = 2 x + // a - b = x + w y - x + w y = 2 w y + // (a + b) / 2 = x + // (a - b) / 2w = y + + let n = vec.len(); + + let mut w = Complex::ONE; + let w_d = Complex::cis(-2. * PI / n as f64); + + for i in 0..n/2 { + (vec[i], vec[i + n / 2]) = ((vec[i] + vec[i + n / 2]) / 2., w * (vec[i] - vec[i + n / 2]) / 2.); + w *= w_d; + } + + double_real_ifft(vec); + + // apply missing denormalisation + for i in 0..n { + vec[i] *= 2.; + } +} + +/// Like real_ffft and real_ifft but using an extra parameter to choose direction +pub fn real_fft(vec: &mut [Complex], inverse: bool) { + if inverse { + real_ifft(vec); + } else { + real_ffft(vec); + } +} From 85c3799b74b721a15b02e918ec09b64ff42eca99 Mon Sep 17 00:00:00 2001 From: jakobrs Date: Sat, 20 May 2023 17:32:17 +0200 Subject: [PATCH 4/4] Move StackStack to its own module --- hcpl_fft/src/lib.rs | 77 ++++++++++--------------------------- hcpl_fft/src/stack_stack.rs | 42 ++++++++++++++++++++ 2 files changed, 63 insertions(+), 56 deletions(-) create mode 100644 hcpl_fft/src/stack_stack.rs diff --git a/hcpl_fft/src/lib.rs b/hcpl_fft/src/lib.rs index 2f5a558..2443687 100644 --- a/hcpl_fft/src/lib.rs +++ b/hcpl_fft/src/lib.rs @@ -1,11 +1,13 @@ #[cfg(feature = "real")] pub mod real; +mod stack_stack; + use hcpl_algebra::monoid::MultiplicativeIdentity; use hcpl_number_theory::roots::TryNthRootOfUnity; use std::fmt::Debug; -/// Performs the bit-reversal permutation on `a` +/// Performs the bit-reversal permutation on `vec` pub fn bit_reversal(vec: &mut [T]) { let n = vec.len(); @@ -24,77 +26,40 @@ pub fn bit_reversal(vec: &mut [T]) { } } -mod stack_stack { - use std::mem::MaybeUninit; - - const BUF_SIZE: usize = 100; - - /// A stack-allocated stack - pub(crate) struct StackStack { - buffer: [MaybeUninit; BUF_SIZE], - len: usize, - } - - impl StackStack { - pub(crate) fn new() -> Self { - Self { - buffer: unsafe { - MaybeUninit::<[MaybeUninit; BUF_SIZE]>::uninit().assume_init() - }, - len: 0, - } - } - - pub(crate) fn push(&mut self, val: T) -> Option<()> { - if self.len == BUF_SIZE { - return None; - } - - self.buffer[self.len].write(val); - self.len += 1; - - Some(()) - } - - pub(crate) fn pop(&mut self) -> Option { - if self.len == 0 { - return None; - } - - self.len -= 1; - - // SAFETY: this element will have been written to exactly once, and won't be read again - Some(unsafe { self.buffer[self.len].assume_init_read() }) - } - } -} - -/// Performs the in-place Fast Fourier Transform on the slice `a`, whose lenght must be a power of two. DOES NOT -/// PERFORM NORMALISATION -pub fn fft(vec: &mut [T], inv: bool) +fn get_roots(mut n: usize, inv: bool) -> stack_stack::StackStack where T: TryNthRootOfUnity, ::Error: Debug, { - bit_reversal(vec); - - let n = vec.len(); - let mut roots = stack_stack::StackStack::new(); - let mut q = n; let mut last: T = if inv { TryNthRootOfUnity::try_nth_root_of_unity_inv(n) } else { TryNthRootOfUnity::try_nth_root_of_unity(n) } .unwrap(); - while q >= 2 { + while n >= 2 { roots.push(last); - q /= 2; + n /= 2; last = last * last; } + roots +} + +/// Performs the in-place Fast Fourier Transform on the slice `vec`, whose lenght must be a power of two. DOES NOT +/// PERFORM NORMALISATION +pub fn fft(vec: &mut [T], inv: bool) +where + T: TryNthRootOfUnity, + ::Error: Debug, +{ + bit_reversal(vec); + + let n = vec.len(); + + let mut roots = get_roots(n, inv); let mut width = 2; while width <= n { let w_d = roots.pop().unwrap(); diff --git a/hcpl_fft/src/stack_stack.rs b/hcpl_fft/src/stack_stack.rs new file mode 100644 index 0000000..a9ae7e3 --- /dev/null +++ b/hcpl_fft/src/stack_stack.rs @@ -0,0 +1,42 @@ +use std::mem::MaybeUninit; + +const BUF_SIZE: usize = 50; + +/// A stack-allocated stack +pub(crate) struct StackStack { + buffer: [MaybeUninit; BUF_SIZE], + len: usize, +} + +impl StackStack { + pub(crate) fn new() -> Self { + Self { + buffer: unsafe { + MaybeUninit::<[MaybeUninit; BUF_SIZE]>::uninit().assume_init() + }, + len: 0, + } + } + + pub(crate) fn push(&mut self, val: T) -> Option<()> { + if self.len == BUF_SIZE { + return None; + } + + self.buffer[self.len].write(val); + self.len += 1; + + Some(()) + } + + pub(crate) fn pop(&mut self) -> Option { + if self.len == 0 { + return None; + } + + self.len -= 1; + + // SAFETY: this element will have been written to exactly once, and won't be read again + Some(unsafe { self.buffer[self.len].assume_init_read() }) + } +}