From 7114540da07f92a2d63540b989c142adf503615a Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Mon, 16 Mar 2026 11:24:44 -0700 Subject: [PATCH 1/3] Init NAX refactor --- mlx/backend/metal/kernels/fp_quantized_nax.h | 64 +- mlx/backend/metal/kernels/quantized_nax.h | 64 +- .../steel/attn/kernels/steel_attention_nax.h | 163 ++-- mlx/backend/metal/kernels/steel/attn/nax.h | 918 +++++++----------- .../metal/kernels/steel/gemm/gemm_nax.h | 84 +- .../steel/gemm/kernels/steel_gemm_fused_nax.h | 48 +- .../gemm/kernels/steel_gemm_gather_nax.h | 15 +- .../gemm/kernels/steel_gemm_splitk_nax.h | 13 +- mlx/backend/metal/kernels/steel/gemm/nax.h | 894 +++++++---------- 9 files changed, 881 insertions(+), 1382 deletions(-) diff --git a/mlx/backend/metal/kernels/fp_quantized_nax.h b/mlx/backend/metal/kernels/fp_quantized_nax.h index 80e1c4c2f7..381bc6c7d3 100644 --- a/mlx/backend/metal/kernels/fp_quantized_nax.h +++ b/mlx/backend/metal/kernels/fp_quantized_nax.h @@ -246,16 +246,13 @@ METAL_FUNC void fp_qmm_t_impl( // Make the weight loader loader_w_t loader_w(wl, scales, K, Ws, simd_gid, simd_lid); - constexpr short UM = 16; - constexpr short UN = 32; - constexpr short UK = 16; constexpr short SM = BM / WM; constexpr short SN = BN / WN; constexpr short SK = 32; - constexpr short TM = SM / UM; - constexpr short TN = SN / UN; - constexpr short TK = SK / UK; + constexpr short TM = SM / 16; + constexpr short TN = SN / 16; + constexpr short TK = SK / 16; const short tm = SM * (simd_gid / WN); const short tn = SN * (simd_gid % WN); @@ -273,12 +270,7 @@ METAL_FUNC void fp_qmm_t_impl( using AccumType = float; - using ASubTile = NAXSubTile; - using BSubTile = NAXSubTile; - using DSubTile = NAXSubTile; - - NAXTile Dtile; - + NAXTile Dtile; Dtile.clear(); x += tm * K; @@ -297,8 +289,8 @@ METAL_FUNC void fp_qmm_t_impl( STEEL_PRAGMA_NO_UNROLL for (int kk1 = 0; kk1 < BK; kk1 += SK) { - NAXTile Atile; - NAXTile Btile; + NAXTile Atile; + NAXTile Btile; volatile int compiler_barrier; @@ -400,16 +392,13 @@ METAL_FUNC void fp_qmm_n_impl( // const short num_outs = min(BN, N - y_col); loader_w_t loader_w(wl, scales, K, Ws, simd_gid, simd_lid); - constexpr short UM = 16; - constexpr short UN = 32; - constexpr short UK = 16; constexpr short SM = BM / WM; constexpr short SN = BN / WN; constexpr short SK = 32; - constexpr short TM = SM / UM; - constexpr short TN = SN / UN; - constexpr short TK = SK / UK; + constexpr short TM = SM / 16; + constexpr short TN = SN / 16; + constexpr short TK = SK / 16; const short tm = SM * (simd_gid / WN); const short tn = SN * (simd_gid % WN); @@ -421,12 +410,7 @@ METAL_FUNC void fp_qmm_n_impl( using AccumType = float; - using ASubTile = NAXSubTile; - using BSubTile = NAXSubTile; - using DSubTile = NAXSubTile; - - NAXTile Dtile; - + NAXTile Dtile; Dtile.clear(); x += tm * K; @@ -438,8 +422,8 @@ METAL_FUNC void fp_qmm_n_impl( STEEL_PRAGMA_NO_UNROLL for (int kk1 = 0; kk1 < BK; kk1 += SK) { - NAXTile Atile; - NAXTile Btile; + NAXTile Atile; + NAXTile Btile; volatile int compiler_barrier; @@ -869,16 +853,13 @@ template < wl += transpose ? y_col_long * K_w : y_col * bytes_per_pack / pack_factor; scales += transpose ? y_col_long * K_g : y_col / group_size; - constexpr short UM = 16; - constexpr short UN = 32; - constexpr short UK = 16; constexpr short SM = BM / WM; constexpr short SN = BN / WN; constexpr short SK = 32; - constexpr short TM = SM / UM; - constexpr short TN = SN / UN; - constexpr short TK = SK / UK; + constexpr short TM = SM / 16; + constexpr short TN = SN / 16; + constexpr short TK = SK / 16; const short tm = SM * (simd_group_id / WN); const short tn = SN * (simd_group_id % WN); @@ -896,10 +877,6 @@ template < using AccumType = float; - using ASubTile = NAXSubTile; - using BSubTile = NAXSubTile; - using DSubTile = NAXSubTile; - // Do as many matmuls as necessary uint32_t index; short offset; @@ -921,8 +898,7 @@ template < threadgroup_barrier(mem_flags::mem_none); // Prepare threadgroup mma operation - NAXTile Dtile; - + NAXTile Dtile; Dtile.clear(); const device T* xn = x + tm * K; @@ -951,8 +927,8 @@ template < STEEL_PRAGMA_NO_UNROLL for (int kk1 = 0; kk1 < BK; kk1 += SK) { - NAXTile Atile; - NAXTile Btile; + NAXTile Atile; + NAXTile Btile; volatile int compiler_barrier; @@ -991,8 +967,8 @@ template < STEEL_PRAGMA_NO_UNROLL for (int kk1 = 0; kk1 < BK; kk1 += SK) { - NAXTile Atile; - NAXTile Btile; + NAXTile Atile; + NAXTile Btile; volatile int compiler_barrier; diff --git a/mlx/backend/metal/kernels/quantized_nax.h b/mlx/backend/metal/kernels/quantized_nax.h index c26ff646bb..8814fafa33 100644 --- a/mlx/backend/metal/kernels/quantized_nax.h +++ b/mlx/backend/metal/kernels/quantized_nax.h @@ -986,16 +986,13 @@ METAL_FUNC void qmm_t_nax_tgp_impl( // Make the weight loader loader_w_t loader_w(wl, scales, biases, K, Ws, simd_gid, simd_lid); - constexpr short UM = 16; - constexpr short UN = 32; - constexpr short UK = 16; constexpr short SM = BM / WM; constexpr short SN = BN / WN; constexpr short SK = 32; - constexpr short TM = SM / UM; - constexpr short TN = SN / UN; - constexpr short TK = SK / UK; + constexpr short TM = SM / 16; + constexpr short TN = SN / 16; + constexpr short TK = SK / 16; const short tm = SM * (simd_gid / WN); const short tn = SN * (simd_gid % WN); @@ -1013,12 +1010,7 @@ METAL_FUNC void qmm_t_nax_tgp_impl( using AccumType = float; - using ASubTile = NAXSubTile; - using BSubTile = NAXSubTile; - using DSubTile = NAXSubTile; - - NAXTile Dtile; - + NAXTile Dtile; Dtile.clear(); x += tm * K; @@ -1037,8 +1029,8 @@ METAL_FUNC void qmm_t_nax_tgp_impl( STEEL_PRAGMA_NO_UNROLL for (int kk1 = 0; kk1 < BK; kk1 += SK) { - NAXTile Atile; - NAXTile Btile; + NAXTile Atile; + NAXTile Btile; volatile int compiler_barrier; @@ -1141,16 +1133,13 @@ METAL_FUNC void qmm_n_nax_tgp_impl( // const short num_outs = min(BN, N - y_col); loader_w_t loader_w(wl, scales, biases, K, Ws, simd_gid, simd_lid); - constexpr short UM = 16; - constexpr short UN = 32; - constexpr short UK = 16; constexpr short SM = BM / WM; constexpr short SN = BN / WN; constexpr short SK = 32; - constexpr short TM = SM / UM; - constexpr short TN = SN / UN; - constexpr short TK = SK / UK; + constexpr short TM = SM / 16; + constexpr short TN = SN / 16; + constexpr short TK = SK / 16; const short tm = SM * (simd_gid / WN); const short tn = SN * (simd_gid % WN); @@ -1162,12 +1151,7 @@ METAL_FUNC void qmm_n_nax_tgp_impl( using AccumType = float; - using ASubTile = NAXSubTile; - using BSubTile = NAXSubTile; - using DSubTile = NAXSubTile; - - NAXTile Dtile; - + NAXTile Dtile; Dtile.clear(); x += tm * K; @@ -1179,8 +1163,8 @@ METAL_FUNC void qmm_n_nax_tgp_impl( STEEL_PRAGMA_NO_UNROLL for (int kk1 = 0; kk1 < BK; kk1 += SK) { - NAXTile Atile; - NAXTile Btile; + NAXTile Atile; + NAXTile Btile; volatile int compiler_barrier; @@ -1534,16 +1518,13 @@ template < scales += transpose ? y_col_long * K_g : y_col / group_size; biases += transpose ? y_col_long * K_g : y_col / group_size; - constexpr short UM = 16; - constexpr short UN = 32; - constexpr short UK = 16; constexpr short SM = BM / WM; constexpr short SN = BN / WN; constexpr short SK = 32; - constexpr short TM = SM / UM; - constexpr short TN = SN / UN; - constexpr short TK = SK / UK; + constexpr short TM = SM / 16; + constexpr short TN = SN / 16; + constexpr short TK = SK / 16; const short tm = SM * (simd_group_id / WN); const short tn = SN * (simd_group_id % WN); @@ -1561,10 +1542,6 @@ template < using AccumType = float; - using ASubTile = NAXSubTile; - using BSubTile = NAXSubTile; - using DSubTile = NAXSubTile; - // Do as many matmuls as necessary uint32_t index; short offset; @@ -1585,8 +1562,7 @@ template < } threadgroup_barrier(mem_flags::mem_none); - NAXTile Dtile; - + NAXTile Dtile; Dtile.clear(); const device T* xn = x + tm * K; @@ -1616,8 +1592,8 @@ template < STEEL_PRAGMA_NO_UNROLL for (int kk1 = 0; kk1 < BK; kk1 += SK) { - NAXTile Atile; - NAXTile Btile; + NAXTile Atile; + NAXTile Btile; volatile int compiler_barrier; @@ -1654,8 +1630,8 @@ template < STEEL_PRAGMA_NO_UNROLL for (int kk1 = 0; kk1 < BK; kk1 += SK) { - NAXTile Atile; - NAXTile Btile; + NAXTile Atile; + NAXTile Btile; volatile int compiler_barrier; diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h index 6abee21fb5..a696462707 100644 --- a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h @@ -125,38 +125,36 @@ template < make_uniform(params->scale) * make_uniform(1.44269504089f); // Prepare MMA tiles - constexpr short UQ = 16; - constexpr short UD = 32; + constexpr short kU = 16; constexpr int kNWarps = WM * WN; static_assert( - BQ >= (kNWarps * UQ) && BQ % (kNWarps * UQ) == 0, + BQ >= (kNWarps * kU) && BQ % (kNWarps * kU) == 0, "Each simdgroup must host atleast 1 simdgroup matrix along Q sequence."); // Q seq frags per warp - constexpr int TQ = BQ / (kNWarps * UQ); + constexpr int TQ = BQ / (kNWarps * kU); // HeadDim frags (all warps load the same frags) - constexpr int TD = BD / UD; + constexpr int TD = BD / kU; + // KV seq frags per warp + constexpr short TK = BK / kU; static_assert(TQ == 1, "Check TQ"); - - using OSubTile = NAXSubTile; - NAXTile Otile; + using otile_t = NAXTile; + otile_t Otile; Otile.clear(); // Prepare mma tile offsets - const short2 simd_coord = OSubTile::NAXFrag_t::get_coord(); + const short tm = kU * TQ * simd_group_id; + Q += tm * int(params->Q_strides[2]); + + const short2 simd_coord = otile_t::NAXFrag_t::get_coord(); const short sm = simd_coord.y; const short sn = simd_coord.x; - const short tm = UQ * TQ * simd_group_id; - - Q += (tm + sm) * int(params->Q_strides[2]) + sn; - K += sm * int(params->K_strides[2]) + sn; - V += sm * int(params->V_strides[2]) + sn; // Init row reduction variables - constexpr short kRowsPT = decltype(Otile)::kRowsPerThread; + constexpr short kRowsPT = otile_t::kRowsPerThread; metal::vec max_score; metal::vec sum_score{0}; @@ -187,40 +185,30 @@ template < // const bool is_last_tq = int(simd_group_id) >= (params->qL_rem / UQ); const bool is_last_q = is_last_bq; - const short lim_rows_q = params->qL_rem - (tm + sm); - const short lim_rows_k = params->kL_rem - sm; + const short lim_rows_q = params->qL_rem - tm; + const short lim_rows_k = params->kL_rem; // Loop over KV seq length for (int kb = 0; kb < kb_lim; kb++) { const int is_last_k = (kb == (params->NK_aligned)); // Do S = Q @ K.T - constexpr short UDs = 16; - constexpr short UKs = 32; - - constexpr short TDs = BD / UDs; - constexpr short TKs = BK / UKs; - - using SSubTile = NAXSubTile; - using QSubTile = NAXSubTile; - using KSubTile = NAXSubTile; - - NAXTile Stile; + using stile_t = NAXTile; + stile_t Stile; Stile.clear(); STEEL_PRAGMA_UNROLL for (short iq = 0; iq < TQ; iq++) { STEEL_PRAGMA_UNROLL - for (short ik = 0; ik < TKs; ik++) { + for (short ik = 0; ik < TK; ik += 2) { STEEL_PRAGMA_UNROLL - for (short id = 0; id < TDs; id++) { - NAXTile Qtile; - NAXTile Ktile; + for (short id = 0; id < TD; id++) { + NAXTile Qtile; + NAXTile Ktile; - const int Q_load_off = iq * UQ * int(params->Q_strides[2]) + id * UDs; - const int K_load_off = - ik * UKs * int(params->K_strides[2]) + id * UDs; + const int Q_load_off = iq * kU * int(params->Q_strides[2]) + id * kU; + const int K_load_off = ik * kU * int(params->K_strides[2]) + id * kU; if (!align_Q && is_last_q) { // Qtile.load_rows( @@ -230,7 +218,7 @@ template < Qtile.load_safe( Q + Q_load_off, int(params->Q_strides[2]), - short2(BD, lim_rows_q - iq * UQ)); + short2(BD, lim_rows_q - iq * kU)); } else { Qtile.load(Q + Q_load_off, int(params->Q_strides[2])); } @@ -243,16 +231,18 @@ template < Ktile.load_safe( K + K_load_off, int(params->K_strides[2]), - short2(BD, lim_rows_k - ik * UKs)); + short2(BD, lim_rows_k - ik * kU)); } else { Ktile.load(K + K_load_off, int(params->K_strides[2])); } - subtile_matmad_nax( - Stile.subtile_at(iq, ik), - Qtile.subtile_at(0, 0), + stile_t::NAXFrag_t::mma( + Stile.frag_at(iq, ik), + Stile.frag_at(iq, ik + 1), + Qtile.frag_at(0, 0), metal::false_type{}, - Ktile.subtile_at(0, 0), + Ktile.frag_at(0, 0), + Ktile.frag_at(1, 0), metal::true_type{}); } } @@ -260,22 +250,10 @@ template < // Scale S STEEL_PRAGMA_UNROLL - for (short ii = 0; ii < decltype(Stile)::kElemsPerTile; ii++) { + for (short ii = 0; ii < stile_t::kElemsPerTile; ii++) { Stile.elems()[ii] *= float(scale2); } - // Scale and Retile S - constexpr short UK = 16; - constexpr short TK = BK / UK; - using PSubTile = NAXSubTile; - - NAXTile Ptile; - - STEEL_PRAGMA_UNROLL - for (short ii = 0; ii < decltype(Stile)::kElemsPerTile; ii++) { - Ptile.elems()[ii] = Stile.elems()[ii]; - } - // Mask out length sequence if (!align_K && is_last_k) { constexpr auto neg_inf = Limits::finite_min; @@ -284,15 +262,15 @@ template < for (short iq = 0; iq < TQ; iq++) { STEEL_PRAGMA_UNROLL for (short ik = 0; ik < TK; ik++) { - const short col_pos = sn + ik * UK; + const short col_pos = ik * kU + sn; - thread auto& fg = Ptile.subtile_at(iq, ik).frag_at(0, 0); + thread auto& fg = Stile.frag_at(iq, ik); STEEL_PRAGMA_UNROLL - for (short ii = 0; ii < PSubTile::kFragThrRows; ii++) { + for (short ii = 0; ii < stile_t::kFragThrRows; ii++) { STEEL_PRAGMA_UNROLL - for (short jj = 0; jj < PSubTile::kFragThrCols; jj++) { - const auto loc = ii * PSubTile::kFragThrCols + jj; + for (short jj = 0; jj < stile_t::kFragThrCols; jj++) { + const auto loc = ii * stile_t::kFragThrCols + jj; fg[loc] = ((col_pos + jj) >= params->kL_rem) ? neg_inf : fg[loc]; } } @@ -311,18 +289,18 @@ template < for (short iq = 0; iq < TQ; iq++) { STEEL_PRAGMA_UNROLL for (short ik = 0; ik < TK; ik++) { - const short row_pos = base_row + iq * UQ; - const short col_pos = base_col + ik * UK; + const short row_pos = base_row + iq * kU; + const short col_pos = base_col + ik * kU; - thread auto& fg = Ptile.subtile_at(iq, ik).frag_at(0, 0); + thread auto& fg = Stile.frag_at(iq, ik); STEEL_PRAGMA_UNROLL - for (short ii = 0; ii < PSubTile::kFragThrRows; ii++) { + for (short ii = 0; ii < stile_t::kFragThrRows; ii++) { STEEL_PRAGMA_UNROLL - for (short jj = 0; jj < PSubTile::kFragThrCols; jj++) { - const auto r = row_pos + ii * PSubTile::kFragRowsJump + sm; + for (short jj = 0; jj < stile_t::kFragThrCols; jj++) { + const auto r = row_pos + ii * stile_t::kFragRowsJump + sm; const auto c = col_pos + jj + sn; - const auto loc = ii * PSubTile::kFragThrCols + jj; + const auto loc = ii * stile_t::kFragThrCols + jj; fg[loc] = (r < c) ? neg_inf : fg[loc]; } } @@ -339,17 +317,19 @@ template < constexpr bool is_bool = is_same_v; using melem_t = typename metal::conditional_t; - using MSubTile = NAXSubTile; + using mtile_t = NAXTile; + using mfrag_t = typename mtile_t::frag_type; STEEL_PRAGMA_UNROLL for (short iq = 0; iq < TQ; iq++) { STEEL_PRAGMA_UNROLL for (short ik = 0; ik < TK; ik++) { - const short row_pos = base_row + iq * UQ + sm; - const short col_pos = base_col + ik * UK + sn; + const short row_pos = base_row + iq * kU; + const short col_pos = base_col + ik * kU; - MSubTile mfrag; - mfrag.load_safe( + mfrag_t mfrag; + mtile_t::NAXFrag_t::load_safe( + mfrag, mask, int64_t(mask_params->M_strides[2]), Int<1>{}, @@ -358,14 +338,14 @@ template < row_pos, col_pos); - thread auto& fg = Ptile.subtile_at(iq, ik).frag_at(0, 0); + thread auto& fg = Stile.frag_at(iq, ik); STEEL_PRAGMA_UNROLL - for (short jj = 0; jj < MSubTile::kElemsPerFrag; jj++) { + for (short jj = 0; jj < mtile_t::kElemsPerFrag; jj++) { if constexpr (is_bool) { - fg[jj] = mfrag.elems()[jj] ? fg[jj] : neg_inf; + fg[jj] = mfrag[jj] ? fg[jj] : neg_inf; } else { - fg[jj] += M_LOG2E_F * AccumType(mfrag.elems()[jj]); + fg[jj] += M_LOG2E_F * AccumType(mfrag[jj]); } } } @@ -383,10 +363,10 @@ template < } // Row max - Ptile.template row_reduce(new_max); + Stile.template row_reduce(new_max); // exp(Si - rowmax(Si)) - Ptile.template row_bin_op(new_max); + Stile.template row_bin_op(new_max); // Factor exp(rowmax(Si) - rowmax(Si-1)) STEEL_PRAGMA_UNROLL @@ -401,7 +381,7 @@ template < sum_score[i] = sum_score[i] * factor[i]; } - Ptile.template row_reduce(sum_score); + Stile.template row_reduce(sum_score); // Update O Otile.template row_bin_op(factor); @@ -412,19 +392,18 @@ template < STEEL_PRAGMA_UNROLL for (short iq = 0; iq < TQ; iq++) { STEEL_PRAGMA_UNROLL - for (short id = 0; id < TD; id++) { + for (short id = 0; id < TD; id += 2) { if constexpr (BD == 128) { - if (id == 2) { + if (id == 4) { threadgroup_barrier(mem_flags::mem_none); } } STEEL_PRAGMA_UNROLL for (short ik = 0; ik < TK; ik++) { - using VSubTile = NAXSubTile; - NAXTile Vtile; + NAXTile Vtile; - const int V_load_off = ik * UK * int(params->V_strides[2]) + id * UD; + const int V_load_off = ik * kU * int(params->V_strides[2]) + id * kU; if (!align_K && is_last_k) { // Vtile.load_rows( @@ -434,17 +413,19 @@ template < Vtile.load_safe( V + V_load_off, int(params->V_strides[2]), - short2(BD, lim_rows_k - ik * UK)); + short2(BD, lim_rows_k - ik * kU)); } else { Vtile.load(V + V_load_off, int(params->V_strides[2])); } - subtile_matmad_nax( - Otile.subtile_at(iq, id), - Ptile.subtile_at(iq, ik), - metal::bool_constant{}, - Vtile.subtile_at(0, 0), - metal::bool_constant{}); + otile_t::NAXFrag_t::mma( + Otile.frag_at(iq, id), + Otile.frag_at(iq, id + 1), + Stile.frag_at(iq, ik), + metal::false_type{}, + Vtile.frag_at(0, 0), + Vtile.frag_at(0, 1), + metal::false_type{}); } } } @@ -467,7 +448,7 @@ template < Otile.template row_bin_op(rcp); // Store results - O += (tm + sm) * int(params->O_strides[2]) + sn; + O += tm * int(params->O_strides[2]); if (!align_Q && is_last_q) { if (lim_rows_q <= 0) diff --git a/mlx/backend/metal/kernels/steel/attn/nax.h b/mlx/backend/metal/kernels/steel/attn/nax.h index c8f3ea5ef1..21499c781d 100644 --- a/mlx/backend/metal/kernels/steel/attn/nax.h +++ b/mlx/backend/metal/kernels/steel/attn/nax.h @@ -72,11 +72,13 @@ struct BaseNAXFrag { StrY str_y, OffX off_x = {}, OffY off_y = {}) { - const short2 sc = short2{0, 0}; // get_coord(); + const short2 sc = get_coord(); + src += sc.y * str_x + sc.x * str_y; + STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemRows; i++) { - const auto r = off_x + i * kElemRowsJump + sc.y; - const auto c = off_y + sc.x; + const auto r = off_x + i * kElemRowsJump; + const auto c = off_y; if constexpr (metal::is_same_v>) { STEEL_PRAGMA_UNROLL @@ -109,13 +111,16 @@ struct BaseNAXFrag { LimX lim_x, OffX off_x = {}, OffY off_y = {}) { - const short2 sc = short2{0, 0}; // get_coord(); + const short2 sc = get_coord(); + dst += sc.y * str_x + sc.x * str_y; + auto lx = lim_x - sc.y; + STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemRows; i++) { - const auto r = off_x + i * kElemRowsJump + sc.y; - const auto c = off_y + sc.x; + const auto r = off_x + i * kElemRowsJump; + const auto c = off_y; - if (r < lim_x) { + if (r < lx) { if constexpr (metal::is_same_v>) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { @@ -153,14 +158,18 @@ struct BaseNAXFrag { LimY lim_y, OffX off_x = {}, OffY off_y = {}) { - const short2 sc = short2{0, 0}; // get_coord(); + const short2 sc = get_coord(); + src += sc.y * str_x + sc.x * str_y; + auto lx = lim_x - sc.y; + auto ly = lim_y - sc.x; + STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemRows; i++) { - const auto r = off_x + i * kElemRowsJump + sc.y; - const auto c = off_y + sc.x; + const auto r = off_x + i * kElemRowsJump; + const auto c = off_y; STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { - if (r < lim_x && (c + j) < lim_y) { + if (r < lx && (c + j) < ly) { dst[i * kElemCols + j] = static_cast(src[r * str_x + (c + j) * str_y]); } else { @@ -186,11 +195,13 @@ struct BaseNAXFrag { OffY off_y = {}) { using U = pointer_element_t; - const short2 sc = short2{0, 0}; // get_coord(); + const short2 sc = get_coord(); + dst += sc.y * str_x + sc.x * str_y; + STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemRows; i++) { - const auto r = off_x + i * kElemRowsJump + sc.y; - const auto c = off_y + sc.x; + const auto r = off_x + i * kElemRowsJump; + const auto c = off_y; if constexpr (metal::is_same_v>) { STEEL_PRAGMA_UNROLL @@ -225,13 +236,16 @@ struct BaseNAXFrag { OffY off_y = {}) { using U = pointer_element_t; - const short2 sc = short2{0, 0}; // get_coord(); + const short2 sc = get_coord(); + dst += sc.y * str_x + sc.x * str_y; + auto lx = lim_x - sc.y; + STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemRows; i++) { - const auto r = off_x + i * kElemRowsJump + sc.y; - const auto c = off_y + sc.x; + const auto r = off_x + i * kElemRowsJump; + const auto c = off_y; - if (r < lim_x) { + if (r < lx) { if constexpr (metal::is_same_v>) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { @@ -268,15 +282,19 @@ struct BaseNAXFrag { OffY off_y = {}) { using U = pointer_element_t; - const short2 sc = short2{0, 0}; // get_coord(); + const short2 sc = get_coord(); + dst += sc.y * str_x + sc.x * str_y; + auto lx = lim_x - sc.y; + auto ly = lim_y - sc.x; + STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemRows; i++) { - const auto r = off_x + i * kElemRowsJump + sc.y; - const auto c = off_y + sc.x; + const auto r = off_x + i * kElemRowsJump; + const auto c = off_y; STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { - if (r < lim_x && (c + j) < lim_y) { + if (r < lx && (c + j) < ly) { dst[r * str_x + (c + j) * str_y] = static_cast(src[i * kElemCols + j]); } @@ -308,7 +326,7 @@ struct BaseNAXFrag { OffY off_y = Int<0>{}) { using U = pointer_element_t; - const short2 sc = short2{0, 0}; // get_coord(); + const short2 sc = get_coord(); const_for_loop<0, kElemRows, 1>([&](auto idx_row) { const auto r = off_x + idx_row * Int{}; @@ -362,627 +380,413 @@ struct BaseNAXFrag { } } } -}; - -template < - typename T, - short kRows_, - short kCols_, - typename NAXFrag_ = BaseNAXFrag> -struct NAXSubTile { - using NAXFrag_t = NAXFrag_; - STEEL_CONST short kRows = kRows_; - STEEL_CONST short kCols = kCols_; - - STEEL_CONST short kFragRows = NAXFrag_t::kFragRows; - STEEL_CONST short kFragCols = NAXFrag_t::kFragCols; - STEEL_CONST short kElemsPerFrag = NAXFrag_t::kElemsPerFrag; - - STEEL_CONST short kSubTileRows = kRows / kFragRows; - STEEL_CONST short kSubTileCols = kCols / kFragCols; - - STEEL_CONST short kNumFrags = kSubTileRows * kSubTileCols; - STEEL_CONST short kElemsPerSubTile = kNumFrags * kElemsPerFrag; - - STEEL_CONST int kRowsPerThread = kSubTileRows * NAXFrag_t::kElemRows; - STEEL_CONST int kColsPerThread = kSubTileCols * NAXFrag_t::kElemCols; - - STEEL_CONST short kFragThrRows = NAXFrag_t::kElemRows; - STEEL_CONST short kFragThrCols = NAXFrag_t::kElemCols; - STEEL_CONST short kFragRowsJump = NAXFrag_t::kElemRowsJump; - - using frag_type = typename NAXFrag_t::template dtype_frag_t; - - frag_type val_frags[kNumFrags]; - METAL_FUNC constexpr void clear() { + template < + typename CType, + typename AType, + typename BType, + bool transpose_a = false, + bool transpose_b = false> + METAL_FUNC static constexpr void mma( + thread dtype_frag_t& Cn0, + thread dtype_frag_t& Cn1, + const thread dtype_frag_t& A, + metal::bool_constant, + const thread dtype_frag_t& Bn0, + const thread dtype_frag_t& Bn1, + metal::bool_constant) { + constexpr auto desc = mpp::tensor_ops::matmul2d_descriptor( + 16, + 32, + 16, + transpose_a, + transpose_b, + true, + mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate); + + // Create matmul op + mpp::tensor_ops::matmul2d gemm_op; + + // Create matmul operands in registers + auto ct_a = + gemm_op + .template get_left_input_cooperative_tensor(); + auto ct_b = + gemm_op + .template get_right_input_cooperative_tensor(); + + // Create matmul output in register + auto ct_c = gemm_op.template get_destination_cooperative_tensor< + decltype(ct_a), + decltype(ct_b), + CType>(); + + // Load A in to left operand registers STEEL_PRAGMA_UNROLL - for (short i = 0; i < kNumFrags; ++i) { - val_frags[i] = frag_type(0); + for (short i = 0; i < kElemsPerFrag; i++) { + ct_a[i] = A[i]; } - } - - METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) { - return val_frags[i * kSubTileCols + j]; - } - - METAL_FUNC constexpr const thread frag_type& frag_at( - const short i, - const short j) const { - return val_frags[i * kSubTileCols + j]; - } - - template - METAL_FUNC constexpr thread frag_type& frag_at() { - return val_frags[i * kSubTileCols + j]; - } - - template - METAL_FUNC constexpr const thread frag_type& frag_at() const { - return val_frags[i * kSubTileCols + j]; - } - - METAL_FUNC thread T* elems() { - return reinterpret_cast(val_frags); - } - - METAL_FUNC const thread T* elems() const { - return reinterpret_cast(val_frags); - } - template - METAL_FUNC void row_reduce(thread metal::vec& vals) const { - thread T* vptr = (thread T*)(&vals); + // Load B into right operand registers STEEL_PRAGMA_UNROLL - for (short i = 0; i < kSubTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kSubTileCols; ++j) { - NAXFrag_t::template row_reduce( - frag_at(i, j), &vptr[i * kFragThrRows]); - } + for (short i = 0; i < kElemsPerFrag; i++) { + ct_b[i] = Bn0[i]; + ct_b[kElemsPerFrag + i] = Bn1[i]; } - } - template - METAL_FUNC void row_bin_op(thread metal::vec& vals) { - thread T* vptr = (thread T*)(&vals); + // Load C into output registers (op handles accumulation) STEEL_PRAGMA_UNROLL - for (short i = 0; i < kSubTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kSubTileCols; ++j) { - NAXFrag_t::template row_bin_op( - frag_at(i, j), &vptr[i * kFragThrRows]); - } + for (short i = 0; i < kElemsPerFrag; i++) { + ct_c[i] = Cn0[i]; + ct_c[kElemsPerFrag + i] = Cn1[i]; } - } - template < - typename SrcPtrType, - typename StrX, - typename StrY, - typename OffX = Int<0>, - typename OffY = Int<0>> - METAL_FUNC constexpr void load( - SrcPtrType src, - StrX str_x, - StrY str_y, - OffX off_x = {}, - OffY off_y = {}) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kSubTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kSubTileCols; ++j) { - NAXFrag_t::load( - frag_at(i, j), - src, - str_x, - str_y, - off_x + i * kFragRows, - off_y + j * kFragCols); - } - } - } + // Do matmul + gemm_op.run(ct_a, ct_b, ct_c); - template < - typename DstPtrType, - typename StrX, - typename StrY, - typename OffX = Int<0>, - typename OffY = Int<0>> - METAL_FUNC constexpr void store( - DstPtrType dst, - StrX str_x, - StrY str_y, - OffX off_x = {}, - OffY off_y = {}) const { + // Copy out results STEEL_PRAGMA_UNROLL - for (short i = 0; i < kSubTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kSubTileCols; ++j) { - NAXFrag_t::store( - frag_at(i, j), - dst, - str_x, - str_y, - off_x + i * kFragRows, - off_y + j * kFragCols); - } + for (short i = 0; i < kElemsPerFrag; i++) { + Cn0[i] = ct_c[i]; + Cn1[i] = ct_c[kElemsPerFrag + i]; } } template < - typename SrcPtrType, - typename StrX, - typename StrY, - typename LimX, - typename OffX = Int<0>, - typename OffY = Int<0>> - METAL_FUNC constexpr void load_rows( - SrcPtrType src, - StrX str_x, - StrY str_y, - LimX lim_x, - OffX off_x = {}, - OffY off_y = {}) { + typename CType, + typename AType, + typename BType, + bool transpose_a = false, + bool transpose_b = false> + METAL_FUNC static constexpr void mma( + thread dtype_frag_t& Cm0, + thread dtype_frag_t& Cm1, + const thread dtype_frag_t& Am0, + const thread dtype_frag_t& Am1, + metal::bool_constant, + const thread dtype_frag_t& B, + metal::bool_constant) { + // Create Matmul descriptor + constexpr auto desc = mpp::tensor_ops::matmul2d_descriptor( + 16, + 32, + 16, + transpose_a, + transpose_b, + true, + mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate); + + // Create matmul op + mpp::tensor_ops::matmul2d gemm_op; + + // Create matmul operands in registers + auto ct_a = + gemm_op + .template get_left_input_cooperative_tensor(); + auto ct_b = + gemm_op + .template get_right_input_cooperative_tensor(); + + // Create matmul output in register + auto ct_c = gemm_op.template get_destination_cooperative_tensor< + decltype(ct_a), + decltype(ct_b), + CType>(); + + // Load A in to left operand registers STEEL_PRAGMA_UNROLL - for (int i = 0; i < kSubTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < kSubTileCols; ++j) { - NAXFrag_t::load_rows( - frag_at(i, j), - src, - str_x, - str_y, - lim_x, - off_x + (i * kFragRows), - off_y + (j * kFragCols)); - } + for (short i = 0; i < kElemsPerFrag; i++) { + ct_a[i] = Am0[i]; + ct_a[kElemsPerFrag + i] = Am1[i]; } - } - template < - typename SrcPtrType, - typename StrX, - typename StrY, - typename LimX, - typename LimY, - typename OffX = Int<0>, - typename OffY = Int<0>> - METAL_FUNC constexpr void load_safe( - SrcPtrType src, - StrX str_x, - StrY str_y, - LimX lim_x, - LimY lim_y, - OffX off_x = {}, - OffY off_y = {}) { + // Load B into right operand registers STEEL_PRAGMA_UNROLL - for (int i = 0; i < kSubTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < kSubTileCols; ++j) { - NAXFrag_t::load_safe( - frag_at(i, j), - src, - str_x, - str_y, - lim_x, - lim_y, - off_x + (i * kFragRows), - off_y + (j * kFragCols)); - } + for (short i = 0; i < kElemsPerFrag; i++) { + ct_b[i] = B[i]; } - } - template < - typename DstPtrType, - typename StrX, - typename StrY, - typename LimX, - typename OffX = Int<0>, - typename OffY = Int<0>> - METAL_FUNC constexpr void store_rows( - DstPtrType dst, - StrX str_x, - StrY str_y, - LimX lim_x, - OffX off_x = {}, - OffY off_y = {}) const { + // Load C into output registers (op handles accumulation) STEEL_PRAGMA_UNROLL - for (int i = 0; i < kSubTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < kSubTileCols; ++j) { - NAXFrag_t::store_safe( - frag_at(i, j), - dst, - str_x, - str_y, - lim_x, - off_x + (i * kFragRows), - off_y + (j * kFragCols)); - } + for (short i = 0; i < kElemsPerFrag; i++) { + ct_c[i] = Cm0[i]; + ct_c[kElemsPerFrag + i] = Cm1[i]; } - } - template < - typename DstPtrType, - typename StrX, - typename StrY, - typename LimX, - typename LimY, - typename OffX = Int<0>, - typename OffY = Int<0>> - METAL_FUNC constexpr void store_safe( - DstPtrType dst, - StrX str_x, - StrY str_y, - LimX lim_x, - LimY lim_y, - OffX off_x = {}, - OffY off_y = {}) const { + // Do matmul + gemm_op.run(ct_a, ct_b, ct_c); + + // Copy out results STEEL_PRAGMA_UNROLL - for (int i = 0; i < kSubTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < kSubTileCols; ++j) { - NAXFrag_t::store_safe( - frag_at(i, j), - dst, - str_x, - str_y, - lim_x, - lim_y, - off_x + (i * kFragRows), - off_y + (j * kFragCols)); - } + for (short i = 0; i < kElemsPerFrag; i++) { + Cm0[i] = ct_c[i]; + Cm1[i] = ct_c[kElemsPerFrag + i]; } } - - template < - typename DstPtrType, - typename StrX, - typename StrY, - typename StartX, - typename StopX, - typename StartY, - typename StopY, - typename OffX = Int<0>, - typename OffY = Int<0>> - METAL_FUNC constexpr void store_slice( - DstPtrType dst, - StrX str_x, - StrY str_y, - StartX start_x, - StopX stop_x, - StartY start_y, - StopY stop_y, - OffX off_x = Int<0>{}, - OffY off_y = Int<0>{}) const { - const_for_loop<0, kSubTileRows, 1>([&](auto idx_row) { - const_for_loop<0, kSubTileCols, 1>([&](auto idx_col) { - NAXFrag_t::store_slice( - frag_at(), - dst, - str_x, - str_y, - start_x, - stop_x, - start_y, - stop_y, - off_x + idx_row * Int{}, - off_y + idx_col * Int{}); - }); - }); - } }; template < - short RC, - short CC, - short RA, - short CA, - short RB, - short CB, - typename CType, - typename AType, - typename BType, - bool transpose_a, - bool transpose_b, - typename NAXFrag_t = BaseNAXFrag> -METAL_FUNC void subtile_matmad_nax( - thread NAXSubTile& C, - thread NAXSubTile& A, - metal::bool_constant, - thread NAXSubTile& B, - metal::bool_constant) { - // Static checks - constexpr short FMa = transpose_a ? CA : RA; - constexpr short FMc = RC; - static_assert(FMa == FMc, "NAX matmul: M dimensions do not match"); - - constexpr short FNb = transpose_b ? RB : CB; - constexpr short FNc = CC; - static_assert(FNb == FNc, "NAX matmul: N dimensions do not match"); - - constexpr short FKa = transpose_a ? RA : CA; - constexpr short FKb = transpose_b ? CB : RB; - static_assert(FKa == FKb, "NAX matmul: N dimensions do not match"); - - constexpr short FM = FMc; - constexpr short FN = FNc; - constexpr short FK = FKa; - - constexpr int TM = FM / 16; - constexpr int TN = FN / 16; - constexpr int TK = FK / 16; - - constexpr auto desc = mpp::tensor_ops::matmul2d_descriptor( - FM, - FN, - FK, - transpose_a, - transpose_b, - true, - mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate); - - mpp::tensor_ops::matmul2d gemm_op; - - auto ct_a = - gemm_op.template get_left_input_cooperative_tensor(); - auto ct_b = - gemm_op - .template get_right_input_cooperative_tensor(); - auto ct_c = gemm_op.template get_destination_cooperative_tensor< - decltype(ct_a), - decltype(ct_b), - CType>(); - - STEEL_PRAGMA_UNROLL - for (short mm = 0; mm < TM; mm++) { - STEEL_PRAGMA_UNROLL - for (short kk = 0; kk < TK; kk++) { - const short fi = transpose_a ? kk : mm; - const short fj = transpose_a ? mm : kk; - - STEEL_PRAGMA_UNROLL - for (short i = 0; i < 8; i++) { - ct_a[(TK * mm + kk) * 8 + i] = A.frag_at(fi, fj)[i]; - } - } - } - - STEEL_PRAGMA_UNROLL - for (short nn = 0; nn < TN; nn++) { - STEEL_PRAGMA_UNROLL - for (short kk = 0; kk < TK; kk++) { - const short fi = transpose_b ? nn : kk; - const short fj = transpose_b ? kk : nn; - - STEEL_PRAGMA_UNROLL - for (short i = 0; i < 8; i++) { - ct_b[(TN * kk + nn) * 8 + i] = B.frag_at(fi, fj)[i]; - } - } - } - - STEEL_PRAGMA_UNROLL - for (short i = 0; i < ct_c.get_capacity(); i++) { - ct_c[i] = C.elems()[i]; - } - - gemm_op.run(ct_a, ct_b, ct_c); - - STEEL_PRAGMA_UNROLL - for (short i = 0; i < ct_c.get_capacity(); i++) { - C.elems()[i] = ct_c[i]; - } -} - -template + typename T, + short kTileRows_, + short kTileCols_, + class NAXFrag_ = BaseNAXFrag> struct NAXTile { - using NAXSubTile_t = NAXSubTile_; + using NAXFrag_t = NAXFrag_; using elem_type = T; - STEEL_CONST short kSubTileRows = NAXSubTile_t::kRows; - STEEL_CONST short kSubTileCols = NAXSubTile_t::kCols; - STEEL_CONST short kElemsPerSubTile = NAXSubTile_t::kElemsPerSubTile; + + STEEL_CONST short kFragRows = NAXFrag_t::kFragRows; + STEEL_CONST short kFragCols = NAXFrag_t::kFragCols; + STEEL_CONST short kElemsPerFrag = NAXFrag_t::kElemsPerFrag; STEEL_CONST short kTileRows = kTileRows_; STEEL_CONST short kTileCols = kTileCols_; - STEEL_CONST short kRows = kTileRows * kSubTileRows; - STEEL_CONST short kCols = kTileCols * kSubTileCols; + STEEL_CONST short kRows = kTileRows * kFragRows; + STEEL_CONST short kCols = kTileCols * kFragCols; + + STEEL_CONST short kNumFrags = kTileRows * kTileCols; + STEEL_CONST short kElemsPerTile = kNumFrags * kElemsPerFrag; - STEEL_CONST short kSubTiles = kTileRows * kTileCols; - STEEL_CONST short kElemsPerTile = kSubTiles * kElemsPerSubTile; + STEEL_CONST short kFragThrRows = NAXFrag_t::kElemRows; + STEEL_CONST short kFragThrCols = NAXFrag_t::kElemCols; + STEEL_CONST short kFragRowsJump = NAXFrag_t::kElemRowsJump; - STEEL_CONST short kRowsPerThread = kTileRows * NAXSubTile_t::kRowsPerThread; - STEEL_CONST short kColsPerThread = kTileCols * NAXSubTile_t::kColsPerThread; + STEEL_CONST short kRowsPerThread = kTileRows * NAXFrag_t::kElemRows; + STEEL_CONST short kColsPerThread = kTileCols * NAXFrag_t::kElemCols; - STEEL_CONST short kSubTileThrRows = NAXSubTile_t::kRowsPerThread; - STEEL_CONST short kSubTileThrCols = NAXSubTile_t::kColsPerThread; + typedef typename NAXFrag_t::template dtype_frag_t frag_type; - NAXSubTile_t val_subtiles[kSubTiles]; + frag_type val_frags[kNumFrags]; // = {frag_type(0)}; METAL_FUNC NAXTile() thread {} METAL_FUNC constexpr void clear() { STEEL_PRAGMA_UNROLL - for (short i = 0; i < kSubTiles; ++i) { - val_subtiles[i].clear(); + for (short i = 0; i < kNumFrags; ++i) { + val_frags[i] = frag_type(0); } } - METAL_FUNC constexpr thread NAXSubTile_t& subtile_at( - const short i, - const short j) { - return val_subtiles[i * kTileCols + j]; + METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) { + return val_frags[i * kTileCols + j]; } - METAL_FUNC constexpr const thread NAXSubTile_t& subtile_at( + METAL_FUNC constexpr const thread frag_type& frag_at( const short i, const short j) const { - return val_subtiles[i * kTileCols + j]; + return val_frags[i * kTileCols + j]; + } + + template + METAL_FUNC constexpr thread frag_type& frag_at() { + return val_frags[i * kTileCols + j]; } template - METAL_FUNC constexpr const thread NAXSubTile_t& subtile_at() const { - return val_subtiles[i * kTileCols + j]; + METAL_FUNC constexpr const thread frag_type& frag_at() const { + return val_frags[i * kTileCols + j]; + } + + template + METAL_FUNC constexpr thread frag_type& + frag_at(const short i, const short j, metal::bool_constant) { + if constexpr (transpose) { + return frag_at(j, i); + } else { + return frag_at(i, j); + } + } + + template + METAL_FUNC constexpr const thread frag_type& + frag_at(const short i, const short j, metal::bool_constant) const { + if constexpr (transpose) { + return frag_at(j, i); + } else { + return frag_at(i, j); + } + } + + template + METAL_FUNC constexpr thread frag_type& frag_at() { + if constexpr (transpose) { + return frag_at(); + } else { + return frag_at(); + } + } + + template + METAL_FUNC constexpr const thread frag_type& frag_at() const { + if constexpr (transpose) { + return frag_at(); + } else { + return frag_at(); + } } METAL_FUNC thread elem_type* elems() { - return reinterpret_cast(val_subtiles[0].elems()); + return reinterpret_cast(val_frags); } METAL_FUNC const thread elem_type* elems() const { - return reinterpret_cast(val_subtiles[0].elems()); + return reinterpret_cast(val_frags); } template METAL_FUNC void row_reduce(thread metal::vec& vals) const { - auto sub_rows = (thread metal::vec*)(&vals); + auto vptr = (thread T*)(&vals); STEEL_PRAGMA_UNROLL for (short i = 0; i < kTileRows; ++i) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kTileCols; ++j) { - subtile_at(i, j).template row_reduce(sub_rows[i]); + NAXFrag_t::template row_reduce( + frag_at(i, j), &vptr[i * kFragThrRows]); } } } template METAL_FUNC void row_bin_op(thread metal::vec& vals) { - auto sub_rows = (thread metal::vec*)(&vals); + auto vptr = (thread T*)(&vals); STEEL_PRAGMA_UNROLL for (short i = 0; i < kTileRows; ++i) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kTileCols; ++j) { - subtile_at(i, j).template row_bin_op(sub_rows[i]); + NAXFrag_t::template row_bin_op( + frag_at(i, j), &vptr[i * kFragThrRows]); } } } template METAL_FUNC void load(const threadgroup U* src) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kTileCols; ++j) { - subtile_at(i, j).load( + const_for_loop<0, kTileRows, 1>([&](auto idx_row) { + const_for_loop<0, kTileCols, 1>([&](auto idx_col) { + NAXFrag_t::load( + frag_at(), src, Int{}, Int{}, - i * kSubTileRows, - j * kSubTileCols); - } - } + idx_row * Int{}, + idx_col * Int{}); + }); + }); } template METAL_FUNC void store(threadgroup U* dst) const { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kTileCols; ++j) { - subtile_at(i, j).store( + const_for_loop<0, kTileRows, 1>([&](auto idx_row) { + const_for_loop<0, kTileCols, 1>([&](auto idx_col) { + NAXFrag_t::store( + frag_at(), dst, Int{}, Int{}, - i * kSubTileRows, - j * kSubTileCols); - } - } + idx_row * Int{}, + idx_col * Int{}); + }); + }); } template METAL_FUNC void load(const device U* src, const int ld) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kTileCols; ++j) { - subtile_at(i, j).load( - &src[(i * kSubTileRows) * ld + (j * kSubTileCols)], ld, Int<1>{}); - } - } + const_for_loop<0, kTileRows, 1>([&](auto idx_row) { + const_for_loop<0, kTileCols, 1>([&](auto idx_col) { + NAXFrag_t::load( + frag_at(), + src, + ld, + Int<1>{}, + idx_row * Int{}, + idx_col * Int{}); + }); + }); } template METAL_FUNC void store(device U* dst, const int ld) const { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kTileCols; ++j) { - subtile_at(i, j).store( - &dst[(i * kSubTileRows) * ld + (j * kSubTileCols)], ld, Int<1>{}); - } - } + const_for_loop<0, kTileRows, 1>([&](auto idx_row) { + const_for_loop<0, kTileCols, 1>([&](auto idx_col) { + NAXFrag_t::store( + frag_at(), + dst, + ld, + Int<1>{}, + idx_row * Int{}, + idx_col * Int{}); + }); + }); + } + + template + METAL_FUNC void + load_rows(const device U* src, const int ld, const short n_rows) { + const_for_loop<0, kTileRows, 1>([&](auto idx_row) { + const_for_loop<0, kTileCols, 1>([&](auto idx_col) { + NAXFrag_t::load_rows( + frag_at(), + src, + ld, + Int<1>{}, + n_rows, + idx_row * Int{}, + idx_col * Int{}); + }); + }); } template METAL_FUNC void load_safe(const device U* src, const int ld, const short2 src_tile_dims) { - STEEL_PRAGMA_UNROLL - for (int i = 0; i < kTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < kTileCols; ++j) { - subtile_at(i, j).load_safe( + const_for_loop<0, kTileRows, 1>([&](auto idx_row) { + const_for_loop<0, kTileCols, 1>([&](auto idx_col) { + NAXFrag_t::load_safe( + frag_at(), src, ld, Int<1>{}, src_tile_dims.y, src_tile_dims.x, - i * kSubTileRows, - j * kSubTileCols); - } - } + idx_row * Int{}, + idx_col * Int{}); + }); + }); } template - METAL_FUNC void - load_rows(const device U* src, const int ld, const short n_rows) { - STEEL_PRAGMA_UNROLL - for (int i = 0; i < kTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < kTileCols; ++j) { - subtile_at(i, j).load_rows( - &src[(i * kSubTileRows) * ld + (j * kSubTileCols)], + METAL_FUNC void store_rows(device U* dst, const int ld, const short n_rows) + const { + const_for_loop<0, kTileRows, 1>([&](auto idx_row) { + const_for_loop<0, kTileCols, 1>([&](auto idx_col) { + NAXFrag_t::store_rows( + frag_at(), + dst, ld, Int<1>{}, - n_rows - i * kSubTileRows); - } - } + n_rows, + idx_row * Int{}, + idx_col * Int{}); + }); + }); } template METAL_FUNC void store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const { - STEEL_PRAGMA_UNROLL - for (int i = 0; i < kTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < kTileCols; ++j) { - subtile_at(i, j).store_safe( + const_for_loop<0, kTileRows, 1>([&](auto idx_row) { + const_for_loop<0, kTileCols, 1>([&](auto idx_col) { + NAXFrag_t::store_safe( + frag_at(), dst, ld, Int<1>{}, dst_tile_dims.y, dst_tile_dims.x, - i * kSubTileRows, - j * kSubTileCols); - } - } - } - - template - METAL_FUNC void store_rows(device U* dst, const int ld, const short n_rows) - const { - STEEL_PRAGMA_UNROLL - for (int i = 0; i < kTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < kTileCols; ++j) { - subtile_at(i, j).store_rows( - &dst[(i * kSubTileRows) * ld + (j * kSubTileCols)], - ld, - Int<1>{}, - n_rows - i * kSubTileRows); - } - } + idx_row * Int{}, + idx_col * Int{}); + }); + }); } template @@ -993,7 +797,8 @@ struct NAXTile { const short2 stop) const { const_for_loop<0, kTileRows, 1>([&](auto idx_row) { const_for_loop<0, kTileCols, 1>([&](auto idx_col) { - subtile_at().store_slice( + NAXFrag_t::store_slice( + frag_at(), dst, ld, Int<1>{}, @@ -1001,8 +806,8 @@ struct NAXTile { stop.y, start.x, stop.x, - idx_row * Int{}, - idx_col * Int{}); + idx_row * Int{}, + idx_col * Int{}); }); }); } @@ -1022,51 +827,54 @@ METAL_FUNC void tile_matmad_nax( metal::bool_constant) { // Static checks constexpr short TMa = transpose_a ? ATile::kTileCols : ATile::kTileRows; - constexpr short TMc = CTile::kTileRows; - static_assert(TMa == TMc, "NAX tile matmul: M dimensions do not match"); - - constexpr short FMa = transpose_a ? ATile::kSubTileCols : ATile::kSubTileRows; - constexpr short FMc = CTile::kSubTileRows; - static_assert(FMa == FMc, "NAX subtile matmul: M dimensions do not match"); + constexpr short TM = CTile::kTileRows; + static_assert(TMa == TM, "MXU tile matmul: M dimensions do not match"); constexpr short TNb = transpose_b ? BTile::kTileRows : BTile::kTileCols; - constexpr short TNc = CTile::kTileCols; - static_assert(TNb == TNc, "NAX tile matmul: N dimensions do not match"); - - constexpr short FNb = transpose_b ? BTile::kSubTileRows : BTile::kSubTileCols; - constexpr short FNc = CTile::kSubTileCols; - static_assert(FNb == FNc, "NAX subtile matmul: N dimensions do not match"); + constexpr short TN = CTile::kTileCols; + static_assert(TNb == TN, "MXU tile matmul: N dimensions do not match"); constexpr short TKa = transpose_a ? ATile::kTileRows : ATile::kTileCols; - constexpr short TKb = transpose_b ? BTile::kTileCols : BTile::kTileRows; - static_assert(TKa == TKb, "NAX tile matmul: K dimensions do not match"); - - constexpr short FKa = transpose_a ? ATile::kSubTileRows : ATile::kSubTileCols; - constexpr short FKb = transpose_b ? BTile::kSubTileCols : BTile::kSubTileRows; - static_assert(FKa == FKb, "NAX subtile matmul: K dimensions do not match"); + constexpr short TK = transpose_b ? BTile::kTileCols : BTile::kTileRows; + static_assert(TKa == TK, "MXU tile matmul: K dimensions do not match"); - constexpr short TM = TMc; - constexpr short TN = TNc; - constexpr short TK = TKa; + constexpr auto ta = metal::bool_constant{}; + constexpr auto tb = metal::bool_constant{}; - // Do matmul here - STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; ++i) { + if constexpr (TN == 1 && TM % 2 == 0) { STEEL_PRAGMA_UNROLL - for (short j = 0; j < TN; ++j) { + for (short mm = 0; mm < TM; mm += 2) { STEEL_PRAGMA_UNROLL - for (short k = 0; k < TK; ++k) { - const short ra = transpose_a ? k : i; - const short ca = transpose_a ? i : k; - const short rb = transpose_b ? j : k; - const short cb = transpose_b ? k : j; - - subtile_matmad_nax( - C.subtile_at(i, j), - A.subtile_at(ra, ca), - metal::bool_constant{}, - B.subtile_at(rb, cb), - metal::bool_constant{}); + for (short nn = 0; nn < TN; ++nn) { + STEEL_PRAGMA_UNROLL + for (short kk = 0; kk < TK; ++kk) { + CTile::NAXFrag_t::mma( + C.frag_at(mm, nn), + C.frag_at(mm + 1, nn), + A.frag_at(mm, kk, ta), + A.frag_at(mm + 1, kk, ta), + metal::bool_constant{}, + B.frag_at(kk, nn, tb), + metal::bool_constant{}); + } + } + } + } else if constexpr (TN % 2 == 0) { + STEEL_PRAGMA_UNROLL + for (short mm = 0; mm < TM; ++mm) { + STEEL_PRAGMA_UNROLL + for (short nn = 0; nn < TN; nn += 2) { + STEEL_PRAGMA_UNROLL + for (short kk = 0; kk < TK; ++kk) { + CTile::NAXFrag_t::mma( + C.frag_at(mm, nn), + C.frag_at(mm, nn + 1), + A.frag_at(mm, kk, ta), + metal::bool_constant{}, + B.frag_at(kk, nn, tb), + B.frag_at(kk, nn + 1, tb), + metal::bool_constant{}); + } } } } diff --git a/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h b/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h index 26645cfb5f..40066be7a4 100644 --- a/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +++ b/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h @@ -22,9 +22,6 @@ template < bool kAlignedM, bool kAlignedN, bool kAlignedK, - short UM, - short UN, - short UK, typename AccumType = float> auto gemm_loop( const device T* A, @@ -35,9 +32,9 @@ auto gemm_loop( int gemm_k_iterations_aligned, const short sgp_sm, const short sgp_sn) { - constexpr short TM = SM / UM; - constexpr short TN = SN / UN; - constexpr short TK = SK / UK; + constexpr short TM = SM / 16; + constexpr short TN = SN / 16; + constexpr short TK = SK / 16; constexpr int RA = transpose_a ? TK : TM; constexpr int CA = transpose_a ? TM : TK; @@ -45,13 +42,7 @@ auto gemm_loop( constexpr int RB = transpose_b ? TN : TK; constexpr int CB = transpose_b ? TK : TN; - using DSubTile = NAXSubTile; - using ASubTile = - NAXSubTile; - using BSubTile = - NAXSubTile; - - NAXTile Dtile; + NAXTile Dtile; Dtile.clear(); int gemm_k_iterations_ = gemm_k_iterations_aligned; @@ -62,8 +53,8 @@ auto gemm_loop( STEEL_PRAGMA_NO_UNROLL for (int kk1 = 0; kk1 < BK; kk1 += SK) { - NAXTile Atile; - NAXTile Btile; + NAXTile Atile; + NAXTile Btile; const int k = kk1; volatile int compiler_barrier; @@ -108,46 +99,29 @@ auto gemm_loop( STEEL_PRAGMA_NO_UNROLL for (int kk1 = 0; kk1 < rem_bk; kk1 += SK) { - NAXTile Atile; - NAXTile Btile; - - STEEL_PRAGMA_UNROLL - for (int mm = 0; mm < TM; mm++) { - STEEL_PRAGMA_UNROLL - for (int nn = 0; nn < TN; nn++) { - STEEL_PRAGMA_UNROLL - for (int kk = 0; kk < TK; kk++) { - const int m = mm * UM; - const int n = nn * UN; - const int k = kk1 + kk * UK; - const short psk = max(0, rem_bk - k); - - const int A_offset = transpose_a ? (m + k * lda) : (m * lda + k); - const int B_offset = transpose_b ? (k + n * ldb) : (k * ldb + n); - - { - const short psm = kAlignedM ? SM : max(0, sgp_sm - m); - const short rmax = transpose_a ? psk : psm; - const short cmax = transpose_a ? psm : psk; - Atile.load_safe(A + A_offset, lda, short2(cmax, rmax)); - } - - { - const short psn = kAlignedN ? SN : max(0, sgp_sn - n); - const short rmax = transpose_b ? psn : psk; - const short cmax = transpose_b ? psk : psn; - Btile.load_safe(B + B_offset, ldb, short2(cmax, rmax)); - } - - subtile_matmad_nax( - Dtile.subtile_at(mm, nn), - Atile.subtile_at(0, 0), - metal::bool_constant{}, - Btile.subtile_at(0, 0), - metal::bool_constant{}); - } - } - } + NAXTile Atile; + NAXTile Btile; + + const int k = kk1; + const short psk = max(0, rem_bk - k); + + const short2 Aklims = + transpose_a ? short2(sgp_sm, psk) : short2(psk, sgp_sm); + const short2 Bklims = + transpose_b ? short2(psk, sgp_sn) : short2(sgp_sn, psk); + + const int A_offset = transpose_a ? k * lda : k; + const int B_offset = transpose_b ? k : k * ldb; + + Atile.load_safe(A + A_offset, lda, Aklims); + Btile.load_safe(B + B_offset, ldb, Bklims); + + tile_matmad_nax( + Dtile, + Atile, + metal::bool_constant{}, + Btile, + metal::bool_constant{}); } } diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h index 4ff9260650..9b25f0c9c6 100644 --- a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h @@ -15,7 +15,7 @@ constant bool align_K [[function_constant(202)]]; template < bool kAlignedM, bool kAlignedN, - typename NAXTile_t, + class NAXTile_t, typename T> void gemm_epilogue( thread NAXTile_t& Dtile, @@ -27,37 +27,42 @@ void gemm_epilogue( (void)params; - constexpr short UM = NAXTile_t::kSubTileRows; - constexpr short UN = NAXTile_t::kSubTileCols; - using CSubTile = NAXSubTile; - using V = typename NAXTile_t::elem_type; constexpr short TM = NAXTile_t::kTileRows; constexpr short TN = NAXTile_t::kTileCols; - constexpr short kElemsPerSubTile = NAXTile_t::kElemsPerSubTile; + constexpr short kElemsPerFrag = NAXTile_t::kElemsPerFrag; + + using CFrag = typename NAXTile_t::NAXFrag_t; + using cfrag_t = typename CFrag::template dtype_frag_t; STEEL_PRAGMA_UNROLL for (short mm = 0; mm < TM; mm++) { STEEL_PRAGMA_UNROLL for (short nn = 0; nn < TN; nn++) { - const short m = mm * UM; - const short n = nn * UN; + const short m = mm * CFrag::kFragRows; + const short n = nn * CFrag::kFragCols; - CSubTile CTile; + cfrag_t celems; if constexpr (kAlignedM && kAlignedN) { - CTile.load(C, addmm_params->ldc, addmm_params->fdc, m, n); + CFrag::load(celems, C, addmm_params->ldc, addmm_params->fdc, m, n); } else { - CTile.load_safe( - C, addmm_params->ldc, addmm_params->fdc, sgp_sm, sgp_sn, m, n); + CFrag::load_safe( + celems, + C, + addmm_params->ldc, + addmm_params->fdc, + sgp_sm, + sgp_sn, + m, + n); } - auto delems = Dtile.subtile_at(mm, nn).elems(); - auto celems = CTile.elems(); + auto delems = Dtile.frag_at(mm, nn); STEEL_PRAGMA_UNROLL - for (short i = 0; i < kElemsPerSubTile; i++) { + for (short i = 0; i < kElemsPerFrag; i++) { if (do_axpby) { delems[i] = addmm_params->alpha * delems[i] + addmm_params->beta * static_cast(celems[i]); @@ -144,15 +149,12 @@ template < C += c_row_long * addmm_params->ldc + c_col_long * addmm_params->fdc; } - constexpr short UM = 16; - constexpr short UN = 32; - constexpr short UK = 16; constexpr short SM = BM / WM; constexpr short SN = BN / WN; constexpr short SK = 32; - constexpr short TM = SM / UM; - constexpr short TN = SN / UN; + constexpr short TM = SM / 16; + constexpr short TN = SN / 16; const short tm = SM * (simd_group_id / WN); const short tn = SN * (simd_group_id % WN); @@ -175,8 +177,7 @@ template < C += tm * addmm_params->ldc + tn * addmm_params->fdc; } - using DSubTile = NAXSubTile; - NAXTile Dtile; + NAXTile Dtile; dispatch_bool(align_K, [&](auto kAlignedK) { dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) { @@ -192,9 +193,6 @@ template < kAlignedM.value, kAlignedN.value, kAlignedK.value, - UM, - UN, - UK, AccumType>( A, B, diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h index 67cd737846..7fec270609 100644 --- a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h @@ -25,14 +25,11 @@ gather_mm_rhs_nax( const constant GEMMParams* params [[buffer(4)]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint3 tid [[threadgroup_position_in_grid]]) { - constexpr short UM = 16; - constexpr short UN = 32; - constexpr short UK = 16; constexpr short SM = BM / WM; constexpr short SN = BN / WN; constexpr short SK = 32; - constexpr short TM = SM / UM; - constexpr short TN = SN / UN; + constexpr short TM = SM / 16; + constexpr short TN = SN / 16; if (params->tiles_n <= static_cast(tid.x) || params->tiles_m <= static_cast(tid.y)) { @@ -88,13 +85,12 @@ gather_mm_rhs_nax( } threadgroup_barrier(mem_flags::mem_none); - using DSubTile = NAXSubTile; - NAXTile Ctile; + NAXTile Ctile; dispatch_bool(align_K, [&](auto kAlignedK) { dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) { dispatch_bool(align_N || !is_unaligned_sn, [&](auto kAlignedN) { - auto do_gemm = gemm_loop< + auto do_gemm = gemm_loop< // Matmul for partial BM, full BN and full K T, SM, SN, @@ -105,9 +101,6 @@ gather_mm_rhs_nax( kAlignedM.value, kAlignedN.value, kAlignedK.value, - UM, - UN, - UK, AccumType>; Ctile = do_gemm( A, diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h index 1b6b828084..982c770c67 100644 --- a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h @@ -72,15 +72,12 @@ template < (c_row_long * params->ldc + c_col_long); // NAX tile configuration - constexpr short UM = 16; - constexpr short UN = 32; - constexpr short UK = 16; constexpr short SM = BM / WM; constexpr short SN = BN / WN; constexpr short SK = 32; - constexpr short TM = SM / UM; - constexpr short TN = SN / UN; + constexpr short TM = SM / 16; + constexpr short TN = SN / 16; // Calculate simdgroup offsets and alignment const short tm = SM * (simd_group_id / WN); @@ -100,8 +97,7 @@ template < B += transpose_b ? (tn * params->ldb) : tn; C += tm * params->ldc + tn; - using DSubTile = NAXSubTile; - NAXTile Dtile; + NAXTile Dtile; // gemm_loop through the partition // Check K-alignment at runtime (partition-specific) @@ -123,9 +119,6 @@ template < kAlignedM.value, kAlignedN.value, kAlignedK.value, - UM, - UN, - UK, AccumType>( A, B, diff --git a/mlx/backend/metal/kernels/steel/gemm/nax.h b/mlx/backend/metal/kernels/steel/gemm/nax.h index 5839176c28..21499c781d 100644 --- a/mlx/backend/metal/kernels/steel/gemm/nax.h +++ b/mlx/backend/metal/kernels/steel/gemm/nax.h @@ -7,7 +7,6 @@ #include #include "mlx/backend/metal/kernels/steel/defines.h" -#include "mlx/backend/metal/kernels/steel/gemm/transforms.h" #include "mlx/backend/metal/kernels/steel/utils/integral_constant.h" #include @@ -74,10 +73,12 @@ struct BaseNAXFrag { OffX off_x = {}, OffY off_y = {}) { const short2 sc = get_coord(); + src += sc.y * str_x + sc.x * str_y; + STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemRows; i++) { - const auto r = off_x + i * kElemRowsJump + sc.y; - const auto c = off_y + sc.x; + const auto r = off_x + i * kElemRowsJump; + const auto c = off_y; if constexpr (metal::is_same_v>) { STEEL_PRAGMA_UNROLL @@ -111,12 +112,15 @@ struct BaseNAXFrag { OffX off_x = {}, OffY off_y = {}) { const short2 sc = get_coord(); + dst += sc.y * str_x + sc.x * str_y; + auto lx = lim_x - sc.y; + STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemRows; i++) { - const auto r = off_x + i * kElemRowsJump + sc.y; - const auto c = off_y + sc.x; + const auto r = off_x + i * kElemRowsJump; + const auto c = off_y; - if (r < lim_x) { + if (r < lx) { if constexpr (metal::is_same_v>) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { @@ -155,13 +159,17 @@ struct BaseNAXFrag { OffX off_x = {}, OffY off_y = {}) { const short2 sc = get_coord(); + src += sc.y * str_x + sc.x * str_y; + auto lx = lim_x - sc.y; + auto ly = lim_y - sc.x; + STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemRows; i++) { - const auto r = off_x + i * kElemRowsJump + sc.y; - const auto c = off_y + sc.x; + const auto r = off_x + i * kElemRowsJump; + const auto c = off_y; STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { - if (r < lim_x && (c + j) < lim_y) { + if (r < lx && (c + j) < ly) { dst[i * kElemCols + j] = static_cast(src[r * str_x + (c + j) * str_y]); } else { @@ -188,10 +196,12 @@ struct BaseNAXFrag { using U = pointer_element_t; const short2 sc = get_coord(); + dst += sc.y * str_x + sc.x * str_y; + STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemRows; i++) { - const auto r = off_x + i * kElemRowsJump + sc.y; - const auto c = off_y + sc.x; + const auto r = off_x + i * kElemRowsJump; + const auto c = off_y; if constexpr (metal::is_same_v>) { STEEL_PRAGMA_UNROLL @@ -227,12 +237,15 @@ struct BaseNAXFrag { using U = pointer_element_t; const short2 sc = get_coord(); + dst += sc.y * str_x + sc.x * str_y; + auto lx = lim_x - sc.y; + STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemRows; i++) { - const auto r = off_x + i * kElemRowsJump + sc.y; - const auto c = off_y + sc.x; + const auto r = off_x + i * kElemRowsJump; + const auto c = off_y; - if (r < lim_x) { + if (r < lx) { if constexpr (metal::is_same_v>) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { @@ -270,14 +283,18 @@ struct BaseNAXFrag { using U = pointer_element_t; const short2 sc = get_coord(); + dst += sc.y * str_x + sc.x * str_y; + auto lx = lim_x - sc.y; + auto ly = lim_y - sc.x; + STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemRows; i++) { - const auto r = off_x + i * kElemRowsJump + sc.y; - const auto c = off_y + sc.x; + const auto r = off_x + i * kElemRowsJump; + const auto c = off_y; STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { - if (r < lim_x && (c + j) < lim_y) { + if (r < lx && (c + j) < ly) { dst[r * str_x + (c + j) * str_y] = static_cast(src[i * kElemCols + j]); } @@ -363,634 +380,413 @@ struct BaseNAXFrag { } } } -}; - -template < - typename T, - short kRows_, - short kCols_, - typename NAXFrag_t = BaseNAXFrag> -struct NAXSubTile { - STEEL_CONST short kRows = kRows_; - STEEL_CONST short kCols = kCols_; - - STEEL_CONST short kFragRows = NAXFrag_t::kFragRows; - STEEL_CONST short kFragCols = NAXFrag_t::kFragCols; - STEEL_CONST short kElemsPerFrag = NAXFrag_t::kElemsPerFrag; - - STEEL_CONST short kSubTileRows = kRows / kFragRows; - STEEL_CONST short kSubTileCols = kCols / kFragCols; - - STEEL_CONST short kNumFrags = kSubTileRows * kSubTileCols; - STEEL_CONST short kElemsPerSubTile = kNumFrags * kElemsPerFrag; - - STEEL_CONST int kRowsPerThread = kSubTileRows * NAXFrag_t::kElemRows; - STEEL_CONST int kColsPerThread = kSubTileCols * NAXFrag_t::kElemCols; - - STEEL_CONST short kFragThrRows = NAXFrag_t::kElemRows; - STEEL_CONST short kFragThrCols = NAXFrag_t::kElemCols; - STEEL_CONST short kFragRowsJump = NAXFrag_t::kElemRowsJump; - - using frag_type = typename NAXFrag_t::template dtype_frag_t; - - frag_type val_frags[kNumFrags]; - METAL_FUNC constexpr void clear() { + template < + typename CType, + typename AType, + typename BType, + bool transpose_a = false, + bool transpose_b = false> + METAL_FUNC static constexpr void mma( + thread dtype_frag_t& Cn0, + thread dtype_frag_t& Cn1, + const thread dtype_frag_t& A, + metal::bool_constant, + const thread dtype_frag_t& Bn0, + const thread dtype_frag_t& Bn1, + metal::bool_constant) { + constexpr auto desc = mpp::tensor_ops::matmul2d_descriptor( + 16, + 32, + 16, + transpose_a, + transpose_b, + true, + mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate); + + // Create matmul op + mpp::tensor_ops::matmul2d gemm_op; + + // Create matmul operands in registers + auto ct_a = + gemm_op + .template get_left_input_cooperative_tensor(); + auto ct_b = + gemm_op + .template get_right_input_cooperative_tensor(); + + // Create matmul output in register + auto ct_c = gemm_op.template get_destination_cooperative_tensor< + decltype(ct_a), + decltype(ct_b), + CType>(); + + // Load A in to left operand registers STEEL_PRAGMA_UNROLL - for (short i = 0; i < kNumFrags; ++i) { - val_frags[i] = frag_type(0); + for (short i = 0; i < kElemsPerFrag; i++) { + ct_a[i] = A[i]; } - } - - METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) { - return val_frags[i * kSubTileCols + j]; - } - - METAL_FUNC constexpr const thread frag_type& frag_at( - const short i, - const short j) const { - return val_frags[i * kSubTileCols + j]; - } - - template - METAL_FUNC constexpr thread frag_type& frag_at() { - return val_frags[i * kSubTileCols + j]; - } - - template - METAL_FUNC constexpr const thread frag_type& frag_at() const { - return val_frags[i * kSubTileCols + j]; - } - - METAL_FUNC thread T* elems() { - return reinterpret_cast(val_frags); - } - - METAL_FUNC const thread T* elems() const { - return reinterpret_cast(val_frags); - } - template - METAL_FUNC void row_reduce(thread metal::vec& vals) const { + // Load B into right operand registers STEEL_PRAGMA_UNROLL - for (short i = 0; i < kSubTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kSubTileCols; ++j) { - NAXFrag_t::template row_reduce( - frag_at(i, j), &vals[i * kFragThrRows]); - } + for (short i = 0; i < kElemsPerFrag; i++) { + ct_b[i] = Bn0[i]; + ct_b[kElemsPerFrag + i] = Bn1[i]; } - } - template - METAL_FUNC void row_bin_op(thread metal::vec& vals) { + // Load C into output registers (op handles accumulation) STEEL_PRAGMA_UNROLL - for (short i = 0; i < kSubTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kSubTileCols; ++j) { - NAXFrag_t::template row_bin_op( - frag_at(i, j), &vals[i * kFragThrRows]); - } + for (short i = 0; i < kElemsPerFrag; i++) { + ct_c[i] = Cn0[i]; + ct_c[kElemsPerFrag + i] = Cn1[i]; } - } - template < - typename SrcPtrType, - typename StrX, - typename StrY, - typename OffX = Int<0>, - typename OffY = Int<0>> - METAL_FUNC constexpr void load( - SrcPtrType src, - StrX str_x, - StrY str_y, - OffX off_x = {}, - OffY off_y = {}) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kSubTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kSubTileCols; ++j) { - NAXFrag_t::load( - frag_at(i, j), - src, - str_x, - str_y, - off_x + i * kFragRows, - off_y + j * kFragCols); - } - } - } + // Do matmul + gemm_op.run(ct_a, ct_b, ct_c); - template < - typename DstPtrType, - typename StrX, - typename StrY, - typename OffX = Int<0>, - typename OffY = Int<0>> - METAL_FUNC constexpr void store( - DstPtrType dst, - StrX str_x, - StrY str_y, - OffX off_x = {}, - OffY off_y = {}) const { + // Copy out results STEEL_PRAGMA_UNROLL - for (short i = 0; i < kSubTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kSubTileCols; ++j) { - NAXFrag_t::store( - frag_at(i, j), - dst, - str_x, - str_y, - off_x + i * kFragRows, - off_y + j * kFragCols); - } + for (short i = 0; i < kElemsPerFrag; i++) { + Cn0[i] = ct_c[i]; + Cn1[i] = ct_c[kElemsPerFrag + i]; } } template < - typename SrcPtrType, - typename StrX, - typename StrY, - typename LimX, - typename OffX = Int<0>, - typename OffY = Int<0>> - METAL_FUNC constexpr void load_rows( - SrcPtrType src, - StrX str_x, - StrY str_y, - LimX lim_x, - OffX off_x = {}, - OffY off_y = {}) { + typename CType, + typename AType, + typename BType, + bool transpose_a = false, + bool transpose_b = false> + METAL_FUNC static constexpr void mma( + thread dtype_frag_t& Cm0, + thread dtype_frag_t& Cm1, + const thread dtype_frag_t& Am0, + const thread dtype_frag_t& Am1, + metal::bool_constant, + const thread dtype_frag_t& B, + metal::bool_constant) { + // Create Matmul descriptor + constexpr auto desc = mpp::tensor_ops::matmul2d_descriptor( + 16, + 32, + 16, + transpose_a, + transpose_b, + true, + mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate); + + // Create matmul op + mpp::tensor_ops::matmul2d gemm_op; + + // Create matmul operands in registers + auto ct_a = + gemm_op + .template get_left_input_cooperative_tensor(); + auto ct_b = + gemm_op + .template get_right_input_cooperative_tensor(); + + // Create matmul output in register + auto ct_c = gemm_op.template get_destination_cooperative_tensor< + decltype(ct_a), + decltype(ct_b), + CType>(); + + // Load A in to left operand registers STEEL_PRAGMA_UNROLL - for (int i = 0; i < kSubTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < kSubTileCols; ++j) { - NAXFrag_t::load_rows( - frag_at(i, j), - src, - str_x, - str_y, - lim_x, - off_x + (i * kFragRows), - off_y + (j * kFragCols)); - } + for (short i = 0; i < kElemsPerFrag; i++) { + ct_a[i] = Am0[i]; + ct_a[kElemsPerFrag + i] = Am1[i]; } - } - template < - typename SrcPtrType, - typename StrX, - typename StrY, - typename LimX, - typename LimY, - typename OffX = Int<0>, - typename OffY = Int<0>> - METAL_FUNC constexpr void load_safe( - SrcPtrType src, - StrX str_x, - StrY str_y, - LimX lim_x, - LimY lim_y, - OffX off_x = {}, - OffY off_y = {}) { + // Load B into right operand registers STEEL_PRAGMA_UNROLL - for (int i = 0; i < kSubTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < kSubTileCols; ++j) { - NAXFrag_t::load_safe( - frag_at(i, j), - src, - str_x, - str_y, - lim_x, - lim_y, - off_x + (i * kFragRows), - off_y + (j * kFragCols)); - } + for (short i = 0; i < kElemsPerFrag; i++) { + ct_b[i] = B[i]; } - } - template < - typename DstPtrType, - typename StrX, - typename StrY, - typename LimX, - typename LimY, - typename OffX = Int<0>, - typename OffY = Int<0>> - METAL_FUNC constexpr void store_safe( - DstPtrType dst, - StrX str_x, - StrY str_y, - LimX lim_x, - LimY lim_y, - OffX off_x = {}, - OffY off_y = {}) const { + // Load C into output registers (op handles accumulation) STEEL_PRAGMA_UNROLL - for (int i = 0; i < kSubTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < kSubTileCols; ++j) { - NAXFrag_t::store_safe( - frag_at(i, j), - dst, - str_x, - str_y, - lim_x, - lim_y, - off_x + (i * kFragRows), - off_y + (j * kFragCols)); - } + for (short i = 0; i < kElemsPerFrag; i++) { + ct_c[i] = Cm0[i]; + ct_c[kElemsPerFrag + i] = Cm1[i]; } - } - template < - typename DstPtrType, - typename StrX, - typename StrY, - typename LimX, - typename OffX = Int<0>, - typename OffY = Int<0>> - METAL_FUNC constexpr void store_rows( - DstPtrType dst, - StrX str_x, - StrY str_y, - LimX lim_x, - OffX off_x = {}, - OffY off_y = {}) const { + // Do matmul + gemm_op.run(ct_a, ct_b, ct_c); + + // Copy out results STEEL_PRAGMA_UNROLL - for (int i = 0; i < kSubTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < kSubTileCols; ++j) { - NAXFrag_t::store_safe( - frag_at(i, j), - dst, - str_x, - str_y, - lim_x, - off_x + (i * kFragRows), - off_y + (j * kFragCols)); - } + for (short i = 0; i < kElemsPerFrag; i++) { + Cm0[i] = ct_c[i]; + Cm1[i] = ct_c[kElemsPerFrag + i]; } } - - template < - typename DstPtrType, - typename StrX, - typename StrY, - typename StartX, - typename StopX, - typename StartY, - typename StopY, - typename OffX = Int<0>, - typename OffY = Int<0>> - METAL_FUNC constexpr void store_slice( - DstPtrType dst, - StrX str_x, - StrY str_y, - StartX start_x, - StopX stop_x, - StartY start_y, - StopY stop_y, - OffX off_x = Int<0>{}, - OffY off_y = Int<0>{}) const { - const_for_loop<0, kSubTileRows, 1>([&](auto idx_row) { - const_for_loop<0, kSubTileCols, 1>([&](auto idx_col) { - NAXFrag_t::store_slice( - frag_at(), - dst, - str_x, - str_y, - start_x, - stop_x, - start_y, - stop_y, - off_x + idx_row * Int{}, - off_y + idx_col * Int{}); - }); - }); - } }; template < - short RC, - short CC, - short RA, - short CA, - short RB, - short CB, - typename CType, - typename AType, - typename BType, - bool transpose_a, - bool transpose_b, - typename NAXFrag_t = BaseNAXFrag> -METAL_FUNC void subtile_matmad_nax( - thread NAXSubTile& C, - thread NAXSubTile& A, - metal::bool_constant, - thread NAXSubTile& B, - metal::bool_constant) { - // Static checks - constexpr short FMa = transpose_a ? CA : RA; - constexpr short FMc = RC; - static_assert(FMa == FMc, "NAX matmul: M dimensions do not match"); - - constexpr short FNb = transpose_b ? RB : CB; - constexpr short FNc = CC; - static_assert(FNb == FNc, "NAX matmul: N dimensions do not match"); - - constexpr short FKa = transpose_a ? RA : CA; - constexpr short FKb = transpose_b ? CB : RB; - static_assert(FKa == FKb, "NAX matmul: N dimensions do not match"); - - constexpr short FM = FMc; - constexpr short FN = FNc; - constexpr short FK = FKa; - - constexpr int TM = FM / 16; - constexpr int TN = FN / 16; - constexpr int TK = FK / 16; - - // Create Matmul descriptor - constexpr auto desc = mpp::tensor_ops::matmul2d_descriptor( - FM, - FN, - FK, - transpose_a, - transpose_b, - true, - mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate); - - // Create matmul op - mpp::tensor_ops::matmul2d gemm_op; - - // Create matmul operands in registers - auto ct_a = - gemm_op.template get_left_input_cooperative_tensor(); - auto ct_b = - gemm_op - .template get_right_input_cooperative_tensor(); - - // Create matmul output in register - auto ct_c = gemm_op.template get_destination_cooperative_tensor< - decltype(ct_a), - decltype(ct_b), - CType>(); - - // Load A in to left operand registers - STEEL_PRAGMA_UNROLL - for (short mm = 0; mm < TM; mm++) { - STEEL_PRAGMA_UNROLL - for (short kk = 0; kk < TK; kk++) { - const short fi = transpose_a ? kk : mm; - const short fj = transpose_a ? mm : kk; - - STEEL_PRAGMA_UNROLL - for (short i = 0; i < 8; i++) { - ct_a[(TK * mm + kk) * 8 + i] = A.frag_at(fi, fj)[i]; - } - } - } - - // Load B into right operand registers - STEEL_PRAGMA_UNROLL - for (short nn = 0; nn < TN; nn++) { - STEEL_PRAGMA_UNROLL - for (short kk = 0; kk < TK; kk++) { - const short fi = transpose_b ? nn : kk; - const short fj = transpose_b ? kk : nn; - - STEEL_PRAGMA_UNROLL - for (short i = 0; i < 8; i++) { - ct_b[(TN * kk + nn) * 8 + i] = B.frag_at(fi, fj)[i]; - } - } - } - - // Load C into output registers (op handles accumulation) - STEEL_PRAGMA_UNROLL - for (short i = 0; i < ct_c.get_capacity(); i++) { - ct_c[i] = C.elems()[i]; - } - - // Do matmul - gemm_op.run(ct_a, ct_b, ct_c); - - // Copy out results - STEEL_PRAGMA_UNROLL - for (short i = 0; i < ct_c.get_capacity(); i++) { - C.elems()[i] = ct_c[i]; - } -} - -template + typename T, + short kTileRows_, + short kTileCols_, + class NAXFrag_ = BaseNAXFrag> struct NAXTile { - using NAXSubTile_t = NAXSubTile_; + using NAXFrag_t = NAXFrag_; using elem_type = T; - STEEL_CONST short kSubTileRows = NAXSubTile_t::kRows; - STEEL_CONST short kSubTileCols = NAXSubTile_t::kCols; - STEEL_CONST short kElemsPerSubTile = NAXSubTile_t::kElemsPerSubTile; + + STEEL_CONST short kFragRows = NAXFrag_t::kFragRows; + STEEL_CONST short kFragCols = NAXFrag_t::kFragCols; + STEEL_CONST short kElemsPerFrag = NAXFrag_t::kElemsPerFrag; STEEL_CONST short kTileRows = kTileRows_; STEEL_CONST short kTileCols = kTileCols_; - STEEL_CONST short kRows = kTileRows * kSubTileRows; - STEEL_CONST short kCols = kTileCols * kSubTileCols; + STEEL_CONST short kRows = kTileRows * kFragRows; + STEEL_CONST short kCols = kTileCols * kFragCols; - STEEL_CONST short kSubTiles = kTileRows * kTileCols; - STEEL_CONST short kElemsPerTile = kSubTiles * kElemsPerSubTile; + STEEL_CONST short kNumFrags = kTileRows * kTileCols; + STEEL_CONST short kElemsPerTile = kNumFrags * kElemsPerFrag; - STEEL_CONST short kRowsPerThread = kTileRows * NAXSubTile_t::kRowsPerThread; - STEEL_CONST short kColsPerThread = kTileCols * NAXSubTile_t::kColsPerThread; + STEEL_CONST short kFragThrRows = NAXFrag_t::kElemRows; + STEEL_CONST short kFragThrCols = NAXFrag_t::kElemCols; + STEEL_CONST short kFragRowsJump = NAXFrag_t::kElemRowsJump; - STEEL_CONST short kSubTileThrRows = NAXSubTile_t::kRowsPerThread; - STEEL_CONST short kSubTileThrCols = NAXSubTile_t::kColsPerThread; + STEEL_CONST short kRowsPerThread = kTileRows * NAXFrag_t::kElemRows; + STEEL_CONST short kColsPerThread = kTileCols * NAXFrag_t::kElemCols; - NAXSubTile_t val_subtiles[kSubTiles]; + typedef typename NAXFrag_t::template dtype_frag_t frag_type; + + frag_type val_frags[kNumFrags]; // = {frag_type(0)}; METAL_FUNC NAXTile() thread {} METAL_FUNC constexpr void clear() { STEEL_PRAGMA_UNROLL - for (short i = 0; i < kSubTiles; ++i) { - val_subtiles[i].clear(); + for (short i = 0; i < kNumFrags; ++i) { + val_frags[i] = frag_type(0); } } - METAL_FUNC constexpr thread NAXSubTile_t& subtile_at( - const short i, - const short j) { - return val_subtiles[i * kTileCols + j]; + METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) { + return val_frags[i * kTileCols + j]; } - METAL_FUNC constexpr const thread NAXSubTile_t& subtile_at( + METAL_FUNC constexpr const thread frag_type& frag_at( const short i, const short j) const { - return val_subtiles[i * kTileCols + j]; + return val_frags[i * kTileCols + j]; } template - METAL_FUNC constexpr const thread NAXSubTile_t& subtile_at() const { - return val_subtiles[i * kTileCols + j]; + METAL_FUNC constexpr thread frag_type& frag_at() { + return val_frags[i * kTileCols + j]; + } + + template + METAL_FUNC constexpr const thread frag_type& frag_at() const { + return val_frags[i * kTileCols + j]; + } + + template + METAL_FUNC constexpr thread frag_type& + frag_at(const short i, const short j, metal::bool_constant) { + if constexpr (transpose) { + return frag_at(j, i); + } else { + return frag_at(i, j); + } + } + + template + METAL_FUNC constexpr const thread frag_type& + frag_at(const short i, const short j, metal::bool_constant) const { + if constexpr (transpose) { + return frag_at(j, i); + } else { + return frag_at(i, j); + } + } + + template + METAL_FUNC constexpr thread frag_type& frag_at() { + if constexpr (transpose) { + return frag_at(); + } else { + return frag_at(); + } + } + + template + METAL_FUNC constexpr const thread frag_type& frag_at() const { + if constexpr (transpose) { + return frag_at(); + } else { + return frag_at(); + } } METAL_FUNC thread elem_type* elems() { - return reinterpret_cast(val_subtiles[0].elems()); + return reinterpret_cast(val_frags); } METAL_FUNC const thread elem_type* elems() const { - return reinterpret_cast(val_subtiles[0].elems()); + return reinterpret_cast(val_frags); } template METAL_FUNC void row_reduce(thread metal::vec& vals) const { - auto sub_rows = (thread metal::vec*)(&vals); + auto vptr = (thread T*)(&vals); STEEL_PRAGMA_UNROLL for (short i = 0; i < kTileRows; ++i) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kTileCols; ++j) { - subtile_at(i, j).template row_reduce(sub_rows[i]); + NAXFrag_t::template row_reduce( + frag_at(i, j), &vptr[i * kFragThrRows]); } } } template METAL_FUNC void row_bin_op(thread metal::vec& vals) { - auto sub_rows = (thread metal::vec*)(&vals); + auto vptr = (thread T*)(&vals); STEEL_PRAGMA_UNROLL for (short i = 0; i < kTileRows; ++i) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kTileCols; ++j) { - subtile_at(i, j).template row_bin_op(sub_rows[i]); + NAXFrag_t::template row_bin_op( + frag_at(i, j), &vptr[i * kFragThrRows]); } } } template METAL_FUNC void load(const threadgroup U* src) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kTileCols; ++j) { - subtile_at(i, j).load( + const_for_loop<0, kTileRows, 1>([&](auto idx_row) { + const_for_loop<0, kTileCols, 1>([&](auto idx_col) { + NAXFrag_t::load( + frag_at(), src, Int{}, Int{}, - i * kSubTileRows, - j * kSubTileCols); - } - } + idx_row * Int{}, + idx_col * Int{}); + }); + }); } template METAL_FUNC void store(threadgroup U* dst) const { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kTileCols; ++j) { - subtile_at(i, j).store( + const_for_loop<0, kTileRows, 1>([&](auto idx_row) { + const_for_loop<0, kTileCols, 1>([&](auto idx_col) { + NAXFrag_t::store( + frag_at(), dst, Int{}, Int{}, - i * kSubTileRows, - j * kSubTileCols); - } - } + idx_row * Int{}, + idx_col * Int{}); + }); + }); } template METAL_FUNC void load(const device U* src, const int ld) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kTileCols; ++j) { - subtile_at(i, j).load( - &src[(i * kSubTileRows * ld + j * kSubTileCols)], ld, Int<1>{}); - } - } + const_for_loop<0, kTileRows, 1>([&](auto idx_row) { + const_for_loop<0, kTileCols, 1>([&](auto idx_col) { + NAXFrag_t::load( + frag_at(), + src, + ld, + Int<1>{}, + idx_row * Int{}, + idx_col * Int{}); + }); + }); } template METAL_FUNC void store(device U* dst, const int ld) const { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kTileCols; ++j) { - subtile_at(i, j).store( - &dst[(i * kSubTileRows * ld + j * kSubTileCols)], ld, Int<1>{}); - } - } + const_for_loop<0, kTileRows, 1>([&](auto idx_row) { + const_for_loop<0, kTileCols, 1>([&](auto idx_col) { + NAXFrag_t::store( + frag_at(), + dst, + ld, + Int<1>{}, + idx_row * Int{}, + idx_col * Int{}); + }); + }); } template METAL_FUNC void load_rows(const device U* src, const int ld, const short n_rows) { - STEEL_PRAGMA_UNROLL - for (int i = 0; i < kTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < kTileCols; ++j) { - subtile_at(i, j).load_rows( - &src[(i * kSubTileRows) * ld + (j * kSubTileCols)], + const_for_loop<0, kTileRows, 1>([&](auto idx_row) { + const_for_loop<0, kTileCols, 1>([&](auto idx_col) { + NAXFrag_t::load_rows( + frag_at(), + src, ld, Int<1>{}, - n_rows - i * kSubTileRows); - } - } + n_rows, + idx_row * Int{}, + idx_col * Int{}); + }); + }); } template METAL_FUNC void load_safe(const device U* src, const int ld, const short2 src_tile_dims) { - STEEL_PRAGMA_UNROLL - for (int i = 0; i < kTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < kTileCols; ++j) { - subtile_at(i, j).load_safe( + const_for_loop<0, kTileRows, 1>([&](auto idx_row) { + const_for_loop<0, kTileCols, 1>([&](auto idx_col) { + NAXFrag_t::load_safe( + frag_at(), src, ld, Int<1>{}, src_tile_dims.y, src_tile_dims.x, - i * kSubTileRows, - j * kSubTileCols); - } - } + idx_row * Int{}, + idx_col * Int{}); + }); + }); } template METAL_FUNC void store_rows(device U* dst, const int ld, const short n_rows) const { - STEEL_PRAGMA_UNROLL - for (int i = 0; i < kTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < kTileCols; ++j) { - subtile_at(i, j).store_rows( - &dst[(i * kSubTileRows) * ld + (j * kSubTileCols)], + const_for_loop<0, kTileRows, 1>([&](auto idx_row) { + const_for_loop<0, kTileCols, 1>([&](auto idx_col) { + NAXFrag_t::store_rows( + frag_at(), + dst, ld, Int<1>{}, - n_rows - i * kSubTileRows); - } - } + n_rows, + idx_row * Int{}, + idx_col * Int{}); + }); + }); } template METAL_FUNC void store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const { - STEEL_PRAGMA_UNROLL - for (int i = 0; i < kTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < kTileCols; ++j) { - subtile_at(i, j).store_safe( + const_for_loop<0, kTileRows, 1>([&](auto idx_row) { + const_for_loop<0, kTileCols, 1>([&](auto idx_col) { + NAXFrag_t::store_safe( + frag_at(), dst, ld, Int<1>{}, dst_tile_dims.y, dst_tile_dims.x, - i * kSubTileRows, - j * kSubTileCols); - } - } + idx_row * Int{}, + idx_col * Int{}); + }); + }); } template @@ -1001,7 +797,8 @@ struct NAXTile { const short2 stop) const { const_for_loop<0, kTileRows, 1>([&](auto idx_row) { const_for_loop<0, kTileCols, 1>([&](auto idx_col) { - subtile_at().store_slice( + NAXFrag_t::store_slice( + frag_at(), dst, ld, Int<1>{}, @@ -1009,8 +806,8 @@ struct NAXTile { stop.y, start.x, stop.x, - idx_row * Int{}, - idx_col * Int{}); + idx_row * Int{}, + idx_col * Int{}); }); }); } @@ -1030,51 +827,54 @@ METAL_FUNC void tile_matmad_nax( metal::bool_constant) { // Static checks constexpr short TMa = transpose_a ? ATile::kTileCols : ATile::kTileRows; - constexpr short TMc = CTile::kTileRows; - static_assert(TMa == TMc, "NAX tile matmul: M dimensions do not match"); - - constexpr short FMa = transpose_a ? ATile::kSubTileCols : ATile::kSubTileRows; - constexpr short FMc = CTile::kSubTileRows; - static_assert(FMa == FMc, "NAX subtile matmul: M dimensions do not match"); + constexpr short TM = CTile::kTileRows; + static_assert(TMa == TM, "MXU tile matmul: M dimensions do not match"); constexpr short TNb = transpose_b ? BTile::kTileRows : BTile::kTileCols; - constexpr short TNc = CTile::kTileCols; - static_assert(TNb == TNc, "NAX tile matmul: N dimensions do not match"); - - constexpr short FNb = transpose_b ? BTile::kSubTileRows : BTile::kSubTileCols; - constexpr short FNc = CTile::kSubTileCols; - static_assert(FNb == FNc, "NAX subtile matmul: N dimensions do not match"); + constexpr short TN = CTile::kTileCols; + static_assert(TNb == TN, "MXU tile matmul: N dimensions do not match"); constexpr short TKa = transpose_a ? ATile::kTileRows : ATile::kTileCols; - constexpr short TKb = transpose_b ? BTile::kTileCols : BTile::kTileRows; - static_assert(TKa == TKb, "NAX tile matmul: K dimensions do not match"); - - constexpr short FKa = transpose_a ? ATile::kSubTileRows : ATile::kSubTileCols; - constexpr short FKb = transpose_b ? BTile::kSubTileCols : BTile::kSubTileRows; - static_assert(FKa == FKb, "NAX subtile matmul: K dimensions do not match"); + constexpr short TK = transpose_b ? BTile::kTileCols : BTile::kTileRows; + static_assert(TKa == TK, "MXU tile matmul: K dimensions do not match"); - constexpr short TM = TMc; - constexpr short TN = TNc; - constexpr short TK = TKa; + constexpr auto ta = metal::bool_constant{}; + constexpr auto tb = metal::bool_constant{}; - // Do matmul here - STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; ++i) { + if constexpr (TN == 1 && TM % 2 == 0) { STEEL_PRAGMA_UNROLL - for (short j = 0; j < TN; ++j) { + for (short mm = 0; mm < TM; mm += 2) { STEEL_PRAGMA_UNROLL - for (short k = 0; k < TK; ++k) { - const short ra = transpose_a ? k : i; - const short ca = transpose_a ? i : k; - const short rb = transpose_b ? j : k; - const short cb = transpose_b ? k : j; - - subtile_matmad_nax( - C.subtile_at(i, j), - A.subtile_at(ra, ca), - metal::bool_constant{}, - B.subtile_at(rb, cb), - metal::bool_constant{}); + for (short nn = 0; nn < TN; ++nn) { + STEEL_PRAGMA_UNROLL + for (short kk = 0; kk < TK; ++kk) { + CTile::NAXFrag_t::mma( + C.frag_at(mm, nn), + C.frag_at(mm + 1, nn), + A.frag_at(mm, kk, ta), + A.frag_at(mm + 1, kk, ta), + metal::bool_constant{}, + B.frag_at(kk, nn, tb), + metal::bool_constant{}); + } + } + } + } else if constexpr (TN % 2 == 0) { + STEEL_PRAGMA_UNROLL + for (short mm = 0; mm < TM; ++mm) { + STEEL_PRAGMA_UNROLL + for (short nn = 0; nn < TN; nn += 2) { + STEEL_PRAGMA_UNROLL + for (short kk = 0; kk < TK; ++kk) { + CTile::NAXFrag_t::mma( + C.frag_at(mm, nn), + C.frag_at(mm, nn + 1), + A.frag_at(mm, kk, ta), + metal::bool_constant{}, + B.frag_at(kk, nn, tb), + B.frag_at(kk, nn + 1, tb), + metal::bool_constant{}); + } } } } From 8482629077086cad0b506ee9ce032f772fae9aab Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Mon, 16 Mar 2026 16:05:46 -0700 Subject: [PATCH 2/3] Fix attention row loading --- .../steel/attn/kernels/steel_attention_nax.h | 36 ++++++++----------- mlx/backend/metal/kernels/steel/attn/nax.h | 9 +++-- mlx/backend/metal/kernels/steel/gemm/nax.h | 9 +++-- 3 files changed, 26 insertions(+), 28 deletions(-) diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h index a696462707..b3dd8328e3 100644 --- a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h @@ -174,11 +174,16 @@ template < } int kb_lim = params->NK; + int kb_min_causal = params->NK; if (do_causal) { int q_max = (tid.x + 1) * BQ + params->qL_off; kb_lim = (q_max + BK - 1) / BK; kb_lim = min(params->NK, kb_lim); + + int q_min = tid.x * BQ + params->qL_off; + kb_min_causal = (q_min / BK) - int(!align_K); + kb_min_causal = max(0, kb_min_causal); } const bool is_last_bq = int(tid.x) == (params->NQ_aligned); @@ -211,27 +216,19 @@ template < const int K_load_off = ik * kU * int(params->K_strides[2]) + id * kU; if (!align_Q && is_last_q) { - // Qtile.load_rows( - // Q + Q_load_off, - // int(params->Q_strides[2]), - // lim_rows_q - iq * UQ); - Qtile.load_safe( + Qtile.load_rows( Q + Q_load_off, int(params->Q_strides[2]), - short2(BD, lim_rows_q - iq * kU)); + lim_rows_q - iq * kU); } else { Qtile.load(Q + Q_load_off, int(params->Q_strides[2])); } if (!align_K && is_last_k) { - // Ktile.load_rows( - // K + K_load_off, - // int(params->K_strides[2]), - // lim_rows_k - ik * UKs); - Ktile.load_safe( + Ktile.load_rows( K + K_load_off, int(params->K_strides[2]), - short2(BD, lim_rows_k - ik * kU)); + lim_rows_k - ik * kU); } else { Ktile.load(K + K_load_off, int(params->K_strides[2])); } @@ -271,7 +268,7 @@ template < STEEL_PRAGMA_UNROLL for (short jj = 0; jj < stile_t::kFragThrCols; jj++) { const auto loc = ii * stile_t::kFragThrCols + jj; - fg[loc] = ((col_pos + jj) >= params->kL_rem) ? neg_inf : fg[loc]; + fg[loc] = ((col_pos + jj) < params->kL_rem) ? fg[loc] : neg_inf; } } } @@ -279,7 +276,7 @@ template < } // Mask out if causal - if (do_causal && kb >= (kb_lim - ((BQ + BK - 1) / BK) - int(!align_K))) { + if (do_causal && kb >= kb_min_causal) { constexpr auto neg_inf = Limits::finite_min; const int base_row = tid.x * BQ + params->qL_off + tm; @@ -406,14 +403,10 @@ template < const int V_load_off = ik * kU * int(params->V_strides[2]) + id * kU; if (!align_K && is_last_k) { - // Vtile.load_rows( - // V + V_load_off, - // int(params->V_strides[2]), - // lim_rows_k - ik * UK); - Vtile.load_safe( + Vtile.load_rows( V + V_load_off, int(params->V_strides[2]), - short2(BD, lim_rows_k - ik * kU)); + lim_rows_k - ik * kU); } else { Vtile.load(V + V_load_off, int(params->V_strides[2])); } @@ -454,8 +447,7 @@ template < if (lim_rows_q <= 0) return; - // Otile.store_rows(O, params->O_strides[2], lim_rows_q); - Otile.store_safe(O, params->O_strides[2], short2(BD, lim_rows_q)); + Otile.store_rows(O, int(params->O_strides[2]), lim_rows_q); } else { Otile.store(O, int(params->O_strides[2])); } diff --git a/mlx/backend/metal/kernels/steel/attn/nax.h b/mlx/backend/metal/kernels/steel/attn/nax.h index 21499c781d..072afe3ba3 100644 --- a/mlx/backend/metal/kernels/steel/attn/nax.h +++ b/mlx/backend/metal/kernels/steel/attn/nax.h @@ -112,7 +112,7 @@ struct BaseNAXFrag { OffX off_x = {}, OffY off_y = {}) { const short2 sc = get_coord(); - dst += sc.y * str_x + sc.x * str_y; + src += sc.y * str_x + sc.x * str_y; auto lx = lim_x - sc.y; STEEL_PRAGMA_UNROLL @@ -135,7 +135,10 @@ struct BaseNAXFrag { } } else { - dst = dtype_frag_t(0); + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * kElemCols + j] = T(0); + } } } } @@ -169,7 +172,7 @@ struct BaseNAXFrag { const auto c = off_y; STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { - if (r < lx && (c + j) < ly) { + if ((r < lx) && ((c + j) < ly)) { dst[i * kElemCols + j] = static_cast(src[r * str_x + (c + j) * str_y]); } else { diff --git a/mlx/backend/metal/kernels/steel/gemm/nax.h b/mlx/backend/metal/kernels/steel/gemm/nax.h index 21499c781d..072afe3ba3 100644 --- a/mlx/backend/metal/kernels/steel/gemm/nax.h +++ b/mlx/backend/metal/kernels/steel/gemm/nax.h @@ -112,7 +112,7 @@ struct BaseNAXFrag { OffX off_x = {}, OffY off_y = {}) { const short2 sc = get_coord(); - dst += sc.y * str_x + sc.x * str_y; + src += sc.y * str_x + sc.x * str_y; auto lx = lim_x - sc.y; STEEL_PRAGMA_UNROLL @@ -135,7 +135,10 @@ struct BaseNAXFrag { } } else { - dst = dtype_frag_t(0); + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * kElemCols + j] = T(0); + } } } } @@ -169,7 +172,7 @@ struct BaseNAXFrag { const auto c = off_y; STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { - if (r < lx && (c + j) < ly) { + if ((r < lx) && ((c + j) < ly)) { dst[i * kElemCols + j] = static_cast(src[r * str_x + (c + j) * str_y]); } else { From 764cbfb197b5298d32eef43340551d64e37cc9ab Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Tue, 17 Mar 2026 11:51:22 -0700 Subject: [PATCH 3/3] Share behavior between nax and pre-nax attention --- .../metal/kernels/steel/attn/kernels/steel_attention.h | 7 ++++++- .../metal/kernels/steel/attn/kernels/steel_attention_nax.h | 4 ++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h index 491830949f..0d9628e834 100644 --- a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h @@ -234,11 +234,16 @@ template < } int kb_lim = params->NK; + int kb_min_causal = params->NK; if (do_causal) { int q_max = (tid.x + 1) * BQ + params->qL_off; kb_lim = (q_max + BK - 1) / BK; kb_lim = min(params->NK, kb_lim); + + int q_min = tid.x * BQ + params->qL_off; + q_min = max(0, q_min); + kb_min_causal = (q_min / BK); } // Loop over KV seq length @@ -298,7 +303,7 @@ template < } // Mask out if causal - if (do_causal && kb >= (kb_lim - ((BQ + BK - 1) / BK) - int(!align_K))) { + if (do_causal && kb >= kb_min_causal) { using stile_t = decltype(Stile); using selem_t = typename stile_t::elem_type; constexpr auto neg_inf = Limits::finite_min; diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h index b3dd8328e3..084ced6954 100644 --- a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h @@ -182,8 +182,8 @@ template < kb_lim = min(params->NK, kb_lim); int q_min = tid.x * BQ + params->qL_off; - kb_min_causal = (q_min / BK) - int(!align_K); - kb_min_causal = max(0, kb_min_causal); + q_min = max(0, q_min); + kb_min_causal = (q_min / BK); } const bool is_last_bq = int(tid.x) == (params->NQ_aligned);