From e7574f793af4f7e75ad586b564d7346102f7d203 Mon Sep 17 00:00:00 2001 From: declanhealy2 Date: Mon, 23 Mar 2026 00:14:10 +0000 Subject: [PATCH 1/2] Add fftfreq, rfftfreq and scalar axes support for fftshift/ifftshift --- docs/src/python/fft.rst | 2 + python/src/fft.cpp | 103 ++++++++++++++++++++++++++++++++++----- python/tests/test_fft.py | 48 ++++++++++++++++++ 3 files changed, 141 insertions(+), 12 deletions(-) diff --git a/docs/src/python/fft.rst b/docs/src/python/fft.rst index 36d9d78380..78bfe7f8fc 100644 --- a/docs/src/python/fft.rst +++ b/docs/src/python/fft.rst @@ -20,5 +20,7 @@ FFT irfft2 rfftn irfftn + fftfreq + rfftfreq fftshift ifftshift diff --git a/python/src/fft.cpp b/python/src/fft.cpp index 522e1064c7..48123895c1 100644 --- a/python/src/fft.cpp +++ b/python/src/fft.cpp @@ -14,6 +14,23 @@ namespace mx = mlx::core; namespace nb = nanobind; using namespace nb::literals; +namespace { + +using IntOrVec = std::variant>; +using AxesArg = std::optional; + +std::optional> normalize_axes(const AxesArg& axes) { + if (!axes.has_value()) { + return std::nullopt; + } + if (auto axis = std::get_if(&axes.value())) { + return std::vector{*axis}; + } + return std::get>(axes.value()); +} + +} // namespace + void init_fft(nb::module_& parent_module) { auto m = parent_module.def_submodule( "fft", "mlx.core.fft: Fast Fourier Transforms."); @@ -460,13 +477,76 @@ void init_fft(nb::module_& parent_module) { Returns: array: The real array containing the inverse of :func:`rfftn`. )pbdoc"); + m.def( + "fftfreq", + [](int n, double d, mx::StreamOrDevice s) { + if (n <= 0) { + throw std::invalid_argument("[fftfreq] `n` must be greater than 0."); + } + if (d == 0.0) { + throw std::invalid_argument("[fftfreq] `d` must be non-zero."); + } + + auto pos = mx::arange(0, (n + 1) / 2, mx::float32, s); + auto neg = mx::arange(-(n / 2), 0, mx::float32, s); + auto freqs = mx::concatenate({pos, neg}, s); + auto scale = mx::array( + static_cast(1.0 / (static_cast(n) * d)), + mx::float32); + return mx::multiply(freqs, scale, s); + }, + "n"_a, + "d"_a = 1.0, + "stream"_a = nb::none(), + R"pbdoc( + Return the discrete Fourier Transform sample frequencies. + + Args: + n (int): Window length. + d (float, optional): Sample spacing. The default is ``1.0``. + + Returns: + array: The sample frequencies as a one-dimensional array of type ``float32``. + )pbdoc"); + m.def( + "rfftfreq", + [](int n, double d, mx::StreamOrDevice s) { + if (n <= 0) { + throw std::invalid_argument("[rfftfreq] `n` must be greater than 0."); + } + if (d == 0.0) { + throw std::invalid_argument("[rfftfreq] `d` must be non-zero."); + } + + auto freqs = mx::arange(0, n / 2 + 1, mx::float32, s); + auto scale = mx::array( + static_cast(1.0 / (static_cast(n) * d)), + mx::float32); + return mx::multiply(freqs, scale, s); + }, + "n"_a, + "d"_a = 1.0, + "stream"_a = nb::none(), + R"pbdoc( + Return the discrete Fourier Transform sample frequencies + for use with :func:`rfft` and :func:`irfft`. + + The returned array contains the non-negative frequency terms + in the range ``[0, floor(n/2)]``. + + Args: + n (int): Window length. + d (float, optional): Sample spacing. The default is ``1.0``. + + Returns: + array: The sample frequencies as a one-dimensional array of type ``float32``. + )pbdoc"); m.def( "fftshift", - [](const mx::array& a, - const std::optional>& axes, - mx::StreamOrDevice s) { - if (axes.has_value()) { - return mx::fft::fftshift(a, axes.value(), s); + [](const mx::array& a, const AxesArg& axes, mx::StreamOrDevice s) { + auto normalized_axes = normalize_axes(axes); + if (normalized_axes.has_value()) { + return mx::fft::fftshift(a, normalized_axes.value(), s); } else { return mx::fft::fftshift(a, s); } @@ -479,7 +559,7 @@ void init_fft(nb::module_& parent_module) { Args: a (array): The input array. - axes (list(int), optional): Axes over which to perform the shift. + axes (int or list(int), optional): Axis or axes over which to perform the shift. If ``None``, shift all axes. Returns: @@ -487,11 +567,10 @@ void init_fft(nb::module_& parent_module) { )pbdoc"); m.def( "ifftshift", - [](const mx::array& a, - const std::optional>& axes, - mx::StreamOrDevice s) { - if (axes.has_value()) { - return mx::fft::ifftshift(a, axes.value(), s); + [](const mx::array& a, const AxesArg& axes, mx::StreamOrDevice s) { + auto normalized_axes = normalize_axes(axes); + if (normalized_axes.has_value()) { + return mx::fft::ifftshift(a, normalized_axes.value(), s); } else { return mx::fft::ifftshift(a, s); } @@ -505,7 +584,7 @@ void init_fft(nb::module_& parent_module) { Args: a (array): The input array. - axes (list(int), optional): Axes over which to perform the inverse shift. + axes (int or list(int), optional): Axis or axes over which to perform the inverse shift. If ``None``, shift all axes. Returns: diff --git a/python/tests/test_fft.py b/python/tests/test_fft.py index 9921c7d985..b3f78fe80f 100644 --- a/python/tests/test_fft.py +++ b/python/tests/test_fft.py @@ -217,6 +217,48 @@ def test_fft_throws(self): with self.assertRaises(ValueError): mx.fft.irfftn(x) + def test_fftfreq(self): + for n, d in [(1, 1.0), (4, 0.5), (5, 0.25), (8, -0.5), (6, 1.0)]: + out = mx.fft.fftfreq(n, d=d) + expected = np.fft.fftfreq(n, d=d).astype(np.float32) + self.assertEqual(out.dtype, mx.float32) + np.testing.assert_allclose(out, expected, atol=0.0, rtol=0.0) + + with self.assertRaises(ValueError): + mx.fft.fftfreq(0) + + with self.assertRaises(ValueError): + mx.fft.fftfreq(-1) + + # Test default d=1.0 + out = mx.fft.fftfreq(8) + expected = np.fft.fftfreq(8).astype(np.float32) + np.testing.assert_allclose(out, expected, atol=0.0, rtol=0.0) + + with self.assertRaises(ValueError): + mx.fft.fftfreq(4, d=0.0) + + def test_rfftfreq(self): + for n, d in [(1, 1.0), (4, 0.5), (5, 0.25), (8, -0.5), (6, 1.0)]: + out = mx.fft.rfftfreq(n, d=d) + expected = np.fft.rfftfreq(n, d=d).astype(np.float32) + self.assertEqual(out.dtype, mx.float32) + np.testing.assert_allclose(out, expected, atol=0.0, rtol=0.0) + + # Test default d=1.0 + out = mx.fft.rfftfreq(8) + expected = np.fft.rfftfreq(8).astype(np.float32) + np.testing.assert_allclose(out, expected, atol=0.0, rtol=0.0) + + with self.assertRaises(ValueError): + mx.fft.rfftfreq(0) + + with self.assertRaises(ValueError): + mx.fft.rfftfreq(-1) + + with self.assertRaises(ValueError): + mx.fft.rfftfreq(4, d=0.0) + def test_fftshift(self): # Test 1D arrays r = np.random.rand(100).astype(np.float32) @@ -224,11 +266,14 @@ def test_fftshift(self): # Test with specific axis r = np.random.rand(4, 6).astype(np.float32) + self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=0) + self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=1) self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[0]) self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[1]) self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[0, 1]) # Test with negative axes + self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=-1) self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[-1]) # Test with odd lengths @@ -249,11 +294,14 @@ def test_ifftshift(self): # Test with specific axis r = np.random.rand(4, 6).astype(np.float32) + self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=0) + self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=1) self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[0]) self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[1]) self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[0, 1]) # Test with negative axes + self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=-1) self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[-1]) # Test with odd lengths From 2f7d331e0ad435656e77d74fc2747521fa1e8747 Mon Sep 17 00:00:00 2001 From: declanhealy2 Date: Wed, 25 Mar 2026 11:36:54 +0000 Subject: [PATCH 2/2] Move fftfreq and rfftfreq to C++ side --- mlx/fft.cpp | 28 ++++++++++++++++++++++++++++ mlx/fft.h | 7 +++++++ python/src/fft.cpp | 28 ++-------------------------- 3 files changed, 37 insertions(+), 26 deletions(-) diff --git a/mlx/fft.cpp b/mlx/fft.cpp index 6cad15c33b..8ddc1aca46 100644 --- a/mlx/fft.cpp +++ b/mlx/fft.cpp @@ -310,4 +310,32 @@ array ifftshift(const array& a, StreamOrDevice s /* = {} */) { return ifftshift(a, axes, s); } +array fftfreq(int n, double d /* = 1.0 */, StreamOrDevice s /* = {} */) { + if (n <= 0) { + throw std::invalid_argument("[fftfreq] `n` must be greater than 0."); + } + if (d == 0.0) { + throw std::invalid_argument("[fftfreq] `d` must be non-zero."); + } + auto pos = arange(0, (n + 1) / 2, float32, s); + auto neg = arange(-(n / 2), 0, float32, s); + auto freqs = concatenate({pos, neg}, s); + auto scale = + array(static_cast(1.0 / (static_cast(n) * d)), float32); + return multiply(freqs, scale, s); +} + +array rfftfreq(int n, double d /* = 1.0 */, StreamOrDevice s /* = {} */) { + if (n <= 0) { + throw std::invalid_argument("[rfftfreq] `n` must be greater than 0."); + } + if (d == 0.0) { + throw std::invalid_argument("[rfftfreq] `d` must be non-zero."); + } + auto freqs = arange(0, n / 2 + 1, float32, s); + auto scale = + array(static_cast(1.0 / (static_cast(n) * d)), float32); + return multiply(freqs, scale, s); +} + } // namespace mlx::core::fft diff --git a/mlx/fft.h b/mlx/fft.h index a36af78de9..f69b8aa2c4 100644 --- a/mlx/fft.h +++ b/mlx/fft.h @@ -212,6 +212,13 @@ inline array irfft2( StreamOrDevice s = {}) { return irfftn(a, axes, norm, s); } +/** Compute the discrete Fourier Transform sample frequencies. */ +MLX_API array fftfreq(int n, double d = 1.0, StreamOrDevice s = {}); + +/** Compute the discrete Fourier Transform sample frequencies for + * `rfft`/`irfft`. */ +MLX_API array rfftfreq(int n, double d = 1.0, StreamOrDevice s = {}); + /** Shift the zero-frequency component to the center of the spectrum. */ MLX_API array fftshift(const array& a, StreamOrDevice s = {}); diff --git a/python/src/fft.cpp b/python/src/fft.cpp index dc9bee537a..8c95ffe02d 100644 --- a/python/src/fft.cpp +++ b/python/src/fft.cpp @@ -557,20 +557,7 @@ void init_fft(nb::module_& parent_module) { m.def( "fftfreq", [](int n, double d, mx::StreamOrDevice s) { - if (n <= 0) { - throw std::invalid_argument("[fftfreq] `n` must be greater than 0."); - } - if (d == 0.0) { - throw std::invalid_argument("[fftfreq] `d` must be non-zero."); - } - - auto pos = mx::arange(0, (n + 1) / 2, mx::float32, s); - auto neg = mx::arange(-(n / 2), 0, mx::float32, s); - auto freqs = mx::concatenate({pos, neg}, s); - auto scale = mx::array( - static_cast(1.0 / (static_cast(n) * d)), - mx::float32); - return mx::multiply(freqs, scale, s); + return mx::fft::fftfreq(n, d, s); }, "n"_a, "d"_a = 1.0, @@ -588,18 +575,7 @@ void init_fft(nb::module_& parent_module) { m.def( "rfftfreq", [](int n, double d, mx::StreamOrDevice s) { - if (n <= 0) { - throw std::invalid_argument("[rfftfreq] `n` must be greater than 0."); - } - if (d == 0.0) { - throw std::invalid_argument("[rfftfreq] `d` must be non-zero."); - } - - auto freqs = mx::arange(0, n / 2 + 1, mx::float32, s); - auto scale = mx::array( - static_cast(1.0 / (static_cast(n) * d)), - mx::float32); - return mx::multiply(freqs, scale, s); + return mx::fft::rfftfreq(n, d, s); }, "n"_a, "d"_a = 1.0,