Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/src/python/fft.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,7 @@ FFT
irfft2
rfftn
irfftn
fftfreq
rfftfreq
fftshift
ifftshift
28 changes: 28 additions & 0 deletions mlx/fft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,4 +310,32 @@ array ifftshift(const array& a, StreamOrDevice s /* = {} */) {
return ifftshift(a, axes, s);
}

array fftfreq(int n, double d /* = 1.0 */, StreamOrDevice s /* = {} */) {
if (n <= 0) {
throw std::invalid_argument("[fftfreq] `n` must be greater than 0.");
}
if (d == 0.0) {
throw std::invalid_argument("[fftfreq] `d` must be non-zero.");
}
auto pos = arange(0, (n + 1) / 2, float32, s);
auto neg = arange(-(n / 2), 0, float32, s);
auto freqs = concatenate({pos, neg}, s);
auto scale =
array(static_cast<float>(1.0 / (static_cast<double>(n) * d)), float32);
return multiply(freqs, scale, s);
}

array rfftfreq(int n, double d /* = 1.0 */, StreamOrDevice s /* = {} */) {
if (n <= 0) {
throw std::invalid_argument("[rfftfreq] `n` must be greater than 0.");
}
if (d == 0.0) {
throw std::invalid_argument("[rfftfreq] `d` must be non-zero.");
}
auto freqs = arange(0, n / 2 + 1, float32, s);
auto scale =
array(static_cast<float>(1.0 / (static_cast<double>(n) * d)), float32);
return multiply(freqs, scale, s);
}

} // namespace mlx::core::fft
7 changes: 7 additions & 0 deletions mlx/fft.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,13 @@ inline array irfft2(
StreamOrDevice s = {}) {
return irfftn(a, axes, norm, s);
}
/** Compute the discrete Fourier Transform sample frequencies. */
MLX_API array fftfreq(int n, double d = 1.0, StreamOrDevice s = {});

/** Compute the discrete Fourier Transform sample frequencies for
* `rfft`/`irfft`. */
MLX_API array rfftfreq(int n, double d = 1.0, StreamOrDevice s = {});

/** Shift the zero-frequency component to the center of the spectrum. */
MLX_API array fftshift(const array& a, StreamOrDevice s = {});

Expand Down
75 changes: 63 additions & 12 deletions python/src/fft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,19 @@ using namespace nb::literals;

