diff --git a/.github/actions/setup-linux/action.yml b/.github/actions/setup-linux/action.yml index 08bb528f57..6892ada17e 100644 --- a/.github/actions/setup-linux/action.yml +++ b/.github/actions/setup-linux/action.yml @@ -54,6 +54,12 @@ runs: echo PYTHONPATH=`python -c 'import sys; print(sys.path[-1])'` >> $GITHUB_ENV echo "::endgroup::" + - name: Set swap space + if: ${{ startsWith(inputs.toolkit, 'cuda') && runner.arch == 'arm64' }} + uses: pierotofy/set-swap-space@fc79b3f67fa8a838184ce84a674ca12238d2c761 + with: + swap-size-gb: 16 + - name: Install CUDA toolkit if: ${{ startsWith(inputs.toolkit, 'cuda') }} shell: bash diff --git a/mlx/backend/cuda/quantized/qmm/CMakeLists.txt b/mlx/backend/cuda/quantized/qmm/CMakeLists.txt index 9e057866de..3b88403e84 100644 --- a/mlx/backend/cuda/quantized/qmm/CMakeLists.txt +++ b/mlx/backend/cuda/quantized/qmm/CMakeLists.txt @@ -3,6 +3,12 @@ target_sources( PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/qmm.cu ${CMAKE_CURRENT_SOURCE_DIR}/qmv.cu ${CMAKE_CURRENT_SOURCE_DIR}/fp_qmv.cu + ${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_naive_m16_k.cu + ${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_naive_m16_n.cu + ${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_naive_m32_k.cu + ${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_naive_m32_n.cu + ${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_naive_m64_k.cu + ${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_naive_m64_n.cu ${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_sm80_m16.cu ${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_sm80_m32.cu ${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_sm80_m64.cu diff --git a/mlx/backend/cuda/quantized/qmm/cute_dequant.cuh b/mlx/backend/cuda/quantized/qmm/cute_dequant.cuh new file mode 100644 index 0000000000..e7f8dd30cf --- /dev/null +++ b/mlx/backend/cuda/quantized/qmm/cute_dequant.cuh @@ -0,0 +1,40 @@ +// Copyright © 2026 Apple Inc. + +#pragma once + +#include +#include + +namespace cutlass_gemm { + +// Whether the quant type is affine quantization. +template +constexpr bool quant_has_bias_v = !cutlass::has_negative_zero_v; + +// Dequantize CuTe tensors with out = w * s + z. +__device__ __forceinline__ void +cute_vectorized_dequant(auto w, auto s, auto z, auto out) { + using namespace cute; + using Element = typename decltype(out)::value_type; + using Quant = typename decltype(w)::value_type; + // Scale must be one element. + CUTE_STATIC_ASSERT_V(cosize(s.layout()) == Int<1>{}); + CUTE_STATIC_ASSERT_V(cosize(z.layout()) == Int<1>{}); + // Quant must be contiguous. + auto layout = coalesce(w.layout()); + CUTE_STATIC_ASSERT_V(stride(layout) == Int<1>{}); + // Use cutlass for conversions. + constexpr int N = size(layout); + auto& w_vec = *(reinterpret_cast*>( + raw_pointer_cast(w.data()))); + Element scale{s[0]}; + cutlass::NumericArrayConverter converter; + auto w_dq = converter(w_vec) * scale; + if constexpr (quant_has_bias_v) { + Element zero_point{z[0]}; + w_dq = w_dq + zero_point; + } + copy(make_tensor(make_rmem_ptr(&w_dq), out.layout()), out); +} + +} // namespace cutlass_gemm diff --git a/mlx/backend/cuda/quantized/qmm/qmm.cu b/mlx/backend/cuda/quantized/qmm/qmm.cu index c8cb6f5800..97a8f7f422 100644 --- a/mlx/backend/cuda/quantized/qmm/qmm.cu +++ b/mlx/backend/cuda/quantized/qmm/qmm.cu @@ -1,5 +1,6 @@ // Copyright © 2026 Apple Inc. +#include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/quantized/qmm/qmm.h" #include @@ -172,6 +173,74 @@ void qmm_sm80( } } +// Defined in qmm_impl_naive_xxx.cu files. +template +void qmm_impl_naive( + const array& x, + const array& w, + const array& scales, + const std::optional& biases, + array& out, + int bits, + int group_size, + QuantizationMode mode, + cu::CommandEncoder& encoder); + +bool supports_qmm_naive( + const array& x, + const array& w, + const array& scales, + const std::optional& biases, + const array& out, + bool transpose, + int bits, + int group_size, + QuantizationMode mode, + cu::Device& device) { + int k = x.shape(-1); + if (k % std::max(64, group_size) != 0) { + return false; + } + if (!x.flags().row_contiguous || !w.flags().row_contiguous || + !scales.flags().row_contiguous) { + return false; + } + if (biases && !biases->flags().row_contiguous) { + return false; + } + if (bits != 2 && bits != 4 && bits != 8) { + return false; + } + return true; +} + +void qmm_naive( + const array& x, + const array& w, + const array& scales, + const std::optional& biases, + array& out, + bool transpose, + int bits, + int group_size, + QuantizationMode mode, + cu::CommandEncoder& encoder) { + auto dispatch = [&]() { + qmm_impl_naive( + x, w, scales, biases, out, bits, group_size, mode, encoder); + }; + dispatch_bool(transpose, [&](auto k_major) { + int m = out.shape(-2); + if (m <= 16) { + dispatch.template operator()<16, k_major.value>(); + } else if (m <= 32) { + dispatch.template operator()<32, k_major.value>(); + } else { + dispatch.template operator()<64, k_major.value>(); + } + }); +} + bool supports_fp_qmv( const array& x, const array& w, diff --git a/mlx/backend/cuda/quantized/qmm/qmm.h b/mlx/backend/cuda/quantized/qmm/qmm.h index c96e4f28bc..a7464dee69 100644 --- a/mlx/backend/cuda/quantized/qmm/qmm.h +++ b/mlx/backend/cuda/quantized/qmm/qmm.h @@ -55,6 +55,30 @@ void qmm_sm80( QuantizationMode mode, cu::CommandEncoder& encoder); +bool supports_qmm_naive( + const array& x, + const array& w, + const array& scales, + const std::optional& biases, + const array& out, + bool transpose, + int bits, + int group_size, + QuantizationMode mode, + cu::Device& device); + +void qmm_naive( + const array& x, + const array& w, + const array& scales, + const std::optional& biases, + array& out, + bool transpose, + int bits, + int group_size, + QuantizationMode mode, + cu::CommandEncoder& encoder); + bool supports_fp_qmv( const array& x, const array& w, diff --git a/mlx/backend/cuda/quantized/qmm/qmm_impl_naive.cuh b/mlx/backend/cuda/quantized/qmm/qmm_impl_naive.cuh new file mode 100644 index 0000000000..fb187aa92d --- /dev/null +++ b/mlx/backend/cuda/quantized/qmm/qmm_impl_naive.cuh @@ -0,0 +1,473 @@ +// Copyright © 2026 Apple Inc. + +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/cuda/quantized/qmm/cute_dequant.cuh" +#include "mlx/backend/cuda/quantized/qmm/qmm.h" +#include "mlx/dtype_utils.h" + +// clang-format off + +// We can't put kernel code in mlx::core due to name conflicts of "Shape". +namespace cutlass_gemm { + +using namespace cute; + +template +struct SharedStorage { + ArrayEngine> A; + ArrayEngine> B; +}; + +__device__ __forceinline__ void +cute_naive_dequant(auto w, auto s, auto z, auto out) { + using Element = typename decltype(out)::value_type; + using Quant = typename decltype(w)::value_type; + using Scale = typename decltype(s)::value_type; + transform(w, out, [](Quant q) { return Element(q); } ); + transform(out, s, out, [](Element e, Scale s) { return e * Element(s); }); + if constexpr (quant_has_bias_v) { + transform(out, z, out, plus{}); + } +} + +__device__ __forceinline__ void +cute_dequant(auto w, auto s, auto z, auto out) { + if constexpr (stride(coalesce(w.layout())) == Int<1>{} && + is_static_v) { + cute_vectorized_dequant(w, s, z, out); + } else { + cute_naive_dequant(w, s, z, out); + } +} + +template +__global__ void qmm_naive_kernel( + ProblemShape shape_MNKL, CtaTiler cta_tiler, + const Element* A, StrideA dA, SmemLayoutA sA_layout, TiledCopyA copy_a, + const Quant* B, StrideB dB, SmemLayoutB sB_layout, TiledCopyB copy_b, + Element* C, StrideC dC, + const Scale* S, const Element* Z, LayoutS S_layout, + TiledMma mma) { + CUTE_STATIC_ASSERT_V(size(copy_a) == size(mma)); + CUTE_STATIC_ASSERT_V(size(copy_b) == size(mma)); + CUTE_STATIC_ASSERT_V(congruent(select<0,2,3>(shape_MNKL), dA)); + CUTE_STATIC_ASSERT_V(congruent(select<1,2,3>(shape_MNKL), dB)); + CUTE_STATIC_ASSERT_V(congruent(select<0,1,3>(shape_MNKL), dC)); + + int thread_idx = int(threadIdx.x); + auto [m_coord, n_coord, l_coord] = static_cast(blockIdx); + + // Represent the full tensors. + Tensor mA_mkl = make_tensor(make_gmem_ptr(A), select<0,2,3>(shape_MNKL), dA); // (M,K,L) + Tensor mB_nkl = make_tensor(make_gmem_ptr(B), select<1,2,3>(shape_MNKL), dB); // (N,K,L) + Tensor mC_mnl = make_tensor(make_gmem_ptr(C), select<0,1,3>(shape_MNKL), dC); // (M,N,L) + + Tensor mS_nkl = make_tensor(make_gmem_ptr(S), S_layout); // (N,(group_size,K/group_size),L) + Tensor mZ_nkl = make_tensor(make_gmem_ptr(Z), S_layout); // (N,(group_size,K/group_size),L) + + // Get batch slice. + Tensor mA = mA_mkl(_,_,l_coord); // (M,K) + Tensor mB = mB_nkl(_,_,l_coord); // (N,K) + Tensor mC = mC_mnl(_,_,l_coord); // (M,N) + + Tensor mS = mS_nkl(_,_,l_coord); // (N,(group_size,K/group_size)) + Tensor mZ = mZ_nkl(_,_,l_coord); // (N,(group_size,K/group_size)) + + // Get the appropriate blocks for this thread block. + auto cta_coord = make_coord(m_coord, n_coord, _); // (m,n,k) + Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k) + Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) + Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N) + + Tensor gS = local_tile(mS, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) + Tensor gZ = local_tile(mZ, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) + + auto m_max_coord = size<0>(shape_MNKL) - size<0>(gA) * m_coord; // M - BLK_M * m_coord + auto n_max_coord = size<1>(shape_MNKL) - size<0>(gB) * n_coord; // N - BLK_N * n_coord + + // Shared memory buffers. + extern __shared__ char shared_memory[]; + using SharedStorage = SharedStorage; + SharedStorage& smem = *reinterpret_cast(shared_memory); + Tensor sA = make_tensor(make_smem_ptr(smem.A.begin()), sA_layout); // (BLK_M,BLK_K) + Tensor sB = make_tensor(make_smem_ptr(smem.B.begin()), sB_layout); // (BLK_N,BLK_K) + + // Partition the copying of A/B/C tiles across the threads. + ThrCopy thr_copy_a = copy_a.get_slice(thread_idx); + Tensor tAgA = thr_copy_a.partition_S(gA); // (ACPY,ACPY_M,ACPY_K,k) + Tensor tAsA = thr_copy_a.partition_D(sA); // (ACPY,ACPY_M,ACPY_K) + Tensor tArA = make_fragment_like(tAsA); // (ACPY,ACPY_M,ACPY_K) + + ThrCopy thr_copy_b = copy_b.get_slice(thread_idx); + Tensor tBgB = thr_copy_b.partition_S(gB); // (BCPY,BCPY_N,BCPY_K,k) + Tensor tBsB = thr_copy_b.partition_D(sB); // (BCPY,BCPY_N,BCPY_K) + Tensor tBrB = make_fragment_like(tBsB); // (BCPY,BCPY_M,BCPY_K) + Tensor tBrB_dq = make_fragment_like(tBsB); // (BCPY,BCPY_M,BCPY_K) + Tensor tBgS = thr_copy_b.partition_S(gS); // (BCPY,BCPY_N,BCPY_K,k) + Tensor tBrS = make_fragment_like(tBgS(_,_,_,0)); // (BCPY,BCPY_N,BCPY_K) + Tensor tBgZ = thr_copy_b.partition_S(gZ); // (BCPY,BCPY_N,BCPY_K,k) + Tensor tBrZ = make_fragment_like(tBgZ(_,_,_,0)); // (BCPY,BCPY_N,BCPY_K) + + // MMA. + ThrMMA thr_mma = mma.get_slice(thread_idx); + Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K) + Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K) + Tensor tCgC = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N) + Tensor tCrC = thr_mma.make_fragment_C(tCgC); // (MMA,MMA_M,MMA_N) + + // Predicates for m/n bounds. + Tensor tApA = make_tensor(make_shape(size<1>(tAsA), size<2>(tAsA)), Stride<_1,_0>{}); // (CPY_M,CPY_K) + Tensor tBpB = make_tensor(make_shape(size<1>(tBsB), size<2>(tBsB)), Stride<_1,_0>{}); // (CPY_N,CPY_K) + Tensor cA = make_identity_tensor(make_shape(size<0>(sA), size<1>(sA))); // (BLK_M,BLK_K) + Tensor cB = make_identity_tensor(make_shape(size<0>(sB), size<1>(sB))); // (BLK_N,BLK_K) + Tensor cC = make_identity_tensor(make_shape(size<0>(gC), size<1>(gC))); // (M,N) + Tensor tAcA = thr_copy_a.partition_S(cA); // (CPY,CPY_M,CPY_K) + Tensor tBcB = thr_copy_b.partition_S(cB); // (CPY,CPY_N,CPY_K) + Tensor tCcC = thr_mma.partition_C(cC); // (MMA,MMA_M,MMA_N) + CUTE_UNROLL + for (int m = 0; m < size<0>(tApA); ++m) { + tApA(m,0) = get<0>(tAcA(0,m,0)) < m_max_coord; + } + CUTE_UNROLL + for (int n = 0; n < size<0>(tBpB); ++n) { + tBpB(n,0) = get<0>(tBcB(0,n,0)) < n_max_coord; + } + + // GMEM => RMEM. + auto fetch_gmem = [&](int tile) { + copy_if(copy_a, tApA, tAgA(_,_,_,tile), tArA); + copy_if(copy_b, tBpB, tBgB(_,_,_,tile), tBrB); + copy(tBgS(_,_,_,tile), tBrS); + copy(tBgZ(_,_,_,tile), tBrZ); + }; + // RMEM => SMEM. + auto store_smem = [&]() { + __syncthreads(); + copy(tArA, tAsA); + CUTE_UNROLL + for (int k = 0; k < size<2>(tBrB); ++k) { + CUTE_UNROLL + for (int n = 0; n < size<1>(tBrB); ++n) { + cute_dequant(tBrB(_,n,k), tBrS(_,n,k), tBrZ(_,n,k), tBrB_dq(_,n,k)); + } + } + copy(tBrB_dq, tBsB); + __syncthreads(); + }; + + // Prefetch first tile. + fetch_gmem(0); + + // Clear accumulators. + clear(tCrC); + + // Loop over CTA tiles. + auto K_TILE_MAX = size<3>(tAgA); + for (int tile = 0; tile < K_TILE_MAX; ++tile) { + store_smem(); + fetch_gmem((tile + 1 < K_TILE_MAX) ? tile + 1 : tile); + gemm(mma, tCsA, tCsB, tCrC); + } + + // Epilogue. + CUTE_UNROLL + for (int i = 0; i < size(tCrC); ++i) { + if ((get<0>(tCcC(i)) < m_max_coord) && (get<1>(tCcC(i)) < n_max_coord)) { + tCgC(i) = Element(tCrC(i)); + } + } +} + +template +inline constexpr auto make_matrix_stride(auto m, auto k) { + if constexpr (KMajor) { + return cute::make_stride(k, cute::Int<1>{}, m * k); + } else { + return cute::make_stride(cute::Int<1>{}, m, m * k); + } +} + +template +inline constexpr auto make_smem_layout(auto bM, auto bK) { + // TODO: Calculate swizzle based on tile shape. + if constexpr (KMajor) { + auto swizzle = composition(Swizzle<3,3,3>{}, + Layout>, + Stride<_8,Stride<_1,_64>>>{}); + return tile_to_shape(swizzle, make_shape(bM, bK)); + } else { + auto swizzle = composition(Swizzle<3,3,3>{}, + Layout, Stride<_1,_64>>{}); + return tile_to_shape(swizzle, make_shape(bM, bK)); + } +} + +template +inline constexpr auto make_tiled_mma() { + using Atom = std::conditional_t< + SM80, + std::conditional_t< + std::is_same_v, + SM80_16x8x16_F32F16F16F32_TN, + std::conditional_t< + std::is_same_v, + SM80_16x8x16_F32BF16BF16F32_TN, + UniversalFMA + > + >, + UniversalFMA>; + if constexpr (!SM80 || std::is_same_v) { + return make_tiled_mma(Atom{}, Layout>{}); + } else { + if constexpr (TileM >= 32) { + return make_tiled_mma(Atom{}, Layout>{}, Tile<_32,_32,_16>{}); + } else { + return make_tiled_mma(Atom{}, Layout>{}, Tile<_16,_32,_16>{}); + } + } +} + +template +inline auto make_tiled_copy(auto num_threads, auto bM, auto bK) { + auto n_read = Int<8>{}; + auto atom = Copy_Atom>>, T>{}; + if constexpr (KMajor) { + auto k_threads = bK / n_read; + return make_tiled_copy( + atom, + make_layout(make_shape(Int{}, k_threads), LayoutRight{}), + make_layout(make_shape(Int<1>{}, n_read))); + } else { + auto m_threads = bM / n_read; + return make_tiled_copy( + atom, + make_layout(make_shape(m_threads, Int{}), LayoutLeft{}), + make_layout(make_shape(n_read, Int<1>{}))); + } +} + +template +inline constexpr auto make_scales_layout(auto n, auto k, auto l, auto group_size) { + if constexpr (KMajor) { + return make_layout( + make_shape(n, make_shape(group_size, k / group_size), l), + make_stride(k / group_size, Stride<_0,_1>{}, n * k / group_size)); + } else { + return make_layout( + make_shape(make_shape(group_size, n / group_size), k, l), + make_stride(Stride<_0,_1>{}, n / group_size, n * k / group_size)); + } +} + +template +void qmm_naive( + const Element* A, + const Quant* B, + const Scale* S, + const Element* Z, + Element* C, + int m, int n, int k, int l, + bool broadcast_b, + auto group_size, + auto&& launch_kernel) { + // Define shapes (dynamic). + auto prob_shape = make_shape(m, n, k, l); // (M,N,K,L) + + // Define TN strides (mixed). + auto dA = make_stride(k, Int<1>{}, m * k); // (dM,dK,dL) + auto dB = make_matrix_stride(n, k); // (dN,dK,dL) + auto dC = make_stride(n, Int<1>{}, m * n); // (dM,dN,dL) + + // Define layout of scales/biases (mixed). + auto S_layout = make_scales_layout(n, k, l, group_size); + + // Handle broadcasting. + if (broadcast_b) { + get<2>(dB) = 0; + get<2>(stride(S_layout)) = 0; + } + + // Define CTA tile sizes (static). + auto bM = Int{}; + auto bN = Int<(!SM80 && group_size > 64) ? 64 : 128>{}; + auto bK = Int{}; + auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M,BLK_N,BLK_K) + + // Define MMA. + TiledMMA mma = make_tiled_mma(); + auto num_threads = size(mma); + + // Define the A/B smem layouts (static). + auto sA_layout = make_smem_layout(bM, bK); + auto sB_layout = make_smem_layout(bN, bK); + + // Atoms. + TiledCopy copy_a = make_tiled_copy(num_threads, bM, bK); + TiledCopy copy_b = make_tiled_copy(num_threads, bN, bK); + + auto* kernel = &qmm_naive_kernel< + decltype(prob_shape), decltype(cta_tiler), + Element, Quant, Scale, + decltype(dA), decltype(sA_layout), decltype(copy_a), + decltype(dB), decltype(sB_layout), decltype(copy_b), + decltype(dC), decltype(S_layout), decltype(mma)>; + + // Set L1 to be SMEM only. + size_t smem_bytes = sizeof(SharedStorage); + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); + cudaFuncSetAttribute(kernel, cudaFuncAttributePreferredSharedMemoryCarveout, 100); + + dim3 num_blocks(size(ceil_div(m, bM)), size(ceil_div(n, bN)), l); + dim3 block_dims(num_threads); + void* args[] = { + &prob_shape, &cta_tiler, + &A, &dA, &sA_layout, ©_a, + &B, &dB, &sB_layout, ©_b, + &C, &dC, + &S, &Z, &S_layout, + &mma}; + launch_kernel(reinterpret_cast(kernel), num_blocks, block_dims, smem_bytes, args); +} + +} // namespace cutlass_gemm + +// clang-format on + +namespace mlx::core { + +template +inline void dispatch_element_types(Dtype dtype, const char* tag, F&& f) { + if (dtype == float32) { + f.template operator()(); + } else if (dtype == float16) { + f.template operator()(); + } else if (dtype == bfloat16) { + f.template operator()(); + } else { + throw std::invalid_argument( + fmt::format("{} Unsupported dtype: {}.", tag, dtype_to_string(dtype))); + } +} + +template +inline void dispatch_groups(int group_size, const char* tag, F&& f) { + if (group_size == 32) { + f.template operator()<32>(); + } else if (group_size == 64) { + f.template operator()<64>(); + } else if (group_size == 128) { + f.template operator()<128>(); + } else { + throw std::invalid_argument( + fmt::format("{} Group size {} is not supported.", tag, group_size)); + } +} + +template +inline void dispatch_quant_types( + int bits, + int group_size, + QuantizationMode mode, + const char* tag, + F&& f) { + if (mode == QuantizationMode::Mxfp4) { + f.template operator()(); + } else if (mode == QuantizationMode::Mxfp8) { + f.template operator()(); + } else if (mode == QuantizationMode::Nvfp4) { + f.template operator()(); + } else { + dispatch_groups(group_size, tag, [&]() { + if (bits == 2) { + f.template operator()(); + } else if (bits == 4) { + f.template operator()(); + } else if (bits == 8) { + f.template operator()(); + } else { + throw std::invalid_argument( + fmt::format("{} {}-bit quantization is not supported.", tag, bits)); + } + }); + } +} + +template +void qmm_impl_naive( + const array& x, + const array& w, + const array& scales, + const std::optional& biases, + array& out, + int bits, + int group_size, + QuantizationMode mode, + cu::CommandEncoder& encoder) { + const char* tag = "[quantized_matmul]"; + int m = out.shape(-2); + int n = out.shape(-1); + int k = x.shape(-1); + int l = out.size() / (m * n); + bool broadcast_b = w.ndim() == 2; + + bool is_sm80 = encoder.device().compute_capability_major() >= 8; + dispatch_bool(is_sm80, [&](auto sm80) { + dispatch_element_types(out.dtype(), tag, [&]() { + dispatch_quant_types( + bits, + group_size, + mode, + tag, + [&]() { + encoder.set_input_array(x); + encoder.set_input_array(w); + encoder.set_input_array(scales); + if (biases) { + encoder.set_input_array(*biases); + } + encoder.set_output_array(out); + cutlass_gemm::qmm_naive( + gpu_ptr(x), + gpu_ptr(w), + gpu_ptr(scales), + biases ? gpu_ptr(*biases) : nullptr, + gpu_ptr(out), + m, + n, + k, + l, + broadcast_b, + cute::Int{}, + [&](auto* kernel, + dim3 num_blocks, + dim3 block_dims, + uint32_t smem_bytes, + void** args) { + encoder.add_kernel_node_raw( + kernel, num_blocks, block_dims, {}, smem_bytes, args); + }); + }); + }); + }); +} + +} // namespace mlx::core + +#define QMM_NAIVE_GPU(TileM, KMajor) \ + namespace mlx::core { \ + template void qmm_impl_naive( \ + const array& x, \ + const array& w, \ + const array& scales, \ + const std::optional& biases, \ + array& out, \ + int bits, \ + int group_size, \ + QuantizationMode mode, \ + cu::CommandEncoder& encoder); \ + } diff --git a/mlx/backend/cuda/quantized/qmm/qmm_impl_naive_m16_k.cu b/mlx/backend/cuda/quantized/qmm/qmm_impl_naive_m16_k.cu new file mode 100644 index 0000000000..4bead82a03 --- /dev/null +++ b/mlx/backend/cuda/quantized/qmm/qmm_impl_naive_m16_k.cu @@ -0,0 +1,5 @@ +// Copyright © 2026 Apple Inc. + +#include "mlx/backend/cuda/quantized/qmm/qmm_impl_naive.cuh" + +QMM_NAIVE_GPU(16, true) diff --git a/mlx/backend/cuda/quantized/qmm/qmm_impl_naive_m16_n.cu b/mlx/backend/cuda/quantized/qmm/qmm_impl_naive_m16_n.cu new file mode 100644 index 0000000000..993243d9e8 --- /dev/null +++ b/mlx/backend/cuda/quantized/qmm/qmm_impl_naive_m16_n.cu @@ -0,0 +1,5 @@ +// Copyright © 2026 Apple Inc. + +#include "mlx/backend/cuda/quantized/qmm/qmm_impl_naive.cuh" + +QMM_NAIVE_GPU(16, false) diff --git a/mlx/backend/cuda/quantized/qmm/qmm_impl_naive_m32_k.cu b/mlx/backend/cuda/quantized/qmm/qmm_impl_naive_m32_k.cu new file mode 100644 index 0000000000..def1b4e7b1 --- /dev/null +++ b/mlx/backend/cuda/quantized/qmm/qmm_impl_naive_m32_k.cu @@ -0,0 +1,5 @@ +// Copyright © 2026 Apple Inc. + +#include "mlx/backend/cuda/quantized/qmm/qmm_impl_naive.cuh" + +QMM_NAIVE_GPU(32, true) diff --git a/mlx/backend/cuda/quantized/qmm/qmm_impl_naive_m32_n.cu b/mlx/backend/cuda/quantized/qmm/qmm_impl_naive_m32_n.cu new file mode 100644 index 0000000000..bf1a500c5d --- /dev/null +++ b/mlx/backend/cuda/quantized/qmm/qmm_impl_naive_m32_n.cu @@ -0,0 +1,5 @@ +// Copyright © 2026 Apple Inc. + +#include "mlx/backend/cuda/quantized/qmm/qmm_impl_naive.cuh" + +QMM_NAIVE_GPU(32, false) diff --git a/mlx/backend/cuda/quantized/qmm/qmm_impl_naive_m64_k.cu b/mlx/backend/cuda/quantized/qmm/qmm_impl_naive_m64_k.cu new file mode 100644 index 0000000000..92f03c788d --- /dev/null +++ b/mlx/backend/cuda/quantized/qmm/qmm_impl_naive_m64_k.cu @@ -0,0 +1,5 @@ +// Copyright © 2026 Apple Inc. + +#include "mlx/backend/cuda/quantized/qmm/qmm_impl_naive.cuh" + +QMM_NAIVE_GPU(64, true) diff --git a/mlx/backend/cuda/quantized/qmm/qmm_impl_naive_m64_n.cu b/mlx/backend/cuda/quantized/qmm/qmm_impl_naive_m64_n.cu new file mode 100644 index 0000000000..1d1f040043 --- /dev/null +++ b/mlx/backend/cuda/quantized/qmm/qmm_impl_naive_m64_n.cu @@ -0,0 +1,5 @@ +// Copyright © 2026 Apple Inc. + +#include "mlx/backend/cuda/quantized/qmm/qmm_impl_naive.cuh" + +QMM_NAIVE_GPU(64, false) diff --git a/mlx/backend/cuda/quantized/qmm/qmm_impl_sm80.cuh b/mlx/backend/cuda/quantized/qmm/qmm_impl_sm80.cuh index e62ee947a7..895cdfdb5c 100644 --- a/mlx/backend/cuda/quantized/qmm/qmm_impl_sm80.cuh +++ b/mlx/backend/cuda/quantized/qmm/qmm_impl_sm80.cuh @@ -1,11 +1,9 @@ // Copyright © 2026 Apple Inc. +#include "mlx/backend/cuda/quantized/qmm/cute_dequant.cuh" #include "mlx/backend/cuda/quantized/qmm/qmm.h" #include "mlx/dtype_utils.h" -#include -#include - // clang-format off // We can't put kernel code in mlx::core due to name conflicts of "Shape". @@ -13,9 +11,6 @@ namespace cutlass_gemm { using namespace cute; -template -constexpr bool has_zero_point_v = !cutlass::has_negative_zero_v; - template -__device__ __forceinline__ void -dequant(const Q& w, const S& s, const Z& z, T out) { - // Scale must be one element. - CUTE_STATIC_ASSERT_V(cosize(s.layout()) == Int<1>{}); - CUTE_STATIC_ASSERT_V(cosize(z.layout()) == Int<1>{}); - // Quant must be contiguous. - auto layout = coalesce(w.layout()); - CUTE_STATIC_ASSERT_V(stride(layout) == Int<1>{}); - // Use cutlass for conversions. - constexpr int N = size(layout); - using Element = typename T::value_type; - using Quant = typename Q::value_type; - auto& w_vec = *(reinterpret_cast*>(raw_pointer_cast(w.data()))); - Element scale{s[0]}; - cutlass::NumericArrayConverter converter; - auto w_dq = converter(w_vec) * scale; - if constexpr (has_zero_point_v) { - Element zero_point{z[0]}; - w_dq = w_dq + zero_point; - } - copy(make_tensor(make_rmem_ptr(&w_dq), out.layout()), out); -} - template RMEM. auto fetch_scales = [&](int tile) { copy(g2r_copy_s, g2r_tCgS(_,_,_,tile), g2r_tCrS); - if constexpr (has_zero_point_v) { + if constexpr (quant_has_bias_v) { copy(g2r_copy_s, g2r_tCgZ(_,_,_,tile), g2r_tCrZ); } }; @@ -205,7 +176,11 @@ __global__ void qmm_sm80_kernel( copy(s2r_atom_b, s2r_tCsB(_,_,block,smem_pipe_read), s2r_tCrB(_,_,block)); CUTE_UNROLL for (int n = 0; n < size<1>(tCrB); ++n) { - dequant(tCrB(_,n,block), tCrS(_,n,block), tCrZ(_,n,block), tCrB_dq(_,n,block)); + cute_vectorized_dequant( + tCrB(_,n,block), + tCrS(_,n,block), + tCrZ(_,n,block), + tCrB_dq(_,n,block)); } }; @@ -300,6 +275,7 @@ void qmm_sm80( const Element* Z, Element* C, int m, int n, int k, int l, + bool broadcast_b, GroupSize group_size, F&& launch_kernel) { // Define shapes (dynamic). @@ -310,6 +286,17 @@ void qmm_sm80( auto dB = make_stride(k, Int<1>{}, n * k); // (dN,dK,dL) auto dC = make_stride(n, Int<1>{}, m * n); // (dM,dN,dL) + // Define layout of scales/biases (mixed). + auto S_layout = make_layout( + make_shape(n, make_shape(group_size, k / group_size), l), + make_stride(k / group_size, Stride<_0, _1>{}, n * k / group_size)); + + // Handle broadcasting. + if (broadcast_b) { + get<2>(dB) = 0; + get<2>(stride(S_layout)) = 0; + } + // Define CTA tile sizes (static). auto bM = Int{}; auto bN = Int<128>{}; @@ -337,11 +324,6 @@ void qmm_sm80( auto sS_layout = make_layout(make_shape(bN, make_shape(group_size, bS)), make_stride(bS, Stride<_0, _1>{})); - // Define layout of scales/biases (mixed). - auto S_layout = make_layout( - make_shape(n, make_shape(group_size, k / group_size), l), - make_stride(k / group_size, Stride<_0, _1>{}, n * k / group_size)); - // Atoms. constexpr int element_bits = sizeof_bits_v; constexpr int quant_bits = sizeof_bits_v; @@ -457,6 +439,7 @@ void qmm_impl_sm80( int n = out.shape(-1); int k = x.shape(-1); int l = out.size() / (m * n); + bool broadcast_b = w.ndim() == 2; dispatch_element_types(out.dtype(), tag, [&]() { dispatch_quant_types( @@ -482,6 +465,7 @@ void qmm_impl_sm80( n, k, l, + broadcast_b, cute::Int{}, [&](auto* kernel, dim3 num_blocks, diff --git a/mlx/backend/cuda/quantized/qmm/qmm_impl_sm90.cuh b/mlx/backend/cuda/quantized/qmm/qmm_impl_sm90.cuh index 0007b9db81..bb29cdafc5 100644 --- a/mlx/backend/cuda/quantized/qmm/qmm_impl_sm90.cuh +++ b/mlx/backend/cuda/quantized/qmm/qmm_impl_sm90.cuh @@ -36,6 +36,7 @@ void qmm_sm90( int64_t n, int64_t k, int64_t l, + bool broadcast_b, GroupSize group_size, F&& launch_kernel) { constexpr int kAlignmentA = 128 / sizeof_bits::value; @@ -93,6 +94,10 @@ void qmm_sm90( auto dB = make_stride(k, Int<1>{}, n * k); auto dS = make_stride(Int<1>{}, n, n * k / group_size); auto dD = make_stride(Int<1>{}, n, m * n); + if (broadcast_b) { + get<2>(dB) = 0; + get<2>(dS) = 0; + } Gemm gemm; typename Gemm::Arguments args{ @@ -188,6 +193,7 @@ void qmm_impl_sm90( int n = out.shape(-1); int k = x.shape(-1); int l = out.size() / (m * n); + bool broadcast_b = w.ndim() == 2; // FIXME: Copy happens for every call. array scales = transpose_last_2_dims(scales_, encoder, s); @@ -211,6 +217,7 @@ void qmm_impl_sm90( n, k, l, + broadcast_b, group_size, [&](auto* kernel, dim3 num_blocks, diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index d7252ec196..b0c5fc50be 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -40,6 +40,7 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { }; bool can_use_qmm_sm90 = supports(supports_qmm_sm90); bool can_use_qmm_sm80 = supports(supports_qmm_sm80); + bool can_use_qmm_naive = supports(supports_qmm_naive); bool can_use_fp_qmv = supports(supports_fp_qmv); bool can_use_qmv = supports(supports_qmv) || can_use_fp_qmv; @@ -51,6 +52,20 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { out.set_data(cu::malloc_async(out.nbytes(), encoder)); qmm_sm80(x, w, scales, biases, out, bits_, group_size_, mode_, encoder); }; + auto call_qmm_naive = [&]() { + out.set_data(cu::malloc_async(out.nbytes(), encoder)); + qmm_naive( + x, + w, + scales, + biases, + out, + transpose_, + bits_, + group_size_, + mode_, + encoder); + }; auto call_qmv = [&]() { out.set_data(cu::malloc_async(out.nbytes(), encoder)); if (can_use_fp_qmv) { @@ -83,6 +98,15 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { return; } + if (can_use_qmm_naive) { + if (can_use_qmv && (M * B < 8)) { + call_qmv(); + } else { + call_qmm_naive(); + } + return; + } + if (can_use_qmv) { call_qmv(); return; diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index 2fb66b8d84..e666ec3389 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -1,5 +1,4 @@ cuda_skip = { - "TestLayers.test_quantized_embedding", # Gather matmul NYI "TestBlas.test_gather_matmul", "TestBlas.test_gather_matmul_grad", @@ -25,16 +24,11 @@ "TestQuantized.test_gather_qmm_sorted", "TestQuantized.test_gather_qmm_grad", "TestQuantized.test_non_multiples", - "TestQuantized.test_qmm", - "TestQuantized.test_qmm_jvp", "TestQuantized.test_qmm_shapes", - "TestQuantized.test_qmm_vjp", "TestQuantized.test_fp_qvm", "TestQuantized.test_qvm", "TestQuantized.test_qvm_splitk", "TestQuantized.test_qmv_small_non_multiples", "TestQuantized.test_small_matrix", - "TestQuantized.test_throw", - "TestQuantized.test_vjp_scales_biases", "TestExportImport.test_export_quantized_model", }