Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions mlx/backend/cpu/simd/math.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,4 +190,68 @@ Simd<T, N> erfinv(Simd<T, N> a_) {
}
}

template <typename T, int N>
Simd<T, N> lgamma(Simd<T, N> x) {
// Delegate to std::lgamma element-wise
Simd<T, N> result;
for (int i = 0; i < N; i++) {
result[i] = static_cast<T>(std::lgamma(static_cast<float>(x[i])));
}
return result;
}

template <typename T>
Simd<T, 1> lgamma(Simd<T, 1> x) {
return Simd<T, 1>(static_cast<T>(std::lgamma(static_cast<float>(x.value))));
}

template <typename T, int N>
Simd<T, N> digamma(Simd<T, N> x_) {
// Asymptotic expansion + recurrence, matching the Metal kernel
Simd<T, N> result;
for (int i = 0; i < N; i++) {
float v = static_cast<float>(x_[i]);
float r = 0.0f;
if (v < 0.0f) {
r = -M_PI / std::tan(M_PI * v);
v = 1.0f - v;
}
while (v < 10.0f) {
r -= 1.0f / v;
v += 1.0f;
}
float z = 1.0f / (v * v);
float y = 3.96825396825e-3f;
y = y * z + (-4.16666666667e-3f);
y = y * z + 7.57575757576e-3f;
y = y * z + (-2.10927960928e-2f);
y = y * z + 8.33333333333e-2f;
r += std::log(v) - 0.5f / v - y * z;
result[i] = static_cast<T>(r);
}
return result;
}

template <typename T>
Simd<T, 1> digamma(Simd<T, 1> x_) {
float v = static_cast<float>(x_.value);
float r = 0.0f;
if (v < 0.0f) {
r = -M_PI / std::tan(M_PI * v);
v = 1.0f - v;
}
while (v < 10.0f) {
r -= 1.0f / v;
v += 1.0f;
}
float z = 1.0f / (v * v);
float y = 3.96825396825e-3f;
y = y * z + (-4.16666666667e-3f);
y = y * z + 7.57575757576e-3f;
y = y * z + (-2.10927960928e-2f);
y = y * z + 8.33333333333e-2f;
r += std::log(v) - 0.5f / v - y * z;
return Simd<T, 1>(static_cast<T>(r));
}

} // namespace mlx::core::simd
12 changes: 12 additions & 0 deletions mlx/backend/cpu/unary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,18 @@ void ErfInv::eval_cpu(const std::vector<array>& inputs, array& out) {
unary_real_fp(in, out, detail::ErfInv(), stream());
}

void LogGamma::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_real_fp(in, out, detail::LogGamma(), stream());
}

void Digamma::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_real_fp(in, out, detail::Digamma(), stream());
}

void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
Expand Down
2 changes: 2 additions & 0 deletions mlx/backend/cpu/unary_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ DEFAULT_OP(Cos, cos)
DEFAULT_OP(Cosh, cosh)
DEFAULT_OP(Erf, erf)
DEFAULT_OP(ErfInv, erfinv)
DEFAULT_OP(LogGamma, lgamma)
DEFAULT_OP(Digamma, digamma)
DEFAULT_OP(Exp, exp)
DEFAULT_OP(Expm1, expm1)
DEFAULT_OP(Floor, floor);
Expand Down
37 changes: 37 additions & 0 deletions mlx/backend/cuda/device/unary_ops.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,43 @@ struct ErfInv {
}
};

struct LogGamma {
template <typename T>
__device__ T operator()(T x) {
if constexpr (cuda::std::is_same_v<T, __half>) {
return ::lgamma(__half2float(x));
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) {
return ::lgamma(__bfloat162float(x));
} else {
return ::lgamma(x);
}
}
};

struct Digamma {
template <typename T>
__device__ T operator()(T x) {
float v = static_cast<float>(x);
float r = 0.0f;
if (v < 0.0f) {
r = -M_PI / ::tan(M_PI * v);
v = 1.0f - v;
}
while (v < 10.0f) {
r -= 1.0f / v;
v += 1.0f;
}
float z = 1.0f / (v * v);
float y = 3.96825396825e-3f;
y = y * z + (-4.16666666667e-3f);
y = y * z + 7.57575757576e-3f;
y = y * z + (-2.10927960928e-2f);
y = y * z + 8.33333333333e-2f;
r += ::logf(v) - 0.5f / v - y * z;
return static_cast<T>(r);
}
};

