From 261e0b9990a8712bd75b8d663179ac22c9ab4927 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Mon, 23 Mar 2026 19:46:05 +0800 Subject: [PATCH 1/9] feat: implement BlockMaskedMM Add CUDA implementation for block-masked matrix multiplication. The approach pre-masks input matrices with a simple CUDA kernel, calls cuBLAS GEMM, then applies the output mask. --- mlx/backend/cuda/CMakeLists.txt | 1 + mlx/backend/cuda/gemms/block_mask.cu | 125 +++++++++++++++++++++++++++ mlx/backend/cuda/gemms/block_mask.h | 20 +++++ mlx/backend/cuda/matmul.cpp | 84 ++++++++++++++++++ mlx/backend/cuda/primitives.cpp | 1 - 5 files changed, 230 insertions(+), 1 deletion(-) create mode 100644 mlx/backend/cuda/gemms/block_mask.cu create mode 100644 mlx/backend/cuda/gemms/block_mask.h diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 1cee777bbe..83343b82e1 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -35,6 +35,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu + ${CMAKE_CURRENT_SOURCE_DIR}/gemms/block_mask.cu ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp ${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu diff --git a/mlx/backend/cuda/gemms/block_mask.cu b/mlx/backend/cuda/gemms/block_mask.cu new file mode 100644 index 0000000000..e382a45e7e --- /dev/null +++ b/mlx/backend/cuda/gemms/block_mask.cu @@ -0,0 +1,125 @@ +// Copyright © 2026 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/device/utils.cuh" +#include "mlx/backend/cuda/gemms/block_mask.h" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/dtype_utils.h" + +namespace mlx::core::cu { + +template +__global__ void block_mask_matrix( + T* data, + const MaskT* mask, + int block_size, + int rows, + int cols, + int64_t data_batch_stride, + const __grid_constant__ Shape mask_shape, + const __grid_constant__ Strides mask_strides, + int mask_ndim, + int64_t mask_row_stride, + int64_t mask_col_stride, + int mask_mat_size, + int batch_count) { + int64_t idx = int64_t(blockIdx.x) * blockDim.x + threadIdx.x; + int64_t total = int64_t(batch_count) * rows * cols; + if (idx >= total) + return; + + int mat_size = rows * cols; + int batch = idx / mat_size; + int within = idx % mat_size; + int row = within / cols; + int col = within % cols; + int mask_row = row / block_size; + int mask_col = col / block_size; + + // Compute mask batch offset (handles broadcasting via stride=0). + int64_t mask_batch_offset = elem_to_loc( + int64_t(batch) * mask_mat_size, + mask_shape.data(), + mask_strides.data(), + mask_ndim); + MaskT mask_val = mask + [mask_batch_offset + mask_row * mask_row_stride + + mask_col * mask_col_stride]; + + int64_t data_offset = int64_t(batch) * data_batch_stride + within; + if constexpr (std::is_same_v) { + if (!mask_val) { + data[data_offset] = T(0); + } + } else { + data[data_offset] *= T(mask_val); + } +} + +} // namespace mlx::core::cu + +namespace mlx::core { + +void apply_block_mask( + cu::CommandEncoder& encoder, + array& data, + const array& mask, + int block_size, + int rows, + int cols, + int64_t data_batch_stride, + int batch_count) { + encoder.set_input_array(mask); + encoder.set_output_array(data); + + int mask_ndim = mask.ndim(); + int64_t mask_row_stride = mask.strides()[mask_ndim - 2]; + int64_t mask_col_stride = mask.strides()[mask_ndim - 1]; + int mask_mat_size = mask.shape()[mask_ndim - 2] * mask.shape()[mask_ndim - 1]; + + int64_t total = int64_t(batch_count) * rows * cols; + constexpr int BLOCK = 256; + int grid = (total + BLOCK - 1) / BLOCK; + + auto launch = [&](auto type_tag) { + using T = cuda_type_t; + auto data_ptr = gpu_ptr(data); + + auto do_mask = [&](auto mask_tag) { + using MaskT = decltype(mask_tag); + const MaskT* mask_ptr; + if constexpr (std::is_same_v) { + mask_ptr = gpu_ptr(mask); + } else { + mask_ptr = gpu_ptr(mask); + } + encoder.add_kernel_node( + cu::block_mask_matrix, + grid, + BLOCK, + data_ptr, + mask_ptr, + block_size, + rows, + cols, + data_batch_stride, + const_param(mask.shape()), + const_param(mask.strides()), + mask_ndim, + mask_row_stride, + mask_col_stride, + mask_mat_size, + batch_count); + }; + + if (mask.dtype() == bool_) { + do_mask(bool{}); + } else { + do_mask(T{}); + } + }; + + dispatch_float_types(data.dtype(), "block_mask_matrix", launch); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/gemms/block_mask.h b/mlx/backend/cuda/gemms/block_mask.h new file mode 100644 index 0000000000..ef2b30ab83 --- /dev/null +++ b/mlx/backend/cuda/gemms/block_mask.h @@ -0,0 +1,20 @@ +// Copyright © 2026 Apple Inc. + +#pragma once + +#include "mlx/array.h" +#include "mlx/backend/cuda/device.h" + +namespace mlx::core { + +void apply_block_mask( + cu::CommandEncoder& encoder, + array& data, + const array& mask, + int block_size, + int rows, + int cols, + int64_t data_batch_stride, + int batch_count); + +} // namespace mlx::core diff --git a/mlx/backend/cuda/matmul.cpp b/mlx/backend/cuda/matmul.cpp index 8336590562..41a33de51a 100644 --- a/mlx/backend/cuda/matmul.cpp +++ b/mlx/backend/cuda/matmul.cpp @@ -2,6 +2,7 @@ #include "mlx/backend/common/matmul.h" #include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/gemms/block_mask.h" #include "mlx/backend/cuda/gemms/cublas_gemm.h" #include "mlx/backend/cuda/gemms/gemv.h" #include "mlx/backend/cuda/gemms/grouped_gemm.h" @@ -203,6 +204,89 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { encoder, M, N, K, a_transposed, lda, b_transposed, ldb, out, a, b); } +void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("BlockMaskedMM::eval_gpu"); + if (!issubdtype(out.dtype(), floating)) { + throw std::runtime_error( + "[BlockMaskedMM] Does not yet support non-floating point types."); + } + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + + auto& a_pre = inputs[0]; + auto& b_pre = inputs[1]; + + // Return 0s if either input is empty. + if (a_pre.size() == 0 || b_pre.size() == 0) { + array zero(0, a_pre.dtype()); + encoder.add_temporary(zero); + fill_gpu(zero, out, s); + return; + } + + out.set_data(cu::malloc_async(out.nbytes(), encoder)); + + int M = a_pre.shape(-2); + int N = b_pre.shape(-1); + int K = a_pre.shape(-1); + + if (M == 0 || N == 0) { + return; + } + if (K == 0) { + array zero(0, a_pre.dtype()); + encoder.add_temporary(zero); + fill_gpu(zero, out, s); + return; + } + + bool has_op_mask = inputs.size() > 3; + bool has_out_mask = inputs.size() == 3 || inputs.size() == 5; + + int batch_count = out.size() / (M * N); + + bool a_transposed; + int64_t lda; + array a = a_pre; + bool b_transposed; + int64_t ldb; + array b = b_pre; + + if (has_op_mask) { + // Make contiguous copies of A and B so we can mask them in-place. + a = contiguous_copy_gpu(a_pre, s); + encoder.add_temporary(a); + b = contiguous_copy_gpu(b_pre, s); + encoder.add_temporary(b); + a_transposed = false; + lda = K; + b_transposed = false; + ldb = N; + + // Apply operand masks. + auto& lhs_mask = inputs[inputs.size() - 2]; + auto& rhs_mask = inputs[inputs.size() - 1]; + apply_block_mask( + encoder, a, lhs_mask, block_size_, M, K, int64_t(M) * K, batch_count); + apply_block_mask( + encoder, b, rhs_mask, block_size_, K, N, int64_t(K) * N, batch_count); + } else { + std::tie(a_transposed, lda, a) = check_transpose(encoder, s, a_pre); + std::tie(b_transposed, ldb, b) = check_transpose(encoder, s, b_pre); + } + + // Run GEMM. + gemm_and_bias( + encoder, M, N, K, a_transposed, lda, b_transposed, ldb, out, a, b); + + // Apply output mask. + if (has_out_mask) { + auto& out_mask = inputs[2]; + apply_block_mask( + encoder, out, out_mask, block_size_, M, N, int64_t(M) * N, batch_count); + } +} + void AddMM::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("AddMM::eval_gpu"); auto& s = stream(); diff --git a/mlx/backend/cuda/primitives.cpp b/mlx/backend/cuda/primitives.cpp index 98dca5708f..79b455cf57 100644 --- a/mlx/backend/cuda/primitives.cpp +++ b/mlx/backend/cuda/primitives.cpp @@ -24,7 +24,6 @@ namespace mlx::core { throw std::runtime_error(#func " has no CUDA implementation."); \ } -NO_GPU(BlockMaskedMM) NO_GPU(GatherQMM) NO_GPU_MULTI(LUF) NO_GPU_MULTI(QRF) From 2e0dcf9ceb90c5c223d674eb4c49f8ac9618e813 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Mon, 23 Mar 2026 20:01:14 +0800 Subject: [PATCH 2/9] test: enable test script and add bench script --- benchmarks/python/block_masked_mm_bench.py | 193 +++++++++++++++++++++ python/tests/cuda_skip.py | 2 - 2 files changed, 193 insertions(+), 2 deletions(-) create mode 100644 benchmarks/python/block_masked_mm_bench.py diff --git a/benchmarks/python/block_masked_mm_bench.py b/benchmarks/python/block_masked_mm_bench.py new file mode 100644 index 0000000000..bcfeb1c4f8 --- /dev/null +++ b/benchmarks/python/block_masked_mm_bench.py @@ -0,0 +1,193 @@ +# Copyright © 2025 Apple Inc. + +import argparse +import time + +import mlx.core as mx +import numpy as np + +MLX_DTYPES = { + "float16": mx.float16, + "bfloat16": mx.bfloat16, + "float32": mx.float32, +} + + +def parse_cases(cases): + parsed = [] + for spec in cases.split(","): + parts = spec.split("x") + m, n, k, bs = int(parts[0]), int(parts[1]), int(parts[2]), int(parts[3]) + sparsity = float(parts[4]) if len(parts) > 4 else 0.5 + parsed.append((m, n, k, bs, sparsity)) + return parsed + + +def make_masks(m, n, k, block_size, sparsity, rng): + """Create block masks with given sparsity (fraction of blocks zeroed).""" + tm = (m + block_size - 1) // block_size + tn = (n + block_size - 1) // block_size + tk = (k + block_size - 1) // block_size + + lhs_mask = (rng.random((tm, tk)) >= sparsity).astype(np.bool_) + rhs_mask = (rng.random((tk, tn)) >= sparsity).astype(np.bool_) + out_mask = (rng.random((tm, tn)) >= sparsity).astype(np.bool_) + return lhs_mask, rhs_mask, out_mask + + +def mlx_naive_block_masked_mm(a, b, block_size, out_mask, lhs_mask, rhs_mask): + """MLX naive: expand masks and use regular matmul.""" + M, K = a.shape[-2], a.shape[-1] + N = b.shape[-1] + + def expand(mask, rows, cols): + e = mx.repeat(mx.repeat(mask, block_size, axis=-2), block_size, axis=-1) + return e[..., :rows, :cols] + + a_masked = a * expand(lhs_mask, M, K) + b_masked = b * expand(rhs_mask, K, N) + c = a_masked @ b_masked + c = c * expand(out_mask, M, N) + return c + + +def bench_mlx(fn, warmup, iters): + for _ in range(warmup): + y = fn() + mx.eval(y) + mx.synchronize() + + start = time.perf_counter() + for _ in range(iters): + y = fn() + mx.eval(y) + mx.synchronize() + return (time.perf_counter() - start) * 1e3 / iters + + +def print_table(headers, rows): + widths = [len(h) for h in headers] + for row in rows: + for i, cell in enumerate(row): + widths[i] = max(widths[i], len(cell)) + + def fmt_row(row): + return ( + "| " + + " | ".join(f"{cell:<{widths[i]}}" for i, cell in enumerate(row)) + + " |" + ) + + sep = "|-" + "-|-".join("-" * w for w in widths) + "-|" + print(fmt_row(headers)) + print(sep) + for row in rows: + print(fmt_row(row)) + + +def main(): + parser = argparse.ArgumentParser( + description="Benchmark block_masked_mm vs naive expand+matmul" + ) + parser.add_argument( + "--cases", + default=( + "256x256x256x32x0.5," + "512x512x512x32x0.5," + "1024x1024x1024x32x0.5," + "1024x1024x1024x64x0.5," + "2048x2048x2048x64x0.5," + "256x256x256x32x0.0," + "1024x1024x1024x32x0.0," + "1024x1024x1024x32x0.9" + ), + help="Comma-separated MxNxKxBSxSparsity list. Sparsity=fraction of blocks zeroed.", + ) + parser.add_argument( + "--dtype", + default="float32", + choices=["float16", "bfloat16", "float32"], + ) + parser.add_argument("--warmup", type=int, default=10) + parser.add_argument("--iters", type=int, default=50) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--no-check", action="store_true") + args = parser.parse_args() + + mlx_dtype = MLX_DTYPES[args.dtype] + + print(f"dtype={args.dtype} warmup={args.warmup} iters={args.iters}") + + headers = [ + "Case (MxNxKxBS)", + "Sparsity", + "MLX ms", + "Naive ms", + "Speedup", + ] + if not args.no_check: + headers.append("Max err") + rows = [] + + cases = parse_cases(args.cases) + for idx, (m, n, k, bs, sparsity) in enumerate(cases): + rng = np.random.default_rng(args.seed + idx) + a_np = rng.standard_normal((m, k)).astype(np.float32) + b_np = rng.standard_normal((k, n)).astype(np.float32) + lhs_mask_np, rhs_mask_np, out_mask_np = make_masks(m, n, k, bs, sparsity, rng) + + a_mx = mx.array(a_np, dtype=mlx_dtype) + b_mx = mx.array(b_np, dtype=mlx_dtype) + lhs_mask_mx = mx.array(lhs_mask_np) + rhs_mask_mx = mx.array(rhs_mask_np) + out_mask_mx = mx.array(out_mask_np) + mx.eval(a_mx, b_mx, lhs_mask_mx, rhs_mask_mx, out_mask_mx) + + # Correctness check: block_masked_mm vs naive expand+matmul + err_str = "" + if not args.no_check: + y_op = mx.block_masked_mm( + a_mx, b_mx, bs, out_mask_mx, lhs_mask_mx, rhs_mask_mx + ) + y_naive = mlx_naive_block_masked_mm( + a_mx, b_mx, bs, out_mask_mx, lhs_mask_mx, rhs_mask_mx + ) + mx.eval(y_op, y_naive) + err = float(mx.max(mx.abs(y_op - y_naive)).item()) + err_str = f"{err:.2e}" + + # Benchmark + t_mlx = bench_mlx( + lambda: mx.block_masked_mm( + a_mx, b_mx, bs, out_mask_mx, lhs_mask_mx, rhs_mask_mx + ), + args.warmup, + args.iters, + ) + t_naive = bench_mlx( + lambda: mlx_naive_block_masked_mm( + a_mx, b_mx, bs, out_mask_mx, lhs_mask_mx, rhs_mask_mx + ), + args.warmup, + args.iters, + ) + speedup = f"{t_naive / t_mlx:.2f}x" if t_mlx > 0 else "-" + + row = [ + f"{m}x{n}x{k}x{bs}", + f"{sparsity:.0%}", + f"{t_mlx:.3f}", + f"{t_naive:.3f}", + speedup, + ] + if not args.no_check: + row.append(err_str) + rows.append(row) + + print_table(headers, rows) + if not args.no_check: + print("err: max|block_masked_mm - naive_expand_matmul|") + + +if __name__ == "__main__": + main() diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index fe042da898..2fb66b8d84 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -1,7 +1,5 @@ cuda_skip = { "TestLayers.test_quantized_embedding", - # Block masked matmul NYI - "TestBlas.test_block_masked_matmul", # Gather matmul NYI "TestBlas.test_gather_matmul", "TestBlas.test_gather_matmul_grad", From 75ffb6b5349c9f8c0384cbe3a85a598bbb3ca510 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Mon, 23 Mar 2026 20:17:43 +0800 Subject: [PATCH 3/9] perf: fuse copy and block mask into single-pass kernel Replace the two-pass approach (contiguous_copy_gpu + apply_block_mask) with a single copy_with_block_mask kernel that reads source data and applies the mask in one pass. --- mlx/backend/cuda/gemms/block_mask.cu | 207 ++++++++++++++++++++++----- mlx/backend/cuda/gemms/block_mask.h | 9 ++ mlx/backend/cuda/matmul.cpp | 20 +-- 3 files changed, 189 insertions(+), 47 deletions(-) diff --git a/mlx/backend/cuda/gemms/block_mask.cu b/mlx/backend/cuda/gemms/block_mask.cu index e382a45e7e..b97011ec35 100644 --- a/mlx/backend/cuda/gemms/block_mask.cu +++ b/mlx/backend/cuda/gemms/block_mask.cu @@ -6,10 +6,12 @@ #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/dtype_utils.h" -namespace mlx::core::cu { +namespace mlx::core { + +namespace cu { template -__global__ void block_mask_matrix( +__global__ void block_mask_inplace( T* data, const MaskT* mask, int block_size, @@ -31,20 +33,15 @@ __global__ void block_mask_matrix( int mat_size = rows * cols; int batch = idx / mat_size; int within = idx % mat_size; - int row = within / cols; - int col = within % cols; - int mask_row = row / block_size; - int mask_col = col / block_size; - // Compute mask batch offset (handles broadcasting via stride=0). int64_t mask_batch_offset = elem_to_loc( int64_t(batch) * mask_mat_size, mask_shape.data(), mask_strides.data(), mask_ndim); MaskT mask_val = mask - [mask_batch_offset + mask_row * mask_row_stride + - mask_col * mask_col_stride]; + [mask_batch_offset + (within / cols) / block_size * mask_row_stride + + (within % cols) / block_size * mask_col_stride]; int64_t data_offset = int64_t(batch) * data_batch_stride + within; if constexpr (std::is_same_v) { @@ -56,9 +53,67 @@ __global__ void block_mask_matrix( } } -} // namespace mlx::core::cu +template +__global__ void block_mask_copy( + const T* src, + T* dst, + int block_size, + int rows, + int cols, + const __grid_constant__ Shape src_shape, + const __grid_constant__ Strides src_strides, + int src_ndim, + const MaskT* mask, + const __grid_constant__ Shape mask_shape, + const __grid_constant__ Strides mask_strides, + int mask_ndim, + int64_t mask_row_stride, + int64_t mask_col_stride, + int mask_mat_size, + int batch_count) { + int64_t idx = int64_t(blockIdx.x) * blockDim.x + threadIdx.x; + int mat_size = rows * cols; + int64_t total = int64_t(batch_count) * mat_size; + if (idx >= total) + return; -namespace mlx::core { + int batch = idx / mat_size; + int within = idx % mat_size; + + int64_t mask_batch_offset = elem_to_loc( + int64_t(batch) * mask_mat_size, + mask_shape.data(), + mask_strides.data(), + mask_ndim); + MaskT mask_val = mask + [mask_batch_offset + (within / cols) / block_size * mask_row_stride + + (within % cols) / block_size * mask_col_stride]; + + int64_t src_offset; + if constexpr (SrcContiguous) { + src_offset = idx; + } else { + src_offset = elem_to_loc( + int64_t(batch) * mat_size + within, + src_shape.data(), + src_strides.data(), + src_ndim); + } + + if constexpr (std::is_same_v) { + dst[idx] = mask_val ? src[src_offset] : T(0); + } else { + dst[idx] = src[src_offset] * T(mask_val); + } +} + +} // namespace cu + +namespace { + +constexpr int BLOCK_DIM = 256; + +} // namespace void apply_block_mask( cu::CommandEncoder& encoder, @@ -72,54 +127,138 @@ void apply_block_mask( encoder.set_input_array(mask); encoder.set_output_array(data); + int64_t total = int64_t(batch_count) * rows * cols; + int grid = (total + BLOCK_DIM - 1) / BLOCK_DIM; int mask_ndim = mask.ndim(); - int64_t mask_row_stride = mask.strides()[mask_ndim - 2]; - int64_t mask_col_stride = mask.strides()[mask_ndim - 1]; + int64_t mask_row_str = mask.strides()[mask_ndim - 2]; + int64_t mask_col_str = mask.strides()[mask_ndim - 1]; int mask_mat_size = mask.shape()[mask_ndim - 2] * mask.shape()[mask_ndim - 1]; + auto mask_shape = const_param(mask.shape()); + auto mask_strides = const_param(mask.strides()); + auto& mask_nc = const_cast(mask); - int64_t total = int64_t(batch_count) * rows * cols; - constexpr int BLOCK = 256; - int grid = (total + BLOCK - 1) / BLOCK; - - auto launch = [&](auto type_tag) { + dispatch_float_types(data.dtype(), "apply_block_mask", [&](auto type_tag) { using T = cuda_type_t; - auto data_ptr = gpu_ptr(data); - auto do_mask = [&](auto mask_tag) { + auto launch = [&](auto mask_tag) { using MaskT = decltype(mask_tag); - const MaskT* mask_ptr; + MaskT* mask_ptr; if constexpr (std::is_same_v) { - mask_ptr = gpu_ptr(mask); + mask_ptr = gpu_ptr(mask_nc); } else { - mask_ptr = gpu_ptr(mask); + mask_ptr = gpu_ptr(mask_nc); } encoder.add_kernel_node( - cu::block_mask_matrix, + cu::block_mask_inplace, grid, - BLOCK, - data_ptr, + BLOCK_DIM, + gpu_ptr(data), mask_ptr, block_size, rows, cols, data_batch_stride, - const_param(mask.shape()), - const_param(mask.strides()), + mask_shape, + mask_strides, mask_ndim, - mask_row_stride, - mask_col_stride, + mask_row_str, + mask_col_str, mask_mat_size, batch_count); }; if (mask.dtype() == bool_) { - do_mask(bool{}); + launch(bool{}); + } else { + launch(T{}); + } + }); +} + +array copy_with_block_mask( + cu::CommandEncoder& encoder, + const array& src, + const array& mask, + int block_size, + int rows, + int cols, + int batch_count) { + array dst(src.shape(), src.dtype(), nullptr, {}); + dst.set_data(cu::malloc_async(dst.nbytes(), encoder)); + encoder.add_temporary(dst); + + encoder.set_input_array(src); + encoder.set_input_array(mask); + encoder.set_output_array(dst); + + int64_t total = int64_t(batch_count) * rows * cols; + int grid = (total + BLOCK_DIM - 1) / BLOCK_DIM; + int mask_ndim = mask.ndim(); + int64_t mask_row_str = mask.strides()[mask_ndim - 2]; + int64_t mask_col_str = mask.strides()[mask_ndim - 1]; + int mask_mat_size = mask.shape()[mask_ndim - 2] * mask.shape()[mask_ndim - 1]; + auto src_shape = const_param(src.shape()); + auto src_strides = const_param(src.strides()); + int src_ndim = src.ndim(); + auto mask_shape = const_param(mask.shape()); + auto mask_strides_p = const_param(mask.strides()); + bool src_contiguous = src.flags().row_contiguous; + + auto& src_nc = const_cast(src); + auto& mask_nc = const_cast(mask); + + dispatch_float_types(src.dtype(), "copy_with_block_mask", [&](auto type_tag) { + using T = cuda_type_t; + auto src_ptr = gpu_ptr(src_nc); + auto dst_ptr = gpu_ptr(dst); + + auto launch = [&](auto mask_tag, auto contiguous_tag) { + using MaskT = decltype(mask_tag); + constexpr bool Contiguous = decltype(contiguous_tag)::value; + MaskT* mask_ptr; + if constexpr (std::is_same_v) { + mask_ptr = gpu_ptr(mask_nc); + } else { + mask_ptr = gpu_ptr(mask_nc); + } + encoder.add_kernel_node( + cu::block_mask_copy, + grid, + BLOCK_DIM, + src_ptr, + dst_ptr, + block_size, + rows, + cols, + src_shape, + src_strides, + src_ndim, + mask_ptr, + mask_shape, + mask_strides_p, + mask_ndim, + mask_row_str, + mask_col_str, + mask_mat_size, + batch_count); + }; + + auto dispatch_contiguous = [&](auto mask_tag) { + if (src_contiguous) { + launch(mask_tag, std::true_type{}); + } else { + launch(mask_tag, std::false_type{}); + } + }; + + if (mask.dtype() == bool_) { + dispatch_contiguous(bool{}); } else { - do_mask(T{}); + dispatch_contiguous(T{}); } - }; + }); - dispatch_float_types(data.dtype(), "block_mask_matrix", launch); + return dst; } } // namespace mlx::core diff --git a/mlx/backend/cuda/gemms/block_mask.h b/mlx/backend/cuda/gemms/block_mask.h index ef2b30ab83..01e3803d56 100644 --- a/mlx/backend/cuda/gemms/block_mask.h +++ b/mlx/backend/cuda/gemms/block_mask.h @@ -17,4 +17,13 @@ void apply_block_mask( int64_t data_batch_stride, int batch_count); +array copy_with_block_mask( + cu::CommandEncoder& encoder, + const array& src, + const array& mask, + int block_size, + int rows, + int cols, + int batch_count); + } // namespace mlx::core diff --git a/mlx/backend/cuda/matmul.cpp b/mlx/backend/cuda/matmul.cpp index 41a33de51a..c1118cf931 100644 --- a/mlx/backend/cuda/matmul.cpp +++ b/mlx/backend/cuda/matmul.cpp @@ -253,23 +253,17 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { array b = b_pre; if (has_op_mask) { - // Make contiguous copies of A and B so we can mask them in-place. - a = contiguous_copy_gpu(a_pre, s); - encoder.add_temporary(a); - b = contiguous_copy_gpu(b_pre, s); - encoder.add_temporary(b); + // Fused copy + mask in a single pass per matrix. + auto& lhs_mask = inputs[inputs.size() - 2]; + auto& rhs_mask = inputs[inputs.size() - 1]; + a = copy_with_block_mask( + encoder, a_pre, lhs_mask, block_size_, M, K, batch_count); + b = copy_with_block_mask( + encoder, b_pre, rhs_mask, block_size_, K, N, batch_count); a_transposed = false; lda = K; b_transposed = false; ldb = N; - - // Apply operand masks. - auto& lhs_mask = inputs[inputs.size() - 2]; - auto& rhs_mask = inputs[inputs.size() - 1]; - apply_block_mask( - encoder, a, lhs_mask, block_size_, M, K, int64_t(M) * K, batch_count); - apply_block_mask( - encoder, b, rhs_mask, block_size_, K, N, int64_t(K) * N, batch_count); } else { std::tie(a_transposed, lda, a) = check_transpose(encoder, s, a_pre); std::tie(b_transposed, ldb, b) = check_transpose(encoder, s, b_pre); From b258254240d695055def0156016523e9448d6be9 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Mon, 23 Mar 2026 22:08:35 +0800 Subject: [PATCH 4/9] fix: use int64 for block mask index arithmetic --- mlx/backend/cuda/gemms/block_mask.cu | 75 ++++++++++++---------------- mlx/backend/cuda/gemms/block_mask.h | 12 ++--- mlx/backend/cuda/matmul.cpp | 2 +- 3 files changed, 40 insertions(+), 49 deletions(-) diff --git a/mlx/backend/cuda/gemms/block_mask.cu b/mlx/backend/cuda/gemms/block_mask.cu index b97011ec35..b64b8a0d33 100644 --- a/mlx/backend/cuda/gemms/block_mask.cu +++ b/mlx/backend/cuda/gemms/block_mask.cu @@ -15,41 +15,35 @@ __global__ void block_mask_inplace( T* data, const MaskT* mask, int block_size, - int rows, - int cols, + int64_t rows, + int64_t cols, int64_t data_batch_stride, const __grid_constant__ Shape mask_shape, const __grid_constant__ Strides mask_strides, int mask_ndim, int64_t mask_row_stride, int64_t mask_col_stride, - int mask_mat_size, - int batch_count) { + int64_t mask_mat_size, + int64_t batch_count) { + int64_t mat_size = rows * cols; int64_t idx = int64_t(blockIdx.x) * blockDim.x + threadIdx.x; - int64_t total = int64_t(batch_count) * rows * cols; - if (idx >= total) + if (idx >= batch_count * mat_size) return; - int mat_size = rows * cols; - int batch = idx / mat_size; - int within = idx % mat_size; - + int64_t batch = idx / mat_size; + int64_t within = idx % mat_size; int64_t mask_batch_offset = elem_to_loc( - int64_t(batch) * mask_mat_size, - mask_shape.data(), - mask_strides.data(), - mask_ndim); + batch * mask_mat_size, mask_shape.data(), mask_strides.data(), mask_ndim); MaskT mask_val = mask [mask_batch_offset + (within / cols) / block_size * mask_row_stride + (within % cols) / block_size * mask_col_stride]; - int64_t data_offset = int64_t(batch) * data_batch_stride + within; if constexpr (std::is_same_v) { if (!mask_val) { - data[data_offset] = T(0); + data[batch * data_batch_stride + within] = T(0); } } else { - data[data_offset] *= T(mask_val); + data[batch * data_batch_stride + within] *= T(mask_val); } } @@ -58,8 +52,8 @@ __global__ void block_mask_copy( const T* src, T* dst, int block_size, - int rows, - int cols, + int64_t rows, + int64_t cols, const __grid_constant__ Shape src_shape, const __grid_constant__ Strides src_strides, int src_ndim, @@ -69,22 +63,17 @@ __global__ void block_mask_copy( int mask_ndim, int64_t mask_row_stride, int64_t mask_col_stride, - int mask_mat_size, - int batch_count) { + int64_t mask_mat_size, + int64_t batch_count) { + int64_t mat_size = rows * cols; int64_t idx = int64_t(blockIdx.x) * blockDim.x + threadIdx.x; - int mat_size = rows * cols; - int64_t total = int64_t(batch_count) * mat_size; - if (idx >= total) + if (idx >= batch_count * mat_size) return; - int batch = idx / mat_size; - int within = idx % mat_size; - + int64_t batch = idx / mat_size; + int64_t within = idx % mat_size; int64_t mask_batch_offset = elem_to_loc( - int64_t(batch) * mask_mat_size, - mask_shape.data(), - mask_strides.data(), - mask_ndim); + batch * mask_mat_size, mask_shape.data(), mask_strides.data(), mask_ndim); MaskT mask_val = mask [mask_batch_offset + (within / cols) / block_size * mask_row_stride + (within % cols) / block_size * mask_col_stride]; @@ -94,7 +83,7 @@ __global__ void block_mask_copy( src_offset = idx; } else { src_offset = elem_to_loc( - int64_t(batch) * mat_size + within, + batch * mat_size + within, src_shape.data(), src_strides.data(), src_ndim); @@ -120,19 +109,20 @@ void apply_block_mask( array& data, const array& mask, int block_size, - int rows, - int cols, + int64_t rows, + int64_t cols, int64_t data_batch_stride, - int batch_count) { + int64_t batch_count) { encoder.set_input_array(mask); encoder.set_output_array(data); - int64_t total = int64_t(batch_count) * rows * cols; + int64_t total = batch_count * rows * cols; int grid = (total + BLOCK_DIM - 1) / BLOCK_DIM; int mask_ndim = mask.ndim(); int64_t mask_row_str = mask.strides()[mask_ndim - 2]; int64_t mask_col_str = mask.strides()[mask_ndim - 1]; - int mask_mat_size = mask.shape()[mask_ndim - 2] * mask.shape()[mask_ndim - 1]; + int64_t mask_mat_size = + int64_t(mask.shape()[mask_ndim - 2]) * mask.shape()[mask_ndim - 1]; auto mask_shape = const_param(mask.shape()); auto mask_strides = const_param(mask.strides()); auto& mask_nc = const_cast(mask); @@ -180,9 +170,9 @@ array copy_with_block_mask( const array& src, const array& mask, int block_size, - int rows, - int cols, - int batch_count) { + int64_t rows, + int64_t cols, + int64_t batch_count) { array dst(src.shape(), src.dtype(), nullptr, {}); dst.set_data(cu::malloc_async(dst.nbytes(), encoder)); encoder.add_temporary(dst); @@ -191,12 +181,13 @@ array copy_with_block_mask( encoder.set_input_array(mask); encoder.set_output_array(dst); - int64_t total = int64_t(batch_count) * rows * cols; + int64_t total = batch_count * rows * cols; int grid = (total + BLOCK_DIM - 1) / BLOCK_DIM; int mask_ndim = mask.ndim(); int64_t mask_row_str = mask.strides()[mask_ndim - 2]; int64_t mask_col_str = mask.strides()[mask_ndim - 1]; - int mask_mat_size = mask.shape()[mask_ndim - 2] * mask.shape()[mask_ndim - 1]; + int64_t mask_mat_size = + int64_t(mask.shape()[mask_ndim - 2]) * mask.shape()[mask_ndim - 1]; auto src_shape = const_param(src.shape()); auto src_strides = const_param(src.strides()); int src_ndim = src.ndim(); diff --git a/mlx/backend/cuda/gemms/block_mask.h b/mlx/backend/cuda/gemms/block_mask.h index 01e3803d56..1ab1736c22 100644 --- a/mlx/backend/cuda/gemms/block_mask.h +++ b/mlx/backend/cuda/gemms/block_mask.h @@ -12,18 +12,18 @@ void apply_block_mask( array& data, const array& mask, int block_size, - int rows, - int cols, + int64_t rows, + int64_t cols, int64_t data_batch_stride, - int batch_count); + int64_t batch_count); array copy_with_block_mask( cu::CommandEncoder& encoder, const array& src, const array& mask, int block_size, - int rows, - int cols, - int batch_count); + int64_t rows, + int64_t cols, + int64_t batch_count); } // namespace mlx::core diff --git a/mlx/backend/cuda/matmul.cpp b/mlx/backend/cuda/matmul.cpp index c1118cf931..61b84ed7c1 100644 --- a/mlx/backend/cuda/matmul.cpp +++ b/mlx/backend/cuda/matmul.cpp @@ -243,7 +243,7 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { bool has_op_mask = inputs.size() > 3; bool has_out_mask = inputs.size() == 3 || inputs.size() == 5; - int batch_count = out.size() / (M * N); + int64_t batch_count = out.size() / (int64_t(M) * N); bool a_transposed; int64_t lda; From d41c5aa68105e65a624765c7a0d21797413b83ed Mon Sep 17 00:00:00 2001 From: Lyxot Date: Wed, 25 Mar 2026 20:26:53 +0800 Subject: [PATCH 5/9] apply pr comments --- mlx/backend/cuda/gemms/block_mask.cu | 125 ++++++++++++--------------- mlx/backend/cuda/matmul.cpp | 25 ++++-- 2 files changed, 74 insertions(+), 76 deletions(-) diff --git a/mlx/backend/cuda/gemms/block_mask.cu b/mlx/backend/cuda/gemms/block_mask.cu index b64b8a0d33..8fb63642c8 100644 --- a/mlx/backend/cuda/gemms/block_mask.cu +++ b/mlx/backend/cuda/gemms/block_mask.cu @@ -6,14 +6,18 @@ #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/dtype_utils.h" +#include + namespace mlx::core { +namespace cg = cooperative_groups; + namespace cu { template __global__ void block_mask_inplace( T* data, - const MaskT* mask, + MaskT* mask, int block_size, int64_t rows, int64_t cols, @@ -26,7 +30,7 @@ __global__ void block_mask_inplace( int64_t mask_mat_size, int64_t batch_count) { int64_t mat_size = rows * cols; - int64_t idx = int64_t(blockIdx.x) * blockDim.x + threadIdx.x; + int64_t idx = cg::this_grid().thread_rank(); if (idx >= batch_count * mat_size) return; @@ -49,7 +53,7 @@ __global__ void block_mask_inplace( template __global__ void block_mask_copy( - const T* src, + T* src, T* dst, int block_size, int64_t rows, @@ -57,7 +61,7 @@ __global__ void block_mask_copy( const __grid_constant__ Shape src_shape, const __grid_constant__ Strides src_strides, int src_ndim, - const MaskT* mask, + MaskT* mask, const __grid_constant__ Shape mask_shape, const __grid_constant__ Strides mask_strides, int mask_ndim, @@ -66,7 +70,7 @@ __global__ void block_mask_copy( int64_t mask_mat_size, int64_t batch_count) { int64_t mat_size = rows * cols; - int64_t idx = int64_t(blockIdx.x) * blockDim.x + threadIdx.x; + int64_t idx = cg::this_grid().thread_rank(); if (idx >= batch_count * mat_size) return; @@ -102,6 +106,15 @@ namespace { constexpr int BLOCK_DIM = 256; +template +void dispatch_mask_type(Dtype mask_dtype, F&& f) { + if (mask_dtype == bool_) { + f.template operator()(); + } else { + f.template operator()(); + } +} + } // namespace void apply_block_mask( @@ -116,8 +129,8 @@ void apply_block_mask( encoder.set_input_array(mask); encoder.set_output_array(data); - int64_t total = batch_count * rows * cols; - int grid = (total + BLOCK_DIM - 1) / BLOCK_DIM; + auto [num_blocks, block_dims] = get_launch_args( + data, data.size() > INT32_MAX, /*work_per_thread=*/1, BLOCK_DIM); int mask_ndim = mask.ndim(); int64_t mask_row_str = mask.strides()[mask_ndim - 2]; int64_t mask_col_str = mask.strides()[mask_ndim - 1]; @@ -130,20 +143,13 @@ void apply_block_mask( dispatch_float_types(data.dtype(), "apply_block_mask", [&](auto type_tag) { using T = cuda_type_t; - auto launch = [&](auto mask_tag) { - using MaskT = decltype(mask_tag); - MaskT* mask_ptr; - if constexpr (std::is_same_v) { - mask_ptr = gpu_ptr(mask_nc); - } else { - mask_ptr = gpu_ptr(mask_nc); - } + dispatch_mask_type(mask.dtype(), [&]() { encoder.add_kernel_node( cu::block_mask_inplace, - grid, - BLOCK_DIM, + num_blocks, + block_dims, gpu_ptr(data), - mask_ptr, + gpu_ptr(mask_nc), block_size, rows, cols, @@ -155,13 +161,7 @@ void apply_block_mask( mask_col_str, mask_mat_size, batch_count); - }; - - if (mask.dtype() == bool_) { - launch(bool{}); - } else { - launch(T{}); - } + }); }); } @@ -181,8 +181,8 @@ array copy_with_block_mask( encoder.set_input_array(mask); encoder.set_output_array(dst); - int64_t total = batch_count * rows * cols; - int grid = (total + BLOCK_DIM - 1) / BLOCK_DIM; + auto [num_blocks, block_dims] = + get_launch_args(dst, dst.size() > INT32_MAX, /*work_per_thread=*/1, 256); int mask_ndim = mask.ndim(); int64_t mask_row_str = mask.strides()[mask_ndim - 2]; int64_t mask_col_str = mask.strides()[mask_ndim - 1]; @@ -203,50 +203,33 @@ array copy_with_block_mask( auto src_ptr = gpu_ptr(src_nc); auto dst_ptr = gpu_ptr(dst); - auto launch = [&](auto mask_tag, auto contiguous_tag) { - using MaskT = decltype(mask_tag); - constexpr bool Contiguous = decltype(contiguous_tag)::value; - MaskT* mask_ptr; - if constexpr (std::is_same_v) { - mask_ptr = gpu_ptr(mask_nc); - } else { - mask_ptr = gpu_ptr(mask_nc); - } - encoder.add_kernel_node( - cu::block_mask_copy, - grid, - BLOCK_DIM, - src_ptr, - dst_ptr, - block_size, - rows, - cols, - src_shape, - src_strides, - src_ndim, - mask_ptr, - mask_shape, - mask_strides_p, - mask_ndim, - mask_row_str, - mask_col_str, - mask_mat_size, - batch_count); - }; - - auto dispatch_contiguous = [&](auto mask_tag) { - if (src_contiguous) { - launch(mask_tag, std::true_type{}); - } else { - launch(mask_tag, std::false_type{}); - } - }; - - if (mask.dtype() == bool_) { - dispatch_contiguous(bool{}); - } else { - dispatch_contiguous(T{}); - } + dispatch_mask_type(mask.dtype(), [&]() { + auto mask_ptr = gpu_ptr(mask_nc); + + dispatch_bool(src_contiguous, [&](auto contiguous_tag) { + constexpr bool Contiguous = decltype(contiguous_tag)::value; + encoder.add_kernel_node( + cu::block_mask_copy, + num_blocks, + block_dims, + src_ptr, + dst_ptr, + block_size, + rows, + cols, + src_shape, + src_strides, + src_ndim, + mask_ptr, + mask_shape, + mask_strides_p, + mask_ndim, + mask_row_str, + mask_col_str, + mask_mat_size, + batch_count); + }); + }); }); return dst; diff --git a/mlx/backend/cuda/matmul.cpp b/mlx/backend/cuda/matmul.cpp index 61b84ed7c1..71c1cccd25 100644 --- a/mlx/backend/cuda/matmul.cpp +++ b/mlx/backend/cuda/matmul.cpp @@ -1,6 +1,7 @@ // Copyright © 2025 Apple Inc. #include "mlx/backend/common/matmul.h" +#include "mlx/backend/common/utils.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/gemms/block_mask.h" #include "mlx/backend/cuda/gemms/cublas_gemm.h" @@ -253,13 +254,27 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { array b = b_pre; if (has_op_mask) { - // Fused copy + mask in a single pass per matrix. auto& lhs_mask = inputs[inputs.size() - 2]; auto& rhs_mask = inputs[inputs.size() - 1]; - a = copy_with_block_mask( - encoder, a_pre, lhs_mask, block_size_, M, K, batch_count); - b = copy_with_block_mask( - encoder, b_pre, rhs_mask, block_size_, K, N, batch_count); + + // When the input is donatable and row-contiguous, mask in-place to avoid + // a copy. Otherwise, fuse the copy and mask into a single pass. + auto mask_input = [&](const array& src, + const array& mask, + int64_t r, + int64_t c) -> array { + if (is_donatable(src, out) && src.flags().row_contiguous) { + array donated = src; + apply_block_mask( + encoder, donated, mask, block_size_, r, c, r * c, batch_count); + return donated; + } + return copy_with_block_mask( + encoder, src, mask, block_size_, r, c, batch_count); + }; + + a = mask_input(a_pre, lhs_mask, M, K); + b = mask_input(b_pre, rhs_mask, K, N); a_transposed = false; lda = K; b_transposed = false; From 414d5efd771972b27a07729d4a7f54d7ad121f39 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Wed, 25 Mar 2026 20:49:04 +0800 Subject: [PATCH 6/9] reuse block_mask_copy with src==dst --- mlx/backend/cuda/gemms/block_mask.cu | 176 +++++++++++---------------- mlx/backend/cuda/gemms/block_mask.h | 1 - mlx/backend/cuda/matmul.cpp | 28 +---- 3 files changed, 80 insertions(+), 125 deletions(-) diff --git a/mlx/backend/cuda/gemms/block_mask.cu b/mlx/backend/cuda/gemms/block_mask.cu index 8fb63642c8..96800b5f17 100644 --- a/mlx/backend/cuda/gemms/block_mask.cu +++ b/mlx/backend/cuda/gemms/block_mask.cu @@ -14,43 +14,6 @@ namespace cg = cooperative_groups; namespace cu { -template -__global__ void block_mask_inplace( - T* data, - MaskT* mask, - int block_size, - int64_t rows, - int64_t cols, - int64_t data_batch_stride, - const __grid_constant__ Shape mask_shape, - const __grid_constant__ Strides mask_strides, - int mask_ndim, - int64_t mask_row_stride, - int64_t mask_col_stride, - int64_t mask_mat_size, - int64_t batch_count) { - int64_t mat_size = rows * cols; - int64_t idx = cg::this_grid().thread_rank(); - if (idx >= batch_count * mat_size) - return; - - int64_t batch = idx / mat_size; - int64_t within = idx % mat_size; - int64_t mask_batch_offset = elem_to_loc( - batch * mask_mat_size, mask_shape.data(), mask_strides.data(), mask_ndim); - MaskT mask_val = mask - [mask_batch_offset + (within / cols) / block_size * mask_row_stride + - (within % cols) / block_size * mask_col_stride]; - - if constexpr (std::is_same_v) { - if (!mask_val) { - data[batch * data_batch_stride + within] = T(0); - } - } else { - data[batch * data_batch_stride + within] *= T(mask_val); - } -} - template __global__ void block_mask_copy( T* src, @@ -115,47 +78,55 @@ void dispatch_mask_type(Dtype mask_dtype, F&& f) { } } -} // namespace - -void apply_block_mask( +template +void block_mask_kernel( cu::CommandEncoder& encoder, - array& data, + T* src_ptr, + T* dst_ptr, const array& mask, int block_size, int64_t rows, int64_t cols, - int64_t data_batch_stride, + const Shape& src_shape_v, + const Strides& src_strides_v, + int src_ndim, + bool src_contiguous, int64_t batch_count) { - encoder.set_input_array(mask); - encoder.set_output_array(data); - - auto [num_blocks, block_dims] = get_launch_args( - data, data.size() > INT32_MAX, /*work_per_thread=*/1, BLOCK_DIM); + auto& mask_nc = const_cast(mask); int mask_ndim = mask.ndim(); int64_t mask_row_str = mask.strides()[mask_ndim - 2]; int64_t mask_col_str = mask.strides()[mask_ndim - 1]; int64_t mask_mat_size = int64_t(mask.shape()[mask_ndim - 2]) * mask.shape()[mask_ndim - 1]; + auto src_shape = const_param(src_shape_v); + auto src_strides = const_param(src_strides_v); auto mask_shape = const_param(mask.shape()); - auto mask_strides = const_param(mask.strides()); - auto& mask_nc = const_cast(mask); + auto mask_strides_p = const_param(mask.strides()); - dispatch_float_types(data.dtype(), "apply_block_mask", [&](auto type_tag) { - using T = cuda_type_t; + int64_t total = batch_count * rows * cols; + auto [num_blocks, block_dims] = get_launch_args( + total, src_shape_v, src_strides_v, total > INT32_MAX, 1, BLOCK_DIM); - dispatch_mask_type(mask.dtype(), [&]() { + dispatch_mask_type(mask.dtype(), [&]() { + auto mask_ptr = gpu_ptr(mask_nc); + + dispatch_bool(src_contiguous, [&](auto contiguous_tag) { + constexpr bool Contiguous = decltype(contiguous_tag)::value; encoder.add_kernel_node( - cu::block_mask_inplace, + cu::block_mask_copy, num_blocks, block_dims, - gpu_ptr(data), - gpu_ptr(mask_nc), + src_ptr, + dst_ptr, block_size, rows, cols, - data_batch_stride, + src_shape, + src_strides, + src_ndim, + mask_ptr, mask_shape, - mask_strides, + mask_strides_p, mask_ndim, mask_row_str, mask_col_str, @@ -165,6 +136,39 @@ void apply_block_mask( }); } +} // namespace + +void apply_block_mask( + cu::CommandEncoder& encoder, + array& data, + const array& mask, + int block_size, + int64_t rows, + int64_t cols, + int64_t batch_count) { + encoder.set_input_array(mask); + encoder.set_output_array(data); + + // Use block_mask_copy in-place (src == dst) with SrcContiguous=true. + dispatch_float_types(data.dtype(), "apply_block_mask", [&](auto type_tag) { + using T = cuda_type_t; + auto data_ptr = gpu_ptr(data); + block_mask_kernel( + encoder, + data_ptr, + data_ptr, + mask, + block_size, + rows, + cols, + data.shape(), + data.strides(), + data.ndim(), + /*src_contiguous=*/true, + batch_count); + }); +} + array copy_with_block_mask( cu::CommandEncoder& encoder, const array& src, @@ -181,55 +185,23 @@ array copy_with_block_mask( encoder.set_input_array(mask); encoder.set_output_array(dst); - auto [num_blocks, block_dims] = - get_launch_args(dst, dst.size() > INT32_MAX, /*work_per_thread=*/1, 256); - int mask_ndim = mask.ndim(); - int64_t mask_row_str = mask.strides()[mask_ndim - 2]; - int64_t mask_col_str = mask.strides()[mask_ndim - 1]; - int64_t mask_mat_size = - int64_t(mask.shape()[mask_ndim - 2]) * mask.shape()[mask_ndim - 1]; - auto src_shape = const_param(src.shape()); - auto src_strides = const_param(src.strides()); - int src_ndim = src.ndim(); - auto mask_shape = const_param(mask.shape()); - auto mask_strides_p = const_param(mask.strides()); - bool src_contiguous = src.flags().row_contiguous; - auto& src_nc = const_cast(src); - auto& mask_nc = const_cast(mask); dispatch_float_types(src.dtype(), "copy_with_block_mask", [&](auto type_tag) { using T = cuda_type_t; - auto src_ptr = gpu_ptr(src_nc); - auto dst_ptr = gpu_ptr(dst); - - dispatch_mask_type(mask.dtype(), [&]() { - auto mask_ptr = gpu_ptr(mask_nc); - - dispatch_bool(src_contiguous, [&](auto contiguous_tag) { - constexpr bool Contiguous = decltype(contiguous_tag)::value; - encoder.add_kernel_node( - cu::block_mask_copy, - num_blocks, - block_dims, - src_ptr, - dst_ptr, - block_size, - rows, - cols, - src_shape, - src_strides, - src_ndim, - mask_ptr, - mask_shape, - mask_strides_p, - mask_ndim, - mask_row_str, - mask_col_str, - mask_mat_size, - batch_count); - }); - }); + block_mask_kernel( + encoder, + gpu_ptr(src_nc), + gpu_ptr(dst), + mask, + block_size, + rows, + cols, + src.shape(), + src.strides(), + src.ndim(), + src.flags().row_contiguous, + batch_count); }); return dst; diff --git a/mlx/backend/cuda/gemms/block_mask.h b/mlx/backend/cuda/gemms/block_mask.h index 1ab1736c22..54972336de 100644 --- a/mlx/backend/cuda/gemms/block_mask.h +++ b/mlx/backend/cuda/gemms/block_mask.h @@ -14,7 +14,6 @@ void apply_block_mask( int block_size, int64_t rows, int64_t cols, - int64_t data_batch_stride, int64_t batch_count); array copy_with_block_mask( diff --git a/mlx/backend/cuda/matmul.cpp b/mlx/backend/cuda/matmul.cpp index 71c1cccd25..353df4721a 100644 --- a/mlx/backend/cuda/matmul.cpp +++ b/mlx/backend/cuda/matmul.cpp @@ -1,7 +1,6 @@ // Copyright © 2025 Apple Inc. #include "mlx/backend/common/matmul.h" -#include "mlx/backend/common/utils.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/gemms/block_mask.h" #include "mlx/backend/cuda/gemms/cublas_gemm.h" @@ -254,27 +253,13 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { array b = b_pre; if (has_op_mask) { + // Fused copy + mask in a single pass per matrix. auto& lhs_mask = inputs[inputs.size() - 2]; auto& rhs_mask = inputs[inputs.size() - 1]; - - // When the input is donatable and row-contiguous, mask in-place to avoid - // a copy. Otherwise, fuse the copy and mask into a single pass. - auto mask_input = [&](const array& src, - const array& mask, - int64_t r, - int64_t c) -> array { - if (is_donatable(src, out) && src.flags().row_contiguous) { - array donated = src; - apply_block_mask( - encoder, donated, mask, block_size_, r, c, r * c, batch_count); - return donated; - } - return copy_with_block_mask( - encoder, src, mask, block_size_, r, c, batch_count); - }; - - a = mask_input(a_pre, lhs_mask, M, K); - b = mask_input(b_pre, rhs_mask, K, N); + a = copy_with_block_mask( + encoder, a_pre, lhs_mask, block_size_, M, K, batch_count); + b = copy_with_block_mask( + encoder, b_pre, rhs_mask, block_size_, K, N, batch_count); a_transposed = false; lda = K; b_transposed = false; @@ -291,8 +276,7 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { // Apply output mask. if (has_out_mask) { auto& out_mask = inputs[2]; - apply_block_mask( - encoder, out, out_mask, block_size_, M, N, int64_t(M) * N, batch_count); + apply_block_mask(encoder, out, out_mask, block_size_, M, N, batch_count); } } From e5faca33ac414e3532693317bc50ee93b34236a3 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Thu, 26 Mar 2026 10:13:58 +0800 Subject: [PATCH 7/9] apply pr comments --- mlx/backend/cuda/gemms/block_mask.cu | 34 ++++++++++------------------ mlx/backend/cuda/matmul.cpp | 4 ++-- 2 files changed, 14 insertions(+), 24 deletions(-) diff --git a/mlx/backend/cuda/gemms/block_mask.cu b/mlx/backend/cuda/gemms/block_mask.cu index 96800b5f17..9e7f6906f0 100644 --- a/mlx/backend/cuda/gemms/block_mask.cu +++ b/mlx/backend/cuda/gemms/block_mask.cu @@ -15,7 +15,7 @@ namespace cg = cooperative_groups; namespace cu { template -__global__ void block_mask_copy( +__global__ void block_mask_copy_kernel( T* src, T* dst, int block_size, @@ -67,8 +67,6 @@ __global__ void block_mask_copy( namespace { -constexpr int BLOCK_DIM = 256; - template void dispatch_mask_type(Dtype mask_dtype, F&& f) { if (mask_dtype == bool_) { @@ -79,17 +77,15 @@ void dispatch_mask_type(Dtype mask_dtype, F&& f) { } template -void block_mask_kernel( +void block_mask_copy( cu::CommandEncoder& encoder, T* src_ptr, T* dst_ptr, + const array& src, const array& mask, int block_size, int64_t rows, int64_t cols, - const Shape& src_shape_v, - const Strides& src_strides_v, - int src_ndim, bool src_contiguous, int64_t batch_count) { auto& mask_nc = const_cast(mask); @@ -98,14 +94,12 @@ void block_mask_kernel( int64_t mask_col_str = mask.strides()[mask_ndim - 1]; int64_t mask_mat_size = int64_t(mask.shape()[mask_ndim - 2]) * mask.shape()[mask_ndim - 1]; - auto src_shape = const_param(src_shape_v); - auto src_strides = const_param(src_strides_v); + auto src_shape = const_param(src.shape()); + auto src_strides = const_param(src.strides()); auto mask_shape = const_param(mask.shape()); auto mask_strides_p = const_param(mask.strides()); - int64_t total = batch_count * rows * cols; - auto [num_blocks, block_dims] = get_launch_args( - total, src_shape_v, src_strides_v, total > INT32_MAX, 1, BLOCK_DIM); + auto [num_blocks, block_dims] = get_launch_args(src, src.size() > INT32_MAX); dispatch_mask_type(mask.dtype(), [&]() { auto mask_ptr = gpu_ptr(mask_nc); @@ -113,7 +107,7 @@ void block_mask_kernel( dispatch_bool(src_contiguous, [&](auto contiguous_tag) { constexpr bool Contiguous = decltype(contiguous_tag)::value; encoder.add_kernel_node( - cu::block_mask_copy, + cu::block_mask_copy_kernel, num_blocks, block_dims, src_ptr, @@ -123,7 +117,7 @@ void block_mask_kernel( cols, src_shape, src_strides, - src_ndim, + src.ndim(), mask_ptr, mask_shape, mask_strides_p, @@ -153,17 +147,15 @@ void apply_block_mask( dispatch_float_types(data.dtype(), "apply_block_mask", [&](auto type_tag) { using T = cuda_type_t; auto data_ptr = gpu_ptr(data); - block_mask_kernel( + block_mask_copy( encoder, data_ptr, data_ptr, + data, mask, block_size, rows, cols, - data.shape(), - data.strides(), - data.ndim(), /*src_contiguous=*/true, batch_count); }); @@ -189,17 +181,15 @@ array copy_with_block_mask( dispatch_float_types(src.dtype(), "copy_with_block_mask", [&](auto type_tag) { using T = cuda_type_t; - block_mask_kernel( + block_mask_copy( encoder, gpu_ptr(src_nc), gpu_ptr(dst), + src, mask, block_size, rows, cols, - src.shape(), - src.strides(), - src.ndim(), src.flags().row_contiguous, batch_count); }); diff --git a/mlx/backend/cuda/matmul.cpp b/mlx/backend/cuda/matmul.cpp index 353df4721a..f06fed3909 100644 --- a/mlx/backend/cuda/matmul.cpp +++ b/mlx/backend/cuda/matmul.cpp @@ -224,8 +224,6 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { return; } - out.set_data(cu::malloc_async(out.nbytes(), encoder)); - int M = a_pre.shape(-2); int N = b_pre.shape(-1); int K = a_pre.shape(-1); @@ -240,6 +238,8 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { return; } + out.set_data(cu::malloc_async(out.nbytes(), encoder)); + bool has_op_mask = inputs.size() > 3; bool has_out_mask = inputs.size() == 3 || inputs.size() == 5; From 4ea2b23fbc29a3df796072e989bdbb2098e6c8bb Mon Sep 17 00:00:00 2001 From: Lyxot Date: Thu, 26 Mar 2026 13:24:30 +0800 Subject: [PATCH 8/9] apply pr comments --- mlx/backend/cuda/gemms/block_mask.cu | 104 +++++++++++---------------- 1 file changed, 41 insertions(+), 63 deletions(-) diff --git a/mlx/backend/cuda/gemms/block_mask.cu b/mlx/backend/cuda/gemms/block_mask.cu index 9e7f6906f0..128f1a7816 100644 --- a/mlx/backend/cuda/gemms/block_mask.cu +++ b/mlx/backend/cuda/gemms/block_mask.cu @@ -76,12 +76,10 @@ void dispatch_mask_type(Dtype mask_dtype, F&& f) { } } -template void block_mask_copy( cu::CommandEncoder& encoder, - T* src_ptr, - T* dst_ptr, - const array& src, + array& src, + array& dst, const array& mask, int block_size, int64_t rows, @@ -94,38 +92,36 @@ void block_mask_copy( int64_t mask_col_str = mask.strides()[mask_ndim - 1]; int64_t mask_mat_size = int64_t(mask.shape()[mask_ndim - 2]) * mask.shape()[mask_ndim - 1]; - auto src_shape = const_param(src.shape()); - auto src_strides = const_param(src.strides()); - auto mask_shape = const_param(mask.shape()); - auto mask_strides_p = const_param(mask.strides()); auto [num_blocks, block_dims] = get_launch_args(src, src.size() > INT32_MAX); - dispatch_mask_type(mask.dtype(), [&]() { - auto mask_ptr = gpu_ptr(mask_nc); - - dispatch_bool(src_contiguous, [&](auto contiguous_tag) { - constexpr bool Contiguous = decltype(contiguous_tag)::value; - encoder.add_kernel_node( - cu::block_mask_copy_kernel, - num_blocks, - block_dims, - src_ptr, - dst_ptr, - block_size, - rows, - cols, - src_shape, - src_strides, - src.ndim(), - mask_ptr, - mask_shape, - mask_strides_p, - mask_ndim, - mask_row_str, - mask_col_str, - mask_mat_size, - batch_count); + dispatch_float_types(src.dtype(), "block_mask_copy", [&](auto type_tag) { + using T = cuda_type_t; + + dispatch_mask_type(mask.dtype(), [&]() { + dispatch_bool(src_contiguous, [&](auto contiguous_tag) { + constexpr bool Contiguous = decltype(contiguous_tag)::value; + encoder.add_kernel_node( + cu::block_mask_copy_kernel, + num_blocks, + block_dims, + gpu_ptr(src), + gpu_ptr(dst), + block_size, + rows, + cols, + const_param(src.shape()), + const_param(src.strides()), + src.ndim(), + gpu_ptr(mask_nc), + const_param(mask.shape()), + const_param(mask.strides()), + mask_ndim, + mask_row_str, + mask_col_str, + mask_mat_size, + batch_count); + }); }); }); } @@ -144,21 +140,8 @@ void apply_block_mask( encoder.set_output_array(data); // Use block_mask_copy in-place (src == dst) with SrcContiguous=true. - dispatch_float_types(data.dtype(), "apply_block_mask", [&](auto type_tag) { - using T = cuda_type_t; - auto data_ptr = gpu_ptr(data); - block_mask_copy( - encoder, - data_ptr, - data_ptr, - data, - mask, - block_size, - rows, - cols, - /*src_contiguous=*/true, - batch_count); - }); + block_mask_copy( + encoder, data, data, mask, block_size, rows, cols, true, batch_count); } array copy_with_block_mask( @@ -178,21 +161,16 @@ array copy_with_block_mask( encoder.set_output_array(dst); auto& src_nc = const_cast(src); - - dispatch_float_types(src.dtype(), "copy_with_block_mask", [&](auto type_tag) { - using T = cuda_type_t; - block_mask_copy( - encoder, - gpu_ptr(src_nc), - gpu_ptr(dst), - src, - mask, - block_size, - rows, - cols, - src.flags().row_contiguous, - batch_count); - }); + block_mask_copy( + encoder, + src_nc, + dst, + mask, + block_size, + rows, + cols, + src.flags().row_contiguous, + batch_count); return dst; } From e584a6b8a4f64e5f98024610945e850e70a2ced0 Mon Sep 17 00:00:00 2001 From: Cheng Date: Thu, 26 Mar 2026 00:57:03 -0700 Subject: [PATCH 9/9] Constness correct --- mlx/backend/cuda/CMakeLists.txt | 2 +- mlx/backend/cuda/gemms/block_mask.cu | 10 ++++------ 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 83343b82e1..1d08971050 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -28,6 +28,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/event.cu ${CMAKE_CURRENT_SOURCE_DIR}/fft.cu ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/gemms/block_mask.cu ${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp ${CMAKE_CURRENT_SOURCE_DIR}/gemms/grouped_gemm_unaligned.cu @@ -35,7 +36,6 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu - ${CMAKE_CURRENT_SOURCE_DIR}/gemms/block_mask.cu ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp ${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu diff --git a/mlx/backend/cuda/gemms/block_mask.cu b/mlx/backend/cuda/gemms/block_mask.cu index 128f1a7816..82160a1beb 100644 --- a/mlx/backend/cuda/gemms/block_mask.cu +++ b/mlx/backend/cuda/gemms/block_mask.cu @@ -16,7 +16,7 @@ namespace cu { template __global__ void block_mask_copy_kernel( - T* src, + const T* src, T* dst, int block_size, int64_t rows, @@ -78,7 +78,7 @@ void dispatch_mask_type(Dtype mask_dtype, F&& f) { void block_mask_copy( cu::CommandEncoder& encoder, - array& src, + const array& src, array& dst, const array& mask, int block_size, @@ -86,7 +86,6 @@ void block_mask_copy( int64_t cols, bool src_contiguous, int64_t batch_count) { - auto& mask_nc = const_cast(mask); int mask_ndim = mask.ndim(); int64_t mask_row_str = mask.strides()[mask_ndim - 2]; int64_t mask_col_str = mask.strides()[mask_ndim - 1]; @@ -113,7 +112,7 @@ void block_mask_copy( const_param(src.shape()), const_param(src.strides()), src.ndim(), - gpu_ptr(mask_nc), + gpu_ptr(mask), const_param(mask.shape()), const_param(mask.strides()), mask_ndim, @@ -160,10 +159,9 @@ array copy_with_block_mask( encoder.set_input_array(mask); encoder.set_output_array(dst); - auto& src_nc = const_cast(src); block_mask_copy( encoder, - src_nc, + src, dst, mask, block_size,