Fix QuantizedLinear for non-affine quantization modes (mxfp4)#384
Conversation
Source/MLXNN/Quantized.swift
Outdated
| if parameter == "biases" && mode != .affine { | ||
| return | ||
| } |
There was a problem hiding this comment.
I am curious about this -- the init for QuantizedLinear should have left us with a biases = nil and we shouldn't hit this case. I wonder what happened here?
There was a problem hiding this comment.
davidkoski
left a comment
There was a problem hiding this comment.
Please look at my comment on updateMissing() -- I think it should be removed, but let me know if you think otherwise.
Nice find and fix!
QuantizedLinear.init(weight:...) called MLX.quantized() without forwarding the mode parameter, so weights were always quantized as affine regardless of the specified mode. This produced spurious biases for non-affine modes like mxfp4. Fix: pass mode: mode to MLX.quantized(), matching QuantizedEmbedding which already does this correctly. Note: no updateMissing() override is needed — when biases is nil, Module.build(value:) wraps it as .value(.other(...)), and the (.value(.other(_)), .none) case in update() already breaks silently. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
b68356b to
15532c3
Compare
|
You're right — I traced through The mxfp4 nil-biases test stays to guard the Thanks for the thorough review and for maintaining this project! |
davidkoski
left a comment
There was a problem hiding this comment.
Changes look good, thank you for figuring this out!

Proposed changes
QuantizedLinearfails to load mxfp4-quantized models with.keyNotFoundfor"biases". Two bugs:Mode not forwarded during quantization — The weight-based
initcallsMLX.quantized(weight, groupSize:bits:)without passingmode, so weights are always quantized as affine regardless of the mode specified.QuantizedEmbeddingalready passesmodecorrectly.Missing biases treated as error for non-affine modes —
Module.update(parameters:, verify: .all)callsupdateMissing()for the"biases"key, which throws.keyNotFound. Non-affine modes (e.g..mxfp4) don't produce biases, so this key is legitimately absent.Fix
mode: modetoMLX.quantized()inQuantizedLinear.init(weight:bias:groupSize:bits:mode:)updateMissing()to skip"biases"whenmode != .affineTests
testQuantizedLinearMxfp4DoesNotCreateAffineBiases— verifies mxfp4 producesnilbiasestestQuantizedLinearMxfp4ParametersRoundTripWithoutBiases— verifiesupdate(parameters:, verify: .all)succeeds without a"biases"keyChecklist
pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes