Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions mlx/backend/cuda/quantized/qmm/qmm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,11 @@ void qmm_impl_sm80(
const array& x,
const array& w,
const array& scales,
const array& biases,
const std::optional<array>& biases,
array& out,
int bits,
int group_size,
QuantizationMode mode,
cu::CommandEncoder& encoder);

bool supports_qmm_sm80(
Expand All @@ -128,11 +129,11 @@ bool supports_qmm_sm80(
if ((n % 128 != 0) || (k % std::max(64, group_size) != 0)) {
return false;
}
if (!biases) {
if (!x.flags().row_contiguous || !w.flags().row_contiguous ||
!scales.flags().row_contiguous) {
return false;
}
if (!x.flags().row_contiguous || !w.flags().row_contiguous ||
!scales.flags().row_contiguous || !biases->flags().row_contiguous) {
if (biases && !biases->flags().row_contiguous) {
return false;
}
if (x.dtype() != float16 && x.dtype() != bfloat16) {
Expand All @@ -141,10 +142,7 @@ bool supports_qmm_sm80(
if (!transpose) {
return false;
}
if (bits != 8) {
return false;
}
if (mode != QuantizationMode::Affine) {
if (bits != 4 && bits != 8) {
return false;
}
return true;
Expand All @@ -154,13 +152,15 @@ void qmm_sm80(
const array& x,
const array& w,
const array& scales,
const array& biases,
const std::optional<array>& biases,
array& out,
int bits,
int group_size,
QuantizationMode mode,
cu::CommandEncoder& encoder) {
auto dispatch = [&]<int TileM>() {
qmm_impl_sm80<TileM>(x, w, scales, biases, out, bits, group_size, encoder);
qmm_impl_sm80<TileM>(
x, w, scales, biases, out, bits, group_size, mode, encoder);
};
int m = out.shape(-2);
if (m <= 16) {
Expand Down
3 changes: 2 additions & 1 deletion mlx/backend/cuda/quantized/qmm/qmm.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,11 @@ void qmm_sm80(
const array& x,
const array& w,
const array& scales,
const array& biases,
const std::optional<array>& biases,
array& out,
int bits,
int group_size,
QuantizationMode mode,
cu::CommandEncoder& encoder);

bool supports_fp_qmv(
Expand Down
160 changes: 95 additions & 65 deletions mlx/backend/cuda/quantized/qmm/qmm_impl_sm80.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ namespace cutlass_gemm {

using namespace cute;

template <typename Quant>
constexpr bool has_zero_point_v = !cutlass::has_negative_zero_v<Quant>;

template <typename Element,
typename Quant,
typename SmemLayoutA,
Expand Down Expand Up @@ -42,15 +45,18 @@ dequant(const Q& w, const S& s, const Z& z, T out) {
using Element = typename T::value_type;
using Quant = typename Q::value_type;
auto& w_vec = *(reinterpret_cast<const cutlass::Array<Quant, N>*>(raw_pointer_cast(w.data())));
Element scale = s[0];
Element zero_point = z[0];
Element scale{s[0]};
cutlass::NumericArrayConverter<Element, Quant, N> converter;
auto w_dq = converter(w_vec) * scale + zero_point;
auto w_dq = converter(w_vec) * scale;
if constexpr (has_zero_point_v<Quant>) {
Element zero_point{z[0]};
w_dq = w_dq + zero_point;
}
copy(make_tensor(make_rmem_ptr<Element>(&w_dq), out.layout()), out);
}

template <typename ProblemShape, typename CtaTiler,
typename Element, typename Quant,
typename Element, typename Quant, typename Scale,
typename StrideA, typename SmemLayoutA, typename TiledCopyA, typename S2RAtomA,
typename StrideB, typename SmemLayoutB, typename TiledCopyB, typename S2RAtomB,
typename StrideC, typename SmemLayoutC, typename TiledCopyC, typename R2SAtomC,
Expand All @@ -60,7 +66,7 @@ __global__ void qmm_sm80_kernel(
const Element* A, StrideA dA, SmemLayoutA sA_layout, TiledCopyA g2s_copy_a, S2RAtomA s2r_atom_a,
const Quant* B, StrideB dB, SmemLayoutB sB_layout, TiledCopyB g2s_copy_b, S2RAtomB s2r_atom_b,
Element* C, StrideC dC, SmemLayoutC sC_layout, TiledCopyC s2g_copy_c, R2SAtomC r2s_atom_c,
const Element* S, const Element* Z, LayoutS S_layout, G2RAtomS g2r_atom_s, TiledMma mma) {
const Scale* S, const Element* Z, LayoutS S_layout, G2RAtomS g2r_atom_s, TiledMma mma) {
CUTE_STATIC_ASSERT_V(size(g2s_copy_a) == size(mma));
CUTE_STATIC_ASSERT_V(size(g2s_copy_b) == size(mma));
CUTE_STATIC_ASSERT_V(size(s2g_copy_c) == size(mma));
Expand All @@ -72,9 +78,9 @@ __global__ void qmm_sm80_kernel(
auto [m_coord, n_coord, l_coord] = static_cast<uint3>(blockIdx);

// Represent the full tensors.
Tensor mA_mkl = make_tensor(make_gmem_ptr(A), select<0,2,3>(shape_MNKL), dA); // (M,K,L)
Tensor mB_nkl = make_tensor(make_gmem_ptr(B), select<1,2,3>(shape_MNKL), dB); // (N,K,L)
Tensor mC_mnl = make_tensor(make_gmem_ptr(C), select<0,1,3>(shape_MNKL), dC); // (M,N,L)
Tensor mA_mkl = make_tensor(make_gmem_ptr(A), select<0,2,3>(shape_MNKL), dA); // (M,K,L)
Tensor mB_nkl = make_tensor(make_gmem_ptr<Quant>(B), select<1,2,3>(shape_MNKL), dB); // (N,K,L)
Tensor mC_mnl = make_tensor(make_gmem_ptr(C), select<0,1,3>(shape_MNKL), dC); // (M,N,L)

Tensor mS_nkl = make_tensor(make_gmem_ptr(S), S_layout); // (N,(group_size,K/group_size),L)
Tensor mZ_nkl = make_tensor(make_gmem_ptr(Z), S_layout); // (N,(group_size,K/group_size),L)
Expand Down Expand Up @@ -189,7 +195,9 @@ __global__ void qmm_sm80_kernel(
// Copy S/Z: GMEM => RMEM.
auto fetch_scales = [&](int tile) {
copy(g2r_copy_s, g2r_tCgS(_,_,_,tile), g2r_tCrS);
copy(g2r_copy_s, g2r_tCgZ(_,_,_,tile), g2r_tCrZ);
if constexpr (has_zero_point_v<Quant>) {
copy(g2r_copy_s, g2r_tCgZ(_,_,_,tile), g2r_tCrZ);
}
};
// Copy A/B: SMEM => RMEM.
auto fetch_smem = [&](auto block) {
Expand Down Expand Up @@ -284,11 +292,11 @@ inline auto make_tiled_copy(NumThreads num_threads) {
make_layout(make_shape(Int<1>{}, Int<bits / sizeof_bits_v<T>>{})));
}

template <int TileM = 16, typename Element, typename Quant, typename GroupSize, typename F>
template <int TileM = 16, typename Element, typename Quant, typename Scale, typename GroupSize, typename F>
void qmm_sm80(
const Element* A,
const Quant* B,
const Element* S,
const Scale* S,
const Element* Z,
Element* C,
int m, int n, int k, int l,
Expand Down Expand Up @@ -345,11 +353,11 @@ void qmm_sm80(
Copy_Atom<SM75_U32x4_LDSM_N, Element> s2r_atom_a;
Copy_Atom<UniversalCopy<uint_bit_t<2 * quant_bits>>, Quant> s2r_atom_b;
Copy_Atom<UniversalCopy<uint_bit_t<2 * element_bits>>, Element> r2s_atom_c;
Copy_Atom<UniversalCopy<Element>, Element> g2r_atom_s;
Copy_Atom<UniversalCopy<Scale>, Scale> g2r_atom_s;

auto* kernel = &qmm_sm80_kernel<
decltype(prob_shape), decltype(cta_tiler),
Element, Quant,
Element, Quant, Scale,
decltype(dA), decltype(sA_layout), decltype(g2s_copy_a), decltype(s2r_atom_a),
decltype(dB), decltype(sB_layout), decltype(g2s_copy_b), decltype(s2r_atom_b),
decltype(dC), decltype(sC_layout), decltype(s2g_copy_c), decltype(r2s_atom_c),
Expand Down Expand Up @@ -392,18 +400,6 @@ inline void dispatch_element_types(Dtype dtype, const char* tag, F&& f) {
}
}

template <typename F>
inline void dispatch_quant_types(int bits, const char* tag, F&& f) {
if (bits == 4) {
f.template operator()<cutlass::uint4b_t>();
} else if (bits == 8) {
f.template operator()<uint8_t>();
} else {
throw std::invalid_argument(
fmt::format("{} {}-bit quantization is not supported.", tag, bits));
}
}

template <typename F>
inline void dispatch_groups(int group_size, const char* tag, F&& f) {
if (group_size == 32) {
Expand All @@ -418,15 +414,43 @@ inline void dispatch_groups(int group_size, const char* tag, F&& f) {
}
}

template <typename T, typename F>
inline void dispatch_quant_types(
int bits,
int group_size,
QuantizationMode mode,
const char* tag,
F&& f) {
if (mode == QuantizationMode::Mxfp4) {
f.template operator()<cutlass::float_e2m1_t, cutlass::float_ue8m0_t, 32>();
} else if (mode == QuantizationMode::Mxfp8) {
f.template operator()<cutlass::float_e4m3_t, cutlass::float_ue8m0_t, 32>();
} else if (mode == QuantizationMode::Nvfp4) {
f.template operator()<cutlass::float_e2m1_t, cutlass::float_e4m3_t, 16>();
} else {
dispatch_groups(group_size, tag, [&]<int group_size>() {
if (bits == 4) {
f.template operator()<cutlass::uint4b_t, T, group_size>();
} else if (bits == 8) {
f.template operator()<uint8_t, T, group_size>();
} else {
throw std::invalid_argument(
fmt::format("{} {}-bit quantization is not supported.", tag, bits));
}
});
}
}

template <int TileM>
void qmm_impl_sm80(
const array& x,
const array& w,
const array& scales,
const array& biases,
const std::optional<array>& biases,
array& out,
int bits,
int group_size,
QuantizationMode mode,
cu::CommandEncoder& encoder) {
const char* tag = "[quantized_matmul]";
int m = out.shape(-2);
Expand All @@ -435,48 +459,54 @@ void qmm_impl_sm80(
int l = out.size() / (m * n);

dispatch_element_types(out.dtype(), tag, [&]<typename Element>() {
dispatch_quant_types(bits, tag, [&]<typename Quant>() {
dispatch_groups(group_size, tag, [&]<int group_size>() {
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_output_array(out);
cutlass_gemm::qmm_sm80<TileM>(
gpu_ptr<Element>(x),
gpu_ptr<Quant>(w),
gpu_ptr<Element>(scales),
gpu_ptr<Element>(biases),
gpu_ptr<Element>(out),
m,
n,
k,
l,
cute::Int<group_size>{},
[&](auto* kernel,
dim3 num_blocks,
dim3 block_dims,
uint32_t smem_bytes,
void** args) {
encoder.add_kernel_node_raw(
kernel, num_blocks, block_dims, {}, smem_bytes, args);
});
});
});
dispatch_quant_types<Element>(
bits,
group_size,
mode,
tag,
[&]<typename Quant, typename Scale, int group_size>() {
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(scales);
if (biases) {
encoder.set_input_array(*biases);
}
encoder.set_output_array(out);
cutlass_gemm::qmm_sm80<TileM>(
gpu_ptr<Element>(x),
gpu_ptr<Quant>(w),
gpu_ptr<Scale>(scales),
biases ? gpu_ptr<Element>(*biases) : nullptr,
gpu_ptr<Element>(out),
m,
n,
k,
l,
cute::Int<group_size>{},
[&](auto* kernel,
dim3 num_blocks,
dim3 block_dims,
uint32_t smem_bytes,
void** args) {
encoder.add_kernel_node_raw(
kernel, num_blocks, block_dims, {}, smem_bytes, args);
});
});
});
}

} // namespace mlx::core

#define QMM_SM80_GPU(TileM) \
namespace mlx::core { \
template void qmm_impl_sm80<TileM>( \
const array& x, \
const array& w, \
const array& scales, \
const array& biases, \
array& out, \
int bits, \
int group_size, \
cu::CommandEncoder& encoder); \
#define QMM_SM80_GPU(TileM) \
namespace mlx::core { \
template void qmm_impl_sm80<TileM>( \
const array& x, \
const array& w, \
const array& scales, \
const std::optional<array>& biases, \
array& out, \
int bits, \
int group_size, \
QuantizationMode mode, \
cu::CommandEncoder& encoder); \
}
2 changes: 1 addition & 1 deletion mlx/backend/cuda/quantized/quantized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
};
auto call_qmm_sm80 = [&]() {
out.set_data(cu::malloc_async(out.nbytes(), encoder));
qmm_sm80(x, w, scales, *biases, out, bits_, group_size_, encoder);
qmm_sm80(x, w, scales, biases, out, bits_, group_size_, mode_, encoder);
};
auto call_qmv = [&]() {
out.set_data(cu::malloc_async(out.nbytes(), encoder));
Expand Down
Loading