Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions csrc/kernels/splitkv_mla.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1035,6 +1035,7 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params
int seqlen_k = __ldg(params.seqlens_k_ptr + batch_idx);
const int start_block_idx = batch_idx == begin_idx ? begin_seqlen / kBlockN : 0;
int end_block_idx = batch_idx == end_idx ? cute::ceil_div(end_seqlen, kBlockN) : cute::ceil_div(seqlen_k, kBlockN);
int n_block = end_block_idx - 1;
const bool is_no_split = start_block_idx == 0 && end_block_idx == cute::ceil_div(seqlen_k, kBlockN);

int rRightBorderForQSeq[2];
Expand Down Expand Up @@ -1083,10 +1084,12 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params
>{});

// Copy K0 and K1
launch_kv_tiles_copy_tma<0, 9>(tma_gK(_, _, __ldg(block_table_ptr + start_block_idx)), sK0, tma_params.tma_K, barriers_K0, threadIdx.x);
if (start_block_idx+1 < end_block_idx) {
launch_kv_tiles_copy_tma<4, 9>(tma_gK(_, _, __ldg(block_table_ptr + start_block_idx+1)), sK1, tma_params.tma_K, barriers_K1, threadIdx.x);
launch_kv_tiles_copy_tma<0, 4>(tma_gK(_, _, __ldg(block_table_ptr + start_block_idx+1)), sK1, tma_params.tma_K, barriers_K1, threadIdx.x);
if (n_block >= start_block_idx) {
launch_kv_tiles_copy_tma<0, 9>(tma_gK(_, _, __ldg(block_table_ptr + start_block_idx)), sK0, tma_params.tma_K, barriers_K0, threadIdx.x);
if (start_block_idx+1 < end_block_idx) {
launch_kv_tiles_copy_tma<4, 9>(tma_gK(_, _, __ldg(block_table_ptr + start_block_idx+1)), sK1, tma_params.tma_K, barriers_K1, threadIdx.x);
launch_kv_tiles_copy_tma<0, 4>(tma_gK(_, _, __ldg(block_table_ptr + start_block_idx+1)), sK1, tma_params.tma_K, barriers_K1, threadIdx.x);
}
}

Tensor rO = partition_fragment_C((typename T::TiledMMA_PV_LocalP){}, Shape<Int<T::BLOCK_SIZE_M>, Int<T::HEAD_DIM_V / 2>>{}); // ((2, 2, 32), 1, 1)
Expand All @@ -1106,7 +1109,7 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params
Tensor rQ8 = make_tensor<InputT>(Shape<Shape<_2, _2, _2>, _1, _4>{});
retrieve_rP_from_sP<T>(rQ8, local_tile(sQ, Shape<_64, _64>{}, Coord<_0, _8>{}), idx_in_warpgroup);

if (warpgroup_idx == 0) {
if (warpgroup_idx == 0 && n_block >= start_block_idx) {
// Warpgroup 0
Tensor rP0 = make_tensor<float>((typename T::rP0Layout){});

Expand Down Expand Up @@ -1146,7 +1149,8 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params
LAUNCH_WG0_SUBROUTINE(true, false);
}

} else {
}
if (warpgroup_idx != 0 && n_block >= start_block_idx) {
// Warpgroup 1
Tensor rP1 = make_tensor<float>((typename T::rP0Layout){});

Expand Down
5 changes: 3 additions & 2 deletions csrc/kernels_fp8/flash_fwd_mla_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
flash::Softmax<2 * size<1>(tOrO)> softmax;

int warp_group_idx = cutlass::canonical_warp_group_idx();
if (warp_group_idx == 0) {
if (warp_group_idx == 0 && n_block >= n_block_min) {
typename Kernel_traits::TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tidx);
Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K)
Expand Down Expand Up @@ -410,7 +410,8 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
cute::copy(softmax.row_max, tRow_maxsRow_max);
cute::copy(softmax.row_sum, tRow_sumsRow_sum);
cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast<int>(NamedBarriers::SoftmaxReady));
} else {
}
if (warp_group_idx != 0 && n_block >= n_block_min) {
const int *block_table = params.block_table + bidb * params.block_table_batch_stride;
int cur_block_table = __ldg(&block_table[n_block]);

Expand Down
28 changes: 26 additions & 2 deletions tests/test_flash_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool=False) -> None:
x, y = x.double(), y.double()
RMSE = ((x - y) * (x - y)).mean().sqrt().item()
# Note(hc): attn_out is full zeros when seqlen_k=0 and batch=1.
if (x * x + y * y).sum().item() < 1e-20:
return
cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12)
amax_diff = (x - y).abs().max().item()
if use_fp8:
Expand All @@ -45,7 +48,7 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool=False) -


@torch.inference_mode()
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, torch_dtype):
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, torch_dtype, test_zero_seqlen: bool = False):
print(
f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}, {torch_dtype=}"
)
Expand All @@ -55,9 +58,17 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, torch_dtyp
if varlen:
for i in range(b):
cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), s_q)
if test_zero_seqlen and i % 7 == 0:
cache_seqlens[i] = 0
total_seqlens = cache_seqlens.sum().item()
mean_seqlens = cache_seqlens.float().mean().int().item()
max_seqlen = cache_seqlens.max().item()
# Note(hc): When testing with seqlen_k = 0 and batch = 1, max_seqlen becomes 0,
# causing blocked_k to be constructed as an empty tensor. This leads to an
# invalid params.k_ptr being used in `auto tma_K = cute::make_tma_copy`,
# resulting in "Error: Failed to initialize the TMA descriptor 1".
# Therefore, we force max_seqlen to be at least 1 to workaround it.
max_seqlen = max(max_seqlen, 1)
max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256
# print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}")

Expand Down Expand Up @@ -139,13 +150,18 @@ def ref_mla():
out_flash, lse_flash = flash_mla()
out_torch, lse_torch = ref_mla()
cal_diff(out_flash, out_torch, "out", use_fp8)
# Note(hc): the lse_flash is 'inf' when current seqlen_k is 0
lse_flash = torch.where(torch.isinf(lse_flash), torch.tensor(float(0.0)), lse_flash)
# Note(hc): the lse_torch is '-inf' when current seqlen_k is 0
lse_torch = torch.where(torch.isneginf(lse_torch), torch.tensor(float(0.0)), lse_torch)
cal_diff(lse_flash, lse_torch, "lse")

t = triton.testing.do_bench(flash_mla)
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d) * (torch.finfo(torch_dtype).bits // 8) + (b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8)
logger_prefix = "" if not test_zero_seqlen else "**** test corner case: "
print(
f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s"
f"{logger_prefix} {t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s"
)


Expand All @@ -169,6 +185,14 @@ def main(torch_dtype):
for varlen in [False, True]:
test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen, torch_dtype)

print("start testing ctx_len=0 corner case: ")
for b in [1, 4, 8, 128]:
for s in [4096]:
for h_q in [128]:
for s_q in [1]:
for varlen in [True]:
test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen, torch_dtype, test_zero_seqlen=True)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
Expand Down