Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Source/MLX/Documentation.docc/Organization/arithmetic.md
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ Note: the `-` and `/` operators are not able to be linked here.
- ``quantizedMM(_:_:scales:biases:transpose:groupSize:bits:mode:stream:)``
- ``gatherQuantizedMM(_:_:scales:biases:lhsIndices:rhsIndices:transpose:groupSize:bits:mode:sortedIndices:stream:)``
- ``quantizedQuantizedMM(_:_:scales:groupSize:bits:mode:globalScaleX:globalScaleW:stream:)``
- ``segmentedMM(_:_:segments:stream:)``
- ``inner(_:_:stream:)``
- ``outer(_:_:stream:)``
- ``tensordot(_:_:axes:stream:)-(MLXArray,MLXArray,Int,StreamOrDevice)``
Expand Down
4 changes: 4 additions & 0 deletions Source/MLX/Documentation.docc/Organization/cumulative.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ These are available as both methods on `MLXArray` and free functions. They each
- ``MLXArray/cumprod(reverse:inclusive:stream:)``
- ``MLXArray/cumsum(axis:reverse:inclusive:stream:)``
- ``MLXArray/cumsum(reverse:inclusive:stream:)``
- ``MLXArray/logCumsumExp(axis:reverse:inclusive:stream:)``
- ``MLXArray/logCumsumExp(reverse:inclusive:stream:)``

### Free Functions

Expand All @@ -40,3 +42,5 @@ These are available as both methods on `MLXArray` and free functions. They each
- ``cumprod(_:reverse:inclusive:stream:)``
- ``cumsum(_:axis:reverse:inclusive:stream:)``
- ``cumsum(_:reverse:inclusive:stream:)``
- ``logCumsumExp(_:axis:reverse:inclusive:stream:)``
- ``logCumsumExp(_:reverse:inclusive:stream:)``
76 changes: 76 additions & 0 deletions Source/MLX/FFT.swift
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,47 @@ public enum MLXFFT {
}
}

/// Shift the zero-frequency component to the center of the spectrum.
///
/// - Parameters:
/// - array: input array
/// - axes: axes over which to shift. If `nil`, all axes are shifted.
/// - stream: stream or device to evaluate on
/// - Returns: the shifted array
///
/// ### See Also
/// - <doc:MLXFFT>
public static func fftshift(
_ array: MLXArray, axes: [Int]? = nil, stream: StreamOrDevice = .default
) -> MLXArray {
var result = mlx_array_new()
let axes = axes ?? Array(0 ..< array.ndim)
mlx_fft_fftshift(&result, array.ctx, axes.asInt32, axes.count, stream.ctx)
return MLXArray(result)
}

/// The inverse of ``fftshift(_:axes:stream:)``.
///
/// While identical to ``fftshift(_:axes:stream:)`` for even-length axes,
/// the behavior differs for odd-length axes.
///
/// - Parameters:
/// - array: input array
/// - axes: axes over which to shift. If `nil`, all axes are shifted.
/// - stream: stream or device to evaluate on
/// - Returns: the shifted array
///
/// ### See Also
/// - <doc:MLXFFT>
public static func ifftshift(
_ array: MLXArray, axes: [Int]? = nil, stream: StreamOrDevice = .default
) -> MLXArray {
var result = mlx_array_new()
let axes = axes ?? Array(0 ..< array.ndim)
mlx_fft_ifftshift(&result, array.ctx, axes.asInt32, axes.count, stream.ctx)
return MLXArray(result)
}

} // MLXFFT

/// One dimensional discrete Fourier Transform.
Expand Down Expand Up @@ -631,3 +672,38 @@ public func irfftn(
) -> MLXArray {
MLXFFT.irfftn(array, s: s, axes: axes, stream: stream)
}

/// Shift the zero-frequency component to the center of the spectrum.
///
/// - Parameters:
/// - array: input array
/// - axes: axes over which to shift. If `nil`, all axes are shifted.
/// - stream: stream or device to evaluate on
/// - Returns: the shifted array
///
/// ### See Also
/// - <doc:MLXFFT>
public func fftshift(
_ array: MLXArray, axes: [Int]? = nil, stream: StreamOrDevice = .default
) -> MLXArray {
MLXFFT.fftshift(array, axes: axes, stream: stream)
}

