From b9f5b2b5ce1c678fbac8d2a97cf6c5d95350c927 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Tue, 24 Mar 2026 14:28:34 +0800 Subject: [PATCH 1/4] feat: add GatherQMM implementation for quantized gather matmul --- mlx/backend/cuda/primitives.cpp | 1 - mlx/backend/cuda/quantized/qmm/qmm.h | 13 + mlx/backend/cuda/quantized/qmm/qmv.cu | 316 +++++++++++++++++++++-- mlx/backend/cuda/quantized/quantized.cpp | 89 +++++++ 4 files changed, 395 insertions(+), 24 deletions(-) diff --git a/mlx/backend/cuda/primitives.cpp b/mlx/backend/cuda/primitives.cpp index 98dca5708f..6f8fa5ee1d 100644 --- a/mlx/backend/cuda/primitives.cpp +++ b/mlx/backend/cuda/primitives.cpp @@ -25,7 +25,6 @@ namespace mlx::core { } NO_GPU(BlockMaskedMM) -NO_GPU(GatherQMM) NO_GPU_MULTI(LUF) NO_GPU_MULTI(QRF) NO_GPU_MULTI(SVD) diff --git a/mlx/backend/cuda/quantized/qmm/qmm.h b/mlx/backend/cuda/quantized/qmm/qmm.h index c96e4f28bc..82e159d8c3 100644 --- a/mlx/backend/cuda/quantized/qmm/qmm.h +++ b/mlx/backend/cuda/quantized/qmm/qmm.h @@ -100,4 +100,17 @@ void qmv( QuantizationMode mode, cu::CommandEncoder& encoder); +void gather_qmv( + const array& x, + const array& w, + const array& scales, + const std::optional& biases, + const array& lhs_indices, + const array& rhs_indices, + array& out, + int bits, + int group_size, + QuantizationMode mode, + cu::CommandEncoder& encoder); + } // namespace mlx::core diff --git a/mlx/backend/cuda/quantized/qmm/qmv.cu b/mlx/backend/cuda/quantized/qmm/qmv.cu index c43e783b4e..8bce00a141 100644 --- a/mlx/backend/cuda/quantized/qmm/qmv.cu +++ b/mlx/backend/cuda/quantized/qmm/qmv.cu @@ -1,7 +1,10 @@ // Copyright © 2026 Apple Inc. +#include "mlx/backend/common/utils.h" +#include "mlx/backend/cuda/device/utils.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/quantized/qmm/qmm.h" +#include "mlx/backend/cuda/quantized/quantized_utils.h" #include "mlx/dtype_utils.h" #include @@ -182,7 +185,6 @@ dequant_fma(const T* x, const Q* w, S scale, T bias, float* out) { } template < - int rows_per_block, int elems_per_thread, int group_size, bool has_bias, @@ -190,30 +192,16 @@ template < typename T, typename Q, typename S> -__global__ void qmv_kernel( +__device__ __forceinline__ void qmv_impl( const T* x, const Q* w, const S* scales, const T* biases, T* out, + int row, int n, - int k, - bool broadcast_w) { - auto grid = cg::this_grid(); - auto block = cg::this_thread_block(); - auto warp = cg::tiled_partition(block); - - // The row that this warp handles. - int row = block.group_index().x * rows_per_block + warp.meta_group_rank(); - if (row >= n) { - return; - } - - // Advance pointers of x/out. - int m = grid.dim_blocks().y; - int l = block.group_index().z; - x += block.group_index().y * k + m * k * l; - out += block.group_index().y * n + m * n * l; + int k) { + auto warp = cg::tiled_partition(cg::this_thread_block()); // For sub-byte Q, pointer moves by 8bits for each advance, e.g. w += 1 would // move past 2 elements for 4-bit Q. @@ -224,11 +212,10 @@ __global__ void qmv_kernel( int groups_per_row = k / group_size; // Advance w/scales/biases to current row. - int w_batch = broadcast_w ? 0 : l; - w += (static_cast(row) + n * w_batch) * w_step(k); - scales += (static_cast(row) + n * w_batch) * groups_per_row; + w += static_cast(row) * w_step(k); + scales += static_cast(row) * groups_per_row; if constexpr (has_bias) { - biases += (static_cast(row) + n * w_batch) * groups_per_row; + biases += static_cast(row) * groups_per_row; } // Accumulations of current row. @@ -274,6 +261,165 @@ __global__ void qmv_kernel( } } +template < + int rows_per_block, + int elems_per_thread, + int group_size, + bool has_bias, + bool has_residue_k, + typename T, + typename Q, + typename S> +__global__ void qmv_kernel( + const T* x, + const Q* w, + const S* scales, + const T* biases, + T* out, + int n, + int k, + bool broadcast_w) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + // The row that this warp handles. + int row = block.group_index().x * rows_per_block + warp.meta_group_rank(); + if (row >= n) { + return; + } + + // Advance pointers of x/out for M and batch dimensions. + int m = grid.dim_blocks().y; + int l = block.group_index().z; + x += block.group_index().y * k + m * k * l; + out += block.group_index().y * n + m * n * l; + + // Advance w/scales/biases for batch dimension. + constexpr int bits = cute::sizeof_bits_v; + auto w_step = [&](int idx) { return idx * cuda::std::min(8, bits) / 8; }; + int groups_per_row = k / group_size; + int w_batch = broadcast_w ? 0 : l; + w += static_cast(n) * w_batch * w_step(k); + scales += static_cast(n) * w_batch * groups_per_row; + if constexpr (has_bias) { + biases += static_cast(n) * w_batch * groups_per_row; + } + + // Row-level compute: dequantize, FMA, reduce, write. + qmv_impl( + x, w, scales, biases, out, row, n, k); +} + +template < + int rows_per_block, + int elems_per_thread, + int group_size, + bool has_bias, + bool has_residue_k, + typename T, + typename Q, + typename S> +__global__ void gather_qmv_kernel( + const T* x, + const uint32_t* w, + const S* scales, + const T* biases, + T* out, + const uint32_t* lhs_indices, + const uint32_t* rhs_indices, + int n, + int k, + int x_batch_ndims, + const __grid_constant__ Shape x_batch_shape, + const __grid_constant__ Strides x_batch_strides, + int w_batch_ndims, + const __grid_constant__ Shape w_batch_shape, + const __grid_constant__ Strides w_batch_strides, + const __grid_constant__ Strides s_batch_strides, + const __grid_constant__ Strides b_batch_strides, + int index_ndims, + const __grid_constant__ Shape index_shape, + const __grid_constant__ Strides lhs_index_strides, + const __grid_constant__ Strides rhs_index_strides, + int output_stride) { + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + // The row that this warp handles. + int row = block.group_index().x * rows_per_block + warp.meta_group_rank(); + if (row >= n) { + return; + } + + // Gather: look up batch indices. + uint32_t batch_idx = block.group_index().z; + uint32_t x_idx, w_idx; + if (index_ndims == 1) { + x_idx = lhs_indices[batch_idx * lhs_index_strides[0]]; + w_idx = rhs_indices[batch_idx * rhs_index_strides[0]]; + } else { + auto [lhs_off, rhs_off] = elem_to_loc( + batch_idx, + index_shape.data(), + lhs_index_strides.data(), + rhs_index_strides.data(), + index_ndims); + x_idx = lhs_indices[lhs_off]; + w_idx = rhs_indices[rhs_off]; + } + + // Offset x using gathered index. + if (x_batch_ndims == 1) { + x += x_idx * x_batch_strides[0]; + } else { + x += elem_to_loc( + x_idx, x_batch_shape.data(), x_batch_strides.data(), x_batch_ndims); + } + + // Offset w/scales/biases using gathered index. + if (w_batch_ndims == 1) { + w += w_idx * w_batch_strides[0]; + scales += w_idx * s_batch_strides[0]; + if constexpr (has_bias) { + biases += w_idx * b_batch_strides[0]; + } + } else { + if constexpr (has_bias) { + auto [w_off, s_off, b_off] = elem_to_loc( + w_idx, + w_batch_shape.data(), + w_batch_strides.data(), + s_batch_strides.data(), + b_batch_strides.data(), + w_batch_ndims); + w += w_off; + scales += s_off; + biases += b_off; + } else { + auto [w_off, s_off] = elem_to_loc( + w_idx, + w_batch_shape.data(), + w_batch_strides.data(), + s_batch_strides.data(), + w_batch_ndims); + w += w_off; + scales += s_off; + } + } + + // Offset output for this batch element. + out += batch_idx * output_stride; + + // Advance pointers for M dimension (block.group_index().y). + x += block.group_index().y * k; + out += block.group_index().y * n; + + // Reinterpret w as Q* for sub-byte access, then run shared compute. + qmv_impl( + x, reinterpret_cast(w), scales, biases, out, row, n, k); +} + template < int group_size, bool has_bias, @@ -433,4 +579,128 @@ void qmv( }); } +void gather_qmv( + const array& x, + const array& w, + const array& scales, + const std::optional& biases, + const array& lhs_indices, + const array& rhs_indices, + array& out, + int bits, + int group_size, + QuantizationMode mode, + cu::CommandEncoder& encoder) { + const char* tag = "[gather_qmm]"; + int m = out.shape(-2); + int n = out.shape(-1); + int k = x.shape(-1); + int B = out.size() / (m * n); + + // Collapse contiguous dims for index arrays. + auto [idx_shape, idx_strides] = collapse_contiguous_dims( + lhs_indices.shape(), {lhs_indices.strides(), rhs_indices.strides()}); + + dispatch_element_types(out.dtype(), tag, [&]() { + dispatch_quant_types( + bits, + group_size, + mode, + tag, + [&]() { + encoder.set_input_array(x); + encoder.set_input_array(w); + encoder.set_input_array(scales); + if (biases) { + encoder.set_input_array(*biases); + } + encoder.set_input_array(lhs_indices); + encoder.set_input_array(rhs_indices); + encoder.set_output_array(out); + + constexpr bool has_bias = !cutlass::has_negative_zero_v; + constexpr int rows_per_block = 8; + constexpr int elems_per_thread = + (cute::sizeof_bits_v <= 16 && cute::sizeof_bits_v <= 4) ? 16 + : 8; + + dim3 num_blocks{ + uint32_t(cuda::ceil_div(n, rows_per_block)), + uint32_t(m), + uint32_t(B)}; + dim3 block_dims{WARP_SIZE, rows_per_block}; + int output_stride = m * n; + + auto x_ptr = gpu_ptr(x); + auto w_ptr = gpu_ptr(w); + auto s_ptr = gpu_ptr(scales); + auto b_ptr = biases ? gpu_ptr(*biases) : (const T*)nullptr; + auto o_ptr = gpu_ptr(out); + auto li_ptr = gpu_ptr(lhs_indices); + auto ri_ptr = gpu_ptr(rhs_indices); + + int x_batch_ndims = x.ndim() - 2; + auto x_shape_p = const_param(x.shape()); + auto x_strides_p = const_param(x.strides()); + int w_batch_ndims = w.ndim() - 2; + auto w_shape_p = const_param(w.shape()); + auto w_strides_p = const_param(w.strides()); + auto s_strides_p = const_param(scales.strides()); + auto b_strides_p = biases + ? const_param(biases->strides()) + : decltype(s_strides_p){}; + int index_ndims = idx_shape.size(); + auto idx_shape_p = const_param(idx_shape); + auto lhs_idx_strides_p = + const_param(idx_strides[0]); + auto rhs_idx_strides_p = + const_param(idx_strides[1]); + + void* args[] = { + &x_ptr, + &w_ptr, + &s_ptr, + &b_ptr, + &o_ptr, + &li_ptr, + &ri_ptr, + &n, + &k, + &x_batch_ndims, + &x_shape_p, + &x_strides_p, + &w_batch_ndims, + &w_shape_p, + &w_strides_p, + &s_strides_p, + &b_strides_p, + &index_ndims, + &idx_shape_p, + &lhs_idx_strides_p, + &rhs_idx_strides_p, + &output_stride}; + + dispatch_bool( + k % (WARP_SIZE * elems_per_thread), [&](auto has_residue_k) { + auto* kernel = &cu::gather_qmv_kernel< + rows_per_block, + elems_per_thread, + group_size, + has_bias, + has_residue_k.value, + T, + Q, + S>; + encoder.add_kernel_node_raw( + reinterpret_cast(kernel), + num_blocks, + block_dims, + {}, + 0, + args); + }); + }); + }); +} + } // namespace mlx::core diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index d7252ec196..84e36e421c 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -104,6 +104,95 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { quantization_mode_to_string(mode_))); } +void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("GatherQMM::eval_gpu"); + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + + const array& x = inputs[0]; + const array& w = inputs[1]; + const array& scales = inputs[2]; + std::optional biases; + if (inputs.size() == 6) { + biases = inputs[3]; + } + const array& lhs_indices = inputs[inputs.size() - 2]; + const array& rhs_indices = inputs[inputs.size() - 1]; + + int M = out.shape(-2); + int N = out.shape(-1); + int K = x.shape(-1); + int B = out.size() / (M * N); + + auto supports = [&](auto&& f) { + return f( + x, + w, + scales, + biases, + out, + transpose_, + bits_, + group_size_, + mode_, + encoder.device()); + }; + bool can_use_fp_qmv = supports(supports_fp_qmv); + bool can_use_qmv = supports(supports_qmv) || can_use_fp_qmv; + + auto call_qmv = [&]() { + out.set_data(cu::malloc_async(out.nbytes(), encoder)); + if (can_use_fp_qmv) { + // TODO: Add gather_fp_qmv for FP-mode-specific optimizations. + gather_qmv( + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + out, + bits_, + group_size_, + mode_, + encoder); + } else { + gather_qmv( + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + out, + bits_, + group_size_, + mode_, + encoder); + } + }; + + if (can_use_qmv) { + call_qmv(); + return; + } + + throw std::runtime_error( + fmt::format( + "[gather_qmm] No implementation for " + "problem shape: {}x{}x{}x{}, transpose: {}, " + "activation: {}, bits: {}, group size: {}, mode: \"{}\".", + M, + N, + K, + B, + transpose_, + dtype_to_string(x.dtype()), + bits_, + group_size_, + quantization_mode_to_string(mode_))); +} + void fast::Quantize::eval_gpu( const std::vector& inputs, std::vector& outputs) { From 898ed0209fb1383202d61dcdb6b2bf43d2e94437 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Thu, 26 Mar 2026 10:45:33 +0800 Subject: [PATCH 2/4] apply pr comments --- mlx/backend/cuda/quantized/qmm/qmv.cu | 120 +++++------------------ mlx/backend/cuda/quantized/quantized.cpp | 43 +++----- 2 files changed, 35 insertions(+), 128 deletions(-) diff --git a/mlx/backend/cuda/quantized/qmm/qmv.cu b/mlx/backend/cuda/quantized/qmm/qmv.cu index 8bce00a141..b0ac2ddf64 100644 --- a/mlx/backend/cuda/quantized/qmm/qmv.cu +++ b/mlx/backend/cuda/quantized/qmm/qmv.cu @@ -1,10 +1,7 @@ // Copyright © 2026 Apple Inc. -#include "mlx/backend/common/utils.h" -#include "mlx/backend/cuda/device/utils.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/quantized/qmm/qmm.h" -#include "mlx/backend/cuda/quantized/quantized_utils.h" #include "mlx/dtype_utils.h" #include @@ -330,18 +327,10 @@ __global__ void gather_qmv_kernel( const uint32_t* rhs_indices, int n, int k, - int x_batch_ndims, - const __grid_constant__ Shape x_batch_shape, - const __grid_constant__ Strides x_batch_strides, - int w_batch_ndims, - const __grid_constant__ Shape w_batch_shape, - const __grid_constant__ Strides w_batch_strides, - const __grid_constant__ Strides s_batch_strides, - const __grid_constant__ Strides b_batch_strides, - int index_ndims, - const __grid_constant__ Shape index_shape, - const __grid_constant__ Strides lhs_index_strides, - const __grid_constant__ Strides rhs_index_strides, + int64_t x_batch_stride, + int64_t w_batch_stride, + int64_t s_batch_stride, + int64_t b_batch_stride, int output_stride) { auto block = cg::this_thread_block(); auto warp = cg::tiled_partition(block); @@ -354,58 +343,15 @@ __global__ void gather_qmv_kernel( // Gather: look up batch indices. uint32_t batch_idx = block.group_index().z; - uint32_t x_idx, w_idx; - if (index_ndims == 1) { - x_idx = lhs_indices[batch_idx * lhs_index_strides[0]]; - w_idx = rhs_indices[batch_idx * rhs_index_strides[0]]; - } else { - auto [lhs_off, rhs_off] = elem_to_loc( - batch_idx, - index_shape.data(), - lhs_index_strides.data(), - rhs_index_strides.data(), - index_ndims); - x_idx = lhs_indices[lhs_off]; - w_idx = rhs_indices[rhs_off]; - } + uint32_t x_idx = lhs_indices[batch_idx]; + uint32_t w_idx = rhs_indices[batch_idx]; - // Offset x using gathered index. - if (x_batch_ndims == 1) { - x += x_idx * x_batch_strides[0]; - } else { - x += elem_to_loc( - x_idx, x_batch_shape.data(), x_batch_strides.data(), x_batch_ndims); - } - - // Offset w/scales/biases using gathered index. - if (w_batch_ndims == 1) { - w += w_idx * w_batch_strides[0]; - scales += w_idx * s_batch_strides[0]; - if constexpr (has_bias) { - biases += w_idx * b_batch_strides[0]; - } - } else { - if constexpr (has_bias) { - auto [w_off, s_off, b_off] = elem_to_loc( - w_idx, - w_batch_shape.data(), - w_batch_strides.data(), - s_batch_strides.data(), - b_batch_strides.data(), - w_batch_ndims); - w += w_off; - scales += s_off; - biases += b_off; - } else { - auto [w_off, s_off] = elem_to_loc( - w_idx, - w_batch_shape.data(), - w_batch_strides.data(), - s_batch_strides.data(), - w_batch_ndims); - w += w_off; - scales += s_off; - } + // Offset pointers using gathered indices. + x += x_idx * x_batch_stride; + w += w_idx * w_batch_stride; + scales += w_idx * s_batch_stride; + if constexpr (has_bias) { + biases += w_idx * b_batch_stride; } // Offset output for this batch element. @@ -597,9 +543,12 @@ void gather_qmv( int k = x.shape(-1); int B = out.size() / (m * n); - // Collapse contiguous dims for index arrays. - auto [idx_shape, idx_strides] = collapse_contiguous_dims( - lhs_indices.shape(), {lhs_indices.strides(), rhs_indices.strides()}); + // Batch strides for contiguous inputs. + int64_t x_batch_stride = x.strides()[0]; + int64_t w_batch_stride = w.strides()[0]; + int64_t s_batch_stride = scales.strides()[0]; + int64_t b_batch_stride = + biases ? biases->strides()[0] : static_cast(0); dispatch_element_types(out.dtype(), tag, [&]() { dispatch_quant_types( @@ -639,23 +588,6 @@ void gather_qmv( auto li_ptr = gpu_ptr(lhs_indices); auto ri_ptr = gpu_ptr(rhs_indices); - int x_batch_ndims = x.ndim() - 2; - auto x_shape_p = const_param(x.shape()); - auto x_strides_p = const_param(x.strides()); - int w_batch_ndims = w.ndim() - 2; - auto w_shape_p = const_param(w.shape()); - auto w_strides_p = const_param(w.strides()); - auto s_strides_p = const_param(scales.strides()); - auto b_strides_p = biases - ? const_param(biases->strides()) - : decltype(s_strides_p){}; - int index_ndims = idx_shape.size(); - auto idx_shape_p = const_param(idx_shape); - auto lhs_idx_strides_p = - const_param(idx_strides[0]); - auto rhs_idx_strides_p = - const_param(idx_strides[1]); - void* args[] = { &x_ptr, &w_ptr, @@ -666,18 +598,10 @@ void gather_qmv( &ri_ptr, &n, &k, - &x_batch_ndims, - &x_shape_p, - &x_strides_p, - &w_batch_ndims, - &w_shape_p, - &w_strides_p, - &s_strides_p, - &b_strides_p, - &index_ndims, - &idx_shape_p, - &lhs_idx_strides_p, - &rhs_idx_strides_p, + &x_batch_stride, + &w_batch_stride, + &s_batch_stride, + &b_batch_stride, &output_stride}; dispatch_bool( diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index 84e36e421c..f952d3421b 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -137,39 +137,22 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { mode_, encoder.device()); }; - bool can_use_fp_qmv = supports(supports_fp_qmv); - bool can_use_qmv = supports(supports_qmv) || can_use_fp_qmv; + bool can_use_qmv = supports(supports_qmv) || supports(supports_fp_qmv); auto call_qmv = [&]() { out.set_data(cu::malloc_async(out.nbytes(), encoder)); - if (can_use_fp_qmv) { - // TODO: Add gather_fp_qmv for FP-mode-specific optimizations. - gather_qmv( - x, - w, - scales, - biases, - lhs_indices, - rhs_indices, - out, - bits_, - group_size_, - mode_, - encoder); - } else { - gather_qmv( - x, - w, - scales, - biases, - lhs_indices, - rhs_indices, - out, - bits_, - group_size_, - mode_, - encoder); - } + gather_qmv( + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + out, + bits_, + group_size_, + mode_, + encoder); }; if (can_use_qmv) { From 5426153807242b08a2af183a9840e590b50a90e3 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Thu, 26 Mar 2026 12:23:48 +0800 Subject: [PATCH 3/4] fix: ensure gather indices are contiguous before kernel dispatch Broadcast index arrays (e.g., from default lhs_indices) may have stride-0 dimensions backed by fewer elements than the logical shape. The gather_qmv kernel reads indices linearly after collapse, causing out-of-bounds access. Fix by copying to contiguous layout first. --- mlx/backend/cuda/quantized/quantized.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index f952d3421b..d2f0e95b33 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -116,8 +116,8 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { if (inputs.size() == 6) { biases = inputs[3]; } - const array& lhs_indices = inputs[inputs.size() - 2]; - const array& rhs_indices = inputs[inputs.size() - 1]; + array lhs_indices = ensure_contiguous(inputs[inputs.size() - 2], encoder, s); + array rhs_indices = ensure_contiguous(inputs[inputs.size() - 1], encoder, s); int M = out.shape(-2); int N = out.shape(-1); From f83932d60ff63bc03c96447d7cd738b43483f4f5 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Thu, 26 Mar 2026 19:16:01 +0800 Subject: [PATCH 4/4] apply pr comments --- mlx/backend/cuda/quantized/qmm/qmv.cu | 186 ++++++++++------------- mlx/backend/cuda/quantized/quantized.cpp | 2 +- 2 files changed, 79 insertions(+), 109 deletions(-) diff --git a/mlx/backend/cuda/quantized/qmm/qmv.cu b/mlx/backend/cuda/quantized/qmm/qmv.cu index b0ac2ddf64..c00deaef18 100644 --- a/mlx/backend/cuda/quantized/qmm/qmv.cu +++ b/mlx/backend/cuda/quantized/qmm/qmv.cu @@ -189,13 +189,14 @@ template < typename T, typename Q, typename S> -__device__ __forceinline__ void qmv_impl( +__device__ __forceinline__ void qmv_kernel_impl( const T* x, const Q* w, const S* scales, const T* biases, T* out, int row, + int w_batch, int n, int k) { auto warp = cg::tiled_partition(cg::this_thread_block()); @@ -209,10 +210,10 @@ __device__ __forceinline__ void qmv_impl( int groups_per_row = k / group_size; // Advance w/scales/biases to current row. - w += static_cast(row) * w_step(k); - scales += static_cast(row) * groups_per_row; + w += (static_cast(row) + n * w_batch) * w_step(k); + scales += (static_cast(row) + n * w_batch) * groups_per_row; if constexpr (has_bias) { - biases += static_cast(row) * groups_per_row; + biases += (static_cast(row) + n * w_batch) * groups_per_row; } // Accumulations of current row. @@ -291,21 +292,10 @@ __global__ void qmv_kernel( int l = block.group_index().z; x += block.group_index().y * k + m * k * l; out += block.group_index().y * n + m * n * l; - - // Advance w/scales/biases for batch dimension. - constexpr int bits = cute::sizeof_bits_v; - auto w_step = [&](int idx) { return idx * cuda::std::min(8, bits) / 8; }; - int groups_per_row = k / group_size; int w_batch = broadcast_w ? 0 : l; - w += static_cast(n) * w_batch * w_step(k); - scales += static_cast(n) * w_batch * groups_per_row; - if constexpr (has_bias) { - biases += static_cast(n) * w_batch * groups_per_row; - } - // Row-level compute: dequantize, FMA, reduce, write. - qmv_impl( - x, w, scales, biases, out, row, n, k); + qmv_kernel_impl( + x, w, scales, biases, out, row, w_batch, n, k); } template < @@ -319,51 +309,33 @@ template < typename S> __global__ void gather_qmv_kernel( const T* x, - const uint32_t* w, + const Q* w, const S* scales, const T* biases, T* out, const uint32_t* lhs_indices, const uint32_t* rhs_indices, int n, - int k, - int64_t x_batch_stride, - int64_t w_batch_stride, - int64_t s_batch_stride, - int64_t b_batch_stride, - int output_stride) { + int k) { + auto grid = cg::this_grid(); auto block = cg::this_thread_block(); auto warp = cg::tiled_partition(block); - // The row that this warp handles. int row = block.group_index().x * rows_per_block + warp.meta_group_rank(); if (row >= n) { return; } - // Gather: look up batch indices. - uint32_t batch_idx = block.group_index().z; - uint32_t x_idx = lhs_indices[batch_idx]; - uint32_t w_idx = rhs_indices[batch_idx]; - - // Offset pointers using gathered indices. - x += x_idx * x_batch_stride; - w += w_idx * w_batch_stride; - scales += w_idx * s_batch_stride; - if constexpr (has_bias) { - biases += w_idx * b_batch_stride; - } - - // Offset output for this batch element. - out += batch_idx * output_stride; + int m = grid.dim_blocks().y; + int l = block.group_index().z; + uint32_t x_idx = lhs_indices[l]; + uint32_t w_idx = rhs_indices[l]; - // Advance pointers for M dimension (block.group_index().y). - x += block.group_index().y * k; - out += block.group_index().y * n; + x += block.group_index().y * k + m * k * x_idx; + out += block.group_index().y * n + m * n * l; - // Reinterpret w as Q* for sub-byte access, then run shared compute. - qmv_impl( - x, reinterpret_cast(w), scales, biases, out, row, n, k); + qmv_kernel_impl( + x, w, scales, biases, out, row, w_idx, n, k); } template < @@ -409,6 +381,51 @@ void qmv( }); } +template < + int group_size, + bool has_bias, + typename T, + typename Q, + typename S, + typename F> +void gather_qmv( + const T* x, + const Q* w, + const S* scales, + const T* biases, + T* out, + const uint32_t* lhs_indices, + const uint32_t* rhs_indices, + int m, + int n, + int k, + int l, + F&& launch_kernel) { + constexpr int rows_per_block = 8; + constexpr int elems_per_thread = + (cute::sizeof_bits_v <= 16 && cute::sizeof_bits_v <= 4) ? 16 : 8; + + dim3 num_blocks{ + uint32_t(cuda::ceil_div(n, rows_per_block)), uint32_t(m), uint32_t(l)}; + dim3 block_dims{WARP_SIZE, rows_per_block}; + void* args[] = { + &x, &w, &scales, &biases, &out, &lhs_indices, &rhs_indices, &n, &k}; + + dispatch_bool(k % (WARP_SIZE * elems_per_thread), [&](auto has_residue_k) { + auto* kernel = &gather_qmv_kernel< + rows_per_block, + elems_per_thread, + group_size, + has_bias, + has_residue_k.value, + T, + Q, + S>; + launch_kernel( + reinterpret_cast(kernel), num_blocks, block_dims, args); + }); +} + } // namespace cu template @@ -541,14 +558,7 @@ void gather_qmv( int m = out.shape(-2); int n = out.shape(-1); int k = x.shape(-1); - int B = out.size() / (m * n); - - // Batch strides for contiguous inputs. - int64_t x_batch_stride = x.strides()[0]; - int64_t w_batch_stride = w.strides()[0]; - int64_t s_batch_stride = scales.strides()[0]; - int64_t b_batch_stride = - biases ? biases->strides()[0] : static_cast(0); + int l = out.size() / (m * n); dispatch_element_types(out.dtype(), tag, [&]() { dispatch_quant_types( @@ -566,62 +576,22 @@ void gather_qmv( encoder.set_input_array(lhs_indices); encoder.set_input_array(rhs_indices); encoder.set_output_array(out); - constexpr bool has_bias = !cutlass::has_negative_zero_v; - constexpr int rows_per_block = 8; - constexpr int elems_per_thread = - (cute::sizeof_bits_v <= 16 && cute::sizeof_bits_v <= 4) ? 16 - : 8; - - dim3 num_blocks{ - uint32_t(cuda::ceil_div(n, rows_per_block)), - uint32_t(m), - uint32_t(B)}; - dim3 block_dims{WARP_SIZE, rows_per_block}; - int output_stride = m * n; - - auto x_ptr = gpu_ptr(x); - auto w_ptr = gpu_ptr(w); - auto s_ptr = gpu_ptr(scales); - auto b_ptr = biases ? gpu_ptr(*biases) : (const T*)nullptr; - auto o_ptr = gpu_ptr(out); - auto li_ptr = gpu_ptr(lhs_indices); - auto ri_ptr = gpu_ptr(rhs_indices); - - void* args[] = { - &x_ptr, - &w_ptr, - &s_ptr, - &b_ptr, - &o_ptr, - &li_ptr, - &ri_ptr, - &n, - &k, - &x_batch_stride, - &w_batch_stride, - &s_batch_stride, - &b_batch_stride, - &output_stride}; - - dispatch_bool( - k % (WARP_SIZE * elems_per_thread), [&](auto has_residue_k) { - auto* kernel = &cu::gather_qmv_kernel< - rows_per_block, - elems_per_thread, - group_size, - has_bias, - has_residue_k.value, - T, - Q, - S>; + cu::gather_qmv( + gpu_ptr(x), + gpu_ptr(w), + gpu_ptr(scales), + biases ? gpu_ptr(*biases) : nullptr, + gpu_ptr(out), + gpu_ptr(lhs_indices), + gpu_ptr(rhs_indices), + m, + n, + k, + l, + [&](auto* kernel, dim3 num_blocks, dim3 block_dims, void** args) { encoder.add_kernel_node_raw( - reinterpret_cast(kernel), - num_blocks, - block_dims, - {}, - 0, - args); + kernel, num_blocks, block_dims, {}, 0, args); }); }); }); diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index d2f0e95b33..01f5ba7643 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -137,7 +137,7 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { mode_, encoder.device()); }; - bool can_use_qmv = supports(supports_qmv) || supports(supports_fp_qmv); + bool can_use_qmv = supports(supports_qmv); auto call_qmv = [&]() { out.set_data(cu::malloc_async(out.nbytes(), encoder));