Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 136 additions & 19 deletions Source/MLX/Ops.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1094,33 +1094,150 @@ public func depends(inputs: [MLXArray], dependencies: [MLXArray]) -> [MLXArray]
/// Quantization reduces the precision of model weights to decrease memory usage and
/// potentially improve inference speed. Different modes use different strategies for
/// mapping full-precision values to lower-precision representations.
public enum QuantizationMode: String, Codable, Sendable {
/// Affine (linear) quantization with scale and bias parameters.
/// Mode de quantification pour la compression des poids de réseaux de neurones.
///
/// Quantization reduces the precision of model weights to decrease memory usage and
/// potentially improve inference speed. Different modes use different strategies for
/// mapping full-precision values to lower-precision representations.
///
/// ## Utilisation
///
/// ```swift
/// // Mode affine avec paramètres par défaut (groupSize: 64, bits: 4)
/// let mode = QuantizationMode.affine()
///
/// // Mode affine avec paramètres personnalisés
/// let mode = QuantizationMode.affine(groupSize: 32, bits: 8)
///
/// // Mode MXFP4 (paramètres fixes : groupSize = 32, bits = 4)
/// let mode = QuantizationMode.mxfp4
/// ```
public enum QuantizationMode: Equatable, Sendable {
/// Affine (linear) quantization with configurable group size and bit width.
///
/// This is the standard quantization approach where values are quantized using:
/// ```
/// quantized_value = round((value - bias) / scale)
/// dequantized_value = quantized_value * scale + bias
/// ```
///
/// The `scale` and `bias` parameters are computed per group of elements (typically 32 or 64 elements)
/// to minimize quantization error. This mode provides good compression with reasonable accuracy preservation
/// for most neural network weights.
case affine
/// The `scale` and `bias` parameters are computed per group of `groupSize` elements
/// to minimize quantization error. This mode provides good compression with reasonable
/// accuracy preservation for most neural network weights.
///
/// - Parameters:
/// - groupSize: Number of elements per quantization group. Default is 64.
/// - bits: Number of bits per quantized element. Default is 4.
case affine(groupSize: Int = 64, bits: Int = 4)

/// MX (Microscaling) FP4 quantization format.
///
/// Fixed parameters: groupSize = 32, bits = 4.
///
/// ### See Also
/// - https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
case mxfp4

/// MX (Microscaling) FP8 quantization format.
///
/// Fixed parameters: groupSize = 32, bits = 8.
///
/// ### See Also
/// - https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
case mxfp8

/// NVIDIA FP4 quantization format.
case nvfp4

/// Nom de la valeur transmis à l'API C sous-jacente.
var cName: String {
switch self {
case .affine:
return "affine"
case .mxfp4:
return "mxfp4"
case .mxfp8:
return "mxfp8"
case .nvfp4:
return "nvfp4"
}
}

/// Taille de groupe effective pour ce mode de quantification.
///
/// Pour `.affine`, retourne la valeur de l'argument associé.
/// Pour les modes MX/NV, la valeur est fixe selon le format.
public var groupSize: Int {
switch self {
case .affine(let groupSize, _):
return groupSize
case .mxfp4, .nvfp4:
return 32
case .mxfp8:
return 32
}
}

/// Nombre de bits par élément pour ce mode de quantification.
///
/// Pour `.affine`, retourne la valeur de l'argument associé.
/// Pour les modes MX/NV, la valeur est fixe selon le format.
public var bits: Int {
switch self {
case .affine(_, let bits):
return bits
case .mxfp4, .nvfp4:
return 4
case .mxfp8:
return 8
}
}
}

extension QuantizationMode: Codable {
/// Clés de codage pour la sérialisation JSON.
private enum CodingKeys: String, CodingKey {
case type
case groupSize
case bits
}

public init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
let type = try container.decode(String.self, forKey: .type)
switch type {
case "affine":
let groupSize = try container.decodeIfPresent(Int.self, forKey: .groupSize) ?? 64
let bits = try container.decodeIfPresent(Int.self, forKey: .bits) ?? 4
self = .affine(groupSize: groupSize, bits: bits)
case "mxfp4":
self = .mxfp4
case "mxfp8":
self = .mxfp8
case "nvfp4":
self = .nvfp4
default:
throw DecodingError.dataCorruptedError(
forKey: .type, in: container,
debugDescription: "Valeur inconnue pour QuantizationMode : \(type)")
}
}

public func encode(to encoder: Encoder) throws {
var container = encoder.container(keyedBy: CodingKeys.self)
switch self {
case .affine(let groupSize, let bits):
try container.encode("affine", forKey: .type)
try container.encode(groupSize, forKey: .groupSize)
try container.encode(bits, forKey: .bits)
case .mxfp4:
try container.encode("mxfp4", forKey: .type)
case .mxfp8:
try container.encode("mxfp8", forKey: .type)
case .nvfp4:
try container.encode("nvfp4", forKey: .type)
}
}
}

