diff --git a/mlx/backend/cpu/quantized.cpp b/mlx/backend/cpu/quantized.cpp index 3c3643a320..f4d9bedafb 100644 --- a/mlx/backend/cpu/quantized.cpp +++ b/mlx/backend/cpu/quantized.cpp @@ -1367,4 +1367,9 @@ void QQMatmul::eval_cpu(const std::vector& inputs, array& out) { } } +void QQAddMM::eval_cpu(const std::vector& inputs, array& out) { + // QQAddMM requires GPU support (CUDA CC 10.0+ or Metal qmv case) + throw std::runtime_error("[QQAddMM] Not implemented for CPU."); +} + } // namespace mlx::core diff --git a/mlx/backend/cuda/quantized/no_qqmm_impl.cpp b/mlx/backend/cuda/quantized/no_qqmm_impl.cpp index 375755bc3d..7d078b8cde 100644 --- a/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +++ b/mlx/backend/cuda/quantized/no_qqmm_impl.cpp @@ -18,7 +18,8 @@ void qqmm_impl( const array&, const array&, QuantizationMode, - const GemmScalars&) { + const GemmScalars&, + const std::optional&) { throw std::runtime_error( "[QQMatmul::eval_gpu] QQMM is only supported with CUDA 12.8 or higher."); } diff --git a/mlx/backend/cuda/quantized/qqmm.cpp b/mlx/backend/cuda/quantized/qqmm.cpp index afce96e3ce..fef6544191 100644 --- a/mlx/backend/cuda/quantized/qqmm.cpp +++ b/mlx/backend/cuda/quantized/qqmm.cpp @@ -183,4 +183,92 @@ void QQMatmul::eval_gpu(const std::vector& inputs, array& out) { scalars); } +void QQAddMM::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("QQAddMM::eval_gpu"); + + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + auto& device = encoder.device(); + + // inputs: [c, x, w, (scales_w), (global_scale_x, global_scale_w)] + const array& c = inputs[0]; + bool w_quantized = (inputs[2].dtype() == uint32); + int base_size = w_quantized ? 4 : 3; // c + x + w + (scales_w if quantized) + + assert( + inputs.size() == base_size || + (mode_ == QuantizationMode::Nvfp4 && inputs.size() == base_size + 2)); + + auto cc = device.compute_capability_major() * 100 + + device.compute_capability_minor() * 10; + if (cc < 1000) { + throw std::runtime_error( + "[QQAddMM::eval_gpu] QQAddMM is only supported on GPUs with compute capability 10.0 or higher."); + } + + // For nvfp4, global scales are optional but must be both present or both + // absent. If present, they add 2 more inputs (global_scale_x, global_scale_w) + bool has_global_scales = + mode_ == QuantizationMode::Nvfp4 && inputs.size() > base_size; + + // For nvfp4, get global scales from inputs if present + std::optional global_scale_x = std::nullopt; + std::optional global_scale_w = std::nullopt; + if (has_global_scales) { + global_scale_x = inputs[inputs.size() - 2]; + global_scale_w = inputs[inputs.size() - 1]; + } + + // Quantize inputs (or use pre-quantized) + auto [x_q, scale_x_pre] = quantize_input( + inputs[1], encoder, s, mode_, bits_, group_size_, global_scale_x); + auto [w_q, scale_w_pre] = !w_quantized + ? quantize_input( + inputs[2], encoder, s, mode_, bits_, group_size_, global_scale_w) + : std::make_tuple( + ensure_contiguous(inputs[2], encoder, s), + ensure_contiguous(inputs[3], encoder, s)); + + out.set_data(cu::malloc_async(out.nbytes(), encoder)); + + int M = x_q.shape(-2); + int N = w_q.shape(-2); // transposed + int K = x_q.shape(-1) * (32 / bits_); + + bool x_transposed = false; + bool w_transposed = true; // always transposed + int64_t lda = K; + int64_t ldb = K; + + // Repack scales to tiled layout for tensor cores + array scale_x = pad_and_swizzle_scales(scale_x_pre, encoder, s); + array scale_w = pad_and_swizzle_scales(scale_w_pre, encoder, s); + + GemmScalars scalars; + if (has_global_scales) { + scalars = create_nvfp4_scalars(*global_scale_x, *global_scale_w, encoder); + } + + // Ensure bias is row contiguous and pass it to qqmm_impl + array bias = ensure_row_contiguous(c, encoder, s); + + qqmm_impl( + encoder, + M, + N, + K, + x_transposed, + lda, + w_transposed, + ldb, + out, + x_q, + w_q, + scale_x, + scale_w, + mode_, + scalars, + bias); +} + } // namespace mlx::core diff --git a/mlx/backend/cuda/quantized/qqmm_impl.cpp b/mlx/backend/cuda/quantized/qqmm_impl.cpp index ccf736cf64..5835ab64cb 100644 --- a/mlx/backend/cuda/quantized/qqmm_impl.cpp +++ b/mlx/backend/cuda/quantized/qqmm_impl.cpp @@ -20,7 +20,8 @@ void qqmm_impl( const array& a_scale, const array& b_scale, QuantizationMode mode, - const GemmScalars& scalars) { + const GemmScalars& scalars, + const std::optional& bias) { std::string qmode = quantization_mode_to_string(mode); CublasQQMM qqmm( @@ -39,6 +40,14 @@ void qqmm_impl( out.dtype(), qmode); + // Note: Unlike regular GEMM, no complex64 check is needed here because + // quantized matmul only supports real floating types (float16, bfloat16, + // float32). The type constraint is enforced in validate_qqmm_inputs() in + // ops.cpp. + if (bias) { + qqmm.set_bias(encoder, *bias); + } + if (scalars.has_values()) { qqmm.run( encoder, diff --git a/mlx/backend/cuda/quantized/qqmm_impl.h b/mlx/backend/cuda/quantized/qqmm_impl.h index ab2b74c19f..724ff2ffe3 100644 --- a/mlx/backend/cuda/quantized/qqmm_impl.h +++ b/mlx/backend/cuda/quantized/qqmm_impl.h @@ -32,6 +32,7 @@ void qqmm_impl( const array& a_scale, const array& b_scale, QuantizationMode mode, - const GemmScalars& scalars = {}); + const GemmScalars& scalars = {}, + const std::optional& bias = std::nullopt); } // namespace mlx::core diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index bd197937c5..e5191b8084 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -1553,6 +1553,61 @@ void QQMatmul::eval_gpu(const std::vector& inputs, array& out) { } } +void QQAddMM::eval_gpu(const std::vector& inputs, array& out) { + auto& s = stream(); + auto& d = metal::device(s.device); + + auto mode = quantization_mode_to_string(mode_); + + // inputs: [c, x, w, (scales_w)] + const array& c = inputs[0]; + bool w_quantized = (inputs[2].dtype() == uint32); + + // QMV case (M=1): supported with bias via dispatch_qmv + if (w_quantized && inputs[1].shape(-2) == 1) { + out.set_data(allocator::malloc(out.nbytes())); + + bool donate_x = inputs[1].is_donatable(); + array x = ensure_row_contiguous(inputs[1], d, s); + // If x is a copy it should be donatable + donate_x |= x.is_donatable(); + auto xhat = donate_x + ? x + : array(allocator::malloc(x.nbytes()), x.shape(), x.dtype()); + quantize_dequantize(x, xhat, mode, group_size_, bits_, d, s); + + // Make sure the last two dims of w and scales are contiguous + array w = ensure_row_contiguous_matrix(inputs[2], d, s); + array scales = ensure_row_contiguous_matrix(inputs[3], d, s); + + // Ensure bias is contiguous + array bias = ensure_row_contiguous(c, d, s); + + bool non_batched = w.ndim() == 2; + int K = x.shape(-1); + int M = non_batched ? x.size() / K : x.shape(-2); + int N = out.shape(-1); + + dispatch_qmv( + xhat, + w, + scales, + bias, // Pass bias to use the epilogue + out, + group_size_, + bits_, + M, + N, + K, + d, + s, + mode); + return; + } else { + throw std::runtime_error("[QQAddMM] NYI for the general case"); + } +} + void fast::Quantize::eval_gpu( const std::vector& inputs, std::vector& outputs) { diff --git a/mlx/ops.cpp b/mlx/ops.cpp index c7af8834fe..2b868cc961 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -4525,6 +4525,82 @@ array qqmm( return out; } +array qqaddmm( + array c, + array in_x, + array w, + std::optional scales_w /* = std::nullopt */, + std::optional group_size_ /* = std::nullopt */, + std::optional bits_ /* = std::nullopt */, + const std::string& mode /* = "nvfp4" */, + const std::optional global_scale_x /* = std::nullopt */, + const std::optional global_scale_w /* = std::nullopt */, + StreamOrDevice s /* = {} */) { + auto stream = to_stream(s); + auto qmode = string_to_quantization_mode(mode, "qqaddmm"); + + // cuBLAS block scaled matmul only supports nvfp4 and mxfp8 + if (qmode != QuantizationMode::Nvfp4 && qmode != QuantizationMode::Mxfp8) { + std::ostringstream msg; + msg << "[qqaddmm] Only 'nvfp4' and 'mxfp8' quantization modes are supported but '" + << mode << "' was provided."; + throw std::invalid_argument(msg.str()); + } + + auto [group_size, bits] = + quantization_params_from_mode(qmode, group_size_, bits_); + + // Allow gemv + auto x = in_x; + if (x.ndim() == 1) { + x = expand_dims(x, 0, s); + } else if (w.ndim() == 2 && x.ndim() > 2) { + x = flatten(x, 0, -2, s); + } + + // Validate inputs (reuse qqmm validation) + validate_qqmm_inputs( + x, w, scales_w, group_size, bits, global_scale_x, global_scale_w, qmode); + + // Validate and extract shapes + auto [w_inner_dims, w_outer_dims] = + extract_qqmm_dims(x, w, scales_w, group_size, bits); + + // Validate bias shape + auto out_shape = x.shape(); + out_shape.back() = w_outer_dims; + + // Broadcast c to output shape (similar to addmm) + auto c_broadcast_shape = broadcast_shapes(c.shape(), {out_shape.back()}); + c = broadcast_to(c, c_broadcast_shape, s); + c = astype(c, x.dtype(), s); + + // Build inputs: [c, x, w, (scales_w), (global_scale_x, global_scale_w)] + std::vector inputs = {c, x, w}; + if (scales_w.has_value()) { + inputs.push_back(*scales_w); + } + if (global_scale_x.has_value() && global_scale_w.has_value()) { + inputs.push_back(*global_scale_x); + inputs.push_back(*global_scale_w); + } + + auto out = array( + std::move(out_shape), + x.dtype(), + std::make_shared(stream, group_size, bits, qmode), + std::move(inputs)); + + if (in_x.ndim() > 2) { + auto orig_shape = in_x.shape(); + orig_shape.pop_back(); + out = unflatten(out, 0, std::move(orig_shape), s); + } else if (in_x.ndim() == 1) { + out = squeeze(out, 0, s); + } + return out; +} + array pack_and_quantize( array& packed_w, const array& scales, diff --git a/mlx/ops.h b/mlx/ops.h index 74032c01e0..dbaef0062f 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1430,6 +1430,20 @@ MLX_API array qqmm( const std::optional global_scale_w = std::nullopt, StreamOrDevice s = {}); +/** Compute D = C + (x @ w.T) with quantized x and w */ +MLX_API array qqaddmm( + array c, // bias to add + array x, // input activations + array w, // maybe quantized weights + const std::optional w_scales = std::nullopt, // optional scales if w + // is quantized + std::optional group_size = std::nullopt, + std::optional bits = std::nullopt, + const std::string& mode = "nvfp4", + const std::optional global_scale_x = std::nullopt, + const std::optional global_scale_w = std::nullopt, + StreamOrDevice s = {}); + /** Convert an E4M3 float8 to the given floating point dtype. */ MLX_API array from_fp8(array x, Dtype dtype, StreamOrDevice s = {}); diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 92e54f9991..238b9605da 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3547,6 +3547,98 @@ std::vector QQMatmul::jvp( throw std::runtime_error("QQMM::jvp NYI"); } +bool QQAddMM::is_equivalent(const Primitive& other) const { + const QQAddMM& qm_other = static_cast(other); + return group_size_ == qm_other.group_size_ && bits_ == qm_other.bits_ && + mode_ == qm_other.mode_; +} + +std::vector QQAddMM::output_shapes(const std::vector& inputs) { + // inputs: [c, x, w, ...] + auto out_shape = inputs[1].shape(); + int w_outer_dims = inputs[2].shape(-2); + out_shape.back() = w_outer_dims; + return {std::move(out_shape)}; +} + +std::vector QQAddMM::vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector&) { + // primals: [c, x, w, (global_scale_x, global_scale_w if nvfp4)] + // For qqaddmm(c, x, w) = c + x @ w.T: + // grad_c = cotan (summed to match c's shape) + // grad_x = cotan @ w (same as qqmm) + // grad_w = cotan.T @ x (same as qqmm) + bool is_nvfp4 = mode_ == QuantizationMode::Nvfp4; + assert(primals.size() == 3 || (is_nvfp4 && primals.size() == 5)); + + std::vector vjps; + auto& cotan = cotangents[0]; + auto& s = stream(); + auto qmode = quantization_mode_to_string(mode_); + + std::optional cotan_amax = (primals.size() == 5) + ? std::make_optional(astype(max(abs(cotan, s), s), float32, s)) + : std::nullopt; + + auto get_primal_scale = [&](int idx) { + return (primals.size() == 5) ? std::make_optional(primals[idx]) + : std::nullopt; + }; + + for (auto arg : argnums) { + if (arg == 0) { // gradient wrt c (bias) + // grad_c = sum(cotan) along batch dimensions to match c's shape + auto grad_c = cotan; + // Sum along all but the last dimension to match bias shape + if (cotan.ndim() > 1) { + std::vector axes; + for (int i = 0; i < cotan.ndim() - 1; i++) { + axes.push_back(i); + } + grad_c = sum(cotan, axes, false, s); + } + vjps.push_back(grad_c); + } else if (arg == 1) { // gradient wrt x + // Same as QQMatmul: grad_x = cotan @ w + vjps.push_back(qqmm( + cotan, + swapaxes(primals[2], -1, -2, s), + {}, + group_size_, + bits_, + qmode, + cotan_amax, + get_primal_scale(4), // global_scale_w + s)); + } else if (arg == 2) { // gradient wrt w + // Same as QQMatmul: grad_w = cotan.T @ x + vjps.push_back(qqmm( + swapaxes(cotan, -1, -2, s), + swapaxes(primals[1], -1, -2, s), + {}, + group_size_, + bits_, + qmode, + cotan_amax, + get_primal_scale(3), // global_scale_x + s)); + } else { + vjps.push_back(zeros_like(primals[arg], s)); + } + } + return vjps; +} + +std::vector QQAddMM::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + throw std::runtime_error("QQAddMM::jvp NYI"); +} + std::pair, std::vector> GatherQMM::vmap( const std::vector& inputs, const std::vector& axes) { diff --git a/mlx/primitives.h b/mlx/primitives.h index 4091aafcfb..6658b81120 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1673,6 +1673,35 @@ class QQMatmul : public UnaryPrimitive { QuantizationMode mode_; }; +class QQAddMM : public UnaryPrimitive { + public: + explicit QQAddMM( + Stream stream, + int group_size, + int bits, + QuantizationMode mode) + : UnaryPrimitive(stream), + group_size_(group_size), + bits_(bits), + mode_(mode) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_GRADS() + DEFINE_NAME(QQAddMM) + bool is_equivalent(const Primitive& other) const override; + std::vector output_shapes(const std::vector& inputs) override; + auto state() const { + return std::make_tuple(group_size_, bits_, mode_); + } + + private: + int group_size_; + int bits_; + QuantizationMode mode_; +}; + class GatherQMM : public UnaryPrimitive { public: explicit GatherQMM( diff --git a/python/mlx/nn/layers/quantized.py b/python/mlx/nn/layers/quantized.py index 57e7c88898..eed5774e33 100644 --- a/python/mlx/nn/layers/quantized.py +++ b/python/mlx/nn/layers/quantized.py @@ -321,11 +321,11 @@ class QQLinear(Module): :obj:`QQLinear` also provides the class method :meth:`from_linear` to convert :class:`mlx.nn.Linear` layers to :obj:`QQLinear` layers. - Note: This layer does not support a bias term yet. - Args: input_dims (int): The dimensionality of the input features. output_dims (int): The dimensionality of the output features. + bias (bool, optional): If set to ``False`` then the layer will not use + a bias. Default: ``False``. group_size (Optional[int]): The group size to use for the quantized weight. See :func:`~mlx.core.quantize`. Default: ``None``. bits (Optional[int]): The bit width to use for the quantized weight. @@ -339,6 +339,7 @@ def __init__( self, input_dims: int, output_dims: int, + bias: bool = False, group_size: int = None, bits: int = None, mode: str = "nvfp4", @@ -357,12 +358,15 @@ def __init__( ) self._quantized = False + if bias: + self.bias = mx.zeros((output_dims,)) + def _extra_repr(self): out_dims, in_dims = self.weight.shape if self.weight.dtype == mx.uint32: in_dims = (in_dims * 32) // self.bits return ( - f"input_dims={in_dims}, output_dims={out_dims}, " + f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}, " f"group_size={self.group_size}, bits={self.bits}, mode={self.mode}" ) @@ -397,7 +401,19 @@ def _set_training_mode(self, mode: bool): self.quantize() def __call__(self, x): - x = mx.qqmm( + if "bias" in self: + # Use qqaddmm to fuse bias addition into the matmul epilogue + return mx.qqaddmm( + self["bias"], + x, + self["weight"], + scales=self.get("scales"), + group_size=self.group_size, + bits=self.bits, + mode=self.mode, + ) + + return mx.qqmm( x, self["weight"], scales=self.get("scales"), @@ -405,7 +421,6 @@ def __call__(self, x): bits=self.bits, mode=self.mode, ) - return x @classmethod def from_linear( @@ -417,10 +432,19 @@ def from_linear( ): """Create a :obj:`QQLinear` layer from a :obj:`Linear` layer.""" output_dims, input_dims = linear_layer.weight.shape # (N,K) - if linear_layer.get("bias") is not None: - raise NotImplementedError("QQLinear does not support bias yet.") - ql = cls(input_dims, output_dims, group_size, bits, mode=mode) + has_bias = linear_layer.get("bias") is not None + ql = cls( + input_dims, + output_dims, + bias=False, + group_size=group_size, + bits=bits, + mode=mode, + ) ql.weight = linear_layer.weight - ql.train(linear_layer.training) + if has_bias: + ql.bias = linear_layer.bias + + ql.train(linear_layer.training) return ql diff --git a/python/src/ops.cpp b/python/src/ops.cpp index a4ce55f8b3..35b4c39320 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -5610,6 +5610,64 @@ void init_ops(nb::module_& m) { array: The result of the multiplication of quantized ``x`` with quantized ``w``. needed). )pbdoc"); + m.def( + "qqaddmm", + &mx::qqaddmm, + nb::arg(), // c (bias) + nb::arg(), // x + nb::arg(), // w_q + "scales"_a = nb::none(), // scales w + "group_size"_a = nb::none(), + "bits"_a = nb::none(), + "mode"_a = "nvfp4", + "global_scale_x"_a = nb::none(), + "global_scale_w"_a = nb::none(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def qqaddmm(c: array, x: array, w: array, scales: Optional[array] = None, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = 'nvfp4', global_scale_x: Optional[array] = None, global_scale_w: Optional[array] = None, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Compute ``c + x @ w.T`` where ``x`` and ``w`` are quantized on the fly. + + This operation fuses bias addition with quantized matrix multiplication, + using the hardware bias epilogue when available (CUDA compute capability + 10.0+) for better performance than separate ``qqmm`` followed by ``add``. + + If ``w`` is quantized, ``scales`` must be provided, and ``group_size``, + ``bits``, and ``mode`` must match the parameters that were used to quantize + ``w``. + + Notes: + Currently only supported on CUDA with compute capability 10.0 or higher. + On other devices, an error will be raised. + + If ``w`` is expected to receive gradients, it must be provided in + non-quantized form. + + ``global_scale_x`` and ``global_scale_w`` are only used for ``nvfp4`` quantization. + + Args: + c (array): The bias array to add to the result. + x (array): Input array. + w (array): Weight matrix. If quantized, it is packed in unsigned integers. + scales (array, optional): The scales to use per ``group_size`` elements of + ``w`` if ``w`` is quantized. Default: ``None``. + group_size (int, optional): Number of elements in ``x`` and ``w`` that + share a scale. See supported values and defaults in the + :ref:`table of quantization modes `. Default: ``None``. + bits (int, optional): Number of bits used to represent each element of + ``x`` and ``w``. See supported values and defaults in the + :ref:`table of quantization modes `. Default: ``None``. + mode (str, optional): The quantization mode. Default: ``"nvfp4"``. + Supported modes are ``nvfp4`` and ``mxfp8``. See the + :ref:`table of quantization modes ` for details. + global_scale_x (array, optional): The per-input float32 scale used for x + with ``"nvfp4"`` quantization. Default: ``None``. + global_scale_w (array, optional): The per-input float32 scale used for w + with ``"nvfp4"`` quantization. Default: ``None``. + Returns: + array: The result of ``c + x @ w.T`` with quantized ``x`` and ``w``. + )pbdoc"); m.def( "from_fp8", &mx::from_fp8, diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index ebdfe580ec..1d08879bff 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -232,6 +232,37 @@ def test_quantize_freeze(self): size = tree_reduce(lambda acc, p: acc + p.size, qlin.trainable_parameters(), 0) self.assertTrue(size > 0) + def test_qqlinear_bias(self): + # Test QQLinear with bias=True + layer = nn.QQLinear(32, 64, bias=True, mode="nvfp4") + self.assertIn("bias", layer) + self.assertEqual(layer.bias.shape, (64,)) + + # Test QQLinear with bias=False (default) + layer_no_bias = nn.QQLinear(32, 64, bias=False, mode="nvfp4") + self.assertNotIn("bias", layer_no_bias) + + # Test _extra_repr shows bias info + self.assertIn("bias=True", layer._extra_repr()) + self.assertIn("bias=False", layer_no_bias._extra_repr()) + + def test_qqlinear_from_linear_with_bias(self): + # Test converting Linear with bias to QQLinear + linear = nn.Linear(32, 64, bias=True) + mx.eval(linear.parameters()) + + qqlinear = nn.QQLinear.from_linear(linear, mode="nvfp4") + self.assertIn("bias", qqlinear) + self.assertEqual(qqlinear.bias.shape, (64,)) + + # Test that bias values are copied + self.assertTrue(mx.array_equal(qqlinear.bias, linear.bias)) + + # Test converting Linear without bias + linear_no_bias = nn.Linear(32, 64, bias=False) + qqlinear_no_bias = nn.QQLinear.from_linear(linear_no_bias, mode="nvfp4") + self.assertNotIn("bias", qqlinear_no_bias) + def test_quantized_sharded_linear_construction(self): input_dims, output_dims = 1536, 1024 for bits in [2, 3, 4, 5, 6, 8]: