Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 193 additions & 0 deletions benchmarks/python/block_masked_mm_bench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
# Copyright © 2025 Apple Inc.

import argparse
import time

import mlx.core as mx
import numpy as np

MLX_DTYPES = {
"float16": mx.float16,
"bfloat16": mx.bfloat16,
"float32": mx.float32,
}


def parse_cases(cases):
parsed = []
for spec in cases.split(","):
parts = spec.split("x")
m, n, k, bs = int(parts[0]), int(parts[1]), int(parts[2]), int(parts[3])
sparsity = float(parts[4]) if len(parts) > 4 else 0.5
parsed.append((m, n, k, bs, sparsity))
return parsed


def make_masks(m, n, k, block_size, sparsity, rng):
"""Create block masks with given sparsity (fraction of blocks zeroed)."""
tm = (m + block_size - 1) // block_size
tn = (n + block_size - 1) // block_size
tk = (k + block_size - 1) // block_size

lhs_mask = (rng.random((tm, tk)) >= sparsity).astype(np.bool_)
rhs_mask = (rng.random((tk, tn)) >= sparsity).astype(np.bool_)
out_mask = (rng.random((tm, tn)) >= sparsity).astype(np.bool_)
return lhs_mask, rhs_mask, out_mask


def mlx_naive_block_masked_mm(a, b, block_size, out_mask, lhs_mask, rhs_mask):
"""MLX naive: expand masks and use regular matmul."""
M, K = a.shape[-2], a.shape[-1]
N = b.shape[-1]

def expand(mask, rows, cols):
e = mx.repeat(mx.repeat(mask, block_size, axis=-2), block_size, axis=-1)
return e[..., :rows, :cols]

a_masked = a * expand(lhs_mask, M, K)
b_masked = b * expand(rhs_mask, K, N)
c = a_masked @ b_masked
c = c * expand(out_mask, M, N)
return c


def bench_mlx(fn, warmup, iters):
for _ in range(warmup):
y = fn()
mx.eval(y)
mx.synchronize()

start = time.perf_counter()
for _ in range(iters):
y = fn()
mx.eval(y)
mx.synchronize()
return (time.perf_counter() - start) * 1e3 / iters


def print_table(headers, rows):
widths = [len(h) for h in headers]
for row in rows:
for i, cell in enumerate(row):
widths[i] = max(widths[i], len(cell))

def fmt_row(row):
return (
"| "
+ " | ".join(f"{cell:<{widths[i]}}" for i, cell in enumerate(row))
+ " |"
)

sep = "|-" + "-|-".join("-" * w for w in widths) + "-|"
print(fmt_row(headers))
print(sep)
for row in rows:
print(fmt_row(row))


def main():
parser = argparse.ArgumentParser(
description="Benchmark block_masked_mm vs naive expand+matmul"
)
parser.add_argument(
"--cases",
default=(
"256x256x256x32x0.5,"
"512x512x512x32x0.5,"
"1024x1024x1024x32x0.5,"
"1024x1024x1024x64x0.5,"
"2048x2048x2048x64x0.5,"
"256x256x256x32x0.0,"
"1024x1024x1024x32x0.0,"
"1024x1024x1024x32x0.9"
),
help="Comma-separated MxNxKxBSxSparsity list. Sparsity=fraction of blocks zeroed.",
)
parser.add_argument(
"--dtype",
default="float32",
choices=["float16", "bfloat16", "float32"],
)
parser.add_argument("--warmup", type=int, default=10)
parser.add_argument("--iters", type=int, default=50)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--no-check", action="store_true")
args = parser.parse_args()

mlx_dtype = MLX_DTYPES[args.dtype]

print(f"dtype={args.dtype} warmup={args.warmup} iters={args.iters}")

headers = [
"Case (MxNxKxBS)",
"Sparsity",
"MLX ms",
"Naive ms",
"Speedup",
]
if not args.no_check:
headers.append("Max err")
rows = []

cases = parse_cases(args.cases)
for idx, (m, n, k, bs, sparsity) in enumerate(cases):
rng = np.random.default_rng(args.seed + idx)
a_np = rng.standard_normal((m, k)).astype(np.float32)
b_np = rng.standard_normal((k, n)).astype(np.float32)
lhs_mask_np, rhs_mask_np, out_mask_np = make_masks(m, n, k, bs, sparsity, rng)

a_mx = mx.array(a_np, dtype=mlx_dtype)
b_mx = mx.array(b_np, dtype=mlx_dtype)
lhs_mask_mx = mx.array(lhs_mask_np)
rhs_mask_mx = mx.array(rhs_mask_np)
out_mask_mx = mx.array(out_mask_np)
mx.eval(a_mx, b_mx, lhs_mask_mx, rhs_mask_mx, out_mask_mx)

# Correctness check: block_masked_mm vs naive expand+matmul
err_str = ""
if not args.no_check:
y_op = mx.block_masked_mm(
a_mx, b_mx, bs, out_mask_mx, lhs_mask_mx, rhs_mask_mx
)
y_naive = mlx_naive_block_masked_mm(
a_mx, b_mx, bs, out_mask_mx, lhs_mask_mx, rhs_mask_mx
)
mx.eval(y_op, y_naive)
err = float(mx.max(mx.abs(y_op - y_naive)).item())
err_str = f"{err:.2e}"

# Benchmark
t_mlx = bench_mlx(
lambda: mx.block_masked_mm(
a_mx, b_mx, bs, out_mask_mx, lhs_mask_mx, rhs_mask_mx
),
args.warmup,
args.iters,
)
t_naive = bench_mlx(
lambda: mlx_naive_block_masked_mm(
a_mx, b_mx, bs, out_mask_mx, lhs_mask_mx, rhs_mask_mx
),
args.warmup,
args.iters,
)
speedup = f"{t_naive / t_mlx:.2f}x" if t_mlx > 0 else "-"

