diff --git a/mlx/backend/cuda/quantized/qmm/qmm.cu b/mlx/backend/cuda/quantized/qmm/qmm.cu index 6d5fa188f4..c8cb6f5800 100644 --- a/mlx/backend/cuda/quantized/qmm/qmm.cu +++ b/mlx/backend/cuda/quantized/qmm/qmm.cu @@ -103,10 +103,11 @@ void qmm_impl_sm80( const array& x, const array& w, const array& scales, - const array& biases, + const std::optional& biases, array& out, int bits, int group_size, + QuantizationMode mode, cu::CommandEncoder& encoder); bool supports_qmm_sm80( @@ -128,11 +129,11 @@ bool supports_qmm_sm80( if ((n % 128 != 0) || (k % std::max(64, group_size) != 0)) { return false; } - if (!biases) { + if (!x.flags().row_contiguous || !w.flags().row_contiguous || + !scales.flags().row_contiguous) { return false; } - if (!x.flags().row_contiguous || !w.flags().row_contiguous || - !scales.flags().row_contiguous || !biases->flags().row_contiguous) { + if (biases && !biases->flags().row_contiguous) { return false; } if (x.dtype() != float16 && x.dtype() != bfloat16) { @@ -141,10 +142,7 @@ bool supports_qmm_sm80( if (!transpose) { return false; } - if (bits != 8) { - return false; - } - if (mode != QuantizationMode::Affine) { + if (bits != 4 && bits != 8) { return false; } return true; @@ -154,13 +152,15 @@ void qmm_sm80( const array& x, const array& w, const array& scales, - const array& biases, + const std::optional& biases, array& out, int bits, int group_size, + QuantizationMode mode, cu::CommandEncoder& encoder) { auto dispatch = [&]() { - qmm_impl_sm80(x, w, scales, biases, out, bits, group_size, encoder); + qmm_impl_sm80( + x, w, scales, biases, out, bits, group_size, mode, encoder); }; int m = out.shape(-2); if (m <= 16) { diff --git a/mlx/backend/cuda/quantized/qmm/qmm.h b/mlx/backend/cuda/quantized/qmm/qmm.h index f729abb86f..c96e4f28bc 100644 --- a/mlx/backend/cuda/quantized/qmm/qmm.h +++ b/mlx/backend/cuda/quantized/qmm/qmm.h @@ -48,10 +48,11 @@ void qmm_sm80( const array& x, const array& w, const array& scales, - const array& biases, + const std::optional& biases, array& out, int bits, int group_size, + QuantizationMode mode, cu::CommandEncoder& encoder); bool supports_fp_qmv( diff --git a/mlx/backend/cuda/quantized/qmm/qmm_impl_sm80.cuh b/mlx/backend/cuda/quantized/qmm/qmm_impl_sm80.cuh index 46bfe80667..6442110bc8 100644 --- a/mlx/backend/cuda/quantized/qmm/qmm_impl_sm80.cuh +++ b/mlx/backend/cuda/quantized/qmm/qmm_impl_sm80.cuh @@ -13,6 +13,9 @@ namespace cutlass_gemm { using namespace cute; +template +constexpr bool has_zero_point_v = !cutlass::has_negative_zero_v; + template *>(raw_pointer_cast(w.data()))); - Element scale = s[0]; - Element zero_point = z[0]; + Element scale{s[0]}; cutlass::NumericArrayConverter converter; - auto w_dq = converter(w_vec) * scale + zero_point; + auto w_dq = converter(w_vec) * scale; + if constexpr (has_zero_point_v) { + Element zero_point{z[0]}; + w_dq = w_dq + zero_point; + } copy(make_tensor(make_rmem_ptr(&w_dq), out.layout()), out); } template (blockIdx); // Represent the full tensors. - Tensor mA_mkl = make_tensor(make_gmem_ptr(A), select<0,2,3>(shape_MNKL), dA); // (M,K,L) - Tensor mB_nkl = make_tensor(make_gmem_ptr(B), select<1,2,3>(shape_MNKL), dB); // (N,K,L) - Tensor mC_mnl = make_tensor(make_gmem_ptr(C), select<0,1,3>(shape_MNKL), dC); // (M,N,L) + Tensor mA_mkl = make_tensor(make_gmem_ptr(A), select<0,2,3>(shape_MNKL), dA); // (M,K,L) + Tensor mB_nkl = make_tensor(make_gmem_ptr(B), select<1,2,3>(shape_MNKL), dB); // (N,K,L) + Tensor mC_mnl = make_tensor(make_gmem_ptr(C), select<0,1,3>(shape_MNKL), dC); // (M,N,L) Tensor mS_nkl = make_tensor(make_gmem_ptr(S), S_layout); // (N,(group_size,K/group_size),L) Tensor mZ_nkl = make_tensor(make_gmem_ptr(Z), S_layout); // (N,(group_size,K/group_size),L) @@ -189,7 +195,9 @@ __global__ void qmm_sm80_kernel( // Copy S/Z: GMEM => RMEM. auto fetch_scales = [&](int tile) { copy(g2r_copy_s, g2r_tCgS(_,_,_,tile), g2r_tCrS); - copy(g2r_copy_s, g2r_tCgZ(_,_,_,tile), g2r_tCrZ); + if constexpr (has_zero_point_v) { + copy(g2r_copy_s, g2r_tCgZ(_,_,_,tile), g2r_tCrZ); + } }; // Copy A/B: SMEM => RMEM. auto fetch_smem = [&](auto block) { @@ -284,11 +292,11 @@ inline auto make_tiled_copy(NumThreads num_threads) { make_layout(make_shape(Int<1>{}, Int>{}))); } -template +template void qmm_sm80( const Element* A, const Quant* B, - const Element* S, + const Scale* S, const Element* Z, Element* C, int m, int n, int k, int l, @@ -345,11 +353,11 @@ void qmm_sm80( Copy_Atom s2r_atom_a; Copy_Atom>, Quant> s2r_atom_b; Copy_Atom>, Element> r2s_atom_c; - Copy_Atom, Element> g2r_atom_s; + Copy_Atom, Scale> g2r_atom_s; auto* kernel = &qmm_sm80_kernel< decltype(prob_shape), decltype(cta_tiler), - Element, Quant, + Element, Quant, Scale, decltype(dA), decltype(sA_layout), decltype(g2s_copy_a), decltype(s2r_atom_a), decltype(dB), decltype(sB_layout), decltype(g2s_copy_b), decltype(s2r_atom_b), decltype(dC), decltype(sC_layout), decltype(s2g_copy_c), decltype(r2s_atom_c), @@ -392,18 +400,6 @@ inline void dispatch_element_types(Dtype dtype, const char* tag, F&& f) { } } -template -inline void dispatch_quant_types(int bits, const char* tag, F&& f) { - if (bits == 4) { - f.template operator()(); - } else if (bits == 8) { - f.template operator()(); - } else { - throw std::invalid_argument( - fmt::format("{} {}-bit quantization is not supported.", tag, bits)); - } -} - template inline void dispatch_groups(int group_size, const char* tag, F&& f) { if (group_size == 32) { @@ -418,15 +414,43 @@ inline void dispatch_groups(int group_size, const char* tag, F&& f) { } } +template +inline void dispatch_quant_types( + int bits, + int group_size, + QuantizationMode mode, + const char* tag, + F&& f) { + if (mode == QuantizationMode::Mxfp4) { + f.template operator()(); + } else if (mode == QuantizationMode::Mxfp8) { + f.template operator()(); + } else if (mode == QuantizationMode::Nvfp4) { + f.template operator()(); + } else { + dispatch_groups(group_size, tag, [&]() { + if (bits == 4) { + f.template operator()(); + } else if (bits == 8) { + f.template operator()(); + } else { + throw std::invalid_argument( + fmt::format("{} {}-bit quantization is not supported.", tag, bits)); + } + }); + } +} + template void qmm_impl_sm80( const array& x, const array& w, const array& scales, - const array& biases, + const std::optional& biases, array& out, int bits, int group_size, + QuantizationMode mode, cu::CommandEncoder& encoder) { const char* tag = "[quantized_matmul]"; int m = out.shape(-2); @@ -435,48 +459,54 @@ void qmm_impl_sm80( int l = out.size() / (m * n); dispatch_element_types(out.dtype(), tag, [&]() { - dispatch_quant_types(bits, tag, [&]() { - dispatch_groups(group_size, tag, [&]() { - encoder.set_input_array(x); - encoder.set_input_array(w); - encoder.set_input_array(scales); - encoder.set_input_array(biases); - encoder.set_output_array(out); - cutlass_gemm::qmm_sm80( - gpu_ptr(x), - gpu_ptr(w), - gpu_ptr(scales), - gpu_ptr(biases), - gpu_ptr(out), - m, - n, - k, - l, - cute::Int{}, - [&](auto* kernel, - dim3 num_blocks, - dim3 block_dims, - uint32_t smem_bytes, - void** args) { - encoder.add_kernel_node_raw( - kernel, num_blocks, block_dims, {}, smem_bytes, args); - }); - }); - }); + dispatch_quant_types( + bits, + group_size, + mode, + tag, + [&]() { + encoder.set_input_array(x); + encoder.set_input_array(w); + encoder.set_input_array(scales); + if (biases) { + encoder.set_input_array(*biases); + } + encoder.set_output_array(out); + cutlass_gemm::qmm_sm80( + gpu_ptr(x), + gpu_ptr(w), + gpu_ptr(scales), + biases ? gpu_ptr(*biases) : nullptr, + gpu_ptr(out), + m, + n, + k, + l, + cute::Int{}, + [&](auto* kernel, + dim3 num_blocks, + dim3 block_dims, + uint32_t smem_bytes, + void** args) { + encoder.add_kernel_node_raw( + kernel, num_blocks, block_dims, {}, smem_bytes, args); + }); + }); }); } } // namespace mlx::core -#define QMM_SM80_GPU(TileM) \ - namespace mlx::core { \ - template void qmm_impl_sm80( \ - const array& x, \ - const array& w, \ - const array& scales, \ - const array& biases, \ - array& out, \ - int bits, \ - int group_size, \ - cu::CommandEncoder& encoder); \ +#define QMM_SM80_GPU(TileM) \ + namespace mlx::core { \ + template void qmm_impl_sm80( \ + const array& x, \ + const array& w, \ + const array& scales, \ + const std::optional& biases, \ + array& out, \ + int bits, \ + int group_size, \ + QuantizationMode mode, \ + cu::CommandEncoder& encoder); \ } diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index 03820511e4..d7252ec196 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -49,7 +49,7 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { }; auto call_qmm_sm80 = [&]() { out.set_data(cu::malloc_async(out.nbytes(), encoder)); - qmm_sm80(x, w, scales, *biases, out, bits_, group_size_, encoder); + qmm_sm80(x, w, scales, biases, out, bits_, group_size_, mode_, encoder); }; auto call_qmv = [&]() { out.set_data(cu::malloc_async(out.nbytes(), encoder));