diff --git a/benchmarks/python/block_masked_mm_bench.py b/benchmarks/python/block_masked_mm_bench.py new file mode 100644 index 0000000000..bcfeb1c4f8 --- /dev/null +++ b/benchmarks/python/block_masked_mm_bench.py @@ -0,0 +1,193 @@ +# Copyright © 2025 Apple Inc. + +import argparse +import time + +import mlx.core as mx +import numpy as np + +MLX_DTYPES = { + "float16": mx.float16, + "bfloat16": mx.bfloat16, + "float32": mx.float32, +} + + +def parse_cases(cases): + parsed = [] + for spec in cases.split(","): + parts = spec.split("x") + m, n, k, bs = int(parts[0]), int(parts[1]), int(parts[2]), int(parts[3]) + sparsity = float(parts[4]) if len(parts) > 4 else 0.5 + parsed.append((m, n, k, bs, sparsity)) + return parsed + + +def make_masks(m, n, k, block_size, sparsity, rng): + """Create block masks with given sparsity (fraction of blocks zeroed).""" + tm = (m + block_size - 1) // block_size + tn = (n + block_size - 1) // block_size + tk = (k + block_size - 1) // block_size + + lhs_mask = (rng.random((tm, tk)) >= sparsity).astype(np.bool_) + rhs_mask = (rng.random((tk, tn)) >= sparsity).astype(np.bool_) + out_mask = (rng.random((tm, tn)) >= sparsity).astype(np.bool_) + return lhs_mask, rhs_mask, out_mask + + +def mlx_naive_block_masked_mm(a, b, block_size, out_mask, lhs_mask, rhs_mask): + """MLX naive: expand masks and use regular matmul.""" + M, K = a.shape[-2], a.shape[-1] + N = b.shape[-1] + + def expand(mask, rows, cols): + e = mx.repeat(mx.repeat(mask, block_size, axis=-2), block_size, axis=-1) + return e[..., :rows, :cols] + + a_masked = a * expand(lhs_mask, M, K) + b_masked = b * expand(rhs_mask, K, N) + c = a_masked @ b_masked + c = c * expand(out_mask, M, N) + return c + + +def bench_mlx(fn, warmup, iters): + for _ in range(warmup): + y = fn() + mx.eval(y) + mx.synchronize() + + start = time.perf_counter() + for _ in range(iters): + y = fn() + mx.eval(y) + mx.synchronize() + return (time.perf_counter() - start) * 1e3 / iters + + +def print_table(headers, rows): + widths = [len(h) for h in headers] + for row in rows: + for i, cell in enumerate(row): + widths[i] = max(widths[i], len(cell)) + + def fmt_row(row): + return ( + "| " + + " | ".join(f"{cell:<{widths[i]}}" for i, cell in enumerate(row)) + + " |" + ) + + sep = "|-" + "-|-".join("-" * w for w in widths) + "-|" + print(fmt_row(headers)) + print(sep) + for row in rows: + print(fmt_row(row)) + + +def main(): + parser = argparse.ArgumentParser( + description="Benchmark block_masked_mm vs naive expand+matmul" + ) + parser.add_argument( + "--cases", + default=( + "256x256x256x32x0.5," + "512x512x512x32x0.5," + "1024x1024x1024x32x0.5," + "1024x1024x1024x64x0.5," + "2048x2048x2048x64x0.5," + "256x256x256x32x0.0," + "1024x1024x1024x32x0.0," + "1024x1024x1024x32x0.9" + ), + help="Comma-separated MxNxKxBSxSparsity list. Sparsity=fraction of blocks zeroed.", + ) + parser.add_argument( + "--dtype", + default="float32", + choices=["float16", "bfloat16", "float32"], + ) + parser.add_argument("--warmup", type=int, default=10) + parser.add_argument("--iters", type=int, default=50) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--no-check", action="store_true") + args = parser.parse_args() + + mlx_dtype = MLX_DTYPES[args.dtype] + + print(f"dtype={args.dtype} warmup={args.warmup} iters={args.iters}") + + headers = [ + "Case (MxNxKxBS)", + "Sparsity", + "MLX ms", + "Naive ms", + "Speedup", + ] + if not args.no_check: + headers.append("Max err") + rows = [] + + cases = parse_cases(args.cases) + for idx, (m, n, k, bs, sparsity) in enumerate(cases): + rng = np.random.default_rng(args.seed + idx) + a_np = rng.standard_normal((m, k)).astype(np.float32) + b_np = rng.standard_normal((k, n)).astype(np.float32) + lhs_mask_np, rhs_mask_np, out_mask_np = make_masks(m, n, k, bs, sparsity, rng) + + a_mx = mx.array(a_np, dtype=mlx_dtype) + b_mx = mx.array(b_np, dtype=mlx_dtype) + lhs_mask_mx = mx.array(lhs_mask_np) + rhs_mask_mx = mx.array(rhs_mask_np) + out_mask_mx = mx.array(out_mask_np) + mx.eval(a_mx, b_mx, lhs_mask_mx, rhs_mask_mx, out_mask_mx) + + # Correctness check: block_masked_mm vs naive expand+matmul + err_str = "" + if not args.no_check: + y_op = mx.block_masked_mm( + a_mx, b_mx, bs, out_mask_mx, lhs_mask_mx, rhs_mask_mx + ) + y_naive = mlx_naive_block_masked_mm( + a_mx, b_mx, bs, out_mask_mx, lhs_mask_mx, rhs_mask_mx + ) + mx.eval(y_op, y_naive) + err = float(mx.max(mx.abs(y_op - y_naive)).item()) + err_str = f"{err:.2e}" + + # Benchmark + t_mlx = bench_mlx( + lambda: mx.block_masked_mm( + a_mx, b_mx, bs, out_mask_mx, lhs_mask_mx, rhs_mask_mx + ), + args.warmup, + args.iters, + ) + t_naive = bench_mlx( + lambda: mlx_naive_block_masked_mm( + a_mx, b_mx, bs, out_mask_mx, lhs_mask_mx, rhs_mask_mx + ), + args.warmup, + args.iters, + ) + speedup = f"{t_naive / t_mlx:.2f}x" if t_mlx > 0 else "-" + + row = [ + f"{m}x{n}x{k}x{bs}", + f"{sparsity:.0%}", + f"{t_mlx:.3f}", + f"{t_naive:.3f}", + speedup, + ] + if not args.no_check: + row.append(err_str) + rows.append(row) + + print_table(headers, rows) + if not args.no_check: + print("err: max|block_masked_mm - naive_expand_matmul|") + + +if __name__ == "__main__": + main() diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 1cee777bbe..1d08971050 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -28,6 +28,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/event.cu ${CMAKE_CURRENT_SOURCE_DIR}/fft.cu ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/gemms/block_mask.cu ${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp ${CMAKE_CURRENT_SOURCE_DIR}/gemms/grouped_gemm_unaligned.cu diff --git a/mlx/backend/cuda/gemms/block_mask.cu b/mlx/backend/cuda/gemms/block_mask.cu new file mode 100644 index 0000000000..82160a1beb --- /dev/null +++ b/mlx/backend/cuda/gemms/block_mask.cu @@ -0,0 +1,176 @@ +// Copyright © 2026 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/device/utils.cuh" +#include "mlx/backend/cuda/gemms/block_mask.h" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/dtype_utils.h" + +#include + +namespace mlx::core { + +namespace cg = cooperative_groups; + +namespace cu { + +template +__global__ void block_mask_copy_kernel( + const T* src, + T* dst, + int block_size, + int64_t rows, + int64_t cols, + const __grid_constant__ Shape src_shape, + const __grid_constant__ Strides src_strides, + int src_ndim, + MaskT* mask, + const __grid_constant__ Shape mask_shape, + const __grid_constant__ Strides mask_strides, + int mask_ndim, + int64_t mask_row_stride, + int64_t mask_col_stride, + int64_t mask_mat_size, + int64_t batch_count) { + int64_t mat_size = rows * cols; + int64_t idx = cg::this_grid().thread_rank(); + if (idx >= batch_count * mat_size) + return; + + int64_t batch = idx / mat_size; + int64_t within = idx % mat_size; + int64_t mask_batch_offset = elem_to_loc( + batch * mask_mat_size, mask_shape.data(), mask_strides.data(), mask_ndim); + MaskT mask_val = mask + [mask_batch_offset + (within / cols) / block_size * mask_row_stride + + (within % cols) / block_size * mask_col_stride]; + + int64_t src_offset; + if constexpr (SrcContiguous) { + src_offset = idx; + } else { + src_offset = elem_to_loc( + batch * mat_size + within, + src_shape.data(), + src_strides.data(), + src_ndim); + } + + if constexpr (std::is_same_v) { + dst[idx] = mask_val ? src[src_offset] : T(0); + } else { + dst[idx] = src[src_offset] * T(mask_val); + } +} + +} // namespace cu + +namespace { + +template +void dispatch_mask_type(Dtype mask_dtype, F&& f) { + if (mask_dtype == bool_) { + f.template operator()(); + } else { + f.template operator()(); + } +} + +void block_mask_copy( + cu::CommandEncoder& encoder, + const array& src, + array& dst, + const array& mask, + int block_size, + int64_t rows, + int64_t cols, + bool src_contiguous, + int64_t batch_count) { + int mask_ndim = mask.ndim(); + int64_t mask_row_str = mask.strides()[mask_ndim - 2]; + int64_t mask_col_str = mask.strides()[mask_ndim - 1]; + int64_t mask_mat_size = + int64_t(mask.shape()[mask_ndim - 2]) * mask.shape()[mask_ndim - 1]; + + auto [num_blocks, block_dims] = get_launch_args(src, src.size() > INT32_MAX); + + dispatch_float_types(src.dtype(), "block_mask_copy", [&](auto type_tag) { + using T = cuda_type_t; + + dispatch_mask_type(mask.dtype(), [&]() { + dispatch_bool(src_contiguous, [&](auto contiguous_tag) { + constexpr bool Contiguous = decltype(contiguous_tag)::value; + encoder.add_kernel_node( + cu::block_mask_copy_kernel, + num_blocks, + block_dims, + gpu_ptr(src), + gpu_ptr(dst), + block_size, + rows, + cols, + const_param(src.shape()), + const_param(src.strides()), + src.ndim(), + gpu_ptr(mask), + const_param(mask.shape()), + const_param(mask.strides()), + mask_ndim, + mask_row_str, + mask_col_str, + mask_mat_size, + batch_count); + }); + }); + }); +} + +} // namespace + +void apply_block_mask( + cu::CommandEncoder& encoder, + array& data, + const array& mask, + int block_size, + int64_t rows, + int64_t cols, + int64_t batch_count) { + encoder.set_input_array(mask); + encoder.set_output_array(data); + + // Use block_mask_copy in-place (src == dst) with SrcContiguous=true. + block_mask_copy( + encoder, data, data, mask, block_size, rows, cols, true, batch_count); +} + +array copy_with_block_mask( + cu::CommandEncoder& encoder, + const array& src, + const array& mask, + int block_size, + int64_t rows, + int64_t cols, + int64_t batch_count) { + array dst(src.shape(), src.dtype(), nullptr, {}); + dst.set_data(cu::malloc_async(dst.nbytes(), encoder)); + encoder.add_temporary(dst); + + encoder.set_input_array(src); + encoder.set_input_array(mask); + encoder.set_output_array(dst); + + block_mask_copy( + encoder, + src, + dst, + mask, + block_size, + rows, + cols, + src.flags().row_contiguous, + batch_count); + + return dst; +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/gemms/block_mask.h b/mlx/backend/cuda/gemms/block_mask.h new file mode 100644 index 0000000000..54972336de --- /dev/null +++ b/mlx/backend/cuda/gemms/block_mask.h @@ -0,0 +1,28 @@ +// Copyright © 2026 Apple Inc. + +#pragma once + +#include "mlx/array.h" +#include "mlx/backend/cuda/device.h" + +namespace mlx::core { + +void apply_block_mask( + cu::CommandEncoder& encoder, + array& data, + const array& mask, + int block_size, + int64_t rows, + int64_t cols, + int64_t batch_count); + +array copy_with_block_mask( + cu::CommandEncoder& encoder, + const array& src, + const array& mask, + int block_size, + int64_t rows, + int64_t cols, + int64_t batch_count); + +} // namespace mlx::core diff --git a/mlx/backend/cuda/matmul.cpp b/mlx/backend/cuda/matmul.cpp index 8336590562..f06fed3909 100644 --- a/mlx/backend/cuda/matmul.cpp +++ b/mlx/backend/cuda/matmul.cpp @@ -2,6 +2,7 @@ #include "mlx/backend/common/matmul.h" #include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/gemms/block_mask.h" #include "mlx/backend/cuda/gemms/cublas_gemm.h" #include "mlx/backend/cuda/gemms/gemv.h" #include "mlx/backend/cuda/gemms/grouped_gemm.h" @@ -203,6 +204,82 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { encoder, M, N, K, a_transposed, lda, b_transposed, ldb, out, a, b); } +void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("BlockMaskedMM::eval_gpu"); + if (!issubdtype(out.dtype(), floating)) { + throw std::runtime_error( + "[BlockMaskedMM] Does not yet support non-floating point types."); + } + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + + auto& a_pre = inputs[0]; + auto& b_pre = inputs[1]; + + // Return 0s if either input is empty. + if (a_pre.size() == 0 || b_pre.size() == 0) { + array zero(0, a_pre.dtype()); + encoder.add_temporary(zero); + fill_gpu(zero, out, s); + return; + } + + int M = a_pre.shape(-2); + int N = b_pre.shape(-1); + int K = a_pre.shape(-1); + + if (M == 0 || N == 0) { + return; + } + if (K == 0) { + array zero(0, a_pre.dtype()); + encoder.add_temporary(zero); + fill_gpu(zero, out, s); + return; + } + + out.set_data(cu::malloc_async(out.nbytes(), encoder)); + + bool has_op_mask = inputs.size() > 3; + bool has_out_mask = inputs.size() == 3 || inputs.size() == 5; + + int64_t batch_count = out.size() / (int64_t(M) * N); + + bool a_transposed; + int64_t lda; + array a = a_pre; + bool b_transposed; + int64_t ldb; + array b = b_pre; + + if (has_op_mask) { + // Fused copy + mask in a single pass per matrix. + auto& lhs_mask = inputs[inputs.size() - 2]; + auto& rhs_mask = inputs[inputs.size() - 1]; + a = copy_with_block_mask( + encoder, a_pre, lhs_mask, block_size_, M, K, batch_count); + b = copy_with_block_mask( + encoder, b_pre, rhs_mask, block_size_, K, N, batch_count); + a_transposed = false; + lda = K; + b_transposed = false; + ldb = N; + } else { + std::tie(a_transposed, lda, a) = check_transpose(encoder, s, a_pre); + std::tie(b_transposed, ldb, b) = check_transpose(encoder, s, b_pre); + } + + // Run GEMM. + gemm_and_bias( + encoder, M, N, K, a_transposed, lda, b_transposed, ldb, out, a, b); + + // Apply output mask. + if (has_out_mask) { + auto& out_mask = inputs[2]; + apply_block_mask(encoder, out, out_mask, block_size_, M, N, batch_count); + } +} + void AddMM::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("AddMM::eval_gpu"); auto& s = stream(); diff --git a/mlx/backend/cuda/primitives.cpp b/mlx/backend/cuda/primitives.cpp index 98dca5708f..79b455cf57 100644 --- a/mlx/backend/cuda/primitives.cpp +++ b/mlx/backend/cuda/primitives.cpp @@ -24,7 +24,6 @@ namespace mlx::core { throw std::runtime_error(#func " has no CUDA implementation."); \ } -NO_GPU(BlockMaskedMM) NO_GPU(GatherQMM) NO_GPU_MULTI(LUF) NO_GPU_MULTI(QRF) diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index fe042da898..2fb66b8d84 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -1,7 +1,5 @@ cuda_skip = { "TestLayers.test_quantized_embedding", - # Block masked matmul NYI - "TestBlas.test_block_masked_matmul", # Gather matmul NYI "TestBlas.test_gather_matmul", "TestBlas.test_gather_matmul_grad",