struct Exp {
template <typename T>
__device__ T operator()(T x) {
Expand Down
2 changes: 2 additions & 0 deletions mlx/backend/cuda/unary/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ target_sources(
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/cosh.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/erf.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/erf_inv.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/lgamma.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/digamma.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/exp.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/expm1.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/floor.cu
Expand Down
7 changes: 7 additions & 0 deletions mlx/backend/cuda/unary/digamma.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.

#include "mlx/backend/cuda/unary/unary.cuh"

namespace mlx::core {
UNARY_GPU(Digamma)
} // namespace mlx::core
7 changes: 7 additions & 0 deletions mlx/backend/cuda/unary/lgamma.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.

#include "mlx/backend/cuda/unary/unary.cuh"

namespace mlx::core {
UNARY_GPU(LogGamma)
} // namespace mlx::core
83 changes: 83 additions & 0 deletions mlx/backend/metal/kernels/lgamma.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Copyright © 2025 Apple Inc.

#pragma once

#include <metal_math>

/*
* Log-gamma function via Lanczos approximation (g=5, 7-term).
* Coefficients from Numerical Recipes; ~10 digits accuracy,
* sufficient for float32.
*
* Uses reflection formula for x < 0.5.
*/
float lgamma_impl(float x) {
// Poles at non-positive integers
if (x <= 0.0f && x == metal::floor(x)) {
return metal::numeric_limits<float>::infinity();
}

// Reflection formula for x < 0.5:
// lgamma(x) = log(pi) - log(|sin(pi*x)|) - lgamma(1-x)
// Lanczos for 1-x is inlined (no function calls — Metal has no call stack).
float reflection = 0.0f;
if (x < 0.5f) {
float sin_pi_x = metal::precise::sin(M_PI_F * x);
reflection = 1.1447298858494002f // log(pi)
- metal::precise::log(metal::abs(sin_pi_x));
x = 1.0f - x;
}

// Lanczos g=5, 7-term (Numerical Recipes)
float z = x - 1.0f;
float t = z + 5.5f; // z + g + 0.5
float s = 1.000000000190015f;
s = metal::fma(76.18009172947146f, 1.0f / (z + 1.0f), s);
s = metal::fma(-86.50532032941677f, 1.0f / (z + 2.0f), s);
s = metal::fma(24.01409824083091f, 1.0f / (z + 3.0f), s);
s = metal::fma(-1.231739572450155f, 1.0f / (z + 4.0f), s);
s = metal::fma(1.208650973866179e-3f, 1.0f / (z + 5.0f), s);
s = metal::fma(-5.395239384953e-6f, 1.0f / (z + 6.0f), s);

float lanczos = 0.9189385332046727f // 0.5 * log(2*pi)
+ (z + 0.5f) * metal::precise::log(t) - t +
metal::precise::log(s);

// For x < 0.5: reflection - lanczos(1-x). For x >= 0.5: just lanczos(x).
return (reflection != 0.0f) ? (reflection - lanczos) : lanczos;
}

/*
* Digamma (psi) function: d/dx[lgamma(x)].
* Uses asymptotic expansion for x >= 10, recurrence to shift
* small x, and reflection formula for negative x.
*
* Asymptotic coefficients are Bernoulli-number derived (float32-adequate).
*/
float digamma_impl(float x) {
float result = 0.0f;

// Reflection for negative x:
// digamma(x) = digamma(1-x) - pi/tan(pi*x)
if (x < 0.0f) {
result = -M_PI_F / metal::precise::tan(M_PI_F * x);
x = 1.0f - x;
}

// Recurrence: shift x to >= 10
// digamma(x) = digamma(x+1) - 1/x
while (x < 10.0f) {
result -= 1.0f / x;
x += 1.0f;
}

// Asymptotic expansion (Bernoulli numbers, Horner form)
float z = 1.0f / (x * x);
float y = metal::fma(3.96825396825e-3f, z, -4.16666666667e-3f);
y = metal::fma(y, z, 7.57575757576e-3f);
y = metal::fma(y, z, -2.10927960928e-2f);
y = metal::fma(y, z, 8.33333333333e-2f);

result += metal::precise::log(x) - 0.5f / x - y * z;
return result;
}
2 changes: 2 additions & 0 deletions mlx/backend/metal/kernels/unary.metal
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ instantiate_unary_types(Negative)
instantiate_unary_float(Sigmoid)
instantiate_unary_float(Erf)
instantiate_unary_float(ErfInv)
instantiate_unary_float(LogGamma)
instantiate_unary_float(Digamma)
instantiate_unary_types(Sign)
instantiate_unary_float(Sin)
instantiate_unary_float(Sinh)
Expand Down
15 changes: 15 additions & 0 deletions mlx/backend/metal/kernels/unary_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include "mlx/backend/metal/kernels/cexpf.h"
#include "mlx/backend/metal/kernels/erf.h"
#include "mlx/backend/metal/kernels/lgamma.h"
#include "mlx/backend/metal/kernels/expm1f.h"
#include "mlx/backend/metal/kernels/fp8.h"

Expand Down Expand Up @@ -174,6 +175,20 @@ struct ErfInv {
};
};

struct LogGamma {
template <typename T>
T operator()(T x) {
return static_cast<T>(lgamma_impl(static_cast<float>(x)));
};
};

struct Digamma {
template <typename T>
T operator()(T x) {
return static_cast<T>(digamma_impl(static_cast<float>(x)));
};
};

struct Exp {
template <typename T>
T operator()(T x) {
Expand Down
2 changes: 2 additions & 0 deletions mlx/backend/metal/unary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ UNARY_GPU(Cos)
UNARY_GPU(Cosh)
UNARY_GPU(Erf)
UNARY_GPU(ErfInv)
UNARY_GPU(LogGamma)
UNARY_GPU(Digamma)
UNARY_GPU(Exp)
UNARY_GPU(Expm1)
UNARY_GPU(Imag)
Expand Down
2 changes: 2 additions & 0 deletions mlx/export.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,8 @@ struct PrimitiveFactory {
SERIALIZE_PRIMITIVE(Equal, "NaNEqual"),
SERIALIZE_PRIMITIVE(Erf),
SERIALIZE_PRIMITIVE(ErfInv),
SERIALIZE_PRIMITIVE(LogGamma),
SERIALIZE_PRIMITIVE(Digamma),
SERIALIZE_PRIMITIVE(Exp),
SERIALIZE_PRIMITIVE(Expm1),
SERIALIZE_PRIMITIVE(ExpandDims),
Expand Down
18 changes: 18 additions & 0 deletions mlx/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2996,6 +2996,24 @@ array erfinv(const array& a, StreamOrDevice s /* = {} */) {
{astype(a, dtype, s)});
}

array lgamma(const array& a, StreamOrDevice s /* = {} */) {
auto dtype = at_least_float(a.dtype());
return array(
a.shape(),
dtype,
std::make_shared<LogGamma>(to_stream(s)),
{astype(a, dtype, s)});
}

array digamma(const array& a, StreamOrDevice s /* = {} */) {
auto dtype = at_least_float(a.dtype());
return array(
a.shape(),
dtype,
std::make_shared<Digamma>(to_stream(s)),
{astype(a, dtype, s)});
}

array stop_gradient(const array& a, StreamOrDevice s /* = {} */) {
return array(
a.shape(), a.dtype(), std::make_shared<StopGradient>(to_stream(s)), {a});
Expand Down
6 changes: 6 additions & 0 deletions mlx/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -948,6 +948,12 @@ MLX_API array erf(const array& a, StreamOrDevice s = {});
/** Computes the inverse error function of the elements of an array. */
MLX_API array erfinv(const array& a, StreamOrDevice s = {});

/** Element-wise log-gamma (log of the absolute value of the gamma function). */
MLX_API array lgamma(const array& a, StreamOrDevice s = {});

/** Element-wise digamma (psi) function, the derivative of lgamma. */
MLX_API array digamma(const array& a, StreamOrDevice s = {});

/** Computes the expm1 function of the elements of an array. */
MLX_API array expm1(const array& a, StreamOrDevice s = {});

Expand Down
Loading
Loading