diff --git a/mlx/backend/cpu/simd/math.h b/mlx/backend/cpu/simd/math.h index f9fc8317a5..7d87ebc3f9 100644 --- a/mlx/backend/cpu/simd/math.h +++ b/mlx/backend/cpu/simd/math.h @@ -190,4 +190,68 @@ Simd erfinv(Simd a_) { } } +template +Simd lgamma(Simd x) { + // Delegate to std::lgamma element-wise + Simd result; + for (int i = 0; i < N; i++) { + result[i] = static_cast(std::lgamma(static_cast(x[i]))); + } + return result; +} + +template +Simd lgamma(Simd x) { + return Simd(static_cast(std::lgamma(static_cast(x.value)))); +} + +template +Simd digamma(Simd x_) { + // Asymptotic expansion + recurrence, matching the Metal kernel + Simd result; + for (int i = 0; i < N; i++) { + float v = static_cast(x_[i]); + float r = 0.0f; + if (v < 0.0f) { + r = -M_PI / std::tan(M_PI * v); + v = 1.0f - v; + } + while (v < 10.0f) { + r -= 1.0f / v; + v += 1.0f; + } + float z = 1.0f / (v * v); + float y = 3.96825396825e-3f; + y = y * z + (-4.16666666667e-3f); + y = y * z + 7.57575757576e-3f; + y = y * z + (-2.10927960928e-2f); + y = y * z + 8.33333333333e-2f; + r += std::log(v) - 0.5f / v - y * z; + result[i] = static_cast(r); + } + return result; +} + +template +Simd digamma(Simd x_) { + float v = static_cast(x_.value); + float r = 0.0f; + if (v < 0.0f) { + r = -M_PI / std::tan(M_PI * v); + v = 1.0f - v; + } + while (v < 10.0f) { + r -= 1.0f / v; + v += 1.0f; + } + float z = 1.0f / (v * v); + float y = 3.96825396825e-3f; + y = y * z + (-4.16666666667e-3f); + y = y * z + 7.57575757576e-3f; + y = y * z + (-2.10927960928e-2f); + y = y * z + 8.33333333333e-2f; + r += std::log(v) - 0.5f / v - y * z; + return Simd(static_cast(r)); +} + } // namespace mlx::core::simd diff --git a/mlx/backend/cpu/unary.cpp b/mlx/backend/cpu/unary.cpp index eafe98866f..ed50bb4f75 100644 --- a/mlx/backend/cpu/unary.cpp +++ b/mlx/backend/cpu/unary.cpp @@ -103,6 +103,18 @@ void ErfInv::eval_cpu(const std::vector& inputs, array& out) { unary_real_fp(in, out, detail::ErfInv(), stream()); } +void LogGamma::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + unary_real_fp(in, out, detail::LogGamma(), stream()); +} + +void Digamma::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + unary_real_fp(in, out, detail::Digamma(), stream()); +} + void Exp::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; diff --git a/mlx/backend/cpu/unary_ops.h b/mlx/backend/cpu/unary_ops.h index f441e88bd5..21e14992ab 100644 --- a/mlx/backend/cpu/unary_ops.h +++ b/mlx/backend/cpu/unary_ops.h @@ -41,6 +41,8 @@ DEFAULT_OP(Cos, cos) DEFAULT_OP(Cosh, cosh) DEFAULT_OP(Erf, erf) DEFAULT_OP(ErfInv, erfinv) +DEFAULT_OP(LogGamma, lgamma) +DEFAULT_OP(Digamma, digamma) DEFAULT_OP(Exp, exp) DEFAULT_OP(Expm1, expm1) DEFAULT_OP(Floor, floor); diff --git a/mlx/backend/cuda/device/unary_ops.cuh b/mlx/backend/cuda/device/unary_ops.cuh index 8234cf8ecc..36bd6c0bd2 100644 --- a/mlx/backend/cuda/device/unary_ops.cuh +++ b/mlx/backend/cuda/device/unary_ops.cuh @@ -131,6 +131,43 @@ struct ErfInv { } }; +struct LogGamma { + template + __device__ T operator()(T x) { + if constexpr (cuda::std::is_same_v) { + return ::lgamma(__half2float(x)); + } else if constexpr (cuda::std::is_same_v) { + return ::lgamma(__bfloat162float(x)); + } else { + return ::lgamma(x); + } + } +}; + +struct Digamma { + template + __device__ T operator()(T x) { + float v = static_cast(x); + float r = 0.0f; + if (v < 0.0f) { + r = -M_PI / ::tan(M_PI * v); + v = 1.0f - v; + } + while (v < 10.0f) { + r -= 1.0f / v; + v += 1.0f; + } + float z = 1.0f / (v * v); + float y = 3.96825396825e-3f; + y = y * z + (-4.16666666667e-3f); + y = y * z + 7.57575757576e-3f; + y = y * z + (-2.10927960928e-2f); + y = y * z + 8.33333333333e-2f; + r += ::logf(v) - 0.5f / v - y * z; + return static_cast(r); + } +}; + struct Exp { template __device__ T operator()(T x) { diff --git a/mlx/backend/cuda/unary/CMakeLists.txt b/mlx/backend/cuda/unary/CMakeLists.txt index 532c5645ed..b18c25d4ad 100644 --- a/mlx/backend/cuda/unary/CMakeLists.txt +++ b/mlx/backend/cuda/unary/CMakeLists.txt @@ -14,6 +14,8 @@ target_sources( PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/cosh.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/erf.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/erf_inv.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/lgamma.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/digamma.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/exp.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/expm1.cu PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/floor.cu diff --git a/mlx/backend/cuda/unary/digamma.cu b/mlx/backend/cuda/unary/digamma.cu new file mode 100644 index 0000000000..0e19486b8e --- /dev/null +++ b/mlx/backend/cuda/unary/digamma.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(Digamma) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/lgamma.cu b/mlx/backend/cuda/unary/lgamma.cu new file mode 100644 index 0000000000..31abdada72 --- /dev/null +++ b/mlx/backend/cuda/unary/lgamma.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(LogGamma) +} // namespace mlx::core diff --git a/mlx/backend/metal/kernels/lgamma.h b/mlx/backend/metal/kernels/lgamma.h new file mode 100644 index 0000000000..d1d8c82645 --- /dev/null +++ b/mlx/backend/metal/kernels/lgamma.h @@ -0,0 +1,83 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include + +/* + * Log-gamma function via Lanczos approximation (g=5, 7-term). + * Coefficients from Numerical Recipes; ~10 digits accuracy, + * sufficient for float32. + * + * Uses reflection formula for x < 0.5. + */ +float lgamma_impl(float x) { + // Poles at non-positive integers + if (x <= 0.0f && x == metal::floor(x)) { + return metal::numeric_limits::infinity(); + } + + // Reflection formula for x < 0.5: + // lgamma(x) = log(pi) - log(|sin(pi*x)|) - lgamma(1-x) + // Lanczos for 1-x is inlined (no function calls — Metal has no call stack). + float reflection = 0.0f; + if (x < 0.5f) { + float sin_pi_x = metal::precise::sin(M_PI_F * x); + reflection = 1.1447298858494002f // log(pi) + - metal::precise::log(metal::abs(sin_pi_x)); + x = 1.0f - x; + } + + // Lanczos g=5, 7-term (Numerical Recipes) + float z = x - 1.0f; + float t = z + 5.5f; // z + g + 0.5 + float s = 1.000000000190015f; + s = metal::fma(76.18009172947146f, 1.0f / (z + 1.0f), s); + s = metal::fma(-86.50532032941677f, 1.0f / (z + 2.0f), s); + s = metal::fma(24.01409824083091f, 1.0f / (z + 3.0f), s); + s = metal::fma(-1.231739572450155f, 1.0f / (z + 4.0f), s); + s = metal::fma(1.208650973866179e-3f, 1.0f / (z + 5.0f), s); + s = metal::fma(-5.395239384953e-6f, 1.0f / (z + 6.0f), s); + + float lanczos = 0.9189385332046727f // 0.5 * log(2*pi) + + (z + 0.5f) * metal::precise::log(t) - t + + metal::precise::log(s); + + // For x < 0.5: reflection - lanczos(1-x). For x >= 0.5: just lanczos(x). + return (reflection != 0.0f) ? (reflection - lanczos) : lanczos; +} + +/* + * Digamma (psi) function: d/dx[lgamma(x)]. + * Uses asymptotic expansion for x >= 10, recurrence to shift + * small x, and reflection formula for negative x. + * + * Asymptotic coefficients are Bernoulli-number derived (float32-adequate). + */ +float digamma_impl(float x) { + float result = 0.0f; + + // Reflection for negative x: + // digamma(x) = digamma(1-x) - pi/tan(pi*x) + if (x < 0.0f) { + result = -M_PI_F / metal::precise::tan(M_PI_F * x); + x = 1.0f - x; + } + + // Recurrence: shift x to >= 10 + // digamma(x) = digamma(x+1) - 1/x + while (x < 10.0f) { + result -= 1.0f / x; + x += 1.0f; + } + + // Asymptotic expansion (Bernoulli numbers, Horner form) + float z = 1.0f / (x * x); + float y = metal::fma(3.96825396825e-3f, z, -4.16666666667e-3f); + y = metal::fma(y, z, 7.57575757576e-3f); + y = metal::fma(y, z, -2.10927960928e-2f); + y = metal::fma(y, z, 8.33333333333e-2f); + + result += metal::precise::log(x) - 0.5f / x - y * z; + return result; +} diff --git a/mlx/backend/metal/kernels/unary.metal b/mlx/backend/metal/kernels/unary.metal index 54a0f566c8..ff8980e05f 100644 --- a/mlx/backend/metal/kernels/unary.metal +++ b/mlx/backend/metal/kernels/unary.metal @@ -67,6 +67,8 @@ instantiate_unary_types(Negative) instantiate_unary_float(Sigmoid) instantiate_unary_float(Erf) instantiate_unary_float(ErfInv) +instantiate_unary_float(LogGamma) +instantiate_unary_float(Digamma) instantiate_unary_types(Sign) instantiate_unary_float(Sin) instantiate_unary_float(Sinh) diff --git a/mlx/backend/metal/kernels/unary_ops.h b/mlx/backend/metal/kernels/unary_ops.h index 327bb5a940..48a0b9893e 100644 --- a/mlx/backend/metal/kernels/unary_ops.h +++ b/mlx/backend/metal/kernels/unary_ops.h @@ -7,6 +7,7 @@ #include "mlx/backend/metal/kernels/cexpf.h" #include "mlx/backend/metal/kernels/erf.h" +#include "mlx/backend/metal/kernels/lgamma.h" #include "mlx/backend/metal/kernels/expm1f.h" #include "mlx/backend/metal/kernels/fp8.h" @@ -174,6 +175,20 @@ struct ErfInv { }; }; +struct LogGamma { + template + T operator()(T x) { + return static_cast(lgamma_impl(static_cast(x))); + }; +}; + +struct Digamma { + template + T operator()(T x) { + return static_cast(digamma_impl(static_cast(x))); + }; +}; + struct Exp { template T operator()(T x) { diff --git a/mlx/backend/metal/unary.cpp b/mlx/backend/metal/unary.cpp index 833b23f632..c46a3c7c42 100644 --- a/mlx/backend/metal/unary.cpp +++ b/mlx/backend/metal/unary.cpp @@ -125,6 +125,8 @@ UNARY_GPU(Cos) UNARY_GPU(Cosh) UNARY_GPU(Erf) UNARY_GPU(ErfInv) +UNARY_GPU(LogGamma) +UNARY_GPU(Digamma) UNARY_GPU(Exp) UNARY_GPU(Expm1) UNARY_GPU(Imag) diff --git a/mlx/export.cpp b/mlx/export.cpp index 1160e8cfde..3259663158 100644 --- a/mlx/export.cpp +++ b/mlx/export.cpp @@ -361,6 +361,8 @@ struct PrimitiveFactory { SERIALIZE_PRIMITIVE(Equal, "NaNEqual"), SERIALIZE_PRIMITIVE(Erf), SERIALIZE_PRIMITIVE(ErfInv), + SERIALIZE_PRIMITIVE(LogGamma), + SERIALIZE_PRIMITIVE(Digamma), SERIALIZE_PRIMITIVE(Exp), SERIALIZE_PRIMITIVE(Expm1), SERIALIZE_PRIMITIVE(ExpandDims), diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 0cafd479a2..e673a16dd7 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -2996,6 +2996,24 @@ array erfinv(const array& a, StreamOrDevice s /* = {} */) { {astype(a, dtype, s)}); } +array lgamma(const array& a, StreamOrDevice s /* = {} */) { + auto dtype = at_least_float(a.dtype()); + return array( + a.shape(), + dtype, + std::make_shared(to_stream(s)), + {astype(a, dtype, s)}); +} + +array digamma(const array& a, StreamOrDevice s /* = {} */) { + auto dtype = at_least_float(a.dtype()); + return array( + a.shape(), + dtype, + std::make_shared(to_stream(s)), + {astype(a, dtype, s)}); +} + array stop_gradient(const array& a, StreamOrDevice s /* = {} */) { return array( a.shape(), a.dtype(), std::make_shared(to_stream(s)), {a}); diff --git a/mlx/ops.h b/mlx/ops.h index 1f8331c1d9..1de31a4b45 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -948,6 +948,12 @@ MLX_API array erf(const array& a, StreamOrDevice s = {}); /** Computes the inverse error function of the elements of an array. */ MLX_API array erfinv(const array& a, StreamOrDevice s = {}); +/** Element-wise log-gamma (log of the absolute value of the gamma function). */ +MLX_API array lgamma(const array& a, StreamOrDevice s = {}); + +/** Element-wise digamma (psi) function, the derivative of lgamma. */ +MLX_API array digamma(const array& a, StreamOrDevice s = {}); + /** Computes the expm1 function of the elements of an array. */ MLX_API array expm1(const array& a, StreamOrDevice s = {}); diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 92e54f9991..bd5f17e966 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -1930,6 +1930,59 @@ std::pair, std::vector> ErfInv::vmap( return {{erfinv(inputs[0], stream())}, axes}; } +std::vector LogGamma::vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector&) { + // d/dx lgamma(x) = digamma(x) + return {multiply(cotangents[0], digamma(primals[0], stream()), stream())}; +} + +std::vector LogGamma::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(primals.size() == 1); + assert(argnums.size() == 1); + return {multiply(tangents[0], digamma(primals[0], stream()), stream())}; +} + +std::pair, std::vector> LogGamma::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + return {{lgamma(inputs[0], stream())}, axes}; +} + +std::vector Digamma::vjp( + const std::vector&, + const std::vector&, + const std::vector&, + const std::vector&) { + throw std::runtime_error( + "[Digamma] Digamma is not twice differentiable. " + "Second-order gradients of lgamma are not supported."); +} + +std::vector Digamma::jvp( + const std::vector&, + const std::vector&, + const std::vector&) { + throw std::runtime_error( + "[Digamma] Digamma is not twice differentiable. " + "Second-order gradients of lgamma are not supported."); +} + +std::pair, std::vector> Digamma::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + return {{digamma(inputs[0], stream())}, axes}; +} + std::vector Exp::vjp( const std::vector& primals, const std::vector& cotangents, diff --git a/mlx/primitives.h b/mlx/primitives.h index 4091aafcfb..7acf61789c 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1018,6 +1018,34 @@ class ErfInv : public UnaryPrimitive { DEFINE_INPUT_OUTPUT_SHAPE() }; +class LogGamma : public UnaryPrimitive { + public: + explicit LogGamma(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(LogGamma) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class Digamma : public UnaryPrimitive { + public: + explicit Digamma(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Digamma) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + class MLX_API Exp : public UnaryPrimitive { public: explicit Exp(Stream stream) : UnaryPrimitive(stream) {} diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 1923170069..545c0775d5 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -919,6 +919,56 @@ void init_ops(nb::module_& m) { Returns: array: The inverse error function of ``a``. )pbdoc"); + m.def( + "lgamma", + [](const ScalarOrArray& a, mx::StreamOrDevice s) { + return mx::lgamma(to_array(a), s); + }, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def lgamma(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Element-wise log-gamma function. + + Computes the natural logarithm of the absolute value of the + gamma function. + + .. math:: + \mathrm{lgamma}(x) = \log |\Gamma(x)| + + Args: + a (array): Input array. + + Returns: + array: The log-gamma of ``a``. + )pbdoc"); + m.def( + "digamma", + [](const ScalarOrArray& a, mx::StreamOrDevice s) { + return mx::digamma(to_array(a), s); + }, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def digamma(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Element-wise digamma function. + + Computes the logarithmic derivative of the gamma function + (the derivative of :func:`lgamma`). + + .. math:: + \psi(x) = \frac{d}{dx} \log \Gamma(x) + + Args: + a (array): Input array. + + Returns: + array: The digamma of ``a``. + )pbdoc"); m.def( "sin", [](const ScalarOrArray& a, mx::StreamOrDevice s) { diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index a498116166..e554fac580 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1020,6 +1020,54 @@ def test_erfinv(self): expected = mx.array([3.8325066566467285] * 8) self.assertTrue(mx.allclose(result, expected)) + def test_lgamma(self): + # Known values: lgamma(1) = 0, lgamma(2) = 0 + x = mx.array([0.5, 1.0, 2.0, 5.0, 10.0]) + expected = np.array( + [ + math.lgamma(0.5), + math.lgamma(1.0), + math.lgamma(2.0), + math.lgamma(5.0), + math.lgamma(10.0), + ] + ).astype(np.float32) + self.assertTrue(np.allclose(mx.lgamma(x), expected, atol=1e-5)) + + # Integer input promotion + self.assertEqual(mx.lgamma(mx.array(2)).dtype, mx.float32) + + # Negative non-integer (reflection formula) + result = mx.lgamma(mx.array([-0.5])) + expected_val = np.array([math.lgamma(0.5) + math.log(math.pi) - math.log(math.sin(math.pi * 0.5))]).astype(np.float32) + # lgamma(-0.5) = log(pi / sin(pi*0.5)) + lgamma(1.5) ... use math.lgamma directly + expected_val = np.array([math.lgamma(-0.5 + 1e-15)]).astype(np.float32) + # Actually just use Python's math.lgamma for the reference + self.assertTrue( + np.allclose( + mx.lgamma(mx.array(-0.5)), + math.lgamma(-0.5 + 1e-15), + atol=1e-3, + ) + ) + + def test_digamma(self): + # digamma(1) = -euler_gamma ≈ -0.5772156649 + result = mx.digamma(mx.array([1.0])) + self.assertTrue(np.allclose(result, [-0.5772157], atol=1e-5)) + + # digamma(2) = 1 - euler_gamma ≈ 0.4227843351 + result = mx.digamma(mx.array([2.0])) + self.assertTrue(np.allclose(result, [0.4227843], atol=1e-5)) + + # Vectorized + x = mx.array([1.0, 2.0, 5.0, 10.0]) + result = mx.digamma(x) + self.assertEqual(list(result.shape), [4]) + + # Integer input promotion + self.assertEqual(mx.digamma(mx.array(2)).dtype, mx.float32) + def test_sin(self): a = mx.array( [0, math.pi / 4, math.pi / 2, math.pi, 3 * math.pi / 4, 2 * math.pi] diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 62fd8c5923..5baa482cd4 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -1731,6 +1731,84 @@ TEST_CASE("test error functions") { } } +TEST_CASE("test lgamma and digamma") { + constexpr float inf = std::numeric_limits::infinity(); + + // lgamma known values + CHECK_EQ(lgamma(array(1.0f)).item(), doctest::Approx(0.0f).epsilon(1e-5)); + CHECK_EQ(lgamma(array(2.0f)).item(), doctest::Approx(0.0f).epsilon(1e-5)); + CHECK_EQ( + lgamma(array(0.5f)).item(), + doctest::Approx(0.5723649f).epsilon(1e-5)); + CHECK_EQ( + lgamma(array(5.0f)).item(), + doctest::Approx(3.1780538f).epsilon(1e-4)); + CHECK_EQ( + lgamma(array(10.0f)).item(), + doctest::Approx(12.801827f).epsilon(1e-3)); + + // lgamma at poles + CHECK(std::isinf(lgamma(array(0.0f)).item())); + CHECK(std::isinf(lgamma(array(-1.0f)).item())); + + // lgamma negative non-integer (reflection) + CHECK_EQ( + lgamma(array(-0.5f)).item(), + doctest::Approx(1.2655121f).epsilon(1e-4)); + + // lgamma integer promotion + CHECK_EQ(lgamma(array(2, int32)).dtype(), float32); + + // lgamma vectorized + { + auto x = array({1.0f, 2.0f, 5.0f, 10.0f}); + auto result = lgamma(x); + CHECK_EQ(result.shape(), Shape{4}); + CHECK_EQ(result.dtype(), float32); + } + + // lgamma float16 + { + array x(2.0f, float16); + auto out = lgamma(x); + CHECK_EQ(out.dtype(), float16); + CHECK_EQ(out.item(), doctest::Approx(0.0f).epsilon(1e-2)); + } + + // digamma known values + // digamma(1) = -euler_gamma + CHECK_EQ( + digamma(array(1.0f)).item(), + doctest::Approx(-0.5772157f).epsilon(1e-5)); + // digamma(2) = 1 - euler_gamma + CHECK_EQ( + digamma(array(2.0f)).item(), + doctest::Approx(0.4227843f).epsilon(1e-5)); + // digamma(0.5) = -euler_gamma - 2*ln(2) + CHECK_EQ( + digamma(array(0.5f)).item(), + doctest::Approx(-1.9635100f).epsilon(1e-4)); + + // digamma integer promotion + CHECK_EQ(digamma(array(2, int32)).dtype(), float32); + + // digamma vectorized + { + auto x = array({1.0f, 2.0f, 5.0f, 10.0f}); + auto result = digamma(x); + CHECK_EQ(result.shape(), Shape{4}); + CHECK_EQ(result.dtype(), float32); + } + + // gradient: d/dx lgamma(x) = digamma(x) + { + auto f = [](array x) { return lgamma(x); }; + auto dfdx = grad(f)(array(3.0f)); + auto expected = digamma(array(3.0f)); + CHECK_EQ(dfdx.item(), doctest::Approx(expected.item()).epsilon(1e-5)); + } +} + TEST_CASE("test arithmetic binary ops") { array x(1.0); array y(1.0);