Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .github/workflows/pull_request.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,18 @@ jobs:
run: |
pre-commit run --all || (echo "Style checks failed, please install pre-commit and run pre-commit run --all and push the change"; echo ""; git --no-pager diff; exit 1)

docs:
needs: lint
if: github.repository == 'ml-explore/mlx-swift'
runs-on: [self-hosted, macos]
steps:
- uses: actions/checkout@v6
with:
submodules: recursive

- name: Verify documentation
run: scripts/verify-docs.sh

mac_build_and_test:
needs: lint
if: github.repository == 'ml-explore/mlx-swift'
Expand Down
6 changes: 3 additions & 3 deletions Source/MLX/Documentation.docc/Articles/converting-python.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ This is a mapping of `mx` free functions to their ``MLX`` counterparts.
`cummin` | ``MLX/cummin(_:axis:reverse:inclusive:stream:)``
`cumprod` | ``MLX/cumprod(_:axis:reverse:inclusive:stream:)``
`cumsum` | ``MLX/cumsum(_:axis:reverse:inclusive:stream:)``
`dequantize` | ``MLX/dequantized(_:scales:biases:groupSize:bits:mode:dtype:stream:)``
`dequantize` | ``MLX/dequantized(_:scales:biases:groupSize:bits:mode:globalScale:dtype:stream:)``
`divide` | ``MLX/divide(_:_:stream:)``
`equal` | ``MLX/equal(_:_:stream:)``
`erf` | ``MLX/erf(_:stream:)``
Expand Down Expand Up @@ -215,8 +215,8 @@ This is a mapping of `mx` free functions to their ``MLX`` counterparts.
`partition` | ``MLX/partitioned(_:kth:axis:stream:)``
`power` | ``MLX/pow(_:_:stream:)-(MLXArray,MLXArray,_)``
`prod` | ``MLX/product(_:axes:keepDims:stream:)``
`qqmm` | ``MLX/quantizedQuantizedMM(_:_:scales:groupSize:bits:mode:stream:)``
`quantize` | ``MLX/quantized(_:groupSize:bits:mode:stream:)``
`qqmm` | ``MLX/quantizedQuantizedMM(_:_:scales:groupSize:bits:mode:globalScaleX:globalScaleW:stream:)``
`quantize` | ``MLX/quantized(_:groupSize:bits:mode:globalScale:stream:)``
`quantized_matmul` | ``MLX/quantizedMM(_:_:scales:biases:transpose:groupSize:bits:mode:stream:)``
`reciprocal` | ``MLX/reciprocal(_:stream:)``
`remainder` | ``MLX/remainder(_:_:stream:)``
Expand Down
9 changes: 2 additions & 7 deletions Source/MLX/Documentation.docc/Articles/vmap.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Vectorization

Automatic vectorization with ``vmap(_:inAxes:outAxes:)``.
Automatic vectorization with `vmap`.

`vmap` transforms a function so that it operates independently over a batch
axis. This is convenient for evaluating a function over many inputs without
Expand Down Expand Up @@ -41,11 +41,6 @@ Here `x` is mapped over its first axis while `y` is used as a broadcast value.

## Nested Mapping

You can nest calls to ``vmap(_:inAxes:outAxes:)`` to map over multiple axes.
You can nest calls to `vmap` to map over multiple axes.
Each nested `vmap` introduces another batch dimension in the result.

## Topics

### Functions

- ``vmap(_:inAxes:outAxes:)``
2 changes: 1 addition & 1 deletion Source/MLX/Documentation.docc/Organization/arithmetic.md
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ Note: the `-` and `/` operators are not able to be linked here.
- ``addMM(_:_:_:alpha:beta:stream:)``
- ``quantizedMM(_:_:scales:biases:transpose:groupSize:bits:mode:stream:)``
- ``gatherQuantizedMM(_:_:scales:biases:lhsIndices:rhsIndices:transpose:groupSize:bits:mode:sortedIndices:stream:)``
- ``quantizedQuantizedMM(_:_:scales:groupSize:bits:mode:stream:)``
- ``quantizedQuantizedMM(_:_:scales:groupSize:bits:mode:globalScaleX:globalScaleW:stream:)``
- ``inner(_:_:stream:)``
- ``outer(_:_:stream:)``
- ``tensordot(_:_:axes:stream:)-(MLXArray,MLXArray,Int,StreamOrDevice)``
Expand Down
6 changes: 3 additions & 3 deletions Source/MLX/Documentation.docc/free-functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,10 @@ operations as methods for convenience.