/// The inverse of ``fftshift(_:axes:stream:)``.
///
/// While identical to ``fftshift(_:axes:stream:)`` for even-length axes,
/// the behavior differs for odd-length axes.
///
/// - Parameters:
/// - array: input array
/// - axes: axes over which to shift. If `nil`, all axes are shifted.
/// - stream: stream or device to evaluate on
/// - Returns: the shifted array
///
/// ### See Also
/// - <doc:MLXFFT>
public func ifftshift(
_ array: MLXArray, axes: [Int]? = nil, stream: StreamOrDevice = .default
) -> MLXArray {
MLXFFT.ifftshift(array, axes: axes, stream: stream)
}
2 changes: 1 addition & 1 deletion Source/MLX/IO.swift
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ public func saveToData(
let writer = new_mlx_io_writer_dataIO()
defer { mlx_io_writer_free(writer) }

_ = evalLock.withLock {
_ = try withError {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops. Did I miss that in the refactoring PR a couple weeks ago 🤔

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may have missed it, but I wrote it :-)

_ = evalLock.withLock {
mlx_save_safetensors_writer(writer, mlx_arrays, mlx_metadata)
}
Expand Down
206 changes: 206 additions & 0 deletions Source/MLX/Linalg.swift
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,116 @@ public enum MLXLinalg {
return MLXArray(result)
}

/// Compute the eigenvalues and eigenvectors of a square matrix.
///
/// This function differs from `numpy.linalg.eig` in that the
/// return type is always complex even if the eigenvalues are all real.
///
/// This function supports arrays with at least 2 dimensions. When the input
/// has more than two dimensions, the eigenvalues and eigenvectors are
/// computed for each matrix in the last two dimensions.
///
/// - Parameters:
/// - array: input array
/// - stream: stream or device to evaluate on
/// - Returns: a tuple containing the eigenvalues and the normalized right
/// eigenvectors. The column `v[0..., i]` is the eigenvector corresponding
/// to the i-th eigenvalue.
public static func eig(_ array: MLXArray, stream: StreamOrDevice = .default) -> (
MLXArray, MLXArray
) {
var r0 = mlx_array_new()
var r1 = mlx_array_new()
mlx_linalg_eig(&r0, &r1, array.ctx, stream.ctx)
return (MLXArray(r0), MLXArray(r1))
}

/// Compute the eigenvalues of a square matrix.
///
/// This function differs from `numpy.linalg.eigvals` in that the
/// return type is always complex even if the eigenvalues are all real.
///
/// This function supports arrays with at least 2 dimensions. When the
/// input has more than two dimensions, the eigenvalues are computed for
/// each matrix in the last two dimensions.
///
/// - Parameters:
/// - array: input array
/// - stream: stream or device to evaluate on
/// - Returns: the eigenvalues (not necessarily in order)
public static func eigvals(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray {
var result = mlx_array_new()
mlx_linalg_eigvals(&result, array.ctx, stream.ctx)
return MLXArray(result)
}

/// Compute the eigenvalues and eigenvectors of a complex Hermitian or
/// real symmetric matrix.
///
/// This function supports arrays with at least 2 dimensions. When the input
/// has more than two dimensions, the eigenvalues and eigenvectors are
/// computed for each matrix in the last two dimensions.
///
/// > The input matrix is assumed to be symmetric (or Hermitian). Only
/// the selected triangle is used. No checks for symmetry are performed.
///
/// - Parameters:
/// - array: input array. Must be a real symmetric or complex Hermitian matrix.
/// - UPLO: whether to use the upper (`"U"`) or lower (`"L"`) triangle of the matrix.
/// Default is `"L"`.
/// - stream: stream or device to evaluate on
/// - Returns: a tuple containing the eigenvalues in ascending order and
/// the normalized eigenvectors. The column `v[0..., i]` is the eigenvector
/// corresponding to the i-th eigenvalue.
public static func eigh(
_ array: MLXArray, UPLO: String = "L", stream: StreamOrDevice = .default
) -> (MLXArray, MLXArray) {
var r0 = mlx_array_new()
var r1 = mlx_array_new()
mlx_linalg_eigh(&r0, &r1, array.ctx, UPLO, stream.ctx)
return (MLXArray(r0), MLXArray(r1))
}

/// Compute the eigenvalues of a complex Hermitian or real symmetric matrix.
///
/// This function supports arrays with at least 2 dimensions. When the
/// input has more than two dimensions, the eigenvalues are computed for
/// each matrix in the last two dimensions.
///
/// > The input matrix is assumed to be symmetric (or Hermitian). Only
/// the selected triangle is used. No checks for symmetry are performed.
///
/// - Parameters:
/// - array: input array. Must be a real symmetric or complex Hermitian matrix.
/// - UPLO: whether to use the upper (`"U"`) or lower (`"L"`) triangle of the matrix.
/// Default is `"L"`.
/// - stream: stream or device to evaluate on
/// - Returns: the eigenvalues in ascending order
public static func eigvalsh(
_ array: MLXArray, UPLO: String = "L", stream: StreamOrDevice = .default
) -> MLXArray {
var result = mlx_array_new()
mlx_linalg_eigvalsh(&result, array.ctx, UPLO, stream.ctx)
return MLXArray(result)
}

/// Compute the (Moore-Penrose) pseudo-inverse of a matrix.
///
/// This function calculates a generalized inverse of a matrix using its
/// singular-value decomposition. This function supports arrays with at least
/// 2 dimensions. When the input has more than two dimensions, the inverse is
/// computed for each matrix in the last two dimensions of `array`.
///
/// - Parameters:
/// - array: input array
/// - stream: stream or device to evaluate on
/// - Returns: `aplus` such that `a @ aplus @ a = a`
public static func pinv(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray {
var result = mlx_array_new()
mlx_linalg_pinv(&result, array.ctx, stream.ctx)
return MLXArray(result)
}

} // MLXLinalg

/// Matrix or vector norm.
Expand Down Expand Up @@ -716,3 +826,99 @@ public func solveTriangular(
{
return MLXLinalg.solveTriangular(a, b, upper: upper, stream: stream)
}

/// Compute the eigenvalues and eigenvectors of a square matrix.
///
/// This function differs from `numpy.linalg.eig` in that the
/// return type is always complex even if the eigenvalues are all real.
///
/// This function supports arrays with at least 2 dimensions. When the input
/// has more than two dimensions, the eigenvalues and eigenvectors are
/// computed for each matrix in the last two dimensions.
///
/// - Parameters:
/// - array: input array
/// - stream: stream or device to evaluate on
/// - Returns: a tuple containing the eigenvalues and the normalized right
/// eigenvectors. The column `v[0..., i]` is the eigenvector corresponding
/// to the i-th eigenvalue.
public func eig(_ array: MLXArray, stream: StreamOrDevice = .default) -> (MLXArray, MLXArray) {
return MLXLinalg.eig(array, stream: stream)
}

/// Compute the eigenvalues of a square matrix.
///
/// This function differs from `numpy.linalg.eigvals` in that the
/// return type is always complex even if the eigenvalues are all real.
///
/// This function supports arrays with at least 2 dimensions. When the
/// input has more than two dimensions, the eigenvalues are computed for
/// each matrix in the last two dimensions.
///
/// - Parameters:
/// - array: input array
/// - stream: stream or device to evaluate on
/// - Returns: the eigenvalues (not necessarily in order)
public func eigvals(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray {
return MLXLinalg.eigvals(array, stream: stream)
}

/// Compute the eigenvalues and eigenvectors of a complex Hermitian or
/// real symmetric matrix.
///
/// This function supports arrays with at least 2 dimensions. When the input
/// has more than two dimensions, the eigenvalues and eigenvectors are
/// computed for each matrix in the last two dimensions.
///
/// > The input matrix is assumed to be symmetric (or Hermitian). Only
/// the selected triangle is used. No checks for symmetry are performed.
///
/// - Parameters:
/// - array: input array. Must be a real symmetric or complex Hermitian matrix.
/// - UPLO: whether to use the upper (`"U"`) or lower (`"L"`) triangle of the matrix.
/// Default is `"L"`.
/// - stream: stream or device to evaluate on
/// - Returns: a tuple containing the eigenvalues in ascending order and
/// the normalized eigenvectors. The column `v[0..., i]` is the eigenvector
/// corresponding to the i-th eigenvalue.
public func eigh(
_ array: MLXArray, UPLO: String = "L", stream: StreamOrDevice = .default
) -> (MLXArray, MLXArray) {
return MLXLinalg.eigh(array, UPLO: UPLO, stream: stream)
}

/// Compute the eigenvalues of a complex Hermitian or real symmetric matrix.
///
/// This function supports arrays with at least 2 dimensions. When the
/// input has more than two dimensions, the eigenvalues are computed for
/// each matrix in the last two dimensions.
///
/// > The input matrix is assumed to be symmetric (or Hermitian). Only
/// the selected triangle is used. No checks for symmetry are performed.
///
/// - Parameters:
/// - array: input array. Must be a real symmetric or complex Hermitian matrix.
/// - UPLO: whether to use the upper (`"U"`) or lower (`"L"`) triangle of the matrix.
/// Default is `"L"`.
/// - stream: stream or device to evaluate on
/// - Returns: the eigenvalues in ascending order
public func eigvalsh(
_ array: MLXArray, UPLO: String = "L", stream: StreamOrDevice = .default
) -> MLXArray {
return MLXLinalg.eigvalsh(array, UPLO: UPLO, stream: stream)
}

/// Compute the (Moore-Penrose) pseudo-inverse of a matrix.
///
/// This function calculates a generalized inverse of a matrix using its
/// singular-value decomposition. This function supports arrays with at least
/// 2 dimensions. When the input has more than two dimensions, the inverse is
/// computed for each matrix in the last two dimensions of `array`.
///
/// - Parameters:
/// - array: input array
/// - stream: stream or device to evaluate on
/// - Returns: `aplus` such that `a @ aplus @ a = a`
public func pinv(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray {
return MLXLinalg.pinv(array, stream: stream)
}
36 changes: 36 additions & 0 deletions Source/MLX/MLXArray+Ops.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1984,6 +1984,42 @@ extension MLXArray {
return MLXArray(result)
}

/// Return the cumulative logsumexp of the elements along the given axis.
///
/// - Parameters:
/// - axis: axis to reduce over
/// - reverse: reverse the reduction
/// - inclusive: include the initial value
/// - stream: stream or device to evaluate on
///
/// ### See Also
/// - <doc:cumulative>
public func logCumsumExp(
axis: Int, reverse: Bool = false, inclusive: Bool = true, stream: StreamOrDevice = .default
) -> MLXArray {
var result = mlx_array_new()
mlx_logcumsumexp(&result, ctx, axis.int32, reverse, inclusive, stream.ctx)
return MLXArray(result)
}

/// Return the cumulative logsumexp of the elements of the flattened array.
///
/// - Parameters:
/// - reverse: reverse the reduction
/// - inclusive: include the initial value
/// - stream: stream or device to evaluate on
///
/// ### See Also
/// - <doc:cumulative>
public func logCumsumExp(
reverse: Bool = false, inclusive: Bool = true, stream: StreamOrDevice = .default
) -> MLXArray {
let flat = self.reshaped([-1], stream: stream)
var result = mlx_array_new()
mlx_logcumsumexp(&result, flat.ctx, 0, reverse, inclusive, stream.ctx)
return MLXArray(result)
}

/// A `log-sum-exp` reduction over the given axes.
///
/// The log-sum-exp reduction is a numerically stable version of:
Expand Down
Loading
Loading