diff --git a/mlx/backend/cuda/primitives.cpp b/mlx/backend/cuda/primitives.cpp index 79b455cf57..94f260767f 100644 --- a/mlx/backend/cuda/primitives.cpp +++ b/mlx/backend/cuda/primitives.cpp @@ -24,7 +24,6 @@ namespace mlx::core { throw std::runtime_error(#func " has no CUDA implementation."); \ } -NO_GPU(GatherQMM) NO_GPU_MULTI(LUF) NO_GPU_MULTI(QRF) NO_GPU_MULTI(SVD) diff --git a/mlx/backend/cuda/quantized/qmm/qmm.h b/mlx/backend/cuda/quantized/qmm/qmm.h index c96e4f28bc..82e159d8c3 100644 --- a/mlx/backend/cuda/quantized/qmm/qmm.h +++ b/mlx/backend/cuda/quantized/qmm/qmm.h @@ -100,4 +100,17 @@ void qmv( QuantizationMode mode, cu::CommandEncoder& encoder); +void gather_qmv( + const array& x, + const array& w, + const array& scales, + const std::optional& biases, + const array& lhs_indices, + const array& rhs_indices, + array& out, + int bits, + int group_size, + QuantizationMode mode, + cu::CommandEncoder& encoder); + } // namespace mlx::core diff --git a/mlx/backend/cuda/quantized/qmm/qmv.cu b/mlx/backend/cuda/quantized/qmm/qmv.cu index c43e783b4e..c00deaef18 100644 --- a/mlx/backend/cuda/quantized/qmm/qmv.cu +++ b/mlx/backend/cuda/quantized/qmm/qmv.cu @@ -182,7 +182,6 @@ dequant_fma(const T* x, const Q* w, S scale, T bias, float* out) { } template < - int rows_per_block, int elems_per_thread, int group_size, bool has_bias, @@ -190,30 +189,17 @@ template < typename T, typename Q, typename S> -__global__ void qmv_kernel( +__device__ __forceinline__ void qmv_kernel_impl( const T* x, const Q* w, const S* scales, const T* biases, T* out, + int row, + int w_batch, int n, - int k, - bool broadcast_w) { - auto grid = cg::this_grid(); - auto block = cg::this_thread_block(); - auto warp = cg::tiled_partition(block); - - // The row that this warp handles. - int row = block.group_index().x * rows_per_block + warp.meta_group_rank(); - if (row >= n) { - return; - } - - // Advance pointers of x/out. - int m = grid.dim_blocks().y; - int l = block.group_index().z; - x += block.group_index().y * k + m * k * l; - out += block.group_index().y * n + m * n * l; + int k) { + auto warp = cg::tiled_partition(cg::this_thread_block()); // For sub-byte Q, pointer moves by 8bits for each advance, e.g. w += 1 would // move past 2 elements for 4-bit Q. @@ -224,7 +210,6 @@ __global__ void qmv_kernel( int groups_per_row = k / group_size; // Advance w/scales/biases to current row. - int w_batch = broadcast_w ? 0 : l; w += (static_cast(row) + n * w_batch) * w_step(k); scales += (static_cast(row) + n * w_batch) * groups_per_row; if constexpr (has_bias) { @@ -274,6 +259,85 @@ __global__ void qmv_kernel( } } +template < + int rows_per_block, + int elems_per_thread, + int group_size, + bool has_bias, + bool has_residue_k, + typename T, + typename Q, + typename S> +__global__ void qmv_kernel( + const T* x, + const Q* w, + const S* scales, + const T* biases, + T* out, + int n, + int k, + bool broadcast_w) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + // The row that this warp handles. + int row = block.group_index().x * rows_per_block + warp.meta_group_rank(); + if (row >= n) { + return; + } + + // Advance pointers of x/out for M and batch dimensions. + int m = grid.dim_blocks().y; + int l = block.group_index().z; + x += block.group_index().y * k + m * k * l; + out += block.group_index().y * n + m * n * l; + int w_batch = broadcast_w ? 0 : l; + + qmv_kernel_impl( + x, w, scales, biases, out, row, w_batch, n, k); +} + +template < + int rows_per_block, + int elems_per_thread, + int group_size, + bool has_bias, + bool has_residue_k, + typename T, + typename Q, + typename S> +__global__ void gather_qmv_kernel( + const T* x, + const Q* w, + const S* scales, + const T* biases, + T* out, + const uint32_t* lhs_indices, + const uint32_t* rhs_indices, + int n, + int k) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + int row = block.group_index().x * rows_per_block + warp.meta_group_rank(); + if (row >= n) { + return; + } + + int m = grid.dim_blocks().y; + int l = block.group_index().z; + uint32_t x_idx = lhs_indices[l]; + uint32_t w_idx = rhs_indices[l]; + + x += block.group_index().y * k + m * k * x_idx; + out += block.group_index().y * n + m * n * l; + + qmv_kernel_impl( + x, w, scales, biases, out, row, w_idx, n, k); +} + template < int group_size, bool has_bias, @@ -317,6 +381,51 @@ void qmv( }); } +template < + int group_size, + bool has_bias, + typename T, + typename Q, + typename S, + typename F> +void gather_qmv( + const T* x, + const Q* w, + const S* scales, + const T* biases, + T* out, + const uint32_t* lhs_indices, + const uint32_t* rhs_indices, + int m, + int n, + int k, + int l, + F&& launch_kernel) { + constexpr int rows_per_block = 8; + constexpr int elems_per_thread = + (cute::sizeof_bits_v <= 16 && cute::sizeof_bits_v <= 4) ? 16 : 8; + + dim3 num_blocks{ + uint32_t(cuda::ceil_div(n, rows_per_block)), uint32_t(m), uint32_t(l)}; + dim3 block_dims{WARP_SIZE, rows_per_block}; + void* args[] = { + &x, &w, &scales, &biases, &out, &lhs_indices, &rhs_indices, &n, &k}; + + dispatch_bool(k % (WARP_SIZE * elems_per_thread), [&](auto has_residue_k) { + auto* kernel = &gather_qmv_kernel< + rows_per_block, + elems_per_thread, + group_size, + has_bias, + has_residue_k.value, + T, + Q, + S>; + launch_kernel( + reinterpret_cast(kernel), num_blocks, block_dims, args); + }); +} + } // namespace cu template @@ -433,4 +542,59 @@ void qmv( }); } +void gather_qmv( + const array& x, + const array& w, + const array& scales, + const std::optional& biases, + const array& lhs_indices, + const array& rhs_indices, + array& out, + int bits, + int group_size, + QuantizationMode mode, + cu::CommandEncoder& encoder) { + const char* tag = "[gather_qmm]"; + int m = out.shape(-2); + int n = out.shape(-1); + int k = x.shape(-1); + int l = out.size() / (m * n); + + dispatch_element_types(out.dtype(), tag, [&]() { + dispatch_quant_types( + bits, + group_size, + mode, + tag, + [&]() { + encoder.set_input_array(x); + encoder.set_input_array(w); + encoder.set_input_array(scales); + if (biases) { + encoder.set_input_array(*biases); + } + encoder.set_input_array(lhs_indices); + encoder.set_input_array(rhs_indices); + encoder.set_output_array(out); + constexpr bool has_bias = !cutlass::has_negative_zero_v; + cu::gather_qmv( + gpu_ptr(x), + gpu_ptr(w), + gpu_ptr(scales), + biases ? gpu_ptr(*biases) : nullptr, + gpu_ptr(out), + gpu_ptr(lhs_indices), + gpu_ptr(rhs_indices), + m, + n, + k, + l, + [&](auto* kernel, dim3 num_blocks, dim3 block_dims, void** args) { + encoder.add_kernel_node_raw( + kernel, num_blocks, block_dims, {}, 0, args); + }); + }); + }); +} + } // namespace mlx::core diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index d7252ec196..01f5ba7643 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -104,6 +104,78 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { quantization_mode_to_string(mode_))); } +void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("GatherQMM::eval_gpu"); + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + + const array& x = inputs[0]; + const array& w = inputs[1]; + const array& scales = inputs[2]; + std::optional biases; + if (inputs.size() == 6) { + biases = inputs[3]; + } + array lhs_indices = ensure_contiguous(inputs[inputs.size() - 2], encoder, s); + array rhs_indices = ensure_contiguous(inputs[inputs.size() - 1], encoder, s); + + int M = out.shape(-2); + int N = out.shape(-1); + int K = x.shape(-1); + int B = out.size() / (M * N); + + auto supports = [&](auto&& f) { + return f( + x, + w, + scales, + biases, + out, + transpose_, + bits_, + group_size_, + mode_, + encoder.device()); + }; + bool can_use_qmv = supports(supports_qmv); + + auto call_qmv = [&]() { + out.set_data(cu::malloc_async(out.nbytes(), encoder)); + gather_qmv( + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + out, + bits_, + group_size_, + mode_, + encoder); + }; + + if (can_use_qmv) { + call_qmv(); + return; + } + + throw std::runtime_error( + fmt::format( + "[gather_qmm] No implementation for " + "problem shape: {}x{}x{}x{}, transpose: {}, " + "activation: {}, bits: {}, group size: {}, mode: \"{}\".", + M, + N, + K, + B, + transpose_, + dtype_to_string(x.dtype()), + bits_, + group_size_, + quantization_mode_to_string(mode_))); +} + void fast::Quantize::eval_gpu( const std::vector& inputs, std::vector& outputs) {