Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions mlx/backend/cpu/quantized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1367,4 +1367,9 @@ void QQMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
}
}

void QQAddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
// QQAddMM requires GPU support (CUDA CC 10.0+ or Metal qmv case)
throw std::runtime_error("[QQAddMM] Not implemented for CPU.");
}

} // namespace mlx::core
3 changes: 2 additions & 1 deletion mlx/backend/cuda/quantized/no_qqmm_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ void qqmm_impl(
const array&,
const array&,
QuantizationMode,
const GemmScalars&) {
const GemmScalars&,
const std::optional<array>&) {
throw std::runtime_error(
"[QQMatmul::eval_gpu] QQMM is only supported with CUDA 12.8 or higher.");
}
Expand Down
88 changes: 88 additions & 0 deletions mlx/backend/cuda/quantized/qqmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,4 +183,92 @@ void QQMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
scalars);
}

void QQAddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("QQAddMM::eval_gpu");

auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
auto& device = encoder.device();

// inputs: [c, x, w, (scales_w), (global_scale_x, global_scale_w)]
const array& c = inputs[0];
bool w_quantized = (inputs[2].dtype() == uint32);
int base_size = w_quantized ? 4 : 3; // c + x + w + (scales_w if quantized)

assert(
inputs.size() == base_size ||
(mode_ == QuantizationMode::Nvfp4 && inputs.size() == base_size + 2));

auto cc = device.compute_capability_major() * 100 +
device.compute_capability_minor() * 10;
if (cc < 1000) {
throw std::runtime_error(
"[QQAddMM::eval_gpu] QQAddMM is only supported on GPUs with compute capability 10.0 or higher.");
}

// For nvfp4, global scales are optional but must be both present or both
// absent. If present, they add 2 more inputs (global_scale_x, global_scale_w)
bool has_global_scales =
mode_ == QuantizationMode::Nvfp4 && inputs.size() > base_size;

// For nvfp4, get global scales from inputs if present
std::optional<array> global_scale_x = std::nullopt;
std::optional<array> global_scale_w = std::nullopt;
if (has_global_scales) {
global_scale_x = inputs[inputs.size() - 2];
global_scale_w = inputs[inputs.size() - 1];
}

// Quantize inputs (or use pre-quantized)
auto [x_q, scale_x_pre] = quantize_input(
inputs[1], encoder, s, mode_, bits_, group_size_, global_scale_x);
auto [w_q, scale_w_pre] = !w_quantized
? quantize_input(
inputs[2], encoder, s, mode_, bits_, group_size_, global_scale_w)
: std::make_tuple(
ensure_contiguous(inputs[2], encoder, s),
ensure_contiguous(inputs[3], encoder, s));

out.set_data(cu::malloc_async(out.nbytes(), encoder));

int M = x_q.shape(-2);
int N = w_q.shape(-2); // transposed
int K = x_q.shape(-1) * (32 / bits_);

bool x_transposed = false;
bool w_transposed = true; // always transposed
int64_t lda = K;
int64_t ldb = K;

// Repack scales to tiled layout for tensor cores
array scale_x = pad_and_swizzle_scales(scale_x_pre, encoder, s);
array scale_w = pad_and_swizzle_scales(scale_w_pre, encoder, s);

GemmScalars scalars;
if (has_global_scales) {
scalars = create_nvfp4_scalars(*global_scale_x, *global_scale_w, encoder);
}

// Ensure bias is row contiguous and pass it to qqmm_impl
array bias = ensure_row_contiguous(c, encoder, s);

qqmm_impl(
encoder,
M,
N,
K,
x_transposed,
lda,
w_transposed,
ldb,
out,
x_q,
w_q,
scale_x,
scale_w,
mode_,
scalars,
bias);
}

} // namespace mlx::core
11 changes: 10 additions & 1 deletion mlx/backend/cuda/quantized/qqmm_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ void qqmm_impl(
const array& a_scale,
const array& b_scale,
QuantizationMode mode,
const GemmScalars& scalars) {
const GemmScalars& scalars,
const std::optional<array>& bias) {
std::string qmode = quantization_mode_to_string(mode);

CublasQQMM qqmm(
Expand All @@ -39,6 +40,14 @@ void qqmm_impl(
out.dtype(),
qmode);

// Note: Unlike regular GEMM, no complex64 check is needed here because
// quantized matmul only supports real floating types (float16, bfloat16,
// float32). The type constraint is enforced in validate_qqmm_inputs() in
// ops.cpp.
if (bias) {
qqmm.set_bias(encoder, *bias);
}

if (scalars.has_values()) {
qqmm.run(
encoder,
Expand Down
3 changes: 2 additions & 1 deletion mlx/backend/cuda/quantized/qqmm_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ void qqmm_impl(
const array& a_scale,
const array& b_scale,
QuantizationMode mode,
const GemmScalars& scalars = {});
const GemmScalars& scalars = {},
const std::optional<array>& bias = std::nullopt);

} // namespace mlx::core
55 changes: 55 additions & 0 deletions mlx/backend/metal/quantized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1553,6 +1553,61 @@ void QQMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
}
}

void QQAddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream();
auto& d = metal::device(s.device);

auto mode = quantization_mode_to_string(mode_);