namespace {

using IntOrVec = std::variant<int, std::vector<int>>;
using AxesArg = std::optional<IntOrVec>;

std::optional<std::vector<int>> normalize_axes(const AxesArg& axes) {
if (!axes.has_value()) {
return std::nullopt;
}
if (auto axis = std::get_if<int>(&axes.value())) {
return std::vector<int>{*axis};
}
return std::get<std::vector<int>>(axes.value());
}

mx::fft::FFTNorm parse_norm(std::string_view norm, std::string_view op) {
if (norm == "backward") {
return mx::fft::FFTNorm::Backward;
Expand Down Expand Up @@ -541,13 +554,52 @@ void init_fft(nb::module_& parent_module) {
Returns:
array: The real array containing the inverse of :func:`rfftn`.
)pbdoc");
m.def(
"fftfreq",
[](int n, double d, mx::StreamOrDevice s) {
return mx::fft::fftfreq(n, d, s);
},
"n"_a,
"d"_a = 1.0,
"stream"_a = nb::none(),
R"pbdoc(
Return the discrete Fourier Transform sample frequencies.

Args:
n (int): Window length.
d (float, optional): Sample spacing. The default is ``1.0``.

Returns:
array: The sample frequencies as a one-dimensional array of type ``float32``.
)pbdoc");
m.def(
"rfftfreq",
[](int n, double d, mx::StreamOrDevice s) {
return mx::fft::rfftfreq(n, d, s);
},
"n"_a,
"d"_a = 1.0,
"stream"_a = nb::none(),
R"pbdoc(
Return the discrete Fourier Transform sample frequencies
for use with :func:`rfft` and :func:`irfft`.

The returned array contains the non-negative frequency terms
in the range ``[0, floor(n/2)]``.

Args:
n (int): Window length.
d (float, optional): Sample spacing. The default is ``1.0``.

Returns:
array: The sample frequencies as a one-dimensional array of type ``float32``.
)pbdoc");
m.def(
"fftshift",
[](const mx::array& a,
const std::optional<std::vector<int>>& axes,
mx::StreamOrDevice s) {
if (axes.has_value()) {
return mx::fft::fftshift(a, axes.value(), s);
[](const mx::array& a, const AxesArg& axes, mx::StreamOrDevice s) {
auto normalized_axes = normalize_axes(axes);
if (normalized_axes.has_value()) {
return mx::fft::fftshift(a, normalized_axes.value(), s);
} else {
return mx::fft::fftshift(a, s);
}
Expand All @@ -560,19 +612,18 @@ void init_fft(nb::module_& parent_module) {

Args:
a (array): The input array.
axes (list(int), optional): Axes over which to perform the shift.
axes (int or list(int), optional): Axis or axes over which to perform the shift.
If ``None``, shift all axes.

Returns:
array: The shifted array with the same shape as the input.
)pbdoc");
m.def(
"ifftshift",
[](const mx::array& a,
const std::optional<std::vector<int>>& axes,
mx::StreamOrDevice s) {
if (axes.has_value()) {
return mx::fft::ifftshift(a, axes.value(), s);
[](const mx::array& a, const AxesArg& axes, mx::StreamOrDevice s) {
auto normalized_axes = normalize_axes(axes);
if (normalized_axes.has_value()) {
return mx::fft::ifftshift(a, normalized_axes.value(), s);
} else {
return mx::fft::ifftshift(a, s);
}
Expand All @@ -586,7 +637,7 @@ void init_fft(nb::module_& parent_module) {

Args:
a (array): The input array.
axes (list(int), optional): Axes over which to perform the inverse shift.
axes (int or list(int), optional): Axis or axes over which to perform the inverse shift.
If ``None``, shift all axes.

Returns:
Expand Down
48 changes: 48 additions & 0 deletions python/tests/test_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,18 +246,63 @@ def test_fft_throws(self):
with self.assertRaises(ValueError):
mx.fft.fft(mx.array([1.0]), norm="invalid")

def test_fftfreq(self):
for n, d in [(1, 1.0), (4, 0.5), (5, 0.25), (8, -0.5), (6, 1.0)]:
out = mx.fft.fftfreq(n, d=d)
expected = np.fft.fftfreq(n, d=d).astype(np.float32)
self.assertEqual(out.dtype, mx.float32)
np.testing.assert_allclose(out, expected, atol=0.0, rtol=0.0)

with self.assertRaises(ValueError):
mx.fft.fftfreq(0)

with self.assertRaises(ValueError):
mx.fft.fftfreq(-1)

# Test default d=1.0
out = mx.fft.fftfreq(8)
expected = np.fft.fftfreq(8).astype(np.float32)
np.testing.assert_allclose(out, expected, atol=0.0, rtol=0.0)

with self.assertRaises(ValueError):
mx.fft.fftfreq(4, d=0.0)

def test_rfftfreq(self):
for n, d in [(1, 1.0), (4, 0.5), (5, 0.25), (8, -0.5), (6, 1.0)]:
out = mx.fft.rfftfreq(n, d=d)
expected = np.fft.rfftfreq(n, d=d).astype(np.float32)
self.assertEqual(out.dtype, mx.float32)
np.testing.assert_allclose(out, expected, atol=0.0, rtol=0.0)

# Test default d=1.0
out = mx.fft.rfftfreq(8)
expected = np.fft.rfftfreq(8).astype(np.float32)
np.testing.assert_allclose(out, expected, atol=0.0, rtol=0.0)

with self.assertRaises(ValueError):
mx.fft.rfftfreq(0)

with self.assertRaises(ValueError):
mx.fft.rfftfreq(-1)

with self.assertRaises(ValueError):
mx.fft.rfftfreq(4, d=0.0)

def test_fftshift(self):
# Test 1D arrays
r = np.random.rand(100).astype(np.float32)
self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r)

# Test with specific axis
r = np.random.rand(4, 6).astype(np.float32)
self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=0)
self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=1)
self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[0])
self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[1])
self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[0, 1])

# Test with negative axes
self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=-1)
self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[-1])

# Test with odd lengths
Expand All @@ -278,11 +323,14 @@ def test_ifftshift(self):

# Test with specific axis
r = np.random.rand(4, 6).astype(np.float32)
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=0)
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=1)
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[0])
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[1])
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[0, 1])

# Test with negative axes
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=-1)
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[-1])

# Test with odd lengths
Expand Down