Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .github/actions/setup-linux/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ runs:
echo PYTHONPATH=`python -c 'import sys; print(sys.path[-1])'` >> $GITHUB_ENV
echo "::endgroup::"

- name: Set swap space
if: ${{ startsWith(inputs.toolkit, 'cuda') && runner.arch == 'arm64' }}
uses: pierotofy/set-swap-space@fc79b3f67fa8a838184ce84a674ca12238d2c761
with:
swap-size-gb: 16

- name: Install CUDA toolkit
if: ${{ startsWith(inputs.toolkit, 'cuda') }}
shell: bash
Expand Down
6 changes: 6 additions & 0 deletions mlx/backend/cuda/quantized/qmm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@ target_sources(
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/qmm.cu
${CMAKE_CURRENT_SOURCE_DIR}/qmv.cu
${CMAKE_CURRENT_SOURCE_DIR}/fp_qmv.cu
${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_naive_m16_k.cu
${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_naive_m16_n.cu
${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_naive_m32_k.cu
${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_naive_m32_n.cu
${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_naive_m64_k.cu
${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_naive_m64_n.cu
${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_sm80_m16.cu
${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_sm80_m32.cu
${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_sm80_m64.cu
Expand Down
40 changes: 40 additions & 0 deletions mlx/backend/cuda/quantized/qmm/cute_dequant.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright © 2026 Apple Inc.

#pragma once

#include <cute/tensor.hpp>
#include <cutlass/numeric_conversion.h>

namespace cutlass_gemm {

// Whether the quant type is affine quantization.
template <typename Quant>
constexpr bool quant_has_bias_v = !cutlass::has_negative_zero_v<Quant>;

// Dequantize CuTe tensors with out = w * s + z.
__device__ __forceinline__ void
cute_vectorized_dequant(auto w, auto s, auto z, auto out) {
using namespace cute;
using Element = typename decltype(out)::value_type;
using Quant = typename decltype(w)::value_type;
// Scale must be one element.
CUTE_STATIC_ASSERT_V(cosize(s.layout()) == Int<1>{});
CUTE_STATIC_ASSERT_V(cosize(z.layout()) == Int<1>{});
// Quant must be contiguous.
auto layout = coalesce(w.layout());
CUTE_STATIC_ASSERT_V(stride(layout) == Int<1>{});
// Use cutlass for conversions.
constexpr int N = size(layout);
auto& w_vec = *(reinterpret_cast<const cutlass::Array<Quant, N>*>(
raw_pointer_cast(w.data())));
Element scale{s[0]};
cutlass::NumericArrayConverter<Element, Quant, N> converter;
auto w_dq = converter(w_vec) * scale;
if constexpr (quant_has_bias_v<Quant>) {
Element zero_point{z[0]};
w_dq = w_dq + zero_point;
}
copy(make_tensor(make_rmem_ptr<Element>(&w_dq), out.layout()), out);
}

} // namespace cutlass_gemm
69 changes: 69 additions & 0 deletions mlx/backend/cuda/quantized/qmm/qmm.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// Copyright © 2026 Apple Inc.

#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/quantized/qmm/qmm.h"

#include <cute/tensor.hpp>
Expand Down Expand Up @@ -172,6 +173,74 @@ void qmm_sm80(
}
}

// Defined in qmm_impl_naive_xxx.cu files.
template <int TileM, bool KMajor>
void qmm_impl_naive(
const array& x,
const array& w,
const array& scales,
const std::optional<array>& biases,
array& out,
int bits,
int group_size,
QuantizationMode mode,
cu::CommandEncoder& encoder);

bool supports_qmm_naive(
const array& x,
const array& w,
const array& scales,
const std::optional<array>& biases,
const array& out,
bool transpose,
int bits,
int group_size,
QuantizationMode mode,
cu::Device& device) {
int k = x.shape(-1);
if (k % std::max(64, group_size) != 0) {
return false;
}
if (!x.flags().row_contiguous || !w.flags().row_contiguous ||
!scales.flags().row_contiguous) {
return false;
}
if (biases && !biases->flags().row_contiguous) {
return false;
}
if (bits != 2 && bits != 4 && bits != 8) {
return false;
}
return true;
}

void qmm_naive(
const array& x,
const array& w,
const array& scales,
const std::optional<array>& biases,
array& out,
bool transpose,
int bits,
int group_size,
QuantizationMode mode,
cu::CommandEncoder& encoder) {
auto dispatch = [&]<int TileM, bool KMajor>() {
qmm_impl_naive<TileM, KMajor>(
x, w, scales, biases, out, bits, group_size, mode, encoder);
};
dispatch_bool(transpose, [&](auto k_major) {
int m = out.shape(-2);
if (m <= 16) {
dispatch.template operator()<16, k_major.value>();
} else if (m <= 32) {
dispatch.template operator()<32, k_major.value>();
} else {
dispatch.template operator()<64, k_major.value>();
}
});
}

bool supports_fp_qmv(
const array& x,
const array& w,
Expand Down
24 changes: 24 additions & 0 deletions mlx/backend/cuda/quantized/qmm/qmm.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,30 @@ void qmm_sm80(
QuantizationMode mode,
cu::CommandEncoder& encoder);

bool supports_qmm_naive(
const array& x,
const array& w,
const array& scales,
const std::optional<array>& biases,
const array& out,
bool transpose,
int bits,
int group_size,
QuantizationMode mode,
cu::Device& device);

void qmm_naive(
const array& x,
const array& w,
const array& scales,
const std::optional<array>& biases,
array& out,
bool transpose,
int bits,
int group_size,
QuantizationMode mode,
cu::CommandEncoder& encoder);

bool supports_fp_qmv(
const array& x,
const array& w,
Expand Down
Loading
Loading