From 4b45f4720bad3ed4fd07d46094baa44fd10ddc37 Mon Sep 17 00:00:00 2001 From: hongchao Date: Tue, 9 Sep 2025 08:54:52 +0000 Subject: [PATCH] support ctx_len=0 --- csrc/kernels/splitkv_mla.cu | 16 ++++++++------ csrc/kernels_fp8/flash_fwd_mla_kernel.h | 5 +++-- tests/test_flash_mla.py | 28 +++++++++++++++++++++++-- 3 files changed, 39 insertions(+), 10 deletions(-) diff --git a/csrc/kernels/splitkv_mla.cu b/csrc/kernels/splitkv_mla.cu index 03a208b..e75a3b0 100644 --- a/csrc/kernels/splitkv_mla.cu +++ b/csrc/kernels/splitkv_mla.cu @@ -1035,6 +1035,7 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params int seqlen_k = __ldg(params.seqlens_k_ptr + batch_idx); const int start_block_idx = batch_idx == begin_idx ? begin_seqlen / kBlockN : 0; int end_block_idx = batch_idx == end_idx ? cute::ceil_div(end_seqlen, kBlockN) : cute::ceil_div(seqlen_k, kBlockN); + int n_block = end_block_idx - 1; const bool is_no_split = start_block_idx == 0 && end_block_idx == cute::ceil_div(seqlen_k, kBlockN); int rRightBorderForQSeq[2]; @@ -1083,10 +1084,12 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params >{}); // Copy K0 and K1 - launch_kv_tiles_copy_tma<0, 9>(tma_gK(_, _, __ldg(block_table_ptr + start_block_idx)), sK0, tma_params.tma_K, barriers_K0, threadIdx.x); - if (start_block_idx+1 < end_block_idx) { - launch_kv_tiles_copy_tma<4, 9>(tma_gK(_, _, __ldg(block_table_ptr + start_block_idx+1)), sK1, tma_params.tma_K, barriers_K1, threadIdx.x); - launch_kv_tiles_copy_tma<0, 4>(tma_gK(_, _, __ldg(block_table_ptr + start_block_idx+1)), sK1, tma_params.tma_K, barriers_K1, threadIdx.x); + if (n_block >= start_block_idx) { + launch_kv_tiles_copy_tma<0, 9>(tma_gK(_, _, __ldg(block_table_ptr + start_block_idx)), sK0, tma_params.tma_K, barriers_K0, threadIdx.x); + if (start_block_idx+1 < end_block_idx) { + launch_kv_tiles_copy_tma<4, 9>(tma_gK(_, _, __ldg(block_table_ptr + start_block_idx+1)), sK1, tma_params.tma_K, barriers_K1, threadIdx.x); + launch_kv_tiles_copy_tma<0, 4>(tma_gK(_, _, __ldg(block_table_ptr + start_block_idx+1)), sK1, tma_params.tma_K, barriers_K1, threadIdx.x); + } } Tensor rO = partition_fragment_C((typename T::TiledMMA_PV_LocalP){}, Shape, Int>{}); // ((2, 2, 32), 1, 1) @@ -1106,7 +1109,7 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params Tensor rQ8 = make_tensor(Shape, _1, _4>{}); retrieve_rP_from_sP(rQ8, local_tile(sQ, Shape<_64, _64>{}, Coord<_0, _8>{}), idx_in_warpgroup); - if (warpgroup_idx == 0) { + if (warpgroup_idx == 0 && n_block >= start_block_idx) { // Warpgroup 0 Tensor rP0 = make_tensor((typename T::rP0Layout){}); @@ -1146,7 +1149,8 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params LAUNCH_WG0_SUBROUTINE(true, false); } - } else { + } + if (warpgroup_idx != 0 && n_block >= start_block_idx) { // Warpgroup 1 Tensor rP1 = make_tensor((typename T::rP0Layout){}); diff --git a/csrc/kernels_fp8/flash_fwd_mla_kernel.h b/csrc/kernels_fp8/flash_fwd_mla_kernel.h index 9ba4035..c1ef147 100644 --- a/csrc/kernels_fp8/flash_fwd_mla_kernel.h +++ b/csrc/kernels_fp8/flash_fwd_mla_kernel.h @@ -321,7 +321,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f flash::Softmax<2 * size<1>(tOrO)> softmax; int warp_group_idx = cutlass::canonical_warp_group_idx(); - if (warp_group_idx == 0) { + if (warp_group_idx == 0 && n_block >= n_block_min) { typename Kernel_traits::TiledMma tiled_mma; auto thr_mma = tiled_mma.get_thread_slice(tidx); Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) @@ -410,7 +410,8 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f cute::copy(softmax.row_max, tRow_maxsRow_max); cute::copy(softmax.row_sum, tRow_sumsRow_sum); cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast(NamedBarriers::SoftmaxReady)); - } else { + } + if (warp_group_idx != 0 && n_block >= n_block_min) { const int *block_table = params.block_table + bidb * params.block_table_batch_stride; int cur_block_table = __ldg(&block_table[n_block]); diff --git a/tests/test_flash_mla.py b/tests/test_flash_mla.py index 931b798..9f53547 100644 --- a/tests/test_flash_mla.py +++ b/tests/test_flash_mla.py @@ -36,6 +36,9 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool=False) -> None: x, y = x.double(), y.double() RMSE = ((x - y) * (x - y)).mean().sqrt().item() + # Note(hc): attn_out is full zeros when seqlen_k=0 and batch=1. + if (x * x + y * y).sum().item() < 1e-20: + return cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12) amax_diff = (x - y).abs().max().item() if use_fp8: @@ -45,7 +48,7 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool=False) - @torch.inference_mode() -def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, torch_dtype): +def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, torch_dtype, test_zero_seqlen: bool = False): print( f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}, {torch_dtype=}" ) @@ -55,9 +58,17 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, torch_dtyp if varlen: for i in range(b): cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), s_q) + if test_zero_seqlen and i % 7 == 0: + cache_seqlens[i] = 0 total_seqlens = cache_seqlens.sum().item() mean_seqlens = cache_seqlens.float().mean().int().item() max_seqlen = cache_seqlens.max().item() + # Note(hc): When testing with seqlen_k = 0 and batch = 1, max_seqlen becomes 0, + # causing blocked_k to be constructed as an empty tensor. This leads to an + # invalid params.k_ptr being used in `auto tma_K = cute::make_tma_copy`, + # resulting in "Error: Failed to initialize the TMA descriptor 1". + # Therefore, we force max_seqlen to be at least 1 to workaround it. + max_seqlen = max(max_seqlen, 1) max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 # print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}") @@ -139,13 +150,18 @@ def ref_mla(): out_flash, lse_flash = flash_mla() out_torch, lse_torch = ref_mla() cal_diff(out_flash, out_torch, "out", use_fp8) + # Note(hc): the lse_flash is 'inf' when current seqlen_k is 0 + lse_flash = torch.where(torch.isinf(lse_flash), torch.tensor(float(0.0)), lse_flash) + # Note(hc): the lse_torch is '-inf' when current seqlen_k is 0 + lse_torch = torch.where(torch.isneginf(lse_torch), torch.tensor(float(0.0)), lse_torch) cal_diff(lse_flash, lse_torch, "lse") t = triton.testing.do_bench(flash_mla) FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d) * (torch.finfo(torch_dtype).bits // 8) + (b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8) + logger_prefix = "" if not test_zero_seqlen else "**** test corner case: " print( - f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s" + f"{logger_prefix} {t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s" ) @@ -169,6 +185,14 @@ def main(torch_dtype): for varlen in [False, True]: test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen, torch_dtype) + print("start testing ctx_len=0 corner case: ") + for b in [1, 4, 8, 128]: + for s in [4096]: + for h_q in [128]: + for s_q in [1]: + for varlen in [True]: + test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen, torch_dtype, test_zero_seqlen=True) + if __name__ == "__main__": parser = argparse.ArgumentParser()