/// Dequantize the matrix `w` using the provided `scales` and
Expand All @@ -1146,7 +1263,7 @@ public enum QuantizationMode: String, Codable, Sendable {
public func dequantized(
_ w: MLXArray,
scales: MLXArray, biases: MLXArray?,
groupSize: Int? = nil, bits: Int? = nil, mode: QuantizationMode = .affine,
groupSize: Int? = nil, bits: Int? = nil, mode: QuantizationMode = .affine(),
globalScale: MLXArray? = nil,
dtype: DType? = nil,
stream: StreamOrDevice = .default
Expand All @@ -1157,7 +1274,7 @@ public func dequantized(
let dtype = mlx_optional_dtype(value: dtype?.cmlxDtype ?? MLX_FLOAT16, has_value: dtype != nil)
mlx_dequantize(
&result, w.ctx,
scales.ctx, (biases ?? .mlxNone).ctx, gs, bits, mode.rawValue,
scales.ctx, (biases ?? .mlxNone).ctx, gs, bits, mode.cName,
(globalScale ?? .mlxNone).ctx,
dtype,
stream.ctx)
Expand Down Expand Up @@ -1428,7 +1545,7 @@ public func gatherQuantizedMatmul(
_ x: MLXArray, _ w: MLXArray, scales: MLXArray, biases: MLXArray?,
lhsIndices: MLXArray? = nil, rhsIndices: MLXArray? = nil,
transpose: Bool = true, groupSize: Int? = nil, bits: Int? = nil,
mode: QuantizationMode = .affine,
mode: QuantizationMode = .affine(),
sortedIndices: Bool = false,
stream: StreamOrDevice = .default
) -> MLXArray {
Expand Down Expand Up @@ -1469,7 +1586,7 @@ public func gatherQuantizedMM(
_ x: MLXArray, _ w: MLXArray, scales: MLXArray, biases: MLXArray?,
lhsIndices: MLXArray? = nil, rhsIndices: MLXArray? = nil,
transpose: Bool = true, groupSize: Int? = nil, bits: Int? = nil,
mode: QuantizationMode = .affine,
mode: QuantizationMode = .affine(),
sortedIndices: Bool = false,
stream: StreamOrDevice = .default
) -> MLXArray {
Expand All @@ -1482,7 +1599,7 @@ public func gatherQuantizedMM(
&result,
x.ctx, w.ctx, scales.ctx, (biases ?? .mlxNone).ctx, (lhsIndices ?? .mlxNone).ctx,
(rhsIndices ?? .mlxNone).ctx, transpose,
gs, bits, mode.rawValue, sortedIndices,
gs, bits, mode.cName, sortedIndices,
stream.ctx)

return MLXArray(result)
Expand Down Expand Up @@ -2196,7 +2313,7 @@ public func padded(
mlx_pad(
&result,
array.ctx, axes, ndim, lowPads, ndim, highPads, ndim, value.ctx,
mode.rawValue.cString(using: .utf8), stream.ctx)
mode.cName.cString(using: .utf8), stream.ctx)
return MLXArray(result)
}

Expand Down Expand Up @@ -2227,7 +2344,7 @@ public func padded(
mlx_pad(
&result,
array.ctx, axes, ndim, lowPads, ndim, highPads, ndim, value.ctx,
mode.rawValue.cString(using: .utf8), stream.ctx)
mode.cName.cString(using: .utf8), stream.ctx)
return MLXArray(result)
}

Expand Down Expand Up @@ -2353,7 +2470,7 @@ public func putAlong(
public func quantized(
_ w: MLXArray,
groupSize: Int? = nil, bits: Int? = nil,
mode: QuantizationMode = .affine,
mode: QuantizationMode = .affine(),
globalScale: MLXArray? = nil,
stream: StreamOrDevice = .default
) -> (wq: MLXArray, scales: MLXArray, biases: MLXArray?) {
Expand All @@ -2364,7 +2481,7 @@ public func quantized(
let bits = mlx_optional_int(value: Int32(bits ?? 0), has_value: bits != nil)

mlx_quantize(
&r, w.ctx, gs, bits, mode.rawValue,
&r, w.ctx, gs, bits, mode.cName,
(globalScale ?? .mlxNone).ctx,
stream.ctx)

Expand All @@ -2379,7 +2496,7 @@ public func quantizedMatmul(
_ x: MLXArray, _ w: MLXArray, scales: MLXArray, biases: MLXArray?,
transpose: Bool = true,
groupSize: Int? = nil, bits: Int? = nil,
mode: QuantizationMode = .affine,
mode: QuantizationMode = .affine(),
Comment on lines 2498 to +2499
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is perhaps missing the point of #285 -- specifically the enum would carry (either implicitly or explicitly) the groupSize and bits. nvfp4 is a unusual one because it has different behavior in different contexts.

The idea is good but I am not sure how to implement this just yet.

stream: StreamOrDevice = .default
) -> MLXArray {
quantizedMM(
Expand Down Expand Up @@ -2413,7 +2530,7 @@ public func quantizedMM(
_ x: MLXArray, _ w: MLXArray, scales: MLXArray, biases: MLXArray?,
transpose: Bool = true,
groupSize: Int? = nil, bits: Int? = nil,
mode: QuantizationMode = .affine,
mode: QuantizationMode = .affine(),
stream: StreamOrDevice = .default
) -> MLXArray {
var result = mlx_array_new()
Expand All @@ -2425,7 +2542,7 @@ public func quantizedMM(
&result,
x.ctx, w.ctx, scales.ctx, (biases ?? .mlxNone).ctx,
transpose, gs, bits,
mode.rawValue,
mode.cName,
stream.ctx
)
return MLXArray(result)
Expand Down Expand Up @@ -2476,7 +2593,7 @@ public func quantizedQuantizedMM(
&result,
x.ctx, w.ctx, (scales ?? .mlxNone).ctx,
gs, bits,
mode.rawValue,
mode.cName,
(globalScaleX ?? .mlxNone).ctx,
(globalScaleW ?? .mlxNone).ctx,
stream.ctx
Expand Down
30 changes: 15 additions & 15 deletions Source/MLXNN/Quantized.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public protocol Quantizable {

extension Quantizable {
public func toQuantized(groupSize: Int, bits: Int) -> Module {
toQuantized(groupSize: groupSize, bits: bits, mode: .affine)
toQuantized(groupSize: groupSize, bits: bits, mode: .affine())
}
}

Expand All @@ -29,10 +29,10 @@ public protocol Quantized: Module {

/// Quantize any ``Quantizable`` layer that is not already quantized.
public func quantizeSingle(
layer: Module, groupSize: Int = 64, bits: Int = 4, mode: QuantizationMode = .affine
layer: Module, groupSize: Int = 64, bits: Int = 4, mode: QuantizationMode = .affine()
) -> Quantized? {
if layer is Quantized {
// already quantized
// Déjà quantifié — on ne le retraite pas
nil
} else if let quantizable = layer as? Quantizable {
quantizable.toQuantized(groupSize: groupSize, bits: bits, mode: mode) as? Quantized
Expand All @@ -56,7 +56,7 @@ public func quantizeSingle(
/// - ``quantize(model:filter:apply:)-(_,_,(Module,Int,Int,QuantizationMode)->Module?)``
public func quantize(
model: Module,
groupSize: Int = 64, bits: Int = 4, mode: QuantizationMode = .affine,
groupSize: Int = 64, bits: Int = 4, mode: QuantizationMode = .affine(),
filter: (String, Module) -> Bool = { _, _ in true },
apply: (Module, Int, Int, QuantizationMode) -> Module? = quantizeSingle(
layer:groupSize:bits:mode:)
Expand Down Expand Up @@ -84,11 +84,11 @@ public func quantize(
model: Module, groupSize: Int = 64, bits: Int = 4,
filter: (String, Module) -> Bool = { _, _ in true },
apply: (Module, Int, Int) -> Module? = {
quantizeSingle(layer: $0, groupSize: $1, bits: $2, mode: .affine)
quantizeSingle(layer: $0, groupSize: $1, bits: $2, mode: .affine())
}
) {
quantize(
model: model, groupSize: groupSize, bits: bits, mode: .affine, filter: filter,
model: model, groupSize: groupSize, bits: bits, mode: .affine(), filter: filter,
apply: { l, g, b, n in apply(l, g, b) }
)
}
Expand Down Expand Up @@ -132,14 +132,14 @@ public func quantize(
model: Module,
filter: (String, Module) -> (groupSize: Int, bits: Int)?,
apply: (Module, Int, Int) -> Module? = {
quantizeSingle(layer: $0, groupSize: $1, bits: $2, mode: .affine)
quantizeSingle(layer: $0, groupSize: $1, bits: $2, mode: .affine())
}
) {
quantize(
model: model,
filter: {
if let (g, b) = filter($0, $1) {
return (g, b, .affine)
return (g, b, .affine())
} else {
return nil
}
Expand Down Expand Up @@ -167,7 +167,7 @@ open class QuantizedEmbedding: Embedding, Quantized {

convenience public init(
embeddingCount: Int, dimensions: Int, groupSize: Int = 64, bits: Int = 4,
mode: QuantizationMode = .affine
mode: QuantizationMode = .affine()
) {
let scale = sqrt(1 / Float(dimensions))
let weight = MLXRandom.normal([embeddingCount, dimensions]) * scale
Expand All @@ -177,14 +177,14 @@ open class QuantizedEmbedding: Embedding, Quantized {

public convenience init(
_ other: Embedding, groupSize: Int = 64, bits: Int = 4,
mode: QuantizationMode = .affine
mode: QuantizationMode = .affine()
) {
self.init(weight: other.weight, groupSize: groupSize, bits: bits, mode: mode)
}

public init(
weight: MLXArray, groupSize: Int = 64, bits: Int = 4,
mode: QuantizationMode = .affine
mode: QuantizationMode = .affine()
) {
self.groupSize = groupSize
self.bits = bits
Expand Down Expand Up @@ -263,7 +263,7 @@ open class QuantizedLinear: Linear, Quantized {
public convenience init(
_ inputDimensions: Int, _ outputDimensions: Int,
bias: Bool = true, groupSize: Int = 64, bits: Int = 4,
mode: QuantizationMode = .affine
mode: QuantizationMode = .affine()
) {
let scale = sqrt(1 / Float(inputDimensions))
let weight = MLXRandom.uniform(
Expand All @@ -283,7 +283,7 @@ open class QuantizedLinear: Linear, Quantized {
/// - mode: quantization mode
public convenience init(
_ other: Linear, groupSize: Int = 64, bits: Int = 4,
mode: QuantizationMode = .affine
mode: QuantizationMode = .affine()
) {
self.init(
weight: other.weight, bias: other.bias, groupSize: groupSize, bits: bits, mode: mode)
Expand All @@ -292,7 +292,7 @@ open class QuantizedLinear: Linear, Quantized {
/// Initialize a ``QuantizedLinear`` with non-quantized weights and bias.
public init(
weight: MLXArray, bias: MLXArray?, groupSize: Int = 64, bits: Int = 4,
mode: QuantizationMode = .affine
mode: QuantizationMode = .affine()
) {
self.groupSize = groupSize
self.bits = bits
Expand All @@ -316,7 +316,7 @@ open class QuantizedLinear: Linear, Quantized {
public init(
weight: MLXArray, bias: MLXArray? = nil, scales: MLXArray, biases: MLXArray?,
groupSize: Int, bits: Int,
mode: QuantizationMode = .affine
mode: QuantizationMode = .affine()
) {
self.groupSize = groupSize
self.bits = bits
Expand Down