diff --git a/docs/src/python/fft.rst b/docs/src/python/fft.rst index 36d9d78380..78bfe7f8fc 100644 --- a/docs/src/python/fft.rst +++ b/docs/src/python/fft.rst @@ -20,5 +20,7 @@ FFT irfft2 rfftn irfftn + fftfreq + rfftfreq fftshift ifftshift diff --git a/mlx/fft.cpp b/mlx/fft.cpp index 6cad15c33b..8ddc1aca46 100644 --- a/mlx/fft.cpp +++ b/mlx/fft.cpp @@ -310,4 +310,32 @@ array ifftshift(const array& a, StreamOrDevice s /* = {} */) { return ifftshift(a, axes, s); } +array fftfreq(int n, double d /* = 1.0 */, StreamOrDevice s /* = {} */) { + if (n <= 0) { + throw std::invalid_argument("[fftfreq] `n` must be greater than 0."); + } + if (d == 0.0) { + throw std::invalid_argument("[fftfreq] `d` must be non-zero."); + } + auto pos = arange(0, (n + 1) / 2, float32, s); + auto neg = arange(-(n / 2), 0, float32, s); + auto freqs = concatenate({pos, neg}, s); + auto scale = + array(static_cast(1.0 / (static_cast(n) * d)), float32); + return multiply(freqs, scale, s); +} + +array rfftfreq(int n, double d /* = 1.0 */, StreamOrDevice s /* = {} */) { + if (n <= 0) { + throw std::invalid_argument("[rfftfreq] `n` must be greater than 0."); + } + if (d == 0.0) { + throw std::invalid_argument("[rfftfreq] `d` must be non-zero."); + } + auto freqs = arange(0, n / 2 + 1, float32, s); + auto scale = + array(static_cast(1.0 / (static_cast(n) * d)), float32); + return multiply(freqs, scale, s); +} + } // namespace mlx::core::fft diff --git a/mlx/fft.h b/mlx/fft.h index a36af78de9..f69b8aa2c4 100644 --- a/mlx/fft.h +++ b/mlx/fft.h @@ -212,6 +212,13 @@ inline array irfft2( StreamOrDevice s = {}) { return irfftn(a, axes, norm, s); } +/** Compute the discrete Fourier Transform sample frequencies. */ +MLX_API array fftfreq(int n, double d = 1.0, StreamOrDevice s = {}); + +/** Compute the discrete Fourier Transform sample frequencies for + * `rfft`/`irfft`. */ +MLX_API array rfftfreq(int n, double d = 1.0, StreamOrDevice s = {}); + /** Shift the zero-frequency component to the center of the spectrum. */ MLX_API array fftshift(const array& a, StreamOrDevice s = {}); diff --git a/python/src/fft.cpp b/python/src/fft.cpp index 8333c3f1f4..8c95ffe02d 100644 --- a/python/src/fft.cpp +++ b/python/src/fft.cpp @@ -18,6 +18,19 @@ using namespace nb::literals; namespace { +using IntOrVec = std::variant>; +using AxesArg = std::optional; + +std::optional> normalize_axes(const AxesArg& axes) { + if (!axes.has_value()) { + return std::nullopt; + } + if (auto axis = std::get_if(&axes.value())) { + return std::vector{*axis}; + } + return std::get>(axes.value()); +} + mx::fft::FFTNorm parse_norm(std::string_view norm, std::string_view op) { if (norm == "backward") { return mx::fft::FFTNorm::Backward; @@ -541,13 +554,52 @@ void init_fft(nb::module_& parent_module) { Returns: array: The real array containing the inverse of :func:`rfftn`. )pbdoc"); + m.def( + "fftfreq", + [](int n, double d, mx::StreamOrDevice s) { + return mx::fft::fftfreq(n, d, s); + }, + "n"_a, + "d"_a = 1.0, + "stream"_a = nb::none(), + R"pbdoc( + Return the discrete Fourier Transform sample frequencies. + + Args: + n (int): Window length. + d (float, optional): Sample spacing. The default is ``1.0``. + + Returns: + array: The sample frequencies as a one-dimensional array of type ``float32``. + )pbdoc"); + m.def( + "rfftfreq", + [](int n, double d, mx::StreamOrDevice s) { + return mx::fft::rfftfreq(n, d, s); + }, + "n"_a, + "d"_a = 1.0, + "stream"_a = nb::none(), + R"pbdoc( + Return the discrete Fourier Transform sample frequencies + for use with :func:`rfft` and :func:`irfft`. + + The returned array contains the non-negative frequency terms + in the range ``[0, floor(n/2)]``. + + Args: + n (int): Window length. + d (float, optional): Sample spacing. The default is ``1.0``. + + Returns: + array: The sample frequencies as a one-dimensional array of type ``float32``. + )pbdoc"); m.def( "fftshift", - [](const mx::array& a, - const std::optional>& axes, - mx::StreamOrDevice s) { - if (axes.has_value()) { - return mx::fft::fftshift(a, axes.value(), s); + [](const mx::array& a, const AxesArg& axes, mx::StreamOrDevice s) { + auto normalized_axes = normalize_axes(axes); + if (normalized_axes.has_value()) { + return mx::fft::fftshift(a, normalized_axes.value(), s); } else { return mx::fft::fftshift(a, s); } @@ -560,7 +612,7 @@ void init_fft(nb::module_& parent_module) { Args: a (array): The input array. - axes (list(int), optional): Axes over which to perform the shift. + axes (int or list(int), optional): Axis or axes over which to perform the shift. If ``None``, shift all axes. Returns: @@ -568,11 +620,10 @@ void init_fft(nb::module_& parent_module) { )pbdoc"); m.def( "ifftshift", - [](const mx::array& a, - const std::optional>& axes, - mx::StreamOrDevice s) { - if (axes.has_value()) { - return mx::fft::ifftshift(a, axes.value(), s); + [](const mx::array& a, const AxesArg& axes, mx::StreamOrDevice s) { + auto normalized_axes = normalize_axes(axes); + if (normalized_axes.has_value()) { + return mx::fft::ifftshift(a, normalized_axes.value(), s); } else { return mx::fft::ifftshift(a, s); } @@ -586,7 +637,7 @@ void init_fft(nb::module_& parent_module) { Args: a (array): The input array. - axes (list(int), optional): Axes over which to perform the inverse shift. + axes (int or list(int), optional): Axis or axes over which to perform the inverse shift. If ``None``, shift all axes. Returns: diff --git a/python/tests/test_fft.py b/python/tests/test_fft.py index 23aaf27af0..9c03530b04 100644 --- a/python/tests/test_fft.py +++ b/python/tests/test_fft.py @@ -246,6 +246,48 @@ def test_fft_throws(self): with self.assertRaises(ValueError): mx.fft.fft(mx.array([1.0]), norm="invalid") + def test_fftfreq(self): + for n, d in [(1, 1.0), (4, 0.5), (5, 0.25), (8, -0.5), (6, 1.0)]: + out = mx.fft.fftfreq(n, d=d) + expected = np.fft.fftfreq(n, d=d).astype(np.float32) + self.assertEqual(out.dtype, mx.float32) + np.testing.assert_allclose(out, expected, atol=0.0, rtol=0.0) + + with self.assertRaises(ValueError): + mx.fft.fftfreq(0) + + with self.assertRaises(ValueError): + mx.fft.fftfreq(-1) + + # Test default d=1.0 + out = mx.fft.fftfreq(8) + expected = np.fft.fftfreq(8).astype(np.float32) + np.testing.assert_allclose(out, expected, atol=0.0, rtol=0.0) + + with self.assertRaises(ValueError): + mx.fft.fftfreq(4, d=0.0) + + def test_rfftfreq(self): + for n, d in [(1, 1.0), (4, 0.5), (5, 0.25), (8, -0.5), (6, 1.0)]: + out = mx.fft.rfftfreq(n, d=d) + expected = np.fft.rfftfreq(n, d=d).astype(np.float32) + self.assertEqual(out.dtype, mx.float32) + np.testing.assert_allclose(out, expected, atol=0.0, rtol=0.0) + + # Test default d=1.0 + out = mx.fft.rfftfreq(8) + expected = np.fft.rfftfreq(8).astype(np.float32) + np.testing.assert_allclose(out, expected, atol=0.0, rtol=0.0) + + with self.assertRaises(ValueError): + mx.fft.rfftfreq(0) + + with self.assertRaises(ValueError): + mx.fft.rfftfreq(-1) + + with self.assertRaises(ValueError): + mx.fft.rfftfreq(4, d=0.0) + def test_fftshift(self): # Test 1D arrays r = np.random.rand(100).astype(np.float32) @@ -253,11 +295,14 @@ def test_fftshift(self): # Test with specific axis r = np.random.rand(4, 6).astype(np.float32) + self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=0) + self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=1) self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[0]) self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[1]) self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[0, 1]) # Test with negative axes + self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=-1) self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[-1]) # Test with odd lengths @@ -278,11 +323,14 @@ def test_ifftshift(self): # Test with specific axis r = np.random.rand(4, 6).astype(np.float32) + self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=0) + self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=1) self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[0]) self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[1]) self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[0, 1]) # Test with negative axes + self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=-1) self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[-1]) # Test with odd lengths