Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 20 additions & 44 deletions mlx/backend/metal/kernels/fp_quantized_nax.h
Original file line number Diff line number Diff line change
Expand Up @@ -246,16 +246,13 @@ METAL_FUNC void fp_qmm_t_impl(
// Make the weight loader
loader_w_t loader_w(wl, scales, K, Ws, simd_gid, simd_lid);

constexpr short UM = 16;
constexpr short UN = 32;
constexpr short UK = 16;
constexpr short SM = BM / WM;
constexpr short SN = BN / WN;
constexpr short SK = 32;

constexpr short TM = SM / UM;
constexpr short TN = SN / UN;
constexpr short TK = SK / UK;
constexpr short TM = SM / 16;
constexpr short TN = SN / 16;
constexpr short TK = SK / 16;

const short tm = SM * (simd_gid / WN);
const short tn = SN * (simd_gid % WN);
Expand All @@ -273,12 +270,7 @@ METAL_FUNC void fp_qmm_t_impl(

using AccumType = float;

using ASubTile = NAXSubTile<T, UM, UK>;
using BSubTile = NAXSubTile<Wtype, UN, UK>;
using DSubTile = NAXSubTile<AccumType, UM, UN>;

NAXTile<AccumType, TM, TN, DSubTile> Dtile;

NAXTile<AccumType, TM, TN> Dtile;
Dtile.clear();

x += tm * K;
Expand All @@ -297,8 +289,8 @@ METAL_FUNC void fp_qmm_t_impl(

STEEL_PRAGMA_NO_UNROLL
for (int kk1 = 0; kk1 < BK; kk1 += SK) {
NAXTile<T, TM, TK, ASubTile> Atile;
NAXTile<Wtype, TN, TK, BSubTile> Btile;
NAXTile<T, TM, TK> Atile;
NAXTile<Wtype, TN, TK> Btile;

volatile int compiler_barrier;

Expand Down Expand Up @@ -400,16 +392,13 @@ METAL_FUNC void fp_qmm_n_impl(
// const short num_outs = min(BN, N - y_col);
loader_w_t loader_w(wl, scales, K, Ws, simd_gid, simd_lid);

constexpr short UM = 16;
constexpr short UN = 32;
constexpr short UK = 16;
constexpr short SM = BM / WM;
constexpr short SN = BN / WN;
constexpr short SK = 32;

constexpr short TM = SM / UM;
constexpr short TN = SN / UN;
constexpr short TK = SK / UK;
constexpr short TM = SM / 16;
constexpr short TN = SN / 16;
constexpr short TK = SK / 16;

const short tm = SM * (simd_gid / WN);
const short tn = SN * (simd_gid % WN);
Expand All @@ -421,12 +410,7 @@ METAL_FUNC void fp_qmm_n_impl(

using AccumType = float;

using ASubTile = NAXSubTile<T, UM, UK>;
using BSubTile = NAXSubTile<T, UK, UN>;
using DSubTile = NAXSubTile<AccumType, UM, UN>;

NAXTile<AccumType, TM, TN, DSubTile> Dtile;

NAXTile<AccumType, TM, TN> Dtile;
Dtile.clear();

x += tm * K;
Expand All @@ -438,8 +422,8 @@ METAL_FUNC void fp_qmm_n_impl(

STEEL_PRAGMA_NO_UNROLL
for (int kk1 = 0; kk1 < BK; kk1 += SK) {
NAXTile<T, TM, TK, ASubTile> Atile;
NAXTile<Wtype, TK, TN, BSubTile> Btile;
NAXTile<T, TM, TK> Atile;
NAXTile<Wtype, TK, TN> Btile;

volatile int compiler_barrier;

Expand Down Expand Up @@ -869,16 +853,13 @@ template <
wl += transpose ? y_col_long * K_w : y_col * bytes_per_pack / pack_factor;
scales += transpose ? y_col_long * K_g : y_col / group_size;

constexpr short UM = 16;
constexpr short UN = 32;
constexpr short UK = 16;
constexpr short SM = BM / WM;
constexpr short SN = BN / WN;
constexpr short SK = 32;

constexpr short TM = SM / UM;
constexpr short TN = SN / UN;
constexpr short TK = SK / UK;
constexpr short TM = SM / 16;
constexpr short TN = SN / 16;
constexpr short TK = SK / 16;

const short tm = SM * (simd_group_id / WN);
const short tn = SN * (simd_group_id % WN);
Expand All @@ -896,10 +877,6 @@ template <

using AccumType = float;

using ASubTile = NAXSubTile<T, UM, UK>;
using BSubTile = NAXSubTile<Wtype, transpose ? UN : UK, transpose ? UK : UN>;
using DSubTile = NAXSubTile<AccumType, UM, UN>;

// Do as many matmuls as necessary
uint32_t index;
short offset;
Expand All @@ -921,8 +898,7 @@ template <
threadgroup_barrier(mem_flags::mem_none);

// Prepare threadgroup mma operation
NAXTile<AccumType, TM, TN, DSubTile> Dtile;

NAXTile<AccumType, TM, TN> Dtile;
Dtile.clear();

const device T* xn = x + tm * K;
Expand Down Expand Up @@ -951,8 +927,8 @@ template <

STEEL_PRAGMA_NO_UNROLL
for (int kk1 = 0; kk1 < BK; kk1 += SK) {
NAXTile<T, TM, TK, ASubTile> Atile;
NAXTile<Wtype, BR, BC, BSubTile> Btile;
NAXTile<T, TM, TK> Atile;
NAXTile<Wtype, BR, BC> Btile;

volatile int compiler_barrier;

Expand Down Expand Up @@ -991,8 +967,8 @@ template <

STEEL_PRAGMA_NO_UNROLL
for (int kk1 = 0; kk1 < BK; kk1 += SK) {
NAXTile<T, TM, TK, ASubTile> Atile;
NAXTile<Wtype, BR, BC, BSubTile> Btile;
NAXTile<T, TM, TK> Atile;
NAXTile<Wtype, BR, BC> Btile;

volatile int compiler_barrier;

Expand Down
64 changes: 20 additions & 44 deletions mlx/backend/metal/kernels/quantized_nax.h
Original file line number Diff line number Diff line change
Expand Up @@ -986,16 +986,13 @@ METAL_FUNC void qmm_t_nax_tgp_impl(
// Make the weight loader
loader_w_t loader_w(wl, scales, biases, K, Ws, simd_gid, simd_lid);

constexpr short UM = 16;
constexpr short UN = 32;
constexpr short UK = 16;
constexpr short SM = BM / WM;
constexpr short SN = BN / WN;
constexpr short SK = 32;

constexpr short TM = SM / UM;
constexpr short TN = SN / UN;
constexpr short TK = SK / UK;
constexpr short TM = SM / 16;
constexpr short TN = SN / 16;
constexpr short TK = SK / 16;

const short tm = SM * (simd_gid / WN);
const short tn = SN * (simd_gid % WN);
Expand All @@ -1013,12 +1010,7 @@ METAL_FUNC void qmm_t_nax_tgp_impl(

using AccumType = float;

using ASubTile = NAXSubTile<T, UM, UK>;
using BSubTile = NAXSubTile<T, UN, UK>;
using DSubTile = NAXSubTile<AccumType, UM, UN>;

NAXTile<AccumType, TM, TN, DSubTile> Dtile;

NAXTile<AccumType, TM, TN> Dtile;
Dtile.clear();

x += tm * K;
Expand All @@ -1037,8 +1029,8 @@ METAL_FUNC void qmm_t_nax_tgp_impl(

STEEL_PRAGMA_NO_UNROLL
for (int kk1 = 0; kk1 < BK; kk1 += SK) {
NAXTile<T, TM, TK, ASubTile> Atile;
NAXTile<T, TN, TK, BSubTile> Btile;
NAXTile<T, TM, TK> Atile;
NAXTile<T, TN, TK> Btile;

volatile int compiler_barrier;

Expand Down Expand Up @@ -1141,16 +1133,13 @@ METAL_FUNC void qmm_n_nax_tgp_impl(
// const short num_outs = min(BN, N - y_col);
loader_w_t loader_w(wl, scales, biases, K, Ws, simd_gid, simd_lid);

constexpr short UM = 16;
constexpr short UN = 32;
constexpr short UK = 16;
constexpr short SM = BM / WM;
constexpr short SN = BN / WN;
constexpr short SK = 32;

constexpr short TM = SM / UM;
constexpr short TN = SN / UN;
constexpr short TK = SK / UK;
constexpr short TM = SM / 16;
constexpr short TN = SN / 16;
constexpr short TK = SK / 16;

const short tm = SM * (simd_gid / WN);
const short tn = SN * (simd_gid % WN);
Expand All @@ -1162,12 +1151,7 @@ METAL_FUNC void qmm_n_nax_tgp_impl(

using AccumType = float;

using ASubTile = NAXSubTile<T, UM, UK>;
using BSubTile = NAXSubTile<T, UK, UN>;
using DSubTile = NAXSubTile<AccumType, UM, UN>;

NAXTile<AccumType, TM, TN, DSubTile> Dtile;

NAXTile<AccumType, TM, TN> Dtile;
Dtile.clear();

x += tm * K;
Expand All @@ -1179,8 +1163,8 @@ METAL_FUNC void qmm_n_nax_tgp_impl(

STEEL_PRAGMA_NO_UNROLL
for (int kk1 = 0; kk1 < BK; kk1 += SK) {
NAXTile<T, TM, TK, ASubTile> Atile;
NAXTile<T, TK, TN, BSubTile> Btile;
NAXTile<T, TM, TK> Atile;
NAXTile<T, TK, TN> Btile;

volatile int compiler_barrier;

Expand Down Expand Up @@ -1534,16 +1518,13 @@ template <
scales += transpose ? y_col_long * K_g : y_col / group_size;
biases += transpose ? y_col_long * K_g : y_col / group_size;

constexpr short UM = 16;
constexpr short UN = 32;
constexpr short UK = 16;
constexpr short SM = BM / WM;
constexpr short SN = BN / WN;
constexpr short SK = 32;

constexpr short TM = SM / UM;
constexpr short TN = SN / UN;
constexpr short TK = SK / UK;
constexpr short TM = SM / 16;
constexpr short TN = SN / 16;
constexpr short TK = SK / 16;

const short tm = SM * (simd_group_id / WN);
const short tn = SN * (simd_group_id % WN);
Expand All @@ -1561,10 +1542,6 @@ template <

using AccumType = float;

using ASubTile = NAXSubTile<T, UM, UK>;
using BSubTile = NAXSubTile<T, transpose ? UN : UK, transpose ? UK : UN>;
using DSubTile = NAXSubTile<AccumType, UM, UN>;

// Do as many matmuls as necessary
uint32_t index;
short offset;
Expand All @@ -1585,8 +1562,7 @@ template <
}
threadgroup_barrier(mem_flags::mem_none);

NAXTile<AccumType, TM, TN, DSubTile> Dtile;

NAXTile<AccumType, TM, TN> Dtile;
Dtile.clear();

const device T* xn = x + tm * K;
Expand Down Expand Up @@ -1616,8 +1592,8 @@ template <

STEEL_PRAGMA_NO_UNROLL
for (int kk1 = 0; kk1 < BK; kk1 += SK) {
NAXTile<T, TM, TK, ASubTile> Atile;
NAXTile<T, BR, BC, BSubTile> Btile;
NAXTile<T, TM, TK> Atile;
NAXTile<T, BR, BC> Btile;

volatile int compiler_barrier;

Expand Down Expand Up @@ -1654,8 +1630,8 @@ template <

STEEL_PRAGMA_NO_UNROLL
for (int kk1 = 0; kk1 < BK; kk1 += SK) {
NAXTile<T, TM, TK, ASubTile> Atile;
NAXTile<T, BR, BC, BSubTile> Btile;
NAXTile<T, TM, TK> Atile;
NAXTile<T, BR, BC> Btile;

volatile int compiler_barrier;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,11 +234,16 @@ template <
}

int kb_lim = params->NK;
int kb_min_causal = params->NK;

if (do_causal) {
int q_max = (tid.x + 1) * BQ + params->qL_off;
kb_lim = (q_max + BK - 1) / BK;
kb_lim = min(params->NK, kb_lim);

int q_min = tid.x * BQ + params->qL_off;
q_min = max(0, q_min);
kb_min_causal = (q_min / BK);
}

// Loop over KV seq length
Expand Down Expand Up @@ -298,7 +303,7 @@ template <
}

// Mask out if causal
if (do_causal && kb >= (kb_lim - ((BQ + BK - 1) / BK) - int(!align_K))) {
if (do_causal && kb >= kb_min_causal) {
using stile_t = decltype(Stile);
using selem_t = typename stile_t::elem_type;
constexpr auto neg_inf = Limits<selem_t>::finite_min;
Expand Down
Loading
Loading