-
Notifications
You must be signed in to change notification settings - Fork 1.6k
[CUDA] Implement BlockMaskedMM #3299
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+475
−3
Merged
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
261e0b9
feat: implement BlockMaskedMM
Lyxot 2e0dcf9
test: enable test script and add bench script
Lyxot 75ffb6b
perf: fuse copy and block mask into single-pass kernel
Lyxot b258254
fix: use int64 for block mask index arithmetic
Lyxot d41c5aa
apply pr comments
Lyxot 414d5ef
reuse block_mask_copy with src==dst
Lyxot e5faca3
apply pr comments
Lyxot 4ea2b23
apply pr comments
Lyxot e584a6b
Constness correct
zcbenz File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,193 @@ | ||
| # Copyright © 2025 Apple Inc. | ||
|
|
||
| import argparse | ||
| import time | ||
|
|
||
| import mlx.core as mx | ||
| import numpy as np | ||
|
|
||
| MLX_DTYPES = { | ||
| "float16": mx.float16, | ||
| "bfloat16": mx.bfloat16, | ||
| "float32": mx.float32, | ||
| } | ||
|
|
||
|
|
||
| def parse_cases(cases): | ||
| parsed = [] | ||
| for spec in cases.split(","): | ||
| parts = spec.split("x") | ||
| m, n, k, bs = int(parts[0]), int(parts[1]), int(parts[2]), int(parts[3]) | ||
| sparsity = float(parts[4]) if len(parts) > 4 else 0.5 | ||
| parsed.append((m, n, k, bs, sparsity)) | ||
| return parsed | ||
|
|
||
|
|
||
| def make_masks(m, n, k, block_size, sparsity, rng): | ||
| """Create block masks with given sparsity (fraction of blocks zeroed).""" | ||
| tm = (m + block_size - 1) // block_size | ||
| tn = (n + block_size - 1) // block_size | ||
| tk = (k + block_size - 1) // block_size | ||
|
|
||
| lhs_mask = (rng.random((tm, tk)) >= sparsity).astype(np.bool_) | ||
| rhs_mask = (rng.random((tk, tn)) >= sparsity).astype(np.bool_) | ||
| out_mask = (rng.random((tm, tn)) >= sparsity).astype(np.bool_) | ||
| return lhs_mask, rhs_mask, out_mask | ||
|
|
||
|
|
||
| def mlx_naive_block_masked_mm(a, b, block_size, out_mask, lhs_mask, rhs_mask): | ||
| """MLX naive: expand masks and use regular matmul.""" | ||
| M, K = a.shape[-2], a.shape[-1] | ||
| N = b.shape[-1] | ||
|
|
||
| def expand(mask, rows, cols): | ||
| e = mx.repeat(mx.repeat(mask, block_size, axis=-2), block_size, axis=-1) | ||
| return e[..., :rows, :cols] | ||
|
|
||
| a_masked = a * expand(lhs_mask, M, K) | ||
| b_masked = b * expand(rhs_mask, K, N) | ||
| c = a_masked @ b_masked | ||
| c = c * expand(out_mask, M, N) | ||
| return c | ||
|
|
||
|
|
||
| def bench_mlx(fn, warmup, iters): | ||
| for _ in range(warmup): | ||
| y = fn() | ||
| mx.eval(y) | ||
| mx.synchronize() | ||
|
|
||
| start = time.perf_counter() | ||
| for _ in range(iters): | ||
| y = fn() | ||
| mx.eval(y) | ||
| mx.synchronize() | ||
| return (time.perf_counter() - start) * 1e3 / iters | ||
|
|
||
|
|
||
| def print_table(headers, rows): | ||
| widths = [len(h) for h in headers] | ||
| for row in rows: | ||
| for i, cell in enumerate(row): | ||
| widths[i] = max(widths[i], len(cell)) | ||
|
|
||
| def fmt_row(row): | ||
| return ( | ||
| "| " | ||
| + " | ".join(f"{cell:<{widths[i]}}" for i, cell in enumerate(row)) | ||
| + " |" | ||
| ) | ||
|
|
||
| sep = "|-" + "-|-".join("-" * w for w in widths) + "-|" | ||
| print(fmt_row(headers)) | ||
| print(sep) | ||
| for row in rows: | ||
| print(fmt_row(row)) | ||
|
|
||
|
|
||
| def main(): | ||
| parser = argparse.ArgumentParser( | ||
| description="Benchmark block_masked_mm vs naive expand+matmul" | ||
| ) | ||
| parser.add_argument( | ||
| "--cases", | ||
| default=( | ||
| "256x256x256x32x0.5," | ||
| "512x512x512x32x0.5," | ||
| "1024x1024x1024x32x0.5," | ||
| "1024x1024x1024x64x0.5," | ||
| "2048x2048x2048x64x0.5," | ||
| "256x256x256x32x0.0," | ||
| "1024x1024x1024x32x0.0," | ||
| "1024x1024x1024x32x0.9" | ||
| ), | ||
| help="Comma-separated MxNxKxBSxSparsity list. Sparsity=fraction of blocks zeroed.", | ||
| ) | ||
| parser.add_argument( | ||
| "--dtype", | ||
| default="float32", | ||
| choices=["float16", "bfloat16", "float32"], | ||
| ) | ||
| parser.add_argument("--warmup", type=int, default=10) | ||
| parser.add_argument("--iters", type=int, default=50) | ||
| parser.add_argument("--seed", type=int, default=42) | ||
| parser.add_argument("--no-check", action="store_true") | ||
| args = parser.parse_args() | ||
|
|
||
| mlx_dtype = MLX_DTYPES[args.dtype] | ||
|
|
||
| print(f"dtype={args.dtype} warmup={args.warmup} iters={args.iters}") | ||
|
|
||
| headers = [ | ||
| "Case (MxNxKxBS)", | ||
| "Sparsity", | ||
| "MLX ms", | ||
| "Naive ms", | ||
| "Speedup", | ||
| ] | ||
| if not args.no_check: | ||
| headers.append("Max err") | ||
| rows = [] | ||
|
|
||
| cases = parse_cases(args.cases) | ||
| for idx, (m, n, k, bs, sparsity) in enumerate(cases): | ||
| rng = np.random.default_rng(args.seed + idx) | ||
| a_np = rng.standard_normal((m, k)).astype(np.float32) | ||
| b_np = rng.standard_normal((k, n)).astype(np.float32) | ||
| lhs_mask_np, rhs_mask_np, out_mask_np = make_masks(m, n, k, bs, sparsity, rng) | ||
|
|
||
| a_mx = mx.array(a_np, dtype=mlx_dtype) | ||
| b_mx = mx.array(b_np, dtype=mlx_dtype) | ||
| lhs_mask_mx = mx.array(lhs_mask_np) | ||
| rhs_mask_mx = mx.array(rhs_mask_np) | ||
| out_mask_mx = mx.array(out_mask_np) | ||
| mx.eval(a_mx, b_mx, lhs_mask_mx, rhs_mask_mx, out_mask_mx) | ||
|
|
||
| # Correctness check: block_masked_mm vs naive expand+matmul | ||
| err_str = "" | ||
| if not args.no_check: | ||
| y_op = mx.block_masked_mm( | ||
| a_mx, b_mx, bs, out_mask_mx, lhs_mask_mx, rhs_mask_mx | ||
| ) | ||
| y_naive = mlx_naive_block_masked_mm( | ||
| a_mx, b_mx, bs, out_mask_mx, lhs_mask_mx, rhs_mask_mx | ||
| ) | ||
| mx.eval(y_op, y_naive) | ||
| err = float(mx.max(mx.abs(y_op - y_naive)).item()) | ||
| err_str = f"{err:.2e}" | ||
|
|
||
| # Benchmark | ||
| t_mlx = bench_mlx( | ||
| lambda: mx.block_masked_mm( | ||
| a_mx, b_mx, bs, out_mask_mx, lhs_mask_mx, rhs_mask_mx | ||
| ), | ||
| args.warmup, | ||
| args.iters, | ||
| ) | ||
| t_naive = bench_mlx( | ||
| lambda: mlx_naive_block_masked_mm( | ||
| a_mx, b_mx, bs, out_mask_mx, lhs_mask_mx, rhs_mask_mx | ||
| ), | ||
| args.warmup, | ||
| args.iters, | ||
| ) | ||
| speedup = f"{t_naive / t_mlx:.2f}x" if t_mlx > 0 else "-" | ||
|
|
||
| row = [ | ||
| f"{m}x{n}x{k}x{bs}", | ||
| f"{sparsity:.0%}", | ||
| f"{t_mlx:.3f}", | ||
| f"{t_naive:.3f}", | ||
| speedup, | ||
| ] | ||
| if not args.no_check: | ||
| row.append(err_str) | ||
| rows.append(row) | ||
|
|
||
| print_table(headers, rows) | ||
| if not args.no_check: | ||
| print("err: max|block_masked_mm - naive_expand_matmul|") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,176 @@ | ||
| // Copyright © 2026 Apple Inc. | ||
|
|
||
| #include "mlx/backend/cuda/device.h" | ||
| #include "mlx/backend/cuda/device/utils.cuh" | ||
| #include "mlx/backend/cuda/gemms/block_mask.h" | ||
| #include "mlx/backend/cuda/kernel_utils.cuh" | ||
| #include "mlx/dtype_utils.h" | ||
|
|
||
| #include <cooperative_groups.h> | ||
|
|
||
| namespace mlx::core { | ||
|
|
||
| namespace cg = cooperative_groups; | ||
|
|
||
| namespace cu { | ||
|
|
||
| template <typename T, typename MaskT, bool SrcContiguous> | ||
| __global__ void block_mask_copy_kernel( | ||
| const T* src, | ||
| T* dst, | ||
| int block_size, | ||
| int64_t rows, | ||
| int64_t cols, | ||
| const __grid_constant__ Shape src_shape, | ||
| const __grid_constant__ Strides src_strides, | ||
| int src_ndim, | ||
| MaskT* mask, | ||
| const __grid_constant__ Shape mask_shape, | ||
| const __grid_constant__ Strides mask_strides, | ||
| int mask_ndim, | ||
| int64_t mask_row_stride, | ||
| int64_t mask_col_stride, | ||
| int64_t mask_mat_size, | ||
| int64_t batch_count) { | ||
| int64_t mat_size = rows * cols; | ||
| int64_t idx = cg::this_grid().thread_rank(); | ||
| if (idx >= batch_count * mat_size) | ||
| return; | ||
|
|
||
| int64_t batch = idx / mat_size; | ||
| int64_t within = idx % mat_size; | ||
| int64_t mask_batch_offset = elem_to_loc( | ||
| batch * mask_mat_size, mask_shape.data(), mask_strides.data(), mask_ndim); | ||
| MaskT mask_val = mask | ||
| [mask_batch_offset + (within / cols) / block_size * mask_row_stride + | ||
| (within % cols) / block_size * mask_col_stride]; | ||
|
|
||
| int64_t src_offset; | ||
| if constexpr (SrcContiguous) { | ||
| src_offset = idx; | ||
| } else { | ||
| src_offset = elem_to_loc( | ||
| batch * mat_size + within, | ||
| src_shape.data(), | ||
| src_strides.data(), | ||
| src_ndim); | ||
| } | ||
|
|
||
| if constexpr (std::is_same_v<MaskT, bool>) { | ||
| dst[idx] = mask_val ? src[src_offset] : T(0); | ||
| } else { | ||
| dst[idx] = src[src_offset] * T(mask_val); | ||
| } | ||
| } | ||
|
|
||
| } // namespace cu | ||
|
|
||
| namespace { | ||
|
|
||
| template <typename T, typename F> | ||
| void dispatch_mask_type(Dtype mask_dtype, F&& f) { | ||
| if (mask_dtype == bool_) { | ||
| f.template operator()<bool>(); | ||
| } else { | ||
| f.template operator()<T>(); | ||
| } | ||
| } | ||
|
|
||
| void block_mask_copy( | ||
| cu::CommandEncoder& encoder, | ||
| const array& src, | ||
| array& dst, | ||
| const array& mask, | ||
| int block_size, | ||
| int64_t rows, | ||
| int64_t cols, | ||
| bool src_contiguous, | ||
| int64_t batch_count) { | ||
| int mask_ndim = mask.ndim(); | ||
| int64_t mask_row_str = mask.strides()[mask_ndim - 2]; | ||
| int64_t mask_col_str = mask.strides()[mask_ndim - 1]; | ||
| int64_t mask_mat_size = | ||
| int64_t(mask.shape()[mask_ndim - 2]) * mask.shape()[mask_ndim - 1]; | ||
|
|
||
| auto [num_blocks, block_dims] = get_launch_args(src, src.size() > INT32_MAX); | ||
|
|
||
| dispatch_float_types(src.dtype(), "block_mask_copy", [&](auto type_tag) { | ||
| using T = cuda_type_t<MLX_GET_TYPE(type_tag)>; | ||
|
|
||
| dispatch_mask_type<T>(mask.dtype(), [&]<typename MaskT>() { | ||
| dispatch_bool(src_contiguous, [&](auto contiguous_tag) { | ||
| constexpr bool Contiguous = decltype(contiguous_tag)::value; | ||
| encoder.add_kernel_node( | ||
| cu::block_mask_copy_kernel<T, MaskT, Contiguous>, | ||
| num_blocks, | ||
| block_dims, | ||
| gpu_ptr<T>(src), | ||
| gpu_ptr<T>(dst), | ||
| block_size, | ||
| rows, | ||
| cols, | ||
| const_param(src.shape()), | ||
| const_param(src.strides()), | ||
| src.ndim(), | ||
| gpu_ptr<MaskT>(mask), | ||
| const_param(mask.shape()), | ||
| const_param(mask.strides()), | ||
| mask_ndim, | ||
| mask_row_str, | ||
| mask_col_str, | ||
| mask_mat_size, | ||
| batch_count); | ||
| }); | ||
| }); | ||
| }); | ||
| } | ||
|
|
||
| } // namespace | ||
|
|
||
| void apply_block_mask( | ||
| cu::CommandEncoder& encoder, | ||
| array& data, | ||
| const array& mask, | ||
| int block_size, | ||
| int64_t rows, | ||
| int64_t cols, | ||
| int64_t batch_count) { | ||
| encoder.set_input_array(mask); | ||
| encoder.set_output_array(data); | ||
|
|
||
| // Use block_mask_copy in-place (src == dst) with SrcContiguous=true. | ||
| block_mask_copy( | ||
| encoder, data, data, mask, block_size, rows, cols, true, batch_count); | ||
| } | ||
|
|
||
| array copy_with_block_mask( | ||
| cu::CommandEncoder& encoder, | ||
| const array& src, | ||
| const array& mask, | ||
| int block_size, | ||
| int64_t rows, | ||
| int64_t cols, | ||
| int64_t batch_count) { | ||
| array dst(src.shape(), src.dtype(), nullptr, {}); | ||
| dst.set_data(cu::malloc_async(dst.nbytes(), encoder)); | ||
| encoder.add_temporary(dst); | ||
|
|
||
| encoder.set_input_array(src); | ||
| encoder.set_input_array(mask); | ||
| encoder.set_output_array(dst); | ||
|
|
||
| block_mask_copy( | ||
| encoder, | ||
| src, | ||
| dst, | ||
| mask, | ||
| block_size, | ||
| rows, | ||
| cols, | ||
| src.flags().row_contiguous, | ||
| batch_count); | ||
|
|
||
| return dst; | ||
| } | ||
|
|
||
| } // namespace mlx::core |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.