diff --git a/Source/MLX/MLXArray+NestedInit.swift b/Source/MLX/MLXArray+NestedInit.swift new file mode 100644 index 00000000..adb155f7 --- /dev/null +++ b/Source/MLX/MLXArray+NestedInit.swift @@ -0,0 +1,307 @@ +// Copyright © 2024 Apple Inc. + +import Cmlx +import Foundation + +// MARK: - Protocole de conversion de tableaux imbriqués + +/// Protocole permettant de convertir un tableau Swift imbriqué en MLXArray. +/// +/// Les types scalaires conformes à ``HasDType`` servent de feuilles, +/// et `Array` se conforme récursivement pour +/// gérer n'importe quelle profondeur d'imbrication. +/// +/// ### Exemple +/// ```swift +/// let matrix = MLXArray([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) +/// // shape: [2, 3], dtype: .float32 +/// +/// let cube = MLXArray([[[1, 0], [0, 1]], [[2, 0], [0, 2]]]) +/// // shape: [2, 2, 2], dtype: .int32 +/// ``` +public protocol MLXNestedArray { + /// Type scalaire des feuilles du tableau imbriqué. + associatedtype ScalarType: HasDType + + /// Forme (shape) du tableau imbriqué à ce niveau. + var mlxShape: [Int] { get } + + /// Valeurs aplaties dans l'ordre row-major (C-order). + func mlxFlattenedValues() -> [ScalarType] +} + +// MARK: - Conformance récursive de Array + +extension Array: MLXNestedArray where Element: MLXNestedArray { + public typealias ScalarType = Element.ScalarType + + /// Calcule la forme en ajoutant la dimension actuelle devant la shape de l'élément. + /// + /// Précondition : tous les éléments ont la même shape (tableau rectangulaire). + public var mlxShape: [Int] { + guard let first = first else { + // tableau vide — on retourne [0] suivi de la shape interne de zéro + return [0] + } + return [count] + first.mlxShape + } + + /// Aplatit récursivement tous les éléments en un tableau 1D de scalaires. + public func mlxFlattenedValues() -> [ScalarType] { + flatMap { $0.mlxFlattenedValues() } + } +} + +// MARK: - Conformances scalaires + +extension Bool: MLXNestedArray { + public typealias ScalarType = Bool + public var mlxShape: [Int] { [] } + public func mlxFlattenedValues() -> [Bool] { [self] } +} + +extension Int32: MLXNestedArray { + public typealias ScalarType = Int32 + public var mlxShape: [Int] { [] } + public func mlxFlattenedValues() -> [Int32] { [self] } +} + +extension Int64: MLXNestedArray { + public typealias ScalarType = Int64 + public var mlxShape: [Int] { [] } + public func mlxFlattenedValues() -> [Int64] { [self] } +} + +extension UInt8: MLXNestedArray { + public typealias ScalarType = UInt8 + public var mlxShape: [Int] { [] } + public func mlxFlattenedValues() -> [UInt8] { [self] } +} + +extension UInt16: MLXNestedArray { + public typealias ScalarType = UInt16 + public var mlxShape: [Int] { [] } + public func mlxFlattenedValues() -> [UInt16] { [self] } +} + +extension UInt32: MLXNestedArray { + public typealias ScalarType = UInt32 + public var mlxShape: [Int] { [] } + public func mlxFlattenedValues() -> [UInt32] { [self] } +} + +extension Float32: MLXNestedArray { + public typealias ScalarType = Float32 + public var mlxShape: [Int] { [] } + public func mlxFlattenedValues() -> [Float32] { [self] } +} + +extension Float64: MLXNestedArray { + public typealias ScalarType = Float64 + public var mlxShape: [Int] { [] } + public func mlxFlattenedValues() -> [Float64] { [self] } +} + +#if !arch(x86_64) + extension Float16: MLXNestedArray { + public typealias ScalarType = Float16 + public var mlxShape: [Int] { [] } + public func mlxFlattenedValues() -> [Float16] { [self] } + } +#endif + +// MARK: - Initialiseur MLXArray pour tableaux imbriqués + +extension MLXArray { + + /// Crée un ``MLXArray`` multi-dimensionnel à partir d'un tableau Swift imbriqué. + /// + /// Reproduit le comportement ergonomique de `mx.array([[1, 2], [3, 4]])` en Python. + /// La shape est déduite automatiquement de la structure d'imbrication. + /// + /// ```swift + /// // Tableau 2D : shape [2, 3], dtype .float32 + /// let matrix = MLXArray([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + /// + /// // Tableau 3D : shape [2, 2, 2], dtype .int32 + /// let cube = MLXArray([[[1, 0], [0, 1]], [[2, 0], [0, 2]]]) + /// ``` + /// + /// - Note: Le tableau doit être rectangulaire (tous les sous-tableaux à chaque + /// profondeur ont la même taille). La conformité est vérifiée par precondition. + /// + /// - Parameter nested: Tableau Swift imbriqué dont les feuilles sont des scalaires + /// conformes à ``HasDType``. + /// + /// ### See Also + /// - + public convenience init(_ nested: N) where N: Collection, N.Element: MLXNestedArray { + // Validation : tous les sous-tableaux à ce niveau doivent avoir la même shape + let shape = nested.mlxShape + let flatValues = nested.mlxFlattenedValues() + + // Vérifie la cohérence de la shape avec le nombre d'éléments aplatis + let expectedCount = shape.isEmpty ? 1 : shape.reduce(1, *) + precondition( + flatValues.count == expectedCount, + "Tableau imbriqué irrégulier : shape \(shape) attend \(expectedCount) éléments, \(flatValues.count) trouvés. " + + "Vérifiez que tous les sous-tableaux ont la même longueur." + ) + + self.init(flatValues, shape) + } + + /// Crée un ``MLXArray`` 2D à partir d'un tableau de tableaux. + /// + /// Surcharge dédiée aux tableaux 2D (le cas d'usage le plus fréquent), + /// offrant une meilleure inférence de type au call-site. + /// + /// ```swift + /// let matrix = MLXArray([[1, 2, 3], [4, 5, 6]]) + /// // shape: [2, 3], dtype: .int32 + /// + /// let floatMatrix = MLXArray([[0.5, 1.5], [2.5, 3.5]]) + /// // shape: [2, 2], dtype: .float32 + /// ``` + /// + /// - Parameter rows: Tableau 2D — chaque élément est une ligne. + /// + /// ### See Also + /// - + /// - ``init(_:)-([[MLXNestedArray]])`` + public convenience init(_ rows: [[T]]) { + let rowCount = rows.count + + guard rowCount > 0 else { + // tableau vide : shape [0] + self.init([T](), [0]) + return + } + + let colCount = rows[0].count + + // Vérifie que toutes les lignes ont la même largeur + precondition( + rows.allSatisfy { $0.count == colCount }, + "Tableau 2D irrégulier : toutes les lignes doivent avoir la même longueur (\(colCount) éléments attendus)." + ) + + let flat = rows.flatMap { $0 } + self.init(flat, [rowCount, colCount]) + } + + /// Crée un ``MLXArray`` 2D à partir d'un tableau de tableaux d'`Int`. + /// + /// Produit un tableau de dtype `.int32` (le comportement par défaut pour `Int` dans MLX Swift). + /// + /// ```swift + /// let a = MLXArray([[7, 8], [9, 10]]) + /// // shape: [2, 2], dtype: .int32 + /// ``` + /// + /// - Parameter rows: Tableau 2D d'entiers. + /// + /// ### See Also + /// - + public convenience init(_ rows: [[Int]]) { + let rowCount = rows.count + + guard rowCount > 0 else { + self.init([Int32](), [0]) + return + } + + let colCount = rows[0].count + + precondition( + rows.allSatisfy { $0.count == colCount }, + "Tableau 2D irrégulier : toutes les lignes doivent avoir la même longueur (\(colCount) éléments attendus)." + ) + + precondition( + rows.joined().allSatisfy { (Int(Int32.min)...Int(Int32.max)).contains($0) }, + "Valeur hors limites pour Int32 — utilisez [[Int32]] si les valeurs dépassent Int32.max." + ) + + let flat = rows.flatMap { $0 }.map { Int32($0) } + self.init(flat, [rowCount, colCount]) + } + + /// Crée un ``MLXArray`` 3D à partir d'un tableau de tableaux de tableaux. + /// + /// ```swift + /// let cube = MLXArray([[[1, 0], [0, 1]], [[2, 0], [0, 2]]]) + /// // shape: [2, 2, 2], dtype: .int32 + /// ``` + /// + /// - Parameter slices: Tableau 3D — chaque élément est une matrice 2D. + /// + /// ### See Also + /// - + public convenience init(_ slices: [[[T]]]) { + let depth = slices.count + + guard depth > 0 else { + self.init([T](), [0]) + return + } + + let rowCount = slices[0].count + let colCount = slices[0].first?.count ?? 0 + + precondition( + slices.allSatisfy { $0.count == rowCount }, + "Tableau 3D irrégulier : toutes les tranches doivent avoir \(rowCount) lignes." + ) + precondition( + slices.allSatisfy { $0.allSatisfy { $0.count == colCount } }, + "Tableau 3D irrégulier : toutes les lignes doivent avoir \(colCount) colonnes." + ) + + let flat = slices.flatMap { $0.flatMap { $0 } } + self.init(flat, [depth, rowCount, colCount]) + } + + /// Crée un ``MLXArray`` 3D à partir d'un tableau de tableaux de tableaux d'`Int`. + /// + /// Produit un tableau de dtype `.int32`. + /// + /// ```swift + /// let cube = MLXArray([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + /// // shape: [2, 2, 2], dtype: .int32 + /// ``` + /// + /// - Parameter slices: Tableau 3D d'entiers. + /// + /// ### See Also + /// - + public convenience init(_ slices: [[[Int]]]) { + let depth = slices.count + + guard depth > 0 else { + self.init([Int32](), [0]) + return + } + + let rowCount = slices[0].count + let colCount = slices[0].first?.count ?? 0 + + precondition( + slices.allSatisfy { $0.count == rowCount }, + "Tableau 3D irrégulier : toutes les tranches doivent avoir \(rowCount) lignes." + ) + precondition( + slices.allSatisfy { $0.allSatisfy { $0.count == colCount } }, + "Tableau 3D irrégulier : toutes les lignes doivent avoir \(colCount) colonnes." + ) + + let allValues = slices.joined().joined() + precondition( + allValues.allSatisfy { (Int(Int32.min)...Int(Int32.max)).contains($0) }, + "Valeur hors limites pour Int32 — utilisez [[[Int32]]] si les valeurs dépassent Int32.max." + ) + + let flat = slices.flatMap { $0.flatMap { $0 } }.map { Int32($0) } + self.init(flat, [depth, rowCount, colCount]) + } +} diff --git a/Tests/MLXTests/MLXArray+NestedInitTests.swift b/Tests/MLXTests/MLXArray+NestedInitTests.swift new file mode 100644 index 00000000..c615457a --- /dev/null +++ b/Tests/MLXTests/MLXArray+NestedInitTests.swift @@ -0,0 +1,145 @@ +// Copyright © 2024 Apple Inc. + +import Foundation +import XCTest + +@testable import MLX + +// Tests pour les initialiseurs de tableaux imbriqués (issue #161) +class MLXArrayNestedInitTests: XCTestCase { + + override class func setUp() { + setDefaultDevice() + } + + // MARK: - Tableaux 2D avec types génériques + + func testInit2DFloat() { + // Tableau 2D de Float32 + let matrix = MLXArray([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] as [[Float32]]) + XCTAssertEqual(matrix.shape, [2, 3]) + XCTAssertEqual(matrix.dtype, .float32) + XCTAssertEqual(matrix.size, 6) + } + + func testInit2DInt32() { + // Tableau 2D de Int32 + let matrix = MLXArray([[1, 2], [3, 4]] as [[Int32]]) + XCTAssertEqual(matrix.shape, [2, 2]) + XCTAssertEqual(matrix.dtype, .int32) + XCTAssertEqual(matrix.size, 4) + } + + func testInit2DInt() { + // Tableau 2D d'Int (surcharge dédiée, produit du .int32) + let matrix = MLXArray([[7, 8], [9, 10]]) + XCTAssertEqual(matrix.shape, [2, 2]) + XCTAssertEqual(matrix.dtype, .int32) + XCTAssertEqual(matrix.size, 4) + } + + func testInit2DIntValues() { + // Vérifie que les valeurs sont correctement stockées + let matrix = MLXArray([[1, 2], [3, 4]]) + let expected = MLXArray([1, 2, 3, 4] as [Int32], [2, 2]) + assertEqual(matrix, expected) + } + + func testInit2DFloatValues() { + // Vérifie que les valeurs float sont correctement stockées + let matrix = MLXArray([[1.0, 2.0], [3.0, 4.0]] as [[Float32]]) + let expected = MLXArray([1.0, 2.0, 3.0, 4.0] as [Float32], [2, 2]) + assertEqual(matrix, expected) + } + + func testInit2DNonSquare() { + // Matrice non carrée + let matrix = MLXArray([[1, 2, 3], [4, 5, 6]] as [[Int32]]) + XCTAssertEqual(matrix.shape, [2, 3]) + XCTAssertEqual(matrix.ndim, 2) + XCTAssertEqual(matrix.dim(0), 2) + XCTAssertEqual(matrix.dim(1), 3) + } + + func testInit2DFloat64() { + // Tableau 2D de Double (Float64) + let matrix = MLXArray([[1.5, 2.5], [3.5, 4.5]] as [[Double]]) + XCTAssertEqual(matrix.shape, [2, 2]) + XCTAssertEqual(matrix.dtype, .float64) + } + + func testInit2DBool() { + // Tableau 2D de Bool + let matrix = MLXArray([[true, false], [false, true]]) + XCTAssertEqual(matrix.shape, [2, 2]) + XCTAssertEqual(matrix.dtype, .bool) + } + + // MARK: - Tableaux 3D + + func testInit3DInt() { + // Cube 3D d'Int (produit du .int32) + let cube = MLXArray([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + XCTAssertEqual(cube.shape, [2, 2, 2]) + XCTAssertEqual(cube.dtype, .int32) + XCTAssertEqual(cube.size, 8) + } + + func testInit3DFloat() { + // Cube 3D de Float32 + let cube = MLXArray([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]] as [[[Float32]]]) + XCTAssertEqual(cube.shape, [2, 2, 2]) + XCTAssertEqual(cube.dtype, .float32) + XCTAssertEqual(cube.size, 8) + } + + func testInit3DIntValues() { + // Vérifie les valeurs pour le cas 3D + let cube = MLXArray([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + let expected = MLXArray([1, 2, 3, 4, 5, 6, 7, 8] as [Int32], [2, 2, 2]) + assertEqual(cube, expected) + } + + func testInit3DNonCubic() { + // Forme 3D non cubique : [2, 3, 4] + let tensor = MLXArray([ + [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], + [[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]], + ]) + XCTAssertEqual(tensor.shape, [2, 3, 4]) + XCTAssertEqual(tensor.ndim, 3) + XCTAssertEqual(tensor.size, 24) + } + + // MARK: - Tableaux vides + + func testInit2DEmptyRows() { + // Tableau avec 0 lignes + let empty = MLXArray([[Int32]]()) + XCTAssertEqual(empty.shape, [0]) + } + + func testInit3DEmpty() { + // Tableau 3D vide + let empty = MLXArray([[[Int32]]]()) + XCTAssertEqual(empty.shape, [0]) + } + + // MARK: - Compatibilité avec l'API existante + + func testNestedInitCompatibleWithExisting1D() { + // S'assure que la surcharge 2D n'interfère pas avec le init 1D existant + let oneDim = MLXArray([1, 2, 3, 4]) + XCTAssertEqual(oneDim.shape, [4]) + XCTAssertEqual(oneDim.ndim, 1) + } + + func testNestedInitIndexing() { + // Vérifie l'accès par index après construction imbriquée + let matrix = MLXArray([[10, 20], [30, 40]]) + let row0 = matrix[0] + XCTAssertEqual(row0.shape, [2]) + let val = matrix[0, 1].item(Int32.self) + XCTAssertEqual(val, 20) + } +}