Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion mlx/backend/cuda/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ namespace mlx::core {
throw std::runtime_error(#func " has no CUDA implementation."); \
}

NO_GPU(GatherQMM)
NO_GPU_MULTI(LUF)
NO_GPU_MULTI(QRF)
NO_GPU_MULTI(SVD)
Expand Down
13 changes: 13 additions & 0 deletions mlx/backend/cuda/quantized/qmm/qmm.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,4 +100,17 @@ void qmv(
QuantizationMode mode,
cu::CommandEncoder& encoder);

void gather_qmv(
const array& x,
const array& w,
const array& scales,
const std::optional<array>& biases,
const array& lhs_indices,
const array& rhs_indices,
array& out,
int bits,
int group_size,
QuantizationMode mode,
cu::CommandEncoder& encoder);

} // namespace mlx::core
204 changes: 184 additions & 20 deletions mlx/backend/cuda/quantized/qmm/qmv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -182,38 +182,24 @@ dequant_fma(const T* x, const Q* w, S scale, T bias, float* out) {
}

template <
int rows_per_block,
int elems_per_thread,
int group_size,
bool has_bias,
bool has_residue_k,
typename T,
typename Q,
typename S>
__global__ void qmv_kernel(
__device__ __forceinline__ void qmv_kernel_impl(
const T* x,
const Q* w,
const S* scales,
const T* biases,
T* out,
int row,
int w_batch,
int n,
int k,
bool broadcast_w) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);

// The row that this warp handles.
int row = block.group_index().x * rows_per_block + warp.meta_group_rank();
if (row >= n) {
return;
}

// Advance pointers of x/out.
int m = grid.dim_blocks().y;
int l = block.group_index().z;
x += block.group_index().y * k + m * k * l;
out += block.group_index().y * n + m * n * l;
int k) {
auto warp = cg::tiled_partition<WARP_SIZE>(cg::this_thread_block());

// For sub-byte Q, pointer moves by 8bits for each advance, e.g. w += 1 would
// move past 2 elements for 4-bit Q.
Expand All @@ -224,7 +210,6 @@ __global__ void qmv_kernel(
int groups_per_row = k / group_size;

// Advance w/scales/biases to current row.
int w_batch = broadcast_w ? 0 : l;
w += (static_cast<int64_t>(row) + n * w_batch) * w_step(k);
scales += (static_cast<int64_t>(row) + n * w_batch) * groups_per_row;
if constexpr (has_bias) {
Expand Down Expand Up @@ -274,6 +259,85 @@ __global__ void qmv_kernel(
}
}

template <
int rows_per_block,
int elems_per_thread,
int group_size,
bool has_bias,
bool has_residue_k,
typename T,
typename Q,
typename S>
__global__ void qmv_kernel(
const T* x,
const Q* w,
const S* scales,
const T* biases,
T* out,
int n,
int k,
bool broadcast_w) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);

// The row that this warp handles.
int row = block.group_index().x * rows_per_block + warp.meta_group_rank();
if (row >= n) {
return;
}

// Advance pointers of x/out for M and batch dimensions.
int m = grid.dim_blocks().y;
int l = block.group_index().z;
x += block.group_index().y * k + m * k * l;
out += block.group_index().y * n + m * n * l;
int w_batch = broadcast_w ? 0 : l;

qmv_kernel_impl<elems_per_thread, group_size, has_bias, has_residue_k>(
x, w, scales, biases, out, row, w_batch, n, k);
}

template <
int rows_per_block,
int elems_per_thread,
int group_size,
bool has_bias,
bool has_residue_k,
typename T,
typename Q,
typename S>
__global__ void gather_qmv_kernel(
const T* x,
const Q* w,
const S* scales,
const T* biases,
T* out,
const uint32_t* lhs_indices,
const uint32_t* rhs_indices,
int n,
int k) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);

int row = block.group_index().x * rows_per_block + warp.meta_group_rank();
if (row >= n) {
return;
}

int m = grid.dim_blocks().y;
int l = block.group_index().z;
uint32_t x_idx = lhs_indices[l];
uint32_t w_idx = rhs_indices[l];

x += block.group_index().y * k + m * k * x_idx;
out += block.group_index().y * n + m * n * l;

qmv_kernel_impl<elems_per_thread, group_size, has_bias, has_residue_k>(
x, w, scales, biases, out, row, w_idx, n, k);
}

template <
int group_size,
bool has_bias,
Expand Down Expand Up @@ -317,6 +381,51 @@ void qmv(
});
}

