diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index 6aabfc6954..b9eac968b8 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -31,6 +31,8 @@ Operations atleast_1d atleast_2d atleast_3d + bessel_i0e + bessel_i1e bitwise_and bitwise_invert bitwise_or diff --git a/mlx/backend/cpu/simd/math.h b/mlx/backend/cpu/simd/math.h index f9fc8317a5..d4eb363f91 100644 --- a/mlx/backend/cpu/simd/math.h +++ b/mlx/backend/cpu/simd/math.h @@ -190,4 +190,137 @@ Simd erfinv(Simd a_) { } } +// Clenshaw recurrence for Chebyshev series evaluation. +template +float chbevl_scalar(float x, const float (&coeffs)[M]) { + float b0 = coeffs[0]; + float b1 = 0.0f; + float b2; + for (int i = 1; i < M; i++) { + b2 = b1; + b1 = b0; + b0 = x * b1 - b2 + coeffs[i]; + } + return 0.5f * (b0 - b2); +} + +inline float i0e_scalar(float x) { + // Cephes Chebyshev coefficients for exp(-x) I0(x), interval [0, 8]. + static const float A_i0[30] = { + -4.41534164647933937950e-18f, 3.33079451882223809783e-17f, + -2.43127984654795469359e-16f, 1.71539128555513303061e-15f, + -1.16853328779934516808e-14f, 7.67618549860493561688e-14f, + -4.85644678311192946090e-13f, 2.95505266312963983461e-12f, + -1.72682629144155570723e-11f, 9.67580903537323691224e-11f, + -5.18979560163526290666e-10f, 2.65982372468238665035e-09f, + -1.30002500998624804212e-08f, 6.04699502254191894932e-08f, + -2.67079385394061173391e-07f, 1.11738753912010371815e-06f, + -4.41673835845875056359e-06f, 1.64484480707288970893e-05f, + -5.75419501008210370398e-05f, 1.88502885095841655729e-04f, + -5.76375574538582365885e-04f, 1.63947561694133579842e-03f, + -4.32430999505057594430e-03f, 1.05464603945949983183e-02f, + -2.37374148058994688156e-02f, 4.93052842396707084878e-02f, + -9.49010970480476444210e-02f, 1.71620901522208775349e-01f, + -3.04682672343198398683e-01f, 6.76795274409476084995e-01f, + }; + // Cephes Chebyshev coefficients for exp(-x) sqrt(x) I0(x), interval [8, inf]. + static const float B_i0[25] = { + -7.23318048787475395456e-18f, -4.83050448594418207126e-18f, + 4.46562142029675999901e-17f, 3.46122286769746109310e-17f, + -2.82762398051658348494e-16f, -3.42548561967721913462e-16f, + 1.77256013305652638360e-15f, 3.81168066935262242075e-15f, + -9.55484669882830764870e-15f, -4.15056934728722208663e-14f, + 1.54008621752140982691e-14f, 3.85277838274214270114e-13f, + 7.18012445138366623367e-13f, -1.79417853150680611778e-12f, + -1.32158118404477131188e-11f, -3.14991652796324136454e-11f, + 1.18891471078464383424e-11f, 4.94060238822496958910e-10f, + 3.39623202570838634515e-09f, 2.26666899049817806459e-08f, + 2.04891858946906374183e-07f, 2.89137052083475648297e-06f, + 6.88975834691682398426e-05f, 3.36911647825569408990e-03f, + 8.04490411014108831608e-01f, + }; + + float ax = std::abs(x); + if (ax <= 8.0f) { + float y = ax * 0.5f - 2.0f; + return chbevl_scalar(y, A_i0); + } + return chbevl_scalar(32.0f / ax - 2.0f, B_i0) / std::sqrt(ax); +} + +inline float i1e_scalar(float x) { + // Cephes Chebyshev coefficients for exp(-x) I1(x) / x, interval [0, 8]. + static const float A_i1[29] = { + 2.77791411276104639959e-18f, -2.11142121435816608115e-17f, + 1.55363195773620046921e-16f, -1.10559694773538630805e-15f, + 7.60068429473540693410e-15f, -5.04218550472791168711e-14f, + 3.22379336594557470981e-13f, -1.98397439776494371520e-12f, + 1.17361862988909016308e-11f, -6.66348972350202774223e-11f, + 3.62559028155211703701e-10f, -1.88724975172282928790e-09f, + 9.38153738649577178388e-09f, -4.44505912879632808065e-08f, + 2.00329475355213526229e-07f, -8.56872026469545474066e-07f, + 3.47025130813767847674e-06f, -1.32731636560394358279e-05f, + 4.78156510755005422638e-05f, -1.61760815825896745588e-04f, + 5.12285956168575772895e-04f, -1.51357245063125314899e-03f, + 4.15642294431288815669e-03f, -1.05640848946261981558e-02f, + 2.47264490306265168283e-02f, -5.29459812080949914269e-02f, + 1.02643658689847095384e-01f, -1.76416518357834055153e-01f, + 2.52587186443633654823e-01f, + }; + // Cephes Chebyshev coefficients for exp(-x) sqrt(x) I1(x), interval [8, inf]. + static const float B_i1[25] = { + 7.51729631084210481353e-18f, 4.41434832307170791151e-18f, + -4.65030536848935832153e-17f, -3.20952592199342395980e-17f, + 2.96262899764595013876e-16f, 3.30820231092092828324e-16f, + -1.88035477551078244854e-15f, -3.81440307243700780478e-15f, + 1.04202769841288027642e-14f, 4.27244001671195135429e-14f, + -2.10154184277266431302e-14f, -4.08355111109219731823e-13f, + -7.19855177624590851209e-13f, 2.03562854414708950722e-12f, + 1.41258074366137813316e-11f, 3.25260358301548823856e-11f, + -1.89749581235054123450e-11f, -5.58974346219658380687e-10f, + -3.83538038596423702205e-09f, -2.63146884688951950684e-08f, + -2.51223623787020892529e-07f, -3.88256480887769039346e-06f, + -1.10588938762623716291e-04f, -9.76109749136146840777e-03f, + 7.78576235018280120474e-01f, + }; + + float ax = std::abs(x); + float result; + if (ax <= 8.0f) { + float y = ax * 0.5f - 2.0f; + result = chbevl_scalar(y, A_i1) * ax; + } else { + result = chbevl_scalar(32.0f / ax - 2.0f, B_i1) / std::sqrt(ax); + } + return x < 0.0f ? -result : result; +} + +template +Simd i0e(Simd x) { + Simd result; + for (int i = 0; i < N; i++) { + result[i] = static_cast(i0e_scalar(static_cast(x[i]))); + } + return result; +} + +template +Simd i0e(Simd x) { + return Simd(static_cast(i0e_scalar(static_cast(x.value)))); +} + +template +Simd i1e(Simd x) { + Simd result; + for (int i = 0; i < N; i++) { + result[i] = static_cast(i1e_scalar(static_cast(x[i]))); + } + return result; +} + +template +Simd i1e(Simd x) { + return Simd(static_cast(i1e_scalar(static_cast(x.value)))); +} + } // namespace mlx::core::simd diff --git a/mlx/backend/cpu/unary.cpp b/mlx/backend/cpu/unary.cpp index eafe98866f..142141890c 100644 --- a/mlx/backend/cpu/unary.cpp +++ b/mlx/backend/cpu/unary.cpp @@ -103,6 +103,18 @@ void ErfInv::eval_cpu(const std::vector& inputs, array& out) { unary_real_fp(in, out, detail::ErfInv(), stream()); } +void BesselI0e::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + unary_real_fp(in, out, detail::BesselI0e(), stream()); +} + +void BesselI1e::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + unary_real_fp(in, out, detail::BesselI1e(), stream()); +} + void Exp::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; diff --git a/mlx/backend/cpu/unary_ops.h b/mlx/backend/cpu/unary_ops.h index f441e88bd5..f98a2c2652 100644 --- a/mlx/backend/cpu/unary_ops.h +++ b/mlx/backend/cpu/unary_ops.h @@ -41,6 +41,8 @@ DEFAULT_OP(Cos, cos) DEFAULT_OP(Cosh, cosh) DEFAULT_OP(Erf, erf) DEFAULT_OP(ErfInv, erfinv) +DEFAULT_OP(BesselI0e, i0e) +DEFAULT_OP(BesselI1e, i1e) DEFAULT_OP(Exp, exp) DEFAULT_OP(Expm1, expm1) DEFAULT_OP(Floor, floor); diff --git a/mlx/backend/cuda/device/unary_ops.cuh b/mlx/backend/cuda/device/unary_ops.cuh index 8234cf8ecc..63f2cfade6 100644 --- a/mlx/backend/cuda/device/unary_ops.cuh +++ b/mlx/backend/cuda/device/unary_ops.cuh @@ -131,6 +131,129 @@ struct ErfInv { } }; +struct BesselI0e { + template + __device__ T operator()(T x) { + float v = static_cast(x); + float ax = v < 0.0f ? -v : v; + // Cephes Chebyshev coefficients for i0e, |x| <= 8 + float A[30] = { + -4.41534164647933937950e-18f, 3.33079451882223809783e-17f, + -2.43127984654795469359e-16f, 1.71539128555513303061e-15f, + -1.16853328779934516808e-14f, 7.67618549860493561688e-14f, + -4.85644678311192946090e-13f, 2.95505266312963983461e-12f, + -1.72682629144155570723e-11f, 9.67580903537323691224e-11f, + -5.18979560163526290666e-10f, 2.65982372468238665035e-09f, + -1.30002500998624804212e-08f, 6.04699502254191894932e-08f, + -2.67079385394061173391e-07f, 1.11738753912010371815e-06f, + -4.41673835845875056359e-06f, 1.64484480707288970893e-05f, + -5.75419501008210370398e-05f, 1.88502885095841655729e-04f, + -5.76375574538582365885e-04f, 1.63947561694133579842e-03f, + -4.32430999505057594430e-03f, 1.05464603945949983183e-02f, + -2.37374148058994688156e-02f, 4.93052842396707084878e-02f, + -9.49010970480476444210e-02f, 1.71620901522208775349e-01f, + -3.04682672343198398683e-01f, 6.76795274409476084995e-01f, + }; + // Cephes Chebyshev coefficients for i0e, |x| > 8 + float B[25] = { + -7.23318048787475395456e-18f, -4.83050448594418207126e-18f, + 4.46562142029675999901e-17f, 3.46122286769746109310e-17f, + -2.82762398051658348494e-16f, -3.42548561967721913462e-16f, + 1.77256013305652638360e-15f, 3.81168066935262242075e-15f, + -9.55484669882830764870e-15f, -4.15056934728722208663e-14f, + 1.54008621752140982691e-14f, 3.85277838274214270114e-13f, + 7.18012445138366623367e-13f, -1.79417853150680611778e-12f, + -1.32158118404477131188e-11f, -3.14991652796324136454e-11f, + 1.18891471078464383424e-11f, 4.94060238822496958910e-10f, + 3.39623202570838634515e-09f, 2.26666899049817806459e-08f, + 2.04891858946906374183e-07f, 2.89137052083475648297e-06f, + 6.88975834691682398426e-05f, 3.36911647825569408990e-03f, + 8.04490411014108831608e-01f, + }; + // Clenshaw recurrence + auto chbevl = [](float x, float* coeffs, int n) { + float b0 = coeffs[0]; + float b1 = 0.0f; + float b2; + for (int i = 1; i < n; i++) { + b2 = b1; + b1 = b0; + b0 = x * b1 - b2 + coeffs[i]; + } + return 0.5f * (b0 - b2); + }; + float result; + if (ax <= 8.0f) { + result = chbevl(ax * 0.5f - 2.0f, A, 30); + } else { + result = chbevl(32.0f / ax - 2.0f, B, 25) / ::sqrtf(ax); + } + return static_cast(result); + } +}; + +struct BesselI1e { + template + __device__ T operator()(T x) { + float v = static_cast(x); + float ax = v < 0.0f ? -v : v; + // Cephes Chebyshev coefficients for i1e, |x| <= 8 + float A[29] = { + 2.77791411276104639959e-18f, -2.11142121435816608115e-17f, + 1.55363195773620046921e-16f, -1.10559694773538630805e-15f, + 7.60068429473540693410e-15f, -5.04218550472791168711e-14f, + 3.22379336594557470981e-13f, -1.98397439776494371520e-12f, + 1.17361862988909016308e-11f, -6.66348972350202774223e-11f, + 3.62559028155211703701e-10f, -1.88724975172282928790e-09f, + 9.38153738649577178388e-09f, -4.44505912879632808065e-08f, + 2.00329475355213526229e-07f, -8.56872026469545474066e-07f, + 3.47025130813767847674e-06f, -1.32731636560394358279e-05f, + 4.78156510755005422638e-05f, -1.61760815825896745588e-04f, + 5.12285956168575772895e-04f, -1.51357245063125314899e-03f, + 4.15642294431288815669e-03f, -1.05640848946261981558e-02f, + 2.47264490306265168283e-02f, -5.29459812080949914269e-02f, + 1.02643658689847095384e-01f, -1.76416518357834055153e-01f, + 2.52587186443633654823e-01f, + }; + // Cephes Chebyshev coefficients for i1e, |x| > 8 + float B[25] = { + 7.51729631084210481353e-18f, 4.41434832307170791151e-18f, + -4.65030536848935832153e-17f, -3.20952592199342395980e-17f, + 2.96262899764595013876e-16f, 3.30820231092092828324e-16f, + -1.88035477551078244854e-15f, -3.81440307243700780478e-15f, + 1.04202769841288027642e-14f, 4.27244001671195135429e-14f, + -2.10154184277266431302e-14f, -4.08355111109219731823e-13f, + -7.19855177624590851209e-13f, 2.03562854414708950722e-12f, + 1.41258074366137813316e-11f, 3.25260358301548823856e-11f, + -1.89749581235054123450e-11f, -5.58974346219658380687e-10f, + -3.83538038596423702205e-09f, -2.63146884688951950684e-08f, + -2.51223623787020892529e-07f, -3.88256480887769039346e-06f, + -1.10588938762623716291e-04f, -9.76109749136146840777e-03f, + 7.78576235018280120474e-01f, + }; + auto chbevl = [](float x, float* coeffs, int n) { + float b0 = coeffs[0]; + float b1 = 0.0f; + float b2; + for (int i = 1; i < n; i++) { + b2 = b1; + b1 = b0; + b0 = x * b1 - b2 + coeffs[i]; + } + return 0.5f * (b0 - b2); + }; + float result; + if (ax <= 8.0f) { + result = chbevl(ax * 0.5f - 2.0f, A, 29) * ax; + } else { + result = chbevl(32.0f / ax - 2.0f, B, 25) / ::sqrtf(ax); + } + if (v < 0.0f) + result = -result; + return static_cast(result); + } +}; + struct Exp { template __device__ T operator()(T x) { diff --git a/mlx/backend/cuda/unary/CMakeLists.txt b/mlx/backend/cuda/unary/CMakeLists.txt index 532c5645ed..4e36db6fe9 100644 --- a/mlx/backend/cuda/unary/CMakeLists.txt +++ b/mlx/backend/cuda/unary/CMakeLists.txt @@ -14,6 +14,8 @@ target_sources( PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/cosh.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/erf.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/erf_inv.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/bessel_i0e.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/bessel_i1e.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/exp.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/expm1.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/floor.cu diff --git a/mlx/backend/cuda/unary/bessel_i0e.cu b/mlx/backend/cuda/unary/bessel_i0e.cu new file mode 100644 index 0000000000..92acb46ddc --- /dev/null +++ b/mlx/backend/cuda/unary/bessel_i0e.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(BesselI0e) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/bessel_i1e.cu b/mlx/backend/cuda/unary/bessel_i1e.cu new file mode 100644 index 0000000000..b7b1fd5826 --- /dev/null +++ b/mlx/backend/cuda/unary/bessel_i1e.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(BesselI1e) +} // namespace mlx::core diff --git a/mlx/backend/metal/kernels/bessel.h b/mlx/backend/metal/kernels/bessel.h new file mode 100644 index 0000000000..39dfd39af5 --- /dev/null +++ b/mlx/backend/metal/kernels/bessel.h @@ -0,0 +1,132 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include + +/* + * Exponentially-scaled modified Bessel functions of the first kind, + * orders 0 and 1: i0e(x) = exp(-|x|) * I_0(x), i1e(x) = exp(-|x|) * I_1(x). + * + * Uses Cephes Chebyshev polynomial approximation (public domain, S. Moshier). + * Same algorithm used by SciPy, JAX/XLA, PyTorch, and TensorFlow. + * + * Accuracy: float32 relative error < 2e-7 across full range. + */ + +namespace { + +// Clenshaw recurrence for evaluating a Chebyshev series. +// Evaluates sum of coefficients[k] * T_k(x) using the recurrence: +// b_{n+1} = b_{n+2} = 0 +// b_k = x * b_{k+1} - b_{k+2} + coefficients[k] +// result = 0.5 * (b_0 - b_2) +template +float chbevl(float x, constant const float (&coeffs)[N]) { + float b0 = coeffs[0]; + float b1 = 0.0f; + float b2; + for (int i = 1; i < N; i++) { + b2 = b1; + b1 = b0; + b0 = x * b1 - b2 + coeffs[i]; + } + return 0.5f * (b0 - b2); +} + +// Chebyshev coefficients for exp(-x) I0(x) in the interval [0, 8]. +// lim(x->0){ exp(-x) I0(x) } = 1. +constant float A_i0[30] = { + -4.41534164647933937950e-18f, 3.33079451882223809783e-17f, + -2.43127984654795469359e-16f, 1.71539128555513303061e-15f, + -1.16853328779934516808e-14f, 7.67618549860493561688e-14f, + -4.85644678311192946090e-13f, 2.95505266312963983461e-12f, + -1.72682629144155570723e-11f, 9.67580903537323691224e-11f, + -5.18979560163526290666e-10f, 2.65982372468238665035e-09f, + -1.30002500998624804212e-08f, 6.04699502254191894932e-08f, + -2.67079385394061173391e-07f, 1.11738753912010371815e-06f, + -4.41673835845875056359e-06f, 1.64484480707288970893e-05f, + -5.75419501008210370398e-05f, 1.88502885095841655729e-04f, + -5.76375574538582365885e-04f, 1.63947561694133579842e-03f, + -4.32430999505057594430e-03f, 1.05464603945949983183e-02f, + -2.37374148058994688156e-02f, 4.93052842396707084878e-02f, + -9.49010970480476444210e-02f, 1.71620901522208775349e-01f, + -3.04682672343198398683e-01f, 6.76795274409476084995e-01f, +}; + +// Chebyshev coefficients for exp(-x) sqrt(x) I0(x) in [8, infinity]. +// lim(x->inf){ sqrt(x) exp(-x) I0(x) } = 1/sqrt(2 pi). +constant float B_i0[25] = { + -7.23318048787475395456e-18f, -4.83050448594418207126e-18f, + 4.46562142029675999901e-17f, 3.46122286769746109310e-17f, + -2.82762398051658348494e-16f, -3.42548561967721913462e-16f, + 1.77256013305652638360e-15f, 3.81168066935262242075e-15f, + -9.55484669882830764870e-15f, -4.15056934728722208663e-14f, + 1.54008621752140982691e-14f, 3.85277838274214270114e-13f, + 7.18012445138366623367e-13f, -1.79417853150680611778e-12f, + -1.32158118404477131188e-11f, -3.14991652796324136454e-11f, + 1.18891471078464383424e-11f, 4.94060238822496958910e-10f, + 3.39623202570838634515e-09f, 2.26666899049817806459e-08f, + 2.04891858946906374183e-07f, 2.89137052083475648297e-06f, + 6.88975834691682398426e-05f, 3.36911647825569408990e-03f, + 8.04490411014108831608e-01f, +}; + +// Chebyshev coefficients for exp(-x) I1(x) / x in the interval [0, 8]. +constant float A_i1[29] = { + 2.77791411276104639959e-18f, -2.11142121435816608115e-17f, + 1.55363195773620046921e-16f, -1.10559694773538630805e-15f, + 7.60068429473540693410e-15f, -5.04218550472791168711e-14f, + 3.22379336594557470981e-13f, -1.98397439776494371520e-12f, + 1.17361862988909016308e-11f, -6.66348972350202774223e-11f, + 3.62559028155211703701e-10f, -1.88724975172282928790e-09f, + 9.38153738649577178388e-09f, -4.44505912879632808065e-08f, + 2.00329475355213526229e-07f, -8.56872026469545474066e-07f, + 3.47025130813767847674e-06f, -1.32731636560394358279e-05f, + 4.78156510755005422638e-05f, -1.61760815825896745588e-04f, + 5.12285956168575772895e-04f, -1.51357245063125314899e-03f, + 4.15642294431288815669e-03f, -1.05640848946261981558e-02f, + 2.47264490306265168283e-02f, -5.29459812080949914269e-02f, + 1.02643658689847095384e-01f, -1.76416518357834055153e-01f, + 2.52587186443633654823e-01f, +}; + +// Chebyshev coefficients for exp(-x) sqrt(x) I1(x) in [8, infinity]. +constant float B_i1[25] = { + 7.51729631084210481353e-18f, 4.41434832307170791151e-18f, + -4.65030536848935832153e-17f, -3.20952592199342395980e-17f, + 2.96262899764595013876e-16f, 3.30820231092092828324e-16f, + -1.88035477551078244854e-15f, -3.81440307243700780478e-15f, + 1.04202769841288027642e-14f, 4.27244001671195135429e-14f, + -2.10154184277266431302e-14f, -4.08355111109219731823e-13f, + -7.19855177624590851209e-13f, 2.03562854414708950722e-12f, + 1.41258074366137813316e-11f, 3.25260358301548823856e-11f, + -1.89749581235054123450e-11f, -5.58974346219658380687e-10f, + -3.83538038596423702205e-09f, -2.63146884688951950684e-08f, + -2.51223623787020892529e-07f, -3.88256480887769039346e-06f, + -1.10588938762623716291e-04f, -9.76109749136146840777e-03f, + 7.78576235018280120474e-01f, +}; + +} // anonymous namespace + +float i0e_impl(float x) { + float ax = metal::abs(x); + if (ax <= 8.0f) { + float y = ax * 0.5f - 2.0f; + return chbevl(y, A_i0); + } + return chbevl(32.0f / ax - 2.0f, B_i0) / metal::sqrt(ax); +} + +float i1e_impl(float x) { + float ax = metal::abs(x); + float result; + if (ax <= 8.0f) { + float y = ax * 0.5f - 2.0f; + result = chbevl(y, A_i1) * ax; + } else { + result = chbevl(32.0f / ax - 2.0f, B_i1) / metal::sqrt(ax); + } + return x < 0.0f ? -result : result; +} diff --git a/mlx/backend/metal/kernels/unary.metal b/mlx/backend/metal/kernels/unary.metal index 54a0f566c8..0bfca2f944 100644 --- a/mlx/backend/metal/kernels/unary.metal +++ b/mlx/backend/metal/kernels/unary.metal @@ -67,6 +67,8 @@ instantiate_unary_types(Negative) instantiate_unary_float(Sigmoid) instantiate_unary_float(Erf) instantiate_unary_float(ErfInv) +instantiate_unary_float(BesselI0e) +instantiate_unary_float(BesselI1e) instantiate_unary_types(Sign) instantiate_unary_float(Sin) instantiate_unary_float(Sinh) diff --git a/mlx/backend/metal/kernels/unary_ops.h b/mlx/backend/metal/kernels/unary_ops.h index 327bb5a940..0426909e80 100644 --- a/mlx/backend/metal/kernels/unary_ops.h +++ b/mlx/backend/metal/kernels/unary_ops.h @@ -5,6 +5,7 @@ #include #include +#include "mlx/backend/metal/kernels/bessel.h" #include "mlx/backend/metal/kernels/cexpf.h" #include "mlx/backend/metal/kernels/erf.h" #include "mlx/backend/metal/kernels/expm1f.h" @@ -174,6 +175,20 @@ struct ErfInv { }; }; +struct BesselI0e { + template + T operator()(T x) { + return static_cast(i0e_impl(static_cast(x))); + }; +}; + +struct BesselI1e { + template + T operator()(T x) { + return static_cast(i1e_impl(static_cast(x))); + }; +}; + struct Exp { template T operator()(T x) { diff --git a/mlx/backend/metal/unary.cpp b/mlx/backend/metal/unary.cpp index 833b23f632..82cc530d3a 100644 --- a/mlx/backend/metal/unary.cpp +++ b/mlx/backend/metal/unary.cpp @@ -125,6 +125,8 @@ UNARY_GPU(Cos) UNARY_GPU(Cosh) UNARY_GPU(Erf) UNARY_GPU(ErfInv) +UNARY_GPU(BesselI0e) +UNARY_GPU(BesselI1e) UNARY_GPU(Exp) UNARY_GPU(Expm1) UNARY_GPU(Imag) diff --git a/mlx/export.cpp b/mlx/export.cpp index 1160e8cfde..022eb121e7 100644 --- a/mlx/export.cpp +++ b/mlx/export.cpp @@ -361,6 +361,8 @@ struct PrimitiveFactory { SERIALIZE_PRIMITIVE(Equal, "NaNEqual"), SERIALIZE_PRIMITIVE(Erf), SERIALIZE_PRIMITIVE(ErfInv), + SERIALIZE_PRIMITIVE(BesselI0e), + SERIALIZE_PRIMITIVE(BesselI1e), SERIALIZE_PRIMITIVE(Exp), SERIALIZE_PRIMITIVE(Expm1), SERIALIZE_PRIMITIVE(ExpandDims), diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 0cafd479a2..b80f303eb5 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -2996,6 +2996,24 @@ array erfinv(const array& a, StreamOrDevice s /* = {} */) { {astype(a, dtype, s)}); } +array bessel_i0e(const array& a, StreamOrDevice s /* = {} */) { + auto dtype = at_least_float(a.dtype()); + return array( + a.shape(), + dtype, + std::make_shared(to_stream(s)), + {astype(a, dtype, s)}); +} + +array bessel_i1e(const array& a, StreamOrDevice s /* = {} */) { + auto dtype = at_least_float(a.dtype()); + return array( + a.shape(), + dtype, + std::make_shared(to_stream(s)), + {astype(a, dtype, s)}); +} + array stop_gradient(const array& a, StreamOrDevice s /* = {} */) { return array( a.shape(), a.dtype(), std::make_shared(to_stream(s)), {a}); diff --git a/mlx/ops.h b/mlx/ops.h index 1f8331c1d9..a21e1b1552 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -948,6 +948,12 @@ MLX_API array erf(const array& a, StreamOrDevice s = {}); /** Computes the inverse error function of the elements of an array. */ MLX_API array erfinv(const array& a, StreamOrDevice s = {}); +/** Element-wise exponentially scaled modified Bessel function of order 0. */ +MLX_API array bessel_i0e(const array& a, StreamOrDevice s = {}); + +/** Element-wise exponentially scaled modified Bessel function of order 1. */ +MLX_API array bessel_i1e(const array& a, StreamOrDevice s = {}); + /** Computes the expm1 function of the elements of an array. */ MLX_API array expm1(const array& a, StreamOrDevice s = {}); diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 92e54f9991..42e2d29c92 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -1930,6 +1930,101 @@ std::pair, std::vector> ErfInv::vmap( return {{erfinv(inputs[0], stream())}, axes}; } +// d/dx i0e(x) = i1e(x) - sign(x) * i0e(x) +std::vector BesselI0e::vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector&) { + auto x = primals[0]; + auto grad = subtract( + bessel_i1e(x, stream()), + multiply(sign(x, stream()), bessel_i0e(x, stream()), stream()), + stream()); + return {multiply(cotangents[0], grad, stream())}; +} + +std::vector BesselI0e::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(primals.size() == 1); + assert(argnums.size() == 1); + auto x = primals[0]; + auto grad = subtract( + bessel_i1e(x, stream()), + multiply(sign(x, stream()), bessel_i0e(x, stream()), stream()), + stream()); + return {multiply(tangents[0], grad, stream())}; +} + +std::pair, std::vector> BesselI0e::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + return {{bessel_i0e(inputs[0], stream())}, axes}; +} + +// d/dx i1e(x) = i0e(x) - (sign(x) + 1/x) * i1e(x) +// At x=0: i1e(0)=0 and 1/x*i1e(x) -> 0 via L'Hopital, so gradient = i0e(0) = 1 +std::vector BesselI1e::vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector&) { + auto x = primals[0]; + auto i0e_val = bessel_i0e(x, stream()); + auto i1e_val = bessel_i1e(x, stream()); + auto s = sign(x, stream()); + auto ax = abs(x, stream()); + auto eps = array(1e-10f, x.dtype()); + auto safe = greater(ax, eps, stream()); + auto inv_x = where( + safe, + divide(array(1.0f, x.dtype()), x, stream()), + array(0.0f, x.dtype()), + stream()); + auto grad = subtract( + i0e_val, + multiply(add(s, inv_x, stream()), i1e_val, stream()), + stream()); + return {multiply(cotangents[0], grad, stream())}; +} + +std::vector BesselI1e::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(primals.size() == 1); + assert(argnums.size() == 1); + auto x = primals[0]; + auto i0e_val = bessel_i0e(x, stream()); + auto i1e_val = bessel_i1e(x, stream()); + auto s = sign(x, stream()); + auto ax = abs(x, stream()); + auto eps = array(1e-10f, x.dtype()); + auto safe = greater(ax, eps, stream()); + auto inv_x = where( + safe, + divide(array(1.0f, x.dtype()), x, stream()), + array(0.0f, x.dtype()), + stream()); + auto grad = subtract( + i0e_val, + multiply(add(s, inv_x, stream()), i1e_val, stream()), + stream()); + return {multiply(tangents[0], grad, stream())}; +} + +std::pair, std::vector> BesselI1e::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + return {{bessel_i1e(inputs[0], stream())}, axes}; +} + std::vector Exp::vjp( const std::vector& primals, const std::vector& cotangents, diff --git a/mlx/primitives.h b/mlx/primitives.h index 4091aafcfb..dae87af355 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1018,6 +1018,34 @@ class ErfInv : public UnaryPrimitive { DEFINE_INPUT_OUTPUT_SHAPE() }; +class BesselI0e : public UnaryPrimitive { + public: + explicit BesselI0e(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(BesselI0e) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class BesselI1e : public UnaryPrimitive { + public: + explicit BesselI1e(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(BesselI1e) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + class MLX_API Exp : public UnaryPrimitive { public: explicit Exp(Stream stream) : UnaryPrimitive(stream) {} diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 1923170069..743b9b9d3a 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -919,6 +919,50 @@ void init_ops(nb::module_& m) { Returns: array: The inverse error function of ``a``. )pbdoc"); + m.def( + "bessel_i0e", + [](const ScalarOrArray& a, mx::StreamOrDevice s) { + return mx::bessel_i0e(to_array(a), s); + }, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def bessel_i0e(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Element-wise exponentially scaled modified Bessel function of order 0. + + Computes :math:`e^{-|x|} I_0(x)` where :math:`I_0` is the modified + Bessel function of the first kind of order 0. + + Args: + a (array): Input array. + + Returns: + array: The exponentially scaled Bessel I0 of ``a``. + )pbdoc"); + m.def( + "bessel_i1e", + [](const ScalarOrArray& a, mx::StreamOrDevice s) { + return mx::bessel_i1e(to_array(a), s); + }, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def bessel_i1e(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Element-wise exponentially scaled modified Bessel function of order 1. + + Computes :math:`e^{-|x|} I_1(x)` where :math:`I_1` is the modified + Bessel function of the first kind of order 1. + + Args: + a (array): Input array. + + Returns: + array: The exponentially scaled Bessel I1 of ``a``. + )pbdoc"); m.def( "sin", [](const ScalarOrArray& a, mx::StreamOrDevice s) { diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index a498116166..4b2fdc082f 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1020,6 +1020,50 @@ def test_erfinv(self): expected = mx.array([3.8325066566467285] * 8) self.assertTrue(mx.allclose(result, expected)) + def test_bessel_i0e(self): + # i0e(0) = 1 + self.assertTrue( + np.allclose(mx.bessel_i0e(mx.array([0.0])), [1.0], atol=1e-6) + ) + # Known values: i0e(1) ≈ 0.4657596, i0e(5) ≈ 0.1835408 + x = mx.array([0.0, 1.0, 5.0, 10.0]) + expected = np.array([1.0, 0.4657596, 0.1835408, 0.1278333]).astype( + np.float32 + ) + self.assertTrue(np.allclose(mx.bessel_i0e(x), expected, atol=1e-5)) + # Type promotion + self.assertEqual(mx.bessel_i0e(mx.array(2)).dtype, mx.float32) + # Even symmetry + self.assertTrue( + np.allclose( + mx.bessel_i0e(mx.array([-3.0])), + mx.bessel_i0e(mx.array([3.0])), + atol=1e-6, + ) + ) + + def test_bessel_i1e(self): + # i1e(0) = 0 + self.assertTrue( + np.allclose(mx.bessel_i1e(mx.array([0.0])), [0.0], atol=1e-6) + ) + # Known values: i1e(1) ≈ 0.2079104 + x = mx.array([0.0, 1.0, 5.0, 10.0]) + expected = np.array([0.0, 0.2079104, 0.1639723, 0.1212627]).astype( + np.float32 + ) + self.assertTrue(np.allclose(mx.bessel_i1e(x), expected, atol=1e-5)) + # Type promotion + self.assertEqual(mx.bessel_i1e(mx.array(2)).dtype, mx.float32) + # Odd symmetry + self.assertTrue( + np.allclose( + mx.bessel_i1e(mx.array([-3.0])), + -mx.bessel_i1e(mx.array([3.0])), + atol=1e-6, + ) + ) + def test_sin(self): a = mx.array( [0, math.pi / 4, math.pi / 2, math.pi, 3 * math.pi / 4, 2 * math.pi] diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 62fd8c5923..91f09f66e5 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -4157,3 +4157,77 @@ TEST_CASE("test max min with nan") { CHECK(array_equal(max_result, expected, true).item()); CHECK(array_equal(min_result, expected, true).item()); } + +TEST_CASE("test bessel_i0e and bessel_i1e") { + // i0e(0) = 1 + CHECK_EQ( + bessel_i0e(array(0.0f)).item(), + doctest::Approx(1.0f).epsilon(1e-5)); + + // i1e(0) = 0 + CHECK_EQ( + bessel_i1e(array(0.0f)).item(), + doctest::Approx(0.0f).epsilon(1e-5)); + + // Known value: i0e(1) = exp(-1) * I0(1) ≈ 0.4657596 + CHECK_EQ( + bessel_i0e(array(1.0f)).item(), + doctest::Approx(0.4657596f).epsilon(1e-5)); + + // Known value: i1e(1) = exp(-1) * I1(1) ≈ 0.2079104 + CHECK_EQ( + bessel_i1e(array(1.0f)).item(), + doctest::Approx(0.2079104f).epsilon(1e-5)); + + // Even symmetry: i0e(-x) = i0e(x) + CHECK_EQ( + bessel_i0e(array(-5.0f)).item(), + doctest::Approx(bessel_i0e(array(5.0f)).item()).epsilon(1e-6)); + + // Odd symmetry: i1e(-x) = -i1e(x) + CHECK_EQ( + bessel_i1e(array(-5.0f)).item(), + doctest::Approx(-bessel_i1e(array(5.0f)).item()).epsilon(1e-6)); + + // Large x (asymptotic regime): i0e(100) ≈ 1/sqrt(2*pi*100) ≈ 0.03989 + CHECK_EQ( + bessel_i0e(array(100.0f)).item(), + doctest::Approx(0.0398942f).epsilon(1e-4)); + + // Type promotion: int -> float + CHECK_EQ(bessel_i0e(array(2, int32)).dtype(), float32); + CHECK_EQ(bessel_i1e(array(2, int32)).dtype(), float32); + + // Vectorized + { + auto x = array({0.0f, 1.0f, 5.0f, 10.0f}); + CHECK_EQ(bessel_i0e(x).shape(), Shape{4}); + CHECK_EQ(bessel_i1e(x).shape(), Shape{4}); + } + + // float16 + { + array x(1.0f, float16); + CHECK_EQ(bessel_i0e(x).dtype(), float16); + CHECK_EQ(bessel_i1e(x).dtype(), float16); + } + + // Gradient of i0e: d/dx i0e(x) = i1e(x) - sign(x) * i0e(x) + { + auto f = [](array x) { return bessel_i0e(x); }; + auto dfdx = grad(f)(array(3.0f)); + auto expected = bessel_i1e(array(3.0f)).item() - + bessel_i0e(array(3.0f)).item(); + CHECK_EQ(dfdx.item(), doctest::Approx(expected).epsilon(1e-5)); + } + + // Gradient of i1e: d/dx i1e(x) = i0e(x) - (sign(x) + 1/x) * i1e(x) + { + auto f = [](array x) { return bessel_i1e(x); }; + auto dfdx = grad(f)(array(3.0f)); + float x_val = 3.0f; + float expected = bessel_i0e(array(x_val)).item() - + (1.0f + 1.0f / x_val) * bessel_i1e(array(x_val)).item(); + CHECK_EQ(dfdx.item(), doctest::Approx(expected).epsilon(1e-5)); + } +}