// inputs: [c, x, w, (scales_w)]
const array& c = inputs[0];
bool w_quantized = (inputs[2].dtype() == uint32);

// QMV case (M=1): supported with bias via dispatch_qmv
if (w_quantized && inputs[1].shape(-2) == 1) {
out.set_data(allocator::malloc(out.nbytes()));

bool donate_x = inputs[1].is_donatable();
array x = ensure_row_contiguous(inputs[1], d, s);
// If x is a copy it should be donatable
donate_x |= x.is_donatable();
auto xhat = donate_x
? x
: array(allocator::malloc(x.nbytes()), x.shape(), x.dtype());
quantize_dequantize(x, xhat, mode, group_size_, bits_, d, s);

// Make sure the last two dims of w and scales are contiguous
array w = ensure_row_contiguous_matrix(inputs[2], d, s);
array scales = ensure_row_contiguous_matrix(inputs[3], d, s);

// Ensure bias is contiguous
array bias = ensure_row_contiguous(c, d, s);

bool non_batched = w.ndim() == 2;
int K = x.shape(-1);
int M = non_batched ? x.size() / K : x.shape(-2);
int N = out.shape(-1);

dispatch_qmv(
xhat,
w,
scales,
bias, // Pass bias to use the epilogue
out,
group_size_,
bits_,
M,
N,
K,
d,
s,
mode);
return;
} else {
throw std::runtime_error("[QQAddMM] NYI for the general case");
}
}

void fast::Quantize::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
Expand Down
76 changes: 76 additions & 0 deletions mlx/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4525,6 +4525,82 @@ array qqmm(
return out;
}

array qqaddmm(
array c,
array in_x,
array w,
std::optional<array> scales_w /* = std::nullopt */,
std::optional<int> group_size_ /* = std::nullopt */,
std::optional<int> bits_ /* = std::nullopt */,
const std::string& mode /* = "nvfp4" */,
const std::optional<array> global_scale_x /* = std::nullopt */,
const std::optional<array> global_scale_w /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
auto stream = to_stream(s);
auto qmode = string_to_quantization_mode(mode, "qqaddmm");

// cuBLAS block scaled matmul only supports nvfp4 and mxfp8
if (qmode != QuantizationMode::Nvfp4 && qmode != QuantizationMode::Mxfp8) {
std::ostringstream msg;
msg << "[qqaddmm] Only 'nvfp4' and 'mxfp8' quantization modes are supported but '"
<< mode << "' was provided.";
throw std::invalid_argument(msg.str());
}

auto [group_size, bits] =
quantization_params_from_mode(qmode, group_size_, bits_);

// Allow gemv
auto x = in_x;
if (x.ndim() == 1) {
x = expand_dims(x, 0, s);
} else if (w.ndim() == 2 && x.ndim() > 2) {
x = flatten(x, 0, -2, s);
}

// Validate inputs (reuse qqmm validation)
validate_qqmm_inputs(
x, w, scales_w, group_size, bits, global_scale_x, global_scale_w, qmode);

// Validate and extract shapes
auto [w_inner_dims, w_outer_dims] =
extract_qqmm_dims(x, w, scales_w, group_size, bits);

// Validate bias shape
auto out_shape = x.shape();
out_shape.back() = w_outer_dims;

// Broadcast c to output shape (similar to addmm)
auto c_broadcast_shape = broadcast_shapes(c.shape(), {out_shape.back()});
c = broadcast_to(c, c_broadcast_shape, s);
c = astype(c, x.dtype(), s);

// Build inputs: [c, x, w, (scales_w), (global_scale_x, global_scale_w)]
std::vector<array> inputs = {c, x, w};
if (scales_w.has_value()) {
inputs.push_back(*scales_w);
}
if (global_scale_x.has_value() && global_scale_w.has_value()) {
inputs.push_back(*global_scale_x);
inputs.push_back(*global_scale_w);
}

auto out = array(
std::move(out_shape),
x.dtype(),
std::make_shared<QQAddMM>(stream, group_size, bits, qmode),
std::move(inputs));

if (in_x.ndim() > 2) {
auto orig_shape = in_x.shape();
orig_shape.pop_back();
out = unflatten(out, 0, std::move(orig_shape), s);
} else if (in_x.ndim() == 1) {
out = squeeze(out, 0, s);
}
return out;
}

array pack_and_quantize(
array& packed_w,
const array& scales,
Expand Down
14 changes: 14 additions & 0 deletions mlx/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -1430,6 +1430,20 @@ MLX_API array qqmm(
const std::optional<array> global_scale_w = std::nullopt,
StreamOrDevice s = {});

/** Compute D = C + (x @ w.T) with quantized x and w */
MLX_API array qqaddmm(
array c, // bias to add
array x, // input activations
array w, // maybe quantized weights
const std::optional<array> w_scales = std::nullopt, // optional scales if w
// is quantized
std::optional<int> group_size = std::nullopt,
std::optional<int> bits = std::nullopt,
const std::string& mode = "nvfp4",
const std::optional<array> global_scale_x = std::nullopt,
const std::optional<array> global_scale_w = std::nullopt,
StreamOrDevice s = {});

/** Convert an E4M3 float8 to the given floating point dtype. */
MLX_API array from_fp8(array x, Dtype dtype, StreamOrDevice s = {});

Expand Down
Loading