diff --git a/mlx/fft.cpp b/mlx/fft.cpp index 6510faec1e..6cad15c33b 100644 --- a/mlx/fft.cpp +++ b/mlx/fft.cpp @@ -1,4 +1,6 @@ // Copyright © 2023 Apple Inc. +#include +#include #include #include @@ -9,12 +11,30 @@ namespace mlx::core::fft { +double fft_scale_factor(const Shape& n, FFTNorm norm, bool inverse) { + if (n.empty()) { + return 1.0; + } + double n_elements = + std::accumulate(n.begin(), n.end(), 1.0, std::multiplies()); + switch (norm) { + case FFTNorm::Backward: + return 1.0; + case FFTNorm::Ortho: + return inverse ? std::sqrt(n_elements) : 1.0 / std::sqrt(n_elements); + case FFTNorm::Forward: + return inverse ? n_elements : 1.0 / n_elements; + } + throw std::invalid_argument("[fftn] Invalid FFT normalization mode."); +} + array fft_impl( const array& a, Shape n, const std::vector& axes, bool real, bool inverse, + FFTNorm norm, StreamOrDevice s) { if (a.ndim() < 1) { throw std::invalid_argument( @@ -91,11 +111,16 @@ array fft_impl( auto in_type = real && !inverse ? float32 : complex64; auto out_type = real && inverse ? float32 : complex64; - return array( + auto out = array( out_shape, out_type, std::make_shared(to_stream(s), valid_axes, inverse, real), {astype(in, in_type, s)}); + auto scale = fft_scale_factor(n, norm, inverse); + if (scale != 1.0) { + return multiply(out, array(scale, out.dtype()), s); + } + return out; } array fft_impl( @@ -103,6 +128,7 @@ array fft_impl( const std::vector& axes, bool real, bool inverse, + FFTNorm norm, StreamOrDevice s) { Shape n; for (auto ax : axes) { @@ -111,82 +137,107 @@ array fft_impl( if (real && inverse && a.ndim() > 0) { n.back() = (n.back() - 1) * 2; } - return fft_impl(a, n, axes, real, inverse, s); + return fft_impl(a, n, axes, real, inverse, norm, s); } -array fft_impl(const array& a, bool real, bool inverse, StreamOrDevice s) { +array fft_impl( + const array& a, + bool real, + bool inverse, + FFTNorm norm, + StreamOrDevice s) { std::vector axes(a.ndim()); std::iota(axes.begin(), axes.end(), 0); - return fft_impl(a, axes, real, inverse, s); + return fft_impl(a, axes, real, inverse, norm, s); } array fftn( const array& a, const Shape& n, const std::vector& axes, + FFTNorm norm /* = FFTNorm::Backward */, StreamOrDevice s /* = {} */) { - return fft_impl(a, n, axes, false, false, s); + return fft_impl(a, n, axes, false, false, norm, s); } array fftn( const array& a, const std::vector& axes, + FFTNorm norm /* = FFTNorm::Backward */, StreamOrDevice s /* = {} */) { - return fft_impl(a, axes, false, false, s); + return fft_impl(a, axes, false, false, norm, s); } -array fftn(const array& a, StreamOrDevice s /* = {} */) { - return fft_impl(a, false, false, s); +array fftn( + const array& a, + FFTNorm norm /* = FFTNorm::Backward */, + StreamOrDevice s /* = {} */) { + return fft_impl(a, false, false, norm, s); } array ifftn( const array& a, const Shape& n, const std::vector& axes, + FFTNorm norm /* = FFTNorm::Backward */, StreamOrDevice s /* = {} */) { - return fft_impl(a, n, axes, false, true, s); + return fft_impl(a, n, axes, false, true, norm, s); } array ifftn( const array& a, const std::vector& axes, + FFTNorm norm /* = FFTNorm::Backward */, StreamOrDevice s /* = {} */) { - return fft_impl(a, axes, false, true, s); + return fft_impl(a, axes, false, true, norm, s); } -array ifftn(const array& a, StreamOrDevice s /* = {} */) { - return fft_impl(a, false, true, s); +array ifftn( + const array& a, + FFTNorm norm /* = FFTNorm::Backward */, + StreamOrDevice s /* = {} */) { + return fft_impl(a, false, true, norm, s); } array rfftn( const array& a, const Shape& n, const std::vector& axes, + FFTNorm norm /* = FFTNorm::Backward */, StreamOrDevice s /* = {} */) { - return fft_impl(a, n, axes, true, false, s); + return fft_impl(a, n, axes, true, false, norm, s); } array rfftn( const array& a, const std::vector& axes, + FFTNorm norm /* = FFTNorm::Backward */, StreamOrDevice s /* = {} */) { - return fft_impl(a, axes, true, false, s); + return fft_impl(a, axes, true, false, norm, s); } -array rfftn(const array& a, StreamOrDevice s /* = {} */) { - return fft_impl(a, true, false, s); +array rfftn( + const array& a, + FFTNorm norm /* = FFTNorm::Backward */, + StreamOrDevice s /* = {} */) { + return fft_impl(a, true, false, norm, s); } array irfftn( const array& a, const Shape& n, const std::vector& axes, + FFTNorm norm /* = FFTNorm::Backward */, StreamOrDevice s /* = {} */) { - return fft_impl(a, n, axes, true, true, s); + return fft_impl(a, n, axes, true, true, norm, s); } array irfftn( const array& a, const std::vector& axes, + FFTNorm norm /* = FFTNorm::Backward */, StreamOrDevice s /* = {} */) { - return fft_impl(a, axes, true, true, s); + return fft_impl(a, axes, true, true, norm, s); } -array irfftn(const array& a, StreamOrDevice s /* = {} */) { - return fft_impl(a, true, true, s); +array irfftn( + const array& a, + FFTNorm norm /* = FFTNorm::Backward */, + StreamOrDevice s /* = {} */) { + return fft_impl(a, true, true, norm, s); } array fftshift( diff --git a/mlx/fft.h b/mlx/fft.h index 9abf2b18c5..a36af78de9 100644 --- a/mlx/fft.h +++ b/mlx/fft.h @@ -2,6 +2,7 @@ #pragma once +#include #include #include "array.h" @@ -11,40 +12,74 @@ namespace mlx::core::fft { +enum class FFTNorm { + Backward, + Ortho, + Forward, +}; + /** Compute the n-dimensional Fourier Transform. */ MLX_API array fftn( const array& a, const Shape& n, const std::vector& axes, + FFTNorm norm = FFTNorm::Backward, + StreamOrDevice s = {}); +MLX_API array fftn( + const array& a, + const std::vector& axes, + FFTNorm norm = FFTNorm::Backward, StreamOrDevice s = {}); MLX_API array -fftn(const array& a, const std::vector& axes, StreamOrDevice s = {}); -MLX_API array fftn(const array& a, StreamOrDevice s = {}); +fftn(const array& a, FFTNorm norm = FFTNorm::Backward, StreamOrDevice s = {}); /** Compute the n-dimensional inverse Fourier Transform. */ MLX_API array ifftn( const array& a, const Shape& n, const std::vector& axes, + FFTNorm norm = FFTNorm::Backward, + StreamOrDevice s = {}); +MLX_API array ifftn( + const array& a, + const std::vector& axes, + FFTNorm norm = FFTNorm::Backward, StreamOrDevice s = {}); MLX_API array -ifftn(const array& a, const std::vector& axes, StreamOrDevice s = {}); -MLX_API array ifftn(const array& a, StreamOrDevice s = {}); +ifftn(const array& a, FFTNorm norm = FFTNorm::Backward, StreamOrDevice s = {}); /** Compute the one-dimensional Fourier Transform. */ -inline array fft(const array& a, int n, int axis, StreamOrDevice s = {}) { - return fftn(a, {n}, {axis}, s); +inline array fft( + const array& a, + int n, + int axis, + FFTNorm norm = FFTNorm::Backward, + StreamOrDevice s = {}) { + return fftn(a, {n}, {axis}, norm, s); } -inline array fft(const array& a, int axis = -1, StreamOrDevice s = {}) { - return fftn(a, {axis}, s); +inline array fft( + const array& a, + int axis = -1, + FFTNorm norm = FFTNorm::Backward, + StreamOrDevice s = {}) { + return fftn(a, {axis}, norm, s); } /** Compute the one-dimensional inverse Fourier Transform. */ -inline array ifft(const array& a, int n, int axis, StreamOrDevice s = {}) { - return ifftn(a, {n}, {axis}, s); +inline array ifft( + const array& a, + int n, + int axis, + FFTNorm norm = FFTNorm::Backward, + StreamOrDevice s = {}) { + return ifftn(a, {n}, {axis}, norm, s); } -inline array ifft(const array& a, int axis = -1, StreamOrDevice s = {}) { - return ifftn(a, {axis}, s); +inline array ifft( + const array& a, + int axis = -1, + FFTNorm norm = FFTNorm::Backward, + StreamOrDevice s = {}) { + return ifftn(a, {axis}, norm, s); } /** Compute the two-dimensional Fourier Transform. */ @@ -52,14 +87,16 @@ inline array fft2( const array& a, const Shape& n, const std::vector& axes, + FFTNorm norm = FFTNorm::Backward, StreamOrDevice s = {}) { - return fftn(a, n, axes, s); + return fftn(a, n, axes, norm, s); } inline array fft2( const array& a, const std::vector& axes = {-2, -1}, + FFTNorm norm = FFTNorm::Backward, StreamOrDevice s = {}) { - return fftn(a, axes, s); + return fftn(a, axes, norm, s); } /** Compute the two-dimensional inverse Fourier Transform. */ @@ -67,14 +104,16 @@ inline array ifft2( const array& a, const Shape& n, const std::vector& axes, + FFTNorm norm = FFTNorm::Backward, StreamOrDevice s = {}) { - return ifftn(a, n, axes, s); + return ifftn(a, n, axes, norm, s); } inline array ifft2( const array& a, const std::vector& axes = {-2, -1}, + FFTNorm norm = FFTNorm::Backward, StreamOrDevice s = {}) { - return ifftn(a, axes, s); + return ifftn(a, axes, norm, s); } /** Compute the n-dimensional Fourier Transform on a real input. */ @@ -82,34 +121,62 @@ MLX_API array rfftn( const array& a, const Shape& n, const std::vector& axes, + FFTNorm norm = FFTNorm::Backward, + StreamOrDevice s = {}); +MLX_API array rfftn( + const array& a, + const std::vector& axes, + FFTNorm norm = FFTNorm::Backward, StreamOrDevice s = {}); MLX_API array -rfftn(const array& a, const std::vector& axes, StreamOrDevice s = {}); -MLX_API array rfftn(const array& a, StreamOrDevice s = {}); +rfftn(const array& a, FFTNorm norm = FFTNorm::Backward, StreamOrDevice s = {}); /** Compute the n-dimensional inverse of `rfftn`. */ MLX_API array irfftn( const array& a, const Shape& n, const std::vector& axes, + FFTNorm norm = FFTNorm::Backward, + StreamOrDevice s = {}); +MLX_API array irfftn( + const array& a, + const std::vector& axes, + FFTNorm norm = FFTNorm::Backward, StreamOrDevice s = {}); MLX_API array -irfftn(const array& a, const std::vector& axes, StreamOrDevice s = {}); -MLX_API array irfftn(const array& a, StreamOrDevice s = {}); +irfftn(const array& a, FFTNorm norm = FFTNorm::Backward, StreamOrDevice s = {}); /** Compute the one-dimensional Fourier Transform on a real input. */ -inline array rfft(const array& a, int n, int axis, StreamOrDevice s = {}) { - return rfftn(a, {n}, {axis}, s); +inline array rfft( + const array& a, + int n, + int axis, + FFTNorm norm = FFTNorm::Backward, + StreamOrDevice s = {}) { + return rfftn(a, {n}, {axis}, norm, s); } -inline array rfft(const array& a, int axis = -1, StreamOrDevice s = {}) { - return rfftn(a, {axis}, s); +inline array rfft( + const array& a, + int axis = -1, + FFTNorm norm = FFTNorm::Backward, + StreamOrDevice s = {}) { + return rfftn(a, {axis}, norm, s); } /** Compute the one-dimensional inverse of `rfft`. */ -inline array irfft(const array& a, int n, int axis, StreamOrDevice s = {}) { - return irfftn(a, {n}, {axis}, s); +inline array irfft( + const array& a, + int n, + int axis, + FFTNorm norm = FFTNorm::Backward, + StreamOrDevice s = {}) { + return irfftn(a, {n}, {axis}, norm, s); } -inline array irfft(const array& a, int axis = -1, StreamOrDevice s = {}) { - return irfftn(a, {axis}, s); +inline array irfft( + const array& a, + int axis = -1, + FFTNorm norm = FFTNorm::Backward, + StreamOrDevice s = {}) { + return irfftn(a, {axis}, norm, s); } /** Compute the two-dimensional Fourier Transform on a real input. */ @@ -117,14 +184,16 @@ inline array rfft2( const array& a, const Shape& n, const std::vector& axes, + FFTNorm norm = FFTNorm::Backward, StreamOrDevice s = {}) { - return rfftn(a, n, axes, s); + return rfftn(a, n, axes, norm, s); } inline array rfft2( const array& a, const std::vector& axes = {-2, -1}, + FFTNorm norm = FFTNorm::Backward, StreamOrDevice s = {}) { - return rfftn(a, axes, s); + return rfftn(a, axes, norm, s); } /** Compute the two-dimensional inverse of `rfft2`. */ @@ -132,14 +201,16 @@ inline array irfft2( const array& a, const Shape& n, const std::vector& axes, + FFTNorm norm = FFTNorm::Backward, StreamOrDevice s = {}) { - return irfftn(a, n, axes, s); + return irfftn(a, n, axes, norm, s); } inline array irfft2( const array& a, const std::vector& axes = {-2, -1}, + FFTNorm norm = FFTNorm::Backward, StreamOrDevice s = {}) { - return irfftn(a, axes, s); + return irfftn(a, axes, norm, s); } /** Shift the zero-frequency component to the center of the spectrum. */ MLX_API array fftshift(const array& a, StreamOrDevice s = {}); diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index fa9e55700e..4e92ff634c 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -2210,8 +2210,10 @@ std::vector FFT::vjp( two, one, stream()); - return { - multiply(fft::rfftn(cotangents[0], axes, stream()), mask, stream())}; + return {multiply( + fft::rfftn(cotangents[0], axes, fft::FFTNorm::Backward, stream()), + mask, + stream())}; } else if (real_) { Shape n; for (auto ax : axes_) { @@ -2237,17 +2239,22 @@ std::vector FFT::vjp( one, stream()); return {multiply( - fft::irfftn(multiply(cotangents[0], mask, stream()), n, axes, stream()), + fft::irfftn( + multiply(cotangents[0], mask, stream()), + n, + axes, + fft::FFTNorm::Backward, + stream()), array(n_elements, in.dtype()), stream())}; } else if (inverse_) { return {multiply( - fft::fftn(cotangents[0], axes, stream()), + fft::fftn(cotangents[0], axes, fft::FFTNorm::Backward, stream()), array(1 / n_elements, complex64), stream())}; } else { return {multiply( - fft::ifftn(cotangents[0], axes, stream()), + fft::ifftn(cotangents[0], axes, fft::FFTNorm::Backward, stream()), array(n_elements, complex64), stream())}; } @@ -2261,13 +2268,13 @@ std::vector FFT::jvp( assert(argnums.size() == 1); auto& tan = tangents[0]; if (real_ & inverse_) { - return {fft::irfftn(tan, stream())}; + return {fft::irfftn(tan, fft::FFTNorm::Backward, stream())}; } else if (real_) { - return {fft::rfftn(tan, stream())}; + return {fft::rfftn(tan, fft::FFTNorm::Backward, stream())}; } else if (inverse_) { - return {fft::ifftn(tan, stream())}; + return {fft::ifftn(tan, fft::FFTNorm::Backward, stream())}; } else { - return {fft::fftn(tan, stream())}; + return {fft::fftn(tan, fft::FFTNorm::Backward, stream())}; } } diff --git a/python/src/fft.cpp b/python/src/fft.cpp index 522e1064c7..8333c3f1f4 100644 --- a/python/src/fft.cpp +++ b/python/src/fft.cpp @@ -2,9 +2,11 @@ #include #include +#include #include #include #include +#include #include "mlx/fft.h" #include "mlx/ops.h" @@ -14,6 +16,25 @@ namespace mx = mlx::core; namespace nb = nanobind; using namespace nb::literals; +namespace { + +mx::fft::FFTNorm parse_norm(std::string_view norm, std::string_view op) { + if (norm == "backward") { + return mx::fft::FFTNorm::Backward; + } + if (norm == "ortho") { + return mx::fft::FFTNorm::Ortho; + } + if (norm == "forward") { + return mx::fft::FFTNorm::Forward; + } + throw std::invalid_argument( + std::string("[") + std::string(op) + + "] Invalid norm. Expected one of {'backward', 'ortho', 'forward'}."); +} + +} // namespace + void init_fft(nb::module_& parent_module) { auto m = parent_module.def_submodule( "fft", "mlx.core.fft: Fast Fourier Transforms."); @@ -22,16 +43,19 @@ void init_fft(nb::module_& parent_module) { [](const mx::array& a, const std::optional& n, int axis, + const std::string& norm, mx::StreamOrDevice s) { + auto fft_norm = parse_norm(norm, "fft"); if (n.has_value()) { - return mx::fft::fft(a, n.value(), axis, s); + return mx::fft::fft(a, n.value(), axis, fft_norm, s); } else { - return mx::fft::fft(a, axis, s); + return mx::fft::fft(a, axis, fft_norm, s); } }, "a"_a, "n"_a = nb::none(), "axis"_a = -1, + "norm"_a = "backward", "stream"_a = nb::none(), R"pbdoc( One dimensional discrete Fourier Transform. @@ -43,6 +67,8 @@ void init_fft(nb::module_& parent_module) { zeros to match ``n``. The default value is ``a.shape[axis]``. axis (int, optional): Axis along which to perform the FFT. The default is ``-1``. + norm (str, optional): One of ``"backward"``, ``"ortho"``, or + ``"forward"``. Default is ``"backward"``. Returns: array: The DFT of the input along the given axis. @@ -52,16 +78,19 @@ void init_fft(nb::module_& parent_module) { [](const mx::array& a, const std::optional& n, int axis, + const std::string& norm, mx::StreamOrDevice s) { + auto fft_norm = parse_norm(norm, "ifft"); if (n.has_value()) { - return mx::fft::ifft(a, n.value(), axis, s); + return mx::fft::ifft(a, n.value(), axis, fft_norm, s); } else { - return mx::fft::ifft(a, axis, s); + return mx::fft::ifft(a, axis, fft_norm, s); } }, "a"_a, "n"_a = nb::none(), "axis"_a = -1, + "norm"_a = "backward", "stream"_a = nb::none(), R"pbdoc( One dimensional inverse discrete Fourier Transform. @@ -73,6 +102,8 @@ void init_fft(nb::module_& parent_module) { zeros to match ``n``. The default value is ``a.shape[axis]``. axis (int, optional): Axis along which to perform the FFT. The default is ``-1``. + norm (str, optional): One of ``"backward"``, ``"ortho"``, or + ``"forward"``. Default is ``"backward"``. Returns: array: The inverse DFT of the input along the given axis. @@ -82,21 +113,24 @@ void init_fft(nb::module_& parent_module) { [](const mx::array& a, const std::optional& n, const std::optional>& axes, + const std::string& norm, mx::StreamOrDevice s) { + auto fft_norm = parse_norm(norm, "fft2"); if (axes.has_value() && n.has_value()) { - return mx::fft::fftn(a, n.value(), axes.value(), s); + return mx::fft::fftn(a, n.value(), axes.value(), fft_norm, s); } else if (axes.has_value()) { - return mx::fft::fftn(a, axes.value(), s); + return mx::fft::fftn(a, axes.value(), fft_norm, s); } else if (n.has_value()) { throw std::invalid_argument( "[fft2] `axes` should not be `None` if `s` is not `None`."); } else { - return mx::fft::fftn(a, s); + return mx::fft::fftn(a, fft_norm, s); } }, "a"_a, "s"_a = nb::none(), "axes"_a.none() = std::vector{-2, -1}, + "norm"_a = "backward", "stream"_a = nb::none(), R"pbdoc( Two dimensional discrete Fourier Transform. @@ -109,6 +143,8 @@ void init_fft(nb::module_& parent_module) { sizes of ``a`` along ``axes``. axes (list(int), optional): Axes along which to perform the FFT. The default is ``[-2, -1]``. + norm (str, optional): One of ``"backward"``, ``"ortho"``, or + ``"forward"``. Default is ``"backward"``. Returns: array: The DFT of the input along the given axes. @@ -118,21 +154,24 @@ void init_fft(nb::module_& parent_module) { [](const mx::array& a, const std::optional& n, const std::optional>& axes, + const std::string& norm, mx::StreamOrDevice s) { + auto fft_norm = parse_norm(norm, "ifft2"); if (axes.has_value() && n.has_value()) { - return mx::fft::ifftn(a, n.value(), axes.value(), s); + return mx::fft::ifftn(a, n.value(), axes.value(), fft_norm, s); } else if (axes.has_value()) { - return mx::fft::ifftn(a, axes.value(), s); + return mx::fft::ifftn(a, axes.value(), fft_norm, s); } else if (n.has_value()) { throw std::invalid_argument( "[ifft2] `axes` should not be `None` if `s` is not `None`."); } else { - return mx::fft::ifftn(a, s); + return mx::fft::ifftn(a, fft_norm, s); } }, "a"_a, "s"_a = nb::none(), "axes"_a.none() = std::vector{-2, -1}, + "norm"_a = "backward", "stream"_a = nb::none(), R"pbdoc( Two dimensional inverse discrete Fourier Transform. @@ -145,6 +184,8 @@ void init_fft(nb::module_& parent_module) { sizes of ``a`` along ``axes``. axes (list(int), optional): Axes along which to perform the FFT. The default is ``[-2, -1]``. + norm (str, optional): One of ``"backward"``, ``"ortho"``, or + ``"forward"``. Default is ``"backward"``. Returns: array: The inverse DFT of the input along the given axes. @@ -154,21 +195,24 @@ void init_fft(nb::module_& parent_module) { [](const mx::array& a, const std::optional& n, const std::optional>& axes, + const std::string& norm, mx::StreamOrDevice s) { + auto fft_norm = parse_norm(norm, "fftn"); if (axes.has_value() && n.has_value()) { - return mx::fft::fftn(a, n.value(), axes.value(), s); + return mx::fft::fftn(a, n.value(), axes.value(), fft_norm, s); } else if (axes.has_value()) { - return mx::fft::fftn(a, axes.value(), s); + return mx::fft::fftn(a, axes.value(), fft_norm, s); } else if (n.has_value()) { throw std::invalid_argument( "[fftn] `axes` should not be `None` if `s` is not `None`."); } else { - return mx::fft::fftn(a, s); + return mx::fft::fftn(a, fft_norm, s); } }, "a"_a, "s"_a = nb::none(), "axes"_a = nb::none(), + "norm"_a = "backward", "stream"_a = nb::none(), R"pbdoc( n-dimensional discrete Fourier Transform. @@ -182,6 +226,8 @@ void init_fft(nb::module_& parent_module) { axes (list(int), optional): Axes along which to perform the FFT. The default is ``None`` in which case the FFT is over the last ``len(s)`` axes are or all axes if ``s`` is also ``None``. + norm (str, optional): One of ``"backward"``, ``"ortho"``, or + ``"forward"``. Default is ``"backward"``. Returns: array: The DFT of the input along the given axes. @@ -191,21 +237,24 @@ void init_fft(nb::module_& parent_module) { [](const mx::array& a, const std::optional& n, const std::optional>& axes, + const std::string& norm, mx::StreamOrDevice s) { + auto fft_norm = parse_norm(norm, "ifftn"); if (axes.has_value() && n.has_value()) { - return mx::fft::ifftn(a, n.value(), axes.value(), s); + return mx::fft::ifftn(a, n.value(), axes.value(), fft_norm, s); } else if (axes.has_value()) { - return mx::fft::ifftn(a, axes.value(), s); + return mx::fft::ifftn(a, axes.value(), fft_norm, s); } else if (n.has_value()) { throw std::invalid_argument( "[ifftn] `axes` should not be `None` if `s` is not `None`."); } else { - return mx::fft::ifftn(a, s); + return mx::fft::ifftn(a, fft_norm, s); } }, "a"_a, "s"_a = nb::none(), "axes"_a = nb::none(), + "norm"_a = "backward", "stream"_a = nb::none(), R"pbdoc( n-dimensional inverse discrete Fourier Transform. @@ -219,6 +268,8 @@ void init_fft(nb::module_& parent_module) { axes (list(int), optional): Axes along which to perform the FFT. The default is ``None`` in which case the FFT is over the last ``len(s)`` axes or all axes if ``s`` is also ``None``. + norm (str, optional): One of ``"backward"``, ``"ortho"``, or + ``"forward"``. Default is ``"backward"``. Returns: array: The inverse DFT of the input along the given axes. @@ -228,16 +279,19 @@ void init_fft(nb::module_& parent_module) { [](const mx::array& a, const std::optional& n, int axis, + const std::string& norm, mx::StreamOrDevice s) { + auto fft_norm = parse_norm(norm, "rfft"); if (n.has_value()) { - return mx::fft::rfft(a, n.value(), axis, s); + return mx::fft::rfft(a, n.value(), axis, fft_norm, s); } else { - return mx::fft::rfft(a, axis, s); + return mx::fft::rfft(a, axis, fft_norm, s); } }, "a"_a, "n"_a = nb::none(), "axis"_a = -1, + "norm"_a = "backward", "stream"_a = nb::none(), R"pbdoc( One dimensional discrete Fourier Transform on a real input. @@ -253,6 +307,8 @@ void init_fft(nb::module_& parent_module) { zeros to match ``n``. The default value is ``a.shape[axis]``. axis (int, optional): Axis along which to perform the FFT. The default is ``-1``. + norm (str, optional): One of ``"backward"``, ``"ortho"``, or + ``"forward"``. Default is ``"backward"``. Returns: array: The DFT of the input along the given axis. The output @@ -263,16 +319,19 @@ void init_fft(nb::module_& parent_module) { [](const mx::array& a, const std::optional& n, int axis, + const std::string& norm, mx::StreamOrDevice s) { + auto fft_norm = parse_norm(norm, "irfft"); if (n.has_value()) { - return mx::fft::irfft(a, n.value(), axis, s); + return mx::fft::irfft(a, n.value(), axis, fft_norm, s); } else { - return mx::fft::irfft(a, axis, s); + return mx::fft::irfft(a, axis, fft_norm, s); } }, "a"_a, "n"_a = nb::none(), "axis"_a = -1, + "norm"_a = "backward", "stream"_a = nb::none(), R"pbdoc( The inverse of :func:`rfft`. @@ -288,6 +347,8 @@ void init_fft(nb::module_& parent_module) { ``a.shape[axis] // 2 + 1``. axis (int, optional): Axis along which to perform the FFT. The default is ``-1``. + norm (str, optional): One of ``"backward"``, ``"ortho"``, or + ``"forward"``. Default is ``"backward"``. Returns: array: The real array containing the inverse of :func:`rfft`. @@ -297,21 +358,24 @@ void init_fft(nb::module_& parent_module) { [](const mx::array& a, const std::optional& n, const std::optional>& axes, + const std::string& norm, mx::StreamOrDevice s) { + auto fft_norm = parse_norm(norm, "rfft2"); if (axes.has_value() && n.has_value()) { - return mx::fft::rfftn(a, n.value(), axes.value(), s); + return mx::fft::rfftn(a, n.value(), axes.value(), fft_norm, s); } else if (axes.has_value()) { - return mx::fft::rfftn(a, axes.value(), s); + return mx::fft::rfftn(a, axes.value(), fft_norm, s); } else if (n.has_value()) { throw std::invalid_argument( "[rfft2] `axes` should not be `None` if `s` is not `None`."); } else { - return mx::fft::rfftn(a, s); + return mx::fft::rfftn(a, fft_norm, s); } }, "a"_a, "s"_a = nb::none(), "axes"_a.none() = std::vector{-2, -1}, + "norm"_a = "backward", "stream"_a = nb::none(), R"pbdoc( Two dimensional real discrete Fourier Transform. @@ -329,6 +393,8 @@ void init_fft(nb::module_& parent_module) { sizes of ``a`` along ``axes``. axes (list(int), optional): Axes along which to perform the FFT. The default is ``[-2, -1]``. + norm (str, optional): One of ``"backward"``, ``"ortho"``, or + ``"forward"``. Default is ``"backward"``. Returns: array: The real DFT of the input along the given axes. The output @@ -339,21 +405,24 @@ void init_fft(nb::module_& parent_module) { [](const mx::array& a, const std::optional& n, const std::optional>& axes, + const std::string& norm, mx::StreamOrDevice s) { + auto fft_norm = parse_norm(norm, "irfft2"); if (axes.has_value() && n.has_value()) { - return mx::fft::irfftn(a, n.value(), axes.value(), s); + return mx::fft::irfftn(a, n.value(), axes.value(), fft_norm, s); } else if (axes.has_value()) { - return mx::fft::irfftn(a, axes.value(), s); + return mx::fft::irfftn(a, axes.value(), fft_norm, s); } else if (n.has_value()) { throw std::invalid_argument( "[irfft2] `axes` should not be `None` if `s` is not `None`."); } else { - return mx::fft::irfftn(a, s); + return mx::fft::irfftn(a, fft_norm, s); } }, "a"_a, "s"_a = nb::none(), "axes"_a.none() = std::vector{-2, -1}, + "norm"_a = "backward", "stream"_a = nb::none(), R"pbdoc( The inverse of :func:`rfft2`. @@ -372,6 +441,8 @@ void init_fft(nb::module_& parent_module) { sizes of ``a`` along ``axes``. axes (list(int), optional): Axes along which to perform the FFT. The default is ``[-2, -1]``. + norm (str, optional): One of ``"backward"``, ``"ortho"``, or + ``"forward"``. Default is ``"backward"``. Returns: array: The real array containing the inverse of :func:`rfft2`. @@ -381,21 +452,24 @@ void init_fft(nb::module_& parent_module) { [](const mx::array& a, const std::optional& n, const std::optional>& axes, + const std::string& norm, mx::StreamOrDevice s) { + auto fft_norm = parse_norm(norm, "rfftn"); if (axes.has_value() && n.has_value()) { - return mx::fft::rfftn(a, n.value(), axes.value(), s); + return mx::fft::rfftn(a, n.value(), axes.value(), fft_norm, s); } else if (axes.has_value()) { - return mx::fft::rfftn(a, axes.value(), s); + return mx::fft::rfftn(a, axes.value(), fft_norm, s); } else if (n.has_value()) { throw std::invalid_argument( "[rfftn] `axes` should not be `None` if `s` is not `None`."); } else { - return mx::fft::rfftn(a, s); + return mx::fft::rfftn(a, fft_norm, s); } }, "a"_a, "s"_a = nb::none(), "axes"_a = nb::none(), + "norm"_a = "backward", "stream"_a = nb::none(), R"pbdoc( n-dimensional real discrete Fourier Transform. @@ -414,6 +488,8 @@ void init_fft(nb::module_& parent_module) { axes (list(int), optional): Axes along which to perform the FFT. The default is ``None`` in which case the FFT is over the last ``len(s)`` axes or all axes if ``s`` is also ``None``. + norm (str, optional): One of ``"backward"``, ``"ortho"``, or + ``"forward"``. Default is ``"backward"``. Returns: array: The real DFT of the input along the given axes. The output @@ -423,21 +499,24 @@ void init_fft(nb::module_& parent_module) { [](const mx::array& a, const std::optional& n, const std::optional>& axes, + const std::string& norm, mx::StreamOrDevice s) { + auto fft_norm = parse_norm(norm, "irfftn"); if (axes.has_value() && n.has_value()) { - return mx::fft::irfftn(a, n.value(), axes.value(), s); + return mx::fft::irfftn(a, n.value(), axes.value(), fft_norm, s); } else if (axes.has_value()) { - return mx::fft::irfftn(a, axes.value(), s); + return mx::fft::irfftn(a, axes.value(), fft_norm, s); } else if (n.has_value()) { throw std::invalid_argument( "[irfftn] `axes` should not be `None` if `s` is not `None`."); } else { - return mx::fft::irfftn(a, s); + return mx::fft::irfftn(a, fft_norm, s); } }, "a"_a, "s"_a = nb::none(), "axes"_a = nb::none(), + "norm"_a = "backward", "stream"_a = nb::none(), R"pbdoc( The inverse of :func:`rfftn`. @@ -456,6 +535,8 @@ void init_fft(nb::module_& parent_module) { axes (list(int), optional): Axes along which to perform the FFT. The default is ``None`` in which case the FFT is over the last ``len(s)`` axes or all axes if ``s`` is also ``None``. + norm (str, optional): One of ``"backward"``, ``"ortho"``, or + ``"forward"``. Default is ``"backward"``. Returns: array: The real array containing the inverse of :func:`rfftn`. diff --git a/python/tests/test_fft.py b/python/tests/test_fft.py index 9921c7d985..23aaf27af0 100644 --- a/python/tests/test_fft.py +++ b/python/tests/test_fft.py @@ -102,6 +102,33 @@ def test_fftn(self): irfft_in = np.ascontiguousarray(np.fft.rfftn(rt, axes=(2, 0))) self.check_mx_np(mx.fft.irfftn, np.fft.irfftn, irfft_in, axes=(2, 0)) + def test_fft_norm(self): + norms = ["backward", "ortho", "forward"] + + r = np.random.randn(8, 6).astype(np.float32) + i = np.random.randn(8, 6).astype(np.float32) + c = r + 1j * i + + for norm in norms: + self.check_mx_np(mx.fft.fft, np.fft.fft, c, axis=1, norm=norm) + self.check_mx_np(mx.fft.ifft, np.fft.ifft, c, axis=1, norm=norm) + self.check_mx_np(mx.fft.rfft, np.fft.rfft, r, axis=1, norm=norm) + + cr = np.fft.rfft(r, axis=1) + self.check_mx_np(mx.fft.irfft, np.fft.irfft, cr, axis=1, norm=norm) + + self.check_mx_np(mx.fft.fft2, np.fft.fft2, c, axes=(0, 1), norm=norm) + self.check_mx_np(mx.fft.ifft2, np.fft.ifft2, c, axes=(0, 1), norm=norm) + self.check_mx_np(mx.fft.fftn, np.fft.fftn, c, axes=(0, 1), norm=norm) + self.check_mx_np(mx.fft.ifftn, np.fft.ifftn, c, axes=(0, 1), norm=norm) + + self.check_mx_np(mx.fft.rfft2, np.fft.rfft2, r, axes=(0, 1), norm=norm) + self.check_mx_np(mx.fft.rfftn, np.fft.rfftn, r, axes=(0, 1), norm=norm) + + cr2 = np.fft.rfft2(r, axes=(0, 1)) + self.check_mx_np(mx.fft.irfft2, np.fft.irfft2, cr2, axes=(0, 1), norm=norm) + self.check_mx_np(mx.fft.irfftn, np.fft.irfftn, cr2, axes=(0, 1), norm=norm) + def _run_ffts(self, shape, atol=1e-4, rtol=1e-4): np.random.seed(9) @@ -216,6 +243,8 @@ def test_fft_throws(self): x = mx.array(3.0) with self.assertRaises(ValueError): mx.fft.irfftn(x) + with self.assertRaises(ValueError): + mx.fft.fft(mx.array([1.0]), norm="invalid") def test_fftshift(self): # Test 1D arrays diff --git a/tests/fft_tests.cpp b/tests/fft_tests.cpp index 39ed80bdeb..7dc1b6e666 100644 --- a/tests/fft_tests.cpp +++ b/tests/fft_tests.cpp @@ -44,6 +44,11 @@ TEST_CASE("test fft basics") { CHECK_EQ(y.size(), x.size()); CHECK(array_equal(y, array(expected)).item()); + auto y_ortho = fft::fft(x, -1, fft::FFTNorm::Ortho); + auto y_forward = fft::fft(x, -1, fft::FFTNorm::Forward); + CHECK(allclose(y_ortho, y * array(0.5f), 1e-6, 1e-6).item()); + CHECK(allclose(y_forward, y * array(0.25f), 1e-6, 1e-6).item()); + y = fft::ifft(x); std::initializer_list expected_inv = { {1.5, 0.0}, @@ -52,6 +57,12 @@ TEST_CASE("test fft basics") { {-0.5, 0.5}, }; CHECK(array_equal(y, array(expected_inv)).item()); + + auto yi = fft::ifft(x); + auto yi_ortho = fft::ifft(x, -1, fft::FFTNorm::Ortho); + auto yi_forward = fft::ifft(x, -1, fft::FFTNorm::Forward); + CHECK(allclose(yi_ortho, yi * array(2.0f), 1e-6, 1e-6).item()); + CHECK(allclose(yi_forward, yi * array(4.0f), 1e-6, 1e-6).item()); } { @@ -126,8 +137,11 @@ TEST_CASE("test fftn") { CHECK_THROWS_AS(fft::fftn(x, {}, {0, -4}), std::invalid_argument); CHECK_THROWS_AS(fft::fftn(x, {}, {0, 0}), std::invalid_argument); CHECK_THROWS_AS(fft::fftn(x, {5, 5, 5}, {0}), std::invalid_argument); - CHECK_THROWS_AS(fft::fftn(x, {0}, {}, {}), std::invalid_argument); - CHECK_THROWS_AS(fft::fftn(x, {1, -1}, {}, {}), std::invalid_argument); + CHECK_THROWS_AS( + fft::fftn(x, {0}, {}, fft::FFTNorm::Backward, {}), std::invalid_argument); + CHECK_THROWS_AS( + fft::fftn(x, {1, -1}, {}, fft::FFTNorm::Backward, {}), + std::invalid_argument); // Test 2D FFT {