row = [
f"{m}x{n}x{k}x{bs}",
f"{sparsity:.0%}",
f"{t_mlx:.3f}",
f"{t_naive:.3f}",
speedup,
]
if not args.no_check:
row.append(err_str)
rows.append(row)

print_table(headers, rows)
if not args.no_check:
print("err: max|block_masked_mm - naive_expand_matmul|")


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions mlx/backend/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
${CMAKE_CURRENT_SOURCE_DIR}/fft.cu
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemms/block_mask.cu
${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu
${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemms/grouped_gemm_unaligned.cu
Expand Down
176 changes: 176 additions & 0 deletions mlx/backend/cuda/gemms/block_mask.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
// Copyright © 2026 Apple Inc.

#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/utils.cuh"
#include "mlx/backend/cuda/gemms/block_mask.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h"

#include <cooperative_groups.h>

namespace mlx::core {

namespace cg = cooperative_groups;

namespace cu {

template <typename T, typename MaskT, bool SrcContiguous>
__global__ void block_mask_copy_kernel(
const T* src,
T* dst,
int block_size,
int64_t rows,
int64_t cols,
const __grid_constant__ Shape src_shape,
const __grid_constant__ Strides src_strides,
int src_ndim,
MaskT* mask,
const __grid_constant__ Shape mask_shape,
const __grid_constant__ Strides mask_strides,
int mask_ndim,
int64_t mask_row_stride,
int64_t mask_col_stride,
int64_t mask_mat_size,
int64_t batch_count) {
int64_t mat_size = rows * cols;
int64_t idx = cg::this_grid().thread_rank();
if (idx >= batch_count * mat_size)
return;

int64_t batch = idx / mat_size;
int64_t within = idx % mat_size;
int64_t mask_batch_offset = elem_to_loc(
batch * mask_mat_size, mask_shape.data(), mask_strides.data(), mask_ndim);
MaskT mask_val = mask
[mask_batch_offset + (within / cols) / block_size * mask_row_stride +
(within % cols) / block_size * mask_col_stride];

int64_t src_offset;
if constexpr (SrcContiguous) {
src_offset = idx;
} else {
src_offset = elem_to_loc(
batch * mat_size + within,
src_shape.data(),
src_strides.data(),
src_ndim);
}

if constexpr (std::is_same_v<MaskT, bool>) {
dst[idx] = mask_val ? src[src_offset] : T(0);
} else {
dst[idx] = src[src_offset] * T(mask_val);
}
}

} // namespace cu

namespace {

template <typename T, typename F>
void dispatch_mask_type(Dtype mask_dtype, F&& f) {
if (mask_dtype == bool_) {
f.template operator()<bool>();
} else {
f.template operator()<T>();
}
}

void block_mask_copy(
cu::CommandEncoder& encoder,
const array& src,
array& dst,
const array& mask,
int block_size,
int64_t rows,
int64_t cols,
bool src_contiguous,
int64_t batch_count) {
int mask_ndim = mask.ndim();
int64_t mask_row_str = mask.strides()[mask_ndim - 2];
int64_t mask_col_str = mask.strides()[mask_ndim - 1];
int64_t mask_mat_size =
int64_t(mask.shape()[mask_ndim - 2]) * mask.shape()[mask_ndim - 1];

auto [num_blocks, block_dims] = get_launch_args(src, src.size() > INT32_MAX);

dispatch_float_types(src.dtype(), "block_mask_copy", [&](auto type_tag) {
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;

dispatch_mask_type<T>(mask.dtype(), [&]<typename MaskT>() {
dispatch_bool(src_contiguous, [&](auto contiguous_tag) {
constexpr bool Contiguous = decltype(contiguous_tag)::value;
encoder.add_kernel_node(
cu::block_mask_copy_kernel<T, MaskT, Contiguous>,
num_blocks,
block_dims,
gpu_ptr<T>(src),
gpu_ptr<T>(dst),
block_size,
rows,
cols,
const_param(src.shape()),
const_param(src.strides()),
src.ndim(),
gpu_ptr<MaskT>(mask),
const_param(mask.shape()),
const_param(mask.strides()),
mask_ndim,
mask_row_str,
mask_col_str,
mask_mat_size,
batch_count);
});
});
});
}

} // namespace

void apply_block_mask(
cu::CommandEncoder& encoder,
array& data,
const array& mask,
int block_size,
int64_t rows,
int64_t cols,
int64_t batch_count) {
encoder.set_input_array(mask);
encoder.set_output_array(data);

// Use block_mask_copy in-place (src == dst) with SrcContiguous=true.
block_mask_copy(
encoder, data, data, mask, block_size, rows, cols, true, batch_count);
}

array copy_with_block_mask(
cu::CommandEncoder& encoder,
const array& src,
const array& mask,
int block_size,
int64_t rows,
int64_t cols,
int64_t batch_count) {
array dst(src.shape(), src.dtype(), nullptr, {});
dst.set_data(cu::malloc_async(dst.nbytes(), encoder));
encoder.add_temporary(dst);

encoder.set_input_array(src);
encoder.set_input_array(mask);
encoder.set_output_array(dst);

block_mask_copy(
encoder,
src,
dst,
mask,
block_size,
rows,
cols,
src.flags().row_contiguous,
batch_count);

return dst;
}

} // namespace mlx::core
Loading
Loading