Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 65 additions & 16 deletions mlx/backend/metal/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,67 @@ ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) {
return std::make_tuple(false, x_copy.strides()[x_copy.ndim() - 2], x_copy);
}

struct RegularNaxGemmTuning {
int bm;
int bn;
int bk;
int wm;
int wn;
int swizzle_log;
};

inline bool use_regular_nax_tuning(
const metal::Device& d,
const Dtype& out_dtype) {
const auto& arch = d.get_architecture();
if (arch.empty()) {
return false;
}

const auto devc = arch.back();

// Preserve the tuned regular-NAX path already used on larger devices.
if (devc == 's' || devc == 'c' || devc == 'd') {
return true;
}

// Extend the same tuning to gen-17 g devices for BF16/FP16 regular NAX GEMMs.
return d.get_architecture_gen() == 17 && devc == 'g' &&
(out_dtype == float16 || out_dtype == bfloat16);
}

inline RegularNaxGemmTuning select_regular_nax_gemm_tuning(
const metal::Device& d,
const Dtype& out_dtype,
int M,
int N,
int K) {
RegularNaxGemmTuning tuning{
/* bm = */ 128,
/* bn = */ 128,
/* bk = */ 512,
/* wm = */ 4,
/* wn = */ 4,
/* swizzle_log = */ 0,
};

if (use_regular_nax_tuning(d, out_dtype)) {
tuning = {
/* bm = */ 64,
/* bn = */ 128,
/* bk = */ (K >= 8192 && K > (M + N)) ? 64 : 256,
/* wm = */ 2,
/* wn = */ 4,
/* swizzle_log = */ 2,
};
} else {
int tm = (M + tuning.bm - 1) / tuning.bm;
tuning.swizzle_log = tm <= 3 ? 0 : 1;
}

return tuning;
}

} // namespace

///////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -201,17 +262,9 @@ void steel_matmul_regular_axpby_nax(
using namespace mlx::steel;

// Determine dispatch kernel
int bm = 128, bn = 128, bk = 512;
int wm = 4, wn = 4;

// Temp routing for larger devices
char devc = d.get_architecture().back();
if (devc == 's' || devc == 'c' || devc == 'd') {
bk = (K >= 8192 && K > (M + N)) ? 64 : 256;

bm = 64;
wm = 2;
}
auto tuning = select_regular_nax_gemm_tuning(d, out.dtype(), M, N, K);
int bm = tuning.bm, bn = tuning.bn, bk = tuning.bk;
int wm = tuning.wm, wn = tuning.wn;

// Prepare kernel name
std::ostringstream kname;
Expand Down Expand Up @@ -275,11 +328,7 @@ void steel_matmul_regular_axpby_nax(
int tn = (N + bn - 1) / bn;
int tm = (M + bm - 1) / bm;

// TODO: Explore device-based tuning for swizzle
int swizzle_log = tm <= 3 ? 0 : 1;
if (devc == 's' || devc == 'c' || devc == 'd') {
swizzle_log = 2;
}
int swizzle_log = tuning.swizzle_log;

// Prepare steel matmul params
GEMMParams params{/* const int M = */ M,
Expand Down
Loading