From ebc1ab410c8de03ea7e16ef8731d384ae78e6f7b Mon Sep 17 00:00:00 2001 From: lentil32 Date: Sun, 22 Mar 2026 20:50:37 +0900 Subject: [PATCH] Extend regular NAX tuning to gen-17 g devices --- mlx/backend/metal/matmul.cpp | 81 +++++++++++++++++++++++++++++------- 1 file changed, 65 insertions(+), 16 deletions(-) diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 84b6ee06da..e52808e358 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -79,6 +79,67 @@ ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) { return std::make_tuple(false, x_copy.strides()[x_copy.ndim() - 2], x_copy); } +struct RegularNaxGemmTuning { + int bm; + int bn; + int bk; + int wm; + int wn; + int swizzle_log; +}; + +inline bool use_regular_nax_tuning( + const metal::Device& d, + const Dtype& out_dtype) { + const auto& arch = d.get_architecture(); + if (arch.empty()) { + return false; + } + + const auto devc = arch.back(); + + // Preserve the tuned regular-NAX path already used on larger devices. + if (devc == 's' || devc == 'c' || devc == 'd') { + return true; + } + + // Extend the same tuning to gen-17 g devices for BF16/FP16 regular NAX GEMMs. + return d.get_architecture_gen() == 17 && devc == 'g' && + (out_dtype == float16 || out_dtype == bfloat16); +} + +inline RegularNaxGemmTuning select_regular_nax_gemm_tuning( + const metal::Device& d, + const Dtype& out_dtype, + int M, + int N, + int K) { + RegularNaxGemmTuning tuning{ + /* bm = */ 128, + /* bn = */ 128, + /* bk = */ 512, + /* wm = */ 4, + /* wn = */ 4, + /* swizzle_log = */ 0, + }; + + if (use_regular_nax_tuning(d, out_dtype)) { + tuning = { + /* bm = */ 64, + /* bn = */ 128, + /* bk = */ (K >= 8192 && K > (M + N)) ? 64 : 256, + /* wm = */ 2, + /* wn = */ 4, + /* swizzle_log = */ 2, + }; + } else { + int tm = (M + tuning.bm - 1) / tuning.bm; + tuning.swizzle_log = tm <= 3 ? 0 : 1; + } + + return tuning; +} + } // namespace /////////////////////////////////////////////////////////////////////////////// @@ -201,17 +262,9 @@ void steel_matmul_regular_axpby_nax( using namespace mlx::steel; // Determine dispatch kernel - int bm = 128, bn = 128, bk = 512; - int wm = 4, wn = 4; - - // Temp routing for larger devices - char devc = d.get_architecture().back(); - if (devc == 's' || devc == 'c' || devc == 'd') { - bk = (K >= 8192 && K > (M + N)) ? 64 : 256; - - bm = 64; - wm = 2; - } + auto tuning = select_regular_nax_gemm_tuning(d, out.dtype(), M, N, K); + int bm = tuning.bm, bn = tuning.bn, bk = tuning.bk; + int wm = tuning.wm, wn = tuning.wn; // Prepare kernel name std::ostringstream kname; @@ -275,11 +328,7 @@ void steel_matmul_regular_axpby_nax( int tn = (N + bn - 1) / bn; int tm = (M + bm - 1) / bm; - // TODO: Explore device-based tuning for swizzle - int swizzle_log = tm <= 3 ? 0 : 1; - if (devc == 's' || devc == 'c' || devc == 'd') { - swizzle_log = 2; - } + int swizzle_log = tuning.swizzle_log; // Prepare steel matmul params GEMMParams params{/* const int M = */ M,