diff --git a/Source/MLX/Ops.swift b/Source/MLX/Ops.swift index c640eb44..871ed948 100644 --- a/Source/MLX/Ops.swift +++ b/Source/MLX/Ops.swift @@ -1094,8 +1094,26 @@ public func depends(inputs: [MLXArray], dependencies: [MLXArray]) -> [MLXArray] /// Quantization reduces the precision of model weights to decrease memory usage and /// potentially improve inference speed. Different modes use different strategies for /// mapping full-precision values to lower-precision representations. -public enum QuantizationMode: String, Codable, Sendable { - /// Affine (linear) quantization with scale and bias parameters. +/// Mode de quantification pour la compression des poids de réseaux de neurones. +/// +/// Quantization reduces the precision of model weights to decrease memory usage and +/// potentially improve inference speed. Different modes use different strategies for +/// mapping full-precision values to lower-precision representations. +/// +/// ## Utilisation +/// +/// ```swift +/// // Mode affine avec paramètres par défaut (groupSize: 64, bits: 4) +/// let mode = QuantizationMode.affine() +/// +/// // Mode affine avec paramètres personnalisés +/// let mode = QuantizationMode.affine(groupSize: 32, bits: 8) +/// +/// // Mode MXFP4 (paramètres fixes : groupSize = 32, bits = 4) +/// let mode = QuantizationMode.mxfp4 +/// ``` +public enum QuantizationMode: Equatable, Sendable { + /// Affine (linear) quantization with configurable group size and bit width. /// /// This is the standard quantization approach where values are quantized using: /// ``` @@ -1103,24 +1121,123 @@ public enum QuantizationMode: String, Codable, Sendable { /// dequantized_value = quantized_value * scale + bias /// ``` /// - /// The `scale` and `bias` parameters are computed per group of elements (typically 32 or 64 elements) - /// to minimize quantization error. This mode provides good compression with reasonable accuracy preservation - /// for most neural network weights. - case affine + /// The `scale` and `bias` parameters are computed per group of `groupSize` elements + /// to minimize quantization error. This mode provides good compression with reasonable + /// accuracy preservation for most neural network weights. + /// + /// - Parameters: + /// - groupSize: Number of elements per quantization group. Default is 64. + /// - bits: Number of bits per quantized element. Default is 4. + case affine(groupSize: Int = 64, bits: Int = 4) /// MX (Microscaling) FP4 quantization format. /// + /// Fixed parameters: groupSize = 32, bits = 4. + /// /// ### See Also /// - https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf case mxfp4 /// MX (Microscaling) FP8 quantization format. /// + /// Fixed parameters: groupSize = 32, bits = 8. + /// /// ### See Also /// - https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf case mxfp8 + /// NVIDIA FP4 quantization format. case nvfp4 + + /// Nom de la valeur transmis à l'API C sous-jacente. + var cName: String { + switch self { + case .affine: + return "affine" + case .mxfp4: + return "mxfp4" + case .mxfp8: + return "mxfp8" + case .nvfp4: + return "nvfp4" + } + } + + /// Taille de groupe effective pour ce mode de quantification. + /// + /// Pour `.affine`, retourne la valeur de l'argument associé. + /// Pour les modes MX/NV, la valeur est fixe selon le format. + public var groupSize: Int { + switch self { + case .affine(let groupSize, _): + return groupSize + case .mxfp4, .nvfp4: + return 32 + case .mxfp8: + return 32 + } + } + + /// Nombre de bits par élément pour ce mode de quantification. + /// + /// Pour `.affine`, retourne la valeur de l'argument associé. + /// Pour les modes MX/NV, la valeur est fixe selon le format. + public var bits: Int { + switch self { + case .affine(_, let bits): + return bits + case .mxfp4, .nvfp4: + return 4 + case .mxfp8: + return 8 + } + } +} + +extension QuantizationMode: Codable { + /// Clés de codage pour la sérialisation JSON. + private enum CodingKeys: String, CodingKey { + case type + case groupSize + case bits + } + + public init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + let type = try container.decode(String.self, forKey: .type) + switch type { + case "affine": + let groupSize = try container.decodeIfPresent(Int.self, forKey: .groupSize) ?? 64 + let bits = try container.decodeIfPresent(Int.self, forKey: .bits) ?? 4 + self = .affine(groupSize: groupSize, bits: bits) + case "mxfp4": + self = .mxfp4 + case "mxfp8": + self = .mxfp8 + case "nvfp4": + self = .nvfp4 + default: + throw DecodingError.dataCorruptedError( + forKey: .type, in: container, + debugDescription: "Valeur inconnue pour QuantizationMode : \(type)") + } + } + + public func encode(to encoder: Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + switch self { + case .affine(let groupSize, let bits): + try container.encode("affine", forKey: .type) + try container.encode(groupSize, forKey: .groupSize) + try container.encode(bits, forKey: .bits) + case .mxfp4: + try container.encode("mxfp4", forKey: .type) + case .mxfp8: + try container.encode("mxfp8", forKey: .type) + case .nvfp4: + try container.encode("nvfp4", forKey: .type) + } + } } /// Dequantize the matrix `w` using the provided `scales` and @@ -1146,7 +1263,7 @@ public enum QuantizationMode: String, Codable, Sendable { public func dequantized( _ w: MLXArray, scales: MLXArray, biases: MLXArray?, - groupSize: Int? = nil, bits: Int? = nil, mode: QuantizationMode = .affine, + groupSize: Int? = nil, bits: Int? = nil, mode: QuantizationMode = .affine(), globalScale: MLXArray? = nil, dtype: DType? = nil, stream: StreamOrDevice = .default @@ -1157,7 +1274,7 @@ public func dequantized( let dtype = mlx_optional_dtype(value: dtype?.cmlxDtype ?? MLX_FLOAT16, has_value: dtype != nil) mlx_dequantize( &result, w.ctx, - scales.ctx, (biases ?? .mlxNone).ctx, gs, bits, mode.rawValue, + scales.ctx, (biases ?? .mlxNone).ctx, gs, bits, mode.cName, (globalScale ?? .mlxNone).ctx, dtype, stream.ctx) @@ -1428,7 +1545,7 @@ public func gatherQuantizedMatmul( _ x: MLXArray, _ w: MLXArray, scales: MLXArray, biases: MLXArray?, lhsIndices: MLXArray? = nil, rhsIndices: MLXArray? = nil, transpose: Bool = true, groupSize: Int? = nil, bits: Int? = nil, - mode: QuantizationMode = .affine, + mode: QuantizationMode = .affine(), sortedIndices: Bool = false, stream: StreamOrDevice = .default ) -> MLXArray { @@ -1469,7 +1586,7 @@ public func gatherQuantizedMM( _ x: MLXArray, _ w: MLXArray, scales: MLXArray, biases: MLXArray?, lhsIndices: MLXArray? = nil, rhsIndices: MLXArray? = nil, transpose: Bool = true, groupSize: Int? = nil, bits: Int? = nil, - mode: QuantizationMode = .affine, + mode: QuantizationMode = .affine(), sortedIndices: Bool = false, stream: StreamOrDevice = .default ) -> MLXArray { @@ -1482,7 +1599,7 @@ public func gatherQuantizedMM( &result, x.ctx, w.ctx, scales.ctx, (biases ?? .mlxNone).ctx, (lhsIndices ?? .mlxNone).ctx, (rhsIndices ?? .mlxNone).ctx, transpose, - gs, bits, mode.rawValue, sortedIndices, + gs, bits, mode.cName, sortedIndices, stream.ctx) return MLXArray(result) @@ -2196,7 +2313,7 @@ public func padded( mlx_pad( &result, array.ctx, axes, ndim, lowPads, ndim, highPads, ndim, value.ctx, - mode.rawValue.cString(using: .utf8), stream.ctx) + mode.cName.cString(using: .utf8), stream.ctx) return MLXArray(result) } @@ -2227,7 +2344,7 @@ public func padded( mlx_pad( &result, array.ctx, axes, ndim, lowPads, ndim, highPads, ndim, value.ctx, - mode.rawValue.cString(using: .utf8), stream.ctx) + mode.cName.cString(using: .utf8), stream.ctx) return MLXArray(result) } @@ -2353,7 +2470,7 @@ public func putAlong( public func quantized( _ w: MLXArray, groupSize: Int? = nil, bits: Int? = nil, - mode: QuantizationMode = .affine, + mode: QuantizationMode = .affine(), globalScale: MLXArray? = nil, stream: StreamOrDevice = .default ) -> (wq: MLXArray, scales: MLXArray, biases: MLXArray?) { @@ -2364,7 +2481,7 @@ public func quantized( let bits = mlx_optional_int(value: Int32(bits ?? 0), has_value: bits != nil) mlx_quantize( - &r, w.ctx, gs, bits, mode.rawValue, + &r, w.ctx, gs, bits, mode.cName, (globalScale ?? .mlxNone).ctx, stream.ctx) @@ -2379,7 +2496,7 @@ public func quantizedMatmul( _ x: MLXArray, _ w: MLXArray, scales: MLXArray, biases: MLXArray?, transpose: Bool = true, groupSize: Int? = nil, bits: Int? = nil, - mode: QuantizationMode = .affine, + mode: QuantizationMode = .affine(), stream: StreamOrDevice = .default ) -> MLXArray { quantizedMM( @@ -2413,7 +2530,7 @@ public func quantizedMM( _ x: MLXArray, _ w: MLXArray, scales: MLXArray, biases: MLXArray?, transpose: Bool = true, groupSize: Int? = nil, bits: Int? = nil, - mode: QuantizationMode = .affine, + mode: QuantizationMode = .affine(), stream: StreamOrDevice = .default ) -> MLXArray { var result = mlx_array_new() @@ -2425,7 +2542,7 @@ public func quantizedMM( &result, x.ctx, w.ctx, scales.ctx, (biases ?? .mlxNone).ctx, transpose, gs, bits, - mode.rawValue, + mode.cName, stream.ctx ) return MLXArray(result) @@ -2476,7 +2593,7 @@ public func quantizedQuantizedMM( &result, x.ctx, w.ctx, (scales ?? .mlxNone).ctx, gs, bits, - mode.rawValue, + mode.cName, (globalScaleX ?? .mlxNone).ctx, (globalScaleW ?? .mlxNone).ctx, stream.ctx diff --git a/Source/MLXNN/Quantized.swift b/Source/MLXNN/Quantized.swift index 076e91ca..908764a3 100644 --- a/Source/MLXNN/Quantized.swift +++ b/Source/MLXNN/Quantized.swift @@ -16,7 +16,7 @@ public protocol Quantizable { extension Quantizable { public func toQuantized(groupSize: Int, bits: Int) -> Module { - toQuantized(groupSize: groupSize, bits: bits, mode: .affine) + toQuantized(groupSize: groupSize, bits: bits, mode: .affine()) } } @@ -29,10 +29,10 @@ public protocol Quantized: Module { /// Quantize any ``Quantizable`` layer that is not already quantized. public func quantizeSingle( - layer: Module, groupSize: Int = 64, bits: Int = 4, mode: QuantizationMode = .affine + layer: Module, groupSize: Int = 64, bits: Int = 4, mode: QuantizationMode = .affine() ) -> Quantized? { if layer is Quantized { - // already quantized + // Déjà quantifié — on ne le retraite pas nil } else if let quantizable = layer as? Quantizable { quantizable.toQuantized(groupSize: groupSize, bits: bits, mode: mode) as? Quantized @@ -56,7 +56,7 @@ public func quantizeSingle( /// - ``quantize(model:filter:apply:)-(_,_,(Module,Int,Int,QuantizationMode)->Module?)`` public func quantize( model: Module, - groupSize: Int = 64, bits: Int = 4, mode: QuantizationMode = .affine, + groupSize: Int = 64, bits: Int = 4, mode: QuantizationMode = .affine(), filter: (String, Module) -> Bool = { _, _ in true }, apply: (Module, Int, Int, QuantizationMode) -> Module? = quantizeSingle( layer:groupSize:bits:mode:) @@ -84,11 +84,11 @@ public func quantize( model: Module, groupSize: Int = 64, bits: Int = 4, filter: (String, Module) -> Bool = { _, _ in true }, apply: (Module, Int, Int) -> Module? = { - quantizeSingle(layer: $0, groupSize: $1, bits: $2, mode: .affine) + quantizeSingle(layer: $0, groupSize: $1, bits: $2, mode: .affine()) } ) { quantize( - model: model, groupSize: groupSize, bits: bits, mode: .affine, filter: filter, + model: model, groupSize: groupSize, bits: bits, mode: .affine(), filter: filter, apply: { l, g, b, n in apply(l, g, b) } ) } @@ -132,14 +132,14 @@ public func quantize( model: Module, filter: (String, Module) -> (groupSize: Int, bits: Int)?, apply: (Module, Int, Int) -> Module? = { - quantizeSingle(layer: $0, groupSize: $1, bits: $2, mode: .affine) + quantizeSingle(layer: $0, groupSize: $1, bits: $2, mode: .affine()) } ) { quantize( model: model, filter: { if let (g, b) = filter($0, $1) { - return (g, b, .affine) + return (g, b, .affine()) } else { return nil } @@ -167,7 +167,7 @@ open class QuantizedEmbedding: Embedding, Quantized { convenience public init( embeddingCount: Int, dimensions: Int, groupSize: Int = 64, bits: Int = 4, - mode: QuantizationMode = .affine + mode: QuantizationMode = .affine() ) { let scale = sqrt(1 / Float(dimensions)) let weight = MLXRandom.normal([embeddingCount, dimensions]) * scale @@ -177,14 +177,14 @@ open class QuantizedEmbedding: Embedding, Quantized { public convenience init( _ other: Embedding, groupSize: Int = 64, bits: Int = 4, - mode: QuantizationMode = .affine + mode: QuantizationMode = .affine() ) { self.init(weight: other.weight, groupSize: groupSize, bits: bits, mode: mode) } public init( weight: MLXArray, groupSize: Int = 64, bits: Int = 4, - mode: QuantizationMode = .affine + mode: QuantizationMode = .affine() ) { self.groupSize = groupSize self.bits = bits @@ -263,7 +263,7 @@ open class QuantizedLinear: Linear, Quantized { public convenience init( _ inputDimensions: Int, _ outputDimensions: Int, bias: Bool = true, groupSize: Int = 64, bits: Int = 4, - mode: QuantizationMode = .affine + mode: QuantizationMode = .affine() ) { let scale = sqrt(1 / Float(inputDimensions)) let weight = MLXRandom.uniform( @@ -283,7 +283,7 @@ open class QuantizedLinear: Linear, Quantized { /// - mode: quantization mode public convenience init( _ other: Linear, groupSize: Int = 64, bits: Int = 4, - mode: QuantizationMode = .affine + mode: QuantizationMode = .affine() ) { self.init( weight: other.weight, bias: other.bias, groupSize: groupSize, bits: bits, mode: mode) @@ -292,7 +292,7 @@ open class QuantizedLinear: Linear, Quantized { /// Initialize a ``QuantizedLinear`` with non-quantized weights and bias. public init( weight: MLXArray, bias: MLXArray?, groupSize: Int = 64, bits: Int = 4, - mode: QuantizationMode = .affine + mode: QuantizationMode = .affine() ) { self.groupSize = groupSize self.bits = bits @@ -316,7 +316,7 @@ open class QuantizedLinear: Linear, Quantized { public init( weight: MLXArray, bias: MLXArray? = nil, scales: MLXArray, biases: MLXArray?, groupSize: Int, bits: Int, - mode: QuantizationMode = .affine + mode: QuantizationMode = .affine() ) { self.groupSize = groupSize self.bits = bits