### Quantization

- ``quantizedQuantizedMM(_:_:scales:groupSize:bits:mode:stream:)``
- ``quantized(_:groupSize:bits:mode:stream:)``
- ``quantizedQuantizedMM(_:_:scales:groupSize:bits:mode:globalScaleX:globalScaleW:stream:)``
- ``quantized(_:groupSize:bits:mode:globalScale:stream:)``
- ``quantizedMatmul(_:_:scales:biases:transpose:groupSize:bits:mode:stream:)``
- ``dequantized(_:scales:biases:groupSize:bits:mode:dtype:stream:)``
- ``dequantized(_:scales:biases:groupSize:bits:mode:globalScale:dtype:stream:)``

### Evaluation and Transformation

Expand Down
2 changes: 2 additions & 0 deletions Source/MLX/ErrorHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import Foundation
/// - Parameters:
/// - handler: An error handler. Pass nil to reset to the default error handler. Pass
/// ``fatalErrorHandler`` to make the error handler call `fatalError` for improved Xcode debugging.
/// - data: Optional pointer to user data passed to the handler callback.
/// - dtor: Optional destructor called on `data` when the handler is replaced or the process exits.
@available(*, deprecated, message: "please use withErrorHandler() or withError()")
public func setErrorHandler(
_ handler: (@convention(c) (UnsafePointer<CChar>?, UnsafeMutableRawPointer?) -> Void)?,
Expand Down
1 change: 0 additions & 1 deletion Source/MLX/Linalg.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ public enum MLXLinalg {
///
/// ### See Also
/// - ``norm(_:ord:axes:keepDims:stream:)``
/// - ``norm(_:ord:axes:keepDims:stream:)-8zljj``
/// - ``MLXLinalg``
public enum NormKind: String, Sendable {
/// Frobenius norm
Expand Down
2 changes: 1 addition & 1 deletion Source/MLX/MLXArray.swift
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ public final class MLXArray {
/// update values.
///
/// ### See Also
/// - ``subscript(indices:stream:)``
/// - ``ArrayAt``
/// - ``ArrayAtIndices``
public var at: ArrayAt { ArrayAt(array: self) }
}
Expand Down
11 changes: 7 additions & 4 deletions Source/MLX/Ops.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1141,7 +1141,7 @@ public enum QuantizationMode: String, Codable, Sendable {
/// - stream: Stream or device to evaluate on
///
/// ### See Also
/// - ``quantized(_:groupSize:bits:mode:stream:)``
/// - ``quantized(_:groupSize:bits:mode:globalScale:stream:)``
/// - ``quantizedMM(_:_:scales:biases:transpose:groupSize:bits:mode:stream:)``
public func dequantized(
_ w: MLXArray,
Expand Down Expand Up @@ -2348,7 +2348,7 @@ public func putAlong(
/// [this documentation](https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.quantize.html)
///
/// ### See Also
/// - ``dequantized(_:scales:biases:groupSize:bits:mode:dtype:stream:)``
/// - ``dequantized(_:scales:biases:groupSize:bits:mode:globalScale:dtype:stream:)``
/// - ``quantizedMM(_:_:scales:biases:transpose:groupSize:bits:mode:stream:)``
public func quantized(
_ w: MLXArray,
Expand Down Expand Up @@ -2407,8 +2407,8 @@ public func quantizedMatmul(
/// - stream: Stream or device to evaluate on
///
/// ### See Also
/// - ``dequantized(_:scales:biases:groupSize:bits:mode:dtype:stream:)``
/// - ``quantized(_:groupSize:bits:mode:stream:)``
/// - ``dequantized(_:scales:biases:groupSize:bits:mode:globalScale:dtype:stream:)``
/// - ``quantized(_:groupSize:bits:mode:globalScale:stream:)``
public func quantizedMM(
_ x: MLXArray, _ w: MLXArray, scales: MLXArray, biases: MLXArray?,
transpose: Bool = true,
Expand Down Expand Up @@ -2644,6 +2644,7 @@ public func softMax(
/// - Parameters:
/// - array: input array
/// - axes: axes to compute the softmax over
/// - precise: if true, compute a more precise softmax by scaling the input
/// - stream: stream or device to evaluate on
///
/// ### See Also
Expand Down Expand Up @@ -2678,6 +2679,7 @@ public func softMax(
/// - Parameters:
/// - array: input array
/// - axis: axis to compute the softmax over
/// - precise: if true, compute a more precise softmax by scaling the input
/// - stream: stream or device to evaluate on
///
/// ### See Also
Expand Down Expand Up @@ -2710,6 +2712,7 @@ public func softMax(_ array: MLXArray, precise: Bool = false, stream: StreamOrDe
///
/// - Parameters:
/// - array: input array
/// - precise: if true, compute a more precise softmax by scaling the input
/// - stream: stream or device to evaluate on
///
/// ### See Also
Expand Down
2 changes: 2 additions & 0 deletions Source/MLX/Random.swift
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,8 @@ public enum MLXRandom {
/// - dtype: type of the output
/// - loc: mean of the distribution
/// - scale: scale "b" of the distribution
/// - key: optional PRNG key
/// - stream: stream or device to evaluate on
public static func laplace(
_ shape: some Collection<Int> = [], dtype: DType, loc: Float = 0, scale: Float = 1,
key: (some RandomStateOrKey)? = MLXArray?.none, stream: StreamOrDevice = .default
Expand Down
5 changes: 5 additions & 0 deletions Source/MLX/WiredMemory.swift
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,11 @@ public struct WiredMemoryManagerConfiguration: Sendable, Hashable {
/// - Parameters:
/// - shrinkThresholdRatio: Minimum fractional drop to allow shrinking.
/// - shrinkCooldown: Minimum time between shrink attempts while active.
/// - policyOnlyWhenUnsupported: If true, policy admission and limit calculations still run
/// even when wired memory control is unsupported.
/// - baselineOverride: Optional baseline to use instead of the cached limit.
/// - useRecommendedWorkingSetWhenUnsupported: If true and wired memory is unsupported,
/// attempt to use Metal's recommended working set size as the baseline.
public init(
shrinkThresholdRatio: Double = 0.25,
shrinkCooldown: TimeInterval = 1.0,
Expand Down
12 changes: 6 additions & 6 deletions Source/MLXFFT/FFT.swift
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,9 @@ public func rfft2(
///
/// - Parameters:
/// - array: input array
/// - n: size of the transformed axis. The corresponding axis in the input is truncated or padded with zeros to
/// match `n / 2 + 1`. If not specified `array.dim(axis) / 2 + 1` will be used.
/// - axis: axis along which to perform the FFT
/// - s: sizes of the transformed axes. The corresponding axes in the input are truncated or padded with zeros to
/// match the sizes from `s`. If not specified the sizes of `axes` in the input will be used.
/// - axes: axes along which to perform the FFT
/// - stream: stream or device to evaluate on
/// - Returns: inverse of ``rfft2(_:s:axes:stream:)``
///
Expand Down Expand Up @@ -253,9 +253,9 @@ public func rfftn(
///
/// - Parameters:
/// - array: input array
/// - n: size of the transformed axis. The corresponding axis in the input is truncated or padded with zeros to
/// match `n / 2 + 1`. If not specified `array.dim(axis) / 2 + 1` will be used.
/// - axis: axis along which to perform the FFT
/// - s: sizes of the transformed axes. The corresponding axes in the input are truncated or padded with zeros to
/// match the sizes from `s`. If not specified the sizes of `axes` in the input will be used.
/// - axes: axes along which to perform the FFT
/// - stream: stream or device to evaluate on
/// - Returns: inverse of ``rfftn(_:s:axes:stream:)``
///
Expand Down
2 changes: 1 addition & 1 deletion Source/MLXFast/MLXFastKernel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import Cmlx
import MLX

/// Container for a kernel created by
/// ``metalKernel(name:inputNames:outputNames:source:header:ensureRowContiguous:atomicOutputs:template:grid:threadGroup:outputShapes:outputDTypes:initValue:verbose:)``
/// ``metalKernel(name:inputNames:outputNames:source:header:ensureRowContiguous:atomicOutputs:)``
///
/// The ``MLXFast/MLXFastKernel`` can be used to evaluate the kernel with inputs:
///
Expand Down
20 changes: 10 additions & 10 deletions Source/MLXLinalg/Linalg.swift
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public let deprecationWarning: Void = ()
/// - Returns: output containing the norm(s)
///
/// ### See Also
/// - ``norm(_:ord:axes:keepDims:stream:)``
/// - `norm(_:ord:axes:keepDims:stream:)` (Double variant)
@available(*, deprecated, message: "norm is now available in the main MLX module")
@_disfavoredOverload
public func norm(
Expand Down Expand Up @@ -100,7 +100,7 @@ public func norm(
/// - Returns: output containing the norm(s)
///
/// ### See Also
/// - ``norm(_:ord:axes:keepDims:stream:)-4dwwp``
/// - `norm(_:ord:axes:keepDims:stream:)` (NormKind variant)
@available(*, deprecated, message: "norm is now available in the main MLX module")
@_disfavoredOverload
public func norm(
Expand All @@ -112,7 +112,7 @@ public func norm(

/// Matrix or vector norm.
///
/// See ``norm(_:ord:axes:keepDims:stream:)-4dwwp``
/// See `norm(_:ord:axes:keepDims:stream:)` (NormKind variant)
@available(*, deprecated, message: "norm is now available in the main MLX module")
@_disfavoredOverload
public func norm(
Expand All @@ -124,7 +124,7 @@ public func norm(

/// Matrix or vector norm.
///
/// See ``norm(_:ord:axes:keepDims:stream:)``
/// See `norm(_:ord:axes:keepDims:stream:)` (Double variant)
@available(*, deprecated, message: "norm is now available in the main MLX module")
@_disfavoredOverload
public func norm(
Expand All @@ -136,7 +136,7 @@ public func norm(

/// Matrix or vector norm.
///
/// See ``norm(_:ord:axes:keepDims:stream:)-4dwwp``
/// See `norm(_:ord:axes:keepDims:stream:)` (NormKind variant)
@available(*, deprecated, message: "norm is now available in the main MLX module")
@_disfavoredOverload
public func norm(
Expand All @@ -148,7 +148,7 @@ public func norm(

/// Matrix or vector norm.
///
/// See ``norm(_:ord:axes:keepDims:stream:)``
/// See `norm(_:ord:axes:keepDims:stream:)` (Double variant)
@available(*, deprecated, message: "norm is now available in the main MLX module")
@_disfavoredOverload
public func norm(
Expand Down Expand Up @@ -286,11 +286,11 @@ public func cross(_ a: MLXArray, _ b: MLXArray, axis: Int = -1, stream: StreamOr
return MLXLinalg.cross(a, b, axis: axis, stream: stream)
}

/// Compute the LU factorization of the given matrix ``A``.
/// Compute the LU factorization of the given matrix `A`.
///
/// Note, unlike the default behavior of ``scipy.linalg.lu``, the pivots
/// are indices. To reconstruct the input use ``L[P] @ U`` for 2
/// dimensions or ``takeAlong(L, P[.ellipsis, .newAxis], axis: -2) @ U``
/// Note, unlike the default behavior of `scipy.linalg.lu`, the pivots
/// are indices. To reconstruct the input use `L[P] @ U` for 2
/// dimensions or `takeAlong(L, P[.ellipsis, .newAxis], axis: -2) @ U`
/// for more than 2 dimensions.
///
/// To construct the full permuation matrix do:
Expand Down
6 changes: 3 additions & 3 deletions Source/MLXNN/ConvolutionTransposed.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import MLX
/// ### See Also
/// - ``ConvTransposed2d``
/// - ``ConvTransposed3d``
/// - ``init(inputChannels:outputChannels:kernelSize:stride:padding:dilation:groups:bias:)``
/// - ``init(inputChannels:outputChannels:kernelSize:stride:padding:outputPadding:dilation:groups:bias:)``
open class ConvTransposed1d: Module, UnaryLayer {

public let weight: MLXArray
Expand Down Expand Up @@ -83,7 +83,7 @@ open class ConvTransposed1d: Module, UnaryLayer {
/// ### See Also
/// - ``ConvTransposed1d``
/// - ``ConvTransposed3d``
/// - ``init(inputChannels:outputChannels:kernelSize:stride:padding:dilation:groups:bias:)``
/// - ``init(inputChannels:outputChannels:kernelSize:stride:padding:outputPadding:dilation:groups:bias:)``
open class ConvTransposed2d: Module, UnaryLayer {

public let weight: MLXArray
Expand Down Expand Up @@ -158,7 +158,7 @@ open class ConvTransposed2d: Module, UnaryLayer {
/// ### See Also
/// - ``ConvTransposed1d``
/// - ``ConvTransposed2d``
/// - ``init(inputChannels:outputChannels:kernelSize:stride:padding:dilation:groups:bias:)``
/// - ``init(inputChannels:outputChannels:kernelSize:stride:padding:outputPadding:dilation:groups:bias:)``
open class ConvTransposed3d: Module, UnaryLayer {

public let weight: MLXArray
Expand Down
4 changes: 2 additions & 2 deletions Source/MLXNN/Documentation.docc/Module.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
- ``Module/parameters()``
- ``Module/trainableParameters()``
- ``Module/update(parameters:)``
- ``Module/update(parameters:verify:)``
- ``Module/update(parameters:verify:path:modulePath:)``

### Layers (sub-modules)

Expand All @@ -20,7 +20,7 @@
- ``Module/modules()``
- ``Module/namedModules()``
- ``Module/update(modules:)``
- ``Module/update(modules:verify:)``
- ``Module/update(modules:verify:path:modulePath:)``
- ``Module/visit(modules:)``

### Traversal
Expand Down
6 changes: 3 additions & 3 deletions Source/MLXNN/Documentation.docc/custom-layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ these in swift.

MLX modules allow accessing and updating individual parameters. However, most
times we need to update large subsets of a module's parameters. This action is
performed by ``Module/update(parameters:verify:)``.
performed by ``Module/update(parameters:verify:path:modulePath:)``.

See also <doc:training>.

Expand Down Expand Up @@ -140,7 +140,7 @@ The ``ModuleInfo`` and ``ParameterInfo`` provide two important features for modu
instance variables:

- both property wrappers allow replacement keys to be specified
- the ``ModuleInfo`` allows ``Module/update(modules:verify:)`` to replace the module
- the ``ModuleInfo`` allows ``Module/update(modules:verify:path:modulePath:)`` to replace the module

Replacement keys are important because many times models and weights are defined
in terms of their python implementation. For example
Expand Down Expand Up @@ -196,7 +196,7 @@ public class FeedForward : Module {
}
```

The `ModuleInfo` provides a hook for ``QuantizedLinear`` and ``Module/update(modules:verify:)`` to
The `ModuleInfo` provides a hook for ``QuantizedLinear`` and ``Module/update(modules:verify:path:modulePath:)`` to
replace the contents of `w1`, etc. with a new compatible `Model` after it is created.

Note that `MLXArray` is settable without any ``ParameterInfo`` -- it has an `update()` method.
Expand Down
2 changes: 1 addition & 1 deletion Source/MLXNN/Linear.swift
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public class Identity: Module, UnaryLayer {
///
/// ## Using In A Module
///
/// > Use `@ModuleInfo` with all your `Linear` module uses so that ``Module/update(modules:verify:)`` can
/// > Use `@ModuleInfo` with all your `Linear` module uses so that ``Module/update(modules:verify:path:modulePath:)`` can
/// replace the modules, e.g. via ``QuantizedLinear/quantize(model:groupSize:bits:predicate:)``.
///
/// For example:
Expand Down
Loading
Loading