From fc38f3ca7c61d8d28aebbf3a242648e861fa4740 Mon Sep 17 00:00:00 2001 From: David Koski Date: Mon, 6 Apr 2026 09:15:45 -0700 Subject: [PATCH 1/2] deprecate factory methods that have an implicit Float/.float32 type - fix #390 specifically this logically deprecates functions like this: static public func ones( _ shape: some Collection, type: (some HasDType).Type = Float.self, stream: StreamOrDevice = .default ) -> MLXArray { MLX.ones(shape, type: type, stream: stream) } calling with no type will get a deprecation warning --- Source/MLX/Factory.swift | 358 +++++++++++++++++++- Tests/MLXTests/MLXArray+IndexingTests.swift | 4 +- Tests/MLXTests/MLXArray+InitTests.swift | 26 ++ Tests/MLXTests/MLXRandomTests.swift | 4 +- Tests/MLXTests/ModuleTests.swift | 7 +- Tests/MLXTests/OpsTests.swift | 2 +- Tests/MLXTests/OptimizerTests.swift | 4 +- Tests/MLXTests/SaveTests.swift | 14 +- 8 files changed, 391 insertions(+), 28 deletions(-) diff --git a/Source/MLX/Factory.swift b/Source/MLX/Factory.swift index 91354051..1568864a 100644 --- a/Source/MLX/Factory.swift +++ b/Source/MLX/Factory.swift @@ -23,12 +23,31 @@ extension MLXArray { /// - ``zeros(like:stream:)`` /// - ``ones(_:type:stream:)`` static public func zeros( - _ shape: some Collection, type: (some HasDType).Type = Float.self, + _ shape: some Collection, type: (some HasDType).Type, stream: StreamOrDevice = .default ) -> MLXArray { MLX.zeros(shape, type: type, stream: stream) } + /// Construct an array of zeros with default `Float` type. + /// + /// - Parameters: + /// - shape: desired shape + /// - stream: stream or device to evaluate on + /// + /// ### See Also + /// - + /// - ``zeros(_:type:stream:)`` + @available( + *, deprecated, message: "specify the type explicitly, e.g. zeros([...], type: Float.self)" + ) + static public func zeros( + _ shape: some Collection, + stream: StreamOrDevice = .default + ) -> MLXArray { + MLX.zeros(shape, type: Float.self, stream: stream) + } + /// Construct an array of zeros with a given ``DType`` /// /// Example: @@ -91,12 +110,31 @@ extension MLXArray { /// - ``ones(like:stream:)`` /// - ``zeros(_:type:stream:)`` static public func ones( - _ shape: some Collection, type: (some HasDType).Type = Float.self, + _ shape: some Collection, type: (some HasDType).Type, stream: StreamOrDevice = .default ) -> MLXArray { MLX.ones(shape, type: type, stream: stream) } + /// Construct an array of ones with default `Float` type. + /// + /// - Parameters: + /// - shape: desired shape + /// - stream: stream or device to evaluate on + /// + /// ### See Also + /// - + /// - ``ones(_:type:stream:)`` + @available( + *, deprecated, message: "specify the type explicitly, e.g. ones([...], type: Float.self)" + ) + static public func ones( + _ shape: some Collection, + stream: StreamOrDevice = .default + ) -> MLXArray { + MLX.ones(shape, type: Float.self, stream: stream) + } + /// Construct an array of ones with a given ``DType`` /// /// Example: @@ -161,12 +199,33 @@ extension MLXArray { /// - /// - ``identity(_:type:stream:)`` static public func eye( - _ n: Int, m: Int? = nil, k: Int = 0, type: (some HasDType).Type = Float.self, + _ n: Int, m: Int? = nil, k: Int = 0, type: (some HasDType).Type, stream: StreamOrDevice = .default ) -> MLXArray { MLX.eye(n, m: m, k: k, type: type, stream: stream) } + /// Create an identity matrix or a general diagonal matrix with default `Float` type. + /// + /// - Parameters: + /// - n: number of rows in the output + /// - m: number of columns in the output -- equal to `n` if not specified + /// - k: index of the diagonal + /// - stream: stream or device to evaluate on + /// + /// ### See Also + /// - + /// - ``eye(_:m:k:type:stream:)`` + @available( + *, deprecated, message: "specify the type explicitly, e.g. eye(..., type: Float.self)" + ) + static public func eye( + _ n: Int, m: Int? = nil, k: Int = 0, + stream: StreamOrDevice = .default + ) -> MLXArray { + MLX.eye(n, m: m, k: k, type: Float.self, stream: stream) + } + /// Create an identity matrix or a general diagonal matrix given a ``DType``. /// /// Example: @@ -215,7 +274,7 @@ extension MLXArray { /// - ``full(_:values:stream:)`` /// - ``repeated(_:count:axis:stream:)`` static public func full( - _ shape: some Collection, values: MLXArray, type: (some HasDType).Type = Float.self, + _ shape: some Collection, values: MLXArray, type: (some HasDType).Type, stream: StreamOrDevice = .default ) -> MLXArray { MLX.full(shape, values: values, type: type, stream: stream) @@ -297,11 +356,29 @@ extension MLXArray { /// - /// - ``eye(_:m:k:type:stream:)`` static public func identity( - _ n: Int, type: (some HasDType).Type = Float.self, stream: StreamOrDevice = .default + _ n: Int, type: (some HasDType).Type, stream: StreamOrDevice = .default ) -> MLXArray { MLX.identity(n, type: type, stream: stream) } + /// Create a square identity matrix with default `Float` type. + /// + /// - Parameters: + /// - n: number of rows and columns in the output + /// - stream: stream or device to evaluate on + /// + /// ### See Also + /// - + /// - ``identity(_:type:stream:)`` + @available( + *, deprecated, message: "specify the type explicitly, e.g. identity(..., type: Float.self)" + ) + static public func identity( + _ n: Int, stream: StreamOrDevice = .default + ) -> MLXArray { + MLX.identity(n, type: Float.self, stream: stream) + } + /// Create a square identity matrix with a given ``DType``. /// /// Example: @@ -514,6 +591,74 @@ extension MLXArray { MLX.arange(start, stop, step: step, dtype: dtype, stream: stream) } + /// Generate values in the half-open interval `[0, stop)` inferring the dtype from the input type. + /// + /// This generic overload infers the output dtype from the Swift type of `stop`. + /// For example, `arange(Int16(10))` produces an `.int16` array. + /// + /// - Parameters: + /// - stop: stop value + /// - stream: stream or device to evaluate on + /// + /// ### See Also + /// - + /// - ``arange(_:stream:)-(Int,_)`` + static public func arange( + _ stop: T, stream: StreamOrDevice = .default + ) -> MLXArray { + MLX.arange(stop, stream: stream) + } + + /// Generate values in the half-open interval `[start, stop)` spaced by `step`, inferring the dtype from the input type. + /// + /// - Parameters: + /// - start: start value + /// - stop: stop value + /// - step: step size (default: 1) + /// - stream: stream or device to evaluate on + /// + /// ### See Also + /// - + static public func arange( + _ start: T, _ stop: T, step: T = 1, stream: StreamOrDevice = .default + ) -> MLXArray { + MLX.arange(start, stop, step: step, stream: stream) + } + + /// Generate values in the half-open interval `[0, stop)` inferring the dtype from the input type (floating point version). + /// + /// This generic overload infers the output dtype from the Swift type of `stop`. + /// For example, `arange(Float16(5.0))` produces a `.float16` array. + /// + /// - Parameters: + /// - stop: stop value + /// - stream: stream or device to evaluate on + /// + /// ### See Also + /// - + /// - ``arange(_:dtype:stream:)-(Double,_,_)`` + static public func arange( + _ stop: T, stream: StreamOrDevice = .default + ) -> MLXArray { + MLX.arange(stop, stream: stream) + } + + /// Generate values in the half-open interval `[start, stop)` spaced by `step`, inferring the dtype from the input type (floating point version). + /// + /// - Parameters: + /// - start: start value + /// - stop: stop value + /// - step: step size + /// - stream: stream or device to evaluate on + /// + /// ### See Also + /// - + static public func arange( + _ start: T, _ stop: T, step: T, stream: StreamOrDevice = .default + ) -> MLXArray { + MLX.arange(start, stop, step: step, stream: stream) + } + /// Repeat an array along a specified axis. /// /// > Deprected in favor of the more consistently named ``repeated(_:count:axis:stream:)`` @@ -612,12 +757,33 @@ extension MLXArray { /// ### See Also /// - static public func tri( - _ n: Int, m: Int? = nil, k: Int = 0, type: (some HasDType).Type = Float.self, + _ n: Int, m: Int? = nil, k: Int = 0, type: (some HasDType).Type, stream: StreamOrDevice = .default ) -> MLXArray { MLX.tri(n, m: m, k: k, type: type, stream: stream) } + /// An array with ones at and below the given diagonal and zeros elsewhere, with default `Float` type. + /// + /// - Parameters: + /// - n: number of rows in the output + /// - m: number of columns in the output -- equal to `n` if not specified + /// - k: index of the diagonal + /// - stream: stream or device to evaluate on + /// + /// ### See Also + /// - + /// - ``tri(_:m:k:type:stream:)`` + @available( + *, deprecated, message: "specify the type explicitly, e.g. tri(..., type: Float.self)" + ) + static public func tri( + _ n: Int, m: Int? = nil, k: Int = 0, + stream: StreamOrDevice = .default + ) -> MLXArray { + MLX.tri(n, m: m, k: k, type: Float.self, stream: stream) + } + /// An array with ones at and below the given diagonal and zeros elsewhere and a given ``DType``. /// /// Example: @@ -662,7 +828,7 @@ extension MLXArray { /// - ``zeros(like:stream:)`` /// - ``ones(_:type:stream:)`` public func zeros( - _ shape: some Collection, type: (some HasDType).Type = Float.self, + _ shape: some Collection, type: (some HasDType).Type, stream: StreamOrDevice = .default ) -> MLXArray { var result = mlx_array_new() @@ -670,6 +836,25 @@ public func zeros( return MLXArray(result) } +/// Construct an array of zeros with default `Float` type. +/// +/// - Parameters: +/// - shape: desired shape +/// - stream: stream or device to evaluate on +/// +/// ### See Also +/// - +/// - ``zeros(_:type:stream:)`` +@available( + *, deprecated, message: "specify the type explicitly, e.g. zeros([...], type: Float.self)" +) +public func zeros( + _ shape: some Collection, + stream: StreamOrDevice = .default +) -> MLXArray { + zeros(shape, type: Float.self, stream: stream) +} + /// Construct an array of zeros with a given ``DType`` /// /// Example: @@ -736,7 +921,7 @@ public func zeros(like array: MLXArray, stream: StreamOrDevice = .default) -> ML /// - ``ones(like:stream:)`` /// - ``zeros(_:type:stream:)`` public func ones( - _ shape: some Collection, type: (some HasDType).Type = Float.self, + _ shape: some Collection, type: (some HasDType).Type, stream: StreamOrDevice = .default ) -> MLXArray { var result = mlx_array_new() @@ -744,6 +929,25 @@ public func ones( return MLXArray(result) } +/// Construct an array of ones with default `Float` type. +/// +/// - Parameters: +/// - shape: desired shape +/// - stream: stream or device to evaluate on +/// +/// ### See Also +/// - +/// - ``ones(_:type:stream:)`` +@available( + *, deprecated, message: "specify the type explicitly, e.g. ones([...], type: Float.self)" +) +public func ones( + _ shape: some Collection, + stream: StreamOrDevice = .default +) -> MLXArray { + ones(shape, type: Float.self, stream: stream) +} + /// Construct an array of ones with a given ``DType`` /// /// Example: @@ -812,7 +1016,7 @@ public func ones(like array: MLXArray, stream: StreamOrDevice = .default) -> MLX /// - /// - ``identity(_:type:stream:)`` public func eye( - _ n: Int, m: Int? = nil, k: Int = 0, type: (some HasDType).Type = Float.self, + _ n: Int, m: Int? = nil, k: Int = 0, type: (some HasDType).Type, stream: StreamOrDevice = .default ) -> MLXArray { var result = mlx_array_new() @@ -820,6 +1024,25 @@ public func eye( return MLXArray(result) } +/// Create an identity matrix or a general diagonal matrix with default `Float` type. +/// +/// - Parameters: +/// - n: number of rows in the output +/// - m: number of columns in the output -- equal to `n` if not specified +/// - k: index of the diagonal +/// - stream: stream or device to evaluate on +/// +/// ### See Also +/// - +/// - ``eye(_:m:k:type:stream:)`` +@available(*, deprecated, message: "specify the type explicitly, e.g. eye(..., type: Float.self)") +public func eye( + _ n: Int, m: Int? = nil, k: Int = 0, + stream: StreamOrDevice = .default +) -> MLXArray { + eye(n, m: m, k: k, type: Float.self, stream: stream) +} + /// Create an identity matrix or a general diagonal matrix given a ``DType``. /// /// Example: @@ -957,13 +1180,31 @@ public func full( /// - /// - ``eye(_:m:k:type:stream:)`` public func identity( - _ n: Int, type: (some HasDType).Type = Float.self, stream: StreamOrDevice = .default + _ n: Int, type: (some HasDType).Type, stream: StreamOrDevice = .default ) -> MLXArray { var result = mlx_array_new() mlx_identity(&result, n.int32, type.dtype.cmlxDtype, stream.ctx) return MLXArray(result) } +/// Create a square identity matrix with default `Float` type. +/// +/// - Parameters: +/// - n: number of rows and columns in the output +/// - stream: stream or device to evaluate on +/// +/// ### See Also +/// - +/// - ``identity(_:type:stream:)`` +@available( + *, deprecated, message: "specify the type explicitly, e.g. identity(..., type: Float.self)" +) +public func identity( + _ n: Int, stream: StreamOrDevice = .default +) -> MLXArray { + identity(n, type: Float.self, stream: stream) +} + /// Create a square identity matrix with a given ``DType``. /// /// Example: @@ -1191,6 +1432,82 @@ public func arange( return MLXArray(result) } +/// Generate values in the half-open interval `[0, stop)` inferring the dtype from the input type. +/// +/// This generic overload infers the output dtype from the Swift type of `stop`. +/// For example, `arange(Int16(10))` produces an `.int16` array. +/// +/// - Parameters: +/// - stop: stop value +/// - stream: stream or device to evaluate on +/// +/// ### See Also +/// - +/// - ``arange(_:stream:)-(Int,_)`` +public func arange( + _ stop: T, stream: StreamOrDevice = .default +) -> MLXArray { + var result = mlx_array_new() + mlx_arange(&result, 0, Double(stop), 1, T.dtype.cmlxDtype, stream.ctx) + return MLXArray(result) +} + +/// Generate values in the half-open interval `[start, stop)` spaced by `step`, inferring the dtype from the input type. +/// +/// - Parameters: +/// - start: start value +/// - stop: stop value +/// - step: step size (default: 1) +/// - stream: stream or device to evaluate on +/// +/// ### See Also +/// - +public func arange( + _ start: T, _ stop: T, step: T = 1, stream: StreamOrDevice = .default +) -> MLXArray { + var result = mlx_array_new() + mlx_arange(&result, Double(start), Double(stop), Double(step), T.dtype.cmlxDtype, stream.ctx) + return MLXArray(result) +} + +/// Generate values in the half-open interval `[0, stop)` inferring the dtype from the input type (floating point version). +/// +/// This generic overload infers the output dtype from the Swift type of `stop`. +/// For example, `arange(Float16(5.0))` produces a `.float16` array. +/// +/// - Parameters: +/// - stop: stop value +/// - stream: stream or device to evaluate on +/// +/// ### See Also +/// - +/// - ``arange(_:dtype:stream:)-(Double,_,_)`` +public func arange( + _ stop: T, stream: StreamOrDevice = .default +) -> MLXArray { + var result = mlx_array_new() + mlx_arange(&result, 0, Double(stop), 1, T.dtype.cmlxDtype, stream.ctx) + return MLXArray(result) +} + +/// Generate values in the half-open interval `[start, stop)` spaced by `step`, inferring the dtype from the input type (floating point version). +/// +/// - Parameters: +/// - start: start value +/// - stop: stop value +/// - step: step size +/// - stream: stream or device to evaluate on +/// +/// ### See Also +/// - +public func arange( + _ start: T, _ stop: T, step: T, stream: StreamOrDevice = .default +) -> MLXArray { + var result = mlx_array_new() + mlx_arange(&result, Double(start), Double(stop), Double(step), T.dtype.cmlxDtype, stream.ctx) + return MLXArray(result) +} + /// Repeat an array along a specified axis. /// /// > Deprected in favor of the more consistently named ``repeated(_:count:axis:stream:)`` @@ -1293,7 +1610,7 @@ public func repeated(_ array: MLXArray, count: Int, stream: StreamOrDevice = .de /// ### See Also /// - public func tri( - _ n: Int, m: Int? = nil, k: Int = 0, type: (some HasDType).Type = Float.self, + _ n: Int, m: Int? = nil, k: Int = 0, type: (some HasDType).Type, stream: StreamOrDevice = .default ) -> MLXArray { var result = mlx_array_new() @@ -1301,6 +1618,25 @@ public func tri( return MLXArray(result) } +/// An array with ones at and below the given diagonal and zeros elsewhere, with default `Float` type. +/// +/// - Parameters: +/// - n: number of rows in the output +/// - m: number of columns in the output -- equal to `n` if not specified +/// - k: index of the diagonal +/// - stream: stream or device to evaluate on +/// +/// ### See Also +/// - +/// - ``tri(_:m:k:type:stream:)`` +@available(*, deprecated, message: "specify the type explicitly, e.g. tri(..., type: Float.self)") +public func tri( + _ n: Int, m: Int? = nil, k: Int = 0, + stream: StreamOrDevice = .default +) -> MLXArray { + tri(n, m: m, k: k, type: Float.self, stream: stream) +} + /// An array with ones at and below the given diagonal and zeros elsewhere and a given ``DType``. /// /// Example: diff --git a/Tests/MLXTests/MLXArray+IndexingTests.swift b/Tests/MLXTests/MLXArray+IndexingTests.swift index 68673ff5..85e7f407 100644 --- a/Tests/MLXTests/MLXArray+IndexingTests.swift +++ b/Tests/MLXTests/MLXArray+IndexingTests.swift @@ -643,8 +643,8 @@ class MLXArrayIndexingTests: XCTestCase { public func testSliceWithBroadcast() { // https://github.com/ml-explore/mlx-swift/issues/76 - let a = MLXArray.ones([2, 6, 6, 6]) - let b = MLXArray.zeros([3, 4, 4, 4]) + let a = MLXArray.ones([2, 6, 6, 6], type: Float.self) + let b = MLXArray.zeros([3, 4, 4, 4], type: Float.self) b[0, 0 ..< 4, 3, 0 ..< 4] = a[0, 1 ..< 5, 5, 1 ..< 5] diff --git a/Tests/MLXTests/MLXArray+InitTests.swift b/Tests/MLXTests/MLXArray+InitTests.swift index 19574da0..965c7e4f 100644 --- a/Tests/MLXTests/MLXArray+InitTests.swift +++ b/Tests/MLXTests/MLXArray+InitTests.swift @@ -201,6 +201,32 @@ class MLXArrayInitTests: XCTestCase { XCTAssertEqual(c.shape, [0]) } + func testArangeDTypeInference() { + // Generic integer overloads infer dtype from Swift type + XCTAssertEqual(arange(Int16(10)).dtype, .int16) + XCTAssertEqual(arange(UInt8(5)).dtype, .uint8) + XCTAssertEqual(arange(Int8(0), Int8(10), step: Int8(2)).dtype, .int8) + XCTAssertEqual(arange(Int64(10)).dtype, .int64) + XCTAssertEqual(arange(UInt32(8)).dtype, .uint32) + + // Generic floating point overload + XCTAssertEqual(arange(Float(5.0)).dtype, .float32) + + // Existing concrete overloads remain unchanged + XCTAssertEqual(arange(10).dtype, .int32) + XCTAssertEqual(arange(5.0).dtype, .float32) + + // Static method versions + XCTAssertEqual(MLXArray.arange(Int16(10)).dtype, .int16) + XCTAssertEqual(MLXArray.arange(UInt8(5)).dtype, .uint8) + + #if !arch(x86_64) + // Float16 test (not available on x86_64) + XCTAssertEqual(arange(Float16(5.0)).dtype, .float16) + XCTAssertEqual(MLXArray.arange(Float16(3.0)).dtype, .float16) + #endif + } + func testData() { let data = Data([1, 2, 3, 4]) let a = MLXArray(data, [2, 2], type: UInt8.self) diff --git a/Tests/MLXTests/MLXRandomTests.swift b/Tests/MLXTests/MLXRandomTests.swift index e3906339..3a058e21 100644 --- a/Tests/MLXTests/MLXRandomTests.swift +++ b/Tests/MLXTests/MLXRandomTests.swift @@ -136,7 +136,7 @@ class MLXRandomTests: XCTestCase { func testLogits() { let key = MLXRandom.key(0) - let logits = MLXArray.zeros([5, 20]) + let logits = MLXArray.zeros([5, 20], type: Float.self) let result = MLXRandom.categorical(logits, key: key) XCTAssertEqual(result.shape, [5]) @@ -149,7 +149,7 @@ class MLXRandomTests: XCTestCase { func testLogitsCount() { let key = MLXRandom.key(0) - let logits = MLXArray.zeros([5, 20]) + let logits = MLXArray.zeros([5, 20], type: Float.self) let result = MLXRandom.categorical(logits, count: 2, key: key) XCTAssertEqual(result.shape, [5, 2]) diff --git a/Tests/MLXTests/ModuleTests.swift b/Tests/MLXTests/ModuleTests.swift index 288da783..a0cec76b 100644 --- a/Tests/MLXTests/ModuleTests.swift +++ b/Tests/MLXTests/ModuleTests.swift @@ -357,8 +357,8 @@ class ModuleTests: XCTestCase { @ModuleInfo var d: Linear? override init() { - _a.wrappedValue = MLXArray.zeros([10]) - _b.wrappedValue = MLXArray.zeros([10]) + _a.wrappedValue = MLXArray.zeros([10], type: Float.self) + _b.wrappedValue = MLXArray.zeros([10], type: Float.self) _c.wrappedValue = Linear(10, 10) _d.wrappedValue = Linear(10, 10) } @@ -836,7 +836,8 @@ class ModuleTests: XCTestCase { let loraScale = 1 / sqrt(Float(inputDimensions)) self._loraA.wrappedValue = MLXRandom.uniform( low: -loraScale, high: loraScale, [inputDimensions, rank]) - self._loraB.wrappedValue = MLXArray.zeros([rank, outputDimensions]) + self._loraB.wrappedValue = MLXArray.zeros( + [rank, outputDimensions], type: Float.self) super.init() diff --git a/Tests/MLXTests/OpsTests.swift b/Tests/MLXTests/OpsTests.swift index 46fd06cc..0a12edb9 100644 --- a/Tests/MLXTests/OpsTests.swift +++ b/Tests/MLXTests/OpsTests.swift @@ -87,7 +87,7 @@ class OpsTests: XCTestCase { } func testFlatten() { - let a = zeros([4, 5, 6, 7]) + let a = zeros([4, 5, 6, 7], type: Float.self) let b = flatten(a, startAxis: 1, endAxis: 2) let c = unflatten(b, axis: 1, shape: [5, 6]) assertEqual(a, c) diff --git a/Tests/MLXTests/OptimizerTests.swift b/Tests/MLXTests/OptimizerTests.swift index 74eda3b0..e5cc2f3c 100644 --- a/Tests/MLXTests/OptimizerTests.swift +++ b/Tests/MLXTests/OptimizerTests.swift @@ -14,8 +14,8 @@ class OptimizerTests: XCTestCase { } class ShapeModule: Module { - let first = [MLXArray.zeros([10]), MLXArray.zeros([1])] - let second = MLXArray.zeros([1]) + let first = [MLXArray.zeros([10], type: Float.self), MLXArray.zeros([1], type: Float.self)] + let second = MLXArray.zeros([1], type: Float.self) } func checkShape(optimizer: OptimizerBase) { diff --git a/Tests/MLXTests/SaveTests.swift b/Tests/MLXTests/SaveTests.swift index 6f700df4..07fd43b1 100644 --- a/Tests/MLXTests/SaveTests.swift +++ b/Tests/MLXTests/SaveTests.swift @@ -34,8 +34,8 @@ final class SaveTests: XCTestCase { ) let arrays: [String: MLXArray] = [ - "foo": MLX.ones([1, 2]), - "bar": MLX.zeros([2, 1]), + "foo": MLX.ones([1, 2], type: Float.self), + "bar": MLX.zeros([2, 1], type: Float.self), ] try MLX.save(arrays: arrays, url: safetensorsPath) @@ -54,7 +54,7 @@ final class SaveTests: XCTestCase { directoryHint: .notDirectory ) - let array = MLX.ones([2, 4]) + let array = MLX.ones([2, 4], type: Float.self) try MLX.save(array: array, url: path) @@ -65,8 +65,8 @@ final class SaveTests: XCTestCase { public func testSaveArraysData() throws { let arrays: [String: MLXArray] = [ - "foo": MLX.ones([1, 2]), - "bar": MLX.zeros([2, 1]), + "foo": MLX.ones([1, 2], type: Float.self), + "bar": MLX.zeros([2, 1], type: Float.self), ] let data = try saveToData(arrays: arrays) @@ -79,8 +79,8 @@ final class SaveTests: XCTestCase { public func testSaveArraysMetadataData() throws { let arrays: [String: MLXArray] = [ - "foo": MLX.ones([1, 2]), - "bar": MLX.zeros([2, 1]), + "foo": MLX.ones([1, 2], type: Float.self), + "bar": MLX.zeros([2, 1], type: Float.self), ] let metadata = [ "key": "value", From 8fb34518f50e5a53a5a230c7735479e4d313b8a0 Mon Sep 17 00:00:00 2001 From: David Koski Date: Mon, 6 Apr 2026 10:07:57 -0700 Subject: [PATCH 2/2] fix doc links --- .../MLX/Documentation.docc/Articles/converting-python.md | 2 +- Source/MLX/Factory.swift | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/Source/MLX/Documentation.docc/Articles/converting-python.md b/Source/MLX/Documentation.docc/Articles/converting-python.md index fa82f309..79bedce9 100644 --- a/Source/MLX/Documentation.docc/Articles/converting-python.md +++ b/Source/MLX/Documentation.docc/Articles/converting-python.md @@ -147,7 +147,7 @@ This is a mapping of `mx` free functions to their ``MLX`` counterparts. `all` | ``MLX/all(_:axes:keepDims:stream:)`` `allclose` | ``MLX/allClose(_:_:rtol:atol:equalNaN:stream:)`` `any` | ``MLX/any(_:axes:keepDims:stream:)`` -`arange` | ``MLX/arange(_:_:step:stream:)`` +`arange` | ``MLX/arange(_:stream:)-9st56`` `arccos` | ``MLX/acos(_:stream:)`` `arccosh` | ``MLX/acosh(_:stream:)`` `arcsin` | ``MLX/asin(_:stream:)`` diff --git a/Source/MLX/Factory.swift b/Source/MLX/Factory.swift index 1568864a..685f7679 100644 --- a/Source/MLX/Factory.swift +++ b/Source/MLX/Factory.swift @@ -465,7 +465,7 @@ extension MLXArray { /// /// ### See Also /// - - /// - ``arange(_:_:step:stream:)`` + /// - ``arange(_:_:step:stream:)-3wme1`` static public func arange(_ stop: Int, stream: StreamOrDevice = .default) -> MLXArray { MLX.arange(0, stop, stream: stream) } @@ -535,7 +535,7 @@ extension MLXArray { /// /// ### See Also /// - - /// - ``arange(_:_:step:stream:)`` + /// - ``arange(_:_:step:stream:)-3wme1`` static public func arange( _ start: Int, _ stop: Int, step: Int = 1, dtype: DType, stream: StreamOrDevice = .default ) -> MLXArray { @@ -1295,7 +1295,7 @@ public func linspace( /// /// ### See Also /// - -/// - ``arange(_:_:step:stream:)`` +/// - ``arange(_:_:step:stream:)-563go`` public func arange(_ stop: Int, stream: StreamOrDevice = .default) -> MLXArray { var result = mlx_array_new() mlx_arange(&result, 0, Double(stop), 1, DType.int32.cmlxDtype, stream.ctx) @@ -1370,7 +1370,7 @@ public func arange(_ stop: Int, dtype: DType, stream: StreamOrDevice = .default) /// /// ### See Also /// - -/// - ``arange(_:_:step:stream:)`` +/// - ``arange(_:_:step:stream:)-563go`` public func arange( _ start: Int, _ stop: Int, step: Int = 1, dtype: DType, stream: StreamOrDevice = .default ) -> MLXArray {