From 9e468b23ce26003cc7d0d2bec10b48629a61074f Mon Sep 17 00:00:00 2001 From: David Koski Date: Fri, 3 Apr 2026 15:26:02 -0700 Subject: [PATCH 1/2] add missing functions - eig - eigenvalues and eigenvectors - eigh - eigenvalues/eigenvectors of symmetric/Hermitian matrix - eigvals - eigenvalues only - eigvalsh - eigenvalues of symmetric/Hermitian matrix - pinv - pseudo-inverse --- Source/MLX/Linalg.swift | 206 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 206 insertions(+) diff --git a/Source/MLX/Linalg.swift b/Source/MLX/Linalg.swift index c2e2813c..1d48f7c1 100644 --- a/Source/MLX/Linalg.swift +++ b/Source/MLX/Linalg.swift @@ -415,6 +415,116 @@ public enum MLXLinalg { return MLXArray(result) } + /// Compute the eigenvalues and eigenvectors of a square matrix. + /// + /// This function differs from `numpy.linalg.eig` in that the + /// return type is always complex even if the eigenvalues are all real. + /// + /// This function supports arrays with at least 2 dimensions. When the input + /// has more than two dimensions, the eigenvalues and eigenvectors are + /// computed for each matrix in the last two dimensions. + /// + /// - Parameters: + /// - array: input array + /// - stream: stream or device to evaluate on + /// - Returns: a tuple containing the eigenvalues and the normalized right + /// eigenvectors. The column `v[0..., i]` is the eigenvector corresponding + /// to the i-th eigenvalue. + public static func eig(_ array: MLXArray, stream: StreamOrDevice = .default) -> ( + MLXArray, MLXArray + ) { + var r0 = mlx_array_new() + var r1 = mlx_array_new() + mlx_linalg_eig(&r0, &r1, array.ctx, stream.ctx) + return (MLXArray(r0), MLXArray(r1)) + } + + /// Compute the eigenvalues of a square matrix. + /// + /// This function differs from `numpy.linalg.eigvals` in that the + /// return type is always complex even if the eigenvalues are all real. + /// + /// This function supports arrays with at least 2 dimensions. When the + /// input has more than two dimensions, the eigenvalues are computed for + /// each matrix in the last two dimensions. + /// + /// - Parameters: + /// - array: input array + /// - stream: stream or device to evaluate on + /// - Returns: the eigenvalues (not necessarily in order) + public static func eigvals(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { + var result = mlx_array_new() + mlx_linalg_eigvals(&result, array.ctx, stream.ctx) + return MLXArray(result) + } + + /// Compute the eigenvalues and eigenvectors of a complex Hermitian or + /// real symmetric matrix. + /// + /// This function supports arrays with at least 2 dimensions. When the input + /// has more than two dimensions, the eigenvalues and eigenvectors are + /// computed for each matrix in the last two dimensions. + /// + /// > The input matrix is assumed to be symmetric (or Hermitian). Only + /// the selected triangle is used. No checks for symmetry are performed. + /// + /// - Parameters: + /// - array: input array. Must be a real symmetric or complex Hermitian matrix. + /// - UPLO: whether to use the upper (`"U"`) or lower (`"L"`) triangle of the matrix. + /// Default is `"L"`. + /// - stream: stream or device to evaluate on + /// - Returns: a tuple containing the eigenvalues in ascending order and + /// the normalized eigenvectors. The column `v[0..., i]` is the eigenvector + /// corresponding to the i-th eigenvalue. + public static func eigh( + _ array: MLXArray, UPLO: String = "L", stream: StreamOrDevice = .default + ) -> (MLXArray, MLXArray) { + var r0 = mlx_array_new() + var r1 = mlx_array_new() + mlx_linalg_eigh(&r0, &r1, array.ctx, UPLO, stream.ctx) + return (MLXArray(r0), MLXArray(r1)) + } + + /// Compute the eigenvalues of a complex Hermitian or real symmetric matrix. + /// + /// This function supports arrays with at least 2 dimensions. When the + /// input has more than two dimensions, the eigenvalues are computed for + /// each matrix in the last two dimensions. + /// + /// > The input matrix is assumed to be symmetric (or Hermitian). Only + /// the selected triangle is used. No checks for symmetry are performed. + /// + /// - Parameters: + /// - array: input array. Must be a real symmetric or complex Hermitian matrix. + /// - UPLO: whether to use the upper (`"U"`) or lower (`"L"`) triangle of the matrix. + /// Default is `"L"`. + /// - stream: stream or device to evaluate on + /// - Returns: the eigenvalues in ascending order + public static func eigvalsh( + _ array: MLXArray, UPLO: String = "L", stream: StreamOrDevice = .default + ) -> MLXArray { + var result = mlx_array_new() + mlx_linalg_eigvalsh(&result, array.ctx, UPLO, stream.ctx) + return MLXArray(result) + } + + /// Compute the (Moore-Penrose) pseudo-inverse of a matrix. + /// + /// This function calculates a generalized inverse of a matrix using its + /// singular-value decomposition. This function supports arrays with at least + /// 2 dimensions. When the input has more than two dimensions, the inverse is + /// computed for each matrix in the last two dimensions of `array`. + /// + /// - Parameters: + /// - array: input array + /// - stream: stream or device to evaluate on + /// - Returns: `aplus` such that `a @ aplus @ a = a` + public static func pinv(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { + var result = mlx_array_new() + mlx_linalg_pinv(&result, array.ctx, stream.ctx) + return MLXArray(result) + } + } // MLXLinalg /// Matrix or vector norm. @@ -716,3 +826,99 @@ public func solveTriangular( { return MLXLinalg.solveTriangular(a, b, upper: upper, stream: stream) } + +/// Compute the eigenvalues and eigenvectors of a square matrix. +/// +/// This function differs from `numpy.linalg.eig` in that the +/// return type is always complex even if the eigenvalues are all real. +/// +/// This function supports arrays with at least 2 dimensions. When the input +/// has more than two dimensions, the eigenvalues and eigenvectors are +/// computed for each matrix in the last two dimensions. +/// +/// - Parameters: +/// - array: input array +/// - stream: stream or device to evaluate on +/// - Returns: a tuple containing the eigenvalues and the normalized right +/// eigenvectors. The column `v[0..., i]` is the eigenvector corresponding +/// to the i-th eigenvalue. +public func eig(_ array: MLXArray, stream: StreamOrDevice = .default) -> (MLXArray, MLXArray) { + return MLXLinalg.eig(array, stream: stream) +} + +/// Compute the eigenvalues of a square matrix. +/// +/// This function differs from `numpy.linalg.eigvals` in that the +/// return type is always complex even if the eigenvalues are all real. +/// +/// This function supports arrays with at least 2 dimensions. When the +/// input has more than two dimensions, the eigenvalues are computed for +/// each matrix in the last two dimensions. +/// +/// - Parameters: +/// - array: input array +/// - stream: stream or device to evaluate on +/// - Returns: the eigenvalues (not necessarily in order) +public func eigvals(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { + return MLXLinalg.eigvals(array, stream: stream) +} + +/// Compute the eigenvalues and eigenvectors of a complex Hermitian or +/// real symmetric matrix. +/// +/// This function supports arrays with at least 2 dimensions. When the input +/// has more than two dimensions, the eigenvalues and eigenvectors are +/// computed for each matrix in the last two dimensions. +/// +/// > The input matrix is assumed to be symmetric (or Hermitian). Only +/// the selected triangle is used. No checks for symmetry are performed. +/// +/// - Parameters: +/// - array: input array. Must be a real symmetric or complex Hermitian matrix. +/// - UPLO: whether to use the upper (`"U"`) or lower (`"L"`) triangle of the matrix. +/// Default is `"L"`. +/// - stream: stream or device to evaluate on +/// - Returns: a tuple containing the eigenvalues in ascending order and +/// the normalized eigenvectors. The column `v[0..., i]` is the eigenvector +/// corresponding to the i-th eigenvalue. +public func eigh( + _ array: MLXArray, UPLO: String = "L", stream: StreamOrDevice = .default +) -> (MLXArray, MLXArray) { + return MLXLinalg.eigh(array, UPLO: UPLO, stream: stream) +} + +/// Compute the eigenvalues of a complex Hermitian or real symmetric matrix. +/// +/// This function supports arrays with at least 2 dimensions. When the +/// input has more than two dimensions, the eigenvalues are computed for +/// each matrix in the last two dimensions. +/// +/// > The input matrix is assumed to be symmetric (or Hermitian). Only +/// the selected triangle is used. No checks for symmetry are performed. +/// +/// - Parameters: +/// - array: input array. Must be a real symmetric or complex Hermitian matrix. +/// - UPLO: whether to use the upper (`"U"`) or lower (`"L"`) triangle of the matrix. +/// Default is `"L"`. +/// - stream: stream or device to evaluate on +/// - Returns: the eigenvalues in ascending order +public func eigvalsh( + _ array: MLXArray, UPLO: String = "L", stream: StreamOrDevice = .default +) -> MLXArray { + return MLXLinalg.eigvalsh(array, UPLO: UPLO, stream: stream) +} + +/// Compute the (Moore-Penrose) pseudo-inverse of a matrix. +/// +/// This function calculates a generalized inverse of a matrix using its +/// singular-value decomposition. This function supports arrays with at least +/// 2 dimensions. When the input has more than two dimensions, the inverse is +/// computed for each matrix in the last two dimensions of `array`. +/// +/// - Parameters: +/// - array: input array +/// - stream: stream or device to evaluate on +/// - Returns: `aplus` such that `a @ aplus @ a = a` +public func pinv(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { + return MLXLinalg.pinv(array, stream: stream) +} From 7efd7b8d96345d99cf683eb0a2f134577412195c Mon Sep 17 00:00:00 2001 From: David Koski Date: Fri, 3 Apr 2026 17:06:01 -0700 Subject: [PATCH 2/2] add remaining functions: - logcumsumexp - segmentedMm (segmented_mm) - permutation - fftshift - ifftshift --- .../Organization/arithmetic.md | 1 + .../Organization/cumulative.md | 4 + Source/MLX/FFT.swift | 76 +++++++++++++++++++ Source/MLX/IO.swift | 2 +- Source/MLX/MLXArray+Ops.swift | 36 +++++++++ Source/MLX/Ops+Array.swift | 40 ++++++++++ Source/MLX/Ops.swift | 21 +++++ Source/MLX/Random.swift | 38 ++++++++++ 8 files changed, 217 insertions(+), 1 deletion(-) diff --git a/Source/MLX/Documentation.docc/Organization/arithmetic.md b/Source/MLX/Documentation.docc/Organization/arithmetic.md index 37c37e0c..7023b95a 100644 --- a/Source/MLX/Documentation.docc/Organization/arithmetic.md +++ b/Source/MLX/Documentation.docc/Organization/arithmetic.md @@ -216,6 +216,7 @@ Note: the `-` and `/` operators are not able to be linked here. - ``quantizedMM(_:_:scales:biases:transpose:groupSize:bits:mode:stream:)`` - ``gatherQuantizedMM(_:_:scales:biases:lhsIndices:rhsIndices:transpose:groupSize:bits:mode:sortedIndices:stream:)`` - ``quantizedQuantizedMM(_:_:scales:groupSize:bits:mode:globalScaleX:globalScaleW:stream:)`` +- ``segmentedMM(_:_:segments:stream:)`` - ``inner(_:_:stream:)`` - ``outer(_:_:stream:)`` - ``tensordot(_:_:axes:stream:)-(MLXArray,MLXArray,Int,StreamOrDevice)`` diff --git a/Source/MLX/Documentation.docc/Organization/cumulative.md b/Source/MLX/Documentation.docc/Organization/cumulative.md index 8789726f..3d20ae6d 100644 --- a/Source/MLX/Documentation.docc/Organization/cumulative.md +++ b/Source/MLX/Documentation.docc/Organization/cumulative.md @@ -29,6 +29,8 @@ These are available as both methods on `MLXArray` and free functions. They each - ``MLXArray/cumprod(reverse:inclusive:stream:)`` - ``MLXArray/cumsum(axis:reverse:inclusive:stream:)`` - ``MLXArray/cumsum(reverse:inclusive:stream:)`` +- ``MLXArray/logCumsumExp(axis:reverse:inclusive:stream:)`` +- ``MLXArray/logCumsumExp(reverse:inclusive:stream:)`` ### Free Functions @@ -40,3 +42,5 @@ These are available as both methods on `MLXArray` and free functions. They each - ``cumprod(_:reverse:inclusive:stream:)`` - ``cumsum(_:axis:reverse:inclusive:stream:)`` - ``cumsum(_:reverse:inclusive:stream:)`` +- ``logCumsumExp(_:axis:reverse:inclusive:stream:)`` +- ``logCumsumExp(_:reverse:inclusive:stream:)`` diff --git a/Source/MLX/FFT.swift b/Source/MLX/FFT.swift index 40fdd699..988385b4 100644 --- a/Source/MLX/FFT.swift +++ b/Source/MLX/FFT.swift @@ -381,6 +381,47 @@ public enum MLXFFT { } } + /// Shift the zero-frequency component to the center of the spectrum. + /// + /// - Parameters: + /// - array: input array + /// - axes: axes over which to shift. If `nil`, all axes are shifted. + /// - stream: stream or device to evaluate on + /// - Returns: the shifted array + /// + /// ### See Also + /// - + public static func fftshift( + _ array: MLXArray, axes: [Int]? = nil, stream: StreamOrDevice = .default + ) -> MLXArray { + var result = mlx_array_new() + let axes = axes ?? Array(0 ..< array.ndim) + mlx_fft_fftshift(&result, array.ctx, axes.asInt32, axes.count, stream.ctx) + return MLXArray(result) + } + + /// The inverse of ``fftshift(_:axes:stream:)``. + /// + /// While identical to ``fftshift(_:axes:stream:)`` for even-length axes, + /// the behavior differs for odd-length axes. + /// + /// - Parameters: + /// - array: input array + /// - axes: axes over which to shift. If `nil`, all axes are shifted. + /// - stream: stream or device to evaluate on + /// - Returns: the shifted array + /// + /// ### See Also + /// - + public static func ifftshift( + _ array: MLXArray, axes: [Int]? = nil, stream: StreamOrDevice = .default + ) -> MLXArray { + var result = mlx_array_new() + let axes = axes ?? Array(0 ..< array.ndim) + mlx_fft_ifftshift(&result, array.ctx, axes.asInt32, axes.count, stream.ctx) + return MLXArray(result) + } + } // MLXFFT /// One dimensional discrete Fourier Transform. @@ -631,3 +672,38 @@ public func irfftn( ) -> MLXArray { MLXFFT.irfftn(array, s: s, axes: axes, stream: stream) } + +/// Shift the zero-frequency component to the center of the spectrum. +/// +/// - Parameters: +/// - array: input array +/// - axes: axes over which to shift. If `nil`, all axes are shifted. +/// - stream: stream or device to evaluate on +/// - Returns: the shifted array +/// +/// ### See Also +/// - +public func fftshift( + _ array: MLXArray, axes: [Int]? = nil, stream: StreamOrDevice = .default +) -> MLXArray { + MLXFFT.fftshift(array, axes: axes, stream: stream) +} + +/// The inverse of ``fftshift(_:axes:stream:)``. +/// +/// While identical to ``fftshift(_:axes:stream:)`` for even-length axes, +/// the behavior differs for odd-length axes. +/// +/// - Parameters: +/// - array: input array +/// - axes: axes over which to shift. If `nil`, all axes are shifted. +/// - stream: stream or device to evaluate on +/// - Returns: the shifted array +/// +/// ### See Also +/// - +public func ifftshift( + _ array: MLXArray, axes: [Int]? = nil, stream: StreamOrDevice = .default +) -> MLXArray { + MLXFFT.ifftshift(array, axes: axes, stream: stream) +} diff --git a/Source/MLX/IO.swift b/Source/MLX/IO.swift index 519096dc..aee73a56 100644 --- a/Source/MLX/IO.swift +++ b/Source/MLX/IO.swift @@ -287,7 +287,7 @@ public func saveToData( let writer = new_mlx_io_writer_dataIO() defer { mlx_io_writer_free(writer) } - _ = evalLock.withLock { + _ = try withError { _ = evalLock.withLock { mlx_save_safetensors_writer(writer, mlx_arrays, mlx_metadata) } diff --git a/Source/MLX/MLXArray+Ops.swift b/Source/MLX/MLXArray+Ops.swift index 4d950f3d..f940d297 100644 --- a/Source/MLX/MLXArray+Ops.swift +++ b/Source/MLX/MLXArray+Ops.swift @@ -1984,6 +1984,42 @@ extension MLXArray { return MLXArray(result) } + /// Return the cumulative logsumexp of the elements along the given axis. + /// + /// - Parameters: + /// - axis: axis to reduce over + /// - reverse: reverse the reduction + /// - inclusive: include the initial value + /// - stream: stream or device to evaluate on + /// + /// ### See Also + /// - + public func logCumsumExp( + axis: Int, reverse: Bool = false, inclusive: Bool = true, stream: StreamOrDevice = .default + ) -> MLXArray { + var result = mlx_array_new() + mlx_logcumsumexp(&result, ctx, axis.int32, reverse, inclusive, stream.ctx) + return MLXArray(result) + } + + /// Return the cumulative logsumexp of the elements of the flattened array. + /// + /// - Parameters: + /// - reverse: reverse the reduction + /// - inclusive: include the initial value + /// - stream: stream or device to evaluate on + /// + /// ### See Also + /// - + public func logCumsumExp( + reverse: Bool = false, inclusive: Bool = true, stream: StreamOrDevice = .default + ) -> MLXArray { + let flat = self.reshaped([-1], stream: stream) + var result = mlx_array_new() + mlx_logcumsumexp(&result, flat.ctx, 0, reverse, inclusive, stream.ctx) + return MLXArray(result) + } + /// A `log-sum-exp` reduction over the given axes. /// /// The log-sum-exp reduction is a numerically stable version of: diff --git a/Source/MLX/Ops+Array.swift b/Source/MLX/Ops+Array.swift index b0b0ddbd..4b90ffbf 100644 --- a/Source/MLX/Ops+Array.swift +++ b/Source/MLX/Ops+Array.swift @@ -828,6 +828,46 @@ public func log1p(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXAr return MLXArray(result) } +/// Return the cumulative logsumexp of the elements along the given axis. +/// +/// - Parameters: +/// - array: input array +/// - axis: axis to reduce over +/// - reverse: reverse the reduction +/// - inclusive: include the initial value +/// - stream: stream or device to evaluate on +/// +/// ### See Also +/// - +public func logCumsumExp( + _ array: MLXArray, + axis: Int, reverse: Bool = false, inclusive: Bool = true, stream: StreamOrDevice = .default +) -> MLXArray { + var result = mlx_array_new() + mlx_logcumsumexp(&result, array.ctx, axis.int32, reverse, inclusive, stream.ctx) + return MLXArray(result) +} + +/// Return the cumulative logsumexp of the elements of the flattened array. +/// +/// - Parameters: +/// - array: input array +/// - reverse: reverse the reduction +/// - inclusive: include the initial value +/// - stream: stream or device to evaluate on +/// +/// ### See Also +/// - +public func logCumsumExp( + _ array: MLXArray, + reverse: Bool = false, inclusive: Bool = true, stream: StreamOrDevice = .default +) -> MLXArray { + let flat = array.reshaped([-1], stream: stream) + var result = mlx_array_new() + mlx_logcumsumexp(&result, flat.ctx, 0, reverse, inclusive, stream.ctx) + return MLXArray(result) +} + /// A `log-sum-exp` reduction over the given axes. /// /// The log-sum-exp reduction is a numerically stable version of: diff --git a/Source/MLX/Ops.swift b/Source/MLX/Ops.swift index c640eb44..ec83c044 100644 --- a/Source/MLX/Ops.swift +++ b/Source/MLX/Ops.swift @@ -1488,6 +1488,27 @@ public func gatherQuantizedMM( return MLXArray(result) } +/// Perform a matrix multiplication but segment the inner dimension and +/// save the result for each segment separately. +/// +/// - Parameters: +/// - a: array of shape `MxK` +/// - b: array of shape `KxN` +/// - segments: offsets into the inner dimension for each segment +/// - stream: stream or device to evaluate on +/// - Returns: result per segment of shape `MxN` +/// ### See Also +/// - +public func segmentedMM( + _ a: MLXArray, _ b: MLXArray, + segments: MLXArray, + stream: StreamOrDevice = .default +) -> MLXArray { + var result = mlx_array_new() + mlx_segmented_mm(&result, a.ctx, b.ctx, segments.ctx, stream.ctx) + return MLXArray(result) +} + /// Element-wise greater than. /// /// Greater than on two arrays with . diff --git a/Source/MLX/Random.swift b/Source/MLX/Random.swift index e267500f..a89c5d2c 100644 --- a/Source/MLX/Random.swift +++ b/Source/MLX/Random.swift @@ -357,6 +357,44 @@ public enum MLXRandom { return MLXArray(result) } + /// Generate a random permutation of integers from `0` to `max`. + /// + /// - Parameters: + /// - max: max value to permute + /// - key: PRNG key + /// - stream: stream + /// - Returns: permuted array + public static func permutation( + _ max: Int, + key: (some RandomStateOrKey)? = MLXArray?.none, + stream: StreamOrDevice = .default + ) -> MLXArray { + let key = resolve(key: key) + var result = mlx_array_new() + mlx_random_permutation_arange(&result, max.int32, key.ctx, stream.ctx) + return MLXArray(result) + } + + /// Generate a random permutation of the entries of an array. + /// + /// - Parameters: + /// - array: array to permute + /// - axis: axis along which to permute + /// - key: PRNG key + /// - stream: stream + /// - Returns: permuted array + public static func permutation( + _ array: MLXArray, + axis: Int = 0, + key: (some RandomStateOrKey)? = MLXArray?.none, + stream: StreamOrDevice = .default + ) -> MLXArray { + let key = resolve(key: key) + var result = mlx_array_new() + mlx_random_permutation(&result, array.ctx, axis.int32, key.ctx, stream.ctx) + return MLXArray(result) + } + /// Generate random integers from the given interval using a `RangeExpression`. /// /// The values are sampled with equal probability from the integers in