template <
int group_size,
bool has_bias,
typename T,
typename Q,
typename S,
typename F>
void gather_qmv(
const T* x,
const Q* w,
const S* scales,
const T* biases,
T* out,
const uint32_t* lhs_indices,
const uint32_t* rhs_indices,
int m,
int n,
int k,
int l,
F&& launch_kernel) {
constexpr int rows_per_block = 8;
constexpr int elems_per_thread =
(cute::sizeof_bits_v<T> <= 16 && cute::sizeof_bits_v<Q> <= 4) ? 16 : 8;

dim3 num_blocks{
uint32_t(cuda::ceil_div(n, rows_per_block)), uint32_t(m), uint32_t(l)};
dim3 block_dims{WARP_SIZE, rows_per_block};
void* args[] = {
&x, &w, &scales, &biases, &out, &lhs_indices, &rhs_indices, &n, &k};

dispatch_bool(k % (WARP_SIZE * elems_per_thread), [&](auto has_residue_k) {
auto* kernel = &gather_qmv_kernel<
rows_per_block,
elems_per_thread,
group_size,
has_bias,
has_residue_k.value,
T,
Q,
S>;
launch_kernel(
reinterpret_cast<void*>(kernel), num_blocks, block_dims, args);
});
}

} // namespace cu

template <typename F>
Expand Down Expand Up @@ -433,4 +542,59 @@ void qmv(
});
}

void gather_qmv(
const array& x,
const array& w,
const array& scales,
const std::optional<array>& biases,
const array& lhs_indices,
const array& rhs_indices,
array& out,
int bits,
int group_size,
QuantizationMode mode,
cu::CommandEncoder& encoder) {
const char* tag = "[gather_qmm]";
int m = out.shape(-2);
int n = out.shape(-1);
int k = x.shape(-1);
int l = out.size() / (m * n);

dispatch_element_types(out.dtype(), tag, [&]<typename T>() {
dispatch_quant_types<T>(
bits,
group_size,
mode,
tag,
[&]<typename Q, typename S, int group_size>() {
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(scales);
if (biases) {
encoder.set_input_array(*biases);
}
encoder.set_input_array(lhs_indices);
encoder.set_input_array(rhs_indices);
encoder.set_output_array(out);
constexpr bool has_bias = !cutlass::has_negative_zero_v<Q>;
cu::gather_qmv<group_size, has_bias>(
gpu_ptr<T>(x),
gpu_ptr<Q>(w),
gpu_ptr<S>(scales),
biases ? gpu_ptr<T>(*biases) : nullptr,
gpu_ptr<T>(out),
gpu_ptr<uint32_t>(lhs_indices),
gpu_ptr<uint32_t>(rhs_indices),
m,
n,
k,
l,
[&](auto* kernel, dim3 num_blocks, dim3 block_dims, void** args) {
encoder.add_kernel_node_raw(
kernel, num_blocks, block_dims, {}, 0, args);
});
});
});
}

} // namespace mlx::core
72 changes: 72 additions & 0 deletions mlx/backend/cuda/quantized/quantized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,78 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
quantization_mode_to_string(mode_)));
}

void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("GatherQMM::eval_gpu");
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);

const array& x = inputs[0];
const array& w = inputs[1];
const array& scales = inputs[2];
std::optional<array> biases;
if (inputs.size() == 6) {
biases = inputs[3];
}
array lhs_indices = ensure_contiguous(inputs[inputs.size() - 2], encoder, s);
array rhs_indices = ensure_contiguous(inputs[inputs.size() - 1], encoder, s);

int M = out.shape(-2);
int N = out.shape(-1);
int K = x.shape(-1);
int B = out.size() / (M * N);

auto supports = [&](auto&& f) {
return f(
x,
w,
scales,
biases,
out,
transpose_,
bits_,
group_size_,
mode_,
encoder.device());
};
bool can_use_qmv = supports(supports_qmv);

auto call_qmv = [&]() {
out.set_data(cu::malloc_async(out.nbytes(), encoder));
gather_qmv(
x,
w,
scales,
biases,
lhs_indices,
rhs_indices,
out,
bits_,
group_size_,
mode_,
encoder);
};

if (can_use_qmv) {
call_qmv();
return;
}

throw std::runtime_error(
fmt::format(
"[gather_qmm] No implementation for "
"problem shape: {}x{}x{}x{}, transpose: {}, "
"activation: {}, bits: {}, group size: {}, mode: \"{}\".",
M,
N,
K,
B,
transpose_,
dtype_to_string(x.dtype()),
bits_,
group_size_,
quantization_mode_to_string(mode_)));
}

void fast::Quantize::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
Expand Down
Loading