From adb4986858ddb839ef9a1347d712c76a9b0cd847 Mon Sep 17 00:00:00 2001 From: Anthony DePasquale Date: Thu, 2 Apr 2026 10:17:01 +0200 Subject: [PATCH 1/3] Add doc comment verification script and CI step --- .github/workflows/pull_request.yml | 12 ++++++++++++ scripts/verify-docs.sh | 25 +++++++++++++++++++++++++ 2 files changed, 37 insertions(+) create mode 100755 scripts/verify-docs.sh diff --git a/.github/workflows/pull_request.yml b/.github/workflows/pull_request.yml index 82795c6f..c00c2c19 100644 --- a/.github/workflows/pull_request.yml +++ b/.github/workflows/pull_request.yml @@ -69,6 +69,18 @@ jobs: run: | pre-commit run --all || (echo "Style checks failed, please install pre-commit and run pre-commit run --all and push the change"; echo ""; git --no-pager diff; exit 1) + docs: + needs: lint + if: github.repository == 'ml-explore/mlx-swift' + runs-on: [self-hosted, macos] + steps: + - uses: actions/checkout@v6 + with: + submodules: recursive + + - name: Verify documentation + run: scripts/verify-docs.sh + mac_build_and_test: needs: lint if: github.repository == 'ml-explore/mlx-swift' diff --git a/scripts/verify-docs.sh b/scripts/verify-docs.sh new file mode 100755 index 00000000..74494904 --- /dev/null +++ b/scripts/verify-docs.sh @@ -0,0 +1,25 @@ +#!/bin/bash +# Verify documentation builds without warnings + +set -e +cd "$(dirname "$0")/.." + +export MLX_SWIFT_BUILD_DOC=1 + +TARGETS=("MLX" "MLXRandom" "MLXNN" "MLXOptimizers" "MLXFFT" "MLXLinalg" "MLXFast") +FAILED=0 + +for TARGET in "${TARGETS[@]}"; do + echo "Building documentation for $TARGET..." + if ! swift package generate-documentation --target "$TARGET" --warnings-as-errors; then + FAILED=1 + fi + echo "" +done + +if [ "$FAILED" -ne 0 ]; then + echo "Documentation build failed with warnings." + exit 1 +fi + +echo "All documentation builds passed." From 1548f810b7c8fee001fc11727c5639f4d96b170e Mon Sep 17 00:00:00 2001 From: Anthony DePasquale Date: Thu, 2 Apr 2026 10:39:31 +0200 Subject: [PATCH 2/3] Fix doc comments --- .../Articles/converting-python.md | 6 +++--- .../MLX/Documentation.docc/Articles/vmap.md | 9 ++------- .../Organization/arithmetic.md | 2 +- .../MLX/Documentation.docc/free-functions.md | 6 +++--- Source/MLX/ErrorHandler.swift | 2 ++ Source/MLX/Linalg.swift | 1 - Source/MLX/MLXArray.swift | 2 +- Source/MLX/Ops.swift | 11 ++++++---- Source/MLX/Random.swift | 2 ++ Source/MLX/WiredMemory.swift | 5 +++++ Source/MLXFFT/FFT.swift | 12 +++++------ Source/MLXFast/MLXFastKernel.swift | 2 +- Source/MLXLinalg/Linalg.swift | 20 +++++++++---------- Source/MLXNN/ConvolutionTransposed.swift | 6 +++--- Source/MLXNN/Documentation.docc/Module.md | 4 ++-- .../MLXNN/Documentation.docc/custom-layers.md | 6 +++--- Source/MLXNN/Linear.swift | 2 +- Source/MLXNN/Module.swift | 20 +++++++++++-------- Source/MLXRandom/Random.swift | 13 ++++++++++++ 19 files changed, 77 insertions(+), 54 deletions(-) diff --git a/Source/MLX/Documentation.docc/Articles/converting-python.md b/Source/MLX/Documentation.docc/Articles/converting-python.md index 1b5e6a33..fa82f309 100644 --- a/Source/MLX/Documentation.docc/Articles/converting-python.md +++ b/Source/MLX/Documentation.docc/Articles/converting-python.md @@ -173,7 +173,7 @@ This is a mapping of `mx` free functions to their ``MLX`` counterparts. `cummin` | ``MLX/cummin(_:axis:reverse:inclusive:stream:)`` `cumprod` | ``MLX/cumprod(_:axis:reverse:inclusive:stream:)`` `cumsum` | ``MLX/cumsum(_:axis:reverse:inclusive:stream:)`` -`dequantize` | ``MLX/dequantized(_:scales:biases:groupSize:bits:mode:dtype:stream:)`` +`dequantize` | ``MLX/dequantized(_:scales:biases:groupSize:bits:mode:globalScale:dtype:stream:)`` `divide` | ``MLX/divide(_:_:stream:)`` `equal` | ``MLX/equal(_:_:stream:)`` `erf` | ``MLX/erf(_:stream:)`` @@ -215,8 +215,8 @@ This is a mapping of `mx` free functions to their ``MLX`` counterparts. `partition` | ``MLX/partitioned(_:kth:axis:stream:)`` `power` | ``MLX/pow(_:_:stream:)-(MLXArray,MLXArray,_)`` `prod` | ``MLX/product(_:axes:keepDims:stream:)`` -`qqmm` | ``MLX/quantizedQuantizedMM(_:_:scales:groupSize:bits:mode:stream:)`` -`quantize` | ``MLX/quantized(_:groupSize:bits:mode:stream:)`` +`qqmm` | ``MLX/quantizedQuantizedMM(_:_:scales:groupSize:bits:mode:globalScaleX:globalScaleW:stream:)`` +`quantize` | ``MLX/quantized(_:groupSize:bits:mode:globalScale:stream:)`` `quantized_matmul` | ``MLX/quantizedMM(_:_:scales:biases:transpose:groupSize:bits:mode:stream:)`` `reciprocal` | ``MLX/reciprocal(_:stream:)`` `remainder` | ``MLX/remainder(_:_:stream:)`` diff --git a/Source/MLX/Documentation.docc/Articles/vmap.md b/Source/MLX/Documentation.docc/Articles/vmap.md index 6d9e1579..719e631f 100644 --- a/Source/MLX/Documentation.docc/Articles/vmap.md +++ b/Source/MLX/Documentation.docc/Articles/vmap.md @@ -1,6 +1,6 @@ # Vectorization -Automatic vectorization with ``vmap(_:inAxes:outAxes:)``. +Automatic vectorization with `vmap`. `vmap` transforms a function so that it operates independently over a batch axis. This is convenient for evaluating a function over many inputs without @@ -41,11 +41,6 @@ Here `x` is mapped over its first axis while `y` is used as a broadcast value. ## Nested Mapping -You can nest calls to ``vmap(_:inAxes:outAxes:)`` to map over multiple axes. +You can nest calls to `vmap` to map over multiple axes. Each nested `vmap` introduces another batch dimension in the result. -## Topics - -### Functions - -- ``vmap(_:inAxes:outAxes:)`` diff --git a/Source/MLX/Documentation.docc/Organization/arithmetic.md b/Source/MLX/Documentation.docc/Organization/arithmetic.md index 7a434ef3..37c37e0c 100644 --- a/Source/MLX/Documentation.docc/Organization/arithmetic.md +++ b/Source/MLX/Documentation.docc/Organization/arithmetic.md @@ -215,7 +215,7 @@ Note: the `-` and `/` operators are not able to be linked here. - ``addMM(_:_:_:alpha:beta:stream:)`` - ``quantizedMM(_:_:scales:biases:transpose:groupSize:bits:mode:stream:)`` - ``gatherQuantizedMM(_:_:scales:biases:lhsIndices:rhsIndices:transpose:groupSize:bits:mode:sortedIndices:stream:)`` -- ``quantizedQuantizedMM(_:_:scales:groupSize:bits:mode:stream:)`` +- ``quantizedQuantizedMM(_:_:scales:groupSize:bits:mode:globalScaleX:globalScaleW:stream:)`` - ``inner(_:_:stream:)`` - ``outer(_:_:stream:)`` - ``tensordot(_:_:axes:stream:)-(MLXArray,MLXArray,Int,StreamOrDevice)`` diff --git a/Source/MLX/Documentation.docc/free-functions.md b/Source/MLX/Documentation.docc/free-functions.md index 98d1ab22..e392e428 100644 --- a/Source/MLX/Documentation.docc/free-functions.md +++ b/Source/MLX/Documentation.docc/free-functions.md @@ -194,10 +194,10 @@ operations as methods for convenience. ### Quantization -- ``quantizedQuantizedMM(_:_:scales:groupSize:bits:mode:stream:)`` -- ``quantized(_:groupSize:bits:mode:stream:)`` +- ``quantizedQuantizedMM(_:_:scales:groupSize:bits:mode:globalScaleX:globalScaleW:stream:)`` +- ``quantized(_:groupSize:bits:mode:globalScale:stream:)`` - ``quantizedMatmul(_:_:scales:biases:transpose:groupSize:bits:mode:stream:)`` -- ``dequantized(_:scales:biases:groupSize:bits:mode:dtype:stream:)`` +- ``dequantized(_:scales:biases:groupSize:bits:mode:globalScale:dtype:stream:)`` ### Evaluation and Transformation diff --git a/Source/MLX/ErrorHandler.swift b/Source/MLX/ErrorHandler.swift index 98fa689b..4d0daab8 100644 --- a/Source/MLX/ErrorHandler.swift +++ b/Source/MLX/ErrorHandler.swift @@ -5,6 +5,8 @@ import Foundation /// - Parameters: /// - handler: An error handler. Pass nil to reset to the default error handler. Pass /// ``fatalErrorHandler`` to make the error handler call `fatalError` for improved Xcode debugging. +/// - data: Optional pointer to user data passed to the handler callback. +/// - dtor: Optional destructor called on `data` when the handler is replaced or the process exits. @available(*, deprecated, message: "please use withErrorHandler() or withError()") public func setErrorHandler( _ handler: (@convention(c) (UnsafePointer?, UnsafeMutableRawPointer?) -> Void)?, diff --git a/Source/MLX/Linalg.swift b/Source/MLX/Linalg.swift index 8ef07557..c2e2813c 100644 --- a/Source/MLX/Linalg.swift +++ b/Source/MLX/Linalg.swift @@ -9,7 +9,6 @@ public enum MLXLinalg { /// /// ### See Also /// - ``norm(_:ord:axes:keepDims:stream:)`` - /// - ``norm(_:ord:axes:keepDims:stream:)-8zljj`` /// - ``MLXLinalg`` public enum NormKind: String, Sendable { /// Frobenius norm diff --git a/Source/MLX/MLXArray.swift b/Source/MLX/MLXArray.swift index 121955b4..44014aad 100644 --- a/Source/MLX/MLXArray.swift +++ b/Source/MLX/MLXArray.swift @@ -602,7 +602,7 @@ public final class MLXArray { /// update values. /// /// ### See Also - /// - ``subscript(indices:stream:)`` + /// - ``ArrayAt`` /// - ``ArrayAtIndices`` public var at: ArrayAt { ArrayAt(array: self) } } diff --git a/Source/MLX/Ops.swift b/Source/MLX/Ops.swift index 829da4ab..c640eb44 100644 --- a/Source/MLX/Ops.swift +++ b/Source/MLX/Ops.swift @@ -1141,7 +1141,7 @@ public enum QuantizationMode: String, Codable, Sendable { /// - stream: Stream or device to evaluate on /// /// ### See Also -/// - ``quantized(_:groupSize:bits:mode:stream:)`` +/// - ``quantized(_:groupSize:bits:mode:globalScale:stream:)`` /// - ``quantizedMM(_:_:scales:biases:transpose:groupSize:bits:mode:stream:)`` public func dequantized( _ w: MLXArray, @@ -2348,7 +2348,7 @@ public func putAlong( /// [this documentation](https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.quantize.html) /// /// ### See Also -/// - ``dequantized(_:scales:biases:groupSize:bits:mode:dtype:stream:)`` +/// - ``dequantized(_:scales:biases:groupSize:bits:mode:globalScale:dtype:stream:)`` /// - ``quantizedMM(_:_:scales:biases:transpose:groupSize:bits:mode:stream:)`` public func quantized( _ w: MLXArray, @@ -2407,8 +2407,8 @@ public func quantizedMatmul( /// - stream: Stream or device to evaluate on /// /// ### See Also -/// - ``dequantized(_:scales:biases:groupSize:bits:mode:dtype:stream:)`` -/// - ``quantized(_:groupSize:bits:mode:stream:)`` +/// - ``dequantized(_:scales:biases:groupSize:bits:mode:globalScale:dtype:stream:)`` +/// - ``quantized(_:groupSize:bits:mode:globalScale:stream:)`` public func quantizedMM( _ x: MLXArray, _ w: MLXArray, scales: MLXArray, biases: MLXArray?, transpose: Bool = true, @@ -2644,6 +2644,7 @@ public func softMax( /// - Parameters: /// - array: input array /// - axes: axes to compute the softmax over +/// - precise: if true, compute a more precise softmax by scaling the input /// - stream: stream or device to evaluate on /// /// ### See Also @@ -2678,6 +2679,7 @@ public func softMax( /// - Parameters: /// - array: input array /// - axis: axis to compute the softmax over +/// - precise: if true, compute a more precise softmax by scaling the input /// - stream: stream or device to evaluate on /// /// ### See Also @@ -2710,6 +2712,7 @@ public func softMax(_ array: MLXArray, precise: Bool = false, stream: StreamOrDe /// /// - Parameters: /// - array: input array +/// - precise: if true, compute a more precise softmax by scaling the input /// - stream: stream or device to evaluate on /// /// ### See Also diff --git a/Source/MLX/Random.swift b/Source/MLX/Random.swift index 1f9b32a5..e267500f 100644 --- a/Source/MLX/Random.swift +++ b/Source/MLX/Random.swift @@ -766,6 +766,8 @@ public enum MLXRandom { /// - dtype: type of the output /// - loc: mean of the distribution /// - scale: scale "b" of the distribution + /// - key: optional PRNG key + /// - stream: stream or device to evaluate on public static func laplace( _ shape: some Collection = [], dtype: DType, loc: Float = 0, scale: Float = 1, key: (some RandomStateOrKey)? = MLXArray?.none, stream: StreamOrDevice = .default diff --git a/Source/MLX/WiredMemory.swift b/Source/MLX/WiredMemory.swift index 29f5081f..fff688e3 100644 --- a/Source/MLX/WiredMemory.swift +++ b/Source/MLX/WiredMemory.swift @@ -209,6 +209,11 @@ public struct WiredMemoryManagerConfiguration: Sendable, Hashable { /// - Parameters: /// - shrinkThresholdRatio: Minimum fractional drop to allow shrinking. /// - shrinkCooldown: Minimum time between shrink attempts while active. + /// - policyOnlyWhenUnsupported: If true, policy admission and limit calculations still run + /// even when wired memory control is unsupported. + /// - baselineOverride: Optional baseline to use instead of the cached limit. + /// - useRecommendedWorkingSetWhenUnsupported: If true and wired memory is unsupported, + /// attempt to use Metal's recommended working set size as the baseline. public init( shrinkThresholdRatio: Double = 0.25, shrinkCooldown: TimeInterval = 1.0, diff --git a/Source/MLXFFT/FFT.swift b/Source/MLXFFT/FFT.swift index f91409cc..56bf2f3b 100644 --- a/Source/MLXFFT/FFT.swift +++ b/Source/MLXFFT/FFT.swift @@ -204,9 +204,9 @@ public func rfft2( /// /// - Parameters: /// - array: input array -/// - n: size of the transformed axis. The corresponding axis in the input is truncated or padded with zeros to -/// match `n / 2 + 1`. If not specified `array.dim(axis) / 2 + 1` will be used. -/// - axis: axis along which to perform the FFT +/// - s: sizes of the transformed axes. The corresponding axes in the input are truncated or padded with zeros to +/// match the sizes from `s`. If not specified the sizes of `axes` in the input will be used. +/// - axes: axes along which to perform the FFT /// - stream: stream or device to evaluate on /// - Returns: inverse of ``rfft2(_:s:axes:stream:)`` /// @@ -253,9 +253,9 @@ public func rfftn( /// /// - Parameters: /// - array: input array -/// - n: size of the transformed axis. The corresponding axis in the input is truncated or padded with zeros to -/// match `n / 2 + 1`. If not specified `array.dim(axis) / 2 + 1` will be used. -/// - axis: axis along which to perform the FFT +/// - s: sizes of the transformed axes. The corresponding axes in the input are truncated or padded with zeros to +/// match the sizes from `s`. If not specified the sizes of `axes` in the input will be used. +/// - axes: axes along which to perform the FFT /// - stream: stream or device to evaluate on /// - Returns: inverse of ``rfftn(_:s:axes:stream:)`` /// diff --git a/Source/MLXFast/MLXFastKernel.swift b/Source/MLXFast/MLXFastKernel.swift index cb95bf10..635d6396 100644 --- a/Source/MLXFast/MLXFastKernel.swift +++ b/Source/MLXFast/MLXFastKernel.swift @@ -4,7 +4,7 @@ import Cmlx import MLX /// Container for a kernel created by -/// ``metalKernel(name:inputNames:outputNames:source:header:ensureRowContiguous:atomicOutputs:template:grid:threadGroup:outputShapes:outputDTypes:initValue:verbose:)`` +/// ``metalKernel(name:inputNames:outputNames:source:header:ensureRowContiguous:atomicOutputs:)`` /// /// The ``MLXFast/MLXFastKernel`` can be used to evaluate the kernel with inputs: /// diff --git a/Source/MLXLinalg/Linalg.swift b/Source/MLXLinalg/Linalg.swift index f6d63df0..7a287bf6 100644 --- a/Source/MLXLinalg/Linalg.swift +++ b/Source/MLXLinalg/Linalg.swift @@ -50,7 +50,7 @@ public let deprecationWarning: Void = () /// - Returns: output containing the norm(s) /// /// ### See Also -/// - ``norm(_:ord:axes:keepDims:stream:)`` +/// - `norm(_:ord:axes:keepDims:stream:)` (Double variant) @available(*, deprecated, message: "norm is now available in the main MLX module") @_disfavoredOverload public func norm( @@ -100,7 +100,7 @@ public func norm( /// - Returns: output containing the norm(s) /// /// ### See Also -/// - ``norm(_:ord:axes:keepDims:stream:)-4dwwp`` +/// - `norm(_:ord:axes:keepDims:stream:)` (NormKind variant) @available(*, deprecated, message: "norm is now available in the main MLX module") @_disfavoredOverload public func norm( @@ -112,7 +112,7 @@ public func norm( /// Matrix or vector norm. /// -/// See ``norm(_:ord:axes:keepDims:stream:)-4dwwp`` +/// See `norm(_:ord:axes:keepDims:stream:)` (NormKind variant) @available(*, deprecated, message: "norm is now available in the main MLX module") @_disfavoredOverload public func norm( @@ -124,7 +124,7 @@ public func norm( /// Matrix or vector norm. /// -/// See ``norm(_:ord:axes:keepDims:stream:)`` +/// See `norm(_:ord:axes:keepDims:stream:)` (Double variant) @available(*, deprecated, message: "norm is now available in the main MLX module") @_disfavoredOverload public func norm( @@ -136,7 +136,7 @@ public func norm( /// Matrix or vector norm. /// -/// See ``norm(_:ord:axes:keepDims:stream:)-4dwwp`` +/// See `norm(_:ord:axes:keepDims:stream:)` (NormKind variant) @available(*, deprecated, message: "norm is now available in the main MLX module") @_disfavoredOverload public func norm( @@ -148,7 +148,7 @@ public func norm( /// Matrix or vector norm. /// -/// See ``norm(_:ord:axes:keepDims:stream:)`` +/// See `norm(_:ord:axes:keepDims:stream:)` (Double variant) @available(*, deprecated, message: "norm is now available in the main MLX module") @_disfavoredOverload public func norm( @@ -286,11 +286,11 @@ public func cross(_ a: MLXArray, _ b: MLXArray, axis: Int = -1, stream: StreamOr return MLXLinalg.cross(a, b, axis: axis, stream: stream) } -/// Compute the LU factorization of the given matrix ``A``. +/// Compute the LU factorization of the given matrix `A`. /// -/// Note, unlike the default behavior of ``scipy.linalg.lu``, the pivots -/// are indices. To reconstruct the input use ``L[P] @ U`` for 2 -/// dimensions or ``takeAlong(L, P[.ellipsis, .newAxis], axis: -2) @ U`` +/// Note, unlike the default behavior of `scipy.linalg.lu`, the pivots +/// are indices. To reconstruct the input use `L[P] @ U` for 2 +/// dimensions or `takeAlong(L, P[.ellipsis, .newAxis], axis: -2) @ U` /// for more than 2 dimensions. /// /// To construct the full permuation matrix do: diff --git a/Source/MLXNN/ConvolutionTransposed.swift b/Source/MLXNN/ConvolutionTransposed.swift index 389d34e3..dbba317f 100644 --- a/Source/MLXNN/ConvolutionTransposed.swift +++ b/Source/MLXNN/ConvolutionTransposed.swift @@ -8,7 +8,7 @@ import MLX /// ### See Also /// - ``ConvTransposed2d`` /// - ``ConvTransposed3d`` -/// - ``init(inputChannels:outputChannels:kernelSize:stride:padding:dilation:groups:bias:)`` +/// - ``init(inputChannels:outputChannels:kernelSize:stride:padding:outputPadding:dilation:groups:bias:)`` open class ConvTransposed1d: Module, UnaryLayer { public let weight: MLXArray @@ -83,7 +83,7 @@ open class ConvTransposed1d: Module, UnaryLayer { /// ### See Also /// - ``ConvTransposed1d`` /// - ``ConvTransposed3d`` -/// - ``init(inputChannels:outputChannels:kernelSize:stride:padding:dilation:groups:bias:)`` +/// - ``init(inputChannels:outputChannels:kernelSize:stride:padding:outputPadding:dilation:groups:bias:)`` open class ConvTransposed2d: Module, UnaryLayer { public let weight: MLXArray @@ -158,7 +158,7 @@ open class ConvTransposed2d: Module, UnaryLayer { /// ### See Also /// - ``ConvTransposed1d`` /// - ``ConvTransposed2d`` -/// - ``init(inputChannels:outputChannels:kernelSize:stride:padding:dilation:groups:bias:)`` +/// - ``init(inputChannels:outputChannels:kernelSize:stride:padding:outputPadding:dilation:groups:bias:)`` open class ConvTransposed3d: Module, UnaryLayer { public let weight: MLXArray diff --git a/Source/MLXNN/Documentation.docc/Module.md b/Source/MLXNN/Documentation.docc/Module.md index 850b8cf5..ae940c4e 100644 --- a/Source/MLXNN/Documentation.docc/Module.md +++ b/Source/MLXNN/Documentation.docc/Module.md @@ -10,7 +10,7 @@ - ``Module/parameters()`` - ``Module/trainableParameters()`` - ``Module/update(parameters:)`` -- ``Module/update(parameters:verify:)`` +- ``Module/update(parameters:verify:path:modulePath:)`` ### Layers (sub-modules) @@ -20,7 +20,7 @@ - ``Module/modules()`` - ``Module/namedModules()`` - ``Module/update(modules:)`` -- ``Module/update(modules:verify:)`` +- ``Module/update(modules:verify:path:modulePath:)`` - ``Module/visit(modules:)`` ### Traversal diff --git a/Source/MLXNN/Documentation.docc/custom-layers.md b/Source/MLXNN/Documentation.docc/custom-layers.md index 0b25662b..3de3cb23 100644 --- a/Source/MLXNN/Documentation.docc/custom-layers.md +++ b/Source/MLXNN/Documentation.docc/custom-layers.md @@ -86,7 +86,7 @@ these in swift. MLX modules allow accessing and updating individual parameters. However, most times we need to update large subsets of a module's parameters. This action is -performed by ``Module/update(parameters:verify:)``. +performed by ``Module/update(parameters:verify:path:modulePath:)``. See also . @@ -140,7 +140,7 @@ The ``ModuleInfo`` and ``ParameterInfo`` provide two important features for modu instance variables: - both property wrappers allow replacement keys to be specified -- the ``ModuleInfo`` allows ``Module/update(modules:verify:)`` to replace the module +- the ``ModuleInfo`` allows ``Module/update(modules:verify:path:modulePath:)`` to replace the module Replacement keys are important because many times models and weights are defined in terms of their python implementation. For example @@ -196,7 +196,7 @@ public class FeedForward : Module { } ``` -The `ModuleInfo` provides a hook for ``QuantizedLinear`` and ``Module/update(modules:verify:)`` to +The `ModuleInfo` provides a hook for ``QuantizedLinear`` and ``Module/update(modules:verify:path:modulePath:)`` to replace the contents of `w1`, etc. with a new compatible `Model` after it is created. Note that `MLXArray` is settable without any ``ParameterInfo`` -- it has an `update()` method. diff --git a/Source/MLXNN/Linear.swift b/Source/MLXNN/Linear.swift index 29bce458..dd52c422 100644 --- a/Source/MLXNN/Linear.swift +++ b/Source/MLXNN/Linear.swift @@ -31,7 +31,7 @@ public class Identity: Module, UnaryLayer { /// /// ## Using In A Module /// -/// > Use `@ModuleInfo` with all your `Linear` module uses so that ``Module/update(modules:verify:)`` can +/// > Use `@ModuleInfo` with all your `Linear` module uses so that ``Module/update(modules:verify:path:modulePath:)`` can /// replace the modules, e.g. via ``QuantizedLinear/quantize(model:groupSize:bits:predicate:)``. /// /// For example: diff --git a/Source/MLXNN/Module.swift b/Source/MLXNN/Module.swift index 75e027bd..7cba9b12 100644 --- a/Source/MLXNN/Module.swift +++ b/Source/MLXNN/Module.swift @@ -79,7 +79,7 @@ public typealias ModuleItem = NestedItem /// /// > Please read for more information about implementing custom layers /// including how to override the module and parameter keys and allowing dynamic updates of the -/// module structure to occur via ``update(modules:verify:)``. +/// module structure to occur via ``update(modules:verify:path:modulePath:)``. /// /// ### Training /// @@ -379,7 +379,7 @@ open class Module { isLeaf: Self.isLeafModuleNoChildren) } - /// Options for verifying ``update(parameters:verify:)`` and ``update(modules:verify:)``. + /// Options for verifying ``update(parameters:verify:path:modulePath:)`` and ``update(modules:verify:path:modulePath:)``. public struct VerifyUpdate: OptionSet, Sendable { public init(rawValue: Int) { self.rawValue = rawValue @@ -398,7 +398,7 @@ open class Module { static public let none = VerifyUpdate([]) } - /// A non-throwing version of ``update(parameters:verify:)``. + /// A non-throwing version of ``update(parameters:verify:path:modulePath:)``. /// /// This passes `verify: .none`. Note that there may still be `fatalErrors()` if /// for example an `MLXArray` is set on a `Module`. @@ -434,6 +434,8 @@ open class Module { /// - parameters: replacement parameters in the same format that ``parameters()`` /// or ``mapParameters(map:isLeaf:)`` provides /// - verify: options for verifying parameters + /// - path: the key path used for error reporting during recursive updates + /// - modulePath: the module type path used for error reporting during recursive updates /// /// ### See Also /// - @@ -441,7 +443,7 @@ open class Module { /// - ``apply(filter:map:)`` /// - ``parameters()`` /// - ``mapParameters(map:isLeaf:)`` - /// - ``update(modules:verify:)`` + /// - ``update(modules:verify:path:modulePath:)`` @discardableResult open func update( parameters: ModuleParameters, verify: VerifyUpdate, path: [String] = [], @@ -587,7 +589,7 @@ open class Module { update(parameters: filterMap(filter: filter, map: Self.mapParameters(map: map))) } - /// A non-throwing version of ``update(modules:verify:)``. + /// A non-throwing version of ``update(modules:verify:path:modulePath:)``. /// /// This passes `verify: .none`. Note that there may still be `fatalErrors()` if /// for example an `Module` is set on a `MLXArray`. @@ -630,11 +632,13 @@ open class Module { /// - Parameters: /// - modules: replacement modules in the same format as ``children()`` or ``leafModules()`` /// - verify: options for verifying parameters + /// - path: the key path used for error reporting during recursive updates + /// - modulePath: the module type path used for error reporting during recursive updates /// /// ### See Also /// - /// - ``update(modules:)`` - /// - ``update(parameters:verify:)`` + /// - ``update(parameters:verify:path:modulePath:)`` /// - ``children()`` /// - ``leafModules()`` /// - ``QuantizedLinear/quantize(model:groupSize:bits:predicate:)`` @@ -1454,7 +1458,7 @@ private protocol TypeErasedSetterProvider { } /// ModuleInfo can provde information about child modules and act as an -/// update point for ``Module/update(modules:verify:)``. +/// update point for ``Module/update(modules:verify:path:modulePath:)``. /// /// The keys for modules and parameters are usually named after their instance variables, /// but `feed_forward` would not be a very Swifty variable name: @@ -1503,7 +1507,7 @@ private protocol TypeErasedSetterProvider { /// } /// ``` /// -/// The `ModuleInfo` provides a hook for ``QuantizedLinear`` and ``Module/update(modules:verify:)`` /// to +/// The `ModuleInfo` provides a hook for ``QuantizedLinear`` and ``Module/update(modules:verify:path:modulePath:)`` to /// replace the contents of `w1`, etc. with a new compatible `Model` after it is created. /// /// ### See Also diff --git a/Source/MLXRandom/Random.swift b/Source/MLXRandom/Random.swift index da4b6a38..26eac50e 100644 --- a/Source/MLXRandom/Random.swift +++ b/Source/MLXRandom/Random.swift @@ -158,6 +158,7 @@ public func uniform( /// - loc: mean of the distribution /// - scale: standard deviation of the distribution /// - key: PRNG key +/// - stream: stream or device to evaluate on @available(*, deprecated, message: "normal is now available in the main MLX module") @_disfavoredOverload public func normal( @@ -189,6 +190,7 @@ public func normal( /// - loc: mean of the distribution /// - scale: standard deviation of the distribution /// - key: PRNG key +/// - stream: stream or device to evaluate on @available(*, deprecated, message: "normal is now available in the main MLX module") @_disfavoredOverload public func normal( @@ -215,6 +217,7 @@ public func normal( /// shapes of `mean` and `covariance`. /// - dtype: DType of the result /// - key: PRNG key +/// - stream: stream or device to evaluate on @available(*, deprecated, message: "multivariateNormal is now available in the main MLX module") @_disfavoredOverload public func multivariateNormal( @@ -496,6 +499,10 @@ public func gumbel( /// /// - Parameters: /// - logits: The *unnormalized* categorical distribution(s). +/// - axis: the axis which specifies the distribution +/// - shape: the shape of the output +/// - key: PRNG key +/// - stream: stream or device to evaluate on @available(*, deprecated, message: "categorical is now available in the main MLX module") @_disfavoredOverload public func categorical( @@ -521,6 +528,10 @@ public func categorical( /// /// - Parameters: /// - logits: The *unnormalized* categorical distribution(s). +/// - axis: the axis which specifies the distribution +/// - count: number of samples to draw +/// - key: PRNG key +/// - stream: stream or device to evaluate on @available(*, deprecated, message: "categorical is now available in the main MLX module") @_disfavoredOverload public func categorical( @@ -537,6 +548,8 @@ public func categorical( /// - dtype: type of the output /// - loc: mean of the distribution /// - scale: scale "b" of the distribution +/// - key: PRNG key +/// - stream: stream or device to evaluate on @available(*, deprecated, message: "laplace is now available in the main MLX module") @_disfavoredOverload public func laplace( From 6a520c110c929688c5d0bf901d70563e801371a5 Mon Sep 17 00:00:00 2001 From: Anthony DePasquale Date: Fri, 3 Apr 2026 18:39:15 +0200 Subject: [PATCH 3/3] Discover doc verification targets dynamically and report all failures --- scripts/verify-docs.sh | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/scripts/verify-docs.sh b/scripts/verify-docs.sh index 74494904..e67666b9 100755 --- a/scripts/verify-docs.sh +++ b/scripts/verify-docs.sh @@ -1,21 +1,36 @@ #!/bin/bash # Verify documentation builds without warnings -set -e cd "$(dirname "$0")/.." export MLX_SWIFT_BUILD_DOC=1 -TARGETS=("MLX" "MLXRandom" "MLXNN" "MLXOptimizers" "MLXFFT" "MLXLinalg" "MLXFast") +# Discover library product targets from Package.swift, skipping test/macro/executable targets +TARGETS=$(swift package dump-package | python3 -c " +import json, sys +pkg = json.load(sys.stdin) +targets = set() +for p in pkg['products']: + if p['type'].get('library') is not None: + targets.update(p['targets']) +for t in sorted(targets): + print(t) +") + +if [ -z "$TARGETS" ]; then + echo "No targets found." + exit 1 +fi + FAILED=0 -for TARGET in "${TARGETS[@]}"; do +while IFS= read -r TARGET; do echo "Building documentation for $TARGET..." if ! swift package generate-documentation --target "$TARGET" --warnings-as-errors; then FAILED=1 fi echo "" -done +done <<< "$TARGETS" if [ "$FAILED" -ne 0 ]; then echo "Documentation build failed with warnings."