diff --git a/Source/MLXNN/Quantized.swift b/Source/MLXNN/Quantized.swift index 5432a8e0..076e91ca 100644 --- a/Source/MLXNN/Quantized.swift +++ b/Source/MLXNN/Quantized.swift @@ -299,7 +299,7 @@ open class QuantizedLinear: Linear, Quantized { self.mode = mode let (quantizedWeight, scales, biases) = MLX.quantized( - weight, groupSize: groupSize, bits: bits) + weight, groupSize: groupSize, bits: bits, mode: mode) self.scales = scales self.biases = biases diff --git a/Tests/MLXTests/QuantizationTests.swift b/Tests/MLXTests/QuantizationTests.swift index 9f1dbf0b..0edbd545 100644 --- a/Tests/MLXTests/QuantizationTests.swift +++ b/Tests/MLXTests/QuantizationTests.swift @@ -34,4 +34,9 @@ class QuantizationTests: XCTestCase { XCTAssertEqual( quantized3.describeExtra(0), "(embeddingCount=512, dimensions=1024)") } + + func testQuantizedLinearMxfp4DoesNotCreateAffineBiases() { + let quantized = QuantizedLinear(64, 64, groupSize: 32, bits: 4, mode: .mxfp4) + XCTAssertNil(quantized.biases) + } }