From 1cb6b57dc4c7b798b1c45f992f2d2b58363246f6 Mon Sep 17 00:00:00 2001 From: Sijia Chen Date: Mon, 24 Feb 2025 00:25:25 -0800 Subject: [PATCH 01/23] replace c10 optional with std optional Signed-off-by: Lucas Wilkinson --- csrc/flash_api.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp index 5a1cb8e..3184465 100644 --- a/csrc/flash_api.cpp +++ b/csrc/flash_api.cpp @@ -61,7 +61,7 @@ std::vector mha_fwd_kvcache_mla( at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size - c10::optional &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v + std::optional &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v const int head_size_v, const at::Tensor &seqlens_k, // batch_size const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq From 596af687fdcc42de1b4789c09507487e38053101 Mon Sep 17 00:00:00 2001 From: Sijia Chen Date: Mon, 24 Feb 2025 01:58:53 -0800 Subject: [PATCH 02/23] support fp16 Signed-off-by: Lucas Wilkinson --- README.md | 2 +- csrc/flash_api.cpp | 9 +++- csrc/flash_fwd_mla_fp16_sm90.cu | 3 ++ csrc/flash_fwd_mla_kernel.h | 76 -------------------------------- csrc/flash_fwd_mla_metadata.cu | 77 +++++++++++++++++++++++++++++++++ setup.py | 2 + tests/test_flash_mla.py | 61 +++++++++++++++++++++----- 7 files changed, 139 insertions(+), 91 deletions(-) create mode 100644 csrc/flash_fwd_mla_fp16_sm90.cu create mode 100644 csrc/flash_fwd_mla_metadata.cu diff --git a/README.md b/README.md index bb55395..4027334 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ FlashMLA is an efficient MLA decoding kernel for Hopper GPUs, optimized for variable-length sequences serving. Currently released: -- BF16 +- BF16, FP16 - Paged kvcache with block size of 64 ## Quick start diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp index 3184465..b7b11f1 100644 --- a/csrc/flash_api.cpp +++ b/csrc/flash_api.cpp @@ -77,7 +77,7 @@ mha_fwd_kvcache_mla( at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache; auto q_dtype = q.dtype(); - TORCH_CHECK(q_dtype == torch::kBFloat16); + TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kFloat16); TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype"); CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache); @@ -186,7 +186,12 @@ mha_fwd_kvcache_mla( auto stream = at::cuda::getCurrentCUDAStream().stream(); TORCH_CHECK(head_size == 576); - run_mha_fwd_splitkv_mla(params, stream); + + if (q_dtype == torch::kBFloat16) { + run_mha_fwd_splitkv_mla(params, stream); + } else { + run_mha_fwd_splitkv_mla(params, stream); + } out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3) .reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v}); diff --git a/csrc/flash_fwd_mla_fp16_sm90.cu b/csrc/flash_fwd_mla_fp16_sm90.cu new file mode 100644 index 0000000..abdaf7b --- /dev/null +++ b/csrc/flash_fwd_mla_fp16_sm90.cu @@ -0,0 +1,3 @@ +#include "flash_fwd_mla_kernel.h" + +template void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h index 55f6811..d96acd8 100644 --- a/csrc/flash_fwd_mla_kernel.h +++ b/csrc/flash_fwd_mla_kernel.h @@ -601,79 +601,3 @@ void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream) using Kernel_traits = Flash_fwd_kernel_traits_mla<576, 64, 64, 8, T, 512>; run_flash_splitkv_fwd_mla>(params, stream); } - -static constexpr int MaxBatchSize = 4096; - -__global__ void __launch_bounds__(256, 1, 1) -get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) { - int *seqlens_k_ptr = params.seqlens_k_ptr; - int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr; - int *num_splits_ptr = params.num_splits_ptr; - int batch_size = params.batch_size; - int block_size_n = params.block_size_n; - int fixed_overhead_num_blocks = params.fixed_overhead_num_blocks; - int num_sm_parts = params.num_sm_parts; - - __shared__ int num_blocks_shared[MaxBatchSize]; - __shared__ int num_splits_shared[MaxBatchSize]; - - int total_num_blocks = 0; - for (int i = threadIdx.x; i < batch_size; i += 32) { - int num_blocks = cutlass::ceil_div(seqlens_k_ptr[i], block_size_n); - total_num_blocks += num_blocks + fixed_overhead_num_blocks; - num_blocks_shared[i] = num_blocks; - } - for (int offset = 16; offset >= 1; offset /= 2) { - total_num_blocks += __shfl_xor_sync(uint32_t(-1), total_num_blocks, offset); - } - __syncwarp(); - - if (threadIdx.x == 0) { - int payload = cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks; - - int now_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0; - num_splits_shared[0] = 0; - for (int i = 0; i < num_sm_parts; ++i) { - int tile_scheduler_metadata0[4], tile_scheduler_metadata1; - tile_scheduler_metadata0[0] = now_idx; - tile_scheduler_metadata0[1] = now_block * block_size_n; - tile_scheduler_metadata1 = now_n_split_idx; - int remain_payload = payload; - while (now_idx < batch_size) { - int num_blocks = num_blocks_shared[now_idx]; - int now_remain_blocks = num_blocks - now_block; - if (remain_payload >= now_remain_blocks + fixed_overhead_num_blocks) { - cum_num_splits += now_n_split_idx + 1; - num_splits_shared[now_idx + 1] = cum_num_splits; - remain_payload -= now_remain_blocks + fixed_overhead_num_blocks; - ++now_idx; - now_block = 0; - now_n_split_idx = 0; - } else { - if (remain_payload - fixed_overhead_num_blocks > 0) { - now_block += remain_payload - fixed_overhead_num_blocks; - ++now_n_split_idx; - remain_payload = 0; - } - break; - } - } - tile_scheduler_metadata0[2] = now_block > 0 ? now_idx : now_idx - 1; - tile_scheduler_metadata0[3] = now_block > 0 ? now_block * block_size_n : seqlens_k_ptr[now_idx - 1]; - *reinterpret_cast(tile_scheduler_metadata_ptr + i * TileSchedulerMetaDataSize) = *reinterpret_cast(tile_scheduler_metadata0); - tile_scheduler_metadata_ptr[i * TileSchedulerMetaDataSize + 4] = tile_scheduler_metadata1; - } - FLASH_DEVICE_ASSERT(now_idx == batch_size && now_block == 0 && now_n_split_idx == 0); - } - __syncwarp(); - - for (int i = threadIdx.x; i <= batch_size; i += 32) { - num_splits_ptr[i] = num_splits_shared[i]; - } -} - -void get_mla_metadata_func(Mla_metadata_params ¶ms, cudaStream_t stream) { - FLASH_ASSERT(params.batch_size < MaxBatchSize); - get_mla_metadata_kernel<<<1, 32, 0, stream>>>(params); - CHECK_CUDA_KERNEL_LAUNCH(); -} diff --git a/csrc/flash_fwd_mla_metadata.cu b/csrc/flash_fwd_mla_metadata.cu new file mode 100644 index 0000000..82f5b5a --- /dev/null +++ b/csrc/flash_fwd_mla_metadata.cu @@ -0,0 +1,77 @@ +#include "flash_fwd_mla_kernel.h" + +static constexpr int MaxBatchSize = 4096; + +__global__ void __launch_bounds__(256, 1, 1) +get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) { + int *seqlens_k_ptr = params.seqlens_k_ptr; + int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr; + int *num_splits_ptr = params.num_splits_ptr; + int batch_size = params.batch_size; + int block_size_n = params.block_size_n; + int fixed_overhead_num_blocks = params.fixed_overhead_num_blocks; + int num_sm_parts = params.num_sm_parts; + + __shared__ int num_blocks_shared[MaxBatchSize]; + __shared__ int num_splits_shared[MaxBatchSize]; + + int total_num_blocks = 0; + for (int i = threadIdx.x; i < batch_size; i += 32) { + int num_blocks = cutlass::ceil_div(seqlens_k_ptr[i], block_size_n); + total_num_blocks += num_blocks + fixed_overhead_num_blocks; + num_blocks_shared[i] = num_blocks; + } + for (int offset = 16; offset >= 1; offset /= 2) { + total_num_blocks += __shfl_xor_sync(uint32_t(-1), total_num_blocks, offset); + } + __syncwarp(); + + if (threadIdx.x == 0) { + int payload = cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks; + + int now_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0; + num_splits_shared[0] = 0; + for (int i = 0; i < num_sm_parts; ++i) { + int tile_scheduler_metadata0[4], tile_scheduler_metadata1; + tile_scheduler_metadata0[0] = now_idx; + tile_scheduler_metadata0[1] = now_block * block_size_n; + tile_scheduler_metadata1 = now_n_split_idx; + int remain_payload = payload; + while (now_idx < batch_size) { + int num_blocks = num_blocks_shared[now_idx]; + int now_remain_blocks = num_blocks - now_block; + if (remain_payload >= now_remain_blocks + fixed_overhead_num_blocks) { + cum_num_splits += now_n_split_idx + 1; + num_splits_shared[now_idx + 1] = cum_num_splits; + remain_payload -= now_remain_blocks + fixed_overhead_num_blocks; + ++now_idx; + now_block = 0; + now_n_split_idx = 0; + } else { + if (remain_payload - fixed_overhead_num_blocks > 0) { + now_block += remain_payload - fixed_overhead_num_blocks; + ++now_n_split_idx; + remain_payload = 0; + } + break; + } + } + tile_scheduler_metadata0[2] = now_block > 0 ? now_idx : now_idx - 1; + tile_scheduler_metadata0[3] = now_block > 0 ? now_block * block_size_n : seqlens_k_ptr[now_idx - 1]; + *reinterpret_cast(tile_scheduler_metadata_ptr + i * TileSchedulerMetaDataSize) = *reinterpret_cast(tile_scheduler_metadata0); + tile_scheduler_metadata_ptr[i * TileSchedulerMetaDataSize + 4] = tile_scheduler_metadata1; + } + FLASH_DEVICE_ASSERT(now_idx == batch_size && now_block == 0 && now_n_split_idx == 0); + } + __syncwarp(); + + for (int i = threadIdx.x; i <= batch_size; i += 32) { + num_splits_ptr[i] = num_splits_shared[i]; + } +} + +void get_mla_metadata_func(Mla_metadata_params ¶ms, cudaStream_t stream) { + FLASH_ASSERT(params.batch_size < MaxBatchSize); + get_mla_metadata_kernel<<<1, 32, 0, stream>>>(params); + CHECK_CUDA_KERNEL_LAUNCH(); +} \ No newline at end of file diff --git a/setup.py b/setup.py index 0a3bd17..662a301 100644 --- a/setup.py +++ b/setup.py @@ -37,6 +37,8 @@ def append_nvcc_threads(nvcc_extra_args): sources=[ "csrc/flash_api.cpp", "csrc/flash_fwd_mla_bf16_sm90.cu", + "csrc/flash_fwd_mla_fp16_sm90.cu", + "csrc/flash_fwd_mla_metadata.cu", ], extra_compile_args={ "cxx": cxx_args, diff --git a/tests/test_flash_mla.py b/tests/test_flash_mla.py index 8db5db0..e676fa7 100644 --- a/tests/test_flash_mla.py +++ b/tests/test_flash_mla.py @@ -1,10 +1,11 @@ +import argparse import math import random import torch import triton -from flash_mla import get_mla_metadata, flash_mla_with_kvcache +from flash_mla import flash_mla_with_kvcache, get_mla_metadata def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): @@ -38,7 +39,9 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None: @torch.inference_mode() def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen): - print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}") + print( + f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}" + ) cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32) if varlen: @@ -52,18 +55,30 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen): q = torch.randn(b, s_q, h_q, d) block_size = 64 - block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + block_table = torch.arange( + b * max_seqlen_pad // block_size, dtype=torch.int32 + ).view(b, max_seqlen_pad // block_size) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) for i in range(b): - blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = float("nan") + blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item() :] = ( + float("nan") + ) blocked_v = blocked_k[..., :dv] - tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv) + tile_scheduler_metadata, num_splits = get_mla_metadata( + cache_seqlens, s_q * h_q // h_kv, h_kv + ) def flash_mla(): return flash_mla_with_kvcache( - q, blocked_k, block_table, cache_seqlens, dv, - tile_scheduler_metadata, num_splits, causal=causal, + q, + blocked_k, + block_table, + cache_seqlens, + dv, + tile_scheduler_metadata, + num_splits, + causal=causal, ) def ref_mla(): @@ -91,14 +106,17 @@ def ref_mla(): t = triton.testing.do_bench(flash_mla) FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 - bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) - print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s") + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * ( + torch.finfo(q.dtype).bits // 8 + ) + print( + f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s" + ) -if __name__ == "__main__": - dtype = torch.bfloat16 +def main(torch_dtype): device = torch.device("cuda:0") - torch.set_default_dtype(dtype) + torch.set_default_dtype(torch_dtype) torch.set_default_device(device) torch.cuda.set_device(device) torch.manual_seed(0) @@ -114,3 +132,22 @@ def ref_mla(): for s_q in [1, 2]: # MTP = 1, 2 for varlen in [False, True]: test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--dtype", + type=str, + choices=["bf16", "fp16"], + default="bf16", + help="Data type to use for testing (bf16 or fp16)", + ) + + args = parser.parse_args() + + torch_dtype = torch.bfloat16 + if args.dtype == "fp16": + torch_dtype = torch.float16 + + main(torch_dtype) From a9fab8656cacafc9777d3bd76e7a0b480484e6b8 Mon Sep 17 00:00:00 2001 From: Sijia Chen Date: Mon, 24 Feb 2025 10:01:59 -0800 Subject: [PATCH 03/23] add flag to disable FP16 compile Signed-off-by: Lucas Wilkinson --- csrc/flash_api.cpp | 9 +++++++-- setup.py | 29 +++++++++++++++++++++-------- 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp index b7b11f1..d2567fe 100644 --- a/csrc/flash_api.cpp +++ b/csrc/flash_api.cpp @@ -77,7 +77,6 @@ mha_fwd_kvcache_mla( at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache; auto q_dtype = q.dtype(); - TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kFloat16); TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype"); CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache); @@ -189,9 +188,15 @@ mha_fwd_kvcache_mla( if (q_dtype == torch::kBFloat16) { run_mha_fwd_splitkv_mla(params, stream); - } else { + } + #ifndef FLASH_MLA_DISABLE_FP16 + else if (q_dtype == torch::kHalf) { run_mha_fwd_splitkv_mla(params, stream); } + #endif + else { + TORCH_CHECK(false, "Unsupported tensor dtype for query"); + } out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3) .reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v}); diff --git a/setup.py b/setup.py index 662a301..6377b1e 100644 --- a/setup.py +++ b/setup.py @@ -11,11 +11,29 @@ IS_WINDOWS, ) +DISABLE_FP16 = os.getenv("FLASH_MLA_DISABLE_FP16", "FALSE") == "TRUE" def append_nvcc_threads(nvcc_extra_args): nvcc_threads = os.getenv("NVCC_THREADS") or "32" return nvcc_extra_args + ["--threads", nvcc_threads] +def get_sources(): + sources = [ + "csrc/flash_api.cpp", + "csrc/flash_fwd_mla_bf16_sm90.cu", + "csrc/flash_fwd_mla_metadata.cu", + ] + + if not DISABLE_FP16: + sources.append("csrc/flash_fwd_mla_fp16_sm90.cu") + + return sources + +def get_features_args(): + features_args = [] + if DISABLE_FP16: + features_args.append("-DFLASH_MLA_DISABLE_FP16") + return features_args subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"]) @@ -34,14 +52,9 @@ def append_nvcc_threads(nvcc_extra_args): ext_modules.append( CUDAExtension( name="flash_mla_cuda", - sources=[ - "csrc/flash_api.cpp", - "csrc/flash_fwd_mla_bf16_sm90.cu", - "csrc/flash_fwd_mla_fp16_sm90.cu", - "csrc/flash_fwd_mla_metadata.cu", - ], + sources=get_sources(), extra_compile_args={ - "cxx": cxx_args, + "cxx": cxx_args + get_features_args(), "nvcc": append_nvcc_threads( [ "-O3", @@ -59,7 +72,7 @@ def append_nvcc_threads(nvcc_extra_args): "--ptxas-options=-v,--register-usage-level=10" ] + cc_flag - ), + ) + get_features_args(), }, include_dirs=[ Path(this_dir) / "csrc", From 6f802fd9b04fcfdf344711c11e9db4d22533d533 Mon Sep 17 00:00:00 2001 From: ljss <450993438@qq.com> Date: Tue, 25 Feb 2025 09:18:11 +0800 Subject: [PATCH 04/23] Style fix Signed-off-by: Lucas Wilkinson --- tests/test_flash_mla.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_flash_mla.py b/tests/test_flash_mla.py index e676fa7..0abe9d2 100644 --- a/tests/test_flash_mla.py +++ b/tests/test_flash_mla.py @@ -60,7 +60,7 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen): ).view(b, max_seqlen_pad // block_size) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) for i in range(b): - blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item() :] = ( + blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = ( float("nan") ) blocked_v = blocked_k[..., :dv] From 79ac6067e94e9f92671c21649bd97224d0ff1f75 Mon Sep 17 00:00:00 2001 From: ljss <450993438@qq.com> Date: Wed, 26 Feb 2025 00:05:57 +0800 Subject: [PATCH 05/23] cuda12.8 recommendation Signed-off-by: Lucas Wilkinson --- README.md | 5 +++-- setup.py | 4 ++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 4027334..6d0bcb6 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ python setup.py install python tests/test_flash_mla.py ``` -Achieving up to 3000 GB/s in memory-bound configuration and 580 TFLOPS in computation-bound configuration on H800 SXM5, using CUDA 12.6. +Achieving up to 3000 GB/s in memory-bound configuration and 580 TFLOPS in computation-bound configuration on H800 SXM5, using CUDA 12.8. ### Usage @@ -42,6 +42,7 @@ for i in range(num_layers): - Hopper GPUs - CUDA 12.3 and above + - **But we highly recommend 12.8 or above for the best performance** - PyTorch 2.0 and above ## Acknowledgement @@ -52,7 +53,7 @@ FlashMLA is inspired by [FlashAttention 2&3](https://github.com/dao-AILab/flash- ```bibtex @misc{flashmla2025, - title={FlashMLA: Efficient MLA decoding kernel}, + title={FlashMLA: Efficient MLA decoding kernels}, author={Jiashi Li}, year={2025}, publisher = {GitHub}, diff --git a/setup.py b/setup.py index 6377b1e..cd311f2 100644 --- a/setup.py +++ b/setup.py @@ -13,10 +13,12 @@ DISABLE_FP16 = os.getenv("FLASH_MLA_DISABLE_FP16", "FALSE") == "TRUE" + def append_nvcc_threads(nvcc_extra_args): nvcc_threads = os.getenv("NVCC_THREADS") or "32" return nvcc_extra_args + ["--threads", nvcc_threads] + def get_sources(): sources = [ "csrc/flash_api.cpp", @@ -29,12 +31,14 @@ def get_sources(): return sources + def get_features_args(): features_args = [] if DISABLE_FP16: features_args.append("-DFLASH_MLA_DISABLE_FP16") return features_args + subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"]) cc_flag = [] From 81399b0de483f02cd776962b2b01ba7626dda675 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 25 Feb 2025 16:38:15 +0000 Subject: [PATCH 06/23] update to use torch library Signed-off-by: Lucas Wilkinson Signed-off-by: Lucas Wilkinson --- csrc/flash_api.cpp | 21 ++++++-- csrc/pytorch_shim.h | 123 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 139 insertions(+), 5 deletions(-) create mode 100644 csrc/pytorch_shim.h diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp index d2567fe..22baab0 100644 --- a/csrc/flash_api.cpp +++ b/csrc/flash_api.cpp @@ -1,6 +1,6 @@ // Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/flash_api.cpp -#include +// #include #include #include #include @@ -206,8 +206,19 @@ mha_fwd_kvcache_mla( return {out, softmax_lse}; } -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.doc() = "FlashMLA"; - m.def("get_mla_metadata", &get_mla_metadata); - m.def("fwd_kvcache_mla", &mha_fwd_kvcache_mla); +#include +#include "pytorch_shim.h" + +TORCH_LIBRARY(_flashmla_C, m) { + m.def("get_mla_metadata", make_pytorch_shim(&get_mla_metadata)); + m.impl("get_mla_metadata", torch::kCUDA, make_pytorch_shim(&get_mla_metadata)); + + m.def("fwd_kvcache_mla", make_pytorch_shim(&mha_fwd_kvcache_mla)); + m.impl("fwd_kvcache_mla", torch::kCUDA, make_pytorch_shim(&mha_fwd_kvcache_mla)); } + +PyMODINIT_FUNC PyInit__flashmla_C() { + static struct PyModuleDef module = { + PyModuleDef_HEAD_INIT, "_flashmla_C", nullptr, 0, nullptr}; + return PyModule_Create(&module); +} \ No newline at end of file diff --git a/csrc/pytorch_shim.h b/csrc/pytorch_shim.h new file mode 100644 index 0000000..779d40b --- /dev/null +++ b/csrc/pytorch_shim.h @@ -0,0 +1,123 @@ +#pragma once + +#include + +/** + * PyBind and PyTorch Library apis generally require different type signatures. + * This file provides a shim to (mostly, there may be missing conversions) to + * convert from function designed to be used with PyBind to one that can be used + * with PyTorch Library. This is done using `make_pytorch_shim` which creates a + * lambda that exponses the API using PyTorch compatible types to the types. + * This is useful when trying to ingergate PyBind based external libraries into + * vLLM. + * + * Example: + * + * PYBIND11_MODULE(NAME, m) { + * m.def("foo", &foo); + * } + * + * could be replaced with (using the shim): + * TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { + * m.def("foo", make_pytorch_shim(&foo)); + * m.impl("foo", torch::kCUDA, make_pytorch_shim(&foo)); + * } + * + * The `pytorch_library_compatible_type` struct is used to map from the + * flash_attn ops types to a PyTorch library compatible one. The main issues is + * that the following types are not support by PyTorch library bindings: + * - `int` + * - `float` + * - `c10::optional &` + * - `c10::optional &` + * So we convert them to (respectively): + * - `int64_t` + * - `double` + * - `const c10::optional&` + * - `const c10::optional&` + */ + +template +struct pytorch_library_compatible_type { + using type = T; + static T convert_from_type(T arg) { return arg; } +}; + +template +using pytorch_library_compatible_type_t = + typename pytorch_library_compatible_type::type; + +template +T convert_from_pytorch_compatible_type( + pytorch_library_compatible_type_t arg) { + return pytorch_library_compatible_type::convert_from_type(arg); +} + +// Map `c10::optional &` -> `const c10::optional&` +// (NOTE: this is bit unsafe but non of the ops in flash_attn mutate +// the optional container) +template +struct pytorch_library_compatible_type&> { + using type = const c10::optional&; + static c10::optional& convert_from_type(const c10::optional& arg) { + return const_cast&>(arg); + } +}; + +// Map `c10::optional` -> +// `c10::optional>` +// (NOTE: tested for `c10::optional` -> `c10::optional`) +template +struct pytorch_library_compatible_type> { + using type = c10::optional>; + static c10::optional> convert_from_type( + c10::optional arg) { + return arg; + } +}; + +// Map `c10::optional&` -> `const c10::optional&` +template <> +struct pytorch_library_compatible_type&> { + using type = const c10::optional&; + static c10::optional& convert_from_type( + const c10::optional& arg) { + return const_cast&>( + reinterpret_cast&>(arg)); + } +}; + +// Map `int` -> `int64_t` +template <> +struct pytorch_library_compatible_type { + using type = int64_t; + static int convert_from_type(int64_t arg) { + TORCH_CHECK(arg <= std::numeric_limits::max(), + "int64_t value is too large to be converted to int"); + TORCH_CHECK(arg >= std::numeric_limits::min(), + "int64_t value is too small to be converted to int"); + return arg; + } +}; + +// Map `float` -> `double` +template <> +struct pytorch_library_compatible_type { + using type = double; + static float convert_from_type(double arg) { + TORCH_CHECK(std::abs(arg) <= std::numeric_limits::max(), + "double value is too large to be converted to float"); + return arg; + } +}; + +// +// Shim Utils +// + +template +auto make_pytorch_shim(Ret (*fun)(Args... args)) { + return [fun](pytorch_library_compatible_type_t... args) { + return fun(convert_from_pytorch_compatible_type(args)...); + }; +} From f2a572c4f2d8605c29ccbec54e65e0d3efea3d70 Mon Sep 17 00:00:00 2001 From: hpp Date: Wed, 26 Feb 2025 11:26:42 +0800 Subject: [PATCH 07/23] add Community Support of [MetaX] and [Moore Threads] Signed-off-by: Lucas Wilkinson --- README.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/README.md b/README.md index 6d0bcb6..95aed11 100644 --- a/README.md +++ b/README.md @@ -49,6 +49,17 @@ for i in range(num_layers): FlashMLA is inspired by [FlashAttention 2&3](https://github.com/dao-AILab/flash-attention/) and [cutlass](https://github.com/nvidia/cutlass) projects. +## Community Support + +### MetaX + +For the MetaX GPU【https://www.metax-tech.com】, the corresponding FlashMLA version link is as follows: +GitHub - [MetaX-MACA/FlashMLA](https://github.com/MetaX-MACA/FlashMLA) + +### Moore Threads (WIP) +For the Moore Threads GPU【https://www.mthreads.com/】, the corresponding FlashMLA version link is as follows: +GitHub - [MooreThreads/MT-DeepSeek](https://github.com/MooreThreads/MT-DeepSeek) + ## Citation ```bibtex From a6e070f0e0f0d1e31394b6903353466395c37c25 Mon Sep 17 00:00:00 2001 From: "yangsijia.614" Date: Tue, 25 Feb 2025 23:52:54 +0800 Subject: [PATCH 08/23] fix(benchmark): store 'compare' and 'one' perf results in csv files and visualize them Signed-off-by: Lucas Wilkinson --- benchmark/bench_flash_mla.py | 18 ++++++++++++------ benchmark/visualize.py | 14 ++++++++++++-- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/benchmark/bench_flash_mla.py b/benchmark/bench_flash_mla.py index 7b0e7b4..14e1352 100644 --- a/benchmark/bench_flash_mla.py +++ b/benchmark/bench_flash_mla.py @@ -1,15 +1,16 @@ # MLA Triton kernel is from: https://github.com/monellz/vllm/commit/feebaa7c063be6bfb590a876741aeef1c5f58cf8#diff-7b2e1c9032522f7266051b9887246a65753871dfb3625a258fee40109fe6e87a +import argparse import math import random +import flashinfer import torch import triton import triton.language as tl -import argparse # pip install flashinfer-python -from flash_mla import get_mla_metadata, flash_mla_with_kvcache -import flashinfer +from flash_mla import flash_mla_with_kvcache, get_mla_metadata + def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): query = query.float() @@ -443,6 +444,7 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) print(f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10 ** 9 / perf_a:.0f} TFLOPS, {bytes / 10 ** 6 / perf_a:.0f} GB/s") print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s") + return bytes / 10 ** 6 / perf_a, bytes / 10 ** 6 / perf_b def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): @@ -501,7 +503,8 @@ def get_args(): if __name__ == "__main__": args = get_args() - with open("all_perf.csv", "w") as fout: + benchmark_type = "all" if args.all else f"{args.baseline}_vs_{args.target}" if args.compare else args.target + with open(f"{benchmark_type}_perf.csv", "w") as fout: fout.write("name,batch,seqlen,head,bw\n") for shape in shape_configs: if args.all: @@ -509,6 +512,9 @@ def get_args(): perf = compare_a(target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"]) fout.write(f'{target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n') elif args.compare: - compare_ab(args.baseline, args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"]) + perfa, prefb = compare_ab(args.baseline, args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"]) + fout.write(f'{args.baseline},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perfa:.0f}\n') + fout.write(f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{prefb:.0f}\n') elif args.one: - compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"]) \ No newline at end of file + perf = compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"]) + fout.write(f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n') \ No newline at end of file diff --git a/benchmark/visualize.py b/benchmark/visualize.py index db62519..c1fb37e 100644 --- a/benchmark/visualize.py +++ b/benchmark/visualize.py @@ -1,7 +1,17 @@ +import argparse + import matplotlib.pyplot as plt import pandas as pd -file_path = 'all_perf.csv' + +def parse_args(): + parser = argparse.ArgumentParser(description='Visualize benchmark results') + parser.add_argument('--file', type=str, default='all_perf.csv', + help='Path to the CSV file with benchmark results (default: all_perf.csv)') + return parser.parse_args() + +args = parse_args() +file_path = args.file df = pd.read_csv(file_path) @@ -16,4 +26,4 @@ plt.ylabel('bw (GB/s)') plt.legend() -plt.savefig('bandwidth_vs_seqlen.png') \ No newline at end of file +plt.savefig(f'{file_path.split(".")[0].split("/")[-1]}_bandwidth_vs_seqlen.png') \ No newline at end of file From 2d38d8561972ca114a47351a27d6d4acc934b30a Mon Sep 17 00:00:00 2001 From: Jiashi Li <31004720+beginlner@users.noreply.github.com> Date: Wed, 26 Feb 2025 20:30:45 +0800 Subject: [PATCH 09/23] Fix readme Signed-off-by: Lucas Wilkinson --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 95aed11..252bae1 100644 --- a/README.md +++ b/README.md @@ -53,12 +53,12 @@ FlashMLA is inspired by [FlashAttention 2&3](https://github.com/dao-AILab/flash- ### MetaX -For the MetaX GPU【https://www.metax-tech.com】, the corresponding FlashMLA version link is as follows: -GitHub - [MetaX-MACA/FlashMLA](https://github.com/MetaX-MACA/FlashMLA) +For [MetaX](https://www.metax-tech.com) GPUs, the corresponding FlashMLA version can be found at: +- [MetaX-MACA/FlashMLA](https://github.com/MetaX-MACA/FlashMLA) ### Moore Threads (WIP) -For the Moore Threads GPU【https://www.mthreads.com/】, the corresponding FlashMLA version link is as follows: -GitHub - [MooreThreads/MT-DeepSeek](https://github.com/MooreThreads/MT-DeepSeek) +For [Moore Threads](https://www.metax-tech.com) GPUs, the corresponding FlashMLA version can be found at: +- [MooreThreads/MT-DeepSeek](https://github.com/MooreThreads/MT-DeepSeek) ## Citation From 8c3a04591176e47f437173421399478a5840878a Mon Sep 17 00:00:00 2001 From: Jiashi Li <31004720+beginlner@users.noreply.github.com> Date: Wed, 26 Feb 2025 20:32:39 +0800 Subject: [PATCH 10/23] fix readme Signed-off-by: Lucas Wilkinson --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 252bae1..0bb3e52 100644 --- a/README.md +++ b/README.md @@ -57,7 +57,7 @@ For [MetaX](https://www.metax-tech.com) GPUs, the corresponding FlashMLA version - [MetaX-MACA/FlashMLA](https://github.com/MetaX-MACA/FlashMLA) ### Moore Threads (WIP) -For [Moore Threads](https://www.metax-tech.com) GPUs, the corresponding FlashMLA version can be found at: +For [Moore Threads](https://www.mthreads.com) GPUs, the corresponding FlashMLA version can be found at: - [MooreThreads/MT-DeepSeek](https://github.com/MooreThreads/MT-DeepSeek) ## Citation From 11ba97db8fddec1353e9152f43b8feb000dcadc5 Mon Sep 17 00:00:00 2001 From: hpp Date: Thu, 27 Feb 2025 09:39:18 +0800 Subject: [PATCH 11/23] add Community Support of [Hygon DCU] [Intellifusion] [Iluvatar Corex] Signed-off-by: Lucas Wilkinson --- README.md | 37 +++++++++++++++++++++++++++++++++---- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 0bb3e52..ae9f900 100644 --- a/README.md +++ b/README.md @@ -53,12 +53,41 @@ FlashMLA is inspired by [FlashAttention 2&3](https://github.com/dao-AILab/flash- ### MetaX -For [MetaX](https://www.metax-tech.com) GPUs, the corresponding FlashMLA version can be found at: +For MetaX GPUs, visit the official website: [MetaX](https://www.metax-tech.com). + +The corresponding FlashMLA version can be found at: - [MetaX-MACA/FlashMLA](https://github.com/MetaX-MACA/FlashMLA) -### Moore Threads (WIP) -For [Moore Threads](https://www.mthreads.com) GPUs, the corresponding FlashMLA version can be found at: -- [MooreThreads/MT-DeepSeek](https://github.com/MooreThreads/MT-DeepSeek) + +### Moore Threads +For the Moore Threads GPU, visit the official website: [Moore Threads](https://www.mthreads.com/). + +The corresponding FlashMLA version is available on GitHub: +[MooreThreads/MT-flashMLA](GitHub - MooreThreads/MT-flashMLA: Fork from https://github.com/deepseek-ai/FlashMLA). + + +### Hygon DCU + +For the Hygon DCU, visit the official website: [Hygon Developer](https://developer.sourcefind.cn/). + +The corresponding FlashMLA version is available here: +[OpenDAS/MLAttention](https://developer.sourcefind.cn/codes/OpenDAS/MLAttention). + + +### Intellifusion + +For the Intellifusion NNP, visit the official website: [Intellifusion](https://www.intellif.com). + +The corresponding FlashMLA version is available on Gitee: +[Intellifusion/tyllm](https://gitee.com/Intellifusion_2025/tyllm/blob/master/python/tylang/flash_mla.py). + + +### Iluvatar Corex + +For Iluvatar Corex GPUs, visit the official website: [Iluvatar Corex](https://www.iluvatar.com). + +The corresponding FlashMLA version is available on GitHub: +[Deep-Spark/FlashMLA](https://github.com/Deep-Spark/FlashMLA/tree/iluvatar_flashmla) ## Citation From 31d22559606ff434c4da52797308369bd32912ec Mon Sep 17 00:00:00 2001 From: hpp Date: Thu, 27 Feb 2025 09:40:47 +0800 Subject: [PATCH 12/23] add Community Support of [Hygon DCU] [Intellifusion] [Iluvatar Corex] Signed-off-by: Lucas Wilkinson --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index ae9f900..2d2cfd6 100644 --- a/README.md +++ b/README.md @@ -56,14 +56,14 @@ FlashMLA is inspired by [FlashAttention 2&3](https://github.com/dao-AILab/flash- For MetaX GPUs, visit the official website: [MetaX](https://www.metax-tech.com). The corresponding FlashMLA version can be found at: -- [MetaX-MACA/FlashMLA](https://github.com/MetaX-MACA/FlashMLA) +[MetaX-MACA/FlashMLA](https://github.com/MetaX-MACA/FlashMLA) ### Moore Threads For the Moore Threads GPU, visit the official website: [Moore Threads](https://www.mthreads.com/). The corresponding FlashMLA version is available on GitHub: -[MooreThreads/MT-flashMLA](GitHub - MooreThreads/MT-flashMLA: Fork from https://github.com/deepseek-ai/FlashMLA). +[MooreThreads/MT-flashMLA](https://github.com/MooreThreads/MT-flashMLA). ### Hygon DCU From 400a3135a0eef4801ef0ab347178f1b7ecfb3d6e Mon Sep 17 00:00:00 2001 From: hpp Date: Thu, 27 Feb 2025 09:42:09 +0800 Subject: [PATCH 13/23] reformat Community Support section Signed-off-by: Lucas Wilkinson --- README.md | 21 ++++++--------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 2d2cfd6..b79757c 100644 --- a/README.md +++ b/README.md @@ -51,43 +51,34 @@ FlashMLA is inspired by [FlashAttention 2&3](https://github.com/dao-AILab/flash- ## Community Support -### MetaX - +### MetaX For MetaX GPUs, visit the official website: [MetaX](https://www.metax-tech.com). -The corresponding FlashMLA version can be found at: -[MetaX-MACA/FlashMLA](https://github.com/MetaX-MACA/FlashMLA) +The corresponding FlashMLA version can be found at: [MetaX-MACA/FlashMLA](https://github.com/MetaX-MACA/FlashMLA) ### Moore Threads For the Moore Threads GPU, visit the official website: [Moore Threads](https://www.mthreads.com/). -The corresponding FlashMLA version is available on GitHub: -[MooreThreads/MT-flashMLA](https://github.com/MooreThreads/MT-flashMLA). +The corresponding FlashMLA version is available on GitHub: [MooreThreads/MT-flashMLA](https://github.com/MooreThreads/MT-flashMLA). ### Hygon DCU - For the Hygon DCU, visit the official website: [Hygon Developer](https://developer.sourcefind.cn/). -The corresponding FlashMLA version is available here: -[OpenDAS/MLAttention](https://developer.sourcefind.cn/codes/OpenDAS/MLAttention). +The corresponding FlashMLA version is available here: [OpenDAS/MLAttention](https://developer.sourcefind.cn/codes/OpenDAS/MLAttention). ### Intellifusion - For the Intellifusion NNP, visit the official website: [Intellifusion](https://www.intellif.com). -The corresponding FlashMLA version is available on Gitee: -[Intellifusion/tyllm](https://gitee.com/Intellifusion_2025/tyllm/blob/master/python/tylang/flash_mla.py). +The corresponding FlashMLA version is available on Gitee: [Intellifusion/tyllm](https://gitee.com/Intellifusion_2025/tyllm/blob/master/python/tylang/flash_mla.py). ### Iluvatar Corex - For Iluvatar Corex GPUs, visit the official website: [Iluvatar Corex](https://www.iluvatar.com). -The corresponding FlashMLA version is available on GitHub: -[Deep-Spark/FlashMLA](https://github.com/Deep-Spark/FlashMLA/tree/iluvatar_flashmla) +The corresponding FlashMLA version is available on GitHub: [Deep-Spark/FlashMLA](https://github.com/Deep-Spark/FlashMLA/tree/iluvatar_flashmla) ## Citation From c37449b8c32fd569880ccadf2bd2be2d3cefb47d Mon Sep 17 00:00:00 2001 From: Jiashi Li <31004720+beginlner@users.noreply.github.com> Date: Sat, 1 Mar 2025 17:55:58 +0800 Subject: [PATCH 14/23] add community support for [AMD] Signed-off-by: Lucas Wilkinson --- README.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/README.md b/README.md index b79757c..1dad9ef 100644 --- a/README.md +++ b/README.md @@ -80,6 +80,12 @@ For Iluvatar Corex GPUs, visit the official website: [Iluvatar Corex](https://ww The corresponding FlashMLA version is available on GitHub: [Deep-Spark/FlashMLA](https://github.com/Deep-Spark/FlashMLA/tree/iluvatar_flashmla) + +### AMD Instinct +For AMD Instinct GPUs, visit the official website: [AMD Instinct](https://www.amd.com/en/products/accelerators/instinct.html). + +The corresponding FlashMLA version can be found at: [AITER/MLA](https://github.com/ROCm/aiter/blob/main/aiter/mla.py) + ## Citation ```bibtex From ae4fe847d093d789ebc861a4eb69ef43d27f8f48 Mon Sep 17 00:00:00 2001 From: ljss <450993438@qq.com> Date: Sat, 1 Mar 2025 18:24:24 +0800 Subject: [PATCH 15/23] add missing copyright Signed-off-by: Lucas Wilkinson --- csrc/flash_api.cpp | 3 +++ csrc/softmax.h | 3 +++ csrc/utils.h | 3 +++ 3 files changed, 9 insertions(+) diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp index 22baab0..91ef3db 100644 --- a/csrc/flash_api.cpp +++ b/csrc/flash_api.cpp @@ -1,4 +1,7 @@ // Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/flash_api.cpp +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ // #include #include diff --git a/csrc/softmax.h b/csrc/softmax.h index 4ab6ae9..17e293a 100644 --- a/csrc/softmax.h +++ b/csrc/softmax.h @@ -1,4 +1,7 @@ // Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/softmax.h +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ #pragma once diff --git a/csrc/utils.h b/csrc/utils.h index 3b8dd52..50295f7 100644 --- a/csrc/utils.h +++ b/csrc/utils.h @@ -1,4 +1,7 @@ // Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/hopper/utils.h +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ #pragma once From 0fbcce4d9d0864835e5946d3d0c7303471d5b78c Mon Sep 17 00:00:00 2001 From: Shengyu Liu Date: Tue, 22 Apr 2025 17:50:57 +0800 Subject: [PATCH 16/23] Performance Update (2025.04.22) (#71) * Fix benchmark script * Performance optimization for compute-bound cases * Add new testcase (s_k = 16384) * Update README.md * Update comment * Update README.md * Add the deep-dive blog * Add background color for MLA Kernel Sched.drawio.svg * Use relative path for the schedule image * Move flash_mla.h to kernels/params.h Signed-off-by: Lucas Wilkinson --- .gitignore | 1 + README.md | 30 +- benchmark/bench_flash_mla.py | 2 +- csrc/flash_api.cpp | 146 +- csrc/flash_fwd_mla_bf16_sm90.cu | 3 - csrc/flash_fwd_mla_fp16_sm90.cu | 3 - csrc/flash_fwd_mla_kernel.h | 603 -------- csrc/kernels/config.h | 13 + .../get_mla_metadata.cu} | 25 +- csrc/kernels/get_mla_metadata.h | 5 + csrc/kernels/mla_combine.cu | 207 +++ csrc/kernels/mla_combine.h | 6 + csrc/{flash_mla.h => kernels/params.h} | 25 +- csrc/kernels/splitkv_mla.cu | 1350 +++++++++++++++++ csrc/kernels/splitkv_mla.h | 6 + csrc/kernels/traits.h | 106 ++ csrc/{static_switch.h => kernels/utils.h} | 37 +- csrc/named_barrier.h | 15 - csrc/softmax.h | 200 --- csrc/utils.h | 241 --- docs/20250422-new-kernel-deep-dive.md | 77 + docs/assets/MLA Kernel Sched.drawio.svg | 856 +++++++++++ flash_mla/flash_mla_interface.py | 1 - setup.py | 25 +- tests/test_flash_mla.py | 2 +- 25 files changed, 2757 insertions(+), 1228 deletions(-) delete mode 100644 csrc/flash_fwd_mla_bf16_sm90.cu delete mode 100644 csrc/flash_fwd_mla_fp16_sm90.cu delete mode 100644 csrc/flash_fwd_mla_kernel.h create mode 100644 csrc/kernels/config.h rename csrc/{flash_fwd_mla_metadata.cu => kernels/get_mla_metadata.cu} (79%) create mode 100644 csrc/kernels/get_mla_metadata.h create mode 100644 csrc/kernels/mla_combine.cu create mode 100644 csrc/kernels/mla_combine.h rename csrc/{flash_mla.h => kernels/params.h} (71%) create mode 100644 csrc/kernels/splitkv_mla.cu create mode 100644 csrc/kernels/splitkv_mla.h create mode 100644 csrc/kernels/traits.h rename csrc/{static_switch.h => kernels/utils.h} (57%) delete mode 100644 csrc/named_barrier.h delete mode 100644 csrc/softmax.h delete mode 100644 csrc/utils.h create mode 100644 docs/20250422-new-kernel-deep-dive.md create mode 100644 docs/assets/MLA Kernel Sched.drawio.svg diff --git a/.gitignore b/.gitignore index 5f9e980..982daef 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ __pycache__/ dist/ *perf.csv *.png +/.vscode diff --git a/README.md b/README.md index 1dad9ef..6de1640 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,28 @@ # FlashMLA +## Performance Update (2025.04.22) + +We're excited to announce the new release of Flash MLA, which delivers 5% ~ 15% performance improvement on compute-bound workloads, achieving up to 660 TFlops on NVIDIA H800 SXM5 GPUs. The interface of the new version is fully compatible with the old one. Just switch to the new version and enjoy the instant speedup! 🚀🚀🚀 + +Besides, we'd love to share the technical details behind the new kernel! Check out our deep-dive write-up here: + +The new kernel primarily targets compute-intensive settings (where the number of q heads $\times$ the number of q tokens per request (if MTP is disabled then it's 1) $\ge 64$). For memory-bound cases, we recommend using version [b31bfe7](https://github.com/deepseek-ai/FlashMLA/tree/b31bfe72a83ea205467b3271a5845440a03ed7cb) for optimal performance. + +## Introduction + FlashMLA is an efficient MLA decoding kernel for Hopper GPUs, optimized for variable-length sequences serving. Currently released: - BF16, FP16 - Paged kvcache with block size of 64 +## Requirements + +- Hopper GPUs +- CUDA 12.3 and above + - **But we highly recommend 12.8 or above for the best performance** +- PyTorch 2.0 and above + ## Quick start ### Install @@ -20,7 +37,9 @@ python setup.py install python tests/test_flash_mla.py ``` -Achieving up to 3000 GB/s in memory-bound configuration and 580 TFLOPS in computation-bound configuration on H800 SXM5, using CUDA 12.8. +It is able up to 3000 GB/s in memory-bound configuration and 660 TFLOPS in computation-bound configuration on H800 SXM5, using CUDA 12.8. + +Note. For memory-bound cases, we recommend using version [b31bfe7](https://github.com/deepseek-ai/FlashMLA/tree/b31bfe72a83ea205467b3271a5845440a03ed7cb) for optimal performance. ### Usage @@ -38,13 +57,6 @@ for i in range(num_layers): ... ``` -## Requirements - -- Hopper GPUs -- CUDA 12.3 and above - - **But we highly recommend 12.8 or above for the best performance** -- PyTorch 2.0 and above - ## Acknowledgement FlashMLA is inspired by [FlashAttention 2&3](https://github.com/dao-AILab/flash-attention/) and [cutlass](https://github.com/nvidia/cutlass) projects. @@ -91,7 +103,7 @@ The corresponding FlashMLA version can be found at: [AITER/MLA](https://github.c ```bibtex @misc{flashmla2025, title={FlashMLA: Efficient MLA decoding kernels}, - author={Jiashi Li}, + author={Jiashi Li, Shengyu Liu}, year={2025}, publisher = {GitHub}, howpublished = {\url{https://github.com/deepseek-ai/FlashMLA}}, diff --git a/benchmark/bench_flash_mla.py b/benchmark/bench_flash_mla.py index 14e1352..95c75f2 100644 --- a/benchmark/bench_flash_mla.py +++ b/benchmark/bench_flash_mla.py @@ -435,7 +435,7 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out" - if target not in ["flash_infer", "flash_mla_triton"]: + if target not in ["flash_infer", "flash_mla_triton"] and baseline not in ["flash_infer", "flash_mla_triton"]: # flash_infer has a different lse return value # flash_mla_triton doesn't return lse torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse" diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp index 91ef3db..d6b96c4 100644 --- a/csrc/flash_api.cpp +++ b/csrc/flash_api.cpp @@ -10,8 +10,11 @@ #include -#include "flash_mla.h" -#include "static_switch.h" +#include "kernels/config.h" +#include "kernels/get_mla_metadata.h" +#include "kernels/mla_combine.h" +#include "kernels/params.h" +#include "kernels/splitkv_mla.h" #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") @@ -23,11 +26,6 @@ get_mla_metadata( const int num_heads_per_head_k, const int num_heads_k ) { - // This should match the logic in the MLA kernel. - static constexpr int block_size_m = 64; - static constexpr int block_size_n = 64; - static constexpr int fixed_overhead_num_blocks = 5; - CHECK_DEVICE(seqlens_k); TORCH_CHECK(seqlens_k.is_contiguous()); TORCH_CHECK(seqlens_k.dtype() == torch::kInt32); @@ -38,7 +36,7 @@ get_mla_metadata( auto dprops = at::cuda::getCurrentDeviceProperties(); int sm_count = dprops->multiProcessorCount; - int num_sm_parts = sm_count / num_heads_k / cutlass::ceil_div(num_heads_per_head_k, block_size_m); + int num_sm_parts = sm_count / num_heads_k / cutlass::ceil_div(num_heads_per_head_k, Config::BLOCK_SIZE_M); auto tile_scheduler_metadata = torch::empty({num_sm_parts, TileSchedulerMetaDataSize}, options); auto num_splits = torch::empty({batch_size + 1}, options); @@ -52,10 +50,10 @@ get_mla_metadata( params.tile_scheduler_metadata_ptr = tile_scheduler_metadata_ptr; params.num_splits_ptr = num_splits_ptr; params.batch_size = batch_size; - params.block_size_n = block_size_n; - params.fixed_overhead_num_blocks = fixed_overhead_num_blocks; + params.block_size_n = Config::PAGE_BLOCK_SIZE; + params.fixed_overhead_num_blocks = Config::FIXED_OVERHEAD_NUM_BLOCKS; params.num_sm_parts = num_sm_parts; - get_mla_metadata_func(params, stream); + run_get_mla_metadata_kernel(params, stream); return {tile_scheduler_metadata, num_splits}; } @@ -64,7 +62,6 @@ std::vector mha_fwd_kvcache_mla( at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size - std::optional &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v const int head_size_v, const at::Tensor &seqlens_k, // batch_size const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq @@ -73,138 +70,141 @@ mha_fwd_kvcache_mla( const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize const at::Tensor &num_splits // batch_size + 1 ) { + // Check the architecture auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_sm90 = dprops->major == 9 && dprops->minor == 0; TORCH_CHECK(is_sm90); - at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache; - + // Check data types auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kHalf); TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32"); + TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); + TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32"); + TORCH_CHECK(num_splits.dtype() == torch::kInt32, "num_splits must have dtype int32"); - CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache); + // Check device + CHECK_DEVICE(q); + CHECK_DEVICE(kcache); + CHECK_DEVICE(seqlens_k); + CHECK_DEVICE(block_table); + CHECK_DEVICE(tile_scheduler_metadata); + CHECK_DEVICE(num_splits); + // Check layout TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - - CHECK_DEVICE(block_table); - TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); + CHECK_CONTIGUOUS(seqlens_k); TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); + CHECK_CONTIGUOUS(tile_scheduler_metadata); + CHECK_CONTIGUOUS(num_splits); const auto sizes = q.sizes(); const int batch_size = sizes[0]; const int seqlen_q_ori = sizes[1]; - const int num_heads_ori = sizes[2]; - const int head_size = sizes[3]; - TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); - TORCH_CHECK(head_size_v % 32 == 0, "head_size_v should be a multiple of 32"); + const int num_heads_q = sizes[2]; + const int head_size_k = sizes[3]; + TORCH_CHECK(head_size_k == 576, "Only head_size_k == 576 is supported"); + TORCH_CHECK(head_size_v == 512, "Only head_size_v == 576 is supported"); const int max_num_blocks_per_seq = block_table.size(1); const int num_blocks = kcache.size(0); const int page_block_size = kcache.size(1); const int num_heads_k = kcache.size(2); TORCH_CHECK(batch_size > 0, "batch size must be postive"); - TORCH_CHECK(num_heads_ori % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + TORCH_CHECK(num_heads_q % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); if (seqlen_q_ori == 1) { is_causal = false; } - const int ngroups = num_heads_ori / num_heads_k; - const int seqlen_q = seqlen_q_ori * ngroups; + const int num_q_heads_per_hk = num_heads_q / num_heads_k; + const int q_seq_per_hk = seqlen_q_ori * num_q_heads_per_hk; const int num_heads = num_heads_k; - q = q.view({batch_size, seqlen_q_ori, num_heads_k, ngroups, head_size}).transpose(2, 3) - .reshape({batch_size, seqlen_q, num_heads, head_size}); + q = q.view({batch_size, seqlen_q_ori, num_heads_k, num_q_heads_per_hk, head_size_k}).transpose(2, 3) + .reshape({batch_size, q_seq_per_hk, num_heads, head_size_k}); - int head_size_k = head_size; - CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); + CHECK_SHAPE(q, batch_size, q_seq_per_hk, num_heads, head_size_k); CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k); - if (vcache_.has_value()) { CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_v); } - CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); - - - TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32"); - CHECK_DEVICE(seqlens_k); - CHECK_CONTIGUOUS(seqlens_k); CHECK_SHAPE(seqlens_k, batch_size); + CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); + TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize); + CHECK_SHAPE(num_splits, batch_size+1); at::cuda::CUDAGuard device_guard{(char)q.get_device()}; auto opts = q.options(); - at::Tensor out = torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts); - at::Tensor softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); + at::Tensor out = torch::empty({batch_size, q_seq_per_hk, num_heads, head_size_v}, opts); + at::Tensor softmax_lse = torch::empty({batch_size, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat)); + CHECK_CONTIGUOUS(softmax_lse); Flash_fwd_mla_params params = {}; // Set the sizes. params.b = batch_size; - params.seqlen_q = seqlen_q; - params.cu_seqlens_k = seqlens_k.data_ptr(); - params.h = num_heads; - params.h_h_k_ratio = num_heads / num_heads_k; - params.ngroups = ngroups; + params.s_q = seqlen_q_ori; + params.q_seq_per_hk = q_seq_per_hk; + params.seqlens_k_ptr = seqlens_k.data_ptr(); + params.h_q = num_heads_q; + params.h_k = num_heads_k; + params.num_blocks = num_blocks; + params.q_head_per_hk = num_q_heads_per_hk; params.is_causal = is_causal; - params.d = head_size; + params.d = head_size_k; params.d_v = head_size_v; params.scale_softmax = softmax_scale; params.scale_softmax_log2 = float(softmax_scale * M_LOG2E); // Set the pointers and strides. params.q_ptr = q.data_ptr(); params.k_ptr = kcache.data_ptr(); - params.v_ptr = vcache.data_ptr(); params.o_ptr = out.data_ptr(); params.softmax_lse_ptr = softmax_lse.data_ptr(); // All stride are in elements, not bytes. params.q_batch_stride = q.stride(0); params.k_batch_stride = kcache.stride(0); - params.v_batch_stride = vcache.stride(0); params.o_batch_stride = out.stride(0); params.q_row_stride = q.stride(-3); params.k_row_stride = kcache.stride(-3); - params.v_row_stride = vcache.stride(-3); params.o_row_stride = out.stride(-3); params.q_head_stride = q.stride(-2); params.k_head_stride = kcache.stride(-2); - params.v_head_stride = vcache.stride(-2); params.o_head_stride = out.stride(-2); params.block_table = block_table.data_ptr(); params.block_table_batch_stride = block_table.stride(0); params.page_block_size = page_block_size; - - TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32"); - TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize); - CHECK_DEVICE(tile_scheduler_metadata); - CHECK_CONTIGUOUS(tile_scheduler_metadata); + params.tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr(); params.num_sm_parts = tile_scheduler_metadata.size(0); - TORCH_CHECK(num_splits.dtype() == torch::kInt32, "num_splits must have dtype int32"); - CHECK_DEVICE(num_splits); - CHECK_CONTIGUOUS(num_splits); params.num_splits_ptr = num_splits.data_ptr(); - at::Tensor softmax_lse_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q}, opts.dtype(at::kFloat)); - at::Tensor out_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q, head_size_v}, opts.dtype(at::kFloat)); + const int total_num_splits = batch_size + params.num_sm_parts; + at::Tensor softmax_lse_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat)); + at::Tensor out_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk, head_size_v}, opts.dtype(at::kFloat)); + CHECK_CONTIGUOUS(softmax_lse_accum); + CHECK_CONTIGUOUS(out_accum); + params.total_num_splits = total_num_splits; params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr(); params.oaccum_ptr = out_accum.data_ptr(); auto stream = at::cuda::getCurrentCUDAStream().stream(); - TORCH_CHECK(head_size == 576); - + TORCH_CHECK(head_size_k == 576); if (q_dtype == torch::kBFloat16) { - run_mha_fwd_splitkv_mla(params, stream); - } - #ifndef FLASH_MLA_DISABLE_FP16 - else if (q_dtype == torch::kHalf) { - run_mha_fwd_splitkv_mla(params, stream); - } - #endif - else { + run_flash_splitkv_mla_kernel(params, stream); + run_flash_mla_combine_kernel(params, stream); + } else if (q_dtype == torch::kHalf) { +#ifdef FLASH_MLA_DISABLE_FP16 + TORCH_CHECK(false, "FlashMLA is compiled with -DFLASH_MLA_DISABLE_FP16. Please remove this flag from your environment and re-compile FlashMLA."); +#else + run_flash_splitkv_mla_kernel(params, stream); + run_flash_mla_combine_kernel(params, stream); +#endif + } else { TORCH_CHECK(false, "Unsupported tensor dtype for query"); } - out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3) - .reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v}); - softmax_lse = softmax_lse.view({batch_size, num_heads_k, seqlen_q_ori, ngroups}).transpose(2, 3) - .reshape({batch_size, num_heads_ori, seqlen_q_ori}); + out = out.view({batch_size, seqlen_q_ori, num_q_heads_per_hk, num_heads_k, head_size_v}).transpose(2, 3) + .reshape({batch_size, seqlen_q_ori, num_heads_q, head_size_v}); + softmax_lse = softmax_lse.view({batch_size, num_heads_k, seqlen_q_ori, num_q_heads_per_hk}).transpose(2, 3) + .reshape({batch_size, num_heads_q, seqlen_q_ori}); return {out, softmax_lse}; } diff --git a/csrc/flash_fwd_mla_bf16_sm90.cu b/csrc/flash_fwd_mla_bf16_sm90.cu deleted file mode 100644 index 35691f2..0000000 --- a/csrc/flash_fwd_mla_bf16_sm90.cu +++ /dev/null @@ -1,3 +0,0 @@ -#include "flash_fwd_mla_kernel.h" - -template void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_fwd_mla_fp16_sm90.cu b/csrc/flash_fwd_mla_fp16_sm90.cu deleted file mode 100644 index abdaf7b..0000000 --- a/csrc/flash_fwd_mla_fp16_sm90.cu +++ /dev/null @@ -1,3 +0,0 @@ -#include "flash_fwd_mla_kernel.h" - -template void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h deleted file mode 100644 index d96acd8..0000000 --- a/csrc/flash_fwd_mla_kernel.h +++ /dev/null @@ -1,603 +0,0 @@ -#pragma once - -#include -#include -#include -#include - -using namespace cute; - -#include "named_barrier.h" -#include "utils.h" -#include "softmax.h" -#include "static_switch.h" -#include "flash_mla.h" - - -template -constexpr auto getSmemLayoutK() { - constexpr int headSizeBytes = sizeof(PrecType) * DIM; - constexpr int headSizeBytes2 = sizeof(PrecType) * DIM2; - - if constexpr (headSizeBytes % 128 == 0 && headSizeBytes2 % 128 == 0) { - return GMMA::Layout_K_SW128_Atom{}; - } else if constexpr (headSizeBytes % 64 == 0 && headSizeBytes2 % 64 == 0) { - return GMMA::Layout_K_SW64_Atom{}; - } else { - return GMMA::Layout_K_SW32_Atom{}; - } -} - -template -struct Flash_fwd_kernel_traits_mla { - using Element = elem_type; - using ElementAccum = float; - using index_t = int64_t; - - static constexpr int kNWarps = kNWarps_; - static constexpr int kNThreads = kNWarps * 32; - static constexpr int kNWarpsS = 4; - static constexpr int kNThreadsS = kNWarpsS * 32; - - static constexpr int kBlockM = kBlockM_; - static constexpr int kBlockN = kBlockN_; - static constexpr int kHeadDim = kHeadDim_; - static_assert(kHeadDim % 32 == 0); - static constexpr int kHeadDimV = kHeadDimV_ != 0 ? kHeadDimV_ : kHeadDim; - static_assert(kHeadDimV % 32 == 0); - static_assert(kHeadDimV <= kHeadDim); - static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; - static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; - - using TiledMma = decltype(make_tiled_mma( - cute::GMMA::ss_op_selector, Int, Int>, - GMMA::Major::K, GMMA::Major::K>(), - Layout, _1, _1>>{})); - - static constexpr int AtomLayoutNO = kNThreads / kNThreadsS; - using TiledMmaO = decltype(make_tiled_mma( - cute::GMMA::rs_op_selector, Int, Int>, - GMMA::Major::K, GMMA::Major::MN>(), - Layout, Int, _1>>{})); - - using SmemLayoutQ = decltype(tile_to_shape( - getSmemLayoutK(), - Shape, Int>{})); - - using SmemLayoutK = decltype(tile_to_shape( - getSmemLayoutK(), - Shape, Int>{})); - - using SmemLayoutV = decltype(tile_to_shape( - getSmemLayoutK(), - Shape, Int>{})); - using SmemLayoutVtransposed = decltype(composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); - - using SmemLayoutP = Layout, Int, _1, Int>>; - using SmemLayoutRow = Layout>, Stride<_1, _2>>; - - using SmemLayoutAtomO = decltype(composition( - Swizzle{}, - Layout, Int>, Stride, _1>>{})); - using SmemLayoutO = decltype(tile_to_shape( - SmemLayoutAtomO{}, - Shape, Int>{})); - using SmemCopyAtomO = Copy_Atom; - using SmemCopyAtomOaccum = Copy_Atom, ElementAccum>; - - static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); - static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); - static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; - using Gmem_copy_struct = SM80_CP_ASYNC_CACHEGLOBAL; - static constexpr int kNThreadsLoad = kNThreads - kNThreadsS; - static_assert(kNThreadsLoad % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); - - using GmemLayoutAtom = Layout< - Shape, Int>, - Stride, _1>>; - using GmemTiledCopy = decltype(make_tiled_copy( - Copy_Atom{}, - GmemLayoutAtom{}, - Layout>{})); // Val layout, 8 vals per read - - using GmemLayoutAtomO = Layout< - Shape, Int>, - Stride, _1>>; - using GmemTiledCopyO = decltype(make_tiled_copy( - Copy_Atom, Element>{}, - GmemLayoutAtomO{}, - Layout>{})); // Val layout, 8 vals per store - - static constexpr int kGmemElemsPerLoadAccum = sizeof(cute::uint128_t) / sizeof(ElementAccum); - static constexpr int kGmemThreadsPerRowAccum = kBlockKSmem / kGmemElemsPerLoadAccum; - using GmemLayoutAtomOaccum = Layout< - Shape, Int>, - Stride, _1>>; - using GmemTiledCopyOaccum = decltype(make_tiled_copy( - Copy_Atom, ElementAccum>{}, - GmemLayoutAtomOaccum{}, - Layout>{})); // Val layout, 4 vals per store -}; - -namespace flash { - -using namespace cute; - -template -struct SharedStorageMLA { - union { - struct { - cute::array_aligned> smem_q; - cute::array_aligned * 2> smem_k; // Double buffer - cute::array_aligned> smem_p; - cute::array_aligned> smem_scale; - }; - struct { - cute::array_aligned> smem_max; - cute::array_aligned> smem_sum; - cute::array_aligned> smem_o; - }; - }; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__forceinline__ __device__ void store(const Flash_fwd_mla_params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, - SharedStorage &shared_storage, AccO tOrO, Softmax softmax) { - constexpr int kBlockM = Kernel_traits::kBlockM; - constexpr int kHeadDimV = Kernel_traits::kHeadDimV; - constexpr int kNThreadsS = Kernel_traits::kNThreadsS; - using Element = typename Kernel_traits::Element; - using ElementAccum = typename Kernel_traits::ElementAccum; - using index_t = typename Kernel_traits::index_t; - - const int tidx = threadIdx.x; - - typename Kernel_traits::TiledMmaO tiled_mma_o; - auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx); - - // Epilogue - - const int split_offset = __ldg(params.num_splits_ptr + bidb); - - Tensor lse = softmax.template normalize_softmax_lse(tOrO, params.scale_softmax); - - using ElementO = std::conditional_t; - Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(shared_storage.smem_o.data())), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) - // Partition sO to match the accumulator partitioning - using SmemTiledCopyO = std::conditional_t< - !Split, - typename Kernel_traits::SmemCopyAtomO, - typename Kernel_traits::SmemCopyAtomOaccum - >; - auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma_o); - auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx); - Tensor rO = flash::convert_type(tOrO); - Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) - Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N) - - __syncthreads(); - - cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum); - - const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; - const index_t row_offset_oaccum = (((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_v; - const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; - const index_t row_offset_lseaccum = ((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; - - Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), - Shape, Int>{}, - make_stride(Split ? kHeadDimV : params.o_row_stride, _1{})); - Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (Split ? row_offset_lseaccum : row_offset_lse)), - Shape>{}, Stride<_1>{}); - - using GmemTiledCopyO = std::conditional_t; - GmemTiledCopyO gmem_tiled_copy_Oaccum; - auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); - Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N) - Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); - - __syncthreads(); - - if (tidx >= kNThreadsS) { return; } - - Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); - cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum); - - Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor taccOcO = thr_mma_o.partition_C(caccO); // ((MMA=4, X), MMA_M, MMA_K=1) - Tensor taccOcO_row = taccOcO(make_coord(0, _, 0), _, 0); - CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M - if (get<1>(taccOcO_row(0)) == 0) { -#pragma unroll - for (int mi = 0; mi < size(lse); ++mi) { - const int row = get<0>(taccOcO_row(mi)); - if (row < params.seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); } - } - } - - // Construct identity layout for sO - Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) - // Repeat the partitioning with identity layouts - Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) - Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); - // Clear_OOB_K must be false since we don't want to write zeros to gmem - flash::copy( - gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, params.seqlen_q - m_block * kBlockM - ); -} - -template -__forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_fwd_mla_params ¶ms, - const int bidb, const int bidh, const int m_block, - const int n_split_idx, const int seqlen_k, - const int n_block_min, const int n_block_max, const bool NoSplit, - SharedStorage &shared_storage) { - constexpr int kBlockM = Kernel_traits::kBlockM; - constexpr int kBlockN = Kernel_traits::kBlockN; - constexpr int kHeadDim = Kernel_traits::kHeadDim; - constexpr int kHeadDimV = Kernel_traits::kHeadDimV; - constexpr int kNThreads = Kernel_traits::kNThreads; - constexpr int kNThreadsS = Kernel_traits::kNThreadsS; - static_assert(kNThreads == 256 and kNThreadsS == 128); - using Element = typename Kernel_traits::Element; - using index_t = typename Kernel_traits::index_t; - - const int tidx = threadIdx.x; - int n_block = n_block_max - 1; - - Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Kernel_traits::SmemLayoutQ{}); - Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{}); - Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{}); - Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{}); - - Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{}); - Tensor tPsP = sP(_, tidx % kNThreadsS, _, _); - Tensor sScale_o = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()), typename Kernel_traits::SmemLayoutRow{}); - Tensor tScale_osScale_o = sScale_o(_, tidx % kNThreadsS); - Tensor sRow_max = make_tensor(make_smem_ptr(shared_storage.smem_max.data()), typename Kernel_traits::SmemLayoutRow{}); - Tensor tRow_maxsRow_max = sRow_max(_, tidx % kNThreadsS); - Tensor sRow_sum = make_tensor(make_smem_ptr(shared_storage.smem_sum.data()), typename Kernel_traits::SmemLayoutRow{}); - Tensor tRow_sumsRow_sum = sRow_sum(_, tidx % kNThreadsS); - - typename Kernel_traits::TiledMmaO tiled_mma_o; - auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx); - Tensor tOrVt = thr_mma_o.partition_fragment_B(sVt); // (MMA, MMA_K,MMA_N) - Tensor tOrO = partition_fragment_C(tiled_mma_o, Shape, Int>{}); // ((MMA=4, X), MMA_M, MMA_N=1) - clear(tOrO); - - flash::Softmax<2 * size<1>(tOrO)> softmax; - - int warp_group_idx = cutlass::canonical_warp_group_idx(); - if (warp_group_idx == 0) { - typename Kernel_traits::TiledMma tiled_mma; - auto thr_mma = tiled_mma.get_thread_slice(tidx); - Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) - Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) - - if (n_block % 2 == 1) { - // Double buffer for sK - constexpr int sK_offset = size(sK); - tSrK.data() = tSrK.data() + sK_offset / 8; - tOrVt.data() = tOrVt.data() + sK_offset / 8; - } - - // We need masking on S for the very last block when K and V has length not multiple of kBlockN. - // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. - // We will have at least 1 "masking" iteration. - // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to - // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. - constexpr int n_masking_steps = !Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1; -#pragma unroll 1 - for (int masking_step = n_masking_steps; n_block >= n_block_min; --masking_step, --n_block) { - __syncthreads(); - - Tensor tSrS = partition_fragment_C(tiled_mma, Shape, Int>{}); // ((MMA=4, X), MMA_M, MMA_N=1) - flash::gemm(tiled_mma, tSrQ, tSrK, tSrS); - - const bool is_masking_step = masking_step > 0; - const bool is_first_masking_step = masking_step == n_masking_steps; - - if (is_masking_step) { - Tensor cS = make_identity_tensor(Shape, Int>{}); - Tensor tScS = thr_mma.partition_C(cS); -#pragma unroll - for (int i = 0; i < size(tSrS); ++i) { - if constexpr (!Is_causal) { // Just masking based on col - if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) tSrS(i) = -INFINITY; - } else { - // Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups - // col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups - int row = int(get<0>(tScS(i))); - int col_limit_right = seqlen_k - 1 - n_block * kBlockN - (params.seqlen_q - 1 - (m_block * kBlockM + row)) / params.ngroups; - if (int(get<1>(tScS(i))) > col_limit_right) tSrS(i) = -INFINITY; - } - } - } - - // We have key_padding_mask so we'll need to Check_inf - Tensor scale_o = is_first_masking_step - ? softmax.template softmax(tSrS, params.scale_softmax_log2) - : is_masking_step ? - softmax.template softmax(tSrS, params.scale_softmax_log2) - : softmax.template softmax(tSrS, params.scale_softmax_log2); - - Tensor rP = flash::convert_type(tSrS); - cute::copy(rP, tPsP); - cute::copy(scale_o, tScale_osScale_o); - - cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast(NamedBarriers::SReady)); - - flash::rescale_o(tOrO, scale_o); - - Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); - flash::gemm(tiled_mma_o, tOrP, tOrVt, tOrO); - - // Double buffer for sK - const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK); - tSrK.data() = tSrK.data() + sK_offset / 8; - tOrVt.data() = tOrVt.data() + sK_offset / 8; - } - - cute::copy(softmax.row_max, tRow_maxsRow_max); - cute::copy(softmax.row_sum, tRow_sumsRow_sum); - cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast(NamedBarriers::SoftmaxReady)); - } else { - const int *block_table = params.block_table + bidb * params.block_table_batch_stride; - int cur_block_table = __ldg(&block_table[n_block]); - - const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; - Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), - Shape, Int>{}, - make_stride(params.q_row_stride, _1{})); - typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_Q; - auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx - kNThreadsS); - Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ); - Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ); - Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) - Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); - - // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs - flash::copy(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ, - params.seqlen_q - m_block * kBlockM); - - const index_t row_offset_k = (bidh / params.h_h_k_ratio) * params.k_head_stride; - Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), - Shape, Int>{}, - make_stride(params.k_row_stride, _1{})); - typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_K; - auto gmem_thr_copy_K = gmem_tiled_copy_K.get_thread_slice(tidx - kNThreadsS); - Tensor tKgK = gmem_thr_copy_K.partition_S(gK); - Tensor tKsK = gmem_thr_copy_K.partition_D(sK); - Tensor cK = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) - Tensor tKcK = gmem_thr_copy_K.partition_S(cK); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) - Tensor tKpK = make_tensor(make_shape(size<2>(tKsK))); - - if (n_block % 2 == 1) { - // Double buffer for sK - constexpr int sK_offset = size(sK); - tKsK.data() = tKsK.data() + sK_offset; - tOrVt.data() = tOrVt.data() + sK_offset / 8; - } - - // We need to clear the sK smem tiles because K is V. - const index_t offset_k = cur_block_table * params.k_batch_stride; - tKgK.data() = tKgK.data() + offset_k; - flash::copy(gmem_tiled_copy_K, tKgK, tKsK, tKcK, tKpK, - seqlen_k - n_block * kBlockN); - tKgK.data() = tKgK.data() + -offset_k; - cute::cp_async_fence(); - - if (n_block - 1 >= n_block_min) { - cur_block_table = __ldg(&block_table[n_block - 1]); - } - -#pragma unroll 1 - for (; n_block >= n_block_min; --n_block) { - flash::cp_async_wait<0>(); - __syncthreads(); - - if (n_block - 1 >= n_block_min) { - // Double buffer for sK - const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK); - tKsK.data() = tKsK.data() + sK_offset; - - const index_t offset_k = cur_block_table * params.k_batch_stride; - tKgK.data() = tKgK.data() + offset_k; - flash::copy(gmem_tiled_copy_K, tKgK, tKsK, tKcK, tKpK); - tKgK.data() = tKgK.data() + -offset_k; - cute::cp_async_fence(); - } - - cutlass::arch::NamedBarrier::sync(kNThreads, static_cast(NamedBarriers::SReady)); - - if (n_block - 2 >= n_block_min) { - cur_block_table = __ldg(&block_table[n_block - 2]); - } - - typename Kernel_traits::TiledMma tiled_mma; - auto tSrS_layout = partition_fragment_C(tiled_mma, Shape, Int>{}).layout(); - Tensor rP = make_tensor(tSrS_layout); - Tensor scale_o = make_tensor(Shape<_2>{}); - cute::copy(tScale_osScale_o, scale_o); - cute::copy(tPsP, rP); - - flash::rescale_o(tOrO, scale_o); - - Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); - flash::gemm(tiled_mma_o, tOrP, tOrVt, tOrO); - - // Double buffer for sK - const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK); - tOrVt.data() = tOrVt.data() + sK_offset / 8; - } - - cutlass::arch::NamedBarrier::sync(kNThreads, static_cast(NamedBarriers::SoftmaxReady)); - cute::copy(tRow_maxsRow_max, softmax.row_max); - cute::copy(tRow_sumsRow_sum, softmax.row_sum); - } - - if (NoSplit) - store(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax); - else - store(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax); -} - -template -__global__ void __launch_bounds__(Kernel_traits::kNThreads, 1, 1) -flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params) { - constexpr int kBlockN = Kernel_traits::kBlockN; - const int m_block = blockIdx.x; - const int bidh = blockIdx.y; - const int partition_idx = blockIdx.z; - - extern __shared__ char shared_memory[]; - auto &shared_storage = *reinterpret_cast(shared_memory); - - int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize; - int4 tile_scheduler_metadata = __ldg(reinterpret_cast(tile_scheduler_metadata_ptr)); - int begin_idx = tile_scheduler_metadata.x; - int begin_seqlen = tile_scheduler_metadata.y; - int end_idx = tile_scheduler_metadata.z; - int end_seqlen = tile_scheduler_metadata.w; - if (begin_idx >= params.b) return; - int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4); - -#pragma unroll 1 - for (int batch_id = begin_idx; batch_id <= end_idx; ++batch_id) { - const int n_split_idx = batch_id == begin_idx ? begin_n_split_idx : 0; - const int seqlen_k = __ldg(params.cu_seqlens_k + batch_id); - const int n_block_min = batch_id == begin_idx ? begin_seqlen / kBlockN : 0; - const int n_block_max = batch_id == end_idx ? cute::ceil_div(end_seqlen, kBlockN) : cute::ceil_div(seqlen_k, kBlockN); - const bool NoSplit = n_block_min == 0 && n_block_max == cute::ceil_div(seqlen_k, kBlockN); - if (batch_id > begin_idx) { - __syncthreads(); // Barrier between two tiles. - } - flash::compute_attn_1rowblock_splitkv_mla(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__global__ void __launch_bounds__(256, 1, 1) -flash_fwd_splitkv_mla_combine_kernel(__grid_constant__ const Flash_fwd_mla_params params) { - constexpr int kNThreads = 128; - - const int tidx = threadIdx.x; - const int bidx = blockIdx.x; - const int hs = params.h * params.seqlen_q; - const int batch_idx = bidx / hs; - const int hs_idx = bidx % hs; - - const int split_offset = __ldg(params.num_splits_ptr + batch_idx); - const int actual_num_splits = __ldg(params.num_splits_ptr + batch_idx + 1) - split_offset; - FLASH_DEVICE_ASSERT(actual_num_splits <= kMaxSplits); - if (actual_num_splits == 1) return; - - __shared__ ElementAccum sLseScale[kMaxSplits]; - - const index_t row_offset_lseaccum = split_offset * hs + hs_idx; - const index_t row_offset_lse = bidx; - Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lseaccum_ptr) + row_offset_lseaccum), - Shape>{}, make_stride(hs)); - Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), - Shape<_1>{}, Stride<_1>{}); - - int warp_idx = cutlass::canonical_warp_idx_sync(); - if (warp_idx == 0) { - constexpr int kNLsePerThread = cute::ceil_div(kMaxSplits, 32); - - float local_lse[kNLsePerThread]; - for (int i = 0; i < kNLsePerThread; ++i) { - const int split = i * 32 + tidx; - local_lse[i] = split < actual_num_splits ? gLSEaccum(split) : -INFINITY; - } - - float max_lse = -INFINITY; - for (int i = 0; i < kNLsePerThread; ++i) max_lse = max(max_lse, local_lse[i]); - for (int offset = 16; offset >= 1; offset /= 2) max_lse = max(max_lse, __shfl_xor_sync(uint32_t(-1), max_lse, offset)); - max_lse = max_lse == -INFINITY ? 0.0f : max_lse; // In case all local LSEs are -inf - - float sum_lse = 0; - for (int i = 0; i < kNLsePerThread; ++i) sum_lse = sum_lse + expf(local_lse[i] - max_lse); - for (int offset = 16; offset >= 1; offset /= 2) sum_lse = sum_lse + __shfl_xor_sync(uint32_t(-1), sum_lse, offset); - - float global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? INFINITY : logf(sum_lse) + max_lse; - if (tidx == 0) gLSE(0) = global_lse; - - for (int i = 0; i < kNLsePerThread; ++i) { - const int split = i * 32 + tidx; - if (split < actual_num_splits) sLseScale[split] = expf(local_lse[i] - global_lse); - } - } - __syncthreads(); - - static_assert(kHeadDimV % kNThreads == 0); - constexpr int Elements = kHeadDimV / kNThreads; - const index_t row_offset_oaccum = (split_offset * hs + hs_idx) * kHeadDimV; - Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + row_offset_oaccum), - Shape>{}, Stride<_1>{}); - using GmemTiledCopyOaccum = decltype(make_tiled_copy( - Copy_Atom, ElementAccum>{}, - Layout>>{}, - Layout>>{})); - GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; - auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); - Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum); - Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); - Tensor tOrO = make_tensor(shape(tOgOaccum)); - clear(tOrO); - - for (int split = 0; split < actual_num_splits; ++split) { - cute::copy(tOgOaccum, tOrOaccum); - ElementAccum lse_scale = sLseScale[split]; - for (int i = 0; i < size(tOrO); ++i) { - tOrO(i) += lse_scale * tOrOaccum(i); - } - tOgOaccum.data() = tOgOaccum.data() + hs * kHeadDimV; - } - - Tensor rO = flash::convert_type(tOrO); - const int head_idx = (bidx - batch_idx * hs) / params.seqlen_q; - const int row = bidx - batch_idx * hs - head_idx * params.seqlen_q; - auto o_ptr = reinterpret_cast(params.o_ptr) + batch_idx * params.o_batch_stride + head_idx * params.o_head_stride + row * params.o_row_stride; - Tensor gO = make_tensor(make_gmem_ptr(o_ptr + tidx * Elements), Shape(rO))::value>>{}, Stride<_1>{}); - cute::copy(rO, gO); -} - -} // namespace flash - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -void run_flash_splitkv_fwd_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream) { - FLASH_ASSERT(params.page_block_size == Kernel_traits::kBlockN); - const int num_m_block = cute::ceil_div(params.seqlen_q, Kernel_traits::kBlockM); - BOOL_SWITCH(params.is_causal, Is_causal, [&] { - auto kernel = &flash::flash_fwd_splitkv_mla_kernel; - constexpr size_t smem_size = sizeof(SharedStorage); - CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - kernel<<>>(params); - }); - CHECK_CUDA_KERNEL_LAUNCH(); - - dim3 grid_combine(params.b * params.h * params.seqlen_q); - MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, kMaxSplits, [&] { - auto combine_kernel = &flash::flash_fwd_splitkv_mla_combine_kernel< - typename Kernel_traits::Element, typename Kernel_traits::ElementAccum, typename Kernel_traits::index_t, Kernel_traits::kHeadDimV, kMaxSplits>; - combine_kernel<<>>(params); - }); - CHECK_CUDA_KERNEL_LAUNCH(); -} - -template -void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream) { - static_assert(Headdim == 576); - FLASH_ASSERT(params.d_v == 512); - FLASH_ASSERT(params.k_ptr == params.v_ptr); // Shared_KV - using Kernel_traits = Flash_fwd_kernel_traits_mla<576, 64, 64, 8, T, 512>; - run_flash_splitkv_fwd_mla>(params, stream); -} diff --git a/csrc/kernels/config.h b/csrc/kernels/config.h new file mode 100644 index 0000000..c9ce159 --- /dev/null +++ b/csrc/kernels/config.h @@ -0,0 +1,13 @@ +#pragma once + +namespace Config { + +static constexpr int BLOCK_SIZE_M = 64; +static constexpr int PAGE_BLOCK_SIZE = 64; + +static constexpr int HEAD_DIM_K = 576; +static constexpr int HEAD_DIM_V = 512; + +static constexpr int FIXED_OVERHEAD_NUM_BLOCKS = 5; + +} diff --git a/csrc/flash_fwd_mla_metadata.cu b/csrc/kernels/get_mla_metadata.cu similarity index 79% rename from csrc/flash_fwd_mla_metadata.cu rename to csrc/kernels/get_mla_metadata.cu index 82f5b5a..6b78f9b 100644 --- a/csrc/flash_fwd_mla_metadata.cu +++ b/csrc/kernels/get_mla_metadata.cu @@ -1,8 +1,11 @@ -#include "flash_fwd_mla_kernel.h" +#include "get_mla_metadata.h" -static constexpr int MaxBatchSize = 4096; +#include +#include -__global__ void __launch_bounds__(256, 1, 1) +#include "utils.h" + +__global__ void __launch_bounds__(32, 1, 1) get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) { int *seqlens_k_ptr = params.seqlens_k_ptr; int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr; @@ -12,8 +15,9 @@ get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) { int fixed_overhead_num_blocks = params.fixed_overhead_num_blocks; int num_sm_parts = params.num_sm_parts; - __shared__ int num_blocks_shared[MaxBatchSize]; - __shared__ int num_splits_shared[MaxBatchSize]; + extern __shared__ int shared_mem[]; + int* num_blocks_shared = shared_mem; // [batch_size] + int* num_splits_shared = shared_mem + batch_size; // [batch_size+1] int total_num_blocks = 0; for (int i = threadIdx.x; i < batch_size; i += 32) { @@ -27,7 +31,7 @@ get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) { __syncwarp(); if (threadIdx.x == 0) { - int payload = cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks; + int payload = max(cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks, 2*fixed_overhead_num_blocks); int now_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0; num_splits_shared[0] = 0; @@ -70,8 +74,9 @@ get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) { } } -void get_mla_metadata_func(Mla_metadata_params ¶ms, cudaStream_t stream) { - FLASH_ASSERT(params.batch_size < MaxBatchSize); - get_mla_metadata_kernel<<<1, 32, 0, stream>>>(params); +void run_get_mla_metadata_kernel(Mla_metadata_params ¶ms, cudaStream_t stream) { + int smem_size = sizeof(int) * (params.batch_size*2+1); + CHECK_CUDA(cudaFuncSetAttribute(get_mla_metadata_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + get_mla_metadata_kernel<<<1, 32, smem_size, stream>>>(params); CHECK_CUDA_KERNEL_LAUNCH(); -} \ No newline at end of file +} diff --git a/csrc/kernels/get_mla_metadata.h b/csrc/kernels/get_mla_metadata.h new file mode 100644 index 0000000..5130581 --- /dev/null +++ b/csrc/kernels/get_mla_metadata.h @@ -0,0 +1,5 @@ +#pragma once + +#include "params.h" + +void run_get_mla_metadata_kernel(Mla_metadata_params ¶ms, cudaStream_t stream); diff --git a/csrc/kernels/mla_combine.cu b/csrc/kernels/mla_combine.cu new file mode 100644 index 0000000..b6ba8f8 --- /dev/null +++ b/csrc/kernels/mla_combine.cu @@ -0,0 +1,207 @@ +#include "mla_combine.h" + +#include +#include +#include +#include + +#include "params.h" +#include "utils.h" +#include "config.h" // for BLOCK_SIZE_M and HEAD_DIM_V + +using namespace cute; + +template +__global__ void __launch_bounds__(NUM_THREADS) +flash_fwd_mla_combine_kernel(__grid_constant__ const Flash_fwd_mla_params params) { + // grid_shape: [batch_size, num_q_heads*s_q / BLOCK_SIZE_M] + // Each CTA gathers the activation of some heads from one batch, do scaling & accumulation, and save the result + static_assert(NUM_THREADS/32 == BLOCK_SIZE_M); // The number of warps == block_size_m + const int batch_idx = blockIdx.x; + const int m_block_idx = blockIdx.y; + const int warp_idx = threadIdx.x / 32; + const int lane_idx = threadIdx.x % 32; + + const int start_split_idx = __ldg(params.num_splits_ptr + batch_idx); + const int end_split_idx = __ldg(params.num_splits_ptr + batch_idx + 1); + const int my_num_splits = end_split_idx - start_split_idx; + FLASH_DEVICE_ASSERT(my_num_splits <= MAX_SPLITS); + if (my_num_splits == 1) { + return; + } + + const int num_q_seqs = params.q_seq_per_hk * params.h_k; + const int num_cur_valid_q_seqs = min(BLOCK_SIZE_M, num_q_seqs - m_block_idx*BLOCK_SIZE_M); + Tensor gLseAccum = make_tensor( + make_gmem_ptr((float*)params.softmax_lseaccum_ptr + start_split_idx*num_q_seqs + m_block_idx*BLOCK_SIZE_M), + Shape, Int>{}, + make_stride(num_q_seqs, _1{}) + ); + Tensor gLse = make_tensor( + make_gmem_ptr((float*)params.softmax_lse_ptr + batch_idx*num_q_seqs + m_block_idx*BLOCK_SIZE_M), + Shape>{}, + Stride<_1>{} + ); + + extern __shared__ float smem_buf[]; + Tensor sLseScale = make_tensor( + make_smem_ptr(smem_buf), + Shape, Int>{}, + Stride, _1>{} // +1 to avoid bank conflict + ); + + // Wait for the previous kernel (the MLA kernel) to finish + cudaGridDependencySynchronize(); + + // Read gLseAccum into sLseScale + { + #pragma unroll 4 + for (int elem_idx = threadIdx.x; elem_idx < my_num_splits*BLOCK_SIZE_M; elem_idx += NUM_THREADS) { + int split_idx = elem_idx / BLOCK_SIZE_M; + int seq_idx = elem_idx % BLOCK_SIZE_M; + sLseScale(seq_idx, split_idx) = seq_idx < num_cur_valid_q_seqs ? gLseAccum(split_idx, seq_idx) : -INFINITY; + } + __syncthreads(); + } + + if (warp_idx >= num_cur_valid_q_seqs) + return; + + // Warp #i gathers LseAccum for seq #i + { + constexpr int NUM_LSE_PER_THREAD = cute::ceil_div(MAX_SPLITS, 32); + float local_lse[NUM_LSE_PER_THREAD]; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) { + const int split_idx = i*32 + lane_idx; + local_lse[i] = split_idx < my_num_splits ? sLseScale(warp_idx, split_idx) : -INFINITY; + } + + float max_lse = -INFINITY; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) + max_lse = max(max_lse, local_lse[i]); + CUTLASS_PRAGMA_UNROLL + for (int offset = 16; offset >= 1; offset /= 2) + max_lse = max(max_lse, __shfl_xor_sync(uint32_t(-1), max_lse, offset)); + max_lse = max_lse == -INFINITY ? 0.0f : max_lse; // In case all local LSEs are -inf + + float sum_lse = 0; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) + sum_lse = sum_lse + exp2f(local_lse[i] - max_lse); + CUTLASS_PRAGMA_UNROLL + for (int offset = 16; offset >= 1; offset /= 2) + sum_lse = sum_lse + __shfl_xor_sync(uint32_t(-1), sum_lse, offset); + + float global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? INFINITY : log2f(sum_lse) + max_lse; + if (lane_idx == 0) + gLse(warp_idx) = global_lse / (float)M_LOG2E; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) { + const int split_idx = i*32 + lane_idx; + if (split_idx < my_num_splits) sLseScale(warp_idx, split_idx) = exp2f(local_lse[i] - global_lse); + } + } + + __syncwarp(); + + // Warp #i accumulates activation for seq #i + { + const int64_t row_offset_oaccum = (int64_t)(start_split_idx*num_q_seqs+m_block_idx*BLOCK_SIZE_M+warp_idx) * HEAD_DIM_V; + Tensor gOaccum = make_tensor( + make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + row_offset_oaccum), + Shape, Int>{}, + make_stride(num_q_seqs*HEAD_DIM_V, _1{}) + ); + + static_assert(HEAD_DIM_V % 32 == 0); + constexpr int ELEMS_PER_THREAD = HEAD_DIM_V / 32; + float result[ELEMS_PER_THREAD]; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ELEMS_PER_THREAD; ++i) + result[i] = 0.0f; + + #pragma unroll 2 + for (int split = 0; split < my_num_splits; ++split) { + float lse_scale = sLseScale(warp_idx, split); + if (lse_scale != 0.f) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ELEMS_PER_THREAD; ++i) { + result[i] += lse_scale * gOaccum(split, lane_idx + i*32); + } + } + } + + cudaTriggerProgrammaticLaunchCompletion(); + + const int q_seq_idx = m_block_idx*BLOCK_SIZE_M + warp_idx; + const int k_head_idx = q_seq_idx / params.q_seq_per_hk; + auto o_ptr = reinterpret_cast(params.o_ptr) + batch_idx*params.o_batch_stride + k_head_idx*params.o_head_stride + (q_seq_idx%params.q_seq_per_hk)*params.o_row_stride; + Tensor gO = make_tensor( + make_gmem_ptr(o_ptr), + Shape>{}, + Stride<_1>{} + ); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ELEMS_PER_THREAD; ++i) + gO(lane_idx+i*32) = (ElementT)result[i]; + } +} + + +#define MLA_NUM_SPLITS_SWITCH(NUM_SPLITS, NAME, ...) \ + [&] { \ + if (NUM_SPLITS <= 32) { \ + constexpr static int NAME = 32; \ + return __VA_ARGS__(); \ + } else if (NUM_SPLITS <= 64) { \ + constexpr static int NAME = 64; \ + return __VA_ARGS__(); \ + } else if (NUM_SPLITS <= 96) { \ + constexpr static int NAME = 96; \ + return __VA_ARGS__(); \ + } else if (NUM_SPLITS <= 128) { \ + constexpr static int NAME = 128; \ + return __VA_ARGS__(); \ + } else if (NUM_SPLITS <= 160) { \ + constexpr static int NAME = 160; \ + return __VA_ARGS__(); \ + } else { \ + FLASH_ASSERT(false); \ + } \ + }() + + +template +void run_flash_mla_combine_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream) { + MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, NUM_SPLITS, [&] { + constexpr int BLOCK_SIZE_M = 8; + constexpr int NUM_THREADS = BLOCK_SIZE_M*32; + constexpr size_t smem_size = BLOCK_SIZE_M*(NUM_SPLITS+1)*sizeof(float); + auto combine_kernel = &flash_fwd_mla_combine_kernel; + CHECK_CUDA(cudaFuncSetAttribute(combine_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + // Use cudaLaunchKernelEx to enable PDL (Programmatic Dependent Launch) + cudaLaunchAttribute attribute[1]; + attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attribute[0].val.programmaticStreamSerializationAllowed = 1; + cudaLaunchConfig_t combine_kernel_config = { + dim3(params.b, cute::ceil_div(params.h_k*params.q_seq_per_hk, BLOCK_SIZE_M), 1), + dim3(NUM_THREADS, 1, 1), + smem_size, + stream, + attribute, + 1 + }; + cudaLaunchKernelEx(&combine_kernel_config, combine_kernel, params); + }); + CHECK_CUDA_KERNEL_LAUNCH(); +} + +template void run_flash_mla_combine_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream); + +#ifndef FLASH_MLA_DISABLE_FP16 +template void run_flash_mla_combine_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream); +#endif \ No newline at end of file diff --git a/csrc/kernels/mla_combine.h b/csrc/kernels/mla_combine.h new file mode 100644 index 0000000..69035e9 --- /dev/null +++ b/csrc/kernels/mla_combine.h @@ -0,0 +1,6 @@ +#pragma once + +#include "params.h" + +template +void run_flash_mla_combine_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_mla.h b/csrc/kernels/params.h similarity index 71% rename from csrc/flash_mla.h rename to csrc/kernels/params.h index 2994cb7..3b4e254 100644 --- a/csrc/flash_mla.h +++ b/csrc/kernels/params.h @@ -5,39 +5,41 @@ struct Flash_fwd_mla_params { using index_t = int64_t; - int b, seqlen_q, d, d_v; - int h, h_h_k_ratio, ngroups; + int b; // batch size + int s_q; + int q_seq_per_hk; // The number of q(s) per KV head, = h_q / h_k * s_q + int d, d_v; // K/V dimension + int h_q, h_k; // The number of Q/K heads + int num_blocks; // Number of blocks in total + int q_head_per_hk; // The number of q_head(s) per KV head, = h_q / h_k bool is_causal; float scale_softmax, scale_softmax_log2; - int *__restrict__ cu_seqlens_k; - + void *__restrict__ q_ptr; void *__restrict__ k_ptr; - void *__restrict__ v_ptr; void *__restrict__ o_ptr; void *__restrict__ softmax_lse_ptr; index_t q_batch_stride; index_t k_batch_stride; - index_t v_batch_stride; index_t o_batch_stride; index_t q_row_stride; index_t k_row_stride; - index_t v_row_stride; index_t o_row_stride; index_t q_head_stride; index_t k_head_stride; - index_t v_head_stride; index_t o_head_stride; int *__restrict__ block_table; index_t block_table_batch_stride; int page_block_size; + int *__restrict__ seqlens_k_ptr; int *__restrict__ tile_scheduler_metadata_ptr; int num_sm_parts; int *__restrict__ num_splits_ptr; + int total_num_splits; void *__restrict__ softmax_lseaccum_ptr; void *__restrict__ oaccum_ptr; }; @@ -45,11 +47,6 @@ struct Flash_fwd_mla_params { static constexpr int TileSchedulerMetaDataSize = 8; // [begin_idx, begin_seqlen, end_idx, end_seqlen, begin_n_split_idx, _, _, _] -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream); - struct Mla_metadata_params { int *__restrict__ seqlens_k_ptr; int *__restrict__ tile_scheduler_metadata_ptr; @@ -59,5 +56,3 @@ struct Mla_metadata_params { int fixed_overhead_num_blocks; int num_sm_parts; }; - -void get_mla_metadata_func(Mla_metadata_params ¶ms, cudaStream_t stream); diff --git a/csrc/kernels/splitkv_mla.cu b/csrc/kernels/splitkv_mla.cu new file mode 100644 index 0000000..ff29305 --- /dev/null +++ b/csrc/kernels/splitkv_mla.cu @@ -0,0 +1,1350 @@ +#include + +#include "params.h" +#include "utils.h" +#include "config.h" +#include "traits.h" + +using namespace cute; +using cutlass::arch::NamedBarrier; + +// Here we use MAX_INIT_VAL_SM to initialize sM, and MAX_INIT_VAL for masking +// The reason is that, we need to calculate new_max = max(sM(row_idx), cur_max*scale_softmax_log2) +// so we must guarantee that MAX_INIT_VAL*scale_softmax_log2 < MAX_INIT_VAL_SM +static constexpr float MAX_INIT_VAL_SM = -1e30f; +static constexpr float MAX_INIT_VAL = -1e33f; + + +__forceinline__ __device__ int get_AorC_row_idx(int local_row_idx, int idx_in_warpgroup) { + // In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in two particular rows. This function converts the local_row_idx (0~2) to the actual row_idx + // You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a + int row_idx = (idx_in_warpgroup/32)*16 + local_row_idx*8 + (idx_in_warpgroup%32/4); + return row_idx; +} + +// Launch TMA copy for a range of KV tile +// A tile has a shape of PAGE_BLOCK_SIZE (64) x 64 +template< + int START_HEAD_DIM_TILE_IDX, + int END_HEAD_DIM_TILE_IDX, + typename TMA_K_OneTile, + typename Engine0, typename Layout0, + typename Engine1, typename Layout1 +> +__forceinline__ __device__ void launch_kv_tiles_copy_tma( + Tensor const &gKV, // (PAGE_BLOCK_SIZE, HEAD_DIM_K) + Tensor &sKV, // (PAGE_BLOCK_SIZE, HEAD_DIM_K), swizzled + TMA_K_OneTile &tma_K, + TMABarrier* barriers_K, + int idx_in_warpgroup +) { + if (idx_in_warpgroup == 0) { + auto thr_tma = tma_K.get_slice(_0{}); + Tensor cur_gKV = thr_tma.partition_S(gKV)(_, _0{}, Int{}); + Tensor cur_sKV = thr_tma.partition_D(sKV)(_, _0{}, Int{}); + cute::copy(tma_K.with(reinterpret_cast(barriers_K[START_HEAD_DIM_TILE_IDX]), 0, cute::TMA::CacheHintSm90::EVICT_FIRST), cur_gKV, cur_sKV); + if constexpr (START_HEAD_DIM_TILE_IDX+1 < END_HEAD_DIM_TILE_IDX) { + launch_kv_tiles_copy_tma(gKV, sKV, tma_K, barriers_K, idx_in_warpgroup); + } + } +} + +// Prefetch some KV tiles +// Currently this is not used because it leads to performance degradation +template< + int START_HEAD_DIM_TILE_IDX, + int END_HEAD_DIM_TILE_IDX, + typename TMA_K_OneTile, + typename Engine0, typename Layout0 +> +__forceinline__ __device__ void prefetch_kv_tiles( + Tensor const &gKV, // (PAGE_BLOCK_SIZE, HEAD_DIM_K) + TMA_K_OneTile &tma_K, + int idx_in_warpgroup +) { + if (idx_in_warpgroup == 0) { + auto thr_tma = tma_K.get_slice(_0{}); + Tensor cur_gKV = thr_tma.partition_S(gKV)(_, _0{}, Int{}); + cute::prefetch(tma_K, cur_gKV); + if constexpr (START_HEAD_DIM_TILE_IDX+1 < END_HEAD_DIM_TILE_IDX) { + prefetch_kv_tiles(gKV, tma_K, idx_in_warpgroup); + } + } +} + +// Adapted from https://github.com/Dao-AILab/flash-attention/blob/cdaf2de6e95cb05400959b5ab984f66e4c7df317/hopper/utils.h +// * Copyright (c) 2024, Tri Dao. +template +__forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) { + constexpr bool Is_RS = !cute::is_base_of::value; + // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const + if constexpr (Is_RS) { cute::warpgroup_fence_operand(const_cast(tCrA)); } + warpgroup_fence_operand(tCrC); + if constexpr (arrive) { + warpgroup_arrive(); + } + if constexpr (zero_init) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + } else { + // cute::gemm(tiled_mma, tCrA, tCrB, tCrC); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + } + if constexpr (commit) { + warpgroup_commit_batch(); + } + if constexpr (wg_wait >= 0) { warpgroup_wait(); } + warpgroup_fence_operand(tCrC); + if constexpr (Is_RS) { warpgroup_fence_operand(const_cast(tCrA)); } +} + + +// Wait for one KV-tile to be ready, and then calculate P += Q K^T for one Q-tile (BLOCK_SIZE_Mx64) and one KV-tile (PAGE_BLOCK_SIZEx64) +// The Q-tile should be in shared memory +template< + typename TiledMMA, + typename Engine0, typename Layout0, + typename Engine1, typename Layout1, + typename Engine2, typename Layout2 +> +__forceinline__ __device__ void qkt_gemm_one_tile_sQ( + TiledMMA &tiled_mma, + Tensor const &thr_mma_sQ_tile, // (MMA, 1, 4) + Tensor const &thr_mma_sKV_tile, // (MMA, 1, 4) + Tensor &rP, // ((2, 2, 8), 1, 1) + TMABarrier* barrier, + bool &cur_phase, + int idx_in_warpgroup +) { + if (idx_in_warpgroup == 0) { + barrier->arrive_and_expect_tx(64*64*2); + } + barrier->wait(cur_phase ? 1 : 0); + + warpgroup_fence_operand(rP); + warpgroup_arrive(); + cute::gemm(tiled_mma, thr_mma_sQ_tile(_, _, _0{}), thr_mma_sKV_tile(_, _, _0{}), rP); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + cute::gemm(tiled_mma, thr_mma_sQ_tile(_, _, _1{}), thr_mma_sKV_tile(_, _, _1{}), rP); + cute::gemm(tiled_mma, thr_mma_sQ_tile(_, _, _2{}), thr_mma_sKV_tile(_, _, _2{}), rP); + cute::gemm(tiled_mma, thr_mma_sQ_tile(_, _, _3{}), thr_mma_sKV_tile(_, _, _3{}), rP); + warpgroup_commit_batch(); + warpgroup_fence_operand(rP); +} + +template< + typename TiledMMA, + typename Engine0, typename Layout0, + typename Engine1, typename Layout1, + typename Engine2, typename Layout2 +> +__forceinline__ __device__ void qkt_gemm_one_tile_rQ( + TiledMMA &tiled_mma, + Tensor const &thr_mma_rQ_tile, // (MMA, 1, 4) + Tensor const &thr_mma_sKV_tile, // (MMA, 1, 4) + Tensor &rP, // ((2, 2, 8), 1, 1) + TMABarrier* barrier, + bool &cur_phase, + int idx_in_warpgroup +) { + if (idx_in_warpgroup == 0) { + barrier->arrive_and_expect_tx(64*64*2); + } + barrier->wait(cur_phase ? 1 : 0); + + warpgroup_fence_operand(const_cast &>(thr_mma_rQ_tile)); + warpgroup_fence_operand(rP); + warpgroup_arrive(); + cute::gemm(tiled_mma, thr_mma_rQ_tile(_, _, _0{}), thr_mma_sKV_tile(_, _, _0{}), rP); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + cute::gemm(tiled_mma, thr_mma_rQ_tile(_, _, _1{}), thr_mma_sKV_tile(_, _, _1{}), rP); + cute::gemm(tiled_mma, thr_mma_rQ_tile(_, _, _2{}), thr_mma_sKV_tile(_, _, _2{}), rP); + cute::gemm(tiled_mma, thr_mma_rQ_tile(_, _, _3{}), thr_mma_sKV_tile(_, _, _3{}), rP); + warpgroup_commit_batch(); + warpgroup_fence_operand(rP); + warpgroup_fence_operand(const_cast &>(thr_mma_rQ_tile)); +} + +// Pipelined TMA wait and Q K^T gemm +// In order to overlap memory copy (G->S copy for K) and computation, we divide both Q and K into tiles of shape (BLOCK_SIZE_M, 64), and (PAGE_BLOCK_SIZE, 64) respectively, and then do the computation as follows: +// - Wait for the 0-th tile to be ready using `barrier.wait()` +// - Compute Q K^T for the 0-th tile +// - Wait for the 1-st tile to be ready +// - Compute Q K^T for the 1-st tile +// ... +// This gives latter tiles more time to be ready, and thus can overlap the memory copy and computation +template< + typename T, // Traits + int PHASE_IDX, // See comments in the code + typename Engine0, typename Layout0, + typename Engine1, typename Layout1, + typename Engine2, typename Layout2, + typename Engine3, typename Layout3 +> +__forceinline__ __device__ void warpgroup_cooperative_qkt_gemm( + Tensor &sQ, // (BLOCK_SIZE_M, HEAD_DIM_K) + Tensor &sKV, // (PAGE_BLOCK_SIZE, HEAD_DIM_K) + Tensor &rP, // ((2, 2, 8), 1, 1) + Tensor &rQ8, // The 8-th tile of Q. We store it separately to leave some room for storing sP1 + TMABarrier* barriers, + bool &cur_phase, + int idx_in_warpgroup +) { + Tensor sQ_tiled = flat_divide(sQ, Shape, _64>{})(_, _, _0{}, _); // (BLOCK_SIZE_M, 64, 9) + Tensor sKV_tiled = flat_divide(sKV, Shape, _64>{})(_, _, _0{}, _); // (PAGE_BLOCK_SIZE, 64, 9) + TiledMMA tiled_mma_sQ = (typename T::TiledMMA_QK_sQ){}; + ThrMMA thr_mma_sQ = tiled_mma_sQ.get_slice(idx_in_warpgroup); + Tensor thr_mma_sQ_tiled = thr_mma_sQ.partition_fragment_A(sQ_tiled); // (MMA, 1, 4, 9) + Tensor thr_mma_sKV_tiled = thr_mma_sQ.partition_fragment_B(sKV_tiled); // (MMA, 1, 4, 9) + TiledMMA tiled_mma_rQ = (typename T::TiledMMA_QK_rQ){}; + + #define QKT_GEMM_ONE_TILE(TILE_IDX) \ + if constexpr(TILE_IDX != 8) { \ + qkt_gemm_one_tile_sQ(tiled_mma_sQ, thr_mma_sQ_tiled(_, _, _, Int{}), thr_mma_sKV_tiled(_, _, _, Int{}), rP, barriers + TILE_IDX, cur_phase, idx_in_warpgroup); \ + } else { \ + qkt_gemm_one_tile_rQ(tiled_mma_rQ, rQ8, thr_mma_sKV_tiled(_, _, _, Int{}), rP, barriers + TILE_IDX, cur_phase, idx_in_warpgroup); \ + } + + if constexpr (PHASE_IDX == 0) { + // In PHASE-0, warpgroup 0 calculates Q K^T for the first 4 tiles + tiled_mma_sQ.accumulate_ = GMMA::ScaleOut::Zero; + tiled_mma_rQ.accumulate_ = GMMA::ScaleOut::One; + QKT_GEMM_ONE_TILE(0); + QKT_GEMM_ONE_TILE(1); + QKT_GEMM_ONE_TILE(2); + QKT_GEMM_ONE_TILE(3); + } else if constexpr (PHASE_IDX == 1) { + // In PHASE-1, warpgroup 1 calculates Q K^T for all the 9 tiles + tiled_mma_sQ.accumulate_ = GMMA::ScaleOut::Zero; + tiled_mma_rQ.accumulate_ = GMMA::ScaleOut::One; + QKT_GEMM_ONE_TILE(4); + QKT_GEMM_ONE_TILE(5); + QKT_GEMM_ONE_TILE(6); + QKT_GEMM_ONE_TILE(7); + QKT_GEMM_ONE_TILE(8); + QKT_GEMM_ONE_TILE(0); + QKT_GEMM_ONE_TILE(1); + QKT_GEMM_ONE_TILE(2); + QKT_GEMM_ONE_TILE(3); + cur_phase ^= 1; + } else { + // In PHASE-2, warpgroup 0 calculates Q K^T for the last 5 tiles + static_assert(PHASE_IDX == 2); + tiled_mma_sQ.accumulate_ = GMMA::ScaleOut::One; + tiled_mma_rQ.accumulate_ = GMMA::ScaleOut::One; + QKT_GEMM_ONE_TILE(4); + QKT_GEMM_ONE_TILE(5); + QKT_GEMM_ONE_TILE(6); + QKT_GEMM_ONE_TILE(7); + QKT_GEMM_ONE_TILE(8); + cur_phase ^= 1; + } +} + + +template< + typename T, + typename Engine0, typename Layout0, + typename Engine1, typename Layout1, + typename Engine2, typename Layout2 +> +__forceinline__ __device__ void warpgroup_cooperative_qkt_gemm_no_pipeline( + Tensor &sQ, // (BLOCK_SIZE_M, HEAD_DIM_K) + Tensor &sKV, // (BLOCK_SIZE_M, HEAD_DIM_K) + Tensor &rP, // ((2, 2, 8), 1, 1) + int idx_in_warpgroup +) { + TiledMMA tiled_mma = (typename T::TiledMMA_QK_sQ){}; + ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup); + Tensor thr_mma_sQ = thr_mma.partition_fragment_A(sQ); // (MMA, 1, 576/16=36) + Tensor thr_mma_sKV = thr_mma.partition_fragment_B(sKV); // (MMA, 1, 576/16=36) + gemm(tiled_mma, thr_mma_sQ, thr_mma_sKV, rP); +} + + +// Compute O += PV, where P resides in register +template< + typename T, + typename Engine0, typename Layout0, + typename Engine1, typename Layout1, + typename Engine2, typename Layout2 +> +__forceinline__ __device__ void warpgroup_cooperative_pv_gemm_localP( + Tensor &rP, // ((2, 2, 8), 1, 1), fragment A layout + Tensor &sKV_half, // (HEAD_DIM_V/2, PAGE_BLOCK_SIZE) + Tensor &rO, // ((2, 2, 32), 1, 1) + int idx_in_warpgroup +) { + TiledMMA tiled_mma = (typename T::TiledMMA_PV_LocalP){}; + ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup); + Tensor rP_retiled = make_tensor(rP.data(), Layout< + Shape, _1, _4>, + Stride, _0, _8> + >{}); + Tensor thr_mma_sKV_half = thr_mma.partition_fragment_B(sKV_half); // (MMA, 1, 64/16=4) + gemm(tiled_mma, rP_retiled, thr_mma_sKV_half, rO); +} + + +// Compute O += PV, where P resides in shared memory +template< + typename T, + typename Engine0, typename Layout0, + typename Engine1, typename Layout1, + typename Engine2, typename Layout2 +> +__forceinline__ __device__ void warpgroup_cooperative_pv_gemm_remoteP( + Tensor &sP, + Tensor &sKV_half, // (HEAD_DIM_V/2, PAGE_BLOCK_SIZE) + Tensor &rO, // ((2, 2, 32), 1, 1) + int idx_in_warpgroup +) { + TiledMMA tiled_mma = (typename T::TiledMMA_PV_RemoteP){}; + ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup); + Tensor thr_mma_sP = thr_mma.partition_fragment_A(sP); + Tensor thr_mma_sKV_half = thr_mma.partition_fragment_B(sKV_half); // (MMA, 1, 64/16=4) + gemm(tiled_mma, thr_mma_sP, thr_mma_sKV_half, rO); +} + + +template< + typename T, + bool DO_OOB_FILLING, + typename Engine0, typename Layout0, + typename Engine1, typename Layout1, + typename Engine2, typename Layout2, + typename Engine3, typename Layout3, + typename Engine4, typename Layout4 +> +__forceinline__ __device__ void wg0_bunch_0( + Tensor &rPb, // ((2, 2, 8), 1, 1) + Tensor &rP0, // ((2, 2, 8), 1, 1) + Tensor &rO0, // ((2, 2, 32), 1, 1) + Tensor &sScale0, // (BLOCK_SIZE_M) + Tensor &sM, // (BLOCK_SIZE_M) + float rL[2], + int rRightBorderForQSeq[2], + float scale_softmax_log2, + int start_token_idx, + int idx_in_warpgroup +) { + // This piece of code is tightly coupled [Accumulate's layout](https://docs.nvidia.com/cuda/parallel-thread-execution/_images/wgmma-64N16-D.png) + CUTLASS_PRAGMA_UNROLL + for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) { + int row_idx = get_AorC_row_idx(local_row_idx, idx_in_warpgroup); + + // Mask, and get row-wise max + float cur_max = MAX_INIT_VAL; + CUTLASS_PRAGMA_UNROLL + for (int i = local_row_idx ? 2 : 0; i < size(rP0); i += 4) { + if constexpr (DO_OOB_FILLING) { + int token_idx = start_token_idx + (i/4)*8 + idx_in_warpgroup%4*2; + rP0(i) = token_idx < rRightBorderForQSeq[local_row_idx] ? rP0(i) : MAX_INIT_VAL; + rP0(i+1) = token_idx+1 < rRightBorderForQSeq[local_row_idx] ? rP0(i+1) : MAX_INIT_VAL; + } + cur_max = max(cur_max, max(rP0(i), rP0(i+1))); + } + cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 1)); + cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 2)); + + // Update sM and sL + cur_max *= scale_softmax_log2; + float new_max = max(sM(row_idx), cur_max); + float scale_for_old = exp2f(sM(row_idx) - new_max); + __syncwarp(); // Make sure all reads have finished before updating sM + if (idx_in_warpgroup%4 == 0) { + sScale0(row_idx) = scale_for_old; + sM(row_idx) = new_max; + } + + // Scale-O + CUTLASS_PRAGMA_UNROLL + for (int i = local_row_idx ? 2 : 0; i < size(rO0); i += 4) { + rO0(i) *= scale_for_old; + rO0(i+1) *= scale_for_old; + } + + // Scale, exp, and get row-wise expsum + float cur_sum = 0; + CUTLASS_PRAGMA_UNROLL + for (int i = local_row_idx ? 2 : 0; i < size(rP0); i += 4) { + rP0(i) = exp2f(rP0(i)*scale_softmax_log2 - new_max); + rP0(i+1) = exp2f(rP0(i+1)*scale_softmax_log2 - new_max); + rPb(i) = (typename T::InputT)rP0(i); + rPb(i+1) = (typename T::InputT)rP0(i+1); + cur_sum += rP0(i) + rP0(i+1); + } + rL[local_row_idx] = rL[local_row_idx]*scale_for_old + cur_sum; + } +} + + +template< + typename T, + bool IS_BLK0_LAST, + bool IS_BLK1_LAST, + bool IS_BLK2_LAST, + typename Engine0, typename Layout0, + typename Engine1, typename Layout1, + typename Engine2, typename Layout2, + typename Engine3, typename Layout3, + typename Engine4, typename Layout4, + typename Engine5, typename Layout5 +> +__forceinline__ __device__ void wg1_bunch_0( + Tensor &rP1b, // ((2, 2, 8), 1, 1) + Tensor &sScale1, // (BLOCK_SIZE_M) + Tensor &rO1, // ((2, 2, 32), 1, 1) + Tensor &sM, // (BLOCK_SIZE_M) + float rL[2], + int rRightBorderForQSeq[2], + Tensor const &sScale0, // (BLOCK_SIZE_M) + Tensor &rP1, // ((2, 2, 8), 1, 1) + float scale_softmax_log2, + int start_token_idx, + int idx_in_warpgroup +) { + CUTLASS_PRAGMA_UNROLL + for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) { + int row_idx = get_AorC_row_idx(local_row_idx, idx_in_warpgroup); + + // Mask, and get row-wise max + float cur_max = MAX_INIT_VAL; + CUTLASS_PRAGMA_UNROLL + for (int i = local_row_idx ? 2 : 0; i < size(rP1); i += 4) { + if constexpr (IS_BLK1_LAST || IS_BLK2_LAST) { + // Need to apply the mask when either this block is the last one, or + // the next block is the last one (because of the causal mask) + int token_idx = start_token_idx + (i/4)*8 + idx_in_warpgroup%4*2; + rP1(i) = token_idx < rRightBorderForQSeq[local_row_idx] ? rP1(i) : MAX_INIT_VAL; + rP1(i+1) = token_idx+1 < rRightBorderForQSeq[local_row_idx] ? rP1(i+1) : MAX_INIT_VAL; + } else if constexpr (IS_BLK0_LAST) { + rP1(i) = rP1(i+1) = MAX_INIT_VAL; + } + cur_max = max(cur_max, max(rP1(i), rP1(i+1))); + } + cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 1)); + cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 2)); + cur_max *= scale_softmax_log2; + + float old_max = sM(row_idx); + float new_max = max(old_max, cur_max); + float scale_for_old = exp2f(old_max - new_max); + __syncwarp(); + if (idx_in_warpgroup%4 == 0) { + sM(row_idx) = new_max; + sScale1(row_idx) = scale_for_old; + } + + // Scale, exp, and get row-wise expsum + float cur_sum = 0; + if constexpr (!IS_BLK0_LAST) { + CUTLASS_PRAGMA_UNROLL + for (int i = local_row_idx ? 2 : 0; i < size(rP1); i += 4) { + rP1(i) = exp2f(rP1(i)*scale_softmax_log2 - new_max); + rP1(i+1) = exp2f(rP1(i+1)*scale_softmax_log2 - new_max); + rP1b(i) = (typename T::InputT)rP1(i); + rP1b(i+1) = (typename T::InputT)rP1(i+1); + cur_sum += rP1(i) + rP1(i+1); + } + } + + // Scale O + float cur_scale_for_o1 = scale_for_old * sScale0(row_idx); + CUTLASS_PRAGMA_UNROLL + for (int i = local_row_idx ? 2 : 0; i < size(rO1); i += 4) { + rO1(i) *= cur_scale_for_o1; + rO1(i+1) *= cur_scale_for_o1; + } + + // Update rL + rL[local_row_idx] = rL[local_row_idx]*cur_scale_for_o1 + cur_sum; + } +} + + +// Save rPb (64x64, bfloat16/half) to sP using the stmatrix instruction +template< + typename T, + typename Engine0, typename Layout0, + typename Engine1, typename Layout1 +> +__forceinline__ __device__ void save_rPb_to_sP( + Tensor &rPb, + Tensor &sP, + int idx_in_warpgroup +) { + auto r2s_copy = make_tiled_copy_C( + Copy_Atom{}, + (typename T::TiledMMA_QK_sQ){} + ); + ThrCopy thr_copy = r2s_copy.get_slice(idx_in_warpgroup); + Tensor thr_copy_rPb = thr_copy.retile_S(rPb); + Tensor thr_copy_sP = thr_copy.partition_D(sP); + cute::copy(r2s_copy, thr_copy_rPb, thr_copy_sP); +} + + +// Retrieve rPb (64x64, bfloat16/half) from sP using the ldmatrix instruction +template< + typename T, + typename Engine0, typename Layout0, + typename Engine1, typename Layout1 +> +__forceinline__ __device__ void retrieve_rP_from_sP( + Tensor &rPb, + Tensor const &sP, + int idx_in_warpgroup +) { + TiledCopy s2r_copy = make_tiled_copy_A( + Copy_Atom{}, + (typename T::TiledMMA_PV_LocalP){} + ); + ThrCopy thr_copy = s2r_copy.get_slice(idx_in_warpgroup); + Tensor thr_copy_sP = thr_copy.partition_S(sP); + Tensor thr_copy_rPb = thr_copy.retile_D(rPb); + cute::copy(s2r_copy, thr_copy_sP, thr_copy_rPb); +} + + +// Rescale rP0 and save the result to rPb +template< + typename T, + typename Engine0, typename Layout0, + typename Engine1, typename Layout1, + typename Engine2, typename Layout2 +> +__forceinline__ __device__ void wg0_scale_rP0( + Tensor const &sScale1, // (BLOCK_M) + Tensor const &rP0, // ((2, 2, 8), 1, 1) + Tensor &rPb, // ((2, 2, 8), 1, 1) + int idx_in_warpgroup +) { + CUTLASS_PRAGMA_UNROLL + for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) { + int row_idx = get_AorC_row_idx(local_row_idx, idx_in_warpgroup); + float scale_factor = sScale1(row_idx); + CUTLASS_PRAGMA_UNROLL + for (int i = local_row_idx ? 2 : 0; i < size(rP0); i += 4) { + rPb(i) = (typename T::InputT)(rP0(i)*scale_factor); + rPb(i+1) = (typename T::InputT)(rP0(i+1)*scale_factor); + } + } +} + + +// Rescale rO0 according to sScale1 +template< + typename Engine0, typename Layout0, + typename Engine1, typename Layout1 +> +__forceinline__ __device__ void wg0_rescale_rO0( + Tensor &rO0, + Tensor &sScale1, + float rL[2], + int idx_in_warpgroup +) { + CUTLASS_PRAGMA_UNROLL + for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) { + int row_idx = get_AorC_row_idx(local_row_idx, idx_in_warpgroup); + float scale_factor = sScale1(row_idx); + CUTLASS_PRAGMA_UNROLL + for (int i = local_row_idx ? 2 : 0; i < size(rO0); i += 4) { + rO0(i) *= scale_factor; + rO0(i+1) *= scale_factor; + } + rL[local_row_idx] *= scale_factor; + } +} + + +// Fill out-of-bound V with 0.0 +// We must fill it since it may contain NaN, which may propagate to the final result +template< + typename T, + typename Engine0, typename Layout0 +> +__forceinline__ __device__ void fill_oob_V( + Tensor &sV, // tile_to_shape(GMMA::Layout_MN_SW128_Atom{}, Shape, Int>{}, LayoutRight{} ); + int valid_window_size, + int idx_in_warpgroup +) { + Tensor sV_int64 = make_tensor( + make_smem_ptr((int64_t*)(sV.data().get().get())), + tile_to_shape( + GMMA::Layout_MN_SW128_Atom{}, + Shape, Int>{}, + LayoutRight{} + ) + ); + valid_window_size = max(valid_window_size, 0); + int head_dim_size = size<0>(sV_int64); // 128%head_dim_size == 0 should holds + for (int token_idx = valid_window_size + (idx_in_warpgroup/head_dim_size); token_idx < size<1>(sV); token_idx += (128/head_dim_size)) { + sV_int64(idx_in_warpgroup%head_dim_size, token_idx) = 0; + } +} + + +// Store O / OAccum +template< + typename T, + bool IS_NO_SPLIT, + typename TMAParams, + typename Engine0, typename Layout0, + typename Engine1, typename Layout1 +> +__forceinline__ __device__ void store_o( + Tensor &rO, // ((2, 2, 32), 1, 1) + Tensor &gOorAccum, // (BLOCK_SIZE_M, HEAD_DIM_V) + float rL[2], + char* sO_addr, + TMAParams &tma_params, + int batch_idx, + int k_head_idx, + int m_block_idx, + int num_valid_seq_q, + int warpgroup_idx, + int idx_in_warpgroup +) { + using InputT = typename T::InputT; + if constexpr (IS_NO_SPLIT) { + // Should convert the output to bfloat16 / float16, and save it to O + Tensor sOutputBuf = make_tensor(make_smem_ptr((InputT*)sO_addr), tile_to_shape( + GMMA::Layout_K_SW128_Atom{}, + Shape, Int>{} + )); + + Tensor rOb = make_tensor_like(rO); + CUTLASS_PRAGMA_UNROLL + for (int idx = 0; idx < size(rO); ++idx) { + rOb(idx) = (InputT)(rO(idx) / rL[idx%4 >= 2]); + } + + Tensor sMyOutputBuf = local_tile(sOutputBuf, Shape<_64, _256>{}, make_coord(_0{}, warpgroup_idx)); + TiledCopy r2s_tiled_copy = make_tiled_copy_C( + Copy_Atom{}, + (typename T::TiledMMA_PV_LocalP){} + ); + ThrCopy r2s_thr_copy = r2s_tiled_copy.get_slice(idx_in_warpgroup); + Tensor r2s_thr_copy_rOb = r2s_thr_copy.retile_S(rOb); + Tensor r2s_thr_copy_sMyOutputBuf = r2s_thr_copy.partition_D(sMyOutputBuf); + cute::copy(r2s_tiled_copy, r2s_thr_copy_rOb, r2s_thr_copy_sMyOutputBuf); + cutlass::arch::fence_view_async_shared(); + + __syncthreads(); + + if (threadIdx.x == 0) { + Tensor tma_gO = tma_params.tma_O.get_tma_tensor(tma_params.shape_O)(_, _, k_head_idx, batch_idx); // (seqlen_q, HEAD_DIM) + auto thr_tma = tma_params.tma_O.get_slice(_0{}); + Tensor my_tma_gO = flat_divide(tma_gO, Shape, Int>{})(_, _, m_block_idx, _0{}); + cute::copy( + tma_params.tma_O, + thr_tma.partition_S(sOutputBuf), + thr_tma.partition_D(my_tma_gO) + ); + cute::tma_store_arrive(); + } + } else { + // Should save the result to OAccum + Tensor sOutputBuf = make_tensor(make_smem_ptr((float*)sO_addr), Layout< + Shape<_64, _512>, + Stride, _1> // We use stride = 520 here to avoid bank conflict + >{}); + + CUTLASS_PRAGMA_UNROLL + for (int idx = 0; idx < size(rO); idx += 2) { + int row = (idx_in_warpgroup/32)*16 + (idx_in_warpgroup%32/4) + (idx%4 >= 2 ? 8 : 0); + int col = warpgroup_idx*256 + (idx_in_warpgroup%4)*2 + idx/4*8; + *(float2*)((float*)sO_addr + sOutputBuf.layout()(row, col)) = float2 { + rO(idx) / rL[idx%4 >= 2], + rO(idx+1) / rL[idx%4 >= 2], + }; + } + cutlass::arch::fence_view_async_shared(); + + __syncthreads(); + + int row = threadIdx.x; + if (row < num_valid_seq_q) { + SM90_BULK_COPY_S2G::copy(&sOutputBuf(row, _0{}), &gOorAccum(row, _0{}), T::HEAD_DIM_V*sizeof(float)); + cute::tma_store_arrive(); + } + } +} + +template< + typename T, + typename TmaParams, typename Tensor0 +> +__forceinline__ __device__ void launch_q_copy( + TmaParams const &tma_params, + int batch_idx, + int m_block_idx, + int k_head_idx, + Tensor0 &sQ, + TMABarrier* barrier_Q +) { + if (threadIdx.x == 0) { + Tensor tma_gQ = tma_params.tma_Q.get_tma_tensor(tma_params.shape_Q)(_, _, k_head_idx, batch_idx); // (seqlen_q, HEAD_DIM) + auto thr_tma = tma_params.tma_Q.get_slice(_0{}); + Tensor my_tma_gQ = flat_divide(tma_gQ, Shape, Int>{})(_, _, m_block_idx, _0{}); + cute::copy( + tma_params.tma_Q.with(reinterpret_cast(*barrier_Q), 0, cute::TMA::CacheHintSm90::EVICT_FIRST), + thr_tma.partition_S(my_tma_gQ), + thr_tma.partition_D(sQ) + ); + barrier_Q->arrive_and_expect_tx(64*576*2); + } +} + +template< + typename T, + bool IS_R, + typename Engine0, typename Layout0 +> +__forceinline__ __device__ auto get_half_V( + Tensor &sK +) { + Tensor sV = make_tensor(sK.data(), (typename T::SmemLayoutV){}); + return flat_divide(sV, Shape, Int>{})(_, _, Int<(int)IS_R>{}, _0{}); +} + +template< + typename T, + bool IS_BLK0_LAST, // "BLK0" means block_idx+0, "BLK1" means block_idx+1, ... + bool IS_BLK1_LAST, + typename TMAParams, + typename Engine0, typename Layout0, + typename Engine1, typename Layout1, + typename Engine2, typename Layout2, + typename Engine3, typename Layout3, + typename Engine4, typename Layout4, + typename Engine5, typename Layout5, + typename Engine6, typename Layout6, + typename Engine7, typename Layout7, + typename Engine8, typename Layout8, + typename Engine9, typename Layout9, + typename Engine10, typename Layout10, + typename Engine11, typename Layout11 +> +__forceinline__ __device__ void wg0_subroutine( + Tensor &tma_gK, + Tensor &sQ, + Tensor &sK0, + Tensor &sK1, + Tensor &sP0, + Tensor &sP1, + Tensor &sM, + Tensor &sScale0, + Tensor &sScale1, + Tensor &rQ8, + Tensor &rP0, + Tensor &rO0, + float rL[2], + int rRightBorderForQSeq[2], + TMABarrier barriers_K0[9], + TMABarrier barriers_K1[9], + bool &cur_phase_K0, + const TMAParams &tma_params, + const Flash_fwd_mla_params ¶ms, + int* block_table_ptr, + int seqlen_k, + int block_idx, + int end_block_idx, + int idx_in_warpgroup +) { + int start_token_idx = block_idx * T::PAGE_BLOCK_SIZE; + #define GET_BLOCK_INDEX(block_idx) ((block_idx) >= end_block_idx ? 0 : __ldg(block_table_ptr + (block_idx))) + int nxt_block0_index = GET_BLOCK_INDEX(block_idx+2); + int nxt_block1_index = GET_BLOCK_INDEX(block_idx+3); + + Tensor sV0L = get_half_V(sK0); + Tensor sV1L = get_half_V(sK1); + + Tensor rPb = make_tensor(Shape, _1, _4>{}); + // Calc P0 = softmax(P0) + wg0_bunch_0(rPb, rP0, rO0, sScale0, sM, rL, rRightBorderForQSeq, params.scale_softmax_log2, start_token_idx, idx_in_warpgroup); + NamedBarrier::arrive(T::NUM_THREADS, NamedBarriers::sScale0Ready); + + // Issue rO0 += rPb @ sV0L + if constexpr (IS_BLK0_LAST) { + fill_oob_V(sV0L, seqlen_k-start_token_idx, idx_in_warpgroup); + cutlass::arch::fence_view_async_shared(); + } + warpgroup_cooperative_pv_gemm_localP(rPb, sV0L, rO0, idx_in_warpgroup); + + // Wait for rO0, launch TMA for the next V0L + cute::warpgroup_wait<0>(); + + // Wait for warpgroup 1, rescale P0, notify warpgroup 1 + NamedBarrier::arrive_and_wait(T::NUM_THREADS, NamedBarriers::sScale1Ready); + if constexpr (!IS_BLK0_LAST && !IS_BLK1_LAST) { + // Put it here seems to be faster, don't know why + launch_kv_tiles_copy_tma<0, 4>(tma_gK(_, _, nxt_block0_index), sK0, tma_params.tma_K, barriers_K0, idx_in_warpgroup); + } + wg0_scale_rP0(sScale1, rP0, rPb, idx_in_warpgroup); + save_rPb_to_sP(rPb, sP0, idx_in_warpgroup); + cutlass::arch::fence_view_async_shared(); + NamedBarrier::arrive(T::NUM_THREADS, NamedBarriers::sP0Ready); + + // Wait for warpgroup 1, rescale O0, issue rO0 += rPb @ sV1L + if constexpr (!IS_BLK0_LAST) { + if constexpr (IS_BLK1_LAST) { + fill_oob_V(sV1L, seqlen_k-start_token_idx-T::PAGE_BLOCK_SIZE, idx_in_warpgroup); + cutlass::arch::fence_view_async_shared(); + } + NamedBarrier::arrive_and_wait(T::NUM_THREADS, NamedBarriers::rO1sP0sV0RIssued); + wg0_rescale_rO0(rO0, sScale1, rL, idx_in_warpgroup); + warpgroup_cooperative_pv_gemm_remoteP(sP1, sV1L, rO0, idx_in_warpgroup); + } + + // Issue P0 = Q @ K0^T + // Since TMAs for these 4 tiles are launched right after rO0 += rPb @ sV0L finishes, they should have already finished. Therefore, we issue the first 4 tiles to fill the pipeline. + if constexpr (!IS_BLK0_LAST && !IS_BLK1_LAST) { + warpgroup_cooperative_qkt_gemm(sQ, sK0, rP0, rQ8, barriers_K0, cur_phase_K0, idx_in_warpgroup); + } + + // Wait for rO0 += rPb @ sV1L, launch TMA + if (!IS_BLK0_LAST && !IS_BLK1_LAST && __builtin_expect(block_idx+3 < end_block_idx, true)) { + cute::warpgroup_wait<4>(); + launch_kv_tiles_copy_tma<0, 4>(tma_gK(_, _, nxt_block1_index), sK1, tma_params.tma_K, barriers_K1, idx_in_warpgroup); + } + + // Issue P0 = Q @ K0^T + if constexpr (!IS_BLK0_LAST && !IS_BLK1_LAST) { + warpgroup_cooperative_qkt_gemm(sQ, sK0, rP0, rQ8, barriers_K0, cur_phase_K0, idx_in_warpgroup); + } + + // Wait for P0 = Q @ K0^T + cute::warpgroup_wait<0>(); +} + + +template< + typename T, + bool IS_BLK0_LAST, + bool IS_BLK1_LAST, + bool IS_BLK2_LAST, + typename TMAParams, + typename Engine0, typename Layout0, + typename Engine1, typename Layout1, + typename Engine2, typename Layout2, + typename Engine3, typename Layout3, + typename Engine4, typename Layout4, + typename Engine5, typename Layout5, + typename Engine6, typename Layout6, + typename Engine7, typename Layout7, + typename Engine8, typename Layout8, + typename Engine9, typename Layout9, + typename Engine10, typename Layout10, + typename Engine11, typename Layout11 +> +__forceinline__ __device__ void wg1_subroutine( + Tensor &tma_gK, + Tensor &sQ, + Tensor &sK0, + Tensor &sK1, + Tensor &sP0, + Tensor &sP1, + Tensor &sM, + Tensor &sScale0, + Tensor &sScale1, + Tensor &rQ8, + Tensor &rP1, + Tensor &rO1, + float rL[2], + int rRightBorderForQSeq[2], + TMABarrier barriers_K0[9], + TMABarrier barriers_K1[9], + bool &cur_phase_K1, + const TMAParams &tma_params, + const Flash_fwd_mla_params ¶ms, + int* block_table_ptr, + int seqlen_k, + int block_idx, + int end_block_idx, + int idx_in_warpgroup +) { + int start_token_idx = block_idx * T::PAGE_BLOCK_SIZE; + int nxt_block0_index = GET_BLOCK_INDEX(block_idx+2); + int nxt_block1_index = GET_BLOCK_INDEX(block_idx+3); + + Tensor rP1b = make_tensor(Shape, _1, _4>{}); + + Tensor sV0R = get_half_V(sK0); + Tensor sV1R = get_half_V(sK1); + + // Wait for rP1 and warpgroup 0, run bunch 1, notify warpgroup 0 + NamedBarrier::arrive_and_wait(T::NUM_THREADS, NamedBarriers::sScale0Ready); + wg1_bunch_0(rP1b, sScale1, rO1, sM, rL, rRightBorderForQSeq, sScale0, rP1, params.scale_softmax_log2, start_token_idx+T::PAGE_BLOCK_SIZE, idx_in_warpgroup); + NamedBarrier::arrive(T::NUM_THREADS, NamedBarriers::sScale1Ready); + + // Save rPb to sP, and issue rO1 += rP1b @ sV1R + // We do this after notifying warpgroup 1, since both "saving rPb to sP" and "issuing" WGMMA are high-latency operations + if constexpr (!IS_BLK0_LAST) { + save_rPb_to_sP(rP1b, sP1, idx_in_warpgroup); + } + if constexpr (!IS_BLK0_LAST) { + if constexpr (IS_BLK1_LAST) { + fill_oob_V(sV1R, seqlen_k-start_token_idx-T::PAGE_BLOCK_SIZE, idx_in_warpgroup); + cutlass::arch::fence_view_async_shared(); + } + warpgroup_cooperative_pv_gemm_localP(rP1b, sV1R, rO1, idx_in_warpgroup); + if constexpr (!IS_BLK1_LAST) { + // We use this proxy for making sP1 visible to the async proxy + // We skip it if IS_BLK1_LAST, since in that case we have already put a fence + cutlass::arch::fence_view_async_shared(); + } + } + + // Wait for sP0, issue rO1 += sP0 @ sV0R, notify warpgroup 0 + NamedBarrier::arrive_and_wait(T::NUM_THREADS, NamedBarriers::sP0Ready); + if constexpr (IS_BLK0_LAST) { + fill_oob_V(sV0R, seqlen_k-start_token_idx, idx_in_warpgroup); + cutlass::arch::fence_view_async_shared(); + } + warpgroup_cooperative_pv_gemm_remoteP(sP0, sV0R, rO1, idx_in_warpgroup); + if constexpr (!IS_BLK0_LAST) { + NamedBarrier::arrive(T::NUM_THREADS, NamedBarriers::rO1sP0sV0RIssued); + } + + // Wait for rO1 += rP1b @ sV1R, launch TMA for the next V1R + if constexpr (!IS_BLK0_LAST && !IS_BLK1_LAST && !IS_BLK2_LAST) { + cute::warpgroup_wait<1>(); + launch_kv_tiles_copy_tma<4, 9>(tma_gK(_, _, nxt_block1_index), sK1, tma_params.tma_K, barriers_K1, idx_in_warpgroup); + } + + // Wait for rO1 += sP0 @ sV0R, launch TMA for the next V0R + if constexpr (!IS_BLK0_LAST && !IS_BLK1_LAST) { + cute::warpgroup_wait<0>(); + launch_kv_tiles_copy_tma<4, 9>(tma_gK(_, _, nxt_block0_index), sK0, tma_params.tma_K, barriers_K0, idx_in_warpgroup); + } + + if constexpr (!IS_BLK0_LAST && !IS_BLK1_LAST && !IS_BLK2_LAST) { + // Issue rP1 = sQ @ sK1, wait + warpgroup_cooperative_qkt_gemm(sQ, sK1, rP1, rQ8, barriers_K1, cur_phase_K1, idx_in_warpgroup); + } + + // We put the `cute::warpgroup_wait<0>()` out of the `if` statement above, otherwise + // nvcc cannot correctly analyse the loop, and will think that we are using accumulator + // registers during the WGMMA pipeline, which results in `WARPGROUP.ARRIVE` and `WARPGROUP.DEPBAR.LE` being inserted in SASS and WGMMA instructions being serialized. + // This is also the reason why we put QK^T here, instead of the first operation in the loop + cute::warpgroup_wait<0>(); +} + +// A helper function for determining the length of the causal mask for one q token +__forceinline__ __device__ int get_mask_len(const Flash_fwd_mla_params ¶ms, int m_block_idx, int local_seq_q_idx) { + int global_seq_q_idx = m_block_idx*Config::BLOCK_SIZE_M + local_seq_q_idx; + if (global_seq_q_idx < params.q_seq_per_hk) { + int s_q_idx = global_seq_q_idx / params.q_head_per_hk; + return params.s_q - s_q_idx - 1; + } else { + // Out-of-bound request, regard as no masks + return 0; + } +} + +template +__global__ void __launch_bounds__(T::NUM_THREADS, 1, 1) +flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params, __grid_constant__ const TmaParams tma_params) { + // grid shape: [ + // num_m_blocks (=ceil_div(seqlen_q_ori*(num_q_heads//num_kv_heads))), + // num_kv_heads, + // num_sm_parts + // ] + // An "sm part" is responsible for all the BLOCK_SIZE_M q_heads in the m_block (as specified by m_block_idx), under one kv head (as specified by k_head_idx), of a segment (as specified by [start_block_idx, end_block_idx]) of one request (as specified by batch_idx). + // If is_no_split is True, then this request is exclusively assigned to this sm_part, so we shall write the result directly into params.o_ptr and params.softmax_lse_ptr. Otherwise, write to oaccum_ptr and softmax_lseaccum_ptr, with the corresponding split idx being (n_split_idx + num_splits_ptr[batch_idx]) + // For the complete schedule of the kernel, please read our deep-dive write-up (link can be found in the README.md file). + + const int m_block_idx = blockIdx.x; + const int k_head_idx = blockIdx.y; + const int partition_idx = blockIdx.z; + const int warpgroup_idx = threadIdx.x / 128; + const int idx_in_warpgroup = threadIdx.x % 128; + + // Define shared tensors + extern __shared__ char wksp_buf[]; + using SharedMemoryPlan = typename T::SharedMemoryPlan; + SharedMemoryPlan &plan = *reinterpret_cast(wksp_buf); + Tensor sQ = make_tensor(make_smem_ptr(plan.smem_sQ.data()), (typename T::SmemLayoutQ){}); + Tensor sK0 = make_tensor(make_smem_ptr(plan.smem_sK0.data()), (typename T::SmemLayoutK){}); + Tensor sK1 = make_tensor(make_smem_ptr(plan.smem_sK1.data()), (typename T::SmemLayoutK){}); + Tensor sP0 = make_tensor(make_smem_ptr(plan.smem_sP0.data()), (typename T::SmemLayoutP0){}); + Tensor sP1 = flat_divide(sQ, Shape, Int>{})(_, _, _0{}, _8{}); // Overlap with sQ's 8-th tile + Tensor sM = make_tensor(make_smem_ptr(plan.smem_sM.data()), make_shape(Int{})); + Tensor sL_reduction_wksp = make_tensor(make_smem_ptr(plan.sL_reduction_wksp.data()), make_shape(Int<2*T::BLOCK_SIZE_M>{})); + Tensor sScale0 = make_tensor(make_smem_ptr(plan.smem_sScale0.data()), make_shape(Int{})); + Tensor sScale1 = make_tensor(make_smem_ptr(plan.smem_sScale1.data()), make_shape(Int{})); + char* sO_addr = (char*)plan.smem_sK0.data(); // Overlap with sK0 and sK1 + + // Prefetch TMA descriptors + if (threadIdx.x == 0) { + cute::prefetch_tma_descriptor(tma_params.tma_Q.get_tma_descriptor()); + cute::prefetch_tma_descriptor(tma_params.tma_K.get_tma_descriptor()); + cute::prefetch_tma_descriptor(tma_params.tma_O.get_tma_descriptor()); + } + + // Define TMA stuffs + Tensor tma_gK = tma_params.tma_K.get_tma_tensor(tma_params.shape_K)(_, _, k_head_idx, _); + TMABarrier* barriers_K0 = plan.barriers_K0; + TMABarrier* barriers_K1 = plan.barriers_K1; + TMABarrier* barrier_Q = &(plan.barrier_Q); + + // Initialize TMA barriers + if (threadIdx.x == 0) { + barrier_Q->init(1); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 9; ++i) { + barriers_K0[i].init(1); + barriers_K1[i].init(1); + } + cutlass::arch::fence_view_async_shared(); + } + __syncthreads(); + bool cur_phase_Q = 0, cur_phase_K0 = 0, cur_phase_K1 = 0; + + // Programmatic Dependent Launch: Wait for the previous kernel to finish + cudaGridDependencySynchronize(); + + int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize; + int4 tile_scheduler_metadata = __ldg(reinterpret_cast(tile_scheduler_metadata_ptr)); + int begin_idx = tile_scheduler_metadata.x; + int begin_seqlen = tile_scheduler_metadata.y; + int end_idx = tile_scheduler_metadata.z; + int end_seqlen = tile_scheduler_metadata.w; + if (begin_idx >= params.b) return; + int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4); + + // Copy the first Q + launch_q_copy(tma_params, begin_idx, m_block_idx, k_head_idx, sQ, barrier_Q); + + #pragma unroll 1 + for (int batch_idx = begin_idx; batch_idx <= end_idx; ++batch_idx) { + constexpr int kBlockN = T::PAGE_BLOCK_SIZE; + const int n_split_idx = batch_idx == begin_idx ? begin_n_split_idx : 0; + int seqlen_k = __ldg(params.seqlens_k_ptr + batch_idx); + const int start_block_idx = batch_idx == begin_idx ? begin_seqlen / kBlockN : 0; + int end_block_idx = batch_idx == end_idx ? cute::ceil_div(end_seqlen, kBlockN) : cute::ceil_div(seqlen_k, kBlockN); + const bool is_no_split = start_block_idx == 0 && end_block_idx == cute::ceil_div(seqlen_k, kBlockN); + + int rRightBorderForQSeq[2]; + if (params.is_causal) { + // The causal mask looks like: + // XXXX + // XXXX + // ... + // XXXX + // XXX + // XXX + // ... + // XXX + // XX + // XX + // ... + // XX + // Firstly, there is a common_mask_len, which is the minimum length of causal masks among all tokens. Since the length of the causal mask decreases monotonically, the common_mask_len is the length of the causal mask for the last token. We consider the common_mask_len as a "reduction in the length of the k-sequence.", and adjust end_block_idx based on it, to save some calculation. + // Besides, a token may have some extra masks other than the common mask. We use rRightBorderForQSeq to denote it, which means the right border of the k-sequence for the particular q token. In this way, (seqlen_k-common_mask_len) - rRightBorderForQSeq < 64 holds, which means that we only need to apply the causal mask to the last two KV blocks + // NOTE This may lead to start_block_idx >= end_block_idx which needs some special handling + int common_mask_len = get_mask_len(params, m_block_idx, T::BLOCK_SIZE_M-1); + end_block_idx = batch_idx == end_idx ? cute::ceil_div(min(end_seqlen, seqlen_k-common_mask_len), kBlockN) : cute::ceil_div(seqlen_k-common_mask_len, kBlockN); + + CUTLASS_PRAGMA_UNROLL + for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) { + int row_idx = get_AorC_row_idx(local_row_idx, idx_in_warpgroup); + rRightBorderForQSeq[local_row_idx] = min(seqlen_k-get_mask_len(params, m_block_idx, row_idx), end_block_idx*T::PAGE_BLOCK_SIZE); + } + } else { + rRightBorderForQSeq[0] = rRightBorderForQSeq[1] = seqlen_k; + } + + // Define global tensors + using InputT = typename T::InputT; + InputT* o_ptr = (InputT*)params.o_ptr + batch_idx*params.o_batch_stride + m_block_idx*T::BLOCK_SIZE_M*params.o_row_stride + k_head_idx*params.o_head_stride; // (BLOCK_SIZE_M, HEAD_DIM_V) : (params.o_row_stride, 1) + float* softmax_lse_ptr = (float*)params.softmax_lse_ptr + (batch_idx*params.h_k + k_head_idx)*params.q_seq_per_hk + m_block_idx*T::BLOCK_SIZE_M; // (BLOCK_SIZE_M) : (1) + int* block_table_ptr = params.block_table + batch_idx*params.block_table_batch_stride; // (/) : (1) + + Tensor gO = make_tensor(make_gmem_ptr(o_ptr), make_layout( + Shape, Int>{}, + make_stride(params.o_row_stride, _1{}) + )); + Tensor gSoftmaxLse = make_tensor(make_gmem_ptr(softmax_lse_ptr), Layout< + Shape>, + Stride<_1> + >{}); + + // Copy K0 and K1 + launch_kv_tiles_copy_tma<0, 9>(tma_gK(_, _, __ldg(block_table_ptr + start_block_idx)), sK0, tma_params.tma_K, barriers_K0, threadIdx.x); + if (start_block_idx+1 < end_block_idx) { + launch_kv_tiles_copy_tma<4, 9>(tma_gK(_, _, __ldg(block_table_ptr + start_block_idx+1)), sK1, tma_params.tma_K, barriers_K1, threadIdx.x); + launch_kv_tiles_copy_tma<0, 4>(tma_gK(_, _, __ldg(block_table_ptr + start_block_idx+1)), sK1, tma_params.tma_K, barriers_K1, threadIdx.x); + } + + Tensor rO = partition_fragment_C((typename T::TiledMMA_PV_LocalP){}, Shape, Int>{}); // ((2, 2, 32), 1, 1) + float rL[2]; + rL[0] = rL[1] = 0.0f; + + // Clear buffers + cute::fill(rO, 0.); + if (threadIdx.x < size(sM)) { + sM[threadIdx.x] = MAX_INIT_VAL_SM; + } + + // Wait for Q + barrier_Q->wait(cur_phase_Q); + cur_phase_Q ^= 1; + + Tensor rQ8 = make_tensor(Shape, _1, _4>{}); + retrieve_rP_from_sP(rQ8, local_tile(sQ, Shape<_64, _64>{}, Coord<_0, _8>{}), idx_in_warpgroup); + + if (warpgroup_idx == 0) { + // Warpgroup 0 + Tensor rP0 = make_tensor((typename T::rP0Layout){}); + + // NOTE We don't use the pipelined version of Q K^T here since it leads + // to a slow-down (or even register spilling, thanks to the great NVCC) + // Wait for K0 + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 9; ++i) { + if (idx_in_warpgroup == 0) + barriers_K0[i].arrive_and_expect_tx(64*64*2); + barriers_K0[i].wait(cur_phase_K0); + } + cur_phase_K0 ^= 1; + + // Issue P0 = Q @ K0^T, wait + warpgroup_cooperative_qkt_gemm_no_pipeline(sQ, sK0, rP0, idx_in_warpgroup); + cute::warpgroup_wait<0>(); + + #define LAUNCH_WG0_SUBROUTINE(IS_BLK0_LAST, IS_BLK1_LAST) \ + wg0_subroutine( \ + tma_gK, sQ, sK0, sK1, sP0, sP1, sM, sScale0, sScale1, \ + rQ8, rP0, rO, rL, rRightBorderForQSeq, \ + barriers_K0, barriers_K1, cur_phase_K0, \ + tma_params, params, \ + block_table_ptr, seqlen_k, block_idx, end_block_idx, idx_in_warpgroup \ + ); + + int block_idx = start_block_idx; + #pragma unroll 1 + for (; block_idx < end_block_idx-2; block_idx += 2) { + LAUNCH_WG0_SUBROUTINE(false, false); + } + + if (block_idx+1 < end_block_idx) { + LAUNCH_WG0_SUBROUTINE(false, true); + } else if (block_idx < end_block_idx) { + LAUNCH_WG0_SUBROUTINE(true, false); + } + + } else { + // Warpgroup 1 + Tensor rP1 = make_tensor((typename T::rP0Layout){}); + + if (start_block_idx+1 < end_block_idx) { + // Issue rP1 = sQ @ sK1, wait + warpgroup_cooperative_qkt_gemm(sQ, sK1, rP1, rQ8, barriers_K1, cur_phase_K1, idx_in_warpgroup); + cute::warpgroup_wait<0>(); + } + + #define LAUNCH_WG1_SUBROUTINE(IS_BLK0_LAST, IS_BLK1_LAST, IS_BLK2_LAST) \ + wg1_subroutine( \ + tma_gK, sQ, sK0, sK1, sP0, sP1, sM, sScale0, sScale1, \ + rQ8, rP1, rO, rL, rRightBorderForQSeq, \ + barriers_K0, barriers_K1, cur_phase_K1, \ + tma_params, params, \ + block_table_ptr, seqlen_k, block_idx, end_block_idx, idx_in_warpgroup \ + ); + + int block_idx = start_block_idx; + #pragma unroll 1 + for (; block_idx < end_block_idx-3; block_idx += 2) { + LAUNCH_WG1_SUBROUTINE(false, false, false); + } + + if (block_idx+2 < end_block_idx) { + LAUNCH_WG1_SUBROUTINE(false, false, true); + block_idx += 2; + LAUNCH_WG1_SUBROUTINE(true, false, false); + } else if (block_idx+1 < end_block_idx) { + LAUNCH_WG1_SUBROUTINE(false, true, false); + } else if (block_idx < end_block_idx) { + LAUNCH_WG1_SUBROUTINE(true, false, false); + } + } + + // Reduce rL across threads within the same warp + rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 1); + rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 2); + rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 1); + rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 2); + + // Reduce rL across warpgroups + int my_row = get_AorC_row_idx(0, idx_in_warpgroup); + if (idx_in_warpgroup%4 == 0) { + sL_reduction_wksp[my_row + warpgroup_idx*64] = rL[0]; + sL_reduction_wksp[my_row + 8 + warpgroup_idx*64] = rL[1]; + } + __syncthreads(); + if (warpgroup_idx == 0) { + rL[0] += sL_reduction_wksp[my_row + 64]; + rL[1] += sL_reduction_wksp[my_row + 8 + 64]; + } else { + if (idx_in_warpgroup%4 == 0) { + sL_reduction_wksp[my_row] += rL[0]; + sL_reduction_wksp[my_row + 8] += rL[1]; + } + __syncwarp(); + rL[0] = sL_reduction_wksp[my_row]; + rL[1] = sL_reduction_wksp[my_row+8]; + } + + // Prune out when rL is 0.0f or NaN + // rL may be 0.0f if there are large values (~10^12) in QK^T, which leads + // to exp2f(P(i)*scale-max) = 0.0f or +inf due to FMA error. + // When this happens, we set rL to 1.0f. This aligns with the old version + // of the MLA kernel. + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 2; ++i) + rL[i] = (rL[i] == 0.0f || rL[i] != rL[i]) ? 1.0f : rL[i]; + + // Copy Q for the next batch + if (batch_idx+1 <= end_idx) { + launch_q_copy(tma_params, batch_idx+1, m_block_idx, k_head_idx, sQ, barrier_Q); + } else { + // Allow the next kernel (the combine kernel) to launch + // The next kernel MUST be the combine kernel + cudaTriggerProgrammaticLaunchCompletion(); + } + + int num_valid_seq_q = min(params.q_seq_per_hk - m_block_idx*T::BLOCK_SIZE_M, T::BLOCK_SIZE_M); + if (is_no_split) { + store_o(rO, gO, rL, sO_addr, tma_params, batch_idx, k_head_idx, m_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup); + + int i = threadIdx.x; + if (i < num_valid_seq_q) { + float cur_L = sL_reduction_wksp[i]; + gSoftmaxLse(i) = (cur_L == 0.0f || cur_L != cur_L) ? INFINITY : logf(cur_L) + sM(i) / (float)M_LOG2E; + } + + cute::tma_store_wait<0>(); + } else { + int split_idx = __ldg(params.num_splits_ptr+batch_idx) + n_split_idx; + float* oaccum_ptr = (float*)params.oaccum_ptr + ((split_idx*params.h_k + k_head_idx)*params.q_seq_per_hk + m_block_idx*T::BLOCK_SIZE_M)*T::HEAD_DIM_V; // (BLOCK_SIZE_M, HEAD_DIM_V) : (HEAD_DIM_V, 1) + float* softmax_lseaccum_ptr = (float*)params.softmax_lseaccum_ptr + (split_idx*params.h_k + k_head_idx)*params.q_seq_per_hk + m_block_idx*T::BLOCK_SIZE_M; // (BLOCK_SIZE_M) : (1) + Tensor gOAccum = make_tensor(make_gmem_ptr(oaccum_ptr), Layout< + Shape, Int>, + Stride, _1> + >{}); + Tensor gSoftmaxLseAccum = make_tensor(make_gmem_ptr(softmax_lseaccum_ptr), Layout< + Shape>, + Stride<_1> + >{}); + store_o(rO, gOAccum, rL, sO_addr, tma_params, batch_idx, k_head_idx, m_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup); + + int i = threadIdx.x; + if (i < num_valid_seq_q) { + float cur_L = sL_reduction_wksp[i]; + gSoftmaxLseAccum(i) = (cur_L == 0.0f || cur_L != cur_L) ? -INFINITY : log2f(cur_L) + sM(i); + } + + cute::tma_store_wait<0>(); + } + + if (batch_idx != end_idx) + __syncthreads(); + } +} + + +template +void run_flash_splitkv_mla_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream) { + using T = Traits; + auto shape_Q = make_shape(params.q_seq_per_hk, params.d, params.h_k, params.b); + auto tma_Q = cute::make_tma_copy( + SM90_TMA_LOAD{}, + make_tensor( + make_gmem_ptr((InputT*)params.q_ptr), + make_layout( + shape_Q, + make_stride(params.q_row_stride, _1{}, params.q_head_stride, params.q_batch_stride) + ) + ), + tile_to_shape( + GMMA::Layout_K_SW128_Atom{}, + Shape, Int>{} + ) + ); + auto shape_K = make_shape(Int{}, Int{}, params.h_k, params.num_blocks); + auto tma_K = cute::make_tma_copy( + SM90_TMA_LOAD{}, + make_tensor( + make_gmem_ptr((InputT*)params.k_ptr), + make_layout( + shape_K, + make_stride(params.k_row_stride, _1{}, params.k_head_stride, params.k_batch_stride) + ) + ), + tile_to_shape( + GMMA::Layout_K_SW128_Atom{}, + Layout< + Shape, Int<64>>, + Stride, _1> + >{} + ) + ); + auto shape_O = make_shape(params.q_seq_per_hk, params.d_v, params.h_k, params.b); + auto tma_O = cute::make_tma_copy( + SM90_TMA_STORE{}, + make_tensor( + make_gmem_ptr((InputT*)params.o_ptr), + make_layout( + shape_O, + make_stride(params.o_row_stride, _1{}, params.o_head_stride, params.o_batch_stride) + ) + ), + tile_to_shape( + GMMA::Layout_K_SW128_Atom{}, + Shape, Int>{} + ) + ); + TmaParams tma_params = { + shape_Q, tma_Q, + shape_K, tma_K, + shape_O, tma_O + }; + auto mla_kernel = &flash_fwd_splitkv_mla_kernel; + constexpr size_t smem_size = sizeof(typename T::SharedMemoryPlan); + CHECK_CUDA(cudaFuncSetAttribute(mla_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + // Use cudaLaunchKernelEx to enable PDL (Programmatic Dependent Launch) + const int num_m_block = cute::ceil_div(params.q_seq_per_hk, T::BLOCK_SIZE_M); + cudaLaunchAttribute mla_kernel_attributes[1]; + mla_kernel_attributes[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + mla_kernel_attributes[0].val.programmaticStreamSerializationAllowed = 1; + cudaLaunchConfig_t mla_kernel_config = { + dim3(num_m_block, params.h_k, params.num_sm_parts), + dim3(T::NUM_THREADS, 1, 1), + smem_size, + stream, + mla_kernel_attributes, + 1 + }; + cudaLaunchKernelEx(&mla_kernel_config, mla_kernel, params, tma_params); + CHECK_CUDA_KERNEL_LAUNCH(); +} + +template void run_flash_splitkv_mla_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream); + +#ifndef FLASH_MLA_DISABLE_FP16 +template void run_flash_splitkv_mla_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream); +#endif diff --git a/csrc/kernels/splitkv_mla.h b/csrc/kernels/splitkv_mla.h new file mode 100644 index 0000000..479fb50 --- /dev/null +++ b/csrc/kernels/splitkv_mla.h @@ -0,0 +1,6 @@ +#pragma once + +#include "params.h" + +template +void run_flash_splitkv_mla_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream); diff --git a/csrc/kernels/traits.h b/csrc/kernels/traits.h new file mode 100644 index 0000000..31c1388 --- /dev/null +++ b/csrc/kernels/traits.h @@ -0,0 +1,106 @@ +#pragma once + +#include +#include +#include +#include + +#include "config.h" + +using TMABarrier = cutlass::arch::ClusterTransactionBarrier; +using namespace cute; + +template +struct Traits { + using InputT = InputT_; + + static constexpr int BLOCK_SIZE_M = Config::BLOCK_SIZE_M; + static constexpr int PAGE_BLOCK_SIZE = Config::PAGE_BLOCK_SIZE; + static constexpr int HEAD_DIM_K = Config::HEAD_DIM_K; + static constexpr int HEAD_DIM_V = Config::HEAD_DIM_V; + + static constexpr int NUM_THREADS = 256; + + static_assert(std::is_same_v || std::is_same_v); + + using TiledMMA_QK_sQ = decltype(make_tiled_mma( + GMMA::ss_op_selector, Int, Int>, GMMA::Major::K, GMMA::Major::K>(), + Layout>{} + )); + + using TiledMMA_QK_rQ = decltype(make_tiled_mma( + GMMA::rs_op_selector, Int, Int>, GMMA::Major::K, GMMA::Major::K>(), + Layout>{} + )); + + using TiledMMA_PV_LocalP = decltype(make_tiled_mma( + GMMA::rs_op_selector, Int, Int>, GMMA::Major::K, GMMA::Major::MN>(), + Layout>{} + )); + + using TiledMMA_PV_RemoteP = decltype(make_tiled_mma( + GMMA::ss_op_selector, Int, Int>, GMMA::Major::K, GMMA::Major::MN>(), + Layout>{} + )); + + using SmemLayoutQ = decltype(tile_to_shape( + GMMA::Layout_K_SW128_Atom{}, + Shape, Int>{} + )); + + using SmemLayoutK = decltype(tile_to_shape( + GMMA::Layout_K_SW128_Atom{}, + Shape, Int>{} + )); + + using SmemLayoutV = decltype(composition( + SmemLayoutK{}, + make_layout(Shape, Int>{}, GenRowMajor{}) + )); // A transposed version of SmemLayoutK + + using SmemLayoutP0 = decltype(tile_to_shape( + GMMA::Layout_K_SW128_Atom{}, + Shape, Int>{} + )); + + using rP0Layout = decltype(layout(partition_fragment_C( + TiledMMA_QK_sQ{}, + Shape, Int>{} + ))); + + struct SharedMemoryPlan { + cute::array_aligned> smem_sQ; + cute::array_aligned> smem_sK0; + cute::array_aligned> smem_sK1; + cute::array_aligned> smem_sP0; + cute::array_aligned smem_sM; + cute::array_aligned sL_reduction_wksp; + cute::array_aligned smem_sScale0; + cute::array_aligned smem_sScale1; + TMABarrier barriers_K0[HEAD_DIM_K/64]; + TMABarrier barriers_K1[HEAD_DIM_K/64]; + TMABarrier barrier_Q; + }; + +}; + +template< + typename ShapeQ, typename TMA_Q, + typename ShapeK, typename TMA_K, + typename ShapeO, typename TMA_O +> +struct TmaParams { + ShapeQ shape_Q; + TMA_Q tma_Q; + ShapeK shape_K; + TMA_K tma_K; + ShapeO shape_O; + TMA_O tma_O; +}; + +enum NamedBarriers : int { + sScale0Ready = 0, + sScale1Ready = 1, + sP0Ready = 2, + rO1sP0sV0RIssued = 3 +}; diff --git a/csrc/static_switch.h b/csrc/kernels/utils.h similarity index 57% rename from csrc/static_switch.h rename to csrc/kernels/utils.h index f156adc..ae9d0fc 100644 --- a/csrc/static_switch.h +++ b/csrc/kernels/utils.h @@ -5,7 +5,7 @@ cudaError_t status_ = call; \ if (status_ != cudaSuccess) { \ fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ - exit(1); \ + exit(1); \ } \ } while(0) @@ -29,37 +29,4 @@ } \ } while(0) - -#define BOOL_SWITCH(COND, CONST_NAME, ...) \ - [&] { \ - if (COND) { \ - constexpr static bool CONST_NAME = true; \ - return __VA_ARGS__(); \ - } else { \ - constexpr static bool CONST_NAME = false; \ - return __VA_ARGS__(); \ - } \ - }() - - -#define MLA_NUM_SPLITS_SWITCH(NUM_SPLITS, NAME, ...) \ - [&] { \ - if (NUM_SPLITS <= 32) { \ - constexpr static int NAME = 32; \ - return __VA_ARGS__(); \ - } else if (NUM_SPLITS <= 64) { \ - constexpr static int NAME = 64; \ - return __VA_ARGS__(); \ - } else if (NUM_SPLITS <= 96) { \ - constexpr static int NAME = 96; \ - return __VA_ARGS__(); \ - } else if (NUM_SPLITS <= 128) { \ - constexpr static int NAME = 128; \ - return __VA_ARGS__(); \ - } else if (NUM_SPLITS <= 160) { \ - constexpr static int NAME = 160; \ - return __VA_ARGS__(); \ - } else { \ - FLASH_ASSERT(false); \ - } \ - }() +#define println(fmt, ...) { print(fmt, ##__VA_ARGS__); print("\n"); } diff --git a/csrc/named_barrier.h b/csrc/named_barrier.h deleted file mode 100644 index cefa936..0000000 --- a/csrc/named_barrier.h +++ /dev/null @@ -1,15 +0,0 @@ -#pragma once - -#include "cutlass/barrier.h" - -namespace flash { - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// Enumerates the reserved named barriers to avoid potential conflicts - -enum class NamedBarriers { - SReady = 1, - SoftmaxReady = 2, -}; - -} // flash diff --git a/csrc/softmax.h b/csrc/softmax.h deleted file mode 100644 index 17e293a..0000000 --- a/csrc/softmax.h +++ /dev/null @@ -1,200 +0,0 @@ -// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/softmax.h -/****************************************************************************** - * Copyright (c) 2024, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include - -#include -#include - -#include "utils.h" - -namespace flash { - -using namespace cute; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__device__ __forceinline__ void thread_reduce_(Tensor const &tensor, Tensor &summary, Operator &op) { - static_assert(Layout0::rank == 2, "Only support 2D Tensor"); - static_assert(Layout1::rank == 1, "Only support 1D Tensor"); - CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); - #pragma unroll - for (int mi = 0; mi < size<0>(tensor); mi++) { - summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0)); - #pragma unroll - for (int ni = 1; ni < size<1>(tensor); ni++) { - summary(mi) = op(summary(mi), tensor(mi, ni)); - } - } -} - -template -__device__ __forceinline__ void quad_allreduce_(Tensor &dst, Tensor &src, Operator &op) { - CUTE_STATIC_ASSERT_V(size(dst) == size(src)); - #pragma unroll - for (int i = 0; i < size(dst); i++){ - dst(i) = Allreduce<4>::run(src(i), op); - } -} - -template -__device__ __forceinline__ void reduce_(Tensor const& tensor, Tensor &summary, Operator &op) { - thread_reduce_(tensor, summary, op); - quad_allreduce_(summary, summary, op); -} - -template -__device__ __forceinline__ void reduce_max(Tensor const& tensor, Tensor &max){ - MaxOp max_op; - reduce_(tensor, max, max_op); -} - -template -__device__ __forceinline__ void reduce_sum(Tensor const& tensor, Tensor &sum){ - SumOp sum_op; - thread_reduce_(tensor, sum, sum_op); -} - -// Apply the exp to all the elements. -template -__forceinline__ __device__ auto scale_apply_exp2(Tensor &tensor, Tensor const &max, const float scale) { - static_assert(Layout0::rank == 2, "Only support 2D Tensor"); - static_assert(Layout1::rank == 1, "Only support 1D Tensor"); - CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); - #pragma unroll - for (int mi = 0; mi < size<0>(tensor); ++mi) { - // If max is -inf, then all elements must have been -inf (possibly due to masking). - // We don't want (-inf - (-inf)) since that would give NaN. - // If we don't have float around M_LOG2E the multiplication is done in fp64. - const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E)); - #pragma unroll - for (int ni = 0; ni < size<1>(tensor); ++ni) { - // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - - // max * log_2(e)) This allows the compiler to use the ffma - // instruction instead of fadd and fmul separately. - // The following macro will disable the use of fma. - // See: https://github.com/pytorch/pytorch/issues/121558 for more details - // This macro is set in PyTorch and not FlashAttention - #ifdef UNFUSE_FMA - tensor(mi, ni) = exp2f(__fmul_rn(tensor(mi, ni), scale) - max_scaled); - #else - tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); - #endif - } - } - return tensor; -} - -// Apply the exp to all the elements. -template -__forceinline__ __device__ void max_scale_exp2_sum(Tensor &tensor, Tensor &max, Tensor &sum, const float scale) { - static_assert(Layout0::rank == 2, "Only support 2D Tensor"); - static_assert(Layout1::rank == 1, "Only support 1D Tensor"); - CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); - #pragma unroll - for (int mi = 0; mi < size<0>(tensor); ++mi) { - MaxOp max_op; - max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0)); - #pragma unroll - for (int ni = 1; ni < size<1>(tensor); ni++) { - max(mi) = max_op(max(mi), tensor(mi, ni)); - } - max(mi) = Allreduce<4>::run(max(mi), max_op); - // If max is -inf, then all elements must have been -inf (possibly due to masking). - // We don't want (-inf - (-inf)) since that would give NaN. - const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale; - sum(mi) = 0; - #pragma unroll - for (int ni = 0; ni < size<1>(tensor); ++ni) { - // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - - // max * log_2(e)) This allows the compiler to use the ffma - // instruction instead of fadd and fmul separately. - tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); - sum(mi) += tensor(mi, ni); - } - SumOp sum_op; - sum(mi) = Allreduce<4>::run(sum(mi), sum_op); - } -} - -template -__forceinline__ __device__ void rescale_o(Tensor0 &acc_o, Tensor1 &scale_o) { - // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) - Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); - #pragma unroll - for (int mi = 0; mi < size(scale_o); ++mi) { - #pragma unroll - for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale_o(mi); } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Softmax { - - using TensorT = decltype(make_tensor(Shape>{})); - TensorT row_max, row_sum; - - __forceinline__ __device__ Softmax() {}; - - template - __forceinline__ __device__ TensorT softmax(Tensor0 &acc_s, float softmax_scale_log2) { - // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) - Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); - static_assert(decltype(size<0>(scores))::value == kNRows); - TensorT scale_o; - clear(scale_o); - if (Is_first) { - flash::template reduce_max(scores, row_max); - flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); - flash::reduce_sum(scores, row_sum); - } else { - Tensor scores_max_prev = make_fragment_like(row_max); - cute::copy(row_max, scores_max_prev); - flash::template reduce_max(scores, row_max); - // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) - #pragma unroll - for (int mi = 0; mi < size(row_max); ++mi) { - float scores_max_cur = !Check_inf - ? row_max(mi) - : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi)); - float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); - scale_o(mi) = scores_scale; - row_sum(mi) *= scores_scale; - } - flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); - // We don't do the reduce across threads here since we don't need to use the row_sum. - // We do that reduce at the end when we need to normalize the softmax. - flash::reduce_sum(scores, row_sum); - } - return scale_o; - }; - - template - __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) { - SumOp sum_op; - quad_allreduce_(row_sum, row_sum, sum_op); - TensorT lse = make_fragment_like(row_sum); - // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) - Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); - static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); - #pragma unroll - for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { - float sum = row_sum(mi); - float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; - lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum); - float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout; - #pragma unroll - for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; } - } - return lse; - }; -}; - -} // namespace flash diff --git a/csrc/utils.h b/csrc/utils.h deleted file mode 100644 index 50295f7..0000000 --- a/csrc/utils.h +++ /dev/null @@ -1,241 +0,0 @@ -// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/hopper/utils.h -/****************************************************************************** - * Copyright (c) 2024, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include -#include -#include - -#include - -#include - -#include -#include -#include -#include - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace flash { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MaxOp { -__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; } -}; - -template <> -struct MaxOp { -// This is slightly faster -__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct SumOp { -__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Allreduce { - static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); - template - static __device__ __forceinline__ T run(T x, Operator &op) { - constexpr int OFFSET = THREADS / 2; - x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); - return Allreduce::run(x, op); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -struct Allreduce<2> { -template -static __device__ __forceinline__ T run(T x, Operator &op) { - x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); - return x; -} -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) { - constexpr bool Is_RS = !cute::is_base_of::value; - // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const - if constexpr (Is_RS) { cute::warpgroup_fence_operand(const_cast(tCrA)); } - warpgroup_fence_operand(tCrC); - if constexpr (arrive) { - warpgroup_arrive(); - } - if constexpr (zero_init) { - tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; - // Unroll the K mode manually to set scale D to 1 - CUTLASS_PRAGMA_UNROLL - for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { - cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); - tiled_mma.accumulate_ = GMMA::ScaleOut::One; - } - } else { - // cute::gemm(tiled_mma, tCrA, tCrB, tCrC); - // Unroll the K mode manually to set scale D to 1 - CUTLASS_PRAGMA_UNROLL - for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { - cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); - tiled_mma.accumulate_ = GMMA::ScaleOut::One; - } - } - if constexpr (commit) { - warpgroup_commit_batch(); - } - if constexpr (wg_wait >= 0) { warpgroup_wait(); } - warpgroup_fence_operand(tCrC); - if constexpr (Is_RS) { warpgroup_fence_operand(const_cast(tCrA)); } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) -// For SM90, convert acc_layout from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) -template -__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout0 acc_layout) { - if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90 - static_assert(decltype(size<0, 0>(acc_layout))::value == 2); - static_assert(decltype(size<0, 1>(acc_layout))::value == 2); - static_assert(decltype(rank(acc_layout))::value == 3); - auto l = acc_layout; - if constexpr (!Transposed) { - return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l))); - } else { - return make_layout(make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l))); - } - - } else { // SM80 - static_assert(decltype(size<0>(acc_layout))::value == 4); - static_assert(decltype(rank(acc_layout))::value == 3); - auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) - if constexpr (!Transposed) { - return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); - } else { - return make_layout(make_layout(get<0, 0>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l))); - } - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) -// if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8. -// For SM90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N)) -// For SM90, FP8, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((4, 2, 2), MMA_M, (N / 32, MMA_N)) -template -__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout0 acc_layout) { - using X = Underscore; - if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90 - static_assert(decltype(size<0, 0>(acc_layout))::value == 2); - static_assert(decltype(size<0, 1>(acc_layout))::value == 2); - static_assert(decltype(rank(acc_layout))::value == 3); - static_assert(decltype(rank(get<0>(acc_layout)))::value == 3); - if constexpr (sizeof(typename MMA_Traits::ValTypeA) == 2) { - auto l = logical_divide(get<0, 2>(acc_layout), Tile<_2>{}); // ((2, N / 16)) - return make_layout(make_layout(get<0, 0>(acc_layout), get<0, 1>(acc_layout), get<0, 0>(l)), get<1>(acc_layout), coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); - } else { - static_assert(sizeof(typename MMA_Traits::ValTypeA) == 1); - static_assert(decltype(stride<0, 0>(acc_layout))::value == 1); - static_assert(decltype(stride<0, 1>(acc_layout))::value == 2); - auto l = logical_divide(get<0, 2>(acc_layout), Tile>>{}); // (((2, 2), N / 32)) - // This combines the first two modes (<0, 0> and <0, 1>) into one mode. - // Will require register shuffling later to be correct. - return make_layout(make_layout(Layout<_4>{}, get<0, 0, 0>(l), get<0, 0, 1>(l)), - get<1>(acc_layout), - coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); // ((4, 2, 2), MMA_M, N / 32 * MMA_N) - // This combination is right but doesn't work with register shuffling. - // return make_layout(make_layout(coalesce(make_layout(get<0, 0>(acc_layout), get<0, 0, 0>(l))), get<0, 1>(acc_layout), get<0, 0, 1>(l)), - // get<1>(acc_layout), - // coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); - } - } else { // SM80 - static_assert(decltype(size<0>(acc_layout))::value == 4); - static_assert(decltype(rank(acc_layout))::value == 3); - constexpr int mma_shape_K = get<2>(typename MMA_Traits::Shape_MNK{}); - static_assert(mma_shape_K == 8 || mma_shape_K == 16); - if constexpr (mma_shape_K == 8) { - return acc_layout; - } else { - auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) - return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); - } - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__forceinline__ __device__ auto convert_type(Tensor const &tensor) { - using From_type = typename Engine::value_type; - constexpr int numel = decltype(size(tensor))::value; - cutlass::NumericArrayConverter convert_op; - // HACK: this requires tensor to be "contiguous" - auto frag = convert_op(*reinterpret_cast *>(tensor.data())); - return make_tensor(make_rmem_ptr(&frag), tensor.layout()); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Blocks until all but N previous cp.async.commit_group operations have committed. -// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all -// (which is equivalent to commit_group then wait_group 0). -// Instead we just call cp.async.wait_group 0, which is slightly faster. -// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113 -template -CUTE_HOST_DEVICE -void cp_async_wait() { -#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) - asm volatile("cp.async.wait_group %0;\n" :: "n"(N)); -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor const &S, - Tensor &D, Tensor const &identity_MN, - Tensor const &predicate_K, const int max_MN=0) { - CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); - CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); - CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K - // There's no case where !Clear_OOB_K && Clear_OOB_MN - static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); - #pragma unroll - for (int m = 0; m < size<1>(S); ++m) { - if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { - #pragma unroll - for (int k = 0; k < size<2>(S); ++k) { - if (Is_even_K || predicate_K(k)) { - cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); - } else if (Clear_OOB_K) { - cute::clear(D(_, m, k)); - } - } - } else if (Clear_OOB_MN) { - cute::clear(D(_, m, _)); - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace flash diff --git a/docs/20250422-new-kernel-deep-dive.md b/docs/20250422-new-kernel-deep-dive.md new file mode 100644 index 0000000..555fd87 --- /dev/null +++ b/docs/20250422-new-kernel-deep-dive.md @@ -0,0 +1,77 @@ +# A Deep-Dive Into the New Flash MLA Kernel + +In the [previous version](https://github.com/deepseek-ai/FlashMLA/tree/b31bfe72a83ea205467b3271a5845440a03ed7cb) of the Flash MLA kernel, we have achieved impressive performance: 3000 GB/s in memory-intensive settings and 580 TFlops in compute-bound settings. Now, we're pushing these numbers even further, reaching up to 660 TFlops. + +In this blog, we present a deep dive into the new kernel, explaining the optimizations and techniques behind this performance boost. We'll first explain why the MLA kernel is compute-bound despite being a decoding-stage attention kernel, then discuss our high-level kernel schedule design, and finally cover the technical details of the new kernel. + +## A Theoretical Analysis of the MLA Algorithm + +GPU kernels can be classified as either compute-bound (limited by floating-point operations per second, FLOPs) or memory-bound (limited by memory bandwidth). To identify the kernel's bottleneck, we calculate the ratio of FLOPs to memory bandwidth (FLOPs/byte) and compare it with the GPU's capacity. + +Assume the number of q heads is $h_q$, the number of q tokens per request is $s_q$ (should be 1 if MTP / speculative decoding is disabled), the number of kv tokens per request is $s_k\ (s_k \gg h_q s_q)$, and the head dimensions of K and V are $d_k$ and $d_v$ respectively. The number of FLOPs is roughly $2 (h_q s_q \cdot d_k \cdot s_k + h_q s_q \cdot s_k \cdot d_v) = 2 h_q s_q s_k (d_k+d_v)$, and the memory access volume (in bytes) is $\mathop{\text{sizeof}}(\text{bfloat16}) \times (h_q s_q d_k + s_k d_k + h_q s_q d_v) \approx 2s_k d_k$. Thus, the compute-memory ratio is $h_q s_q \cdot \frac{d_k+d_v}{d_k} \approx 2 h_q s_q$. + +An NVIDIA H800 SXM5 GPU has a peak memory bandwidth of 3.35 TB/s and peak FLOPs of 990 TFlops. However, due to throttling (reducing to ~1600 MHz in our case), the practical peak FLOPs drops to ~865 TFlops. Therefore, when $h_qs_q \ge \frac{1}{2} \cdot \frac{865}{3.35} = 128$, the kernel is compute-bound; otherwise, it's memory-bound. + +According to [the overview of DeepSeek's Online Inference System](https://github.com/deepseek-ai/open-infra-index/blob/main/202502OpenSourceWeek/day_6_one_more_thing_deepseekV3R1_inference_system_overview.md), we don't use Tensor Parallel for decoding instances, meaning $h_q$ is 128 and the kernel is compute-bound. Thus, we need to optimize the kernel for compute-bound settings. + +## High-Level Design of the New Kernel + +To fully utilize GPU compute resources, we need to overlap CUDA Core operations with Tensor Core operations and memory access with computation, keeping the Tensor Core constantly busy. This requires redesigning the kernel's "schedule." + +[FlashAttention-3's paper](https://arxiv.org/abs/2205.14135) introduces ping-pong scheduling and intra-warpgroup GEMM-softmax pipelining to overlap block-wise matmul and CUDA Core operations. However, these techniques can't be directly applied here due to resource constraints. The output matrix (scaled and accumulated during each mainloop round, similar to [FlashAttention's algorithm](https://arxiv.org/abs/2205.14135)) must be stored in registers due to [WGMMA instruction](https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-instructions) requirements. Each $64 \times 512$ output matrix occupies 32,768 32-bit registers. With only 65,536 32-bit registers per SM, we can store only one output matrix per SM. This eliminates the possiblility of having two output matrices and letting them use CUDA Core and Tensor Core in a interleaved manner. We need to find another clever way to overlap CUDA Core and Tensor Core computation. + +(You might pause here to ponder - perhaps you can find a better solution than ours!) + +Our solution involves an additional mathematical transformation beyond FlashAttention's online softmax and accumulation approach. In each step, we take two KV blocks (called $K_0$, $K_1$, $V_0$, and $V_1$). Since the output matrix occupies 32,768 registers (too many for one warpgroup), we split it vertically into $O_L$ and $O_R$ (each $64 \times 256$). We similarly split $V_0$ and $V_1$ into $V_{0L}$, $V_{0R}$, $V_{1L}$, and $V_{1R}$ (each $64 \times 256$). The output matrix is then computed as follows: + +0. Maintain a running max $m$ (initialized to $-\infty$, shared between the two warpgroups) and output matrices $\vec o_L, \vec o_R$ (initialized to 0). +1. [0] Compute $\vec p_0 = \vec q K_0^\intercal / qk\_scale$. +2. [1] Compute $\vec p_1 = \vec q K_1^\intercal / qk\_scale$. +3. [0] Compute $mp_0 = \max(\vec p_0)$, $m\_new_0 = \max(m, mp_0)$, and $scale_0 = \exp(m\_new_0 - m)$. Update $m \gets m\_new_0$. +4. [0] Perform softmax on $\vec p_0$: $\vec p_0 \gets \exp(\vec p_0 - m\_new_0)$. +5. [0] Update $\vec o_L \gets \vec o_L \cdot scale_0 + \vec p_0 V_{0L}$. +6. [1] Compute $mp_1 = \max(\vec p_1)$, $m\_new_1 = \max(m, mp_1)$, and $scale_1 = \exp(m\_new_1 - m)$. Update $m \gets m\_new_1$. +7. [1] Perform softmax on $\vec p_1$: $\vec p_1 \gets \exp(\vec p_1 - m\_new_1)$. +8. [1] Update $\vec o_R \gets \vec o_R \cdot (scale_0 \cdot scale_1) + \vec p_1 V_{1R}$. +9. [0] Update $\vec p_0 \gets \vec p_0 \cdot scale_1$. +10. [1] Update $\vec o_R \gets \vec o_R + \vec p_0 V_{0R}$. +11. [0] Update $\vec o_L \gets \vec o_L \cdot scale_1 + \vec p_1 V_{1L}$. + +Note: We assume one q head for simplicity, so $\vec q$ and $\vec o$ are vectors. Bracketed numbers indicate the warpgroup performing the operation. Assume $\vec o_L$ resides in warpgroup 0's register and $\vec o_R$ resides in warpgroup 1's register. + +This schedule can be viewed as a "ping-pong" variant using one output matrix—we call it "seesaw" scheduling. It's mathematically equivalent to FlashAttention's online softmax algorithm. This schedule allows us to overlap CUDA Core and Tensor Core operations by interleaving the two warpgroups, and also allows us to overlap memory access with computation since we can launch the corresponding Tensor Memory Accelerator (TMA) instructions right after data is no longer needed. + +The complete schedule is shown below (remember that in MLA, $K$ and $V$ are the same with different names): + +![MLA Kernel Sched](assets/MLA%20Kernel%20Sched.drawio.svg) + +## Discussion of Technical Details + +This section covers technical details of the new kernel. + +First, although the kernel targets compute-bound scenarios (where memory bandwidth isn't the bottleneck), we can't ignore memory latency. If the data is not ready when we want to use it, we have to wait. To solve this problem, we employ the following techniques: + +- **Fine-grained TMA copy - GEMM pipelining:** For a $64 \times 576$ K block, we launch 9 TMA copies (each moving a $64 \times 64$ block). GEMM operations begin as soon as each TMA copy completes (When the first TMA copy is done, we can start the first GEMM operation, and so on), improving memory latency tolerance. +- **Cache hints:** Using `cute::TMA::CacheHintSm90::EVICT_FIRST` for TMA copies improves L2 cache hit rates, as shown by experiments. + +These optimizations achieve up to 80% Tensor Core utilization (of the throttled theoretical peak) and 3 TB/s memory bandwidth on an H800 SXM5 GPU. While slightly slower (~2%) than the old ping-pong buffer version in memory-bound settings, this is acceptable. + +Other performance improvements include: +- **Programmatic Dependent Launch.** We use [programmatic dependent launch](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization) to overlap `splitkv_mla` and `combine` kernels. +- **Tile Scheduler.** We implement a tile scheduler to allocate jobs (requests and blocks) to SMs. This ensures a balanced load across SMs. + +## Acknowledgements + +FlashMLA's algorithm and scheduling is inspired by [FlashAttention](https://github.com/dao-AILab/flash-attention/), [Flash-Decoding](https://crfm.stanford.edu/2023/10/12/flashdecoding.html), and [CUTLASS](https://github.com/nvidia/cutlass), as well as many projects behind them. We thank the authors for their great work. + +## Citation + +```bibtex +@misc{flashmla2025, + title={FlashMLA: Efficient MLA decoding kernels}, + author={Jiashi Li, Shengyu Liu}, + year={2025}, + publisher = {GitHub}, + howpublished = {\url{https://github.com/deepseek-ai/FlashMLA}}, +} +``` diff --git a/docs/assets/MLA Kernel Sched.drawio.svg b/docs/assets/MLA Kernel Sched.drawio.svg new file mode 100644 index 0000000..f3e94a5 --- /dev/null +++ b/docs/assets/MLA Kernel Sched.drawio.svg @@ -0,0 +1,856 @@ + + + + + + + + + + + + + + + + +
+
+
+ rP0 = sQ @ sK0 +
+
+
+
+ + rP0 = sQ @ sK0 + +
+
+
+ + + + + + + + + +
+
+
+ rP1 = sQ @ sK1 +
+
+
+
+ + rP1 = sQ @ sK1 + +
+
+
+ + + + + + + + + +
+
+
+
+
+ + Get sScale0 + +
+
+ + Update sM + +
+ + rPb = rP0 = Softmax(rP0) + +
+ + rO0 = Scale(rO0) + +
+
+ + Update rL + +
+
+
+
+
+
+ + Get sScale0... + +
+
+
+ + + + + + + + + +
+
+
+ Issue +
+
+
+
+ + Issue + +
+
+
+ + + + + + + + + +
+
+
+ rO0 += rPb @ sV0L +
+
+
+
+ + rO0 += rPb @ sV0L + +
+
+
+ + + + + + + + + +
+
+
+
+
+ + Get sScale1 + +
+
+ + Update sM + +
+ + rP1b = Softmax(rP1 + + ) + + +
+ + rO1 = Scale(rO1) + + +
+
+ + + Update rL + + +
+
+
+
+
+
+ + Get sScale1... + +
+
+
+ + + + + + + + + + + + + +
+
+
+ Issue +
+
+
+
+ + Issue + +
+
+
+ + + + + + + + + +
+
+
+ rO1 += rP1b @ sV1R +
+
+
+
+ + rO1 += rP1b @ sV1R + +
+
+
+ + + + + + + + + +
+
+
+ rO0 = Scale(rO0) +
+
+
+
+ + rO0 = Scale(rO0) + +
+
+
+ + + + + + + + + + + + + +
+
+
+ rO0 += sP1 @ sV1L +
+
+
+
+ + rO0 += sP1 @ sV1L + +
+
+
+ + + + + + + + + +
+
+
+ rO1 += sP0 @ sV0R +
+
+
+
+ + rO1 += sP0 @ sV0R + +
+
+
+ + + + + + + + + + + + + + + + + + + + + + + + + +
+
+
+ sP0 = Scale(rP0) +
+
+
+
+ + sP0 = Scale(rP0) + +
+
+
+ + + + + + + +
+
+
+ Tensor +
+
+
+
+ + Tensor + +
+
+
+ + + + + + + +
+
+
+ CUDA +
+
+
+
+ + CUDA + +
+
+
+ + + + + + + +
+
+
+ CUDA +
+
+
+
+ + CUDA + +
+
+
+ + + + + + + +
+
+
+ Tensor +
+
+
+
+ + Tensor + +
+
+
+ + + + + + + +
+
+
+ Warpgroup 0 +
+
+
+
+ + Warpgroup 0 + +
+
+
+ + + + + + + +
+
+
+ Warpgroup 1 +
+
+
+
+ + Warpgroup 1 + +
+
+
+ + + + + + + + + + + + + + + + +
+
+
+ Issue +
+
+
+
+ + Issue + +
+
+
+ + + + + + + + + + + + + + + + + + + + + +
+
+
+ Issue +
+
+
+
+ + Issue + +
+
+
+ + + + + + + + + +
+
+
+ Pipelined TMA wait and issue +
+
+
+
+ + Pipelined TMA wait and issue + +
+
+
+ + + + + + + + + + + + + +
+
+
+ Pipelined TMA wait and issue +
+
+
+
+ + Pipelined TMA wait and issue + +
+
+
+ + + + + + + + + +
+
+
+ sP1 = rP1b +
+
+
+
+ + sP1 = rP1b + +
+
+
+ + + + + + + + + + +
+
+
+ wg0-bunch-0 +
+
+
+
+ + wg0-bunch-0 + +
+
+
+ + + + + + + + + + +
+
+
+ wg1-bunch-0 +
+
+
+
+ + wg1-bunch-0 + +
+
+
+ + + + + + + + + +
+
+
+ Issue TMA (nxt V0L) +
+
+
+
+ + Issue TMA (nxt V0L) + +
+
+
+ + + + + + + + + +
+
+
+ Issue TMA (nxt V1L) +
+
+
+
+ + Issue TMA (nxt V1L) + +
+
+
+ + + + + + + + + + + + + + + + + +
+
+
+ Issue TMA (nxt V1R) +
+
+
+
+ + Issue TMA (nxt V1R) + +
+
+
+ + + + + + + + + +
+
+
+ Issue TMA (nxt V0R) +
+
+
+
+ + Issue TMA (nxt V0R) + +
+
+
+ + + + + + + + + + + + + + + + + + + +
+
+
+ sXX: Stored on shared memory +
+ rXX: Stored on register file +
+
+
+
+
+ + sXX: Stored on shared memory... + +
+
+
+ + + + + + + + + + + + + + + + +
+
+
+ + Loop boundary in our code + +
+ + (plz refer to comments in `wg1_subroutine`) + +
+
+
+
+
+ + Loop boundary in our code... + +
+
+
+
+ + + + + Text is not SVG - cannot display + + + +
\ No newline at end of file diff --git a/flash_mla/flash_mla_interface.py b/flash_mla/flash_mla_interface.py index b2922af..47637f8 100644 --- a/flash_mla/flash_mla_interface.py +++ b/flash_mla/flash_mla_interface.py @@ -55,7 +55,6 @@ def flash_mla_with_kvcache( out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla( q, k_cache, - None, head_dim_v, cache_seqlens, block_table, diff --git a/setup.py b/setup.py index cd311f2..131ceff 100644 --- a/setup.py +++ b/setup.py @@ -11,29 +11,13 @@ IS_WINDOWS, ) -DISABLE_FP16 = os.getenv("FLASH_MLA_DISABLE_FP16", "FALSE") == "TRUE" - - def append_nvcc_threads(nvcc_extra_args): nvcc_threads = os.getenv("NVCC_THREADS") or "32" return nvcc_extra_args + ["--threads", nvcc_threads] - -def get_sources(): - sources = [ - "csrc/flash_api.cpp", - "csrc/flash_fwd_mla_bf16_sm90.cu", - "csrc/flash_fwd_mla_metadata.cu", - ] - - if not DISABLE_FP16: - sources.append("csrc/flash_fwd_mla_fp16_sm90.cu") - - return sources - - def get_features_args(): features_args = [] + DISABLE_FP16 = os.getenv("FLASH_MLA_DISABLE_FP16", "FALSE") in ["TRUE", "1"] if DISABLE_FP16: features_args.append("-DFLASH_MLA_DISABLE_FP16") return features_args @@ -56,7 +40,12 @@ def get_features_args(): ext_modules.append( CUDAExtension( name="flash_mla_cuda", - sources=get_sources(), + sources=[ + "csrc/flash_api.cpp", + "csrc/kernels/get_mla_metadata.cu", + "csrc/kernels/mla_combine.cu", + "csrc/kernels/splitkv_mla.cu", + ], extra_compile_args={ "cxx": cxx_args + get_features_args(), "nvcc": append_nvcc_threads( diff --git a/tests/test_flash_mla.py b/tests/test_flash_mla.py index 0abe9d2..67c9d93 100644 --- a/tests/test_flash_mla.py +++ b/tests/test_flash_mla.py @@ -127,7 +127,7 @@ def main(torch_dtype): causal = True for b in [128]: - for s in [4096, 8192]: + for s in [4096, 8192, 16384]: for h_q in [16, 32, 64, 128]: # TP = 8, 4, 2, 1 for s_q in [1, 2]: # MTP = 1, 2 for varlen in [False, True]: From fed73edb9c078ac20b115d09238080f788c8811c Mon Sep 17 00:00:00 2001 From: Shengyu Liu Date: Tue, 22 Apr 2025 18:03:14 +0800 Subject: [PATCH 17/23] Update README.md (#72) Signed-off-by: Lucas Wilkinson --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 6de1640..5d66f55 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ We're excited to announce the new release of Flash MLA, which delivers 5% ~ 15% performance improvement on compute-bound workloads, achieving up to 660 TFlops on NVIDIA H800 SXM5 GPUs. The interface of the new version is fully compatible with the old one. Just switch to the new version and enjoy the instant speedup! 🚀🚀🚀 -Besides, we'd love to share the technical details behind the new kernel! Check out our deep-dive write-up here: +Besides, we'd love to share the technical details behind the new kernel! Check out our deep-dive write-up [here](docs/20250422-new-kernel-deep-dive.md). The new kernel primarily targets compute-intensive settings (where the number of q heads $\times$ the number of q tokens per request (if MTP is disabled then it's 1) $\ge 64$). For memory-bound cases, we recommend using version [b31bfe7](https://github.com/deepseek-ai/FlashMLA/tree/b31bfe72a83ea205467b3271a5845440a03ed7cb) for optimal performance. From 44b59aec8bb29524cce39f4d7ab71f3a384cabee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=84=8D=F0=9D=95=A0=F0=9D=95=9D=F0=9D=95=9D=F0=9D=95=A0?= =?UTF-8?q?=F0=9D=95=A8=20=F0=9D=95=84=F0=9D=95=92=F0=9D=95=9F?= Date: Wed, 23 Apr 2025 05:14:05 +0300 Subject: [PATCH 18/23] Minor fix to the docs to correct FlashAttention-3's paper link and typos (#73) Thank you for open source FlashMLA! Just read the write up and very amazing work! Found some very minor mistakes regarding to typos, and the link to the FlashAttention-3 paper is wrong as that is the original FlashAttention paper, so I just send the PR here. Thanks again! Signed-off-by: Hollow Man Signed-off-by: Lucas Wilkinson --- docs/20250422-new-kernel-deep-dive.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/20250422-new-kernel-deep-dive.md b/docs/20250422-new-kernel-deep-dive.md index 555fd87..8ad34bb 100644 --- a/docs/20250422-new-kernel-deep-dive.md +++ b/docs/20250422-new-kernel-deep-dive.md @@ -18,7 +18,7 @@ According to [the overview of DeepSeek's Online Inference System](https://github To fully utilize GPU compute resources, we need to overlap CUDA Core operations with Tensor Core operations and memory access with computation, keeping the Tensor Core constantly busy. This requires redesigning the kernel's "schedule." -[FlashAttention-3's paper](https://arxiv.org/abs/2205.14135) introduces ping-pong scheduling and intra-warpgroup GEMM-softmax pipelining to overlap block-wise matmul and CUDA Core operations. However, these techniques can't be directly applied here due to resource constraints. The output matrix (scaled and accumulated during each mainloop round, similar to [FlashAttention's algorithm](https://arxiv.org/abs/2205.14135)) must be stored in registers due to [WGMMA instruction](https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-instructions) requirements. Each $64 \times 512$ output matrix occupies 32,768 32-bit registers. With only 65,536 32-bit registers per SM, we can store only one output matrix per SM. This eliminates the possiblility of having two output matrices and letting them use CUDA Core and Tensor Core in a interleaved manner. We need to find another clever way to overlap CUDA Core and Tensor Core computation. +[FlashAttention-3's paper](https://arxiv.org/abs/2407.08608) introduces ping-pong scheduling and intra-warpgroup GEMM-softmax pipelining to overlap block-wise matmul and CUDA Core operations. However, these techniques can't be directly applied here due to resource constraints. The output matrix (scaled and accumulated during each mainloop round, similar to [FlashAttention's algorithm](https://arxiv.org/abs/2205.14135)) must be stored in registers due to [WGMMA instruction](https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-instructions) requirements. Each $64 \times 512$ output matrix occupies 32,768 32-bit registers. With only 65,536 32-bit registers per SM, we can store only one output matrix per SM. This eliminates the possibility of having two output matrices and letting them use CUDA Core and Tensor Core in a interleaved manner. We need to find another clever way to overlap CUDA Core and Tensor Core computation. (You might pause here to ponder - perhaps you can find a better solution than ours!) @@ -62,7 +62,7 @@ Other performance improvements include: ## Acknowledgements -FlashMLA's algorithm and scheduling is inspired by [FlashAttention](https://github.com/dao-AILab/flash-attention/), [Flash-Decoding](https://crfm.stanford.edu/2023/10/12/flashdecoding.html), and [CUTLASS](https://github.com/nvidia/cutlass), as well as many projects behind them. We thank the authors for their great work. +FlashMLA's algorithm and scheduling are inspired by [FlashAttention](https://github.com/dao-AILab/flash-attention/), [Flash-Decoding](https://crfm.stanford.edu/2023/10/12/flashdecoding.html), and [CUTLASS](https://github.com/nvidia/cutlass), as well as many projects behind them. We thank the authors for their great work. ## Citation From c54af4bd6234a3a19dd3d4e86e04e74b54660698 Mon Sep 17 00:00:00 2001 From: Shengyu Liu Date: Wed, 23 Apr 2025 10:21:14 +0800 Subject: [PATCH 19/23] Fix LaTeX render error (#74) Signed-off-by: Lucas Wilkinson --- docs/20250422-new-kernel-deep-dive.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/20250422-new-kernel-deep-dive.md b/docs/20250422-new-kernel-deep-dive.md index 8ad34bb..da0b6dc 100644 --- a/docs/20250422-new-kernel-deep-dive.md +++ b/docs/20250422-new-kernel-deep-dive.md @@ -25,13 +25,13 @@ To fully utilize GPU compute resources, we need to overlap CUDA Core operations Our solution involves an additional mathematical transformation beyond FlashAttention's online softmax and accumulation approach. In each step, we take two KV blocks (called $K_0$, $K_1$, $V_0$, and $V_1$). Since the output matrix occupies 32,768 registers (too many for one warpgroup), we split it vertically into $O_L$ and $O_R$ (each $64 \times 256$). We similarly split $V_0$ and $V_1$ into $V_{0L}$, $V_{0R}$, $V_{1L}$, and $V_{1R}$ (each $64 \times 256$). The output matrix is then computed as follows: 0. Maintain a running max $m$ (initialized to $-\infty$, shared between the two warpgroups) and output matrices $\vec o_L, \vec o_R$ (initialized to 0). -1. [0] Compute $\vec p_0 = \vec q K_0^\intercal / qk\_scale$. -2. [1] Compute $\vec p_1 = \vec q K_1^\intercal / qk\_scale$. -3. [0] Compute $mp_0 = \max(\vec p_0)$, $m\_new_0 = \max(m, mp_0)$, and $scale_0 = \exp(m\_new_0 - m)$. Update $m \gets m\_new_0$. -4. [0] Perform softmax on $\vec p_0$: $\vec p_0 \gets \exp(\vec p_0 - m\_new_0)$. +1. [0] Compute $`\vec p_0 = \vec q K_0^\intercal / qk\_scale`$. +2. [1] Compute $`\vec p_1 = \vec q K_1^\intercal / qk\_scale`$. +3. [0] Compute $mp_0 = \max(\vec p_0)$, $`m\_new_0 = \max(m, mp_0)`$, and $`scale_0 = \exp(m\_new_0 - m)`$. Update $`m \gets m\_new_0`$. +4. [0] Perform softmax on $\vec p_0$: $`\vec p_0 \gets \exp(\vec p_0 - m\_new_0)`$. 5. [0] Update $\vec o_L \gets \vec o_L \cdot scale_0 + \vec p_0 V_{0L}$. -6. [1] Compute $mp_1 = \max(\vec p_1)$, $m\_new_1 = \max(m, mp_1)$, and $scale_1 = \exp(m\_new_1 - m)$. Update $m \gets m\_new_1$. -7. [1] Perform softmax on $\vec p_1$: $\vec p_1 \gets \exp(\vec p_1 - m\_new_1)$. +6. [1] Compute $mp_1 = \max(\vec p_1)$, $`m\_new_1 = \max(m, mp_1)`$, and $`scale_1 = \exp(m\_new_1 - m)`$. Update $`m \gets m\_new_1`$. +7. [1] Perform softmax on $\vec p_1$: $`\vec p_1 \gets \exp(\vec p_1 - m\_new_1)`$. 8. [1] Update $\vec o_R \gets \vec o_R \cdot (scale_0 \cdot scale_1) + \vec p_1 V_{1R}$. 9. [0] Update $\vec p_0 \gets \vec p_0 \cdot scale_1$. 10. [1] Update $\vec o_R \gets \vec o_R + \vec p_0 V_{0R}$. From 287942839a692de43f2781237b78c23c8f05c550 Mon Sep 17 00:00:00 2001 From: ljss <450993438@qq.com> Date: Mon, 28 Apr 2025 18:53:04 +0800 Subject: [PATCH 20/23] Fix synchronization issues Signed-off-by: Lucas Wilkinson --- .gitignore | 1 + csrc/kernels/splitkv_mla.cu | 10 +++++++--- csrc/kernels/traits.h | 3 ++- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index 982daef..9b500a0 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ dist/ *perf.csv *.png /.vscode +compile_commands.json diff --git a/csrc/kernels/splitkv_mla.cu b/csrc/kernels/splitkv_mla.cu index ff29305..5e1fded 100644 --- a/csrc/kernels/splitkv_mla.cu +++ b/csrc/kernels/splitkv_mla.cu @@ -1017,13 +1017,14 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params cudaGridDependencySynchronize(); int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize; - int4 tile_scheduler_metadata = __ldg(reinterpret_cast(tile_scheduler_metadata_ptr)); + // We don't use __ldg here, otherwise NVCC (ptxas, in particular) will do instruction reorder and place __ldg (LDG.E.128.CONSTANT in SASS) in front of cudaGridDependencySynchronize() (ACQBULK in SASS), leading to data race. + int4 tile_scheduler_metadata = *(reinterpret_cast(tile_scheduler_metadata_ptr)); int begin_idx = tile_scheduler_metadata.x; int begin_seqlen = tile_scheduler_metadata.y; int end_idx = tile_scheduler_metadata.z; int end_seqlen = tile_scheduler_metadata.w; if (begin_idx >= params.b) return; - int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4); + int begin_n_split_idx = *(tile_scheduler_metadata_ptr + 4); // Copy the first Q launch_q_copy(tma_params, begin_idx, m_block_idx, k_head_idx, sQ, barrier_Q); @@ -1123,6 +1124,8 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params // Issue P0 = Q @ K0^T, wait warpgroup_cooperative_qkt_gemm_no_pipeline(sQ, sK0, rP0, idx_in_warpgroup); + // We add a barrier here, making sure that previous writes to sM are visible to warpgroup 0 + NamedBarrier::arrive_and_wait(128, NamedBarriers::sMInitialized); cute::warpgroup_wait<0>(); #define LAUNCH_WG0_SUBROUTINE(IS_BLK0_LAST, IS_BLK1_LAST) \ @@ -1238,7 +1241,8 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params cute::tma_store_wait<0>(); } else { - int split_idx = __ldg(params.num_splits_ptr+batch_idx) + n_split_idx; + // Don't use __ldg because of PDL and instruction reordering + int split_idx = params.num_splits_ptr[batch_idx] + n_split_idx; float* oaccum_ptr = (float*)params.oaccum_ptr + ((split_idx*params.h_k + k_head_idx)*params.q_seq_per_hk + m_block_idx*T::BLOCK_SIZE_M)*T::HEAD_DIM_V; // (BLOCK_SIZE_M, HEAD_DIM_V) : (HEAD_DIM_V, 1) float* softmax_lseaccum_ptr = (float*)params.softmax_lseaccum_ptr + (split_idx*params.h_k + k_head_idx)*params.q_seq_per_hk + m_block_idx*T::BLOCK_SIZE_M; // (BLOCK_SIZE_M) : (1) Tensor gOAccum = make_tensor(make_gmem_ptr(oaccum_ptr), Layout< diff --git a/csrc/kernels/traits.h b/csrc/kernels/traits.h index 31c1388..5f915a6 100644 --- a/csrc/kernels/traits.h +++ b/csrc/kernels/traits.h @@ -102,5 +102,6 @@ enum NamedBarriers : int { sScale0Ready = 0, sScale1Ready = 1, sP0Ready = 2, - rO1sP0sV0RIssued = 3 + rO1sP0sV0RIssued = 3, + sMInitialized = 4, }; From 3fb059e461c860ebb482a176e5ba691e17dc0621 Mon Sep 17 00:00:00 2001 From: ljss <450993438@qq.com> Date: Tue, 29 Apr 2025 12:02:57 +0800 Subject: [PATCH 21/23] update to cutlass 3.9 Signed-off-by: Lucas Wilkinson --- csrc/cutlass | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/cutlass b/csrc/cutlass index afa1772..e94e888 160000 --- a/csrc/cutlass +++ b/csrc/cutlass @@ -1 +1 @@ -Subproject commit afa1772203677c5118fcd82537a9c8fefbcc7008 +Subproject commit e94e888df3551224738bfa505787b515eae8352f From b9c70e8f2593bd72c857cd3a43ccdc59d955084c Mon Sep 17 00:00:00 2001 From: ljss <450993438@qq.com> Date: Tue, 29 Apr 2025 12:03:15 +0800 Subject: [PATCH 22/23] update .gitignore Signed-off-by: Lucas Wilkinson --- .gitignore | 1 + setup.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index 9b500a0..4535280 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ dist/ *.png /.vscode compile_commands.json +.cache diff --git a/setup.py b/setup.py index 131ceff..217f540 100644 --- a/setup.py +++ b/setup.py @@ -11,10 +11,12 @@ IS_WINDOWS, ) + def append_nvcc_threads(nvcc_extra_args): nvcc_threads = os.getenv("NVCC_THREADS") or "32" return nvcc_extra_args + ["--threads", nvcc_threads] + def get_features_args(): features_args = [] DISABLE_FP16 = os.getenv("FLASH_MLA_DISABLE_FP16", "FALSE") in ["TRUE", "1"] From 1e3cc1df0defe575c79fe8c251a22515684a2112 Mon Sep 17 00:00:00 2001 From: Zeyu WANG Date: Fri, 1 Aug 2025 17:21:27 +0800 Subject: [PATCH 23/23] Add more GPU architctures support (#76) * Add more GPU architctures support * Merge fmha and mla runner * add varlen & non varlen support, and add incontiguous tensor support * update readme * add varlen api --------- Co-authored-by: dianzhangc Signed-off-by: Lucas Wilkinson --- README.md | 12 +- csrc/sm100/collective/fmha_common.hpp | 127 ++ csrc/sm100/collective/fmha_fusion.hpp | 396 ++++ ..._fmha_fwd_epilogue_tma_warpspecialized.hpp | 234 +++ ..._fmha_fwd_mainloop_tma_warpspecialized.hpp | 1218 +++++++++++ .../sm100_fmha_load_tma_warpspecialized.hpp | 316 +++ ...a_mla_fwd_mainloop_tma_warpspecialized.hpp | 1225 +++++++++++ ...m100_fmha_mla_load_tma_warpspecialized.hpp | 340 +++ csrc/sm100/common/gather_tensor.hpp | 215 ++ csrc/sm100/common/helper.h | 72 + csrc/sm100/common/mask.cuh | 8 + csrc/sm100/common/pipeline_mla.hpp | 250 +++ csrc/sm100/common/pow_2.hpp | 92 + csrc/sm100/common/utils.hpp | 83 + csrc/sm100/device/fmha.hpp | 276 +++ csrc/sm100/device/fmha_device_bwd.hpp | 340 +++ csrc/sm100/fmha_cutlass_bwd_sm100.cu | 83 + csrc/sm100/fmha_cutlass_bwd_sm100.cuh | 200 ++ csrc/sm100/fmha_cutlass_fwd_sm100.cu | 81 + csrc/sm100/fmha_cutlass_fwd_sm100.cuh | 334 +++ .../kernel/fmha_causal_tile_scheduler.hpp | 197 ++ csrc/sm100/kernel/fmha_kernel_bwd_convert.hpp | 153 ++ csrc/sm100/kernel/fmha_kernel_bwd_sum_OdO.hpp | 161 ++ csrc/sm100/kernel/fmha_options.hpp | 85 + csrc/sm100/kernel/fmha_tile_scheduler.hpp | 162 ++ ...00_fmha_bwd_kernel_tma_warpspecialized.hpp | 1841 +++++++++++++++++ ...mha_bwd_mla_kernel_tma_warpspecialized.hpp | 1834 ++++++++++++++++ ...00_fmha_fwd_kernel_tma_warpspecialized.hpp | 619 ++++++ csrc/sm100/pybind.cu | 17 + csrc/{ => sm90}/flash_api.cpp | 0 csrc/{ => sm90}/kernels/config.h | 0 csrc/{ => sm90}/kernels/get_mla_metadata.cu | 0 csrc/{ => sm90}/kernels/get_mla_metadata.h | 0 csrc/{ => sm90}/kernels/mla_combine.cu | 0 csrc/{ => sm90}/kernels/mla_combine.h | 0 csrc/{ => sm90}/kernels/params.h | 0 csrc/{ => sm90}/kernels/splitkv_mla.cu | 0 csrc/{ => sm90}/kernels/splitkv_mla.h | 0 csrc/{ => sm90}/kernels/traits.h | 0 csrc/{ => sm90}/kernels/utils.h | 0 flash_mla/__init__.py | 3 + flash_mla/flash_mla_interface.py | 271 ++- setup.py | 61 +- ...st_flash_mla.py => test_flash_mla_sm90.py} | 0 tests/test_fmha_sm100.py | 199 ++ 45 files changed, 11489 insertions(+), 16 deletions(-) create mode 100644 csrc/sm100/collective/fmha_common.hpp create mode 100644 csrc/sm100/collective/fmha_fusion.hpp create mode 100644 csrc/sm100/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp create mode 100644 csrc/sm100/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp create mode 100644 csrc/sm100/collective/sm100_fmha_load_tma_warpspecialized.hpp create mode 100644 csrc/sm100/collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp create mode 100644 csrc/sm100/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp create mode 100644 csrc/sm100/common/gather_tensor.hpp create mode 100644 csrc/sm100/common/helper.h create mode 100644 csrc/sm100/common/mask.cuh create mode 100644 csrc/sm100/common/pipeline_mla.hpp create mode 100644 csrc/sm100/common/pow_2.hpp create mode 100644 csrc/sm100/common/utils.hpp create mode 100644 csrc/sm100/device/fmha.hpp create mode 100644 csrc/sm100/device/fmha_device_bwd.hpp create mode 100644 csrc/sm100/fmha_cutlass_bwd_sm100.cu create mode 100644 csrc/sm100/fmha_cutlass_bwd_sm100.cuh create mode 100644 csrc/sm100/fmha_cutlass_fwd_sm100.cu create mode 100644 csrc/sm100/fmha_cutlass_fwd_sm100.cuh create mode 100644 csrc/sm100/kernel/fmha_causal_tile_scheduler.hpp create mode 100644 csrc/sm100/kernel/fmha_kernel_bwd_convert.hpp create mode 100644 csrc/sm100/kernel/fmha_kernel_bwd_sum_OdO.hpp create mode 100644 csrc/sm100/kernel/fmha_options.hpp create mode 100644 csrc/sm100/kernel/fmha_tile_scheduler.hpp create mode 100644 csrc/sm100/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp create mode 100644 csrc/sm100/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp create mode 100644 csrc/sm100/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp create mode 100644 csrc/sm100/pybind.cu rename csrc/{ => sm90}/flash_api.cpp (100%) rename csrc/{ => sm90}/kernels/config.h (100%) rename csrc/{ => sm90}/kernels/get_mla_metadata.cu (100%) rename csrc/{ => sm90}/kernels/get_mla_metadata.h (100%) rename csrc/{ => sm90}/kernels/mla_combine.cu (100%) rename csrc/{ => sm90}/kernels/mla_combine.h (100%) rename csrc/{ => sm90}/kernels/params.h (100%) rename csrc/{ => sm90}/kernels/splitkv_mla.cu (100%) rename csrc/{ => sm90}/kernels/splitkv_mla.h (100%) rename csrc/{ => sm90}/kernels/traits.h (100%) rename csrc/{ => sm90}/kernels/utils.h (100%) rename tests/{test_flash_mla.py => test_flash_mla_sm90.py} (100%) create mode 100644 tests/test_fmha_sm100.py diff --git a/README.md b/README.md index 5d66f55..07e021a 100644 --- a/README.md +++ b/README.md @@ -28,13 +28,21 @@ Currently released: ### Install ```bash -python setup.py install +pip install -v . ``` ### Benchmark +#### Testing MLA Decoding + +```bash +python tests/test_flash_mla_sm90.py +``` + +#### Testing MLA Forward/Backward + ```bash -python tests/test_flash_mla.py +python tests/test_fmha_sm100.py ``` It is able up to 3000 GB/s in memory-bound configuration and 660 TFLOPS in computation-bound configuration on H800 SXM5, using CUDA 12.8. diff --git a/csrc/sm100/collective/fmha_common.hpp b/csrc/sm100/collective/fmha_common.hpp new file mode 100644 index 0000000..c60d9e9 --- /dev/null +++ b/csrc/sm100/collective/fmha_common.hpp @@ -0,0 +1,127 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/kernel_hardware_info.h" +#include "cutlass/arch/reg_reconfig.h" +#include "cute/tensor.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; + +template +CUTE_DEVICE void gemm_reset_zero_acc(Atom& atom, TA const& tA, TB const& tB, TC&& tC) { + constexpr int rA = decltype(rank(tA))::value; + constexpr int rB = decltype(rank(tB))::value; + constexpr int rC = decltype(rank(tC))::value; + static_assert(rA == 3 && rB == 3 && rC == 3); + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tA); k_block++) { + cute::gemm(atom, tA(_,_,k_block), tB(_,_,k_block), tC); + atom.accumulate_ = decltype(atom.accumulate_)::One; + } +} + +template +CUTE_DEVICE void gemm_zero_acc(Atom& atom, TA const& tA, TB const& tB, TC&& tC) { + atom.accumulate_ = decltype(atom.accumulate_)::Zero; + gemm_reset_zero_acc(atom, tA, tB, tC); +} + +template +CUTE_DEVICE constexpr auto unstageSmemLayout(Layout const& layout, Stages stages = {}) { + return composition(layout, prepend(make_layout(stages), _)); +} + +template +CUTE_DEVICE T warp_uniform(T a) { + return __shfl_sync(0xffffffff, a, 0); +} + +template +CUTE_HOST_DEVICE constexpr +auto +to_tiled_mma_sm100_ts( + TiledMMA, cute::C, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant>, + TAs...>, TMs...>) { + + return TiledMMA>, + TAs...>, TMs...>{}; +} + +template +CUTE_HOST_DEVICE constexpr +auto +to_tiled_mma_sm100_ts( + TiledMMA, + TAs...>, TMs...>) { + return TiledMMA, + TAs...>, TMs...>{}; +} + +template +CUTLASS_DEVICE +void warpgroup_reg_set() { + if constexpr (RegCount < 128) { + cutlass::arch::warpgroup_reg_dealloc(); + } + else { + cutlass::arch::warpgroup_reg_alloc(); + } +} + +} // namespace cutlass::fmha::collective diff --git a/csrc/sm100/collective/fmha_fusion.hpp b/csrc/sm100/collective/fmha_fusion.hpp new file mode 100644 index 0000000..1486767 --- /dev/null +++ b/csrc/sm100/collective/fmha_fusion.hpp @@ -0,0 +1,396 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; + +struct NoMask { + template + CUTLASS_DEVICE + int get_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + return ceil_div(get<1>(problem_size), get<1>(tile_shape)); + } + + template + CUTLASS_DEVICE + int get_masked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + return 0; + } + + template + CUTLASS_DEVICE + int get_unmasked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + return get_trip_count(blk_coord, tile_shape, problem_size); + } + + template + CUTLASS_DEVICE + void apply_mask( + AccQK& acc_qk, + IndexQK const& index_qk, + ProblemSize const& problem_size) { + + return; + } +}; + +struct ResidualMask : NoMask { + + using Base = NoMask; + + template + CUTLASS_DEVICE int get_masked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + if (get<1>(problem_size) % get<1>(tile_shape) != 0) { + return 1; + } + return 0; + } + + template + CUTLASS_DEVICE + int get_unmasked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + // if the sequence length does not divide the tile size evenly + if (get<1>(problem_size) % get<1>(tile_shape) != 0) { + return get_trip_count(blk_coord, tile_shape, problem_size) - 1; + } + return get_trip_count(blk_coord, tile_shape, problem_size); + } + + template + CUTLASS_DEVICE + void apply_mask( + AccQK& acc_qk, + IndexQK const& index_qk, + ProblemSize const& problem_size) { + + // This is useful is seqlen_k % kBlockN != 0 since it masks + // the remaining elements out from softmax. + // d % kHeadDim != 0 or seqlen_q % kBlockM do not suffer from similar + // issues as they are transparently taken care of by TMA and the + // epilogue, if it is instantiated with predication support. + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_qk); i++) { + auto pos = index_qk(i); + if (get<1>(pos) >= get<1>(problem_size)) { + acc_qk(i) = -INFINITY; + } + } + } +}; + +struct ResidualMaskForBackward : NoMask { + + using Base = NoMask; + + template + CUTLASS_DEVICE int get_masked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + if (get<1>(problem_size) % get<1>(tile_shape) != 0) { + return 1; + } + return 0; + } + + template + CUTLASS_DEVICE + int get_unmasked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + // if the sequence length does not divide the tile size evenly + if (get<1>(problem_size) % get<1>(tile_shape) != 0) { + return get_trip_count(blk_coord, tile_shape, problem_size) - 1; + } + return get_trip_count(blk_coord, tile_shape, problem_size); + } + + template + CUTLASS_DEVICE + void apply_mask( + AccQK& acc_qk, + IndexQK const& index_qk, + ProblemSize const& problem_size) { + + // This is useful is seqlen_k % kBlockN != 0 since it masks + // the remaining elements out from softmax. + // d % kHeadDim != 0 or seqlen_q % kBlockM do not suffer from similar + // issues as they are transparently taken care of by TMA and the + // epilogue, if it is instantiated with predication support. + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_qk); i++) { + auto pos = index_qk(i); + if (! elem_less(pos, select<0,1>(problem_size))) { + acc_qk(i) = -INFINITY; + } + } + } +}; + +// There are two ways to do causal if N_Q != N_K +// (1) The Q is at the beginning of the matrix +// (2) The Q is at the end of the matrix +template +struct CausalMask : NoMask { + + using Base = NoMask; + + static constexpr bool IsQBegin = kIsQBegin; + + template + CUTLASS_DEVICE + int get_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + // See note below on different ways to think about causal attention + // Again, we'd add the offset_q into the max_blocks_q calculation + int max_blocks_k = Base::get_trip_count(blk_coord, tile_shape, problem_size); + if constexpr (IsQBegin) { + int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape), get<1>(tile_shape)); + return std::min(max_blocks_k, max_blocks_q); + } else { + const int offset_q = get<1>(problem_size) - get<0>(problem_size); + int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape) + offset_q, get<1>(tile_shape)); + return std::min(max_blocks_k, max_blocks_q); + } + } + + template + CUTLASS_DEVICE + int get_masked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + int trip_count = get_trip_count(blk_coord, tile_shape, problem_size); + if constexpr (IsQBegin) { + return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape)))); + } else { + const int offset_tile_q = get<1>(problem_size) % get<1>(tile_shape); + return std::min(trip_count, int(ceil_div(get<0>(tile_shape) + offset_tile_q, get<1>(tile_shape)))); + } + } + + template + CUTLASS_DEVICE + int get_unmasked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + return get_trip_count(blk_coord, tile_shape, problem_size) - get_masked_trip_count(blk_coord, tile_shape, problem_size); + } + + template + CUTLASS_DEVICE + void apply_mask( + AccQK& acc_qk, + IndexQK const& index_qk, + ProblemSize const& problem_size) { + + // There are two ways to do causal if N_Q != N_K + // (1) is to assume that the Q is at the beginning of the matrix + // - this is the default setting. + // (2) is that it is at the end of the matrix + // - this is usually what we want for inference settings + // where we only compute the next row and use cache for the rest + // - if you'd like this, you only need to set kIsQBegin=false + + if constexpr (IsQBegin) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_qk); i++) { + auto pos = index_qk(i); + if ((get<0>(pos) < get<1>(pos)) || (get<1>(pos) >= get<1>(problem_size))) { + acc_qk(i) = -INFINITY; + } + } + } else { + const auto offset_q = get<1>(problem_size) - get<0>(problem_size); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_qk); i++) { + auto pos = index_qk(i); + if ((get<0>(pos) + offset_q < get<1>(pos)) || (get<1>(pos) >= get<1>(problem_size))) { + acc_qk(i) = -INFINITY; + } + } + } + } +}; + +template +struct CausalForBackwardMask : CausalMask, ResidualMaskForBackward { + + using Base = CausalMask; + + template + CUTLASS_DEVICE + void apply_mask( + AccQK& acc_qk, + IndexQK const& index_qk, + ProblemSize const& problem_size) { + + // There are two ways to do causal if N_Q != N_K + // (1) is to assume that the Q is at the beginning of the matrix + // - this is what we demonstrate here + // (2) is that it is at the end of the matrix + // - this is usually what we want for inference settings + // where we only compute the next row and use cache for the rest + // - if you'd like this, you only need to add an offset like so: + // get<0>(pos) + offset_q < get<1>(pos) + int offset_q = 0; + if constexpr (!kIsQBegin) { + offset_q = get<1>(problem_size) - get<0>(problem_size); + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_qk); i++) { + auto pos = index_qk(i); + bool masked = (get<0>(pos) + offset_q < get<1>(pos)) || !elem_less(pos, problem_size); + if (masked) { + acc_qk(i) = -INFINITY; + } + } + } + +}; + +struct VariableLength { + int max_length; + int* cumulative_length = nullptr; + int total_length = -1; + + CUTE_HOST_DEVICE operator int() const { + return max_length; + } +}; + +template struct is_variable_length_impl : std::false_type {}; +template<> struct is_variable_length_impl : std::true_type {}; +template constexpr bool is_variable_length_v = is_variable_length_impl>::value; + +template +CUTE_HOST_DEVICE +constexpr auto +apply_variable_length(Shape const& shape, Idx const& idx) { + return transform_leaf(shape, [&](auto const& s) { + if constexpr (is_variable_length_v) { + return s.cumulative_length[idx+1] - s.cumulative_length[idx]; + } + else { + return s; + } + }); +} + +template +CUTE_HOST_DEVICE +constexpr auto +apply_variable_length(Shape const& shape, Coord const& coord, Idx const& idx) { + auto new_shape = apply_variable_length(shape, idx); + auto new_coord = transform_leaf(shape, coord, [&](auto const& s, auto const& c) { + if constexpr (is_variable_length_v) { + return cute::make_tuple(c, s.cumulative_length[idx]); + } + else { + return c; + } + }); + return cute::make_tuple(new_shape, new_coord); +} + +template +CUTE_HOST_DEVICE +constexpr auto +apply_variable_length_offset(Shape const& shape, Coord const& coord) { + auto idx = back(back(coord)); + auto result_shape = transform_leaf(shape, [&](auto const& s) { + if constexpr (is_variable_length_v) { + return s.cumulative_length[idx+1] - s.cumulative_length[idx]; + } + else { + return s; + } + }); + auto result_offset = transform_leaf(coord, shape, [&](auto const& c, auto const& s) { + if constexpr (is_variable_length_v) { + return s.cumulative_length[idx]; + } + else { + return _0{}; + } + }); + return cute::make_tuple(result_shape, result_offset); +} + +} // namespace cutlass::fmha::collective + +namespace cute { + +template<> +struct is_integral : true_type {}; + +CUTE_HOST_DEVICE +void print(cutlass::fmha::collective::VariableLength a) { + printf("Varlen<%d, %p>", a.max_length, a.cumulative_length); +} + +} diff --git a/csrc/sm100/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp b/csrc/sm100/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp new file mode 100644 index 0000000..616357c --- /dev/null +++ b/csrc/sm100/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp @@ -0,0 +1,234 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/layout.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +namespace cutlass::fmha::collective { + +template< + class Element, + class ElementAcc, + class TileShape, // Q, D, _ + class StrideO, // Q, D, B + class StrideLSE_, // Q, B + class OrderLoadEpilogue = cute::false_type +> +struct Sm100FmhaFwdEpilogueTmaWarpspecialized { + + using Pipeline = cutlass::PipelineAsync<2>; + +// using SmemLayoutO = decltypa(make_layout(append<3>(select<0,1>(TileShape_WG{}), _2{}))); + using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + cute::UMMA::Major::K, Element, tuple_element_t<0, TileShape>, tuple_element_t<1, TileShape>>()); +// using SmemLayoutAtomO = decltype(make_ordered_layout(select<0,1>(TileShape{}), Step<_1, _0>{})); + using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, replace<2>(TileShape{}, _2{}), Step<_2, _1, _3>{})); + using SmemLayoutO_ = SmemLayoutO; + using StrideLSE = StrideLSE_; + using ElementOut = Element; + + static const int NumWarpsEpilogue = 1; + static const int NumWarpsLoad = 1; + + struct TensorStorage { + + using SmemLayoutO = SmemLayoutO_; + cute::array_aligned> smem_o; + + }; + + struct Arguments { + Element* ptr_O; + StrideO dO; + + ElementAcc* ptr_LSE; + StrideLSE dLSE; + }; + + using TMA_O = decltype(make_tma_copy( + SM90_TMA_STORE{}, + make_tensor((Element*) nullptr, repeat_like(StrideO{}, 0), StrideO{}), + SmemLayoutO{}(_,_,_0{}) + )); + + + struct Params { + TMA_O tma_store_o; + + ElementAcc* ptr_LSE; + StrideLSE dLSE; + }; + + // FMHA and MLA have different input ProblemShapes; + // get problem_shape_O according to the input ProblemShape. + template + CUTLASS_DEVICE static constexpr + auto get_problem_shape_O ( + ProblemShape const& problem_shape) { + if constexpr (rank_v(ProblemShape{}))> == 2) { + return replace<1>(select<0,2,3>(problem_shape), get<2, 0>(problem_shape)); + } else { + return select<0,2,3>(problem_shape); + } + } + + template + static Params to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + void* workspace = nullptr) { + + auto ptr_O = args.ptr_O; + StrideO dO = args.dO; + + auto problem_shape_O = get_problem_shape_O(problem_shape); + + if constexpr (is_variable_length_v>) { + auto cumulative_length_q = get<0>(problem_shape).cumulative_length; + if (cumulative_length_q != nullptr) { + int max_length_q = get<0>(problem_shape).max_length; + // for variable sequence lenght, the batch is in units of row_stride + get<2,1>(dO) = get<0>(dO); + get<2,1>(problem_shape_O) = max_length_q * (1 + get<2,1>(problem_shape_O)); + // offset ptr by the amount we add back in later + ptr_O -= max_length_q * get<0>(dO); + } + } + + auto tma_store_o = make_tma_copy( + SM90_TMA_STORE{}, + make_tensor(ptr_O, problem_shape_O, dO), + SmemLayoutO{}(_,_,_0{}) + ); + + return { + tma_store_o, + args.ptr_LSE, + args.dLSE + }; + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + cute::prefetch_tma_descriptor(params.tma_store_o.get_tma_descriptor()); + } + + const Params& params; + + CUTLASS_DEVICE Sm100FmhaFwdEpilogueTmaWarpspecialized(const Params& params) : params(params) {} + + template + CUTLASS_DEVICE auto + store( + BlkCoord const& blk_coord_in, ProblemShape const& problem_shape, + Params const& params, ParamsProblemShape const& params_problem_shape, + TensorStorage& shared_storage, + Pipeline& pipeline, typename Pipeline::PipelineState& pipeline_consumer_state) { + + BlkCoord blk_coord = blk_coord_in; + uint32_t lane_predicate = cute::elect_one_sync(); + + using X = Underscore; + + int o0_index = 2 * get<0>(blk_coord); + int o1_index = 2 * get<0>(blk_coord) + 1; + + Tensor mO_qdl_p = params.tma_store_o.get_tma_tensor(get_problem_shape_O(problem_shape)); + // offset mode 0 by (max_length - real_length) + // offset mode 3,1 by cumulative_length + real_length + // the ptr is already offset by - max_length + // so in total this achieves + int offs_0 = 0; + int offs_2_1 = 0; + + if constexpr (is_variable_length_v>) { + auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length; + if (cumulative_length_q != nullptr) { + int max_length_q = get<0>(params_problem_shape).max_length; + offs_0 = max_length_q - get<0>(problem_shape); + offs_2_1 = cumulative_length_q[get<2,1>(blk_coord)] + get<0>(problem_shape); + get<2,1>(blk_coord) = 0; + } + } + + Tensor mO_qdl = domain_offset(make_coord(offs_0, _0{}, make_coord(_0{}, offs_2_1)), mO_qdl_p); + + Tensor gO_qdl = local_tile(mO_qdl, TileShape{}, make_coord(_, _, _), Step<_1, _1, X>{}); + Tensor gO = gO_qdl(_, _, _, _0{}, get<2>(blk_coord)); + Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{}); + auto block_tma = params.tma_store_o.get_slice(0); + Tensor tOsO = block_tma.partition_S(sO); + Tensor tOgO = block_tma.partition_D(gO); + + auto pipeline_release_state = pipeline_consumer_state; + + // O1 O2 + // one pipeline: O + // wait from corr, issue tma store on smem + pipeline.consumer_wait(pipeline_consumer_state); + ++pipeline_consumer_state; + + if (lane_predicate) { + copy(params.tma_store_o, tOsO(_,_,_,_0{}), tOgO(_,_,_,o0_index)); + } + tma_store_arrive(); + + pipeline.consumer_wait(pipeline_consumer_state); + ++pipeline_consumer_state; + + if (lane_predicate) { + copy(params.tma_store_o, tOsO(_,_,_,_1{}), tOgO(_,_,_,o1_index)); + } + tma_store_arrive(); + + tma_store_wait<1>(); + + pipeline.consumer_release(pipeline_release_state); + ++pipeline_release_state; + + tma_store_wait<0>(); + + if constexpr (cute::is_same_v) { + cutlass::arch::NamedBarrier::arrive((NumWarpsLoad + NumWarpsEpilogue) * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + } + + pipeline.consumer_release(pipeline_release_state); + ++pipeline_release_state; + + } + +}; + +} // namespace cutlass::fmha::collective diff --git a/csrc/sm100/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp b/csrc/sm100/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp new file mode 100644 index 0000000..f39fd75 --- /dev/null +++ b/csrc/sm100/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp @@ -0,0 +1,1218 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cute/arch/simd_sm100.hpp" +#include "cute/tensor.hpp" +#include "cute/layout.hpp" + +#include "collective/fmha_common.hpp" +#include "collective/fmha_fusion.hpp" +#include "collective/sm100_fmha_load_tma_warpspecialized.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; + +template< + class Element_, + class ElementQK_, + class ElementPV_, + class TileShape_, + class StrideQ_, + class StrideK_, + class StrideV_, + class Mask_, + // shape here is QG K H + // and referes to the two softmax warps + // (2, 1, 1) means that they are stacked (best for large Q since it loads the least K/V) + // (1, 2, 1) means they sit side by side (best for small Q / large K) + class ThreadShape = Shape<_2, _1, _1>, + // Since shared memory is sufficient for FMHA, there is no need to reuse shared memory. + class OrderLoadEpilogue = cute::false_type +> +struct Sm100FmhaFwdMainloopTmaWarpspecialized { + + using Element = Element_; + using ElementQK = ElementQK_; + using ElementPV = ElementPV_; + using TileShape = TileShape_; + using StrideQ = StrideQ_; + using StrideK = StrideK_; + using StrideV = StrideV_; + using Mask = Mask_; + + static constexpr int StageCountQ = 2; + static constexpr int StageCountKV = sizeof(Element_) == 1 ? 4 : 3; + + using StagesQ = cutlass::gemm::collective::StageCount; + using StagesKV = cutlass::gemm::collective::StageCount; + + using ClusterShape = Shape<_1, _1, _1>; + + static const int Alignment = 128 / sizeof_bits_v; + + using TileShapeQK = decltype(shape_div(TileShape{}, ThreadShape{})); + + using TileShapePV = decltype(select<0,2,1>(TileShapeQK{})); + + using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, StrideQ, Alignment, + Element, StrideK, Alignment, + ElementQK, + TileShapeQK, ClusterShape, cutlass::gemm::collective::StageCount<3> /* we change it later anyways*/, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100>::CollectiveOp; + + using CollectiveMmaPV = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // the stride for A does not matter since we do not load from smem at all + Element, StrideK, Alignment, + Element, decltype(select<1,0,2>(StrideV{})), Alignment, + ElementPV, + TileShapePV, ClusterShape, cutlass::gemm::collective::StageCount<3> /* we change it later anyways*/, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100>::CollectiveOp; + + using SmemLayoutQ = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutA{}, Int{})); + using SmemLayoutK = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutB{}, Int{})); + using SmemLayoutV = decltype(unstageSmemLayout(typename CollectiveMmaPV::SmemLayoutB{}, Int{})); + + // Reuse shared memory for V and O. + static constexpr bool IsOrderLoadEpilogue = std::is_same_v; + struct TensorStorage { + cute::array_aligned> smem_q; + union { + cute::array_aligned> smem_k; + cute::array_aligned> smem_v; + }; + }; + + enum class TmemAllocation : uint32_t { + kSizeS = 128, + kSizeO = 128, + kSizeP = 32, + S0 = 0, + S1 = S0 + kSizeS, + V0 = S0, // stats storage from softmax to correction + V1 = S1, + P0 = S0 + kSizeP, + P1 = S1 + kSizeP, + O0 = S1 + kSizeS, + O1 = O0 + kSizeO, + kEnd = O1 + kSizeO + }; + + // indices for V0 / V1 + enum : int { + kIdxOldRowMax = 0, + kIdxNewRowMax = 1, + kIdxFinalRowSum = 0, + kIdxFinalRowMax = 1 + }; + + // from load to mma warp, protects q in smem + using PipelineQ = cutlass::PipelineTmaUmmaAsync< + StageCountQ, + typename CollectiveMmaQK::AtomThrShapeMNK + >; + + // from load to mma warp, protects k/v in smem + using PipelineKV = cutlass::PipelineTmaUmmaAsync< + StageCountKV, + typename CollectiveMmaQK::AtomThrShapeMNK + >; + + // from mma to softmax0/1 warp, protects S in tmem + // (not sure yet about the reverse direction) + // there is one pipe per softmax warp, and the mma warp alternates between them + using PipelineS = cutlass::PipelineUmmaAsync<1>; + + // from softmax0/1/ to correction wg + using PipelineC = cutlass::PipelineAsync<1>; + + // from mma to correction + using PipelineO = cutlass::PipelineUmmaAsync<2>; + + // from corr to epilogue + using PipelineE = cutlass::PipelineAsync<2>; + + using OrderBarrierSoftmax = cutlass::OrderedSequenceBarrier< + /*stages*/ 1, /*groups*/ 2>; + + static const int TransactionBytesLoadQ = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v); + + static const int TransactionBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v); + static const int TransactionBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v); + + static_assert(TransactionBytesLoadK == TransactionBytesLoadV, "K and V smem layouts must be of equal size"); + + using Load = Sm100FmhaLoadTmaWarpspecialized< + Element, StrideQ, StrideK, StrideV, + CollectiveMmaQK, CollectiveMmaPV, + SmemLayoutQ, SmemLayoutK, SmemLayoutV, + TensorStorage, PipelineQ, PipelineKV, Mask, TileShape + >; + + struct Arguments { + typename Load::Arguments load; + + // if zero, defaults to 1/sqrt(D) + float scale_softmax = 0.0f; + + // scaling factors to dequantize QKV + float scale_q = 1.0f; + float scale_k = 1.0f; + float scale_v = 1.0f; + + // scaling factor to quantize O + float inv_scale_o = 1.0f; + }; + + struct Params { + typename Load::Params load; + + float scale_softmax; + float scale_softmax_log2; + + float scale_output; + }; + + template + static bool can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static Params to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + void* workspace) { + + float scale_softmax = args.scale_softmax; + if (scale_softmax == 0.0f) { + scale_softmax = 1.0f / (float) std::sqrt(get<2>(problem_shape)); + } + float log2_e = static_cast(std::log2(std::exp(1.0))); + + return Params{ + Load::to_underlying_arguments(problem_shape, args.load, workspace), + args.scale_q * args.scale_k * scale_softmax, + args.scale_q * args.scale_k * log2_e * scale_softmax, + args.scale_v * args.inv_scale_o + }; + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + Load::prefetch_tma_descriptors(params.load); + } + + template + CUTLASS_DEVICE void + load( + BlkCoord const& blk_coord, ProblemShape const& problem_shape, + Params const& params, ParamsProblemShape const& params_problem_shape, + TensorStorage& storage, + PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state, + PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) { + + Load load; + load.load(blk_coord, problem_shape, params.load, params_problem_shape, + storage, + pipeline_q, pipeline_q_producer_state, + pipeline_kv, pipeline_kv_producer_state); + } + + template + CUTLASS_DEVICE auto + mma( + BlkCoord const& blk_coord, + Params const& params, ProblemShape const& problem_shape, + TensorStorage& storage, + PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_consumer_state, + PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_consumer_state, + PipelineS& pipeline_s0, typename PipelineS::PipelineState& pipeline_s0_producer_state, + PipelineS& pipeline_s1, typename PipelineS::PipelineState& pipeline_s1_producer_state, + PipelineO& pipeline_corr, typename PipelineO::PipelineState& pipeline_corr_producer_state) { + + auto pipeline_q_release_state = pipeline_q_consumer_state; + auto pipeline_kv_release_state = pipeline_kv_consumer_state; + + int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); + + typename CollectiveMmaQK::TiledMma mma_qk; + ThrMMA thr_mma_qk = mma_qk.get_slice(0); + + typename CollectiveMmaPV::TiledMma mma_pv; + TiledMMA mma_pv_ts = to_tiled_mma_sm100_ts(mma_pv); + ThrMMA thr_mma_pv = mma_pv_ts.get_slice(0); + + Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{}); + Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{}); + Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{}); + + Tensor tSrQ = thr_mma_qk.make_fragment_A(sQ); + Tensor tSrK = thr_mma_qk.make_fragment_B(sK); + Tensor tOrV = thr_mma_pv.make_fragment_B(sV); + + // tmem layout is + // S0 S1`O0 O1 + // sequential in memory, where S overlaps with P and V + + Tensor tStS = partition_fragment_C(mma_qk, select<0,1>(TileShapeQK{})); + Tensor tOtO = partition_fragment_C(mma_pv_ts, select<0,1>(TileShapePV{})); + + Tensor tStS0 = tStS; + tStS0.data() = tStS.data().get() + uint32_t(TmemAllocation::S0); + Tensor tStS1 = tStS; + tStS1.data() = tStS.data().get() + uint32_t(TmemAllocation::S1); + + Tensor tOtO0 = tOtO; + tOtO0.data() = tOtO.data().get() + uint32_t(TmemAllocation::O0); + Tensor tOtO1 = tOtO; + tOtO1.data() = tOtO.data().get() + uint32_t(TmemAllocation::O1); + + Tensor sP = make_tensor(make_smem_ptr((Element*)nullptr), typename CollectiveMmaPV::SmemLayoutA{}); + Tensor tOrP = thr_mma_pv.make_fragment_A(sP)(_, _, _, _0{}); // slice out staging + + Tensor tOrP0 = tOrP; + tOrP0.data() = tOrP0.data().get() + uint32_t(TmemAllocation::P0); + Tensor tOrP1 = tOrP; + tOrP1.data() = tOrP1.data().get() + uint32_t(TmemAllocation::P1); + + int k_index = 0; + int v_index = 0; + int q_index = 0; + + // wait for Q1 + q_index = pipeline_q_consumer_state.index(); + pipeline_q.consumer_wait(pipeline_q_consumer_state); + ++pipeline_q_consumer_state; + + Tensor tSrQ0 = tSrQ(_,_,_,q_index); + + + // wait for K1 + k_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // gemm Q1 * K1 -> S1 + pipeline_s0.producer_acquire(pipeline_s0_producer_state); + + gemm_zero_acc(mma_qk, tSrQ0, tSrK(_,_,_,k_index), tStS0); + + pipeline_s0.producer_commit(pipeline_s0_producer_state); + ++pipeline_s0_producer_state; + + // release K1 + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + + // wait for Q2 + if constexpr (get<0>(ThreadShape{}) > 1 || get<2>(ThreadShape{}) > 1) { + q_index = pipeline_q_consumer_state.index(); + pipeline_q.consumer_wait(pipeline_q_consumer_state); + ++pipeline_q_consumer_state; + } + + Tensor tSrQ1 = tSrQ(_,_,_,q_index); + + if constexpr (get<1>(ThreadShape{}) > 1) { + k_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + pipeline_s1.producer_acquire(pipeline_s1_producer_state); + + // gemm Q2 * K1 -> S2 + gemm_zero_acc(mma_qk, tSrQ1, tSrK(_,_,_,k_index), tStS1); + + pipeline_s1.producer_commit(pipeline_s1_producer_state); + ++pipeline_s1_producer_state; + + // release K1 + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + // wait for V1 + v_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // this acquire returns the ownership of all of S0 to the mma warp + // including the P0 part + // acquire corr first to take it out of the critical + // path since softmax takes longer + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + pipeline_s0.producer_acquire(pipeline_s0_producer_state); + + // gemm P1 * V1 -> O1 + gemm_zero_acc(mma_pv_ts, tOrP0, tOrV(_,_,_,v_index), tOtO0); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + + mma_pv_ts.accumulate_ = UMMA::ScaleOut::Zero; + + // loop: + mask_tile_count -= 1; + for (; mask_tile_count > 0; mask_tile_count -= 1) { + + // wait for Ki + k_index = (pipeline_kv_consumer_state.index()); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // gemm Q1 * Ki -> S1 + gemm_zero_acc(mma_qk, tSrQ0, tSrK(_,_,_,k_index), tStS0); + + pipeline_s0.producer_commit(pipeline_s0_producer_state); + ++pipeline_s0_producer_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + + // gemm P2 * V(i-1) -> O2 + if constexpr (get<1>(ThreadShape{}) > 1) { + v_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + pipeline_s1.producer_acquire(pipeline_s1_producer_state); + + gemm_reset_zero_acc(mma_pv_ts, tOrP1, tOrV(_,_,_,v_index), tOtO1); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + // release V(i-1) + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + k_index = (pipeline_kv_consumer_state.index()); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + // gemm Q2 * Ki -> S2 + gemm_zero_acc(mma_qk, tSrQ1, tSrK(_,_,_,k_index), tStS1); + + pipeline_s1.producer_commit(pipeline_s1_producer_state); + ++pipeline_s1_producer_state; + + // release Ki + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + // wait for Vi + v_index = (pipeline_kv_consumer_state.index()); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // gemm P1 * Vi -> O1 + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + + pipeline_s0.producer_acquire(pipeline_s0_producer_state); + + gemm_reset_zero_acc(mma_pv_ts, tOrP0, tOrV(_,_,_,v_index), tOtO0); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + } + + // release Q1 + pipeline_q.consumer_release(pipeline_q_release_state); + ++pipeline_q_release_state; + + // release Q2 + if constexpr (get<0>(ThreadShape{}) > 1) { + pipeline_q.consumer_release(pipeline_q_release_state); + ++pipeline_q_release_state; + } + + // wait for Vi + if constexpr (get<1>(ThreadShape{}) > 1) { + v_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + // gemm P2 * Vi -> O2 + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + pipeline_s1.producer_acquire(pipeline_s1_producer_state); + + gemm_reset_zero_acc(mma_pv_ts, tOrP1, tOrV(_,_,_,v_index), tOtO1); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + // release Vi + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + pipeline_s0.producer_commit(pipeline_s0_producer_state); + ++pipeline_s0_producer_state; + + pipeline_s1.producer_commit(pipeline_s1_producer_state); + ++pipeline_s1_producer_state; + + // T0 S00 B1, T0 S10 B1, T0 S00 B2, T0 S01 B1, T0 S10 B2, T0 S11 B1, T0 S01 B2, T1 S00 B1, T0 S11 B2, ... + // Q1 * K1 , Q2 * K1 , S11 * V1 , Q1 * K2 , S21 * V1 , Q2 * K2 , S12 * V2 , Q1 * K3 , S22 * K2 , ... + } + + template + CUTLASS_DEVICE auto + softmax_step( + float& row_max, float& row_sum, + Stage stage, bool final_call, + BlkCoord const& blk_coord, CoordTensor const& cS, + Params const& params, ProblemShape const& problem_shape, + PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state, + PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state, + OrderBarrierSoftmax& order_s) { + + Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); + + Tensor tStS = partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0,1>(TileShapeQK{})); + tStS.data() = uint32_t(stage == _0{} ? TmemAllocation::S0 : TmemAllocation::S1); + + Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{}))); + tStS_v.data() = uint32_t(stage == _0{} ? TmemAllocation::V0 : TmemAllocation::V1); + Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{}))); + + auto tilePlikeFP32 = size<1>(TileShapeQK{}) / Int{} * Int{}; + Tensor tStS_P = tStS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); + tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1)); + Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); + + // Each thread owns a single row + using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem + using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 128 cols of 8b elem + using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tStS); + auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + + Tensor tTMEM_LOADtS = thr_tmem_load.partition_S(tStS); + Tensor tTMEM_LOADcS = thr_tmem_load.partition_D(tScS); + + auto tiled_tmem_storev = make_tmem_copy(TMEM_STORE_V{}, tStS_v); + auto thr_tmem_storev = tiled_tmem_storev.get_slice(thread_idx); + + Tensor tTMEM_STOREVtS = thr_tmem_storev.partition_D(tStS_v); + Tensor tTMEM_STOREVcS = thr_tmem_storev.partition_S(tScS_v); + + auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tStS_P); + auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx); + + Tensor tTMEM_STOREtS_x4 = thr_tmem_store.partition_D(tStS_P); + tTMEM_STOREtS_x4.data() = warp_uniform(tTMEM_STOREtS_x4.data().get()); + Tensor tTMEM_STOREcS = thr_tmem_store.partition_S(tScS_P); + + // wait on tensor core pipe + pipeline_s.consumer_wait(pipeline_s_consumer_state); + + // read all of S from tmem into reg mem + Tensor tTMEM_LOADrS = make_tensor(shape(tTMEM_LOADcS)); + copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS); + + if constexpr (need_apply_mask) { + Mask{}.apply_mask(tTMEM_LOADrS, tTMEM_LOADcS, problem_shape); + } + + ElementQK old_row_max = row_max; + { + // compute rowmax + float row_max_0 = row_max; + float row_max_1 = row_max; + float row_max_2 = row_max; + float row_max_3 = row_max; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTMEM_LOADrS); i += 4) { + row_max_0 = ::fmax(row_max_0, tTMEM_LOADrS(i)); + row_max_1 = ::fmax(row_max_1, tTMEM_LOADrS(i+1)); + row_max_2 = ::fmax(row_max_2, tTMEM_LOADrS(i+2)); + row_max_3 = ::fmax(row_max_3, tTMEM_LOADrS(i+3)); + } + row_max = ::fmax(row_max_0, row_max_1); + row_max = ::fmax(row_max, row_max_2); + row_max = ::fmax(row_max, row_max_3); + } + + ElementQK row_max_safe = row_max == -INFINITY ? 0 : row_max; + + Tensor tTMEM_STOREVrS = make_tensor(shape(tTMEM_STOREVcS)); + tTMEM_STOREVrS(kIdxOldRowMax) = old_row_max; + tTMEM_STOREVrS(kIdxNewRowMax) = row_max_safe; + copy(tiled_tmem_storev, tTMEM_STOREVrS, tTMEM_STOREVtS); + + pipeline_c.producer_commit(pipeline_c_producer_state); + ++pipeline_c_producer_state; + + // notify correction wg that they are ready (might need addtl ordering between S0 and S1 WG's) + + ElementQK scale = params.scale_softmax_log2; + ElementQK row_max_scale = row_max_safe * scale; + + float2 scale_fp32x2 = make_float2(scale, scale); + float2 minus_row_max_scale_fp32x2 = make_float2(-row_max_scale, -row_max_scale); + + Tensor tTMEM_STORErS_x4 = make_tensor(shape(tTMEM_STOREcS)); + + constexpr int kConversionsPerStep = 2; + + Tensor tTMEM_STORErS_x4_e = recast>(tTMEM_STORErS_x4); + + NumericArrayConverter convert; + + const int kReleasePipeCount = 10; // must be multiple of 2 + + order_s.wait(); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTMEM_LOADrS); i += 2) { + float2 in = make_float2( + tTMEM_LOADrS(i + 0), + tTMEM_LOADrS(i + 1) + ); + float2 out; + cute::fma(out, scale_fp32x2, in, minus_row_max_scale_fp32x2); + tTMEM_LOADrS(i + 0) = out.x; + tTMEM_LOADrS(i + 1) = out.y; + + tTMEM_LOADrS(i+0) = ::exp2f(tTMEM_LOADrS(i+0)); + tTMEM_LOADrS(i+1) = ::exp2f(tTMEM_LOADrS(i+1)); + + Array in_conv; + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kConversionsPerStep; j++) { + in_conv[j] = tTMEM_LOADrS(i + j); + } + tTMEM_STORErS_x4_e[i / kConversionsPerStep] = convert(in_conv); + + + if (i == size(tTMEM_LOADrS) - kReleasePipeCount) { + order_s.arrive(); + } + + // this prevents register spills in fp16 + if constexpr (size<2>(tTMEM_STORErS_x4) == _2{}) { + if (i == size(tTMEM_LOADrS) - 6) { + copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, 0), tTMEM_STOREtS_x4(_, _, 0)); + } + } + } + + // tmem_store(reg_S8) -> op_P + CUTE_STATIC_ASSERT_V(size<2>(tTMEM_STORErS_x4) <= _2{}); + CUTE_STATIC_ASSERT_V(size<1>(tTMEM_STORErS_x4) == _1{}); + copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1), tTMEM_STOREtS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1)); + + cutlass::arch::fence_view_async_tmem_store(); + + // notify tensor core warp that P is ready + pipeline_s.consumer_release(pipeline_s_consumer_state); + ++pipeline_s_consumer_state; + + pipeline_c.producer_acquire(pipeline_c_producer_state); + + ElementQK acc_scale = 0.5f * ::exp2f(scale * (old_row_max - row_max_safe)); + row_sum *= acc_scale; + // row_sum = sum(reg_S) + float2 local_row_sum_f32x2 = make_float2(row_sum, row_sum); + float2 local_row_sum_1 = make_float2(0, 0); + float2 local_row_sum_2 = make_float2(0, 0); + float2 local_row_sum_3 = make_float2(0, 0); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTMEM_LOADrS); i += 8) { + // row_sum += tTMEM_LOADrS(i); + float2 in = make_float2(tTMEM_LOADrS(i), tTMEM_LOADrS(i+1)); + cute::add(local_row_sum_f32x2, local_row_sum_f32x2, in); + + in = make_float2(tTMEM_LOADrS(i+2), tTMEM_LOADrS(i+2+1)); + cute::add(local_row_sum_1, local_row_sum_1, in); + + in = make_float2(tTMEM_LOADrS(i+4), tTMEM_LOADrS(i+4+1)); + cute::add(local_row_sum_2, local_row_sum_2, in); + + in = make_float2(tTMEM_LOADrS(i+6), tTMEM_LOADrS(i+6+1)); + cute::add(local_row_sum_3, local_row_sum_3, in); + } + + cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_1); + cute::add(local_row_sum_2, local_row_sum_2, local_row_sum_3); + cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_2); + float local_row_sum = local_row_sum_f32x2.x + local_row_sum_f32x2.y; + + row_sum = local_row_sum; + + if (final_call) { + // re-acquire the S part in the final step + pipeline_s.consumer_wait(pipeline_s_consumer_state); + + Tensor tTMEM_STOREVrS = make_tensor(shape(tTMEM_STOREVcS)); + tTMEM_STOREVrS(kIdxFinalRowMax) = row_max; + tTMEM_STOREVrS(kIdxFinalRowSum) = row_sum; + copy(tiled_tmem_storev, tTMEM_STOREVrS, tTMEM_STOREVtS); + } + } + + template + CUTLASS_DEVICE auto + softmax( + Stage stage, + BlkCoord const& blk_coord, + Params const& params, ProblemShape const& problem_shape, + PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state, + PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state, + OrderBarrierSoftmax& order_s) { + + int mask_tile_count = Mask{}.get_unmasked_trip_count(blk_coord, TileShape{}, problem_shape); + + ElementQK row_max = -INFINITY; + ElementQK row_sum = 0; + + Tensor cS_base = make_identity_tensor(select<0,1>(TileShapeQK{})); + auto logical_offset = make_coord( + get<0>(blk_coord) * get<0>(TileShape{}) + (stage % get<0>(ThreadShape{})) * get<0>(TileShapeQK{}), + 0 + (stage % get<1>(ThreadShape{})) * get<1>(TileShapeQK{}) + ); + Tensor cS = domain_offset(logical_offset, cS_base); + + pipeline_c.producer_acquire(pipeline_c_producer_state); + + CUTLASS_PRAGMA_NO_UNROLL + for (; mask_tile_count > 0; mask_tile_count -= 1) { + softmax_step( + row_max, row_sum, stage, + (mask_tile_count == 1) && + (Mask{}.get_masked_trip_count(blk_coord, TileShape{}, problem_shape) == 0), + blk_coord, cS, params, problem_shape, + pipeline_s, pipeline_s_consumer_state, + pipeline_c, pipeline_c_producer_state, + order_s + ); + + cS.data() = cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{}); + } + + // Masked iterations + mask_tile_count = Mask{}.get_masked_trip_count(blk_coord, TileShape{}, problem_shape); + + CUTLASS_PRAGMA_NO_UNROLL + for (; mask_tile_count > 0; mask_tile_count -= 1) { + softmax_step( + row_max, row_sum, stage, mask_tile_count == 1, + blk_coord, cS, params, problem_shape, + pipeline_s, pipeline_s_consumer_state, + pipeline_c, pipeline_c_producer_state, + order_s + ); + + cS.data() = cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{}); + } + + pipeline_c.producer_commit(pipeline_c_producer_state); + ++pipeline_c_producer_state; + + pipeline_c.producer_acquire(pipeline_c_producer_state); + // empty step to sync against pipe s + pipeline_s.consumer_release(pipeline_s_consumer_state); + ++pipeline_s_consumer_state; + } + + template + CUTLASS_DEVICE auto + correction_epilogue( + float scale, + Stage stage, + TensorO const& sO_01) { + + using ElementOut = typename TensorO::value_type; + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + Tensor sO = sO_01(_,_,stage); + + // As opposed to the softmax, we do not have enough registers here + // to load all of the values (for tile kv = 128), so we loop + // good values would be either 32 or 64 + const int kCorrectionTileSize = 32 / sizeof(ElementOut); + + using TMEM_LOAD = std::conditional_t; // 4x32 threads with 64 cols of 32b elem + + typename CollectiveMmaPV::TiledMma mma; + Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{})); + Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{})); + Tensor tOcO = mma.get_slice(0).partition_C(cO); + Tensor tOsO = mma.get_slice(0).partition_C(sO); + + Tensor tOtO_i = logical_divide(tOtO, make_layout(make_shape(_128{}, Int{}))); + Tensor tOcO_i = logical_divide(tOcO, make_layout(make_shape(_128{}, Int{}))); + Tensor tOsO_i = logical_divide(tOsO, make_layout(make_shape(_128{}, Int{}))); + + if constexpr (decltype(stage == _0{})::value) { + tOtO_i.data() = tOtO_i.data().get() + uint32_t(TmemAllocation::O0); + } + else { + static_assert(decltype(stage == _1{})::value, "stage is either 0 or 1"); + tOtO_i.data() = tOtO_i.data().get() + uint32_t(TmemAllocation::O1); + } + + auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i(make_coord(_, _), _0{})); + auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + + Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i(make_coord(_, _), _)); + Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i(make_coord(_, _), _)); + Tensor tTMEM_LOADsO = thr_tmem_load.partition_D(tOsO_i(make_coord(_, _), _)); + + float2 scale_f32x2 = make_float2(scale, scale); + + // loop: + // TMEM_LOAD, FMUL2 scale, TMEM_STORE + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < get<2>(TileShape{}) / kCorrectionTileSize; i++) { + Tensor tTMEM_LOADtO_i = tTMEM_LOADtO(_, _0{}, _0{}, i); + Tensor tTMEM_LOADsO_i = tTMEM_LOADsO(_, _0{}, _0{}, i); + + Tensor tTMrO = make_tensor(shape(tTMEM_LOADcO(_, _0{}, _0{}, i))); + + copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO); + +#ifndef ONLY_SOFTMAX + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(tTMrO); j += 2) { + float2 in = make_float2(tTMrO(j), tTMrO(j+1)); + float2 out; + cute::mul(out, scale_f32x2, in); + tTMrO(j) = out.x; + tTMrO(j+1) = out.y; + } +#endif + + constexpr int N = 4 / sizeof(ElementOut); + NumericArrayConverter convert; + + Tensor tSMrO = make_tensor_like(tTMrO); + + Tensor tCs = recast(tTMrO); + Tensor tCd = recast(tSMrO); + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(tCs); j++) { + tCd(j) = convert.convert(tCs(j)); + } + + Tensor tSMsO_i = recast(tTMEM_LOADsO_i); + Tensor tSMrO_i = recast(tSMrO); + + copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tSMrO_i, tSMsO_i); + } + + cutlass::arch::fence_view_async_shared(); + } + + CUTLASS_DEVICE auto + correction_rescale( + float scale, + uint32_t tmem_O) { + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + // As opposed to the softmax, we do not have enough registers here + // to load all of the values (for tile kv = 128), so we loop + // good values would be either 32 or 64 + const int kCorrectionTileSize = 16; + + using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b16x; // 4x32 threads with 64 cols of 32b elem + using TMEM_STORE = SM100_TMEM_STORE_32dp32b16x; // 4x32 threads with 64 cols of 32b elem + + typename CollectiveMmaPV::TiledMma mma; + Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{})); + Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{})); + Tensor tOcO = mma.get_slice(0).partition_C(cO); + + Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int{}))); + Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int{}))); + + tOtO_i.data() = tOtO_i.data().get() + tmem_O; + + auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i); + auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tOtO_i); + auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx); + + Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i); + Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i); + Tensor tTMEM_STOREtO = thr_tmem_store.partition_D(tOtO_i); + Tensor tTMEM_STOREcO = thr_tmem_store.partition_S(tOcO_i); + static_assert(shape(tTMEM_STOREcO) == shape(tTMEM_LOADcO)); + + float2 scale_f32x2 = make_float2(scale, scale); + + Tensor tTMrO = make_tensor(make_shape(shape(tTMEM_LOADcO), Int<128 / kCorrectionTileSize>{})); + + auto copy_in = [&](int i) { + Tensor tTMEM_LOADtO_i = tTMEM_LOADtO; + tTMEM_LOADtO_i.data() = tTMEM_LOADtO_i.data().get() + uint32_t(i * kCorrectionTileSize); + Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); + copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO_i); + }; + + auto copy_out = [&](int i) { + Tensor tTMEM_STOREtO_i = tTMEM_STOREtO; + tTMEM_STOREtO_i.data() = tTMEM_STOREtO_i.data().get() + uint32_t(i * kCorrectionTileSize); + Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); + copy(tiled_tmem_store, tTMrO_i, tTMEM_STOREtO_i); + }; + + // sequence: LLMSLMSLMSS + + // loop: + // TMEM_LOAD, FMUL2 scale, TMEM_STORE + copy_in(0); + + int count = get<2>(TileShape{}) / kCorrectionTileSize; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < count; i++) { + if (i != count - 1) { + copy_in(i+1); + } + + Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(tTMrO_i); j += 2) { + float2 in = make_float2(tTMrO_i(j), tTMrO_i(j+1)); + float2 out; + cute::mul(out, scale_f32x2, in); + tTMrO_i(j) = out.x; + tTMrO_i(j+1) = out.y; + } + + copy_out(i); + } + } + + template< + class BlkCoord, class ProblemShape, class ParamsProblemShape, + class TensorStorageEpi, class CollectiveEpilogue + > + CUTLASS_DEVICE auto + correction( + BlkCoord const& blk_coord, + Params const& params, ProblemShape const& problem_shape, + ParamsProblemShape const& params_problem_shape, + TensorStorageEpi& shared_storage_epi, + PipelineC& pipeline_s0_c, typename PipelineC::PipelineState& pipeline_s0_c_consumer_state, + PipelineC& pipeline_s1_c, typename PipelineC::PipelineState& pipeline_s1_c_consumer_state, + PipelineO& pipeline_o, typename PipelineO::PipelineState& pipeline_o_consumer_state, + PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state, + CollectiveEpilogue& epilogue) { + + int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + Tensor tStS = partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0,1>(TileShapeQK{})); + + Tensor cS = make_identity_tensor(select<0,1>(TileShapeQK{})); + Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); + + Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{}))); + Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{}))); + + using TMEM_LOAD_V = SM100_TMEM_LOAD_32dp32b2x; // 4x32 threads with 2 cols of 32b elem + + auto tiled_tmem_loadv = make_tmem_copy(TMEM_LOAD_V{}, tStS_v); + auto thr_tmem_loadv = tiled_tmem_loadv.get_slice(thread_idx); + + Tensor tTMEM_LOADVtS = thr_tmem_loadv.partition_S(tStS_v); + Tensor tTMEM_LOADVcS = thr_tmem_loadv.partition_D(tScS_v); + + Tensor tTMEM_LOADVtS0 = tTMEM_LOADVtS; + tTMEM_LOADVtS0.data() = tTMEM_LOADVtS0.data().get() + uint32_t(TmemAllocation::V0); + Tensor tTMEM_LOADVtS1 = tTMEM_LOADVtS; + tTMEM_LOADVtS1.data() = tTMEM_LOADVtS1.data().get() + uint32_t(TmemAllocation::V1); + + // ignore first signal from softmax as no correction is required + pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); + pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); + ++pipeline_s0_c_consumer_state; + + pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); + + // handle the last iteration differently (i.e. tmem_load/stsm for epi) + mask_tile_count -= 1; + + CUTLASS_PRAGMA_NO_UNROLL + for (; mask_tile_count > 0; mask_tile_count -= 1) { + + pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); + + Tensor tTMEM_LOADVrS = make_tensor(shape(tTMEM_LOADVcS)); + + // read row_wise new global max + copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS); + + // e^(scale * (old_max - new_max) + float scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax))); + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + + correction_rescale(scale, uint32_t(TmemAllocation::O0)); + + pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); + ++pipeline_s1_c_consumer_state; + + cutlass::arch::fence_view_async_tmem_store(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + + pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); + + copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS); + + scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax))); + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + + correction_rescale(scale, uint32_t(TmemAllocation::O1)); + + pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); + ++pipeline_s0_c_consumer_state; + + cutlass::arch::fence_view_async_tmem_store(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + } + + pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); + ++pipeline_s1_c_consumer_state; + + // do the final correction to O1 + // better to somehow special-case it in the loop above + // doesn't matter for non-persistent code, but if it were + // persistent we do not want to release O too early + + pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); + + // read from V0 + // read row_sum and final row_max here + Tensor tTMEM_LOADVrS = make_tensor(shape(tTMEM_LOADVcS)); + copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS); + + pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); + ++pipeline_s0_c_consumer_state; + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + pipeline_epi.producer_acquire(pipeline_epi_producer_state); + // store to epi smem + + // loop: + // TMEM_LOAD + // FMUL2 scale = 1 / global_sum * out_quant_scale + // F2FP + // store to smem + Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{}); + Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), select<0,3>(problem_shape), epilogue.params.dLSE); + + correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _0{}, sO); + + if (epilogue.params.ptr_LSE != nullptr) { + int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord); + + int row_offset = 0; + if constexpr (is_variable_length_v>) { + row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)]; + } + + ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax); + + if (row_idx < get<0>(problem_shape)) { + gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse; + } + } + + cutlass::arch::fence_view_async_tmem_load(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + + pipeline_epi.producer_commit(pipeline_epi_producer_state); + ++pipeline_epi_producer_state; + + pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); + + // load from V1 + copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS); + + pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); + ++pipeline_s1_c_consumer_state; + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + pipeline_epi.producer_acquire(pipeline_epi_producer_state); + + correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _1{}, sO); + + if (epilogue.params.ptr_LSE != nullptr) { + int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{}); + + ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax); + + int row_offset = 0; + if constexpr (is_variable_length_v>) { + row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)]; + } + + if (row_idx < get<0>(problem_shape)) { + gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse; + } + } + + cutlass::arch::fence_view_async_tmem_load(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + + pipeline_epi.producer_commit(pipeline_epi_producer_state); + ++pipeline_epi_producer_state; + } + + + template< + class BlkCoord, class ProblemShape, class ParamsProblemShape, + class TensorStorageEpi, class CollectiveEpilogue + > + CUTLASS_DEVICE auto + correction_empty( + BlkCoord const& blk_coord, + Params const& params, ProblemShape const& problem_shape, + ParamsProblemShape const& params_problem_shape, + TensorStorageEpi& shared_storage_epi, + PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state, + CollectiveEpilogue& epilogue) { + + pipeline_epi.producer_acquire(pipeline_epi_producer_state); + + Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{}); + Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), select<0,3>(problem_shape), epilogue.params.dLSE); + float lse = -INFINITY; + int thread_idx = threadIdx.x % (4 * NumThreadsPerWarp); + +#define DSHOW(x) print(#x ": "); print(x); print("\n") + if (threadIdx.x % 128 == 0 && block0()) { + DSHOW(sO); + } +#if 1 + + using ElementOut = typename CollectiveEpilogue::ElementOut; + auto tiled_copy = make_cotiled_copy( + Copy_Atom, ElementOut>{}, + make_ordered_layout(make_shape(_128{}, Int{}), Step<_1, _0>{}), + sO.layout()); + + auto thr_copy = tiled_copy.get_slice(thread_idx); + auto tOgO = thr_copy.partition_D(sO); + auto tOrO = make_tensor(shape(tOgO(_,_,_,_0{}))); + clear(tOrO); + + copy(tiled_copy, tOrO, tOgO(_,_,_,_0{})); +#endif + + if (epilogue.params.ptr_LSE != nullptr) { + int row_idx = thread_idx + get<0>(TileShape{}) * get<0>(blk_coord); + + int row_offset = 0; + if constexpr (is_variable_length_v>) { + row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)]; + } + + if (row_idx < get<0>(problem_shape)) { + gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse; + } + } + + pipeline_epi.producer_commit(pipeline_epi_producer_state); + ++pipeline_epi_producer_state; + + copy(tiled_copy, tOrO, tOgO(_,_,_,_1{})); + cutlass::arch::fence_view_async_shared(); + pipeline_epi.producer_acquire(pipeline_epi_producer_state); + + if (epilogue.params.ptr_LSE != nullptr) { + int row_idx = thread_idx + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{}); + + int row_offset = 0; + if constexpr (is_variable_length_v>) { + row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)]; + } + + if (row_idx < get<0>(problem_shape)) { + gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse; + } + } + + cutlass::arch::fence_view_async_shared(); + pipeline_epi.producer_commit(pipeline_epi_producer_state); + ++pipeline_epi_producer_state; + } + +}; + +} // namespace cutlass::fmha::collective diff --git a/csrc/sm100/collective/sm100_fmha_load_tma_warpspecialized.hpp b/csrc/sm100/collective/sm100_fmha_load_tma_warpspecialized.hpp new file mode 100644 index 0000000..1951056 --- /dev/null +++ b/csrc/sm100/collective/sm100_fmha_load_tma_warpspecialized.hpp @@ -0,0 +1,316 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cute/tensor.hpp" +#include "cute/layout.hpp" + +#include "collective/fmha_common.hpp" +#include "collective/fmha_fusion.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; + +template< + class Element, + class StrideQ, + class StrideK, + class StrideV, + class CollectiveMmaQK, + class CollectiveMmaPV, + class SmemLayoutQ, + class SmemLayoutK, + class SmemLayoutV, + class TensorStorage, + class PipelineQ, + class PipelineKV, + class Mask, + class TileShape +> +struct Sm100FmhaLoadTmaWarpspecialized { + + using TileShapeQK = typename CollectiveMmaQK::TileShape; + using TileShapePV = typename CollectiveMmaPV::TileShape; + + struct Arguments { + const Element* ptr_Q; + StrideQ dQ; + const Element* ptr_K; + StrideK dK; + const Element* ptr_V; + StrideV dV; + }; + + using TMA_Q = typename CollectiveMmaQK::Params::TMA_A; + using TMA_K = typename CollectiveMmaQK::Params::TMA_B; + using TMA_V = typename CollectiveMmaPV::Params::TMA_B; + + struct Params { + TMA_Q tma_load_q; + TMA_K tma_load_k; + TMA_V tma_load_v; + }; + + template + static Params to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + void* workspace) { + + auto ptr_Q = args.ptr_Q; + auto ptr_K = args.ptr_K; + auto ptr_V = args.ptr_V; + auto dQ = args.dQ; + auto dK = args.dK; + auto dV = args.dV; + auto problem_shape_qk = problem_shape; + + if constexpr (is_variable_length_v>) { + auto cumulative_length_q = get<0>(problem_shape).cumulative_length; + if (cumulative_length_q != nullptr) { + int max_length_q = get<0>(problem_shape).max_length; + // for variable sequence lenght, the batch is in units of row_stride + get<2,1>(dQ) = get<0>(dQ); + get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_q * (1 + get<3,1>(problem_shape))); + // offset ptr by the amount we add back in later + ptr_Q -= max_length_q * get<0>(dQ); + } + } + + if constexpr (is_variable_length_v>) { + auto cumulative_length_kv = get<1>(problem_shape).cumulative_length; + if (cumulative_length_kv != nullptr) { + int max_length_kv = get<1>(problem_shape).max_length; + // for variable sequence lenght, the batch is in units of row_stride + get<2,1>(dK) = get<0>(dK); + get<2,1>(dV) = get<0>(dV); + get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_kv * (1 + get<3,1>(problem_shape))); + // offset ptr by the amount we add back in later + ptr_K -= max_length_kv * get<0>(dK); + ptr_V -= max_length_kv * get<0>(dV); + } + } + + auto params_qk = CollectiveMmaQK::to_underlying_arguments( + problem_shape_qk, + typename CollectiveMmaQK::Arguments { + ptr_Q, dQ, + ptr_K, dK, + }, /*workspace=*/ nullptr); + + auto problem_shape_pv = select<0,2,1,3>(problem_shape_qk); + auto params_pv = CollectiveMmaPV::to_underlying_arguments( + problem_shape_pv, + typename CollectiveMmaPV::Arguments { + ptr_K, dK, // never used, dummy + ptr_V, select<1,0,2>(dV), + }, /*workspace=*/ nullptr); + + return Params{ + params_qk.tma_load_a, + params_qk.tma_load_b, + params_pv.tma_load_b + }; + } + + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor()); + } + + template + CUTLASS_DEVICE void + load( + BlkCoord const& blk_coord_in, ProblemShape const& problem_shape, + Params const& params, ParamsProblemShape const& params_problem_shape, + TensorStorage& storage, + PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state, + PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) { + + BlkCoord blk_coord_q = blk_coord_in; + BlkCoord blk_coord_kv = blk_coord_in; + + int mask_tile_count = Mask{}.get_trip_count(blk_coord_in, TileShape{}, problem_shape); + + using X = Underscore; + + // this one is only executed by one thread, no need to elect_one + + // Q1, K1, Q2, V1, K2, V2, K3, V3, ... + // two pipes: Q and KV + // from Memory (prod) to TensorCore (cons) + + // compute gQ, sQ + // we load 2*get<0>(blk_coord), and 2*get<0>(blk_coord) + 1 + ThrMMA mma_qk = typename CollectiveMmaQK::TiledMma{}.get_slice(0); + Tensor mQ_qdl_p = params.tma_load_q.get_tma_tensor(select<0,2,3>(problem_shape)); + + int q_offs_0 = 0; + int q_offs_2_1 = 0; + + if constexpr (is_variable_length_v>) { + auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length; + if (cumulative_length_q != nullptr) { + int max_length_q = get<0>(params_problem_shape).max_length; + q_offs_0 = max_length_q - get<0>(problem_shape); + q_offs_2_1 = cumulative_length_q[get<2,1>(blk_coord_q)] + get<0>(problem_shape); + get<2,1>(blk_coord_q) = 0; + } + } + + Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, q_offs_2_1)), mQ_qdl_p); + + Tensor gQ_qdl = local_tile(mQ_qdl, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{}); + Tensor tSgQ_qdl = mma_qk.partition_A(gQ_qdl); + Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{}); + auto [tQgQ_qdl, tQsQ] = tma_partition( + params.tma_load_q, _0{}, make_layout(_1{}), + group_modes<0,3>(sQ), group_modes<0,3>(tSgQ_qdl) + ); + Tensor tQgQ = tQgQ_qdl(_, _, _0{}, get<2>(blk_coord_q)); + + // compute gK, sK + Tensor mK_kdl_p = params.tma_load_k.get_tma_tensor(select<1,2,3>(problem_shape)); + + int kv_offs_0 = 0; + int kv_offs_2_1 = 0; + + if constexpr (is_variable_length_v>) { + auto cumulative_length = get<1>(params_problem_shape).cumulative_length; + if (cumulative_length != nullptr) { + int max_length = get<1>(params_problem_shape).max_length; + kv_offs_0 = max_length - get<1>(problem_shape); + kv_offs_2_1 = cumulative_length[get<2,1>(blk_coord_kv)] + get<1>(problem_shape); + get<2,1>(blk_coord_kv) = 0; + } + } + + Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, kv_offs_2_1)), mK_kdl_p); + + Tensor gK_kdl = local_tile(mK_kdl, TileShapeQK{}, make_coord(_, _, _), Step{}); + Tensor tSgK_kdl = mma_qk.partition_B(gK_kdl); + Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{}); + auto [tKgK_kdl, tKsK] = tma_partition( + params.tma_load_k, _0{}, make_layout(_1{}), + group_modes<0,3>(sK), group_modes<0,3>(tSgK_kdl) + ); + Tensor tKgK = tKgK_kdl(_, _, _0{}, get<2>(blk_coord_kv)); + + // compute gV, sV + ThrMMA mma_pv = typename CollectiveMmaPV::TiledMma{}.get_slice(0); + Tensor mV_dkl_p = params.tma_load_v.get_tma_tensor(select<2,1,3>(problem_shape)); + + Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, kv_offs_2_1)), mV_dkl_p); + + Tensor gV_dkl = local_tile(mV_dkl, TileShapePV{}, make_coord(_, _, _), Step{}); + Tensor tOgV_dkl = mma_pv.partition_B(gV_dkl); + Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{}); + auto [tVgV_dkl, tVsV] = tma_partition( + params.tma_load_v, _0{}, make_layout(_1{}), + group_modes<0,3>(sV), group_modes<0,3>(tOgV_dkl) + ); + auto tVgV = tVgV_dkl(_, _0{}, _, get<2>(blk_coord_kv)); + + // blk_coord in decomposed in terms of TileShape, not TileShapeQK + // As such, it needs to be transformed as + // (a,b,c): a -> 2*a (Q0) 2*a+1 (Q1) + // b -> 2*a (Ki i even) 2*a+1 (Ki i odd) + + uint32_t lane_predicate = cute::elect_one_sync(); + + // Q1 + int q0_index = 2 * get<0>(blk_coord_q); + int q1_index = 2 * get<0>(blk_coord_q) + 1; + pipeline_q.producer_acquire(pipeline_q_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state); + copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q0_index), tQsQ(_, pipeline_q_producer_state.index())); + } + ++pipeline_q_producer_state; + + // K1 + int k_index = 0; + pipeline_kv.producer_acquire(pipeline_kv_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); + copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, k_index), tKsK(_, pipeline_kv_producer_state.index())); + } + ++pipeline_kv_producer_state; + + // Q2 + pipeline_q.producer_acquire(pipeline_q_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state); + copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q1_index), tQsQ(_, pipeline_q_producer_state.index())); + } + ++pipeline_q_producer_state; + + // V1 + pipeline_kv.producer_acquire(pipeline_kv_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); + copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index())); + } + ++pipeline_kv_producer_state; + k_index += 1; + + // loop: + mask_tile_count -= 1; + for (; mask_tile_count > 0; mask_tile_count -= 1) { + + // Ki + pipeline_kv.producer_acquire(pipeline_kv_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); + copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, k_index), tKsK(_, pipeline_kv_producer_state.index())); + } + ++pipeline_kv_producer_state; + + // Vi + pipeline_kv.producer_acquire(pipeline_kv_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); + copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index())); + } + ++pipeline_kv_producer_state; + k_index += 1; + } + } +}; + +} // namespace cutlass::fmha::collective diff --git a/csrc/sm100/collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp b/csrc/sm100/collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp new file mode 100644 index 0000000..bf41af9 --- /dev/null +++ b/csrc/sm100/collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp @@ -0,0 +1,1225 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cute/arch/simd_sm100.hpp" +#include "cute/tensor.hpp" +#include "cute/layout.hpp" + +#include "collective/fmha_common.hpp" +#include "collective/fmha_fusion.hpp" +#include "collective/sm100_fmha_mla_load_tma_warpspecialized.hpp" +#include "common/pipeline_mla.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; + +template< + class Element_, + class ElementQK_, + class ElementPV_, + class ComposedTileShape_, + class StrideQ_, + class StrideK_, + class StrideV_, + class Mask_, + // shape here is QG K H + // and referes to the two softmax warps + // (2, 1, 1) means that they are stacked (best for large Q since it loads the least K/V) + // (1, 2, 1) means they sit side by side (best for small Q / large K) + class ThreadShape = Shape<_2, _1, _1>, + class OrderLoadEpilogue = cute::false_type +> +struct Sm100MlaFwdMainloopTmaWarpspecialized { + + using Element = Element_; + using ElementQK = ElementQK_; + using ElementPV = ElementPV_; + using ComposedTileShape = ComposedTileShape_; + using StrideQ = StrideQ_; + using StrideK = StrideK_; + using StrideV = StrideV_; + using Mask = Mask_; + + static constexpr int StageCountQ = 2; + static constexpr int StageCountK = 1; + static constexpr int StageCountV = 1; + static constexpr int StageCountKV = StageCountK + StageCountV; + // Support StageCountKV > 2 in the future. + static_assert(StageCountK == 1 && StageCountV == 1, "Only support StageCountK = StageCountV = 1!"); + static_assert(std::is_same_v>, "Only support ThreadShape = Shape<_2, _1, _1>"); + + using ClusterShape = Shape<_1, _1, _1>; + + static const int Alignment = 128 / sizeof_bits_v; + + static constexpr auto HeadDimLatent = size<2, 0>(ComposedTileShape{}); + static constexpr auto HeadDimRope = size<2, 1>(ComposedTileShape{}); + static constexpr auto HeadDimQK = HeadDimLatent + HeadDimRope; + static constexpr auto HeadDimPV = HeadDimLatent; + + using TileShapeQK = decltype(shape_div(replace<2>(ComposedTileShape{}, HeadDimQK), ThreadShape{})); + using TileShapePV = decltype(select<0,2,1>(shape_div(replace<2>(ComposedTileShape{}, HeadDimPV), ThreadShape{}))); + using TileShape = decltype(replace<2>(ComposedTileShape{}, HeadDimLatent)); + + using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, StrideQ, Alignment, + Element, StrideK, Alignment, + ElementQK, + TileShapeQK, ClusterShape, cutlass::gemm::collective::StageCount<3> /* we change it later anyways*/, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100>::CollectiveOp; + + using CollectiveMmaPV = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // the stride for A does not matter since we do not load from smem at all + Element, StrideK, Alignment, + Element, decltype(select<1,0,2>(StrideV{})), Alignment, + ElementPV, + TileShapePV, ClusterShape, cutlass::gemm::collective::StageCount<3> /* we change it later anyways*/, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100>::CollectiveOp; + + using SmemLayoutQ = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutA{}, Int{})); + using SmemLayoutK = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutB{}, Int{})); + using SmemLayoutV = decltype(unstageSmemLayout(typename CollectiveMmaPV::SmemLayoutB{}, Int{})); + + using SmemStorageOneStageO = decltype(make_layout(replace<2>(TileShapePV{}, _1{}))); + + // Since the shared memory is not sufficient if we use separate Q, K, V, and O shared memory, + // we reuse shared memory for V and O to address this problem, + // and a barrier has been added to coordinate access to shared memory. + static constexpr bool IsOrderLoadEpilogue = std::is_same_v; + static const int NumWarpsEpilogue = 1; + static const int NumWarpsLoad = 1; + + struct TensorStorageQKVO { + cute::array_aligned> smem_q; + cute::array_aligned> smem_k; + cute::array_aligned> smem_o; // use as O0 + cute::array_aligned> smem_v; // use as V0 and O1 + }; + + struct TensorStorageQKV { + cute::array_aligned> smem_q; + cute::array_aligned> smem_k; + cute::array_aligned> smem_v; + }; + + using TensorStorage = std::conditional_t; + + enum class TmemAllocation : uint32_t { + kSizeS = 128, + kSizeO = 128, + kSizeP = 32, + S0 = 0, + S1 = S0 + kSizeS, + V0 = S0, // stats storage from softmax to correction + V1 = S1, + P0 = S0 + kSizeP, + P1 = S1 + kSizeP, + O0 = S1 + kSizeS, + O1 = O0 + kSizeO, + kEnd = O1 + kSizeO + }; + + // indices for V0 / V1 + enum : int { + kIdxOldRowMax = 0, + kIdxNewRowMax = 1, + kIdxFinalRowSum = 0, + kIdxFinalRowMax = 1 + }; + + // from load to mma warp, protects q in smem + using PipelineQ = cutlass::PipelineTmaUmmaAsync< + StageCountQ, + typename CollectiveMmaQK::AtomThrShapeMNK + >; + + // from load to mma warp, protects k/v in smem + using PipelineKV = cutlass::PipelineTmaAsyncMla< + StageCountKV, + typename CollectiveMmaQK::AtomThrShapeMNK + >; + + // from mma to softmax0/1 warp, protects S in tmem + // (not sure yet about the reverse direction) + // there is one pipe per softmax warp, and the mma warp alternates between them + using PipelineS = cutlass::PipelineUmmaAsync<1>; + + // from softmax0/1/ to correction wg + using PipelineC = cutlass::PipelineAsync<1>; + + // from mma to correction + using PipelineO = cutlass::PipelineUmmaAsync<2>; + + // from corr to epilogue + using PipelineE = cutlass::PipelineAsync<2>; + + using OrderBarrierSoftmax = cutlass::OrderedSequenceBarrier< + /*stages*/ 1, /*groups*/ 2>; + + static constexpr int TransactionBytesLoadQ = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v); + static constexpr int TransactionBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v); + static constexpr int TransactionBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v); + + using Load = Sm100MlaFwdLoadTmaWarpspecialized< + Element, StrideQ, StrideK, StrideV, + CollectiveMmaQK, CollectiveMmaPV, + SmemLayoutQ, SmemLayoutK, SmemLayoutV, + TensorStorage, PipelineQ, PipelineKV, Mask, TileShape, OrderLoadEpilogue + >; + + struct Arguments { + typename Load::Arguments load; + + // if zero, defaults to 1/sqrt(D) + float scale_softmax = 0.0f; + + // scaling factors to dequantize QKV + float scale_q = 1.0f; + float scale_k = 1.0f; + float scale_v = 1.0f; + + // scaling factor to quantize O + float inv_scale_o = 1.0f; + }; + + struct Params { + typename Load::Params load; + + float scale_softmax; + float scale_softmax_log2; + + float scale_output; + }; + + template + static bool can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static Params to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + void* workspace) { + + float scale_softmax = args.scale_softmax; + if (scale_softmax == 0.0f) { + scale_softmax = 1.0f / (float) std::sqrt(get<2, 0>(problem_shape) + get<2, 1>(problem_shape)); + } + float log2_e = static_cast(std::log2(std::exp(1.0))); + + return Params{ + Load::to_underlying_arguments(problem_shape, args.load, workspace), + args.scale_q * args.scale_k * scale_softmax, + args.scale_q * args.scale_k * log2_e * scale_softmax, + args.scale_v * args.inv_scale_o + }; + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + Load::prefetch_tma_descriptors(params.load); + } + + template + CUTLASS_DEVICE void + load( + BlkCoord const& blk_coord, ProblemShape const& problem_shape, + Params const& params, ParamsProblemShape const& params_problem_shape, + TensorStorage& storage, + PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state, + PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) { + + Load load; + load.load(blk_coord, problem_shape, params.load, params_problem_shape, + storage, + pipeline_q, pipeline_q_producer_state, + pipeline_kv, pipeline_kv_producer_state); + } + + template + CUTLASS_DEVICE auto + mma( + BlkCoord const& blk_coord, + Params const& params, ProblemShape const& problem_shape, + TensorStorage& storage, + PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_consumer_state, + PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_consumer_state, + PipelineS& pipeline_s0, typename PipelineS::PipelineState& pipeline_s0_producer_state, + PipelineS& pipeline_s1, typename PipelineS::PipelineState& pipeline_s1_producer_state, + PipelineO& pipeline_corr, typename PipelineO::PipelineState& pipeline_corr_producer_state) { + + auto pipeline_q_release_state = pipeline_q_consumer_state; + auto pipeline_kv_release_state = pipeline_kv_consumer_state; + + int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); + + typename CollectiveMmaQK::TiledMma mma_qk; + ThrMMA thr_mma_qk = mma_qk.get_slice(0); + + typename CollectiveMmaPV::TiledMma mma_pv; + TiledMMA mma_pv_ts = to_tiled_mma_sm100_ts(mma_pv); + ThrMMA thr_mma_pv = mma_pv_ts.get_slice(0); + + Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{}); + Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{}); + Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{}); + + Tensor tSrQ = thr_mma_qk.make_fragment_A(sQ); + Tensor tSrK = thr_mma_qk.make_fragment_B(sK); + Tensor tOrV = thr_mma_pv.make_fragment_B(sV); + + // tmem layout is + // S0 S1`O0 O1 + // sequential in memory, where S overlaps with P and V + + Tensor tStS = partition_fragment_C(mma_qk, select<0,1>(TileShapeQK{})); + Tensor tOtO = partition_fragment_C(mma_pv_ts, select<0,1>(TileShapePV{})); + + Tensor tStS0 = tStS; + tStS0.data() = tStS.data().get() + uint32_t(TmemAllocation::S0); + Tensor tStS1 = tStS; + tStS1.data() = tStS.data().get() + uint32_t(TmemAllocation::S1); + + Tensor tOtO0 = tOtO; + tOtO0.data() = tOtO.data().get() + uint32_t(TmemAllocation::O0); + Tensor tOtO1 = tOtO; + tOtO1.data() = tOtO.data().get() + uint32_t(TmemAllocation::O1); + + Tensor sP = make_tensor(make_smem_ptr((Element*)nullptr), typename CollectiveMmaPV::SmemLayoutA{}); + Tensor tOrP = thr_mma_pv.make_fragment_A(sP)(_, _, _, _0{}); // slice out staging + + Tensor tOrP0 = tOrP; + tOrP0.data() = tOrP0.data().get() + uint32_t(TmemAllocation::P0); + Tensor tOrP1 = tOrP; + tOrP1.data() = tOrP1.data().get() + uint32_t(TmemAllocation::P1); + + int k_index = 0; + int v_index = 0; + int q_index = 0; + + // wait for Q1 + q_index = pipeline_q_consumer_state.index(); + pipeline_q.consumer_wait(pipeline_q_consumer_state); + ++pipeline_q_consumer_state; + + Tensor tSrQ0 = tSrQ(_,_,_,q_index); + + + // wait for K1 + k_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // gemm Q1 * K1 -> S1 + pipeline_s0.producer_acquire(pipeline_s0_producer_state); + + gemm_zero_acc(mma_qk, tSrQ0, tSrK(_,_,_,k_index / 2), tStS0); + + pipeline_s0.producer_commit(pipeline_s0_producer_state); + ++pipeline_s0_producer_state; + + // release K1 + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + + // wait for Q2 + if constexpr (get<0>(ThreadShape{}) > 1 || get<2>(ThreadShape{}) > 1) { + q_index = pipeline_q_consumer_state.index(); + pipeline_q.consumer_wait(pipeline_q_consumer_state); + ++pipeline_q_consumer_state; + } + + Tensor tSrQ1 = tSrQ(_,_,_,q_index); + + if constexpr (get<1>(ThreadShape{}) > 1) { + k_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + pipeline_s1.producer_acquire(pipeline_s1_producer_state); + + // gemm Q2 * K1 -> S2 + gemm_zero_acc(mma_qk, tSrQ1, tSrK(_,_,_,k_index / 2), tStS1); + + pipeline_s1.producer_commit(pipeline_s1_producer_state); + ++pipeline_s1_producer_state; + + // release K1 + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + // wait for V1 + v_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // this acquire returns the ownership of all of S0 to the mma warp + // including the P0 part + // acquire corr first to take it out of the critical + // path since softmax takes longer + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + pipeline_s0.producer_acquire(pipeline_s0_producer_state); + + // gemm P1 * V1 -> O1 + gemm_zero_acc(mma_pv_ts, tOrP0, tOrV(_,_,_,v_index / 2), tOtO0); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + + mma_pv_ts.accumulate_ = UMMA::ScaleOut::Zero; + + // loop: + mask_tile_count -= 1; + for (; mask_tile_count > 0; mask_tile_count -= 1) { + + // wait for Ki + k_index = (pipeline_kv_consumer_state.index()); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // gemm Q1 * Ki -> S1 + gemm_zero_acc(mma_qk, tSrQ0, tSrK(_,_,_,k_index / 2), tStS0); + + pipeline_s0.producer_commit(pipeline_s0_producer_state); + ++pipeline_s0_producer_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + + // gemm P2 * V(i-1) -> O2 + if constexpr (get<1>(ThreadShape{}) > 1) { + v_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + pipeline_s1.producer_acquire(pipeline_s1_producer_state); + + gemm_reset_zero_acc(mma_pv_ts, tOrP1, tOrV(_,_,_,v_index / 2), tOtO1); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + // release V(i-1) + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + k_index = (pipeline_kv_consumer_state.index()); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + // gemm Q2 * Ki -> S2 + gemm_zero_acc(mma_qk, tSrQ1, tSrK(_,_,_,k_index / 2), tStS1); + + pipeline_s1.producer_commit(pipeline_s1_producer_state); + ++pipeline_s1_producer_state; + + // release Ki + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + // wait for Vi + v_index = (pipeline_kv_consumer_state.index()); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // gemm P1 * Vi -> O1 + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + + pipeline_s0.producer_acquire(pipeline_s0_producer_state); + + gemm_reset_zero_acc(mma_pv_ts, tOrP0, tOrV(_,_,_,v_index / 2), tOtO0); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + } + + // release Q1 + pipeline_q.consumer_release(pipeline_q_release_state); + ++pipeline_q_release_state; + + // release Q2 + if constexpr (get<0>(ThreadShape{}) > 1) { + pipeline_q.consumer_release(pipeline_q_release_state); + ++pipeline_q_release_state; + } + + // wait for Vi + if constexpr (get<1>(ThreadShape{}) > 1) { + v_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + // gemm P2 * Vi -> O2 + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + pipeline_s1.producer_acquire(pipeline_s1_producer_state); + + gemm_reset_zero_acc(mma_pv_ts, tOrP1, tOrV(_,_,_,v_index / 2), tOtO1); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + // release Vi + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + pipeline_s0.producer_commit(pipeline_s0_producer_state); + ++pipeline_s0_producer_state; + + pipeline_s1.producer_commit(pipeline_s1_producer_state); + ++pipeline_s1_producer_state; + + // T0 S00 B1, T0 S10 B1, T0 S00 B2, T0 S01 B1, T0 S10 B2, T0 S11 B1, T0 S01 B2, T1 S00 B1, T0 S11 B2, ... + // Q1 * K1 , Q2 * K1 , S11 * V1 , Q1 * K2 , S21 * V1 , Q2 * K2 , S12 * V2 , Q1 * K3 , S22 * K2 , ... + } + + template + CUTLASS_DEVICE auto + softmax_step( + bool need_apply_mask, + float& row_max, float& row_sum, + Stage stage, bool final_call, + BlkCoord const& blk_coord, CoordTensor const& cS, + Params const& params, ProblemShape const& problem_shape, + PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state, + PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state, + OrderBarrierSoftmax& order_s) { + + Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); + + Tensor tStS = partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0,1>(TileShapeQK{})); + tStS.data() = uint32_t(stage == _0{} ? TmemAllocation::S0 : TmemAllocation::S1); + + Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{}))); + tStS_v.data() = uint32_t(stage == _0{} ? TmemAllocation::V0 : TmemAllocation::V1); + Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{}))); + + auto tilePlikeFP32 = size<1>(TileShapeQK{}) / Int{} * Int{}; + Tensor tStS_P = tStS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); + tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1)); + Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); + + // Each thread owns a single row + using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem + using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 128 cols of 8b elem + using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tStS); + auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + + Tensor tTMEM_LOADtS = thr_tmem_load.partition_S(tStS); + Tensor tTMEM_LOADcS = thr_tmem_load.partition_D(tScS); + + auto tiled_tmem_storev = make_tmem_copy(TMEM_STORE_V{}, tStS_v); + auto thr_tmem_storev = tiled_tmem_storev.get_slice(thread_idx); + + Tensor tTMEM_STOREVtS = thr_tmem_storev.partition_D(tStS_v); + Tensor tTMEM_STOREVcS = thr_tmem_storev.partition_S(tScS_v); + + auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tStS_P); + auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx); + + Tensor tTMEM_STOREtS_x4 = thr_tmem_store.partition_D(tStS_P); + tTMEM_STOREtS_x4.data() = warp_uniform(tTMEM_STOREtS_x4.data().get()); + Tensor tTMEM_STOREcS = thr_tmem_store.partition_S(tScS_P); + + // wait on tensor core pipe + pipeline_s.consumer_wait(pipeline_s_consumer_state); + + // read all of S from tmem into reg mem + Tensor tTMEM_LOADrS = make_tensor(shape(tTMEM_LOADcS)); + copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS); + + if constexpr (need_mask) { + if(need_apply_mask) { + Mask{}.apply_mask(tTMEM_LOADrS, tTMEM_LOADcS, problem_shape); + } + } + + ElementQK old_row_max = row_max; + { + // compute rowmax + float row_max_0 = row_max; + float row_max_1 = row_max; + float row_max_2 = row_max; + float row_max_3 = row_max; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTMEM_LOADrS); i += 4) { + row_max_0 = ::fmax(row_max_0, tTMEM_LOADrS(i)); + row_max_1 = ::fmax(row_max_1, tTMEM_LOADrS(i+1)); + row_max_2 = ::fmax(row_max_2, tTMEM_LOADrS(i+2)); + row_max_3 = ::fmax(row_max_3, tTMEM_LOADrS(i+3)); + } + row_max = ::fmax(row_max_0, row_max_1); + row_max = ::fmax(row_max, row_max_2); + row_max = ::fmax(row_max, row_max_3); + } + + ElementQK row_max_safe = row_max == -INFINITY ? 0 : row_max; + + Tensor tTMEM_STOREVrS = make_tensor(shape(tTMEM_STOREVcS)); + tTMEM_STOREVrS(kIdxOldRowMax) = old_row_max; + tTMEM_STOREVrS(kIdxNewRowMax) = row_max_safe; + copy(tiled_tmem_storev, tTMEM_STOREVrS, tTMEM_STOREVtS); + + pipeline_c.producer_commit(pipeline_c_producer_state); + ++pipeline_c_producer_state; + + // notify correction wg that they are ready (might need addtl ordering between S0 and S1 WG's) + + ElementQK scale = params.scale_softmax_log2; + ElementQK row_max_scale = row_max_safe * scale; + + float2 scale_fp32x2 = make_float2(scale, scale); + float2 minus_row_max_scale_fp32x2 = make_float2(-row_max_scale, -row_max_scale); + + Tensor tTMEM_STORErS_x4 = make_tensor(shape(tTMEM_STOREcS)); + + constexpr int kConversionsPerStep = 2; + + Tensor tTMEM_STORErS_x4_e = recast>(tTMEM_STORErS_x4); + + NumericArrayConverter convert; + + constexpr int kReleasePipeCount = 10; // must be multiple of 2 + + order_s.wait(); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTMEM_LOADrS); i += 2) { + float2 in = make_float2( + tTMEM_LOADrS(i + 0), + tTMEM_LOADrS(i + 1) + ); + float2 out; + cute::fma(out, scale_fp32x2, in, minus_row_max_scale_fp32x2); + tTMEM_LOADrS(i + 0) = out.x; + tTMEM_LOADrS(i + 1) = out.y; + + tTMEM_LOADrS(i+0) = ::exp2f(tTMEM_LOADrS(i+0)); + tTMEM_LOADrS(i+1) = ::exp2f(tTMEM_LOADrS(i+1)); + + Array in_conv; + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kConversionsPerStep; j++) { + in_conv[j] = tTMEM_LOADrS(i + j); + } + tTMEM_STORErS_x4_e[i / kConversionsPerStep] = convert(in_conv); + + + if (i == size(tTMEM_LOADrS) - kReleasePipeCount) { + order_s.arrive(); + } + + // this prevents register spills in fp16 + if constexpr (size<2>(tTMEM_STORErS_x4) == _2{}) { + if (i == size(tTMEM_LOADrS) - 6) { + copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, 0), tTMEM_STOREtS_x4(_, _, 0)); + } + } + } + + // tmem_store(reg_S8) -> op_P + CUTE_STATIC_ASSERT_V(size<2>(tTMEM_STORErS_x4) <= _2{}); + CUTE_STATIC_ASSERT_V(size<1>(tTMEM_STORErS_x4) == _1{}); + copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1), tTMEM_STOREtS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1)); + + cutlass::arch::fence_view_async_tmem_store(); + + // notify tensor core warp that P is ready + pipeline_s.consumer_release(pipeline_s_consumer_state); + ++pipeline_s_consumer_state; + + pipeline_c.producer_acquire(pipeline_c_producer_state); + + ElementQK acc_scale = 0.5f * ::exp2f(scale * (old_row_max - row_max_safe)); + row_sum *= acc_scale; + // row_sum = sum(reg_S) + float2 local_row_sum_f32x2 = make_float2(row_sum, row_sum); + float2 local_row_sum_1 = make_float2(0, 0); + float2 local_row_sum_2 = make_float2(0, 0); + float2 local_row_sum_3 = make_float2(0, 0); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTMEM_LOADrS); i += 8) { + // row_sum += tTMEM_LOADrS(i); + float2 in = make_float2(tTMEM_LOADrS(i), tTMEM_LOADrS(i+1)); + cute::add(local_row_sum_f32x2, local_row_sum_f32x2, in); + + in = make_float2(tTMEM_LOADrS(i+2), tTMEM_LOADrS(i+2+1)); + cute::add(local_row_sum_1, local_row_sum_1, in); + + in = make_float2(tTMEM_LOADrS(i+4), tTMEM_LOADrS(i+4+1)); + cute::add(local_row_sum_2, local_row_sum_2, in); + + in = make_float2(tTMEM_LOADrS(i+6), tTMEM_LOADrS(i+6+1)); + cute::add(local_row_sum_3, local_row_sum_3, in); + } + + cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_1); + cute::add(local_row_sum_2, local_row_sum_2, local_row_sum_3); + cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_2); + float local_row_sum = local_row_sum_f32x2.x + local_row_sum_f32x2.y; + + row_sum = local_row_sum; + + if (final_call) { + // re-acquire the S part in the final step + pipeline_s.consumer_wait(pipeline_s_consumer_state); + + Tensor tTMEM_STOREVrS = make_tensor(shape(tTMEM_STOREVcS)); + tTMEM_STOREVrS(kIdxFinalRowMax) = row_max; + tTMEM_STOREVrS(kIdxFinalRowSum) = row_sum; + copy(tiled_tmem_storev, tTMEM_STOREVrS, tTMEM_STOREVtS); + } + } + + template + CUTLASS_DEVICE auto + softmax( + Stage stage, + BlkCoord const& blk_coord, + Params const& params, ProblemShape const& problem_shape, + PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state, + PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state, + OrderBarrierSoftmax& order_s) { + const int mask_trip_count = Mask{}.get_masked_trip_count(blk_coord, TileShape{}, problem_shape); + const int total_trip_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); + int trip_idx = total_trip_count; + + ElementQK row_max = -INFINITY; + ElementQK row_sum = 0; + + Tensor cS_base = make_identity_tensor(select<0,1>(TileShapeQK{})); + auto logical_offset = make_coord( + get<0>(blk_coord) * get<0>(TileShape{}) + (stage % get<0>(ThreadShape{})) * get<0>(TileShapeQK{}), + 0 + (stage % get<1>(ThreadShape{})) * get<1>(TileShapeQK{}) + ); + Tensor cS = domain_offset(logical_offset, cS_base); + + pipeline_c.producer_acquire(pipeline_c_producer_state); + + constexpr bool NeedMask = !std::is_same_v; + + CUTLASS_PRAGMA_NO_UNROLL + for (; trip_idx > 0; trip_idx -= 1) { + softmax_step( + trip_idx <= mask_trip_count, + row_max, row_sum, stage, + trip_idx == 1, + blk_coord, cS, params, problem_shape, + pipeline_s, pipeline_s_consumer_state, + pipeline_c, pipeline_c_producer_state, + order_s + ); + + cS.data() = cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{}); + } + + pipeline_c.producer_commit(pipeline_c_producer_state); + ++pipeline_c_producer_state; + + pipeline_c.producer_acquire(pipeline_c_producer_state); + // empty step to sync against pipe s + pipeline_s.consumer_release(pipeline_s_consumer_state); + ++pipeline_s_consumer_state; + } + + template + CUTLASS_DEVICE auto + correction_epilogue( + float scale, + Stage stage, + TensorO const& sO_01) { + + using ElementOut = typename TensorO::value_type; + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + Tensor sO = sO_01(_,_,stage); + + // As opposed to the softmax, we do not have enough registers here + // to load all of the values (for tile kv = 128), so we loop + // good values would be either 32 or 64 + constexpr int kCorrectionTileSize = 32 / sizeof(ElementOut); + + using TMEM_LOAD = std::conditional_t; // 4x32 threads with 64 cols of 32b elem + + typename CollectiveMmaPV::TiledMma mma; + Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{})); + Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{})); + Tensor tOcO = mma.get_slice(0).partition_C(cO); + Tensor tOsO = mma.get_slice(0).partition_C(sO); + + Tensor tOtO_i = logical_divide(tOtO, make_layout(make_shape(_128{}, Int{}))); + Tensor tOcO_i = logical_divide(tOcO, make_layout(make_shape(_128{}, Int{}))); + Tensor tOsO_i = logical_divide(tOsO, make_layout(make_shape(_128{}, Int{}))); + + if constexpr (decltype(stage == _0{})::value) { + tOtO_i.data() = tOtO_i.data().get() + uint32_t(TmemAllocation::O0); + } + else { + static_assert(decltype(stage == _1{})::value, "stage is either 0 or 1"); + tOtO_i.data() = tOtO_i.data().get() + uint32_t(TmemAllocation::O1); + } + + auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i(make_coord(_, _), _0{})); + auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + + Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i(make_coord(_, _), _)); + Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i(make_coord(_, _), _)); + Tensor tTMEM_LOADsO = thr_tmem_load.partition_D(tOsO_i(make_coord(_, _), _)); + + float2 scale_f32x2 = make_float2(scale, scale); + + // loop: + // TMEM_LOAD, FMUL2 scale, TMEM_STORE + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < get<2>(TileShape{}) / kCorrectionTileSize; i++) { + Tensor tTMEM_LOADtO_i = tTMEM_LOADtO(_, _0{}, _0{}, i); + Tensor tTMEM_LOADsO_i = tTMEM_LOADsO(_, _0{}, _0{}, i); + + Tensor tTMrO = make_tensor(shape(tTMEM_LOADcO(_, _0{}, _0{}, i))); + + copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO); + +#ifndef ONLY_SOFTMAX + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(tTMrO); j += 2) { + float2 in = make_float2(tTMrO(j), tTMrO(j+1)); + float2 out; + cute::mul(out, scale_f32x2, in); + tTMrO(j) = out.x; + tTMrO(j+1) = out.y; + } +#endif + + constexpr int N = 4 / sizeof(ElementOut); + NumericArrayConverter convert; + + Tensor tSMrO = make_tensor_like(tTMrO); + + Tensor tCs = recast(tTMrO); + Tensor tCd = recast(tSMrO); + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(tCs); j++) { + tCd(j) = convert.convert(tCs(j)); + } + + Tensor tSMsO_i = recast(tTMEM_LOADsO_i); + Tensor tSMrO_i = recast(tSMrO); + + copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tSMrO_i, tSMsO_i); + } + + cutlass::arch::fence_view_async_shared(); + } + + CUTLASS_DEVICE auto + correction_rescale( + float scale, + uint32_t tmem_O) { + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + // As opposed to the softmax, we do not have enough registers here + // to load all of the values (for tile kv = 128), so we loop + // good values would be either 32 or 64 + constexpr int kCorrectionTileSize = 16; + + using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b16x; // 4x32 threads with 64 cols of 32b elem + using TMEM_STORE = SM100_TMEM_STORE_32dp32b16x; // 4x32 threads with 64 cols of 32b elem + + typename CollectiveMmaPV::TiledMma mma; + Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{})); + Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{})); + Tensor tOcO = mma.get_slice(0).partition_C(cO); + + Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int{}))); + Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int{}))); + + tOtO_i.data() = tOtO_i.data().get() + tmem_O; + + auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i); + auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tOtO_i); + auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx); + + Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i); + Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i); + Tensor tTMEM_STOREtO = thr_tmem_store.partition_D(tOtO_i); + Tensor tTMEM_STOREcO = thr_tmem_store.partition_S(tOcO_i); + static_assert(shape(tTMEM_STOREcO) == shape(tTMEM_LOADcO)); + + float2 scale_f32x2 = make_float2(scale, scale); + + Tensor tTMrO = make_tensor(make_shape(shape(tTMEM_LOADcO), Int<128 / kCorrectionTileSize>{})); + + auto copy_in = [&](int i) { + Tensor tTMEM_LOADtO_i = tTMEM_LOADtO; + tTMEM_LOADtO_i.data() = tTMEM_LOADtO_i.data().get() + uint32_t(i * kCorrectionTileSize); + Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); + copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO_i); + }; + + auto copy_out = [&](int i) { + Tensor tTMEM_STOREtO_i = tTMEM_STOREtO; + tTMEM_STOREtO_i.data() = tTMEM_STOREtO_i.data().get() + uint32_t(i * kCorrectionTileSize); + Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); + copy(tiled_tmem_store, tTMrO_i, tTMEM_STOREtO_i); + }; + + // sequence: LLMSLMSLMSS + + // loop: + // TMEM_LOAD, FMUL2 scale, TMEM_STORE + copy_in(0); + + constexpr int count = get<2>(TileShape{}) / kCorrectionTileSize; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < count; i++) { + if (i != count - 1) { + copy_in(i+1); + } + + Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(tTMrO_i); j += 2) { + float2 in = make_float2(tTMrO_i(j), tTMrO_i(j+1)); + float2 out; + cute::mul(out, scale_f32x2, in); + tTMrO_i(j) = out.x; + tTMrO_i(j+1) = out.y; + } + + copy_out(i); + } + } + + template< + class BlkCoord, class ProblemShape, class ParamsProblemShape, + class TensorStorageEpi, class CollectiveEpilogue + > + CUTLASS_DEVICE auto + correction( + BlkCoord const& blk_coord, + Params const& params, ProblemShape const& problem_shape, + ParamsProblemShape const& params_problem_shape, + TensorStorageEpi& shared_storage_epi, + PipelineC& pipeline_s0_c, typename PipelineC::PipelineState& pipeline_s0_c_consumer_state, + PipelineC& pipeline_s1_c, typename PipelineC::PipelineState& pipeline_s1_c_consumer_state, + PipelineO& pipeline_o, typename PipelineO::PipelineState& pipeline_o_consumer_state, + PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state, + CollectiveEpilogue& epilogue) { + + int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + Tensor tStS = partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0,1>(TileShapeQK{})); + + Tensor cS = make_identity_tensor(select<0,1>(TileShapeQK{})); + Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); + + Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{}))); + Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{}))); + + using TMEM_LOAD_V = SM100_TMEM_LOAD_32dp32b2x; // 4x32 threads with 2 cols of 32b elem + + auto tiled_tmem_loadv = make_tmem_copy(TMEM_LOAD_V{}, tStS_v); + auto thr_tmem_loadv = tiled_tmem_loadv.get_slice(thread_idx); + + Tensor tTMEM_LOADVtS = thr_tmem_loadv.partition_S(tStS_v); + Tensor tTMEM_LOADVcS = thr_tmem_loadv.partition_D(tScS_v); + + Tensor tTMEM_LOADVtS0 = tTMEM_LOADVtS; + tTMEM_LOADVtS0.data() = tTMEM_LOADVtS0.data().get() + uint32_t(TmemAllocation::V0); + Tensor tTMEM_LOADVtS1 = tTMEM_LOADVtS; + tTMEM_LOADVtS1.data() = tTMEM_LOADVtS1.data().get() + uint32_t(TmemAllocation::V1); + + // ignore first signal from softmax as no correction is required + pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); + pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); + ++pipeline_s0_c_consumer_state; + + pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); + + // handle the last iteration differently (i.e. tmem_load/stsm for epi) + mask_tile_count -= 1; + + CUTLASS_PRAGMA_NO_UNROLL + for (; mask_tile_count > 0; mask_tile_count -= 1) { + + pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); + + Tensor tTMEM_LOADVrS = make_tensor(shape(tTMEM_LOADVcS)); + + // read row_wise new global max + copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS); + + // e^(scale * (old_max - new_max) + float scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax))); + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + + correction_rescale(scale, uint32_t(TmemAllocation::O0)); + + pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); + ++pipeline_s1_c_consumer_state; + + cutlass::arch::fence_view_async_tmem_store(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + + pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); + + copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS); + + scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax))); + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + + correction_rescale(scale, uint32_t(TmemAllocation::O1)); + + pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); + ++pipeline_s0_c_consumer_state; + + cutlass::arch::fence_view_async_tmem_store(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + } + + pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); + ++pipeline_s1_c_consumer_state; + + // do the final correction to O1 + // better to somehow special-case it in the loop above + // doesn't matter for non-persistent code, but if it were + // persistent we do not want to release O too early + + pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); + + // read from V0 + // read row_sum and final row_max here + Tensor tTMEM_LOADVrS = make_tensor(shape(tTMEM_LOADVcS)); + copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS); + + pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); + ++pipeline_s0_c_consumer_state; + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + pipeline_epi.producer_acquire(pipeline_epi_producer_state); + // store to epi smem + + // loop: + // TMEM_LOAD + // FMUL2 scale = 1 / global_sum * out_quant_scale + // F2FP + // store to smem + Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{}); + Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), select<0,3>(problem_shape), epilogue.params.dLSE); + correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _0{}, sO); + + if (epilogue.params.ptr_LSE != nullptr) { + int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord); + + int row_offset = 0; + if constexpr (is_variable_length_v>) { + row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)]; + } + + ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax); + + if (row_idx < get<0>(problem_shape)) { + gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse; + } + } + + cutlass::arch::fence_view_async_tmem_load(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + + pipeline_epi.producer_commit(pipeline_epi_producer_state); + ++pipeline_epi_producer_state; + + pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); + + // load from V1 + copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS); + + pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); + ++pipeline_s1_c_consumer_state; + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + pipeline_epi.producer_acquire(pipeline_epi_producer_state); + + correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _1{}, sO); + + if (epilogue.params.ptr_LSE != nullptr) { + int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{}); + + ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax); + + int row_offset = 0; + if constexpr (is_variable_length_v>) { + row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)]; + } + + if (row_idx < get<0>(problem_shape)) { + gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse; + } + } + + cutlass::arch::fence_view_async_tmem_load(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + + pipeline_epi.producer_commit(pipeline_epi_producer_state); + ++pipeline_epi_producer_state; + } + + + template< + class BlkCoord, class ProblemShape, class ParamsProblemShape, + class TensorStorageEpi, class CollectiveEpilogue + > + CUTLASS_DEVICE auto + correction_empty( + BlkCoord const& blk_coord, + Params const& params, ProblemShape const& problem_shape, + ParamsProblemShape const& params_problem_shape, + TensorStorageEpi& shared_storage_epi, + PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state, + CollectiveEpilogue& epilogue) { + + pipeline_epi.producer_acquire(pipeline_epi_producer_state); + + Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{}); + Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), select<0,3>(problem_shape), epilogue.params.dLSE); + float lse = -INFINITY; + int thread_idx = threadIdx.x % (4 * NumThreadsPerWarp); + +#define DSHOW(x) print(#x ": "); print(x); print("\n") + if (threadIdx.x % 128 == 0 && block0()) { + DSHOW(sO); + } +#if 1 + + using ElementOut = typename CollectiveEpilogue::ElementOut; + auto tiled_copy = make_cotiled_copy( + Copy_Atom, ElementOut>{}, + make_ordered_layout(make_shape(_128{}, Int{}), Step<_1, _0>{}), + sO.layout()); + + auto thr_copy = tiled_copy.get_slice(thread_idx); + auto tOgO = thr_copy.partition_D(sO); + auto tOrO = make_tensor(shape(tOgO(_,_,_,_0{}))); + clear(tOrO); + + copy(tiled_copy, tOrO, tOgO(_,_,_,_0{})); +#endif + + if (epilogue.params.ptr_LSE != nullptr) { + int row_idx = thread_idx + get<0>(TileShape{}) * get<0>(blk_coord); + + int row_offset = 0; + if constexpr (is_variable_length_v>) { + row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)]; + } + + if (row_idx < get<0>(problem_shape)) { + gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse; + } + } + + pipeline_epi.producer_commit(pipeline_epi_producer_state); + ++pipeline_epi_producer_state; + + copy(tiled_copy, tOrO, tOgO(_,_,_,_1{})); + cutlass::arch::fence_view_async_shared(); + pipeline_epi.producer_acquire(pipeline_epi_producer_state); + + if (epilogue.params.ptr_LSE != nullptr) { + int row_idx = thread_idx + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{}); + + int row_offset = 0; + if constexpr (is_variable_length_v>) { + row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)]; + } + + if (row_idx < get<0>(problem_shape)) { + gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse; + } + } + + cutlass::arch::fence_view_async_shared(); + pipeline_epi.producer_commit(pipeline_epi_producer_state); + ++pipeline_epi_producer_state; + } + +}; + +} // namespace cutlass::fmha::collective diff --git a/csrc/sm100/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp b/csrc/sm100/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp new file mode 100644 index 0000000..c2d3e2b --- /dev/null +++ b/csrc/sm100/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp @@ -0,0 +1,340 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cute/tensor.hpp" +#include "cute/layout.hpp" + +#include "collective/fmha_common.hpp" +#include "collective/fmha_fusion.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; + +template< + class Element, + class StrideQ, + class StrideK, + class StrideV, + class CollectiveMmaQK, + class CollectiveMmaPV, + class SmemLayoutQ, + class SmemLayoutK, + class SmemLayoutV, + class TensorStorage, + class PipelineQ, + class PipelineKV, + class Mask, + class TileShape, + class OrderLoadEpilogue = cute::false_type +> +struct Sm100MlaFwdLoadTmaWarpspecialized { + + using TileShapeQK = typename CollectiveMmaQK::TileShape; + using TileShapePV = typename CollectiveMmaPV::TileShape; + + static constexpr int TransactionBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v); + static constexpr int TransactionBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v); + + static const int NumWarpsEpilogue = 1; + static const int NumWarpsLoad = 1; + + struct Arguments { + const Element* ptr_Q; + StrideQ dQ; + const Element* ptr_K; + StrideK dK; + const Element* ptr_V; + StrideV dV; + }; + + using TMA_Q = typename CollectiveMmaQK::Params::TMA_A; + using TMA_K = typename CollectiveMmaQK::Params::TMA_B; + using TMA_V = typename CollectiveMmaPV::Params::TMA_B; + + struct Params { + TMA_Q tma_load_q; + TMA_K tma_load_k; + TMA_V tma_load_v; + }; + + template + static Params to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + void* workspace) { + + auto ptr_Q = args.ptr_Q; + auto ptr_K = args.ptr_K; + auto ptr_V = args.ptr_V; + auto dQ = args.dQ; + auto dK = args.dK; + auto dV = args.dV; + auto problem_shape_qk = replace<2>(problem_shape, get<2, 0>(problem_shape) + get<2, 1>(problem_shape)); + + if constexpr (is_variable_length_v>) { + auto cumulative_length_q = get<0>(problem_shape).cumulative_length; + if (cumulative_length_q != nullptr) { + int max_length_q = get<0>(problem_shape).max_length; + // for variable sequence lenght, the batch is in units of row_stride + get<2,1>(dQ) = get<0>(dQ); + get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_q * (1 + get<3,1>(problem_shape))); + // offset ptr by the amount we add back in later + ptr_Q -= max_length_q * get<0>(dQ); + } + } + + if constexpr (is_variable_length_v>) { + auto cumulative_length_kv = get<1>(problem_shape).cumulative_length; + if (cumulative_length_kv != nullptr) { + int max_length_kv = get<1>(problem_shape).max_length; + // for variable sequence lenght, the batch is in units of row_stride + get<2,1>(dK) = get<0>(dK); + get<2,1>(dV) = get<0>(dV); + get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_kv * (1 + get<3,1>(problem_shape))); + // offset ptr by the amount we add back in later + ptr_K -= max_length_kv * get<0>(dK); + ptr_V -= max_length_kv * get<0>(dV); + } + } + + auto problem_shape_pv = replace<1>(select<0,2,1,3>(problem_shape_qk), get<2, 0>(problem_shape)); + + auto params_qk = CollectiveMmaQK::to_underlying_arguments( + problem_shape_qk, + typename CollectiveMmaQK::Arguments { + ptr_Q, dQ, + ptr_K, dK, + }, /*workspace=*/ nullptr); + + auto params_pv = CollectiveMmaPV::to_underlying_arguments( + problem_shape_pv, + typename CollectiveMmaPV::Arguments { + ptr_K, dK, // never used, dummy + ptr_V, select<1,0,2>(dV), + }, /*workspace=*/ nullptr); + + return Params{ + params_qk.tma_load_a, + params_qk.tma_load_b, + params_pv.tma_load_b + }; + } + + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor()); + } + + template + CUTLASS_DEVICE void + load( + BlkCoord const& blk_coord_in, ProblemShape const& problem_shape, + Params const& params, ParamsProblemShape const& params_problem_shape, + TensorStorage& storage, + PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state, + PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) { + + BlkCoord blk_coord_q = blk_coord_in; + BlkCoord blk_coord_kv = blk_coord_in; + + auto problem_shape_qk = replace<2>(problem_shape, get<2, 0>(problem_shape) + get<2, 1>(problem_shape)); + auto problem_shape_v = replace<2>(problem_shape, get<2, 0>(problem_shape)); + + int mask_tile_count = Mask{}.get_trip_count(blk_coord_in, TileShape{}, problem_shape); + + using X = Underscore; + + // this one is only executed by one thread, no need to elect_one + + // Q1, K1, Q2, V1, K2, V2, K3, V3, ... + // two pipes: Q and KV + // from Memory (prod) to TensorCore (cons) + + // compute gQ, sQ + // we load 2*get<0>(blk_coord), and 2*get<0>(blk_coord) + 1 + ThrMMA mma_qk = typename CollectiveMmaQK::TiledMma{}.get_slice(0); + Tensor mQ_qdl_p = params.tma_load_q.get_tma_tensor(select<0,2,3>(problem_shape_qk)); + + int q_offs_0 = 0; + int q_offs_2_1 = 0; + + if constexpr (is_variable_length_v>) { + auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length; + if (cumulative_length_q != nullptr) { + int max_length_q = get<0>(params_problem_shape).max_length; + q_offs_0 = max_length_q - get<0>(problem_shape); + q_offs_2_1 = cumulative_length_q[get<2,1>(blk_coord_q)] + get<0>(problem_shape); + get<2,1>(blk_coord_q) = 0; + } + } + + Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, q_offs_2_1)), mQ_qdl_p); + + Tensor gQ_qdl = local_tile(mQ_qdl, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{}); + Tensor tSgQ_qdl = mma_qk.partition_A(gQ_qdl); + Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{}); + auto [tQgQ_qdl, tQsQ] = tma_partition( + params.tma_load_q, _0{}, make_layout(_1{}), + group_modes<0,3>(sQ), group_modes<0,3>(tSgQ_qdl) + ); + Tensor tQgQ = tQgQ_qdl(_, _, _0{}, get<2>(blk_coord_q)); + + // compute gK, sK + Tensor mK_kdl_p = params.tma_load_k.get_tma_tensor(select<1,2,3>(problem_shape_qk)); + + int kv_offs_0 = 0; + int kv_offs_2_1 = 0; + + if constexpr (is_variable_length_v>) { + auto cumulative_length = get<1>(params_problem_shape).cumulative_length; + if (cumulative_length != nullptr) { + int max_length = get<1>(params_problem_shape).max_length; + kv_offs_0 = max_length - get<1>(problem_shape); + kv_offs_2_1 = cumulative_length[get<2,1>(blk_coord_kv)] + get<1>(problem_shape); + get<2,1>(blk_coord_kv) = 0; + } + } + + Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, kv_offs_2_1)), mK_kdl_p); + + Tensor gK_kdl = local_tile(mK_kdl, TileShapeQK{}, make_coord(_, _, _), Step{}); + Tensor tSgK_kdl = mma_qk.partition_B(gK_kdl); + Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{}); + auto [tKgK_kdl, tKsK] = tma_partition( + params.tma_load_k, _0{}, make_layout(_1{}), + group_modes<0,3>(sK), group_modes<0,3>(tSgK_kdl) + ); + Tensor tKgK = tKgK_kdl(_, _, _0{}, get<2>(blk_coord_kv)); + + // compute gV, sV + ThrMMA mma_pv = typename CollectiveMmaPV::TiledMma{}.get_slice(0); + Tensor mV_dkl_p = params.tma_load_v.get_tma_tensor(select<2,1,3>(problem_shape_v)); + + Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, kv_offs_2_1)), mV_dkl_p); + + Tensor gV_dkl = local_tile(mV_dkl, TileShapePV{}, make_coord(_, _, _), Step{}); + Tensor tOgV_dkl = mma_pv.partition_B(gV_dkl); + Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{}); + auto [tVgV_dkl, tVsV] = tma_partition( + params.tma_load_v, _0{}, make_layout(_1{}), + group_modes<0,3>(sV), group_modes<0,3>(tOgV_dkl) + ); + auto tVgV = tVgV_dkl(_, _0{}, _, get<2>(blk_coord_kv)); + + // blk_coord in decomposed in terms of TileShape, not TileShapeQK + // As such, it needs to be transformed as + // (a,b,c): a -> 2*a (Q0) 2*a+1 (Q1) + // b -> 2*a (Ki i even) 2*a+1 (Ki i odd) + + uint32_t lane_predicate = cute::elect_one_sync(); + + // Q1 + int q0_index = 2 * get<0>(blk_coord_q); + int q1_index = 2 * get<0>(blk_coord_q) + 1; + pipeline_q.producer_acquire(pipeline_q_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state); + copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q0_index), tQsQ(_, pipeline_q_producer_state.index())); + } + ++pipeline_q_producer_state; + + // K1 + int k_index = 0; + pipeline_kv.producer_acquire(pipeline_kv_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); + copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, k_index), tKsK(_, pipeline_kv_producer_state.index() / 2)); + } + ++pipeline_kv_producer_state; + + // Q2 + pipeline_q.producer_acquire(pipeline_q_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state); + copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q1_index), tQsQ(_, pipeline_q_producer_state.index())); + } + ++pipeline_q_producer_state; + + if constexpr (cute::is_same_v) { + cutlass::arch::NamedBarrier::sync((NumWarpsLoad + NumWarpsEpilogue) * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + } + + // V1 + pipeline_kv.producer_acquire_bytes(pipeline_kv_producer_state, TransactionBytesLoadV); + if (lane_predicate) { + auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); + copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index() / 2)); + } + ++pipeline_kv_producer_state; + k_index += 1; + + // loop: + mask_tile_count -= 1; + for (; mask_tile_count > 0; mask_tile_count -= 1) { + + // Ki + pipeline_kv.producer_acquire(pipeline_kv_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); + copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, k_index), tKsK(_, pipeline_kv_producer_state.index() / 2)); + + // prefetch vi + cute::prefetch(params.tma_load_v, tVgV(_, k_index)); + } + ++pipeline_kv_producer_state; + + // Vi + pipeline_kv.producer_acquire_bytes(pipeline_kv_producer_state, TransactionBytesLoadV); + if (lane_predicate) { + auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); + copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index() / 2)); + + // prefetch ki+1 + if(mask_tile_count > 1) { + cute::prefetch(params.tma_load_k, tKgK(_, k_index + 1)); + } + } + ++pipeline_kv_producer_state; + k_index += 1; + } + } +}; + +} // namespace cutlass::fmha::collective diff --git a/csrc/sm100/common/gather_tensor.hpp b/csrc/sm100/common/gather_tensor.hpp new file mode 100644 index 0000000..46fb640 --- /dev/null +++ b/csrc/sm100/common/gather_tensor.hpp @@ -0,0 +1,215 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/layout.hpp" +#include "cute/tensor.hpp" +#include "cute/util/print.hpp" + +namespace example { + +using namespace cute; + +// Empty type used to disable gather/scatter for a GEMM argument +struct NoGather +{ + template + NoGather(Ts...) {}; +}; + +/// Function object that applies an index to its argument +template +struct IndexedGather +{ + CUTE_HOST_DEVICE constexpr + IndexedGather(Index const *indices = {}): indices_(indices) {} + + template + CUTE_HOST_DEVICE constexpr + Index + operator()(I i) const { return indices_[i]; } + + CUTE_HOST_DEVICE friend + void + print(IndexedGather const &s) { + cute::print("Indexed"); + } + + Index const *indices_; +}; + +/// Function object that applies a stride to its argument +/// Example: StridedFunc gathers every other row/column +template +struct StridedGather +{ + CUTE_HOST_DEVICE constexpr + StridedGather(Stride stride = {}): stride_(stride) {} + + template + CUTE_HOST_DEVICE constexpr + auto + operator()(I i) const { return i * stride_; } + + CUTE_HOST_DEVICE friend + void + print(StridedGather const &s) { + cute::print("Strided{"); + print(s.stride_); + cute::print("}"); + } + + Stride stride_; +}; + +/// Custom stride object that applies a function followed by a stride +template +struct CustomStride +{ + CUTE_HOST_DEVICE constexpr + CustomStride(Func const &func, Stride const &stride): func_(func), stride_(stride) {} + + template + CUTE_HOST_DEVICE constexpr friend + auto + operator*(I i, CustomStride const &s) { return s.func_(i) * s.stride_; } + + template + CUTE_HOST_DEVICE constexpr friend + auto + operator*(CustomStride const &s, I i) { return s.func_(i) * s.stride_; } + + CUTE_HOST_DEVICE friend + void + print(CustomStride const & s) { + cute::print("Custom{"); + print(s.func_); + cute::print(","); + print(s.stride_); + cute::print("}"); + } + + template + CUTE_HOST_DEVICE constexpr friend + auto + safe_div(CustomStride const &s, Div const &div) + { + return CustomStride(s.func_, safe_div(s.stride_, div)); + } + + // Circumvent the requirement on make_layout that shape and stride are integral + template + CUTE_HOST_DEVICE constexpr friend + auto + make_layout(Shape const &shape, CustomStride const &stride) + { + return Layout(shape, stride); + } + + Func func_; + Stride stride_; +}; + +template +CUTLASS_HOST_DEVICE +auto +make_custom_stride_layout(Stride const &stride, Func&& func) +{ + // Use a dummy shape and replace the first non-unit stride with a custom gather stride + auto idx = find_if(stride, [](auto x){ return not is_constant<1, decltype(x)>{}; }); + constexpr int I = decltype(idx)::value; + return make_layout(repeat_like(stride, _1{}), + replace(stride, CustomStride{static_cast(func), get(stride)})); +} + +/// Helper function to optionally create a gather tensor +template +CUTLASS_HOST_DEVICE +auto +make_gather_tensor(Iterator iter, Shape const &shape, Stride const &stride, Func &&func) +{ + if constexpr (not cutlass::platform::is_same, NoGather>::value) { + Layout matrix_layout = make_identity_layout(shape); + auto offset = as_arithmetic_tuple(repeat_like(shape, _0{})); + Layout gather_layout = make_custom_stride_layout(stride, static_cast(func)); + return make_tensor(iter, ComposedLayout{gather_layout, offset, matrix_layout}); + } else { + return make_tensor(iter, shape, stride); + } +} + +} // namespace example + +namespace cute +{ + +template +CUTE_HOST_DEVICE constexpr +auto +upcast(Shape const& shape, Stride const& stride) +{ + if constexpr (is_tuple::value) { + return transform_layout(shape, stride, [](auto const& s, auto const& d) { return upcast(s,d); }); + } else if constexpr (is_scaled_basis::value) { + if constexpr (Stride::mode() == I) { + return make_layout(ceil_div(shape, Int{}), ceil_div(stride, Int{})); + } else { + return make_layout(shape, stride); + } + } else { + return upcast(shape, stride); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +upcast(ComposedLayout,Offset,Layout> const& layout) +{ + // Find index of the stride-1 mode - that is the only one that requires updating inner shape and offset + auto idx = find_if(layout.layout_a().stride(), [](auto x){ return is_constant<1, decltype(x)>{}; }); + constexpr int I = decltype(idx)::value; + + // Upcast the outer layout (works as expected) + auto outer = upcast(layout.layout_a()); + + // Upcast the accumulated offset along stride-1 mode + auto offset = as_arithmetic_tuple(replace(layout.offset(), upcast(get(layout.offset())))); + + // Upcast the inner layout's shape along stride-1 mode + auto inner = upcast(layout.layout_b().shape(), layout.layout_b().stride()); + + return composition(outer, offset, inner); +} + +} // namespace example diff --git a/csrc/sm100/common/helper.h b/csrc/sm100/common/helper.h new file mode 100644 index 0000000..e957c4e --- /dev/null +++ b/csrc/sm100/common/helper.h @@ -0,0 +1,72 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + #pragma once + + #include "cuda_runtime.h" + #include + + /** + * Panic wrapper for unwinding CUTLASS errors + */ + #define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + if (error != cutlass::Status::kSuccess) { \ + std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at: " << __LINE__ \ + << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + + + /** + * Panic wrapper for unwinding CUDA runtime errors + */ + #define CUDA_CHECK(status) \ + { \ + cudaError_t error = status; \ + if (error != cudaSuccess) { \ + std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \ + << " at line: " << __LINE__ << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + + +#define FLASH_MLA_ASSERT(cond) \ +do { \ + if (!(cond)) { \ + std::cerr << "FLASH_MLA_ASSERT: " << #cond << " failed at " << __FILE__ << ":" << __LINE__ << std::endl; \ + std::abort(); \ + } \ +} while (0) + + \ No newline at end of file diff --git a/csrc/sm100/common/mask.cuh b/csrc/sm100/common/mask.cuh new file mode 100644 index 0000000..d118aab --- /dev/null +++ b/csrc/sm100/common/mask.cuh @@ -0,0 +1,8 @@ +#pragma once + +enum class MaskMode { + kNone = 0U, // No mask + kCausal = 1U, // Causal mask + kCustom = 2U, // Custom mask +}; + diff --git a/csrc/sm100/common/pipeline_mla.hpp b/csrc/sm100/common/pipeline_mla.hpp new file mode 100644 index 0000000..5bbeed9 --- /dev/null +++ b/csrc/sm100/common/pipeline_mla.hpp @@ -0,0 +1,250 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief Support the producer to acquire specific bytes of data. +*/ + +#pragma once + +#include "cutlass/pipeline/sm100_pipeline.hpp" + +namespace cutlass { + +using namespace cute; + +template < + int Stages_, + class ClusterShape = Shape, + class AtomThrShape_MNK_ = Shape<_1,_1,_1> +> +class PipelineTmaAsyncMla { + +public: + static constexpr uint32_t Stages = Stages_; + using AtomThrShape_MNK = AtomThrShape_MNK_; + +private: + using Impl = PipelineTmaUmmaAsync; + +public: + using FullBarrier = typename Impl::FullBarrier; + using EmptyBarrier = typename Impl::EmptyBarrier; + using ProducerBarrierType = typename Impl::ProducerBarrierType; + using ConsumerBarrierType = typename Impl::ConsumerBarrierType; + using PipelineState = typename Impl::PipelineState; + using SharedStorage = typename Impl::SharedStorage; + using ThreadCategory = typename Impl::ThreadCategory; + using Params = typename Impl::Params; + + + using McastDirection = McastDirection; + + // Helper function to initialize barriers + static + CUTLASS_DEVICE + void + init_barriers(SharedStorage& storage, Params params, ClusterShape cluster_shape) { + int warp_idx = canonical_warp_idx_sync(); + if (warp_idx == params.initializing_warp) { + // Barrier FULL and EMPTY init + constexpr int producer_arv_cnt = 1; + auto atom_thr_shape = AtomThrShape_MNK{}; + uint32_t const multicast_consumer_arrival_count = (cute::size<0>(cluster_shape) / cute::size<0>(atom_thr_shape)) + + (cute::size<1>(cluster_shape) / cute::size<1>(atom_thr_shape)) - 1; + + cutlass::arch::detail::initialize_barrier_array_pair_aligned( + storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, multicast_consumer_arrival_count); + } + cutlass::arch::fence_barrier_init(); + } + + static + CUTLASS_DEVICE + void + init_barriers(SharedStorage& storage, Params params, ClusterShape cluster_shape, McastDirection mcast_direction) { + auto atom_thr_shape = AtomThrShape_MNK{}; + + int warp_idx = canonical_warp_idx_sync(); + if (warp_idx == params.initializing_warp) { + // Barrier FULL and EMPTY init + constexpr int producer_arv_cnt = 1; + uint32_t const multicast_consumer_arrival_count = (mcast_direction == McastDirection::kRow) ? + cute::size<1>(cluster_shape) / cute::size<1>(atom_thr_shape) : // Mcast with row ctas + cute::size<0>(cluster_shape) / cute::size<0>(atom_thr_shape); // Mcast with col ctas + + cutlass::arch::detail::initialize_barrier_array_pair_aligned( + storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, multicast_consumer_arrival_count); + } + cutlass::arch::fence_barrier_init(); + } + + CUTLASS_DEVICE + void init_masks(ClusterShape cluster_shape, dim3 block_id_in_cluster = cute::block_id_in_cluster()) { + // Calculate consumer mask + if (params_.role == ThreadCategory::Consumer) { + auto cluster_layout = make_layout(cluster_shape); + block_id_mask_ = detail::calculate_multicast_mask(cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster); + } + } + + CUTLASS_DEVICE + void init_masks(ClusterShape cluster_shape, McastDirection mcast_direction) { + // Calculate consumer mask + dim3 block_id_in_cluster = cute::block_id_in_cluster(); + auto cluster_layout = make_layout(cluster_shape); + if (mcast_direction == McastDirection::kRow) { + block_id_mask_ = detail::calculate_multicast_mask(cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster); + } + else { + block_id_mask_ = detail::calculate_multicast_mask(cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster); + } + } + + +public: + template + CUTLASS_DEVICE + PipelineTmaAsyncMla(SharedStorage& storage, Params params, ClusterShape cluster_shape, InitBarriers = {}, InitMasks = {}) + : impl_(storage, params, cluster_shape, cute::false_type{}, cute::false_type{}) + , params_(params) + , empty_barrier_ptr_(&storage.empty_barrier_[0]) + , full_barrier_ptr_(&storage.full_barrier_[0]) { + static_assert(cute::is_same_v || cute::is_same_v); + if constexpr (cute::is_same_v) { + init_barriers(storage, params_, cluster_shape); + } + + static_assert(cute::is_same_v || cute::is_same_v); + if constexpr (cute::is_same_v) { + init_masks(cluster_shape); + } + } + + template + CUTLASS_DEVICE + PipelineTmaAsyncMla(SharedStorage& storage, Params params, ClusterShape cluster_shape, McastDirection mcast_direction, InitBarriers = {}, InitMasks = {}) + : impl_(storage, params, cluster_shape, cute::false_type{}, cute::false_type{}) + , params_(params) + , empty_barrier_ptr_(&storage.empty_barrier_[0]) + , full_barrier_ptr_(&storage.full_barrier_[0]) { + static_assert(cute::is_same_v || cute::is_same_v); + if constexpr (cute::is_same_v) { + init_barriers(storage, params_, cluster_shape, mcast_direction); + } + + static_assert(cute::is_same_v || cute::is_same_v); + if constexpr (cute::is_same_v) { + init_masks(cluster_shape, mcast_direction); + } + } + + + CUTLASS_DEVICE + void producer_acquire(PipelineState state, ProducerToken barrier_token = {BarrierStatus::WaitAgain}) { + impl_.producer_acquire(state, barrier_token); + } + + CUTLASS_DEVICE + void producer_acquire_bytes(uint32_t stage, uint32_t bytes, uint32_t phase, ProducerToken barrier_token) { + detail::pipeline_check_is_producer(params_.role); + if (barrier_token != BarrierStatus::WaitDone) { + empty_barrier_ptr_[stage].wait(phase); + } + + if (params_.is_leader) { + full_barrier_ptr_[stage].arrive_and_expect_tx(bytes); + } + #ifndef NDEBUG + if (params_.role == ThreadCategory::Consumer || params_.role == ThreadCategory::NonParticipant) { + asm volatile ("brkpt;\n" ::); + } + + // Most likely you have elected more than one leader + if (params_.is_leader && (threadIdx.x % 32 != 0)) { + asm volatile ("brkpt;\n" ::); + } + #endif + } + + CUTLASS_DEVICE + void producer_acquire_bytes(PipelineState state, uint32_t bytes, ProducerToken barrier_token = {BarrierStatus::WaitAgain}) { + producer_acquire_bytes(state.index(), bytes, state.phase(), barrier_token); + } + + CUTLASS_DEVICE + ProducerBarrierType* producer_get_barrier(PipelineState state) { + return impl_.producer_get_barrier(state); + } + + CUTLASS_DEVICE + void consumer_wait(PipelineState state, ConsumerToken barrier_token = {BarrierStatus::WaitAgain}) { + impl_.consumer_wait(state, barrier_token); + } + + CUTLASS_DEVICE + void consumer_release(PipelineState state) { + consumer_release(state.index(), false); + } + +private: + Impl impl_; + Params params_; + EmptyBarrier *empty_barrier_ptr_; + FullBarrier *full_barrier_ptr_; + uint16_t block_id_mask_ = 0; + static constexpr bool is_2sm_mma = size(AtomThrShape_MNK{}) > 1; + + // Consumer signalling Producer of completion + // Ensures all blocks in the Same Row and Column get notifed. + CUTLASS_DEVICE + void consumer_release(uint32_t stage, uint32_t skip) { + detail::pipeline_check_is_consumer(params_.role); + uint64_t* smem_ptr = reinterpret_cast(&empty_barrier_ptr_[stage]); + if constexpr (is_2sm_mma) { // Mma cluster shape is 2x1 + if (!skip) { + cutlass::arch::umma_arrive_multicast_2x1SM(smem_ptr, block_id_mask_); + } + } + else { + if (!skip) { + if constexpr (cute::is_static_v and size(ClusterShape{}) == 1) { + cutlass::arch::umma_arrive(smem_ptr); + } + else { + cutlass::arch::umma_arrive_multicast(smem_ptr, block_id_mask_); + } + } + } + } +}; + +} diff --git a/csrc/sm100/common/pow_2.hpp b/csrc/sm100/common/pow_2.hpp new file mode 100644 index 0000000..eca9325 --- /dev/null +++ b/csrc/sm100/common/pow_2.hpp @@ -0,0 +1,92 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include + +#include + +namespace cutlass::fmha { + +struct Pow2 { + int n; + int log2_n; + + explicit CUTE_DEVICE Pow2(int n) : n(n) { +#ifdef __CUDA_ARCH__ + log2_n = __ffs(n) - 1; +#endif + } + + template + CUTE_HOST_DEVICE T operator *(T const& b) const { + return n * b; + } + + template + CUTE_HOST_DEVICE auto operator *(Int const&) const { + if constexpr (N & (N - 1) == 0) { + return Pow2{n * N}; + } + return n * N; + } + +}; + +template +CUTE_HOST_DEVICE auto operator/(T const& a, Pow2 const& b) { + return a >> b.log2_n; +} + +template +CUTE_HOST_DEVICE auto operator%(T const& a, Pow2 const& b) { + return a & (b.n - 1); +} + +template +CUTE_HOST_DEVICE bool operator<(T const& a, Pow2 const& b) { + return a < b.n; +} + +CUTE_HOST_DEVICE void print(Pow2 const& a) { + printf("2^%d", a.log2_n); +} + +} // end namespace cutlass::fmha + +namespace cute { + +template <> +struct is_integral : true_type {}; + +} // end namespace cute diff --git a/csrc/sm100/common/utils.hpp b/csrc/sm100/common/utils.hpp new file mode 100644 index 0000000..f43770d --- /dev/null +++ b/csrc/sm100/common/utils.hpp @@ -0,0 +1,83 @@ +#pragma once + +#include "cutlass/numeric_types.h" +#include "helper.h" + +template +struct cutlass_dtype { + using type = T; +}; + +template <> +struct cutlass_dtype { + using type = cutlass::half_t; +}; + +template <> +struct cutlass_dtype { + using type = cutlass::bfloat16_t; +}; + +template <> +struct cutlass_dtype<__nv_fp8_e4m3> { + using type = cutlass::float_e4m3_t; +}; + +template <> +struct cutlass_dtype<__nv_fp8_e5m2> { + using type = cutlass::float_e5m2_t; +}; + +template +using cutlass_dtype_t = typename cutlass_dtype::type; + +template +struct DeviceAllocation { + T* ptr_ = nullptr; + size_t offset_ = 0; + size_t size_ = 0; + + DeviceAllocation(DeviceAllocation const&) = delete; + DeviceAllocation& operator=(DeviceAllocation const&) = delete; + + DeviceAllocation() = default; + DeviceAllocation(size_t size) { reset(size); } + ~DeviceAllocation() { reset(); } + + void reset(size_t size, size_t offset=0) { + reset(); + auto ret = cudaMalloc(&ptr_, sizeof(T) * (size + offset)); + assert(ret == cudaSuccess); + size_ = size; + offset_ = offset; + } + + T* get() { + return ptr_ + offset_; + } + + const T* get() const { + return ptr_ + offset_; + } + + void reset() { + if (ptr_ != nullptr) { + auto ret = cudaFree(ptr_); + assert(ret == cudaSuccess); + } + } + + size_t size() const { return size_; } + + size_t get_storage_size() const { return (size_ + offset_) * sizeof(T); } + + void copy_from_host(const T* ptr, size_t sz) { + auto ret = cudaMemcpy(ptr_, ptr, sz * sizeof(T), cudaMemcpyDefault); + assert(ret == cudaSuccess); + } + + void copy_from_device(const T* ptr, size_t sz) { + auto ret = cudaMemcpy(ptr_, ptr, sz * sizeof(T), cudaMemcpyDefault); + assert(ret == cudaSuccess); + } +}; \ No newline at end of file diff --git a/csrc/sm100/device/fmha.hpp b/csrc/sm100/device/fmha.hpp new file mode 100644 index 0000000..f8406d3 --- /dev/null +++ b/csrc/sm100/device/fmha.hpp @@ -0,0 +1,276 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief An universal device layer for cutlass 3.x-style kernels. +*/ + +#pragma once + +// common +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" + +#if !defined(__CUDACC_RTC__) +#include "cutlass/cluster_launch.hpp" +#include "cutlass/trace.h" +#endif // !defined(__CUDACC_RTC__) + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::fmha::device { + +//////////////////////////////////////////////////////////////////////////////// +////////////////////////////// CUTLASS 3.x API ///////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +template +class FMHA { +public: + using Kernel = Kernel_; + + static int const kThreadCount = Kernel::MaxThreadsPerBlock; + + /// Argument structure: User API + using Arguments = typename Kernel::Arguments; + /// Argument structure: Kernel API + using Params = typename Kernel::Params; + +private: + + /// Kernel API parameters object + Params params_; + + bool is_initialized(bool set = false) { + static bool initialized = false; + if (set) initialized = true; + return initialized; + } + +public: + + /// Access the Params structure + Params const& params() const { + return params_; + } + + /// Determines whether the GEMM can execute the given problem. + static Status + can_implement(Arguments const& args) { + if (Kernel::can_implement(args)) { + return Status::kSuccess; + } + else { + return Status::kInvalid; + } + } + + /// Gets the workspace size + static size_t + get_workspace_size(Arguments const& args) { + size_t workspace_bytes = 0; + workspace_bytes += Kernel::get_workspace_size(args); + return workspace_bytes; + } + + /// Computes the grid shape + static dim3 + get_grid_shape(Params const& params) { + return Kernel::get_grid_shape(params); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int /* smem_capacity */ = -1) { + CUTLASS_TRACE_HOST("FMHA::maximum_active_blocks()"); + int max_active_blocks = -1; + int smem_size = Kernel::SharedStorageSize; + + // first, account for dynamic smem capacity if needed + cudaError_t result; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaFuncSetAttribute() returned error: " + << cudaGetErrorString(result)); + return -1; + } + } + + // query occupancy after setting smem size + result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, + device_kernel, + Kernel::MaxThreadsPerBlock, + smem_size); + + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: " + << cudaGetErrorString(result)); + return -1; + } + + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + + /// Initializes GEMM state from arguments. + Status + initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("FMHA::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + // Initialize the workspace + Status status = Kernel::initialize_workspace(args, workspace, stream); + if (status != Status::kSuccess) { + return status; + } + + // Initialize the Params structure + params_ = Kernel::to_underlying_arguments(args, workspace); + + if (is_initialized()) return Status::kSuccess; + + // account for dynamic smem capacity if needed + int smem_size = Kernel::SharedStorageSize; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + cudaError_t result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + + is_initialized(true); + + return Status::kSuccess; + } + + /// Update API is preserved in 3.0, but does not guarantee a lightweight update of params. + Status + update(Arguments const& args, void* workspace = nullptr) { + CUTLASS_TRACE_HOST("FMHA()::update() - workspace: " << workspace); + + size_t workspace_bytes = get_workspace_size(args); + if (workspace_bytes > 0 && nullptr == workspace) { + return Status::kErrorWorkspaceNull; + } + + params_ = Kernel::to_underlying_arguments(args, workspace); + return Status::kSuccess; + } + + /// Primary run() entry point API that is static allowing users to create and manage their own params. + /// Supplied params struct must be construct by calling Kernel::to_underling_arguments() + static Status + run(Params& params, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("FMHA::run()"); + dim3 const block = Kernel::get_block_shape(); + dim3 const grid = get_grid_shape(params); + + // configure smem size and carveout + int smem_size = Kernel::SharedStorageSize; + + Status launch_result; + // Use extended launch API only for mainloops that use it + if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) { + dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}), + cute::size<1>(typename Kernel::ClusterShape{}), + cute::size<2>(typename Kernel::ClusterShape{})); + void const* kernel = (void const*) device_kernel; + void* kernel_params[] = {¶ms}; + launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params); + } + else { + launch_result = Status::kSuccess; + device_kernel<<>>(params); + } + + cudaError_t result = cudaGetLastError(); + if (cudaSuccess == result && Status::kSuccess == launch_result) { + return Status::kSuccess; + } + else { + CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); + return Status::kErrorInternal; + } + } + + // + // Non-static launch overloads that first create and set the internal params struct of this kernel handle. + // + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + Status status = initialize(args, workspace, stream); + if (Status::kSuccess == status) { + status = run(params_, stream); + } + return status; + } + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + return run(args, workspace, stream); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + run(cudaStream_t stream = nullptr) { + return run(params_, stream); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + operator()(cudaStream_t stream = nullptr) { + return run(params_, stream); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::device + +//////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/sm100/device/fmha_device_bwd.hpp b/csrc/sm100/device/fmha_device_bwd.hpp new file mode 100644 index 0000000..d2463ac --- /dev/null +++ b/csrc/sm100/device/fmha_device_bwd.hpp @@ -0,0 +1,340 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +// common +#include "cutlass/cutlass.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cute/tensor.hpp" + +#include "../device/fmha.hpp" +#include "../kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp" +#include "../kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp" +#include "../kernel/fmha_kernel_bwd_sum_OdO.hpp" +#include "../kernel/fmha_kernel_bwd_convert.hpp" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::fmha::device { + +//////////////////////////////////////////////////////////////////////////////// +////////////////////////////// CUTLASS 3.x API ///////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class Element, + class ElementAccumulator, + class TileShape, + bool IsMla, + class Mask +> +class Sm100FmhaBwd { +public: + /// Argument structure: User API + struct Arguments { + // Q K D D_VO HB + ProblemShape problem_shape; + + const Element* ptr_Q; + cute::tuple> stride_Q; + const Element* ptr_K; + cute::tuple> stride_K; + const Element* ptr_V; + cute::tuple> stride_V; + + const Element* ptr_O; + cute::tuple> stride_O; + const ElementAccumulator* ptr_LSE; + cute::tuple> stride_LSE; + + const Element* ptr_dO; + cute::tuple> stride_dO; + + Element* ptr_dQ; + cute::tuple> stride_dQ; + Element* ptr_dK; + cute::tuple> stride_dK; + Element* ptr_dV; + cute::tuple> stride_dV; + + ElementAccumulator softmax_scale; + + cutlass::KernelHardwareInfo hw_info; + }; + + using OperationSumOdO = cutlass::fmha::device::FMHA< + cutlass::fmha::kernel::FmhaKernelBwdSumOdO + >; + using OperationConvert = cutlass::fmha::device::FMHA< + cutlass::fmha::kernel::FmhaKernelBwdConvert + >; + + using OperationMha= cutlass::fmha::device::FMHA< + cutlass::fmha::kernel::Sm100FmhaBwdKernelTmaWarpSpecialized< + ProblemShape, Element, ElementAccumulator, TileShape, Mask + > + >; + + using OperationMla = cutlass::fmha::device::FMHA< + cutlass::fmha::kernel::Sm100FmhaBwdMlaKernelTmaWarpSpecialized< + ProblemShape, Element, ElementAccumulator, TileShape, Mask + > + >; + + using Operation = std::conditional_t; + + using Kernel = typename Operation::Kernel; + + struct Params { + OperationSumOdO op_sum_OdO; + Operation op; + OperationConvert op_convert; + ElementAccumulator* dQ_acc; + size_t dQ_acc_size; + }; + +private: + Params params_; + + static typename OperationSumOdO::Arguments to_sum_OdO_arguments( + Arguments const& args, + ElementAccumulator* sum_odo = nullptr, + ElementAccumulator* scaled_lse = nullptr) { + using namespace cute; + auto [Q_, K, D, D_VO, HB] = args.problem_shape; + auto [H, B] = HB; + D = cutlass::round_up(D, 8); // Alignment + int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment + auto stride_sum_OdO = make_stride(_1{}, make_stride(Q, Q*H)); + auto stride_scaled_lse = make_stride(_1{}, make_stride(Q, Q*H)); + auto log2_e = log2f(expf(1.0f)); + return typename OperationSumOdO::Arguments { + args.problem_shape, + args.ptr_O, args.stride_O, + args.ptr_dO, args.stride_dO, + sum_odo, stride_sum_OdO, + args.ptr_LSE, args.stride_LSE, + scaled_lse, stride_scaled_lse, + -1.0f, -log2_e + }; + } + + static typename OperationConvert::Arguments to_convert_arguments(Arguments const& args, ElementAccumulator* src = nullptr) { + using namespace cute; + auto [Q_, K, D, D_VO, HB] = args.problem_shape; + auto [H, B] = HB; + D = cutlass::round_up(D, 8); // Alignment + int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment + auto stride_src_dQ = make_stride(D, _1{}, make_stride(D*Q, D*Q*H)); + return typename OperationConvert::Arguments { + args.problem_shape, + src, stride_src_dQ, + nullptr, stride_src_dQ, + nullptr, stride_src_dQ, + args.ptr_dQ, args.stride_dQ, + nullptr, args.stride_dK, + nullptr, args.stride_dV, + args.softmax_scale + }; + } + + static typename Operation::Arguments to_bwd_arguments( + Arguments const& args, + ElementAccumulator* sum_OdO = nullptr, cute::tuple> const& stride_sum_OdO = {}, + ElementAccumulator* scaled_lse = nullptr, cute::tuple> const& stride_scaled_lse = {}, + ElementAccumulator* dQ_acc = nullptr, cute::tuple> const& stride_dQ = {}) { + + return typename Operation::Arguments{ + args.problem_shape, + { args.ptr_Q, args.stride_Q, + args.ptr_K, args.stride_K, + args.ptr_V, args.stride_V, + args.ptr_dO, args.stride_dO, + scaled_lse, stride_scaled_lse, + sum_OdO, stride_sum_OdO, + dQ_acc, stride_dQ, + args.softmax_scale }, + { args.ptr_dK, args.stride_dK, + args.ptr_dV, args.stride_dV }, + args.hw_info + }; + } + +public: + + /// Determines whether the GEMM can execute the given problem. + static Status + can_implement(Arguments const& args) { + Status status = Status::kSuccess; + + status = OperationSumOdO::can_implement(to_sum_OdO_arguments(args)); + if (status != Status::kSuccess) { + return status; + } + + status = OperationConvert::can_implement(to_convert_arguments(args)); + if (status != Status::kSuccess) { + return status; + } + + status = Operation::can_implement(to_bwd_arguments(args)); + if (status != Status::kSuccess) { + return status; + } + + return status; + } + + /// Gets the workspace size + static size_t + get_workspace_size(Arguments const& args) { + auto [Q_, K, D, D_VO, HB] = args.problem_shape; + auto [H, B] = HB; + D = cutlass::round_up(D, 8); // Alignment + int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment + size_t workspace_bytes = 0; + // OdO vector + workspace_bytes += B*H*Q * sizeof(ElementAccumulator); + // scaled LSE vector + workspace_bytes += B*H*Q * sizeof(ElementAccumulator); + // FP32 versions of outputs that are churned (start off with Q only) + workspace_bytes += B*H*Q*D * sizeof(ElementAccumulator); + return workspace_bytes; + } + + /// Initializes state from arguments. + Status + initialize_split(Arguments const& args, void* workspace_dQ, void* workspace_sum_OdO, void* workspace_scaled_lse, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("Universal::initialize_split() - workspace_dQ=" + << workspace_dQ << ", workspace_sum_OdO=" << workspace_sum_OdO << "stream: " << (stream ? "non-null" : "null")); + + auto [Q_, K, D, D_VO, HB] = args.problem_shape; + auto [H, B] = HB; + D = cutlass::round_up(D, 8); // Alignment + int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment + ElementAccumulator* sum_OdO = reinterpret_cast(workspace_sum_OdO); + ElementAccumulator* scaled_lse = reinterpret_cast(workspace_scaled_lse); + ElementAccumulator* dQ_acc = reinterpret_cast(workspace_dQ); + params_.dQ_acc = dQ_acc; + params_.dQ_acc_size = B*H*Q*D * sizeof(ElementAccumulator); + auto args_sum_OdO = to_sum_OdO_arguments(args, sum_OdO, scaled_lse); + auto args_convert = to_convert_arguments(args, dQ_acc); + params_.op_sum_OdO.initialize(args_sum_OdO, nullptr, stream); + params_.op_convert.initialize(args_convert, nullptr, stream); + auto args_bwd = to_bwd_arguments( + args, sum_OdO, args_sum_OdO.stride_sum_OdO, + scaled_lse, args_sum_OdO.stride_scaled_lse, + dQ_acc, args_convert.stride_src_dQ + ); + params_.op.initialize(args_bwd, nullptr, stream); + + return Status::kSuccess; + } + + /// Initializes state from arguments. + Status + initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("Universal::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + auto [Q_, K, D, D_VO, HB] = args.problem_shape; + auto [H, B] = HB; + D = cutlass::round_up(D, 8); // Alignment + int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment + char* workspace_chr = reinterpret_cast(workspace); + ElementAccumulator* sum_OdO = reinterpret_cast(workspace_chr); + workspace_chr += B*H*Q * sizeof(ElementAccumulator); + ElementAccumulator* scaled_lse = reinterpret_cast(workspace_chr); + workspace_chr += B*H*Q * sizeof(ElementAccumulator); + ElementAccumulator* dQ_acc = reinterpret_cast(workspace_chr); + return initialize_split(args, dQ_acc, sum_OdO, scaled_lse, stream); + } + + /// Primary run() entry point API that is static allowing users to create and manage their own params. + /// Supplied params struct must be construct by calling Kernel::to_underling_arguments() + static Status + run(Params& params, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("FmhaDeviceBwd::run()"); + + Status result = Status::kSuccess; + result = params.op_sum_OdO.run(stream); + if (result != Status::kSuccess) { + return result; + } + + auto cuda_result = cudaMemsetAsync(params.dQ_acc, 0, params.dQ_acc_size, stream); + if (cuda_result != cudaSuccess) { + return Status::kErrorInternal; + } + + result = params.op.run(stream); + if (result != Status::kSuccess) { + return result; + } + + result = params.op_convert.run(stream); + if (result != Status::kSuccess) { + return result; + } + + return Status::kSuccess; + } + + // + // Non-static launch overloads that first create and set the internal params struct of this kernel handle. + // + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + Status status = initialize(args, workspace, stream); + if (Status::kSuccess == status) { + status = run(params_, stream); + } + return status; + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + run(cudaStream_t stream = nullptr) { + return run(params_, stream); + } + +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::fmha::device + +//////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/sm100/fmha_cutlass_bwd_sm100.cu b/csrc/sm100/fmha_cutlass_bwd_sm100.cu new file mode 100644 index 0000000..4ff745d --- /dev/null +++ b/csrc/sm100/fmha_cutlass_bwd_sm100.cu @@ -0,0 +1,83 @@ +#include +#include +#include +#include +#include +#include "common/mask.cuh" +#include "common/utils.hpp" + +#include "fmha_cutlass_bwd_sm100.cuh" + +template +void call_run_fmha_bwd([[maybe_unused]] Mask mask, [[maybe_unused]] Varlen is_varlen, + [[maybe_unused]] Element in, [[maybe_unused]] ElementOut out, [[maybe_unused]] Mla mla, + at::Tensor workspace_buffer, at::Tensor d_o, at::Tensor q, at::Tensor k, + at::Tensor v, at::Tensor o, at::Tensor lse, + at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv, + at::Tensor dq, at::Tensor dk, at::Tensor dv, + float softmax_scale, int max_seqlen_q, int total_seqlen_kv) { + static constexpr bool IsVarlen = std::is_same_v; + static constexpr bool IsMla = std::is_same_v; + using TileShape = std::conditional_t, Shape<_128, _128, _128, _128>>; + run_fmha_bwd(workspace_buffer, d_o, q, k, v, o, lse, + cumulative_seqlen_q, cumulative_seqlen_kv, + dq, dk, dv, + softmax_scale, max_seqlen_q, total_seqlen_kv); +} + + +void FMHACutlassSM100BwdRun(at::Tensor workspace_buffer, at::Tensor d_o, at::Tensor q, at::Tensor k, + at::Tensor v, at::Tensor o, at::Tensor lse, + at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv, + at::Tensor dq, at::Tensor dk, at::Tensor dv, + int mask_mode_code, float softmax_scale, int max_seqlen_q, int max_seqlen_kv, bool is_varlen) { + + const c10::cuda::OptionalCUDAGuard device_guard(q.device()); + + int head_dim_qk = q.size(-1); + int head_dim_vo = v.size(-1); + MaskMode mask_mode = static_cast(mask_mode_code); + auto scalar_type_in = q.scalar_type(); + auto scalar_type_out = o.scalar_type(); + + if(scalar_type_in == at::ScalarType::BFloat16 && scalar_type_out == at::ScalarType::BFloat16) { + using Element = cutlass::bfloat16_t; + using ElementOut = cutlass::bfloat16_t; + + auto apply_config = [&](auto fn) { + if (mask_mode == MaskMode::kCausal) { + if(is_varlen) { + fn(CausalForBackwardMask{}, cute::true_type{}, Element{}, ElementOut{}); + } else { + fn(CausalForBackwardMask{}, cute::false_type{}, Element{}, ElementOut{}); + } + } + else { + if(is_varlen) { + fn(ResidualMaskForBackward{}, cute::true_type{}, Element{}, ElementOut{}); + } else { + fn(ResidualMaskForBackward{}, cute::false_type{}, Element{}, ElementOut{}); + } + } + }; + + apply_config([&](auto mask, auto varlen, auto in, auto out) { + if (head_dim_qk == 192 && head_dim_vo == 128) { + call_run_fmha_bwd(mask, varlen, in, out, true_type{}, workspace_buffer, d_o, q, k, v, o, lse, + cumulative_seqlen_q, cumulative_seqlen_kv, + dq, dk, dv, + softmax_scale, max_seqlen_q, max_seqlen_kv); + } else if (head_dim_qk == 128 && head_dim_vo == 128) { + call_run_fmha_bwd(mask, varlen, in, out, false_type{}, workspace_buffer, d_o, q, k, v, o, lse, + cumulative_seqlen_q, cumulative_seqlen_kv, + dq, dk, dv, + softmax_scale, max_seqlen_q, max_seqlen_kv); } + else { + std::cout << "No kernel instantiated for head_dim_qk=" << head_dim_qk << " head_dim_vo=" << head_dim_vo << std::endl; + } + }); + + } else { + FLASH_MLA_ASSERT(false); + } +} diff --git a/csrc/sm100/fmha_cutlass_bwd_sm100.cuh b/csrc/sm100/fmha_cutlass_bwd_sm100.cuh new file mode 100644 index 0000000..2b19be2 --- /dev/null +++ b/csrc/sm100/fmha_cutlass_bwd_sm100.cuh @@ -0,0 +1,200 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include +#include +#include + +#include + +#include +#include + +#include +#include +#include + +#include "common/utils.hpp" +#include "collective/fmha_fusion.hpp" +#include "device/fmha_device_bwd.hpp" + +using namespace cute; +using namespace cutlass::fmha::kernel; +using namespace cutlass::fmha::collective; +using namespace cutlass::fmha; +using namespace cutlass; + + +template< + class DType, + bool kIsVarlen, + bool kIsMla, + class TileShape, + class ActiveMask +> +struct BwdRunner { + + using Element = DType; + using ElementAccumulator = float; + + // Q K D D_VO (H B) + using ProblemShape = std::conditional_t< + kIsVarlen, + cute::tuple>, + cute::tuple> + >; + + using Operation = cutlass::fmha::device::Sm100FmhaBwd; + + using TensorStride = Stride>; + using StrideQ = TensorStride; // Seq DQK (H B) + using StrideK = TensorStride; // Seq DQK (H B) + using StrideV = TensorStride; // Seq DVO (H B) + using StrideO = TensorStride; // Seq DVO (H B) + using StrideLSE = Stride<_1, Stride>; // Seq (H B) + + // Backwards specific + using StrideDQ = TensorStride; + using StrideDK = TensorStride; // Seq DQK (H B) + using StrideDV = TensorStride; // Seq DVO (H B) + using StrideDO = TensorStride; + + static void run(at::Tensor workspace_buffer, at::Tensor d_o, at::Tensor q, at::Tensor k, + at::Tensor v, at::Tensor o, at::Tensor lse, + at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv, + at::Tensor dq, at::Tensor dk, at::Tensor dv, + float softmax_scale, int max_seqlen_q, int max_seqlen_kv) { + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + ProblemShape problem_shape; + cute::tuple> tensor_shape; + + + int d = q.size(-1); + int d_vo = v.size(-1); + int batch_size = cumulative_seqlen_q.size(0) - 1; + int num_qo_heads = q.size(1); + int total_seqlen_q = q.size(0); + int total_seqlen_kv = k.size(0); + + //varlen: q: [Q, H, D] + //fixedlen: q: [B, H, Q, D] + if constexpr (kIsVarlen) { + problem_shape = cute::make_tuple( + VariableLength{max_seqlen_q, static_cast(cumulative_seqlen_q.data_ptr()), total_seqlen_q}, + VariableLength{max_seqlen_kv, static_cast(cumulative_seqlen_kv.data_ptr()), total_seqlen_kv}, + d, d_vo, cute::make_tuple(num_qo_heads, batch_size)); + tensor_shape = make_shape(total_seqlen_q, total_seqlen_kv, d, d_vo, make_shape(num_qo_heads, 1)); + } else { + int q_len = total_seqlen_q / batch_size; + int kv_len = total_seqlen_kv / batch_size; + problem_shape = cute::make_tuple(q_len, kv_len, d, d_vo, cute::make_tuple(num_qo_heads, batch_size)); + tensor_shape = problem_shape; + } + + auto [Q, K, D, D_VO, HB] = tensor_shape; + auto [H, B] = HB; + + int q_stride0 = q.stride(0), q_stride1 = q.stride(1), q_stride2 = q.stride(2); + int k_stride0 = k.stride(0), k_stride1 = k.stride(1), k_stride2 = k.stride(2); + int v_stride0 = v.stride(0), v_stride1 = v.stride(1), v_stride2 = v.stride(2); + int o_stride0 = o.stride(0), o_stride1 = o.stride(1), o_stride2 = o.stride(2); + int lse_stride0 = lse.stride(0), lse_stride1 = lse.stride(1); + int dq_stride0 = dq.stride(0), dq_stride1 = dq.stride(1), dq_stride2 = dq.stride(2); + int dk_stride0 = dk.stride(0), dk_stride1 = dk.stride(1), dk_stride2 = dk.stride(2); + int dv_stride0 = dv.stride(0), dv_stride1 = dv.stride(1), dv_stride2 = dv.stride(2); + int do_stride0 = d_o.stride(0), do_stride1 = d_o.stride(1), do_stride2 = d_o.stride(2); + TORCH_CHECK(q_stride2 == 1); + TORCH_CHECK(k_stride2 == 1); + TORCH_CHECK(v_stride2 == 1); + TORCH_CHECK(o_stride2 == 1); + TORCH_CHECK(lse_stride0 == 1); + TORCH_CHECK(dq_stride2 == 1); + TORCH_CHECK(dk_stride2 == 1); + TORCH_CHECK(dv_stride2 == 1); + TORCH_CHECK(do_stride2 == 1); + + StrideQ stride_Q = make_stride(q_stride0, _1{}, make_stride(q_stride1, B == 1 ? 0 : q_stride0*Q)); + StrideK stride_K = make_stride(k_stride0, _1{}, make_stride(k_stride1, B == 1 ? 0 : k_stride0*K)); + StrideV stride_V = make_stride(v_stride0, _1{}, make_stride(v_stride1, B == 1 ? 0 : v_stride0*K)); + StrideO stride_O = make_stride(o_stride0, _1{}, make_stride(o_stride1, B == 1 ? 0 : o_stride0*Q)); + StrideLSE stride_LSE = make_stride(_1{}, make_stride(lse_stride1, B == 1 ? 0 : Q)); + + StrideDQ stride_dQ = make_stride(dq_stride0, _1{}, make_stride(dq_stride1, B == 1 ? 0 : dq_stride0*Q)); + StrideDK stride_dK = make_stride(dk_stride0, _1{}, make_stride(dk_stride1, B == 1 ? 0 : dk_stride0*K)); + StrideDV stride_dV = make_stride(dv_stride0, _1{}, make_stride(dv_stride1, B == 1 ? 0 : dv_stride0*K)); + StrideDO stride_dO = make_stride(do_stride0, _1{}, make_stride(do_stride1, B == 1 ? 0 : do_stride0*Q)); + + typename Operation::Arguments arguments{ + problem_shape, + (static_cast(q.data_ptr())), stride_Q, + (static_cast(k.data_ptr())), stride_K, + (static_cast(v.data_ptr())), stride_V, + (static_cast(o.data_ptr())), stride_O, + (static_cast(lse.data_ptr())), stride_LSE, + (static_cast(d_o.data_ptr())), stride_dO, + (static_cast(dq.data_ptr())), stride_dQ, + (static_cast(dk.data_ptr())), stride_dK, + (static_cast(dv.data_ptr())), stride_dV, + static_cast(softmax_scale), + hw_info + }; + + Operation op; + + size_t workspace_size = 0; + workspace_size = Operation::get_workspace_size(arguments); + DeviceAllocation workspace(workspace_size); + uint8_t* workspace_ptr = workspace.get(); + + CUTLASS_CHECK(op.can_implement(arguments)); + CUTLASS_CHECK(op.initialize(arguments, workspace.get())); + CUTLASS_CHECK(op.run(at::cuda::getCurrentCUDAStream())); + } + +}; + + +template +void run_fmha_bwd(at::Tensor workspace_buffer, at::Tensor d_o, at::Tensor q, at::Tensor k, + at::Tensor v, at::Tensor o, at::Tensor lse, + at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv, + at::Tensor dq, at::Tensor dk, at::Tensor dv, + float softmax_scale, int max_seqlen_q, int total_seqlen_kv) { + BwdRunner::run(workspace_buffer, d_o, q, k, v, o, lse, + cumulative_seqlen_q, cumulative_seqlen_kv, + dq, dk, dv, + softmax_scale, max_seqlen_q, total_seqlen_kv); +} diff --git a/csrc/sm100/fmha_cutlass_fwd_sm100.cu b/csrc/sm100/fmha_cutlass_fwd_sm100.cu new file mode 100644 index 0000000..e322709 --- /dev/null +++ b/csrc/sm100/fmha_cutlass_fwd_sm100.cu @@ -0,0 +1,81 @@ +#include "common/mask.cuh" +#include "common/utils.hpp" +#include "fmha_cutlass_fwd_sm100.cuh" + +#include +#include +#include +#include +#include + +template +void call_run_fmha_fwd([[maybe_unused]] Mask mask, [[maybe_unused]] Varlen is_varlen, + [[maybe_unused]] Element in, [[maybe_unused]] ElementOut out, + [[maybe_unused]] Mla mla, at::Tensor workspace_buffer, at::Tensor q, + at::Tensor k, at::Tensor v, at::Tensor cumulative_seqlen_q, + at::Tensor cumulative_seqlen_kv, at::Tensor o, at::Tensor lse, + float softmax_scale, int max_seqlen_q, int max_seqlen_kv) { + static constexpr bool IsVarlen = std::is_same_v; + static constexpr bool IsMla = std::is_same_v; + static constexpr bool IsCausalMask = std::is_same_v>; + using Option = std::conditional_t, + Option>; + + run_fmha_fwd( + workspace_buffer, q, k, v, cumulative_seqlen_q, cumulative_seqlen_kv, o, lse, + softmax_scale, max_seqlen_q, max_seqlen_kv); +} + +void FMHACutlassSM100FwdRun(at::Tensor workspace_buffer, at::Tensor q, at::Tensor k, + at::Tensor v, at::Tensor cumulative_seqlen_q, + at::Tensor cumulative_seqlen_kv, at::Tensor o, at::Tensor lse, + int mask_mode_code, float sm_scale, int max_seqlen_q, + int max_seqlen_kv, bool is_varlen) { + const c10::cuda::OptionalCUDAGuard device_guard(q.device()); + CHECK(q.scalar_type() == k.scalar_type()); + auto scalar_type_in = q.scalar_type(); + auto scalar_type_out = o.scalar_type(); + int head_dim_qk = q.size(-1); + int head_dim_vo = v.size(-1); + MaskMode mask_mode = static_cast(mask_mode_code); + + if (scalar_type_in == at::ScalarType::BFloat16 && + scalar_type_out == at::ScalarType::BFloat16) { + using Element = cutlass::bfloat16_t; + using ElementOut = cutlass::bfloat16_t; + + auto apply_config = [&](auto fn) { + if (mask_mode == MaskMode::kCausal) { + if (is_varlen) { + fn(CausalMask{}, cute::true_type{}, Element{}, ElementOut{}); + } else { + fn(CausalMask{}, cute::false_type{}, Element{}, ElementOut{}); + } + } else { + if (is_varlen) { + fn(ResidualMask{}, cute::true_type{}, Element{}, ElementOut{}); + } else { + fn(ResidualMask{}, cute::false_type{}, Element{}, ElementOut{}); + } + } + }; + + apply_config([&](auto mask, auto varlen, auto in, auto out) { + if (head_dim_qk == 192 && head_dim_vo == 128) { + call_run_fmha_fwd(mask, varlen, in, out, true_type{}, workspace_buffer, q, k, v, + cumulative_seqlen_q, cumulative_seqlen_kv, o, lse, sm_scale, + max_seqlen_q, max_seqlen_kv); + } else if (head_dim_qk == 128 && head_dim_vo == 128) { + call_run_fmha_fwd(mask, varlen, in, out, false_type{}, workspace_buffer, q, k, v, + cumulative_seqlen_q, cumulative_seqlen_kv, o, lse, sm_scale, + max_seqlen_q, max_seqlen_kv); + } else { + std::cout << "No kernel instantiated for head_dim_qk=" << head_dim_qk + << " head_dim_vo=" << head_dim_vo << std::endl; + } + }); + + } else { + FLASH_MLA_ASSERT(false); + } +} diff --git a/csrc/sm100/fmha_cutlass_fwd_sm100.cuh b/csrc/sm100/fmha_cutlass_fwd_sm100.cuh new file mode 100644 index 0000000..71831bb --- /dev/null +++ b/csrc/sm100/fmha_cutlass_fwd_sm100.cuh @@ -0,0 +1,334 @@ +#pragma once + +#include "collective/fmha_fusion.hpp" +#include "collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp" +#include "collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp" +#include "collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/kernel_hardware_info.h" +#include "device/fmha.hpp" +#include "kernel/fmha_causal_tile_scheduler.hpp" +#include "kernel/fmha_options.hpp" +#include "kernel/fmha_tile_scheduler.hpp" +#include "kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp" + +#include +#include + +using namespace cute; +using namespace cutlass::fmha::collective; +using namespace cutlass::fmha::kernel; +using namespace cutlass::fmha::device; + +struct FmhaOptions { + int b = 1; + int h = 1; + int h_k = 1; + int q = 256; + int k = 256; + int d = 128; +}; + +struct MlaOptions { + int b = 1; + int h = 1; + int h_k = 1; + int q = 256; + int k = 256; + int dl = 128; // headdim latent + int dr = 64; // headdim rope +}; + +template +struct FwdRunner { + + using Element = Element_; + using ElementAccumulatorQK = float; + using ElementAccumulatorPV = float; + using ElementOut = ElementOut_; + + using HeadDimLatent = _128; + using HeadDim = Shape; + using TileShapeMla = Shape<_256, _128, HeadDim>; + using TileShapeFmha = Shape<_256, _128, _128>; + using TileShape = std::conditional_t; + + using ProblemShapeRegular = std::conditional_t< + kIsMla, + cute::tuple, cute::tuple, int>>, + cute::tuple, int>>>; + + using ProblemShapeVarlen = + std::conditional_t, + cute::tuple, int>>, + cute::tuple, int>>>; + + using ProblemShapeType = + std::conditional_t; + + using StrideQ = cute::tuple, int>>; + using StrideK = cute::tuple, int>>; + using StrideV = StrideK; + using StrideO = StrideQ; + using StrideLSE = cute::tuple<_1, cute::tuple, int>>; + + static constexpr bool kIsPersistent = + find_option_t::value; + + using TileScheduler = std::conditional_t< + kIsPersistent, + std::conditional_t> || + std::is_same_v>, + cutlass::fmha::kernel::CausalPersistentTileScheduler, + cutlass::fmha::kernel::PersistentTileScheduler>, + std::conditional_t>; + + static constexpr bool IsOrderLoadEpilogue = + kIsPersistent && (sizeof(Element) == sizeof(ElementOut)); + using OrderLoadEpilogue = std::conditional_t; + + using MainloopMla = cutlass::fmha::collective::Sm100MlaFwdMainloopTmaWarpspecialized< + Element, ElementAccumulatorQK, ElementAccumulatorPV, TileShapeMla, StrideQ, StrideK, + StrideV, ActiveMask, Shape<_2, _1, _1>, OrderLoadEpilogue>; + + using OperationMla = + cutlass::fmha::device::FMHA, + TileScheduler, cutlass::fmha::kernel::Sm100MlaFwdCtxKernelWarpspecializedSchedule>>; + + using MainloopFmha = cutlass::fmha::collective::Sm100FmhaFwdMainloopTmaWarpspecialized< + Element, ElementAccumulatorQK, ElementAccumulatorPV, TileShapeFmha, StrideQ, StrideK, + StrideV, ActiveMask>; + + using OperationFmha = + cutlass::fmha::device::FMHA, + TileScheduler>>; + + using Mainloop = std::conditional_t; + using Operation = std::conditional_t; + + // + // Data members + // + + /// Initialization + StrideQ stride_Q; + StrideK stride_K; + StrideV stride_V; + StrideO stride_O; + StrideLSE stride_LSE; + + template + auto initialize_varlen(const ProblemShape &problem_size, int max_seqlen_q, int max_seqlen_kv, + int total_seqlen_q, int total_seqlen_kv) { + + int num_batches = get<3, 1>(problem_size); + + ProblemShape problem_size_for_init = problem_size; + get<3, 1>(problem_size_for_init) = 1; + get<0>(problem_size_for_init) = total_seqlen_q; + get<1>(problem_size_for_init) = total_seqlen_kv; + + ProblemShapeType problem_size_for_launch; + + get<0>(problem_size_for_launch) = VariableLength{max_seqlen_q}; + get<1>(problem_size_for_launch) = VariableLength{max_seqlen_kv}; + get<2>(problem_size_for_launch) = get<2>(problem_size); + get<3>(problem_size_for_launch) = get<3>(problem_size); + + return cute::make_tuple(problem_size_for_init, problem_size_for_launch); + } + + template + static constexpr auto get_problem_shape(const Options &options) { + int h_r = options.h / options.h_k; + if constexpr (std::is_same_v) { + return cute::make_tuple(options.q, options.k, cute::make_tuple(options.dl, options.dr), + cute::make_tuple(cute::make_tuple(h_r, options.h_k), options.b)); + } else { + return cute::make_tuple(options.q, options.k, options.d, + cute::make_tuple(cute::make_tuple(h_r, options.h_k), options.b)); + } + } + + template + ProblemShapeType initialize(const Options &options, int max_seqlen_q, int max_seqlen_kv, + int total_seqlen_q, int total_seqlen_kv, + void *cumulative_length_q, void *cumulative_length_kv) { + assert(options.h % options.h_k == 0); + auto problem_shape_in = get_problem_shape(options); + + ProblemShapeType problem_shape; + decltype(problem_shape_in) problem_size; + + if constexpr (kIsVarlen) { + auto [problem_shape_init, problem_shape_launch] = initialize_varlen( + problem_shape_in, max_seqlen_q, max_seqlen_kv, total_seqlen_q, total_seqlen_kv); + problem_shape = problem_shape_launch; + problem_size = problem_shape_init; + } else { + problem_size = problem_shape_in; + problem_shape = problem_shape_in; + } + + auto get_head_dimension = [&]() { + if constexpr (rank_v(problem_shape))> == 2) { + return cute::make_tuple(size<2, 0>(problem_shape) + size<2, 1>(problem_shape), + size<2, 0>(problem_shape)); + } else { + return cute::make_tuple(size<2>(problem_size), size<2>(problem_size)); + } + }; + + + if constexpr (kIsVarlen) { + get<0>(problem_shape).cumulative_length = static_cast(cumulative_length_q); + get<1>(problem_shape).cumulative_length = static_cast(cumulative_length_kv); + } + + return problem_shape; + } + + auto get_arguments(const ProblemShapeType &problem_shape, + const cutlass::KernelHardwareInfo &hw_info, float scale_softmax, + void *q_ptr, void *k_ptr, void *v_ptr, void *o_ptr, void *lse_ptr, + void *cumulative_length_q, void *cumulative_length_kv) { + auto problem_shape_ = problem_shape; + if constexpr (kIsVarlen) { + get<0>(problem_shape_).cumulative_length = static_cast(cumulative_length_q); + get<1>(problem_shape_).cumulative_length = static_cast(cumulative_length_kv); + } + + typename Operation::Arguments arguments{ + problem_shape_, + {static_cast(q_ptr), stride_Q, static_cast(k_ptr), stride_K, + static_cast(v_ptr), stride_V, scale_softmax}, + {static_cast(o_ptr), stride_O, + static_cast(lse_ptr), stride_LSE}, + hw_info}; + + return arguments; + } + + template + void run(const Options &options, const cutlass::KernelHardwareInfo &hw_info, at::Tensor q, + at::Tensor k, at::Tensor v, at::Tensor o, at::Tensor lse, float scale_softmax, + at::Tensor workspace, at::Tensor cumulative_seqlen_q, + at::Tensor cumulative_seqlen_kv, int max_seqlen_q, int max_seqlen_kv) { + + int total_seqlen_q = q.size(0); + int total_seqlen_kv = k.size(0); + ProblemShapeType problem_shape = + initialize(options, max_seqlen_q, max_seqlen_kv, total_seqlen_q, total_seqlen_kv, + cumulative_seqlen_q.data_ptr(), cumulative_seqlen_kv.data_ptr()); + + int SQ = size<0>(problem_shape); + int SK = size<1>(problem_shape); + int B = size<3, 1>(problem_shape); + int H = size<3, 0>(problem_shape); + int H_K = size<3, 0, 1>(problem_shape); + int H_Q = size<3, 0, 0>(problem_shape); + + int q_stride0 = q.stride(0), q_stride1 = q.stride(1), q_stride2 = q.stride(2); + int k_stride0 = k.stride(0), k_stride1 = k.stride(1), k_stride2 = k.stride(2); + int v_stride0 = v.stride(0), v_stride1 = v.stride(1), v_stride2 = v.stride(2); + int o_stride0 = o.stride(0), o_stride1 = o.stride(1), o_stride2 = o.stride(2); + int lse_stride0 = lse.stride(0), lse_stride1 = lse.stride(1); + TORCH_CHECK(q_stride2 == 1); + TORCH_CHECK(k_stride2 == 1); + TORCH_CHECK(v_stride2 == 1); + TORCH_CHECK(o_stride2 == 1); + TORCH_CHECK(lse_stride0 == 1); + + stride_Q = make_stride(q_stride0, _1{}, make_stride(make_stride(q_stride1, H_Q * q_stride1), SQ * q_stride0)); + stride_O = make_stride(o_stride0, _1{}, make_stride(make_stride(o_stride1, H_Q * o_stride1), SQ * o_stride0)); + stride_K = make_stride(k_stride0, _1{}, make_stride(make_stride(_0{}, k_stride1), SK * k_stride0)); + stride_V = make_stride(v_stride0, _1{}, make_stride(make_stride(_0{}, v_stride1), SK * v_stride0)); + stride_LSE = make_stride(_1{}, make_stride(make_stride(lse_stride1, lse_stride1 * H_Q), SQ)); + + if constexpr (kIsVarlen) { + get<2, 1>(stride_Q) = 0; + get<2, 1>(stride_K) = 0; + get<2, 1>(stride_V) = 0; + get<2, 1>(stride_O) = 0; + get<1, 1>(stride_LSE) = 0; + } + + typename Operation::Arguments arguments = + get_arguments(problem_shape, hw_info, scale_softmax, q.data_ptr(), k.data_ptr(), + v.data_ptr(), o.data_ptr(), lse.data_ptr(), + cumulative_seqlen_q.data_ptr(), cumulative_seqlen_kv.data_ptr()); + + Operation op; + + // size_t workspace_size = 0; + // workspace_size = Operation::get_workspace_size(arguments); + + // todo: if use workspace, need check workspace size first. + // we don't use workspace in current version. + + CUTLASS_CHECK(op.can_implement(arguments)); + CUTLASS_CHECK(op.initialize(arguments, nullptr)); + CUTLASS_CHECK(op.run(at::cuda::getCurrentCUDAStream())); + } +}; + +template +void run_fmha_fwd(at::Tensor workspace, at::Tensor q, at::Tensor k, at::Tensor v, + at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv, at::Tensor o, + at::Tensor lse, float scale_softmax, int max_seqlen_q, int max_seqlen_kv) { + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + auto get_options = [&]() { + if constexpr (kIsMla) { + MlaOptions options; + options.b = cumulative_seqlen_q.size(0) - 1; + options.h = q.size(1); + options.h_k = k.size(1); + options.q = q.size(0) / options.b; + options.k = k.size(0) / options.b; + options.dl = v.size(-1); + options.dr = q.size(-1) - v.size(-1); + return options; + } else { + FmhaOptions options; + options.b = cumulative_seqlen_q.size(0) - 1; + options.h = q.size(1); + options.h_k = k.size(1); + options.q = q.size(0) / options.b; + options.k = k.size(0) / options.b; + options.d = q.size(-1); + return options; + } + }; + + auto options = get_options(); + + if (options.h % cutlass::fmha::kernel::CausalIndividualTileScheduler::TileH == 0 && + (!std::is_same_v)) { + FwdRunner runner; + runner.run(options, hw_info, q, k, v, o, lse, scale_softmax, workspace, cumulative_seqlen_q, + cumulative_seqlen_kv, max_seqlen_q, max_seqlen_kv); + } else { + FwdRunner runner; + runner.run(options, hw_info, q, k, v, o, lse, scale_softmax, workspace, cumulative_seqlen_q, + cumulative_seqlen_kv, max_seqlen_q, max_seqlen_kv); + } +} diff --git a/csrc/sm100/kernel/fmha_causal_tile_scheduler.hpp b/csrc/sm100/kernel/fmha_causal_tile_scheduler.hpp new file mode 100644 index 0000000..572e67f --- /dev/null +++ b/csrc/sm100/kernel/fmha_causal_tile_scheduler.hpp @@ -0,0 +1,197 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" + +namespace cutlass::fmha::kernel { + +//////////////////////////////////////////////////////////////////////////////// + +// Swizzle Q tile and H tile to improve L2 cache hit rate, +// and launch the longest main loop first to keep most SMs busy. + +struct CausalIndividualTileScheduler { + + static constexpr int TileQ = 16; + static constexpr int TileH = 8; + static constexpr int TileSize = TileQ * TileH; + + struct Params { + dim3 grid; + int tile_max_q; + FastDivmod divmod_tile_col; + FastDivmod divmod_tile_size; + FastDivmod divmod_tile_head; + }; + + bool valid_ = true; + Params params; + + CUTLASS_DEVICE + CausalIndividualTileScheduler(Params const& params) : params(params) {} + + template + static Params to_underlying_arguments( + ProblemSize const& problem_size, KernelHardwareInfo hw_info, + ClusterShape const& cluster_shape, TileShape const& tile_shape) { + using namespace cute; + + dim3 grid(size<3,0>(problem_size), round_up(ceil_div(size<0>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape)), size<3,1>(problem_size)); + // gridDim.x must multiple of TileH + const int tile_col_count = grid.x / TileH; + const int tile_max_q = grid.y / TileQ * TileQ; + return Params{ grid , tile_max_q, tile_col_count, TileSize, TileH}; + } + + static dim3 get_grid_shape(Params const& params) { + return params.grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return valid_; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + const int block_idx = blockIdx.y * gridDim.x + blockIdx.x; + + int tile_idx, tile_tail; + params.divmod_tile_size(tile_idx, tile_tail, block_idx); + + int tile_row_idx, tile_col_idx; + params.divmod_tile_col(tile_row_idx,tile_col_idx, tile_idx); + + int row_offset_in_tail, col_offset_in_tail; + params.divmod_tile_head(row_offset_in_tail,col_offset_in_tail, tile_tail); + + const int row_idx = tile_row_idx * TileQ + row_offset_in_tail; + const int col_idx = tile_col_idx * TileH + col_offset_in_tail; + + // last q tile launch first + if(blockIdx.y >= params.tile_max_q) { + return make_coord(int(gridDim.y - 1 - blockIdx.y), _0{}, make_coord(int(blockIdx.x), int(blockIdx.z))); + } + + return make_coord(int(gridDim.y) - 1 - row_idx, _0{}, make_coord(col_idx, int(blockIdx.z))); + } + + CUTLASS_DEVICE + CausalIndividualTileScheduler& operator++() { + valid_ = false; + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + + +//////////////////////////////////////////////////////////////////////////////// + +// Launch order: H Q B +struct CausalPersistentTileScheduler { + + struct Params { + int num_blocks; + FastDivmod divmod_h; + FastDivmod divmod_m_block; + FastDivmod divmod_b; + + KernelHardwareInfo hw_info; + }; + + int block_idx = 0; + Params params; + + CUTLASS_DEVICE + CausalPersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {} + + template + static Params to_underlying_arguments( + ProblemSize const& problem_size, KernelHardwareInfo hw_info, + ClusterShape const& cluster_shape, TileShape const& tile_shape) { + using namespace cute; + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = hw_info.sm_count; + if (sm_count <= 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + hw_info.sm_count = sm_count; + + int num_m_blocks = cutlass::round_up(ceil_div(size<0>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape)); + int num_blocks = num_m_blocks * size<3,0>(problem_size) * size<3,1>(problem_size); + + return Params { + num_blocks, + { size<3,0>(problem_size) }, { num_m_blocks}, { size<3,1>(problem_size) }, + hw_info + }; + } + + static dim3 get_grid_shape(Params const& params) { + dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1); + return grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return block_idx < params.num_blocks; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + int block_decode = block_idx; + int m_block, bidb, bidh; + params.divmod_h(block_decode, bidh, block_decode); + params.divmod_m_block(block_decode, m_block, block_decode); + params.divmod_b(block_decode, bidb, block_decode); + return make_coord(m_block, _0{}, make_coord(bidh, bidb)); + } + + CUTLASS_DEVICE + CausalPersistentTileScheduler& operator++() { + block_idx += gridDim.x; + return *this; + } +}; +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::fmha::kernel diff --git a/csrc/sm100/kernel/fmha_kernel_bwd_convert.hpp b/csrc/sm100/kernel/fmha_kernel_bwd_convert.hpp new file mode 100644 index 0000000..32e007c --- /dev/null +++ b/csrc/sm100/kernel/fmha_kernel_bwd_convert.hpp @@ -0,0 +1,153 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/layout.hpp" + +namespace cutlass::fmha::kernel { + +using namespace cute; + +template +struct FmhaKernelBwdConvert { + + struct Arguments { + ProblemShape problem_shape; + + const ElementAcc* ptr_src_dQ; + tuple> stride_src_dQ; + const ElementAcc* ptr_src_dK; + tuple> stride_src_dK; + const ElementAcc* ptr_src_dV; + tuple> stride_src_dV; + + Element* ptr_dest_dQ; + tuple> stride_dest_dQ; + Element* ptr_dest_dK; + tuple> stride_dest_dK; + Element* ptr_dest_dV; + tuple> stride_dest_dV; + + ElementAcc scale = 1.0; + }; + + using Params = Arguments; + + using ClusterShape = Shape<_1, _1, _1>; + static constexpr int SharedStorageSize = 0; + + static const int MinBlocksPerMultiprocessor = 1; + static const int MaxThreadsPerBlock = 128; + using ArchTag = cutlass::arch::Sm90; + + static const int kBlockSeq = 8; + + static size_t get_workspace_size(Arguments const& args) { return 0; } + static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) { + return cutlass::Status::kSuccess; + } + + static const int kNumThreadsD = 16; + static const int kNumThreadsSeq = MaxThreadsPerBlock / kNumThreadsD; + static const int kElementsPerLoad = 4; + + static const int kIterationsSeq = kBlockSeq / kNumThreadsSeq; + + static bool can_implement(Arguments const& args) { + return get<2>(args.problem_shape) % kElementsPerLoad == 0 && get<3>(args.problem_shape) % kElementsPerLoad == 0; + } + + static dim3 get_grid_shape(Params const& params) { + dim3 grid(size<4,0>(params.problem_shape), size<4,1>(params.problem_shape), ceil_div(std::max(size<0>(params.problem_shape), size<1>(params.problem_shape)), kBlockSeq)); + return grid; + } + + static dim3 get_block_shape() { + dim3 block(kNumThreadsD, kNumThreadsSeq, 1); + return block; + } + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + return args; + } + + template + CUTLASS_DEVICE void copy(Params const& params, const ElementAcc* ptr_src, StrideSrc const& stride_src, Element* ptr_dest, StrideDest const& stride_dest, Count const& count, int d_dim) { + auto ptr_src_bh = ptr_src + get<2,0>(stride_src) * blockIdx.x + get<2,1>(stride_src) * blockIdx.y; + auto ptr_dest_bh = ptr_dest + get<2,0>(stride_dest) * blockIdx.x + get<2,1>(stride_dest) * blockIdx.y; + + int seqlen = count; + if constexpr (is_variable_length_v) { + int offset = count.cumulative_length[blockIdx.y]; + ptr_dest_bh += offset * get<0>(stride_dest); + seqlen = count.cumulative_length[blockIdx.y + 1] - offset; + } + + for (int idx_s_t = threadIdx.y; idx_s_t < kBlockSeq; idx_s_t += kNumThreadsSeq) { + int idx_s = idx_s_t + kBlockSeq * blockIdx.z; + if (idx_s >= seqlen) continue; + auto ptr_src_bhs = ptr_src_bh + idx_s * get<0>(stride_src); + auto ptr_dest_bhs = ptr_dest_bh + idx_s * get<0>(stride_dest); + + for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < d_dim; idx_d += kElementsPerLoad * kNumThreadsD) { + ElementAcc value_src[kElementsPerLoad]; + Element value_dest[kElementsPerLoad]; + + using VecSrc = uint_bit_t * kElementsPerLoad>; + using VecDest = uint_bit_t * kElementsPerLoad>; + *reinterpret_cast(value_src) = *reinterpret_cast(&ptr_src_bhs[idx_d]); + + for (int v = 0; v < kElementsPerLoad; v++) { + value_dest[v] = static_cast(params.scale * value_src[v]); + } + + *reinterpret_cast(&ptr_dest_bhs[idx_d]) = *reinterpret_cast(value_dest); + } + } + } + + CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { + if (params.ptr_src_dQ != nullptr) { + copy(params, params.ptr_src_dQ, params.stride_src_dQ, params.ptr_dest_dQ, params.stride_dest_dQ, get<0>(params.problem_shape), get<2>(params.problem_shape)); + } + if (params.ptr_src_dK != nullptr) { + copy(params, params.ptr_src_dK, params.stride_src_dK, params.ptr_dest_dK, params.stride_dest_dK, get<1>(params.problem_shape), get<2>(params.problem_shape)); + } + if (params.ptr_src_dV != nullptr) { + copy(params, params.ptr_src_dV, params.stride_src_dV, params.ptr_dest_dV, params.stride_dest_dV, get<1>(params.problem_shape), get<3>(params.problem_shape)); + } + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/csrc/sm100/kernel/fmha_kernel_bwd_sum_OdO.hpp b/csrc/sm100/kernel/fmha_kernel_bwd_sum_OdO.hpp new file mode 100644 index 0000000..bdcf1cb --- /dev/null +++ b/csrc/sm100/kernel/fmha_kernel_bwd_sum_OdO.hpp @@ -0,0 +1,161 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/layout.hpp" + +namespace cutlass::fmha::kernel { + +using namespace cute; + +template +struct FmhaKernelBwdSumOdO { + + struct Arguments { + ProblemShape problem_shape; + + const Element* ptr_O; + cute::tuple> stride_O; + const Element* ptr_dO; + cute::tuple> stride_dO; + + ElementAcc* ptr_sum_OdO; + cute::tuple> stride_sum_OdO; + + const ElementAcc* ptr_lse = nullptr; + cute::tuple> stride_lse; + + ElementAcc* ptr_scaled_lse = nullptr; + cute::tuple> stride_scaled_lse; + + ElementAcc sum_odo_scale = 1.0; + ElementAcc lse_scale = 1.0; + }; + + using Params = Arguments; + + using ClusterShape = Shape<_1, _1, _1>; + static constexpr int SharedStorageSize = 0; + + static const int MinBlocksPerMultiprocessor = 1; + static const int MaxThreadsPerBlock = 128; + using ArchTag = cutlass::arch::Sm100; + + static size_t get_workspace_size(Arguments const& args) { return 0; } + static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) { + return cutlass::Status::kSuccess; + } + + static const int kBlockQ = 16; + + static const int kNumThreadsD = 8; + static const int kNumThreadsQ = MaxThreadsPerBlock / kNumThreadsD; + static const int kElementsPerLoad = 2; + + static const int kIterationsQ = kBlockQ / kNumThreadsQ; + + static bool can_implement(Arguments const& args) { + return get<2>(args.problem_shape) % kElementsPerLoad == 0 && get<3>(args.problem_shape) % kElementsPerLoad == 0; + } + + static dim3 get_grid_shape(Params const& params) { + dim3 grid(ceil_div(size<0>(params.problem_shape), kBlockQ), size<4,0>(params.problem_shape), size<4,1>(params.problem_shape)); + return grid; + } + + static dim3 get_block_shape() { + dim3 block(kNumThreadsD, kNumThreadsQ, 1); + return block; + } + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + return args; + } + + CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { + auto ptr_O_bh = params.ptr_O + blockIdx.y * get<2,0>(params.stride_O) + blockIdx.z * get<2,1>(params.stride_O); + auto ptr_dO_bh = params.ptr_dO + blockIdx.y * get<2,0>(params.stride_dO) + blockIdx.z * get<2,1>(params.stride_dO); + auto ptr_sum_OdO_bh = params.ptr_sum_OdO + blockIdx.y * get<1,0>(params.stride_sum_OdO) + blockIdx.z * get<1,1>(params.stride_sum_OdO); + auto ptr_lse_bh = params.ptr_lse + blockIdx.y * get<1,0>(params.stride_lse) + blockIdx.z * get<1,1>(params.stride_lse); + auto ptr_scaled_lse_bh = params.ptr_scaled_lse + blockIdx.y * get<1,0>(params.stride_scaled_lse) + blockIdx.z * get<1,1>(params.stride_scaled_lse); + + auto problem_q = get<0>(params.problem_shape); + int seqlen_q = problem_q; + if constexpr (is_variable_length_v) { + int offset = problem_q.cumulative_length[blockIdx.z]; + ptr_O_bh += offset * get<0>(params.stride_O); + ptr_dO_bh += offset * get<0>(params.stride_dO); + ptr_lse_bh += offset * get<0>(params.stride_lse); + seqlen_q = problem_q.cumulative_length[blockIdx.z + 1] - offset; + } + + CUTLASS_PRAGMA_UNROLL + for (int idx_q_t = threadIdx.y; idx_q_t < kBlockQ; idx_q_t += kNumThreadsQ) { + int idx_q = idx_q_t + kBlockQ * blockIdx.x; + if (idx_q >= seqlen_q) continue; + ElementAcc acc = 0; + auto ptr_O_bhq = ptr_O_bh + idx_q * get<0>(params.stride_O); + auto ptr_dO_bhq = ptr_dO_bh + idx_q * get<0>(params.stride_dO); + auto ptr_sum_OdO_bhq = ptr_sum_OdO_bh + idx_q * get<0>(params.stride_sum_OdO); + auto ptr_lse_bhq = ptr_lse_bh + idx_q * get<0>(params.stride_lse); + auto ptr_scaled_lse_bhq = ptr_scaled_lse_bh + idx_q * get<0>(params.stride_scaled_lse); + + for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<3>(params.problem_shape); idx_d += kElementsPerLoad * kNumThreadsD) { + Element value_O[kElementsPerLoad]; + Element value_dO[kElementsPerLoad]; + + using Vec = uint_bit_t * kElementsPerLoad>; + *reinterpret_cast(value_O) = *reinterpret_cast(&ptr_O_bhq[idx_d]); + *reinterpret_cast(value_dO) = *reinterpret_cast(&ptr_dO_bhq[idx_d]); + + for (int v = 0; v < kElementsPerLoad; v++) { + acc += value_O[v] * value_dO[v]; + } + } + + for (int i = 1; i < kNumThreadsD; i *= 2) { + acc += __shfl_xor_sync((uint32_t)-1, acc, i, kNumThreadsD); + } + + if (threadIdx.x == 0) { + *ptr_sum_OdO_bhq = params.sum_odo_scale * acc; + if (params.ptr_scaled_lse) { + *ptr_scaled_lse_bhq = params.lse_scale * *ptr_lse_bhq; + } + } + } + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/csrc/sm100/kernel/fmha_options.hpp b/csrc/sm100/kernel/fmha_options.hpp new file mode 100644 index 0000000..d4faa8d --- /dev/null +++ b/csrc/sm100/kernel/fmha_options.hpp @@ -0,0 +1,85 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + + +#include "cutlass/cutlass.h" + +namespace cutlass::fmha::kernel { + +template +struct find_option; + +template +struct find_option { + using option_value = Default; +}; + +template +struct find_option : + std::conditional_t< + Option::tag == kTag, + Option, + find_option + > +{}; + +template +using find_option_t = typename find_option::option_value; + +enum class Tag { + kIsPersistent, + kNumMmaWarpGroups, + kLoadsQSeparately, + + kIsMainloopLocked, + kIsEpilogueLocked, + + kStagesQ, + kStagesKV, + + kEpilogueKind, + + kBlocksPerSM, + kClusterM, + + kAccQK +}; + +template +struct Option { + static constexpr auto tag = kTag; + using option_value = Value; +}; + +} // namespace cutlass::fmha::kernel diff --git a/csrc/sm100/kernel/fmha_tile_scheduler.hpp b/csrc/sm100/kernel/fmha_tile_scheduler.hpp new file mode 100644 index 0000000..119f069 --- /dev/null +++ b/csrc/sm100/kernel/fmha_tile_scheduler.hpp @@ -0,0 +1,162 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/kernel_hardware_info.h" + +namespace cutlass::fmha::kernel { + +//////////////////////////////////////////////////////////////////////////////// + +struct IndividualTileScheduler { + + struct Params { + dim3 grid; + }; + + bool valid_ = true; + + CUTLASS_DEVICE + IndividualTileScheduler(Params const&) {} + + template + static Params to_underlying_arguments( + ProblemSize const& problem_size, KernelHardwareInfo hw_info, + ClusterShape const& cluster_shape, TileShape const& tile_shape) { + using namespace cute; + dim3 grid(round_up(ceil_div(size<0>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape)), size<3,0>(problem_size), size<3,1>(problem_size)); + return Params{ grid }; + } + + static dim3 get_grid_shape(Params const& params) { + return params.grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return valid_; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + return make_coord(blockIdx.x, _0{}, make_coord(blockIdx.y, blockIdx.z)); + } + + CUTLASS_DEVICE + IndividualTileScheduler& operator++() { + valid_ = false; + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +struct PersistentTileScheduler { + + struct Params { + int num_blocks; + FastDivmod divmod_m_block; + FastDivmod divmod_h; + FastDivmod divmod_b; + + KernelHardwareInfo hw_info; + }; + + int block_idx = 0; + Params params; + + CUTLASS_DEVICE + PersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {} + + template + static Params to_underlying_arguments( + ProblemSize const& problem_size, KernelHardwareInfo hw_info, + ClusterShape const& cluster_shape, TileShape const& tile_shape) { + using namespace cute; + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = hw_info.sm_count; + if (sm_count <= 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + hw_info.sm_count = sm_count; + + int num_m_blocks = cutlass::round_up(ceil_div(size<0>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape)); + int num_blocks = num_m_blocks * size<3,0>(problem_size) * size<3,1>(problem_size); + + return Params { + num_blocks, + { num_m_blocks}, { size<3,0>(problem_size) }, { size<3,1>(problem_size) }, + hw_info + }; + } + + static dim3 get_grid_shape(Params const& params) { + dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1); + return grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return block_idx < params.num_blocks; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + int block_decode = block_idx; + int m_block, bidb, bidh; + params.divmod_m_block(block_decode, m_block, block_decode); + params.divmod_b(block_decode, bidb, block_decode); + params.divmod_h(block_decode, bidh, block_decode); + return make_coord(m_block, _0{}, make_coord(bidh, bidb)); + } + + CUTLASS_DEVICE + PersistentTileScheduler& operator++() { + block_idx += gridDim.x; + return *this; + } +}; + + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::fmha::kernel diff --git a/csrc/sm100/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp b/csrc/sm100/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp new file mode 100644 index 0000000..59b410b --- /dev/null +++ b/csrc/sm100/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp @@ -0,0 +1,1841 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cute/arch/simd_sm100.hpp" + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "collective/fmha_common.hpp" + +#include + +namespace cutlass::fmha::kernel { + +using namespace cutlass::fmha::collective; + +using namespace cute; + +template< + class ProblemShape, + class Element, + class ElementAcc, + class TileShape, + class Mask +> +struct Sm100FmhaBwdKernelTmaWarpSpecialized { + + using TileShapeQ = decltype(get<0>(TileShape{})); + static_assert(std::is_same_v, "tile shape K must be 128"); + using TileShapeK = decltype(get<1>(TileShape{})); + static_assert(std::is_same_v, "tile shape K must be 128"); + using TileShapeDQK = decltype(get<2>(TileShape{})); + using TileShapeDVO = decltype(get<2>(TileShape{})); + + using TmemAllocator = cute::TMEM::Allocator1Sm; + struct TmemAllocation { + static constexpr uint32_t kDK = 0; // TileShapeK x TileShapeDQK x acc + static constexpr uint32_t kDV = kDK + TileShapeDQK{}; // TileShapeK x TileShapeDVO x acc + static constexpr uint32_t kDQ = kDV + TileShapeDVO{}; // TileShapeQ x TileShapeDQK x acc + static constexpr uint32_t kDP = kDQ; // TileShapeK x TileShapeQ x inp + static constexpr uint32_t kS = kDQ + max(TileShapeQ{}, TileShapeDQK{}); + static constexpr uint32_t kP = kS; + static constexpr uint32_t kTotal = kS + TileShapeQ{}; + }; + + static_assert( + static_cast(TmemAllocation::kTotal) <= TmemAllocator::Sm100TmemCapacityColumns, + "using too much tmem" + ); + + enum class WarpRole { + Empty = 0x0, Load = 0x1, Mma = 0x2, Compute = 0x3, Reduce = 0x4 + }; + + static constexpr unsigned long long kWarpAssignment = 0x12'3333'3333'4444ull; + static constexpr int kNumComputeWarps = 8; + static constexpr int kNumReduceWarps = 4; + CUTLASS_DEVICE WarpRole warp_idx_to_role(int warp_idx) { + return static_cast((kWarpAssignment >> (4 * warp_idx)) & 0xF); + } + + struct RegisterAllocation { + static constexpr int kWarpgroup0 = 160-8; + static constexpr int kWarpgroup1 = 128; + static constexpr int kWarpgroup2 = 96; + static constexpr int kReduce = kWarpgroup0; + static constexpr int kCompute = kWarpgroup1; + static constexpr int kMma = kWarpgroup2; + static constexpr int kEmpty = kWarpgroup2; + static constexpr int kLoad = kWarpgroup2; + + static_assert(kWarpgroup0 + 2 * kWarpgroup1 + kWarpgroup2 <= 512); + }; + + using ArchTag = cutlass::arch::Sm100; + + using ClusterShape = Shape<_1, _1, _1>; + using Schedule = cutlass::gemm::KernelTmaWarpSpecialized1SmSm100; + + static constexpr int MinBlocksPerMultiprocessor = 1; + static constexpr int kNumWarps = kNumComputeWarps + kNumReduceWarps + 4; + static constexpr int MaxThreadsPerBlock = NumThreadsPerWarp * kNumWarps; + + static constexpr int Alignment = 128 / sizeof_bits_v; + static constexpr int kStages = 2; + + using TensorStrideContiguousK = Stride>; + using TensorStrideContiguousMN = Stride<_1, int, Stride>; + + // compute S + using CollectiveMmaKQ = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, TensorStrideContiguousK, Alignment, + Element, TensorStrideContiguousK, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapeKQ = typename CollectiveMmaKQ::TileShape; + using TiledMmaKQ = typename CollectiveMmaKQ::TiledMma; + + // compute dP + using CollectiveMmaVDO = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, TensorStrideContiguousK, Alignment, + Element, TensorStrideContiguousK, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapeVDO = typename CollectiveMmaVDO::TileShape; + using TiledMmaVDO = typename CollectiveMmaVDO::TiledMma; + + // compute dV + using CollectiveMmaPDO = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // needs to match ordering of S calculation + Element, TensorStrideContiguousK, Alignment, + Element, TensorStrideContiguousMN, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapePDO = typename CollectiveMmaPDO::TileShape; + using TiledMmaPDO = decltype(to_tiled_mma_sm100_ts(typename CollectiveMmaPDO::TiledMma{})); + + // compute dK + using CollectiveMmaDSQ = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // somewhat arbitrary since we dump to smem, need to agree with the next one + Element, TensorStrideContiguousK , Alignment, + Element, TensorStrideContiguousMN, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapeDSQ = typename CollectiveMmaDSQ::TileShape; + using TiledMmaDSQ = typename CollectiveMmaDSQ::TiledMma; + + // compute dQ + using CollectiveMmaDSK = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // somewhat arbitrary since we dump to smem, need to agree with the previous one + Element, TensorStrideContiguousMN, Alignment, + Element, TensorStrideContiguousMN, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapeDSK = typename CollectiveMmaDSK::TileShape; + using TiledMmaDSK = typename CollectiveMmaDSK::TiledMma; + + // pipelines are named Pipeline + static constexpr int kStagesComputeSmem = 1; + using PipelineLoadMmaQ = PipelineTmaUmmaAsync<2, ClusterShape>; + using PipelineLoadMmaDO = PipelineTmaUmmaAsync<1, ClusterShape>; + using PipelineLoadComputeLSE = PipelineAsync<1>; + using PipelineLoadComputeSumOdO = PipelineAsync<1>; + using PipelineMmaComputeS = PipelineUmmaAsync<1>; + using PipelineMmaComputeDP = PipelineUmmaAsync<1>; + using PipelineMmaReduceDQ = PipelineUmmaAsync<1>; + using PipelineComputeMmaP = PipelineUmmaConsumerAsync<1>; + using PipelineComputeMmaDS = PipelineUmmaConsumerAsync; + using PipelineMmaComputeDKDV = PipelineUmmaAsync<2>; + static constexpr int kStagesReduceTmaStore = 2; + using PipelineReduceTmaStore = PipelineTmaStore; + + struct PipelineStorage { + alignas(16) typename PipelineLoadMmaQ::SharedStorage load_mma_q; + alignas(16) typename PipelineLoadMmaDO::SharedStorage load_mma_do; + alignas(16) typename PipelineLoadComputeLSE::SharedStorage load_compute_lse; + alignas(16) typename PipelineLoadComputeSumOdO::SharedStorage load_compute_sum_odo; + alignas(16) typename PipelineMmaComputeS::SharedStorage mma_compute_s; + alignas(16) typename PipelineMmaComputeDP::SharedStorage mma_compute_dp; + alignas(16) typename PipelineMmaReduceDQ::SharedStorage mma_reduce_dq; + alignas(16) typename PipelineComputeMmaP::SharedStorage compute_mma_p; + alignas(16) typename PipelineComputeMmaDS::SharedStorage compute_mma_ds; + alignas(16) typename PipelineMmaComputeDKDV::SharedStorage mma_compute_dkdv; + }; + + template + static CUTE_DEVICE constexpr auto restage(Layout const& layout, Stages stages = {}) { + return composition(layout, make_tuple(_, _, _, make_layout(stages))); + } + + using SmemLayoutK = decltype(restage(typename CollectiveMmaKQ::SmemLayoutA{})); + using SmemLayoutV = decltype(restage(typename CollectiveMmaVDO::SmemLayoutA{})); + using SmemLayoutQ = decltype(restage(typename CollectiveMmaKQ::SmemLayoutB{}, _2{})); + using SmemLayoutDO = decltype(restage(typename CollectiveMmaVDO::SmemLayoutB{}, _1{})); + using SmemLayoutDS = decltype(restage(typename CollectiveMmaDSK::SmemLayoutA{}, Int{})); + using SmemLayoutLSE = Layout>; + using SmemLayoutSumOdO = Layout>; + + using SmemLayoutQT = decltype(restage(typename CollectiveMmaDSQ::SmemLayoutB{}, _2{})); + using SmemLayoutKT = decltype(restage(typename CollectiveMmaDSK::SmemLayoutB{})); + using SmemLayoutDST = decltype(restage(typename CollectiveMmaDSQ::SmemLayoutA{}, Int{})); + using SmemLayoutDOT = decltype(restage(typename CollectiveMmaPDO::SmemLayoutB{}, _1{})); + + using TileShapeDQ = _32; + using SmemAtomDQ = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + cute::UMMA::Major::K, ElementAcc, TileShapeQ, TileShapeDQ + >()); + using SmemShapeDQ = Shape>; + using SmemLayoutDQ = decltype(tile_to_shape(SmemAtomDQ{}, SmemShapeDQ{}, Step<_2, _1, _3>{})); + + struct TensorStorage { + union { + alignas(2048) cute::array> smem_k; + alignas(2048) cute::array> smem_k_t; + }; + alignas(2048) cute::array> smem_v; + union { + alignas(2048) cute::array> smem_q; + alignas(2048) cute::array> smem_q_t; + }; + union { + alignas(2048) cute::array> smem_do; + alignas(2048) cute::array> smem_do_t; + }; + union { + alignas(2048) cute::array> smem_ds; + alignas(2048) cute::array> smem_ds_t; + }; + alignas(1024) cute::array> smem_dq; + alignas(16) cute::array> smem_lse; + alignas(16) cute::array> smem_sum_odo; + }; + + static constexpr int kTransactionsBytesLoadQ = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v); + static constexpr int kTransactionsBytesLoadDO = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutDO{})) * cute::sizeof_bits_v); + + static constexpr int kTransactionsBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v); + static constexpr int kTransactionsBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v); + + struct SharedStorage { + TensorStorage tensors; + PipelineStorage pipelines; + uint32_t tmem_base_ptr; + }; + + // this is tight enough that it won't work with sizeof due to padding for alignment + static constexpr int SharedStorageSize = offsetof(SharedStorage, tmem_base_ptr) + sizeof(uint32_t); + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "using too much smem"); + + using TensorStride = TensorStrideContiguousK; // S D (H B) + using RowTensorStride = Stride<_1, Stride>; // S (H B) + + struct MainloopArguments { + const Element* ptr_q; + TensorStride stride_q; + const Element* ptr_k; + TensorStride stride_k; + const Element* ptr_v; + TensorStride stride_v; + const Element* ptr_do; + TensorStride stride_do; + + const ElementAcc* ptr_lse; + RowTensorStride stride_lse; + + const ElementAcc* ptr_sum_odo; + RowTensorStride stride_sum_odo; + + ElementAcc* ptr_dq_acc; + TensorStride stride_dq_acc; + + ElementAcc softmax_scale = 1.0f / sqrtf(TileShapeDQK{}); + }; + + using TMA_K = typename CollectiveMmaKQ::Params::TMA_A; + using TMA_V = typename CollectiveMmaVDO::Params::TMA_A; + using TMA_Q = typename CollectiveMmaKQ::Params::TMA_B; + using TMA_DO = typename CollectiveMmaVDO::Params::TMA_B; + + using TMA_DQ = decltype(make_tma_copy(SM90_TMA_REDUCE_ADD{}, + make_tensor((const ElementAcc*)nullptr, make_shape(1, 1, make_shape(1, 1)), TensorStride{}), + SmemLayoutDQ{}(_, _, _0{}) + )); + + struct MainloopParams { + TMA_K tma_load_k; + TMA_V tma_load_v; + TMA_Q tma_load_q; + TMA_DO tma_load_do; + TMA_DQ tma_red_dq; + }; + + struct EpilogueArguments { + Element* ptr_dk; + TensorStride stride_dk; + Element* ptr_dv; + TensorStride stride_dv; + }; + + struct Arguments { + ProblemShape problem_shape; + MainloopArguments mainloop; + EpilogueArguments epilogue; + KernelHardwareInfo hw_info; + }; + + struct Params { + ProblemShape problem_shape; + MainloopArguments mainloop; + MainloopParams mainloop_params; + EpilogueArguments epilogue; + KernelHardwareInfo hw_info; + }; + + + static bool can_implement(Arguments const& args) { + auto [Q, K, D, D_VO, HB] = args.problem_shape; + auto [H, B] = HB; + if (Q <= 0 || K <= 0 || D <= 0 || D_VO <= 0 || H <= 0 || B <= 0) { + return false; + } + if (D % Alignment != 0 || D_VO % Alignment != 0) { + return false; + } + return true; + } + + + static Status initialize_workspace(Arguments const&, void*, cudaStream_t) { + return Status::kSuccess; + } + + + static Params to_underlying_arguments(Arguments const& args, void*) { + auto [Q_, K_, D, D_VO, HB] = args.problem_shape; + int Q = Q_; + int K = K_; + + if constexpr (is_variable_length_v) { + Q = Q_.total_length; + } + if constexpr (is_variable_length_v) { + K = K_.total_length; + } + + auto params_kq = CollectiveMmaKQ::to_underlying_arguments( + make_shape(K, Q, D, HB), + typename CollectiveMmaKQ::Arguments { + args.mainloop.ptr_k, args.mainloop.stride_k, + args.mainloop.ptr_q, args.mainloop.stride_q, + }, /*workspace=*/nullptr); + + auto params_vdo = CollectiveMmaVDO::to_underlying_arguments( + make_shape(K, Q, D_VO, HB), + typename CollectiveMmaVDO::Arguments { + args.mainloop.ptr_v, args.mainloop.stride_v, + args.mainloop.ptr_do, args.mainloop.stride_do, + }, /*workspace=*/nullptr); + + TMA_DQ tma_red_dq = make_tma_copy( + SM90_TMA_REDUCE_ADD{}, + make_tensor(args.mainloop.ptr_dq_acc, make_shape(Q_, D, HB), args.mainloop.stride_dq_acc), + SmemLayoutDQ{}(_, _, _0{}) + ); + + return Params{ + args.problem_shape, + args.mainloop, + MainloopParams{ + params_kq.tma_load_a, + params_vdo.tma_load_a, + params_kq.tma_load_b, + params_vdo.tma_load_b, + tma_red_dq + }, + args.epilogue, + args.hw_info + }; + } + + + template + static CUTLASS_DEVICE auto quantize(T const& input) { + constexpr int AlignmentS = 4; + auto output = make_tensor(shape(input)); + auto input_vec = recast>(input); + auto output_vec = recast>(output); + + cutlass::NumericArrayConverter epilogue_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(input_vec); i++) { + output_vec(i) = epilogue_op(input_vec(i)); + } + + return output; + } + + + template + CUTLASS_DEVICE void load( + BlkCoord const& blk_coord, + BlkOffset const& blk_offset, + ProblemShape_ const& problem_shape, + int iter_index, + int iter_count, + MainloopArguments const& mainloop_args, + MainloopParams const& mainloop_params, + TensorStorage& shared_tensors, + PipelineLoadMmaQ& pipeline_load_mma_q, + typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_producer_state, + PipelineLoadMmaDO& pipeline_load_mma_do, + typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_producer_state, + PipelineLoadComputeLSE& pipeline_load_compute_lse, + typename PipelineLoadComputeLSE::PipelineState& pipeline_load_compute_lse_producer_state, + PipelineLoadComputeSumOdO& pipeline_load_compute_sum_odo, + typename PipelineLoadComputeSumOdO::PipelineState& pipeline_load_compute_sum_odo_producer_state) { + + auto [Q, K, D, D_VO, HB] = problem_shape; + + using X = Underscore; + + uint16_t mcast_mask = 0; + + auto mK_in = mainloop_params.tma_load_k.get_tma_tensor(make_shape(K, D, HB)); + auto mV_in = mainloop_params.tma_load_v.get_tma_tensor(make_shape(K, D_VO, HB)); + auto mQ_in = mainloop_params.tma_load_q.get_tma_tensor(make_shape(Q, D, HB)); + auto mDO_in = mainloop_params.tma_load_do.get_tma_tensor(make_shape(Q, D_VO, HB)); + + auto mK = domain_offset(select<1,2,4>(blk_offset), mK_in); + auto mV = domain_offset(select<1,3,4>(blk_offset), mV_in); + auto mQ = domain_offset(select<0,2,4>(blk_offset), mQ_in); + auto mDO = domain_offset(select<0,3,4>(blk_offset), mDO_in); + + auto gK = local_tile(mK, TileShapeKQ{}, make_coord(_,_,_), Step<_1, X, _1>{}); + auto gQ = local_tile(mQ, TileShapeKQ{}, make_coord(_,_,_), Step{}); + auto gV = local_tile(mV, TileShapeVDO{}, make_coord(_,_,_), Step<_1, X, _1>{}); + auto gDO = local_tile(mDO, TileShapeVDO{}, make_coord(_,_,_), Step{}); + + ThrMMA cta_mma_kq = TiledMmaKQ{}.get_slice(_0{}); + ThrMMA cta_mma_vdo = TiledMmaVDO{}.get_slice(_0{}); + + auto tSTgK = cta_mma_kq.partition_A(gK); + auto tSTgQ = cta_mma_kq.partition_B(gQ); + auto tDPTgV = cta_mma_vdo.partition_A(gV); + auto tDPTgDO = cta_mma_vdo.partition_B(gDO); + + auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + auto sK = make_tensor(make_smem_ptr(shared_tensors.smem_k.begin()), SmemLayoutK{}); + auto sV = make_tensor(make_smem_ptr(shared_tensors.smem_v.begin()), SmemLayoutV{}); + auto sDO = make_tensor(make_smem_ptr(shared_tensors.smem_do.begin()), SmemLayoutDO{}); + + auto [tKgK_mkl, tKsK] = tma_partition( + mainloop_params.tma_load_k, _0{}, make_layout(_1{}), + group_modes<0,3>(sK), group_modes<0,3>(tSTgK)); + auto [tQgQ_mkl, tQsQ] = tma_partition( + mainloop_params.tma_load_q, _0{}, make_layout(_1{}), + group_modes<0,3>(sQ), group_modes<0,3>(tSTgQ)); + auto [tVgV_mkl, tVsV] = tma_partition( + mainloop_params.tma_load_v, _0{}, make_layout(_1{}), + group_modes<0,3>(sV), group_modes<0,3>(tDPTgV)); + auto [tDOgDO_mkl, tDOsDO] = tma_partition( + mainloop_params.tma_load_do, _0{}, make_layout(_1{}), + group_modes<0,3>(sDO), group_modes<0,3>(tDPTgDO)); + + // set up lse and sum_odo + + auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; + + pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state); + auto tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state); + + pipeline_load_mma_q.producer_expect_transaction(pipeline_load_mma_q_producer_state, kTransactionsBytesLoadK); + + // load K + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_k.with(*tma_barrier, mcast_mask), + tKgK_mkl(_, blk_coord_k, _0{}, blk_coord_batch), + tKsK(_, _0{}) + ); + } + + // load Q + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask), + tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch), + tQsQ(_, pipeline_load_mma_q_producer_state.index()) + ); + } + + ++pipeline_load_mma_q_producer_state; + + pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state); + + // load LSE + // 32 threads loading 128 values of 32b each + // so 4*32b=128b + + int thread_idx = threadIdx.x % NumThreadsPerWarp; + int smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * 4; + int gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4; + auto mLSE = make_tensor(mainloop_args.ptr_lse, make_shape(Q, HB), mainloop_args.stride_lse); + for (int i = 0; i < 4; i++) { + cutlass::arch::cp_async_zfill<4>( + shared_tensors.smem_lse.begin() + smem_idx + i, + &mLSE(gmem_idx + i, blk_coord_batch), + gmem_idx + i < Q + ); + } + + pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_load_compute_lse_producer_state; + + + pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state); + tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state); + + pipeline_load_mma_do.producer_expect_transaction(pipeline_load_mma_do_producer_state, kTransactionsBytesLoadV); + + // load V + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_v.with(*tma_barrier, mcast_mask), + tVgV_mkl(_, blk_coord_k, _0{}, blk_coord_batch), + tVsV(_, _0{}) + ); + } + + // load dO + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask), + tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch), + tDOsDO(_, pipeline_load_mma_do_producer_state.index()) + ); + } + + ++pipeline_load_mma_do_producer_state; + + pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state); + + // load sum_OdO + smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * 4; + gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4; + auto mSumOdO = make_tensor(mainloop_args.ptr_sum_odo, make_shape(Q, HB), mainloop_args.stride_sum_odo); + for (int i = 0; i < 4; i++) { + cutlass::arch::cp_async_zfill<4>( + shared_tensors.smem_sum_odo.begin() + smem_idx + i, + &mSumOdO(gmem_idx + i, blk_coord_batch), + gmem_idx + i < Q + ); + } + + pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_load_compute_sum_odo_producer_state; + + iter_count -= 1; + iter_index += 1; + + while (iter_count > 0) { + pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state); + tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state); + + // load Q + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask), + tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch), + tQsQ(_, pipeline_load_mma_q_producer_state.index()) + ); + } + + ++pipeline_load_mma_q_producer_state; + + pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state); + + // load LSE + smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * 4; + gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4; + for (int i = 0; i < 4; i++) { + cutlass::arch::cp_async_zfill<4>( + shared_tensors.smem_lse.begin() + smem_idx + i, + &mLSE(gmem_idx + i, blk_coord_batch), + gmem_idx + i < Q + ); + } + + pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_load_compute_lse_producer_state; + + pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state); + tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state); + + // load dO + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask), + tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch), + tDOsDO(_, pipeline_load_mma_do_producer_state.index()) + ); + } + + ++pipeline_load_mma_do_producer_state; + + pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state); + + // load sum_OdO + smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * 4; + gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4; + for (int i = 0; i < 4; i++) { + cutlass::arch::cp_async_zfill<4>( + shared_tensors.smem_sum_odo.begin() + smem_idx + i, + &mSumOdO(gmem_idx + i, blk_coord_batch), + gmem_idx + i < Q + ); + } + + pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_load_compute_sum_odo_producer_state; + + iter_count -= 1; + iter_index += 1; + } + } + + + template + CUTLASS_DEVICE void mma( + BlkCoord const& blk_coord, + ProblemShape_ const& problem_shape, + int iter_index, + int iter_count, + MainloopArguments const& mainloop_args, + TensorStorage& shared_tensors, + PipelineLoadMmaQ& pipeline_load_mma_q, + typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_consumer_state, + PipelineLoadMmaDO& pipeline_load_mma_do, + typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_consumer_state, + PipelineMmaComputeS& pipeline_mma_compute_s, + typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_producer_state, + PipelineMmaComputeDP& pipeline_mma_compute_dp, + typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_producer_state, + PipelineMmaReduceDQ& pipeline_mma_reduce_dq, + typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_producer_state, + PipelineComputeMmaP& pipeline_compute_mma_p, + typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_consumer_state, + PipelineComputeMmaDS& pipeline_compute_mma_ds, + typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_consumer_state, + PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, + typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_producer_state) { + + auto [Q, K, D, D_VO, HB] = problem_shape; + + auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + auto sK = make_tensor(make_smem_ptr(shared_tensors.smem_k.begin()), SmemLayoutK{}); + auto sV = make_tensor(make_smem_ptr(shared_tensors.smem_v.begin()), SmemLayoutV{}); + auto sDO = make_tensor(make_smem_ptr(shared_tensors.smem_do.begin()), SmemLayoutDO{}); + + auto sQT = make_tensor(make_smem_ptr(shared_tensors.smem_q_t.begin()), SmemLayoutQT{}); + auto sKT = make_tensor(make_smem_ptr(shared_tensors.smem_k_t.begin()), SmemLayoutKT{}); + auto sDS = make_tensor(make_smem_ptr(shared_tensors.smem_ds.begin()), SmemLayoutDS{}); + auto sDST = make_tensor(make_smem_ptr(shared_tensors.smem_ds_t.begin()), SmemLayoutDST{}); + auto sP = make_tensor(make_smem_ptr((Element*) nullptr), typename CollectiveMmaPDO::SmemLayoutA{}); + auto sDOT = make_tensor(make_smem_ptr(shared_tensors.smem_do_t.begin()), SmemLayoutDOT{}); + + Tensor tSTrK = TiledMmaKQ::make_fragment_A(sK); + Tensor tSTrQ = TiledMmaKQ::make_fragment_B(sQ); + + Tensor tDPTrV = TiledMmaVDO::make_fragment_A(sV); + Tensor tDPTrDO = TiledMmaVDO::make_fragment_B(sDO); + + Tensor tDQrDS = TiledMmaDSK::make_fragment_A(sDS); + Tensor tDQrKT = TiledMmaDSK::make_fragment_B(sKT); + + Tensor tDKrDST = TiledMmaDSQ::make_fragment_A(sDST); + Tensor tDKrQT = TiledMmaDSQ::make_fragment_B(sQT); + + Tensor tDVrP = TiledMmaPDO::make_fragment_A(sP)(_, _, _, _0{}); + tDVrP.data() = TmemAllocation::kP; + Tensor tDVrDOT = TiledMmaPDO::make_fragment_B(sDOT); + + TiledMmaKQ tiled_mma_kq; + TiledMmaVDO tiled_mma_vdo; + TiledMmaDSK tiled_mma_dsk; + TiledMmaDSQ tiled_mma_dsq; + TiledMmaPDO tiled_mma_pdo; + + tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::Zero; + tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::Zero; + + Tensor tSTtST = partition_fragment_C(tiled_mma_kq, select<0,1>(TileShapeKQ{})); + tSTtST.data() = TmemAllocation::kS; + + Tensor tDPTtDPT = partition_fragment_C(tiled_mma_vdo, select<0,1>(TileShapeVDO{})); + tDPTtDPT.data() = TmemAllocation::kDP; + + Tensor tDQtDQ = partition_fragment_C(tiled_mma_dsk, select<0,1>(TileShapeDSK{})); + tDQtDQ.data() = TmemAllocation::kDQ; + + Tensor tDKtDK = partition_fragment_C(tiled_mma_dsq, select<0,1>(TileShapeDSQ{})); + tDKtDK.data() = TmemAllocation::kDK; + + Tensor tDVtDV = partition_fragment_C(tiled_mma_pdo, select<0,1>(TileShapePDO{})); + tDVtDV.data() = TmemAllocation::kDV; + + auto pipeline_load_mma_q_release_state = pipeline_load_mma_q_consumer_state; + + pipeline_load_mma_q.consumer_wait(pipeline_load_mma_q_consumer_state); + pipeline_mma_compute_s.producer_acquire(pipeline_mma_compute_s_producer_state); + + // S = Q*K + tiled_mma_kq.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tSTrQ); ++k_block) { + cute::gemm(tiled_mma_kq, + tSTrK(_,_,k_block,_0{}), + tSTrQ(_,_,k_block,pipeline_load_mma_q_consumer_state.index()), + tSTtST); + tiled_mma_kq.accumulate_ = UMMA::ScaleOut::One; + } + + ++pipeline_load_mma_q_consumer_state; + + pipeline_mma_compute_s.producer_commit(pipeline_mma_compute_s_producer_state); + ++pipeline_mma_compute_s_producer_state; + + pipeline_load_mma_do.consumer_wait(pipeline_load_mma_do_consumer_state); + + pipeline_mma_compute_dp.producer_acquire(pipeline_mma_compute_dp_producer_state); + pipeline_mma_reduce_dq.producer_acquire(pipeline_mma_reduce_dq_producer_state); + + // dP = dO*V + tiled_mma_vdo.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDPTrV); ++k_block) { + cute::gemm(tiled_mma_vdo, + tDPTrV(_,_,k_block,_0{}), + tDPTrDO(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), + tDPTtDPT); + tiled_mma_vdo.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_mma_compute_dp.producer_commit(pipeline_mma_compute_dp_producer_state); + ++pipeline_mma_compute_dp_producer_state; + + pipeline_compute_mma_p.consumer_wait(pipeline_compute_mma_p_consumer_state); + + // dV = P*dO + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDVrP); ++k_block) { + cute::gemm(tiled_mma_pdo, + tDVrP(_,_,k_block), + tDVrDOT(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), + tDVtDV); + tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_compute_mma_p.consumer_release(pipeline_compute_mma_p_consumer_state); + ++pipeline_compute_mma_p_consumer_state; + + pipeline_load_mma_do.consumer_release(pipeline_load_mma_do_consumer_state); + ++pipeline_load_mma_do_consumer_state; + + iter_count -= 1; + + // in tmem, S & P overlap + // and dP and dQ overlap + // so we need to acquire dQ and dP at the same time + while (iter_count > 0) { + pipeline_load_mma_q.consumer_wait(pipeline_load_mma_q_consumer_state); + pipeline_mma_compute_s.producer_acquire(pipeline_mma_compute_s_producer_state); + + // S = Q*K + tiled_mma_kq.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tSTrQ); ++k_block) { + cute::gemm(tiled_mma_kq, + tSTrK(_,_,k_block,_0{}), + tSTrQ(_,_,k_block,pipeline_load_mma_q_consumer_state.index()), + tSTtST); + tiled_mma_kq.accumulate_ = UMMA::ScaleOut::One; + } + + ++pipeline_load_mma_q_consumer_state; + + pipeline_mma_compute_s.producer_commit(pipeline_mma_compute_s_producer_state); + ++pipeline_mma_compute_s_producer_state; + + pipeline_compute_mma_ds.consumer_wait(pipeline_compute_mma_ds_consumer_state); + + // we need to acquire dP here, because tmem dQ == tmem dP + pipeline_mma_compute_dp.producer_acquire(pipeline_mma_compute_dp_producer_state); + + // dQ = dS*K + tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDQrDS); ++k_block) { + cute::gemm(tiled_mma_dsk, + tDQrDS(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), + tDQrKT(_,_,k_block,_0{}), + tDQtDQ); + tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_mma_reduce_dq.producer_commit(pipeline_mma_reduce_dq_producer_state); + ++pipeline_mma_reduce_dq_producer_state; + + // dK = dS*Q + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDKrDST); ++k_block) { + cute::gemm(tiled_mma_dsq, + tDKrDST(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), + tDKrQT(_,_,k_block,pipeline_load_mma_q_release_state.index()), + tDKtDK); + tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_load_mma_q.consumer_release(pipeline_load_mma_q_release_state); + ++pipeline_load_mma_q_release_state; + + pipeline_compute_mma_ds.consumer_release(pipeline_compute_mma_ds_consumer_state); + ++pipeline_compute_mma_ds_consumer_state; + + // we grab dq here, because in tmem dq == dp + pipeline_mma_reduce_dq.producer_acquire(pipeline_mma_reduce_dq_producer_state); + + pipeline_load_mma_do.consumer_wait(pipeline_load_mma_do_consumer_state); + + // dP = dO*V + tiled_mma_vdo.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDPTrV); ++k_block) { + cute::gemm(tiled_mma_vdo, + tDPTrV(_,_,k_block,_0{}), + tDPTrDO(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), + tDPTtDPT); + tiled_mma_vdo.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_mma_compute_dp.producer_commit(pipeline_mma_compute_dp_producer_state); + ++pipeline_mma_compute_dp_producer_state; + + pipeline_compute_mma_p.consumer_wait(pipeline_compute_mma_p_consumer_state); + + // dV = P*dO + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDVrP); ++k_block) { + cute::gemm(tiled_mma_pdo, + tDVrP(_,_,k_block), + tDVrDOT(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), + tDVtDV); + tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_compute_mma_p.consumer_release(pipeline_compute_mma_p_consumer_state); + ++pipeline_compute_mma_p_consumer_state; + + pipeline_load_mma_do.consumer_release(pipeline_load_mma_do_consumer_state); + ++pipeline_load_mma_do_consumer_state; + + iter_count -= 1; + } + + // signal to the epilogue that dV is ready + pipeline_mma_compute_dkdv.producer_acquire(pipeline_mma_compute_dkdv_producer_state); + pipeline_mma_compute_dkdv.producer_commit(pipeline_mma_compute_dkdv_producer_state); + ++pipeline_mma_compute_dkdv_producer_state; + + pipeline_mma_compute_dkdv.producer_acquire(pipeline_mma_compute_dkdv_producer_state); + + pipeline_compute_mma_ds.consumer_wait(pipeline_compute_mma_ds_consumer_state); + + // dK = dS*Q + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDKrDST); ++k_block) { + cute::gemm(tiled_mma_dsq, + tDKrDST(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), + tDKrQT(_,_,k_block,pipeline_load_mma_q_release_state.index()), + tDKtDK); + tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::One; + } + + // signal to epilgue that dK is ready + pipeline_mma_compute_dkdv.producer_commit(pipeline_mma_compute_dkdv_producer_state); + ++pipeline_mma_compute_dkdv_producer_state; + + // we've already acquired mma_reduce_dq in the loop + + // dQ = dS*K + tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDQrDS); ++k_block) { + cute::gemm(tiled_mma_dsk, + tDQrDS(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), + tDQrKT(_,_,k_block,_0{}), + tDQtDQ); + tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_mma_reduce_dq.producer_commit(pipeline_mma_reduce_dq_producer_state); + ++pipeline_mma_reduce_dq_producer_state; + + pipeline_load_mma_q.consumer_release(pipeline_load_mma_q_release_state); + ++pipeline_load_mma_q_release_state; + + pipeline_compute_mma_ds.consumer_release(pipeline_compute_mma_ds_consumer_state); + ++pipeline_compute_mma_ds_consumer_state; + } + + + + template + CUTLASS_DEVICE void store( + TensorG gmem, + TensorR const& regs, + TensorC const& coord, + TensorShape const& tensor_shape) { + + //TODO Performance of FlashMLA on hopper is dropped with latest cutlass, so here revert the to the old version. + // Tensor preds = cute::lazy::transform(coord, [&](auto const& c) { return elem_less(c, tensor_shape); }); + + auto copy_op = make_cotiled_copy( + Copy_Atom, Element>{}, + make_layout(make_shape(_1{}, Int{})), + regs.layout() + ); + auto thr_copy = copy_op.get_slice(_0{}); + + Tensor quantized_regs = quantize(regs); + auto tCg = thr_copy.partition_D(gmem); + auto tCr = thr_copy.partition_S(quantize(regs)); + auto tCc = thr_copy.partition_D(coord); + + + constexpr int R = decltype(tCr.layout())::rank; + auto tCg_v = group_modes<1, R>(tCg); + auto tCr_v = group_modes<1, R>(tCr); + auto tCc_v = group_modes<1, R>(tCc); + auto tCp_v = make_tensor(shape<1>(tCc_v)); + + for (int i = 0; i < size(tCp_v); ++i) { + tCp_v(i) = elem_less(tCc_v(_0{},i), tensor_shape); + } + + copy_if(copy_op, tCp_v, tCr_v, tCg_v); + + } + + + template + CUTLASS_DEVICE void epilogue_clear( + BlkCoord const& blk_coord, + BlkOffset const& blk_offset, + ProblemShape_ const& problem_shape, + MainloopArguments const& mainloop_args, + EpilogueArguments const& epilogue_args) { + + auto [Q, K, D, D_VO, HB] = problem_shape; + auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; + + auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk); + auto mDK = domain_offset(select<1,2,4>(blk_offset), mDK_in); + auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{}) + (_, _, blk_coord_k, _0{}, blk_coord_batch); + + Tensor cDK = domain_offset( + make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}), + make_identity_tensor(take<0,2>(TileShapeDSQ{})) + ); + + auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv); + auto mDV = domain_offset(select<1,3,4>(blk_offset), mDV_in); + auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{}) + (_, _, blk_coord_k, _0{}, blk_coord_batch); + + Tensor cDV = domain_offset( + make_coord(blk_coord_k * TileShapeK{}, _0{}), + make_identity_tensor(take<0,2>(TileShapePDO{})) + ); + + for (int i = threadIdx.x; i < size(gDK); i += blockDim.x) { + if (elem_less(cDK(i), select<1,2>(problem_shape))) { + gDK(i) = Element(0); + } + } + for (int i = threadIdx.x; i < size(gDV); i += blockDim.x) { + if (elem_less(cDV(i), select<1,3>(problem_shape))) { + gDV(i) = Element(0); + } + } + } + + + template + CUTLASS_DEVICE void epilogue( + BlkCoord const& blk_coord, + BlkOffset const& blk_offset, + ProblemShape_ const& problem_shape, + MainloopArguments const& mainloop_args, + EpilogueArguments const& epilogue_args, + PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, + typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) { + + auto [Q, K, D, D_VO, HB] = problem_shape; + auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; + + auto load_op = SM100_TMEM_LOAD_32dp32b16x{}; + + auto tDKtDK = partition_fragment_C(TiledMmaDSQ{}, select<0,1>(TileShapeDSQ{}))(make_coord(_,_),_0{},_0{}); + tDKtDK.data() = TmemAllocation::kDK; + + auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk); + auto mDK = domain_offset(select<1,2,4>(blk_offset), mDK_in); + auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{}) + (_, _, blk_coord_k, _0{}, blk_coord_batch); + + Tensor cDK = domain_offset( + make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}), + make_identity_tensor(take<0,2>(TileShapeDSQ{})) + ); + + constexpr int kNumWarpgroups = kNumComputeWarps / 4; + int dp_idx = threadIdx.x % 128; + int wg_idx = (threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp)) / 128; + + auto split_wg = [&](auto const& t) { + if constexpr (decltype(rank(t))::value == 3) { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), make_shape(Int{}, size<2>(t) / Int{})))); + return p(_, _, make_coord(wg_idx, _)); + } + else { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), size<2>(t), make_shape(Int{}, size<3>(t) / Int{})))); + return p(_, _, _, make_coord(wg_idx, _)); + } + }; + + auto tiled_t2r_dk = make_tmem_copy(load_op, tDKtDK); + auto thread_t2r_dk = tiled_t2r_dk.get_slice(dp_idx); + + Tensor tTR_cDK = split_wg(thread_t2r_dk.partition_D(cDK)); + Tensor tTR_gDK = split_wg(thread_t2r_dk.partition_D(gDK)); + Tensor tTR_rDK = make_tensor(shape(tTR_cDK)); + Tensor tTR_tDK = split_wg(thread_t2r_dk.partition_S(tDKtDK)); + + auto tDVtDV = partition_fragment_C(TiledMmaDSQ{}, select<0,1>(TileShapeDSQ{}))(make_coord(_,_),_0{},_0{}); + tDVtDV.data() = TmemAllocation::kDV; + + auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv); + auto mDV = domain_offset(select<1,3,4>(blk_offset), mDV_in); + auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{}) + (_, _, blk_coord_k, _0{}, blk_coord_batch); + + Tensor cDV = domain_offset( + make_coord(blk_coord_k * TileShapeK{}, _0{}), + make_identity_tensor(take<0,2>(TileShapePDO{})) + ); + + auto tiled_t2r_dv = make_tmem_copy(load_op, tDVtDV); + auto thread_t2r_dv = tiled_t2r_dv.get_slice(dp_idx); + + Tensor tTR_cDV = split_wg(thread_t2r_dv.partition_D(cDV)); + Tensor tTR_gDV = split_wg(thread_t2r_dv.partition_D(gDV)); + Tensor tTR_rDV = make_tensor(shape(tTR_cDV)); + Tensor tTR_tDV = split_wg(thread_t2r_dv.partition_S(tDVtDV)); + + pipeline_mma_compute_dkdv.consumer_wait(pipeline_mma_compute_dkdv_consumer_state); + + // load tDVtDV + cute::copy(tiled_t2r_dv, tTR_tDV, tTR_rDV); + + // store tDVgDV + store(tTR_gDV, tTR_rDV, tTR_cDV, select<1,3>(problem_shape)); + + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_compute_dkdv.consumer_release(pipeline_mma_compute_dkdv_consumer_state); + ++pipeline_mma_compute_dkdv_consumer_state; + + pipeline_mma_compute_dkdv.consumer_wait(pipeline_mma_compute_dkdv_consumer_state); + + // load tDKtDK + cute::copy(tiled_t2r_dk, tTR_tDK, tTR_rDK); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rDK); i++) { + tTR_rDK(i) = mainloop_args.softmax_scale * tTR_rDK(i); + } + + // store tDKgDK + store(tTR_gDK, tTR_rDK, tTR_cDK, select<1,2>(problem_shape)); + + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_compute_dkdv.consumer_release(pipeline_mma_compute_dkdv_consumer_state); + ++pipeline_mma_compute_dkdv_consumer_state; + + } + + + template + CUTLASS_DEVICE void compute( + BlkCoord const& blk_coord, + BlkOffset const& blk_offset, + ProblemShape_ const& problem_shape, + int iter_index, + int iter_count, + MainloopArguments const& mainloop_args, + EpilogueArguments const& epilogue_args, + TensorStorage& shared_tensors, + PipelineLoadComputeLSE& pipeline_load_compute_lse, + typename PipelineLoadComputeLSE::PipelineState& pipeline_load_compute_lse_consumer_state, + PipelineLoadComputeSumOdO& pipeline_load_compute_sum_odo, + typename PipelineLoadComputeSumOdO::PipelineState& pipeline_load_compute_sum_odo_consumer_state, + PipelineMmaComputeS& pipeline_mma_compute_s, + typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_consumer_state, + PipelineMmaComputeDP& pipeline_mma_compute_dp, + typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_consumer_state, + PipelineComputeMmaP& pipeline_compute_mma_p, + typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_producer_state, + PipelineComputeMmaDS& pipeline_compute_mma_ds, + typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_producer_state, + PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, + typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) { + + + auto [Q, K, D, D_VO, HB] = problem_shape; + + // in tmem, S & P overlap + // and dP and dQ overlap + + // there are two compute wg's that cooperatively compute softmax + // they are striped by this tmem atom, i.e. wg0 has 16 elems, then wg1 etc + + auto load_op = SM100_TMEM_LOAD_32dp32b16x{}; + auto store_op = []() { + if constexpr (sizeof(Element) == 1) { + return SM100_TMEM_STORE_32dp32b4x{}; + } + else { + return SM100_TMEM_STORE_32dp32b8x{}; + } + }(); + + Tensor tSTtST = partition_fragment_C(TiledMmaKQ{}, select<0,1>(TileShapeKQ{}))(make_coord(_,_),_0{},_0{}); + tSTtST.data() = TmemAllocation::kS; + + Tensor tDPTtDPT = partition_fragment_C(TiledMmaVDO{}, select<0,1>(TileShapeVDO{}))(make_coord(_,_),_0{},_0{}); + tDPTtDPT.data() = TmemAllocation::kDP; + + Tensor cST = make_identity_tensor(take<0,2>(TileShapeKQ{})); + Tensor cDPT = make_identity_tensor(take<0,2>(TileShapeVDO{})); + + constexpr int kNumWarpgroups = kNumComputeWarps / 4; + int dp_idx = threadIdx.x % 128; + int wg_idx = (threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp)) / 128; + auto tiled_t2r = make_tmem_copy(load_op, tSTtST); + auto thread_t2r = tiled_t2r.get_slice(dp_idx); + + auto split_wg = [&](auto const& t) { + if constexpr (decltype(size<1>(t))::value > 1) { + if constexpr (decltype(rank(t))::value == 3) { + auto p = t.compose(make_layout(make_shape(size<0>(t), make_shape(Int{}, size<1>(t) / Int{}), size<2>(t)))); + return p(_, make_coord(wg_idx, _), _); + } + else { + auto p = t.compose(make_layout(make_shape(size<0>(t), make_shape(Int{}, size<1>(t) / Int{}), size<2>(t), size<3>(t)))); + return p(_, make_coord(wg_idx, _), _, _); + } + } + else { + if constexpr (decltype(rank(t))::value == 3) { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), make_shape(Int{}, size<2>(t) / Int{})))); + return p(_, _, make_coord(wg_idx, _)); + } + else { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), size<2>(t), make_shape(Int{}, size<3>(t) / Int{})))); + return p(_, _, _, make_coord(wg_idx, _)); + } + + } + }; + + + Tensor tTR_cST_p = thread_t2r.partition_D(cST); + Tensor tTR_cST = split_wg(tTR_cST_p); + Tensor tTR_rST = make_tensor(shape(tTR_cST)); + Tensor tTR_tST = split_wg(thread_t2r.partition_S(tSTtST)); + + Tensor tTR_cDPT_p = thread_t2r.partition_D(cDPT); + Tensor tTR_cDPT = split_wg(tTR_cDPT_p); + Tensor tTR_rDPT = make_tensor(shape(tTR_cDPT)); + Tensor tTR_tDPT = split_wg(thread_t2r.partition_S(tDPTtDPT)); + + Tensor sLSE = make_tensor(make_smem_ptr(shared_tensors.smem_lse.begin()), SmemLayoutLSE{}); + Tensor sSumOdO = make_tensor(make_smem_ptr(shared_tensors.smem_sum_odo.begin()), SmemLayoutSumOdO{}); + + auto sP = make_tensor(make_smem_ptr((Element*) nullptr), typename CollectiveMmaPDO::SmemLayoutA{}); + + auto tDVrP = TiledMmaPDO::make_fragment_A(sP)(_, _, _, _0{}); + auto tDVcST = TiledMmaPDO{}.get_slice(_0{}).partition_A(cST); + tDVrP.data() = TmemAllocation::kP; + + auto tiled_r2t = make_tmem_copy(store_op, tDVrP); + auto thread_r2t = tiled_r2t.get_slice(dp_idx); + + auto tRT_tP = split_wg(thread_r2t.partition_D(tDVrP)); + auto tRT_cST_p = thread_r2t.partition_S(tDVcST); + auto tRT_cST = split_wg(tRT_cST_p); + + bool is_residual_k = get<1>(blk_coord) * TileShapeK{} + TileShapeK{} >= get<1>(problem_shape); + int last_iter = iter_count - 1 + iter_index; + + CUTLASS_PRAGMA_NO_UNROLL + while (iter_count > 0) { + // wait for S and P + pipeline_mma_compute_s.consumer_wait(pipeline_mma_compute_s_consumer_state); + pipeline_compute_mma_p.producer_acquire(pipeline_compute_mma_p_producer_state); + // wait for LSE + pipeline_load_compute_lse.consumer_wait(pipeline_load_compute_lse_consumer_state); + + auto dispatch_bool = [](bool b, auto fn) { + if (b) { + fn(cute::true_type{}); + } + else { + fn(cute::false_type{}); + } + }; + + bool leading_causal_masking = false; + if constexpr (std::is_base_of_v, Mask>) { + leading_causal_masking = warp_uniform(iter_index == get<1>(blk_coord)); + } else if constexpr (std::is_base_of_v, Mask>) { + int offset = get<1>(problem_shape) - get<0>(problem_shape); + int kv_left = get<1>(blk_coord) * TileShapeK{}; + int kv_right = kv_left + TileShapeK{} - 1; + int q_left = iter_index * TileShapeQ{} + offset; + int q_right = q_left + TileShapeQ{} - 1; + + leading_causal_masking = warp_uniform(!((q_left > kv_right) || (q_right < kv_left))); + } + bool trailing_residual_masking = false; + if constexpr (std::is_base_of_v) { + trailing_residual_masking = warp_uniform((iter_index == last_iter) || is_residual_k); + } + + dispatch_bool(leading_causal_masking || trailing_residual_masking, [&](auto is_masked_tile) { + + // compute P = softmax(S, LSE) + cute::copy(tiled_t2r, tTR_tST, tTR_rST); + + if constexpr (decltype(is_masked_tile)::value) { + Mask{}.apply_mask(tTR_rST, [&](int i) { + auto c_transpose = tTR_cST(i); + return make_coord(get<1>(c_transpose) + iter_index * TileShapeQ{}, get<0>(c_transpose) + get<1>(blk_coord) * TileShapeK{}); + }, problem_shape); + } + + ElementAcc log2_e = static_cast(M_LOG2E); + float2 softmax_scale_log2_e; + softmax_scale_log2_e.x = mainloop_args.softmax_scale * log2_e; + softmax_scale_log2_e.y = mainloop_args.softmax_scale * log2_e; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rST); i += 2) { + float2 acc; + float2 lse; + float2 out; + acc.x = tTR_rST(i); + acc.y = tTR_rST(i + 1); + lse.x = sLSE(get<1>(tTR_cST(i)), pipeline_load_compute_lse_consumer_state.index()); + lse.y = sLSE(get<1>(tTR_cST(i+1)), pipeline_load_compute_lse_consumer_state.index()); + cute::fma(out, softmax_scale_log2_e, acc, lse); + tTR_rST(i) = ::exp2f(out.x); + tTR_rST(i+1) = ::exp2f(out.y); + } + + auto tRT_rST = quantize(tTR_rST); + auto tRT_rST_reshaped = make_tensor(tRT_rST.data(), shape(tRT_cST)); + + cutlass::arch::fence_view_async_tmem_load(); + cutlass::arch::NamedBarrier( + kNumComputeWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::TransformBarrier + ).arrive_and_wait(); + + cute::copy(tiled_r2t, tRT_rST_reshaped, tRT_tP); + }); + + // notify for P + cutlass::arch::fence_view_async_tmem_store(); + pipeline_compute_mma_p.producer_commit(pipeline_compute_mma_p_producer_state); + ++pipeline_compute_mma_p_producer_state; + // release S + pipeline_mma_compute_s.consumer_release(pipeline_mma_compute_s_consumer_state); + ++pipeline_mma_compute_s_consumer_state; + // release LSE + pipeline_load_compute_lse.consumer_release(pipeline_load_compute_lse_consumer_state); + ++pipeline_load_compute_lse_consumer_state; + + // wait for OdO + pipeline_load_compute_sum_odo.consumer_wait(pipeline_load_compute_sum_odo_consumer_state); + // wait for dP + pipeline_mma_compute_dp.consumer_wait(pipeline_mma_compute_dp_consumer_state); + + // wait for dS + // in principle, we could defer waiting for dS, and move in the freeing of dP + // however, that would force us to keep dS in registers longer + pipeline_compute_mma_ds.producer_acquire(pipeline_compute_mma_ds_producer_state); + + // compute dS = dsoftmax(P, dP, sum_OdO) + cute::copy(tiled_t2r, tTR_tDPT, tTR_rDPT); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rDPT); i += 2) { + float2 st; + st.x = tTR_rST(i); + st.y = tTR_rST(i+1); + float2 dpt; + dpt.x = tTR_rDPT(i); + dpt.y = tTR_rDPT(i+1); + float2 odo; + odo.x = sSumOdO(get<1>(tTR_cDPT(i)), pipeline_load_compute_sum_odo_consumer_state.index()); + odo.y = sSumOdO(get<1>(tTR_cDPT(i+1)), pipeline_load_compute_sum_odo_consumer_state.index()); + float2 dif; + // sum odo is negated during preprocess + cute::add(dif, dpt, odo); + float2 out; + cute::mul(out, dif, st); + tTR_rDPT(i) = out.x; + tTR_rDPT(i+1) = out.y; + } + + auto tTR_rDST = quantize(tTR_rDPT); + + // release dP + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_compute_dp.consumer_release(pipeline_mma_compute_dp_consumer_state); + ++pipeline_mma_compute_dp_consumer_state; + + Tensor sDS = make_tensor(make_smem_ptr((Element*) shared_tensors.smem_ds.begin()), SmemLayoutDS{}) + (_, _, _, pipeline_compute_mma_ds_producer_state.index()); + + auto thread_layout = make_ordered_layout( + make_shape(_128{}, _128{}), + make_stride(_1{}, _0{}) + ); + + auto sDS_pi = as_position_independent_swizzle_tensor(sDS); + auto sDS_pi_slice_p = sDS_pi.compose(thread_layout)(dp_idx, _).compose(make_layout(shape(tTR_cDPT_p))); + auto sDS_pi_slice = split_wg(sDS_pi_slice_p); + + copy_aligned(tTR_rDST, sDS_pi_slice); + + // notify for dS + cutlass::arch::fence_view_async_shared(); + pipeline_compute_mma_ds.producer_commit(pipeline_compute_mma_ds_producer_state); + ++pipeline_compute_mma_ds_producer_state; + // release OdO + pipeline_load_compute_sum_odo.consumer_release(pipeline_load_compute_sum_odo_consumer_state); + ++pipeline_load_compute_sum_odo_consumer_state; + + iter_count -= 1; + iter_index += 1; + } + + epilogue( + blk_coord, blk_offset, problem_shape, mainloop_args, epilogue_args, + pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_consumer_state + ); + } + + template + CUTLASS_DEVICE void reduce( + BlkCoord const& blk_coord, + ProblemShape_ const& problem_shape, + int iter_index, + int iter_count, + MainloopArguments const& mainloop_args, + MainloopParams const& mainloop_params, + TensorStorage& shared_tensors, + PipelineMmaReduceDQ& pipeline_mma_reduce_dq, + typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_consumer_state, + PipelineReduceTmaStore& pipeline_reduce_tma_store, + typename PipelineReduceTmaStore::PipelineState& pipeline_reduce_tma_store_producer_state) { + + using X = Underscore; + + auto [Q, K, D, D_VO, HB] = problem_shape; + + auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; + + // must match TileShapeDQ + auto load_op = SM100_TMEM_LOAD_32dp32b32x{}; + + auto tDQtDQ = partition_fragment_C(TiledMmaDSK{}, select<0,1>(TileShapeDSK{}))(make_coord(_,_),_0{},_0{}); + tDQtDQ.data() = TmemAllocation::kDQ; + + Tensor mDQ = mainloop_params.tma_red_dq.get_tma_tensor(make_shape(Q, D, HB)); + auto gDQ = local_tile(mDQ, TileShapeKQ{}, make_coord(_,_,_), Step{}) + (_, _, _, _0{}, blk_coord_batch); + + Tensor cDQ = make_identity_tensor(take<0,2>(TileShapeDSK{})); + + Tensor sDQ = make_tensor(make_smem_ptr(shared_tensors.smem_dq.begin()), SmemLayoutDQ{}); + + int thread_idx = threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp); + auto tiled_t2r = make_tmem_copy(load_op, tDQtDQ); + auto thread_t2r = tiled_t2r.get_slice(thread_idx); + + Tensor tTR_cDQ = thread_t2r.partition_D(cDQ); + Tensor tTR_gDQ = thread_t2r.partition_D(gDQ); + Tensor tTR_sDQ = thread_t2r.partition_D(sDQ); + Tensor tTR_tDQ = thread_t2r.partition_S(tDQtDQ); + + auto block_tma = mainloop_params.tma_red_dq.get_slice(_0{}); + + Tensor tDQsDQ = block_tma.partition_S(sDQ); + Tensor tDQcDQ = block_tma.partition_S(cDQ); + Tensor tDQgDQ = block_tma.partition_D(gDQ); + + int lane_predicate = (threadIdx.x % (kNumReduceWarps * NumThreadsPerWarp)) == 0; + + while (iter_count > 0) { + pipeline_mma_reduce_dq.consumer_wait(pipeline_mma_reduce_dq_consumer_state); + + Tensor tTR_rDQ = make_tensor(shape(tTR_cDQ)); + + // load dQ from tmem to rmem + cute::copy(tiled_t2r, tTR_tDQ, tTR_rDQ); + + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_reduce_dq.consumer_release(pipeline_mma_reduce_dq_consumer_state); + ++pipeline_mma_reduce_dq_consumer_state; + + // we don't have enough smem to dump it all to smem, so we do it in stages + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<2>(tTR_cDQ); i++) { + if (lane_predicate) { + pipeline_reduce_tma_store.producer_acquire(pipeline_reduce_tma_store_producer_state); + } + // wait in all threads for the acquire to complete + cutlass::arch::NamedBarrier( + kNumReduceWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::TransposeBarrier + ).arrive_and_wait(); + + cute::copy(tTR_rDQ(_, _, i), tTR_sDQ(_, _, _0{}, pipeline_reduce_tma_store_producer_state.index())); + + // wait for the stores to all be visible to the TMA + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier( + kNumReduceWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::TransposeBarrier + ).arrive_and_wait(); + if (lane_predicate) { + // launch tma store + copy(mainloop_params.tma_red_dq, tDQsDQ(_,_,_0{}, pipeline_reduce_tma_store_producer_state.index()), tDQgDQ(_,_,i,iter_index)); + pipeline_reduce_tma_store.producer_commit(pipeline_reduce_tma_store_producer_state); + } + + ++pipeline_reduce_tma_store_producer_state; + } + + iter_count -= 1; + iter_index += 1; + } + } + + + CUTLASS_DEVICE void operator()(Params const& params, char* smem) { + int warp_idx = cutlass::canonical_warp_idx_sync(); + auto role = warp_idx_to_role(warp_idx); + uint32_t lane_predicate = cute::elect_one_sync(); + + if (role == WarpRole::Load && lane_predicate) { + prefetch_tma_descriptor(params.mainloop_params.tma_load_q.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_k.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_v.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_do.get_tma_descriptor()); + } + + SharedStorage& shared_storage = *reinterpret_cast(smem); + + int initializing_warp = 0; + typename PipelineLoadMmaQ::Params pipeline_load_mma_q_params; + if (role == WarpRole::Load) { + pipeline_load_mma_q_params.role = PipelineLoadMmaQ::ThreadCategory::Producer; + } + if (role == WarpRole::Mma) { + pipeline_load_mma_q_params.role = PipelineLoadMmaQ::ThreadCategory::Consumer; + } + pipeline_load_mma_q_params.is_leader = lane_predicate && (role == WarpRole::Load); + // Also loads K in the first iteration + pipeline_load_mma_q_params.transaction_bytes = kTransactionsBytesLoadQ; + pipeline_load_mma_q_params.initializing_warp = initializing_warp++; + PipelineLoadMmaQ pipeline_load_mma_q(shared_storage.pipelines.load_mma_q, pipeline_load_mma_q_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineLoadMmaDO::Params pipeline_load_mma_do_params; + if (role == WarpRole::Load) { + pipeline_load_mma_do_params.role = PipelineLoadMmaDO::ThreadCategory::Producer; + } + if (role == WarpRole::Mma) { + pipeline_load_mma_do_params.role = PipelineLoadMmaDO::ThreadCategory::Consumer; + } + pipeline_load_mma_do_params.is_leader = lane_predicate && (role == WarpRole::Load); + // Also loads V in the first iteration + pipeline_load_mma_do_params.transaction_bytes = kTransactionsBytesLoadDO; + pipeline_load_mma_do_params.initializing_warp = initializing_warp++; + PipelineLoadMmaDO pipeline_load_mma_do(shared_storage.pipelines.load_mma_do, pipeline_load_mma_do_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineLoadComputeLSE::Params pipeline_load_compute_lse_params; + if (role == WarpRole::Load) { + pipeline_load_compute_lse_params.role = PipelineLoadComputeLSE::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_load_compute_lse_params.role = PipelineLoadComputeLSE::ThreadCategory::Consumer; + } + pipeline_load_compute_lse_params.producer_arv_count = NumThreadsPerWarp; + pipeline_load_compute_lse_params.consumer_arv_count = kNumComputeWarps * NumThreadsPerWarp; + pipeline_load_compute_lse_params.initializing_warp = initializing_warp++; + PipelineLoadComputeLSE pipeline_load_compute_lse( + shared_storage.pipelines.load_compute_lse, + pipeline_load_compute_lse_params, + /*barrier init*/ cute::true_type{}); + + typename PipelineLoadComputeSumOdO::Params pipeline_load_compute_sum_odo_params; + if (role == WarpRole::Load) { + pipeline_load_compute_sum_odo_params.role = PipelineLoadComputeSumOdO::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_load_compute_sum_odo_params.role = PipelineLoadComputeSumOdO::ThreadCategory::Consumer; + } + pipeline_load_compute_sum_odo_params.producer_arv_count = NumThreadsPerWarp; + pipeline_load_compute_sum_odo_params.consumer_arv_count = kNumComputeWarps * NumThreadsPerWarp; + pipeline_load_compute_sum_odo_params.initializing_warp = initializing_warp++; + PipelineLoadComputeSumOdO pipeline_load_compute_sum_odo( + shared_storage.pipelines.load_compute_sum_odo, + pipeline_load_compute_sum_odo_params, + /*barrier init*/ cute::true_type{}); + + typename PipelineMmaComputeS::Params pipeline_mma_compute_s_params; + if (role == WarpRole::Mma) { + pipeline_mma_compute_s_params.role = PipelineMmaComputeS::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_mma_compute_s_params.role = PipelineMmaComputeS::ThreadCategory::Consumer; + } + pipeline_mma_compute_s_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_mma_compute_s_params.initializing_warp = initializing_warp++; + PipelineMmaComputeS pipeline_mma_compute_s( + shared_storage.pipelines.mma_compute_s, + pipeline_mma_compute_s_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineMmaComputeDP::Params pipeline_mma_compute_dp_params; + if (role == WarpRole::Mma) { + pipeline_mma_compute_dp_params.role = PipelineMmaComputeDP::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_mma_compute_dp_params.role = PipelineMmaComputeDP::ThreadCategory::Consumer; + } + pipeline_mma_compute_dp_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_mma_compute_dp_params.initializing_warp = initializing_warp++; + PipelineMmaComputeDP pipeline_mma_compute_dp( + shared_storage.pipelines.mma_compute_dp, + pipeline_mma_compute_dp_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineMmaReduceDQ::Params pipeline_mma_reduce_dq_params; + if (role == WarpRole::Mma) { + pipeline_mma_reduce_dq_params.role = PipelineMmaReduceDQ::ThreadCategory::Producer; + } + if (role == WarpRole::Reduce) { + pipeline_mma_reduce_dq_params.role = PipelineMmaReduceDQ::ThreadCategory::Consumer; + } + pipeline_mma_reduce_dq_params.consumer_arv_count = kNumReduceWarps * cutlass::NumThreadsPerWarp; + pipeline_mma_reduce_dq_params.initializing_warp = initializing_warp++; + PipelineMmaReduceDQ pipeline_mma_reduce_dq( + shared_storage.pipelines.mma_reduce_dq, + pipeline_mma_reduce_dq_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineComputeMmaP::Params pipeline_compute_mma_p_params; + if (role == WarpRole::Mma) { + pipeline_compute_mma_p_params.role = PipelineComputeMmaP::ThreadCategory::Consumer; + } + if (role == WarpRole::Compute) { + pipeline_compute_mma_p_params.role = PipelineComputeMmaP::ThreadCategory::Producer; + } + pipeline_compute_mma_p_params.producer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_compute_mma_p_params.consumer_arv_count = 1; + pipeline_compute_mma_p_params.initializing_warp = initializing_warp++; + PipelineComputeMmaP pipeline_compute_mma_p( + shared_storage.pipelines.compute_mma_p, + pipeline_compute_mma_p_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineComputeMmaDS::Params pipeline_compute_mma_ds_params; + if (role == WarpRole::Mma) { + pipeline_compute_mma_ds_params.role = PipelineComputeMmaDS::ThreadCategory::Consumer; + } + if (role == WarpRole::Compute) { + pipeline_compute_mma_ds_params.role = PipelineComputeMmaDS::ThreadCategory::Producer; + } + pipeline_compute_mma_ds_params.producer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_compute_mma_ds_params.consumer_arv_count = 1; + pipeline_compute_mma_ds_params.initializing_warp = initializing_warp++; + PipelineComputeMmaDS pipeline_compute_mma_ds( + shared_storage.pipelines.compute_mma_ds, + pipeline_compute_mma_ds_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineMmaComputeDKDV::Params pipeline_mma_compute_dkdv_params; + if (role == WarpRole::Mma) { + pipeline_mma_compute_dkdv_params.role = PipelineMmaComputeDKDV::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_mma_compute_dkdv_params.role = PipelineMmaComputeDKDV::ThreadCategory::Consumer; + } + pipeline_mma_compute_dkdv_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_mma_compute_dkdv_params.initializing_warp = initializing_warp++; + PipelineMmaComputeDKDV pipeline_mma_compute_dkdv( + shared_storage.pipelines.mma_compute_dkdv, + pipeline_mma_compute_dkdv_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + PipelineReduceTmaStore pipeline_reduce_tma_store; + + TmemAllocator tmem_allocator; + + pipeline_init_arrive_relaxed(size(ClusterShape{})); + + pipeline_load_mma_q.init_masks(ClusterShape{}); + pipeline_load_mma_do.init_masks(ClusterShape{}); + pipeline_mma_compute_s.init_masks(ClusterShape{}); + pipeline_mma_compute_dp.init_masks(ClusterShape{}); + pipeline_mma_reduce_dq.init_masks(ClusterShape{}); + pipeline_compute_mma_p.init_masks(ClusterShape{}); + pipeline_compute_mma_ds.init_masks(ClusterShape{}); + pipeline_mma_compute_dkdv.init_masks(ClusterShape{}); + + typename decltype(pipeline_load_mma_q)::PipelineState pipeline_load_mma_q_consumer_state; + typename decltype(pipeline_load_mma_do)::PipelineState pipeline_load_mma_do_consumer_state; + typename decltype(pipeline_load_compute_lse)::PipelineState pipeline_load_compute_lse_consumer_state; + typename decltype(pipeline_load_compute_sum_odo)::PipelineState pipeline_load_compute_sum_odo_consumer_state; + typename decltype(pipeline_mma_compute_s)::PipelineState pipeline_mma_compute_s_consumer_state; + typename decltype(pipeline_mma_compute_dp)::PipelineState pipeline_mma_compute_dp_consumer_state; + typename decltype(pipeline_mma_reduce_dq)::PipelineState pipeline_mma_reduce_dq_consumer_state; + typename decltype(pipeline_compute_mma_p)::PipelineState pipeline_compute_mma_p_consumer_state; + typename decltype(pipeline_compute_mma_ds)::PipelineState pipeline_compute_mma_ds_consumer_state; + typename decltype(pipeline_mma_compute_dkdv)::PipelineState pipeline_mma_compute_dkdv_consumer_state; + + auto pipeline_load_mma_q_producer_state = make_producer_start_state(); + auto pipeline_load_mma_do_producer_state = make_producer_start_state(); + auto pipeline_load_compute_lse_producer_state = make_producer_start_state(); + auto pipeline_load_compute_sum_odo_producer_state = make_producer_start_state(); + auto pipeline_mma_compute_s_producer_state = make_producer_start_state(); + auto pipeline_mma_compute_dp_producer_state = make_producer_start_state(); + auto pipeline_mma_reduce_dq_producer_state = make_producer_start_state(); + auto pipeline_compute_mma_p_producer_state = make_producer_start_state(); + auto pipeline_compute_mma_ds_producer_state = make_producer_start_state(); + auto pipeline_mma_compute_dkdv_producer_state = make_producer_start_state(); + auto pipeline_reduce_tma_store_producer_state = make_producer_start_state(); + + pipeline_init_wait(size(ClusterShape{})); + + auto blk_coord = make_coord(_0{}, blockIdx.x, _0{}, _0{}, make_coord(blockIdx.y, blockIdx.z)); + auto [problem_shape, blk_offset] = apply_variable_length_offset( + params.problem_shape, + blk_coord + ); + int iter_count = ceil_div(get<0>(problem_shape), TileShapeQ{}); + int iter_start = 0; + if constexpr (std::is_base_of_v, Mask>) { + iter_start = (get<1>(blk_coord) * TileShapeK{}) / TileShapeQ{}; + } else if constexpr (std::is_base_of_v, Mask>) { + int offset = get<1>(problem_shape) - get<0>(problem_shape); + iter_start = max(0, (int(get<1>(blk_coord) * TileShapeK{}) - offset) / (int)TileShapeQ{}); + } + if (get<1>(blk_coord) * TileShapeK{} >= get<1>(problem_shape)) { + return; + } + iter_count -= iter_start; + + if (iter_count <= 0) { + epilogue_clear( + blk_coord, + blk_offset, + problem_shape, + params.mainloop, + params.epilogue + ); + return; + } + + if (role == WarpRole::Load) { + warpgroup_reg_set(); + + load( + blk_coord, + blk_offset, + problem_shape, + iter_start, + iter_count, + params.mainloop, + params.mainloop_params, + shared_storage.tensors, + pipeline_load_mma_q, pipeline_load_mma_q_producer_state, + pipeline_load_mma_do, pipeline_load_mma_do_producer_state, + pipeline_load_compute_lse, pipeline_load_compute_lse_producer_state, + pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_producer_state + ); + + } + else if (role == WarpRole::Mma) { + warpgroup_reg_set(); + + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + + mma( + blk_coord, + problem_shape, + iter_start, + iter_count, + params.mainloop, + shared_storage.tensors, + pipeline_load_mma_q, pipeline_load_mma_q_consumer_state, + pipeline_load_mma_do, pipeline_load_mma_do_consumer_state, + pipeline_mma_compute_s, pipeline_mma_compute_s_producer_state, + pipeline_mma_compute_dp, pipeline_mma_compute_dp_producer_state, + pipeline_mma_reduce_dq, pipeline_mma_reduce_dq_producer_state, + pipeline_compute_mma_p, pipeline_compute_mma_p_consumer_state, + pipeline_compute_mma_ds, pipeline_compute_mma_ds_consumer_state, + pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_producer_state + ); + + } + else if (role == WarpRole::Compute) { + warpgroup_reg_set(); + + compute( + blk_coord, + blk_offset, + problem_shape, + iter_start, + iter_count, + params.mainloop, + params.epilogue, + shared_storage.tensors, + pipeline_load_compute_lse, pipeline_load_compute_lse_consumer_state, + pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_consumer_state, + pipeline_mma_compute_s, pipeline_mma_compute_s_consumer_state, + pipeline_mma_compute_dp, pipeline_mma_compute_dp_consumer_state, + pipeline_compute_mma_p, pipeline_compute_mma_p_producer_state, + pipeline_compute_mma_ds, pipeline_compute_mma_ds_producer_state, + pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_consumer_state + ); + + cutlass::arch::NamedBarrier( + kNumComputeWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier + ).arrive_and_wait(); + + if (warp_idx % kNumComputeWarps == 0) { + uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + + } + else if (role == WarpRole::Reduce) { + warpgroup_reg_set(); + + reduce( + blk_coord, + problem_shape, + iter_start, + iter_count, + params.mainloop, + params.mainloop_params, + shared_storage.tensors, + pipeline_mma_reduce_dq, pipeline_mma_reduce_dq_consumer_state, + pipeline_reduce_tma_store, pipeline_reduce_tma_store_producer_state + ); + + pipeline_reduce_tma_store.producer_tail(pipeline_reduce_tma_store_producer_state); + } + else { + warpgroup_reg_set(); + + /* no-op */ + + } + } + + static dim3 get_block_shape() { + dim3 block(MaxThreadsPerBlock, 1, 1); + return block; + } + + static dim3 get_grid_shape(Params const& params) { + auto [Q, K, D, D_VO, HB] = params.problem_shape; + auto [H, B] = HB; + dim3 grid(ceil_div(K, TileShapeK{}), H, B); + return grid; + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/csrc/sm100/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp b/csrc/sm100/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp new file mode 100644 index 0000000..5a58157 --- /dev/null +++ b/csrc/sm100/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp @@ -0,0 +1,1834 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cute/arch/simd_sm100.hpp" + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "collective/fmha_common.hpp" + +#include + +namespace cutlass::fmha::kernel { + +using namespace cutlass::fmha::collective; + +using namespace cute; + +template< + class ProblemShape, + class Element, + class ElementAcc, + class TileShape, + class Mask +> +struct Sm100FmhaBwdMlaKernelTmaWarpSpecialized { + + using TileShapeQ = decltype(get<0>(TileShape{})); + using TileShapeK = decltype(get<1>(TileShape{})); + using TileShapeDQK = decltype(get<2>(TileShape{})); + using TileShapeDVO = decltype(get<3>(TileShape{})); + + using TmemAllocator = cute::TMEM::Allocator1Sm; + struct TmemAllocation { + static constexpr uint32_t kDK = 0; // TileShapeK x TileShapeDQK x acc + static constexpr uint32_t kDV = kDK + TileShapeDQK{}; // TileShapeK x TileShapeDVO x acc + static constexpr uint32_t kDQ = kDV + TileShapeDVO{}; // TileShapeQ x TileShapeDQK x acc + static constexpr uint32_t kDP = kDQ; // TileShapeK x TileShapeQ x inp + static constexpr uint32_t kS = kDQ + 65536 * 16; + static constexpr uint32_t kP = kS; + static constexpr uint32_t kTotal = kDQ + TileShapeDQK{}; + }; + + static_assert( + static_cast(TmemAllocation::kTotal) <= TmemAllocator::Sm100TmemCapacityColumns, + "using too much tmem" + ); + + enum class WarpRole { + Empty = 0x0, Load = 0x1, Mma = 0x2, Compute = 0x3, Reduce = 0x4 + }; + + static constexpr unsigned long long kWarpAssignment = 0x12'3333'3333'4444ull; + static constexpr int kNumComputeWarps = 8; + static constexpr int kNumReduceWarps = 4; + + static constexpr int kLoadPerThread = TileShapeQ{} / NumThreadsPerWarp; + static_assert(TileShapeQ{} % NumThreadsPerWarp == 0, "TileShapeQ must be divisible by NumThreadsPerWarp"); + CUTLASS_DEVICE WarpRole warp_idx_to_role(int warp_idx) { + return static_cast((kWarpAssignment >> (4 * warp_idx)) & 0xF); + } + + struct RegisterAllocation { + static constexpr int kWarpgroup0 = 160-8; + static constexpr int kWarpgroup1 = 128; + static constexpr int kWarpgroup2 = 96; + static constexpr int kReduce = kWarpgroup0; + static constexpr int kCompute = kWarpgroup1; + static constexpr int kMma = kWarpgroup2; + static constexpr int kEmpty = kWarpgroup2; + static constexpr int kLoad = kWarpgroup2; + + static_assert(kWarpgroup0 + 2 * kWarpgroup1 + kWarpgroup2 <= 512); + }; + + using ArchTag = cutlass::arch::Sm100; + + using ClusterShape = Shape<_1, _1, _1>; + using Schedule = cutlass::gemm::KernelTmaWarpSpecialized1SmSm100; + + static constexpr int MinBlocksPerMultiprocessor = 1; + static constexpr int kNumWarps = kNumComputeWarps + kNumReduceWarps + 4; + static constexpr int MaxThreadsPerBlock = NumThreadsPerWarp * kNumWarps; + + static constexpr int Alignment = 128 / sizeof_bits_v; + static constexpr int kStages = 2; + + using TensorStrideContiguousK = Stride>; + using TensorStrideContiguousMN = Stride<_1, int, Stride>; + + // compute S + using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, TensorStrideContiguousK, Alignment, + Element, TensorStrideContiguousK, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapeQK = typename CollectiveMmaQK::TileShape; + using TiledMmaQK = typename CollectiveMmaQK::TiledMma; + + // compute dP + using CollectiveMmaDOV = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, TensorStrideContiguousK, Alignment, + Element, TensorStrideContiguousK, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapeDOV = typename CollectiveMmaDOV::TileShape; + using TiledMmaDOV = typename CollectiveMmaDOV::TiledMma; + + // compute dV + using CollectiveMmaPDO = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // needs to match ordering of S calculation + Element, TensorStrideContiguousK, Alignment, + Element, TensorStrideContiguousMN, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapePDO = typename CollectiveMmaPDO::TileShape; + using TiledMmaPDO = typename CollectiveMmaPDO::TiledMma; + + // compute dK + using CollectiveMmaDSQ = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // somewhat arbitrary since we dump to smem, need to agree with the next one + Element, TensorStrideContiguousK , Alignment, + Element, TensorStrideContiguousMN, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapeDSQ = typename CollectiveMmaDSQ::TileShape; + using TiledMmaDSQ = typename CollectiveMmaDSQ::TiledMma; + + // compute dQ + using CollectiveMmaDSK = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // somewhat arbitrary since we dump to smem, need to agree with the previous one + Element, TensorStrideContiguousMN, Alignment, + Element, TensorStrideContiguousMN, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapeDSK = typename CollectiveMmaDSK::TileShape; + using TiledMmaDSK = typename CollectiveMmaDSK::TiledMma; + + // pipelines are named Pipeline + static constexpr int kStagesComputeSmem = 1; + using PipelineLoadMmaQ = PipelineTmaUmmaAsync<2, ClusterShape>; + using PipelineLoadMmaDO = PipelineTmaUmmaAsync<1, ClusterShape>; + using PipelineLoadComputeLSE = PipelineAsync<1>; + using PipelineLoadComputeSumOdO = PipelineAsync<1>; + using PipelineMmaComputeS = PipelineUmmaAsync<1>; + using PipelineMmaComputeDP = PipelineUmmaAsync<1>; + using PipelineMmaReduceDQ = PipelineUmmaAsync<1>; + using PipelineComputeMmaP = PipelineUmmaConsumerAsync<1>; + using PipelineComputeMmaDS = PipelineUmmaConsumerAsync; + using PipelineMmaComputeDKDV = PipelineUmmaAsync<2>; + static constexpr int kStagesReduceTmaStore = 2; + using PipelineReduceTmaStore = PipelineTmaStore; + + struct PipelineStorage { + alignas(16) typename PipelineLoadMmaQ::SharedStorage load_mma_q; + alignas(16) typename PipelineLoadMmaDO::SharedStorage load_mma_do; + alignas(16) typename PipelineLoadComputeLSE::SharedStorage load_compute_lse; + alignas(16) typename PipelineLoadComputeSumOdO::SharedStorage load_compute_sum_odo; + alignas(16) typename PipelineMmaComputeS::SharedStorage mma_compute_s; + alignas(16) typename PipelineMmaComputeDP::SharedStorage mma_compute_dp; + alignas(16) typename PipelineMmaReduceDQ::SharedStorage mma_reduce_dq; + alignas(16) typename PipelineComputeMmaP::SharedStorage compute_mma_p; + alignas(16) typename PipelineComputeMmaDS::SharedStorage compute_mma_ds; + alignas(16) typename PipelineMmaComputeDKDV::SharedStorage mma_compute_dkdv; + }; + + template + static CUTE_DEVICE constexpr auto restage(Layout const& layout, Stages stages = {}) { + return composition(layout, make_tuple(_, _, _, make_layout(stages))); + } + + using SmemLayoutK = decltype(restage(typename CollectiveMmaQK::SmemLayoutB{})); + using SmemLayoutV = decltype(restage(typename CollectiveMmaDOV::SmemLayoutB{})); + using SmemLayoutQ = decltype(restage(typename CollectiveMmaQK::SmemLayoutA{}, _2{})); + using SmemLayoutDO = decltype(restage(typename CollectiveMmaDOV::SmemLayoutA{}, _1{})); + using SmemLayoutDS = decltype(restage(typename CollectiveMmaDSK::SmemLayoutA{}, Int{})); + using SmemLayoutLSE = Layout>; + using SmemLayoutSumOdO = Layout>; + + using SmemLayoutQT = decltype(restage(typename CollectiveMmaDSQ::SmemLayoutB{}, _2{})); + using SmemLayoutKT = decltype(restage(typename CollectiveMmaDSK::SmemLayoutB{})); + using SmemLayoutDST = decltype(restage(typename CollectiveMmaDSQ::SmemLayoutA{}, Int{})); + using SmemLayoutDOT = decltype(restage(typename CollectiveMmaPDO::SmemLayoutB{}, _1{})); + using SmemLayoutP = decltype(restage(typename CollectiveMmaPDO::SmemLayoutA{}, _1{})); + using SmemLayoutPT = decltype(restage(typename CollectiveMmaDSK::SmemLayoutA{}, _1{})); + + using TileShapeDQ = _32; + using SmemAtomDQ = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + cute::UMMA::Major::K, ElementAcc, TileShapeQ, TileShapeDQ + >()); + using SmemShapeDQ = Shape>; + using SmemLayoutDQ = decltype(tile_to_shape(SmemAtomDQ{}, SmemShapeDQ{}, Step<_2, _1, _3>{})); + + struct TensorStorage { + union { + alignas(2048) cute::array> smem_k; + alignas(2048) cute::array> smem_k_t; + }; + alignas(2048) cute::array> smem_v; + union { + alignas(2048) cute::array> smem_q; + alignas(2048) cute::array> smem_q_t; + }; + union { + alignas(2048) cute::array> smem_do; + alignas(2048) cute::array> smem_do_t; + }; + union { + alignas(2048) cute::array> smem_ds; + alignas(2048) cute::array> smem_ds_t; + }; + union{ + alignas(2048) cute::array> smem_p; + alignas(2048) cute::array> smem_p_t; + }; + alignas(1024) cute::array> smem_dq; + alignas(16) cute::array> smem_lse; + alignas(16) cute::array> smem_sum_odo; + }; + + static constexpr int kTransactionsBytesLoadQ = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v); + static constexpr int kTransactionsBytesLoadDO = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutDO{})) * cute::sizeof_bits_v); + + static constexpr int kTransactionsBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v); + static constexpr int kTransactionsBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v); + + struct SharedStorage { + TensorStorage tensors; + PipelineStorage pipelines; + uint32_t tmem_base_ptr; + }; + + // this is tight enough that it won't work with sizeof due to padding for alignment + static constexpr int SharedStorageSize = offsetof(SharedStorage, tmem_base_ptr) + sizeof(uint32_t); + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "using too much smem"); + + using TensorStride = TensorStrideContiguousK; // S D (H B) + using RowTensorStride = Stride<_1, Stride>; // S (H B) + + struct MainloopArguments { + const Element* ptr_q; + TensorStride stride_q; + const Element* ptr_k; + TensorStride stride_k; + const Element* ptr_v; + TensorStride stride_v; + const Element* ptr_do; + TensorStride stride_do; + + const ElementAcc* ptr_lse; + RowTensorStride stride_lse; + + const ElementAcc* ptr_sum_odo; + RowTensorStride stride_sum_odo; + + ElementAcc* ptr_dq_acc; + TensorStride stride_dq_acc; + + ElementAcc softmax_scale = 1.0f / sqrtf(TileShapeDQK{}); + }; + + using TMA_K = typename CollectiveMmaQK::Params::TMA_B; + using TMA_V = typename CollectiveMmaDOV::Params::TMA_B; + using TMA_Q = typename CollectiveMmaQK::Params::TMA_A; + using TMA_DO = typename CollectiveMmaDOV::Params::TMA_A; + + using TMA_DQ = decltype(make_tma_copy(SM90_TMA_REDUCE_ADD{}, + make_tensor((const ElementAcc*)nullptr, make_shape(1, 1, make_shape(1, 1)), TensorStride{}), + SmemLayoutDQ{}(_, _, _0{}) + )); + + struct MainloopParams { + TMA_K tma_load_k; + TMA_V tma_load_v; + TMA_Q tma_load_q; + TMA_DO tma_load_do; + TMA_DQ tma_red_dq; + }; + + struct EpilogueArguments { + Element* ptr_dk; + TensorStride stride_dk; + Element* ptr_dv; + TensorStride stride_dv; + }; + + struct Arguments { + ProblemShape problem_shape; + MainloopArguments mainloop; + EpilogueArguments epilogue; + KernelHardwareInfo hw_info; + }; + + struct Params { + ProblemShape problem_shape; + MainloopArguments mainloop; + MainloopParams mainloop_params; + EpilogueArguments epilogue; + KernelHardwareInfo hw_info; + }; + + + static bool can_implement(Arguments const& args) { + auto [Q, K, D, D_VO, HB] = args.problem_shape; + auto [H, B] = HB; + if (Q <= 0 || K <= 0 || D <= 0 || H <= 0 || B <= 0 || D_VO <= 0) { + return false; + } + if (D % Alignment != 0 || D_VO % Alignment != 0) { + return false; + } + return true; + } + + + static Status initialize_workspace(Arguments const&, void*, cudaStream_t) { + return Status::kSuccess; + } + + + static Params to_underlying_arguments(Arguments const& args, void*) { + auto [Q_, K_, D, D_VO, HB] = args.problem_shape; + int Q = Q_; + int K = K_; + + if constexpr (is_variable_length_v) { + Q = Q_.total_length; + } + if constexpr (is_variable_length_v) { + K = K_.total_length; + } + + auto params_kq = CollectiveMmaQK::to_underlying_arguments( + make_shape(Q, K, D, HB), + typename CollectiveMmaQK::Arguments { + args.mainloop.ptr_q, args.mainloop.stride_q, + args.mainloop.ptr_k, args.mainloop.stride_k, + }, /*workspace=*/nullptr); + + auto params_vdo = CollectiveMmaDOV::to_underlying_arguments( + make_shape(Q, K, D_VO, HB), + typename CollectiveMmaDOV::Arguments { + args.mainloop.ptr_do, args.mainloop.stride_do, + args.mainloop.ptr_v, args.mainloop.stride_v, + }, /*workspace=*/nullptr); + + TMA_DQ tma_red_dq = make_tma_copy( + SM90_TMA_REDUCE_ADD{}, + make_tensor(args.mainloop.ptr_dq_acc, make_shape(Q_, D, HB), args.mainloop.stride_dq_acc), + SmemLayoutDQ{}(_, _, _0{}) + ); + + return Params{ + args.problem_shape, + args.mainloop, + MainloopParams{ + params_kq.tma_load_b, + params_vdo.tma_load_b, + params_kq.tma_load_a, + params_vdo.tma_load_a, + tma_red_dq + }, + args.epilogue, + args.hw_info + }; + } + + + template + static CUTLASS_DEVICE auto quantize(T const& input) { + constexpr int AlignmentS = 4; + auto output = make_tensor(shape(input)); + auto input_vec = recast>(input); + auto output_vec = recast>(output); + + cutlass::NumericArrayConverter epilogue_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(input_vec); i++) { + output_vec(i) = epilogue_op(input_vec(i)); + } + + return output; + } + + + template + CUTLASS_DEVICE void load( + BlkCoord const& blk_coord, + BlkOffset const& blk_offset, + ProblemShape_ const& problem_shape, + int iter_index, + int iter_count, + MainloopArguments const& mainloop_args, + MainloopParams const& mainloop_params, + TensorStorage& shared_tensors, + PipelineLoadMmaQ& pipeline_load_mma_q, + typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_producer_state, + PipelineLoadMmaDO& pipeline_load_mma_do, + typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_producer_state, + PipelineLoadComputeLSE& pipeline_load_compute_lse, + typename PipelineLoadComputeLSE::PipelineState& pipeline_load_compute_lse_producer_state, + PipelineLoadComputeSumOdO& pipeline_load_compute_sum_odo, + typename PipelineLoadComputeSumOdO::PipelineState& pipeline_load_compute_sum_odo_producer_state) { + + auto [Q, K, D, D_VO, HB] = problem_shape; + + using X = Underscore; + + uint16_t mcast_mask = 0; + + auto mK_in = mainloop_params.tma_load_k.get_tma_tensor(make_shape(K, D, HB)); + auto mV_in = mainloop_params.tma_load_v.get_tma_tensor(make_shape(K, D_VO, HB)); + auto mQ_in = mainloop_params.tma_load_q.get_tma_tensor(make_shape(Q, D, HB)); + auto mDO_in = mainloop_params.tma_load_do.get_tma_tensor(make_shape(Q, D_VO, HB)); + + auto mK = domain_offset(select<1,2,4>(blk_offset), mK_in); + auto mV = domain_offset(select<1,3,4>(blk_offset), mV_in); + auto mQ = domain_offset(select<0,2,4>(blk_offset), mQ_in); + auto mDO = domain_offset(select<0,3,4>(blk_offset), mDO_in); + + auto gK = local_tile(mK, TileShapeQK{}, make_coord(_,_,_), Step{}); + auto gQ = local_tile(mQ, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{}); + auto gV = local_tile(mV, TileShapeDOV{}, make_coord(_,_,_), Step{}); + auto gDO = local_tile(mDO, TileShapeDOV{}, make_coord(_,_,_), Step<_1, X, _1>{}); + + ThrMMA cta_mma_kq = TiledMmaQK{}.get_slice(_0{}); + ThrMMA cta_mma_vdo = TiledMmaDOV{}.get_slice(_0{}); + + auto tSTgK = cta_mma_kq.partition_B(gK); + auto tSTgQ = cta_mma_kq.partition_A(gQ); + auto tDPTgV = cta_mma_vdo.partition_B(gV); + auto tDPTgDO = cta_mma_vdo.partition_A(gDO); + + auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + auto sK = make_tensor(make_smem_ptr(shared_tensors.smem_k.begin()), SmemLayoutK{}); + auto sV = make_tensor(make_smem_ptr(shared_tensors.smem_v.begin()), SmemLayoutV{}); + auto sDO = make_tensor(make_smem_ptr(shared_tensors.smem_do.begin()), SmemLayoutDO{}); + + auto [tKgK_mkl, tKsK] = tma_partition( + mainloop_params.tma_load_k, _0{}, make_layout(_1{}), + group_modes<0,3>(sK), group_modes<0,3>(tSTgK)); + auto [tQgQ_mkl, tQsQ] = tma_partition( + mainloop_params.tma_load_q, _0{}, make_layout(_1{}), + group_modes<0,3>(sQ), group_modes<0,3>(tSTgQ)); + auto [tVgV_mkl, tVsV] = tma_partition( + mainloop_params.tma_load_v, _0{}, make_layout(_1{}), + group_modes<0,3>(sV), group_modes<0,3>(tDPTgV)); + auto [tDOgDO_mkl, tDOsDO] = tma_partition( + mainloop_params.tma_load_do, _0{}, make_layout(_1{}), + group_modes<0,3>(sDO), group_modes<0,3>(tDPTgDO)); + + // set up lse and sum_odo + + auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; + + pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state); + auto tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state); + + pipeline_load_mma_q.producer_expect_transaction(pipeline_load_mma_q_producer_state, kTransactionsBytesLoadK); + + // load K + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_k.with(*tma_barrier, mcast_mask), + tKgK_mkl(_, blk_coord_k, _0{}, blk_coord_batch), + tKsK(_, _0{}) + ); + } + + // load Q + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask), + tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch), + tQsQ(_, pipeline_load_mma_q_producer_state.index()) + ); + } + + ++pipeline_load_mma_q_producer_state; + + pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state); + + // load LSE + // 32 threads loading kLoadPerThread * 32 values of 32b each + + int thread_idx = threadIdx.x % NumThreadsPerWarp; + int smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * kLoadPerThread; + int gmem_idx = TileShapeQ{} * iter_index + thread_idx * kLoadPerThread; + auto mLSE = make_tensor(mainloop_args.ptr_lse, make_shape(Q, HB), mainloop_args.stride_lse); + for (int i = 0; i < kLoadPerThread; i++) { + cutlass::arch::cp_async_zfill<4>( + shared_tensors.smem_lse.begin() + smem_idx + i, + &mLSE(gmem_idx + i, blk_coord_batch), + gmem_idx + i < Q + ); + } + + pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_load_compute_lse_producer_state; + + + pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state); + tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state); + + pipeline_load_mma_do.producer_expect_transaction(pipeline_load_mma_do_producer_state, kTransactionsBytesLoadV); + + // load V + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_v.with(*tma_barrier, mcast_mask), + tVgV_mkl(_, blk_coord_k, _0{}, blk_coord_batch), + tVsV(_, _0{}) + ); + } + + // load dO + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask), + tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch), + tDOsDO(_, pipeline_load_mma_do_producer_state.index()) + ); + } + + ++pipeline_load_mma_do_producer_state; + + pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state); + + // load sum_OdO + smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * kLoadPerThread; + gmem_idx = TileShapeQ{} * iter_index + thread_idx * kLoadPerThread; + auto mSumOdO = make_tensor(mainloop_args.ptr_sum_odo, make_shape(Q, HB), mainloop_args.stride_sum_odo); + for (int i = 0; i < kLoadPerThread; i++) { + cutlass::arch::cp_async_zfill<4>( + shared_tensors.smem_sum_odo.begin() + smem_idx + i, + &mSumOdO(gmem_idx + i, blk_coord_batch), + gmem_idx + i < Q + ); + } + + pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_load_compute_sum_odo_producer_state; + + iter_count -= 1; + iter_index += 1; + + while (iter_count > 0) { + pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state); + tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state); + + // load Q + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask), + tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch), + tQsQ(_, pipeline_load_mma_q_producer_state.index()) + ); + } + + ++pipeline_load_mma_q_producer_state; + + pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state); + + // load LSE + smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * kLoadPerThread; + gmem_idx = TileShapeQ{} * iter_index + thread_idx * kLoadPerThread; + for (int i = 0; i < kLoadPerThread; i++) { + cutlass::arch::cp_async_zfill<4>( + shared_tensors.smem_lse.begin() + smem_idx + i, + &mLSE(gmem_idx + i, blk_coord_batch), + gmem_idx + i < Q + ); + } + + pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_load_compute_lse_producer_state; + + pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state); + tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state); + + // load dO + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask), + tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch), + tDOsDO(_, pipeline_load_mma_do_producer_state.index()) + ); + } + + ++pipeline_load_mma_do_producer_state; + + pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state); + + // load sum_OdO + smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * kLoadPerThread; + gmem_idx = TileShapeQ{} * iter_index + thread_idx * kLoadPerThread; + for (int i = 0; i < kLoadPerThread; i++) { + cutlass::arch::cp_async_zfill<4>( + shared_tensors.smem_sum_odo.begin() + smem_idx + i, + &mSumOdO(gmem_idx + i, blk_coord_batch), + gmem_idx + i < Q + ); + } + + pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_load_compute_sum_odo_producer_state; + + iter_count -= 1; + iter_index += 1; + } + } + + + template + CUTLASS_DEVICE void mma( + BlkCoord const& blk_coord, + ProblemShape_ const& problem_shape, + int iter_index, + int iter_count, + MainloopArguments const& mainloop_args, + TensorStorage& shared_tensors, + PipelineLoadMmaQ& pipeline_load_mma_q, + typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_consumer_state, + PipelineLoadMmaDO& pipeline_load_mma_do, + typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_consumer_state, + PipelineMmaComputeS& pipeline_mma_compute_s, + typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_producer_state, + PipelineMmaComputeDP& pipeline_mma_compute_dp, + typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_producer_state, + PipelineMmaReduceDQ& pipeline_mma_reduce_dq, + typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_producer_state, + PipelineComputeMmaP& pipeline_compute_mma_p, + typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_consumer_state, + PipelineComputeMmaDS& pipeline_compute_mma_ds, + typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_consumer_state, + PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, + typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_producer_state) { + + auto [Q, K, D, D_VO, HB] = problem_shape; + + auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + auto sK = make_tensor(make_smem_ptr(shared_tensors.smem_k.begin()), SmemLayoutK{}); + auto sV = make_tensor(make_smem_ptr(shared_tensors.smem_v.begin()), SmemLayoutV{}); + auto sDO = make_tensor(make_smem_ptr(shared_tensors.smem_do.begin()), SmemLayoutDO{}); + + auto sQT = make_tensor(make_smem_ptr(shared_tensors.smem_q_t.begin()), SmemLayoutQT{}); + auto sKT = make_tensor(make_smem_ptr(shared_tensors.smem_k_t.begin()), SmemLayoutKT{}); + auto sDS = make_tensor(make_smem_ptr(shared_tensors.smem_ds.begin()), SmemLayoutDS{}); + auto sDST = make_tensor(make_smem_ptr(shared_tensors.smem_ds_t.begin()), SmemLayoutDST{}); + auto sP = make_tensor(make_smem_ptr(shared_tensors.smem_p.begin()), SmemLayoutP{}); + auto sDOT = make_tensor(make_smem_ptr(shared_tensors.smem_do_t.begin()), SmemLayoutDOT{}); + + Tensor tSTrK = TiledMmaQK::make_fragment_B(sK); + Tensor tSTrQ = TiledMmaQK::make_fragment_A(sQ); + + Tensor tDPTrV = TiledMmaDOV::make_fragment_B(sV); + Tensor tDPTrDO = TiledMmaDOV::make_fragment_A(sDO); + + Tensor tDQrDS = TiledMmaDSK::make_fragment_A(sDS); + Tensor tDQrKT = TiledMmaDSK::make_fragment_B(sKT); + + Tensor tDKrDST = TiledMmaDSQ::make_fragment_A(sDST); + Tensor tDKrQT = TiledMmaDSQ::make_fragment_B(sQT); + + Tensor tDVrP = TiledMmaPDO::make_fragment_A(sP); + Tensor tDVrDOT = TiledMmaPDO::make_fragment_B(sDOT); + + TiledMmaQK tiled_mma_qk; + TiledMmaDOV tiled_mma_dov; + TiledMmaDSK tiled_mma_dsk; + TiledMmaDSQ tiled_mma_dsq; + TiledMmaPDO tiled_mma_pdo; + + tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::Zero; + tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::Zero; + + Tensor tSTtST = partition_fragment_C(tiled_mma_qk, select<0,1>(TileShapeQK{})); + tSTtST.data() = TmemAllocation::kS; + + Tensor tDPTtDPT = partition_fragment_C(tiled_mma_dov, select<0,1>(TileShapeDOV{})); + tDPTtDPT.data() = TmemAllocation::kDP; + + Tensor tDQtDQ = partition_fragment_C(tiled_mma_dsk, select<0,1>(TileShapeDSK{})); + tDQtDQ.data() = TmemAllocation::kDQ; + + Tensor tDKtDK = partition_fragment_C(tiled_mma_dsq, select<0,1>(TileShapeDSQ{})); + tDKtDK.data() = TmemAllocation::kDK; + + Tensor tDVtDV = partition_fragment_C(tiled_mma_pdo, select<0,1>(TileShapePDO{})); + tDVtDV.data() = TmemAllocation::kDV; + + auto pipeline_load_mma_q_release_state = pipeline_load_mma_q_consumer_state; + + pipeline_load_mma_q.consumer_wait(pipeline_load_mma_q_consumer_state); + pipeline_mma_compute_s.producer_acquire(pipeline_mma_compute_s_producer_state); + + // S = Q*K + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tSTrQ); ++k_block) { + cute::gemm(tiled_mma_qk, + tSTrQ(_,_,k_block,pipeline_load_mma_q_consumer_state.index()), + tSTrK(_,_,k_block,_0{}), + tSTtST); + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::One; + } + + ++pipeline_load_mma_q_consumer_state; + + pipeline_mma_compute_s.producer_commit(pipeline_mma_compute_s_producer_state); + ++pipeline_mma_compute_s_producer_state; + + pipeline_load_mma_do.consumer_wait(pipeline_load_mma_do_consumer_state); + + pipeline_mma_compute_dp.producer_acquire(pipeline_mma_compute_dp_producer_state); + pipeline_mma_reduce_dq.producer_acquire(pipeline_mma_reduce_dq_producer_state); + + // dP = dO*V + tiled_mma_dov.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDPTrV); ++k_block) { + cute::gemm(tiled_mma_dov, + tDPTrDO(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), + tDPTrV(_,_,k_block,_0{}), + tDPTtDPT); + tiled_mma_dov.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_mma_compute_dp.producer_commit(pipeline_mma_compute_dp_producer_state); + ++pipeline_mma_compute_dp_producer_state; + + pipeline_compute_mma_p.consumer_wait(pipeline_compute_mma_p_consumer_state); + + // dV = P*dO + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDVrP); ++k_block) { + cute::gemm(tiled_mma_pdo, + tDVrP(_,_,k_block,_0{}), + tDVrDOT(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), + tDVtDV); + tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_compute_mma_p.consumer_release(pipeline_compute_mma_p_consumer_state); + ++pipeline_compute_mma_p_consumer_state; + + pipeline_load_mma_do.consumer_release(pipeline_load_mma_do_consumer_state); + ++pipeline_load_mma_do_consumer_state; + + iter_count -= 1; + + // in tmem, S & P overlap + // and dP and dQ overlap + // so we need to acquire dQ and dP at the same time + while (iter_count > 0) { + pipeline_load_mma_q.consumer_wait(pipeline_load_mma_q_consumer_state); + pipeline_mma_compute_s.producer_acquire(pipeline_mma_compute_s_producer_state); + + // S = Q*K + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tSTrQ); ++k_block) { + cute::gemm(tiled_mma_qk, + tSTrQ(_,_,k_block,pipeline_load_mma_q_consumer_state.index()), + tSTrK(_,_,k_block,_0{}), + tSTtST); + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::One; + } + + ++pipeline_load_mma_q_consumer_state; + + pipeline_mma_compute_s.producer_commit(pipeline_mma_compute_s_producer_state); + ++pipeline_mma_compute_s_producer_state; + + pipeline_compute_mma_ds.consumer_wait(pipeline_compute_mma_ds_consumer_state); + + // we need to acquire dP here, because tmem dQ == tmem dP + pipeline_mma_compute_dp.producer_acquire(pipeline_mma_compute_dp_producer_state); + + // dQ = dS*K + tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDQrDS); ++k_block) { + cute::gemm(tiled_mma_dsk, + tDQrDS(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), + tDQrKT(_,_,k_block,_0{}), + tDQtDQ); + tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_mma_reduce_dq.producer_commit(pipeline_mma_reduce_dq_producer_state); + ++pipeline_mma_reduce_dq_producer_state; + + // dK = dS*Q + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDKrDST); ++k_block) { + cute::gemm(tiled_mma_dsq, + tDKrDST(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), + tDKrQT(_,_,k_block,pipeline_load_mma_q_release_state.index()), + tDKtDK); + tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_load_mma_q.consumer_release(pipeline_load_mma_q_release_state); + ++pipeline_load_mma_q_release_state; + + pipeline_compute_mma_ds.consumer_release(pipeline_compute_mma_ds_consumer_state); + ++pipeline_compute_mma_ds_consumer_state; + + // we grab dq here, because in tmem dq == dp + pipeline_mma_reduce_dq.producer_acquire(pipeline_mma_reduce_dq_producer_state); + + pipeline_load_mma_do.consumer_wait(pipeline_load_mma_do_consumer_state); + + // dP = dO*V + tiled_mma_dov.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDPTrV); ++k_block) { + cute::gemm(tiled_mma_dov, + tDPTrDO(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), + tDPTrV(_,_,k_block,_0{}), + tDPTtDPT); + tiled_mma_dov.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_mma_compute_dp.producer_commit(pipeline_mma_compute_dp_producer_state); + ++pipeline_mma_compute_dp_producer_state; + + pipeline_compute_mma_p.consumer_wait(pipeline_compute_mma_p_consumer_state); + + // dV = P*dO + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDVrP); ++k_block) { + cute::gemm(tiled_mma_pdo, + tDVrP(_,_,k_block,_0{}), + tDVrDOT(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), + tDVtDV); + tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_compute_mma_p.consumer_release(pipeline_compute_mma_p_consumer_state); + ++pipeline_compute_mma_p_consumer_state; + + pipeline_load_mma_do.consumer_release(pipeline_load_mma_do_consumer_state); + ++pipeline_load_mma_do_consumer_state; + + iter_count -= 1; + } + + // signal to the epilogue that dV is ready + pipeline_mma_compute_dkdv.producer_acquire(pipeline_mma_compute_dkdv_producer_state); + pipeline_mma_compute_dkdv.producer_commit(pipeline_mma_compute_dkdv_producer_state); + ++pipeline_mma_compute_dkdv_producer_state; + + pipeline_mma_compute_dkdv.producer_acquire(pipeline_mma_compute_dkdv_producer_state); + + pipeline_compute_mma_ds.consumer_wait(pipeline_compute_mma_ds_consumer_state); + + // dK = dS*Q + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDKrDST); ++k_block) { + cute::gemm(tiled_mma_dsq, + tDKrDST(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), + tDKrQT(_,_,k_block,pipeline_load_mma_q_release_state.index()), + tDKtDK); + tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::One; + } + + // signal to epilgue that dK is ready + pipeline_mma_compute_dkdv.producer_commit(pipeline_mma_compute_dkdv_producer_state); + ++pipeline_mma_compute_dkdv_producer_state; + + // we've already acquired mma_reduce_dq in the loop + + // dQ = dS*K + tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDQrDS); ++k_block) { + cute::gemm(tiled_mma_dsk, + tDQrDS(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), + tDQrKT(_,_,k_block,_0{}), + tDQtDQ); + tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_mma_reduce_dq.producer_commit(pipeline_mma_reduce_dq_producer_state); + ++pipeline_mma_reduce_dq_producer_state; + + pipeline_load_mma_q.consumer_release(pipeline_load_mma_q_release_state); + ++pipeline_load_mma_q_release_state; + + pipeline_compute_mma_ds.consumer_release(pipeline_compute_mma_ds_consumer_state); + ++pipeline_compute_mma_ds_consumer_state; + } + + + + template + CUTLASS_DEVICE void store( + TensorG gmem, + TensorR const& regs, + TensorC const& coord, + TensorShape const& tensor_shape) { + //TODO Performance of FlashMLA on hopper is dropped with latest cutlass, so here revert the to the old version. + // Tensor preds = cute::lazy::transform(coord, [&](auto const& c) { return elem_less(c, tensor_shape); }); + + auto copy_op = make_cotiled_copy( + Copy_Atom, Element>{}, + make_layout(make_shape(_1{}, Int{})), + regs.layout() + ); + auto thr_copy = copy_op.get_slice(_0{}); + + Tensor quantized_regs = quantize(regs); + auto tCg = thr_copy.partition_D(gmem); + auto tCr = thr_copy.partition_S(quantize(regs)); + auto tCc = thr_copy.partition_D(coord); + + + constexpr int R = decltype(tCr.layout())::rank; + auto tCg_v = group_modes<1, R>(tCg); + auto tCr_v = group_modes<1, R>(tCr); + auto tCc_v = group_modes<1, R>(tCc); + auto tCp_v = make_tensor(shape<1>(tCc_v)); + + for (int i = 0; i < size(tCp_v); ++i) { + tCp_v(i) = elem_less(tCc_v(_0{},i), tensor_shape); + } + + copy_if(copy_op, tCp_v, tCr_v, tCg_v); + + } + + + template + CUTLASS_DEVICE void epilogue_clear( + BlkCoord const& blk_coord, + BlkOffset const& blk_offset, + ProblemShape_ const& problem_shape, + MainloopArguments const& mainloop_args, + EpilogueArguments const& epilogue_args) { + + auto [Q, K, D, D_VO, HB] = problem_shape; + auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; + + auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk); + auto mDK = domain_offset(select<1,2,4>(blk_offset), mDK_in); + auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{}) + (_, _, blk_coord_k, _0{}, blk_coord_batch); + + Tensor cDK = domain_offset( + make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}), + make_identity_tensor(take<0,2>(TileShapeDSQ{})) + ); + + auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv); + auto mDV = domain_offset(select<1,3,4>(blk_offset), mDV_in); + auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{}) + (_, _, blk_coord_k, _0{}, blk_coord_batch); + + Tensor cDV = domain_offset( + make_coord(blk_coord_k * TileShapeK{}, _0{}), + make_identity_tensor(take<0,2>(TileShapePDO{})) + ); + + for (int i = threadIdx.x; i < size(gDK); i += blockDim.x) { + if (elem_less(cDK(i), select<1,2>(problem_shape))) { + gDK(i) = Element(0); + } + } + for (int i = threadIdx.x; i < size(gDV); i += blockDim.x) { + if (elem_less(cDV(i), select<1,3>(problem_shape))) { + gDV(i) = Element(0); + } + } + + } + + + template + CUTLASS_DEVICE void epilogue( + BlkCoord const& blk_coord, + BlkOffset const& blk_offset, + ProblemShape_ const& problem_shape, + MainloopArguments const& mainloop_args, + EpilogueArguments const& epilogue_args, + PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, + typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) { + + auto [Q, K, D, D_VO, HB] = problem_shape; + auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; + + auto load_op = SM100_TMEM_LOAD_32dp32b16x{}; + + auto tDKtDK = partition_fragment_C(TiledMmaDSQ{}, select<0,1>(TileShapeDSQ{}))(make_coord(_,_),_0{},_0{}); + tDKtDK.data() = TmemAllocation::kDK; + + auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk); + auto mDK = domain_offset(select<1,2,4>(blk_offset), mDK_in); + auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{}) + (_, _, blk_coord_k, _0{}, blk_coord_batch); + + Tensor cDK = domain_offset( + make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}), + make_identity_tensor(take<0,2>(TileShapeDSQ{})) + ); + + constexpr int kNumWarpgroups = kNumComputeWarps / 4; + int dp_idx = threadIdx.x % 128; + int wg_idx = (threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp)) / 128; + + auto split_wg = [&](auto const& t) { + if constexpr (decltype(rank(t))::value == 3) { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), make_shape(Int{}, size<2>(t) / Int{})))); + return p(_, _, make_coord(wg_idx, _)); + } + else { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), size<2>(t), make_shape(Int{}, size<3>(t) / Int{})))); + return p(_, _, _, make_coord(wg_idx, _)); + } + }; + + auto tiled_t2r_dk = make_tmem_copy(load_op, tDKtDK); + auto thread_t2r_dk = tiled_t2r_dk.get_slice(dp_idx); + + Tensor tTR_cDK = split_wg(thread_t2r_dk.partition_D(cDK)); + Tensor tTR_gDK = split_wg(thread_t2r_dk.partition_D(gDK)); + Tensor tTR_rDK = make_tensor(shape(tTR_cDK)); + Tensor tTR_tDK = split_wg(thread_t2r_dk.partition_S(tDKtDK)); + + auto tDVtDV = partition_fragment_C(TiledMmaPDO{}, select<0,1>(TileShapePDO{}))(make_coord(_,_),_0{},_0{}); + tDVtDV.data() = TmemAllocation::kDV; + + auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv); + auto mDV = domain_offset(select<1,3,4>(blk_offset), mDV_in); + auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{}) + (_, _, blk_coord_k, _0{}, blk_coord_batch); + + Tensor cDV = domain_offset( + make_coord(blk_coord_k * TileShapeK{}, _0{}), + make_identity_tensor(take<0,2>(TileShapePDO{})) + ); + + auto tiled_t2r_dv = make_tmem_copy(load_op, tDVtDV); + auto thread_t2r_dv = tiled_t2r_dv.get_slice(dp_idx); + + Tensor tTR_cDV = split_wg(thread_t2r_dv.partition_D(cDV)); + Tensor tTR_gDV = split_wg(thread_t2r_dv.partition_D(gDV)); + Tensor tTR_rDV = make_tensor(shape(tTR_cDV)); + Tensor tTR_tDV = split_wg(thread_t2r_dv.partition_S(tDVtDV)); + + pipeline_mma_compute_dkdv.consumer_wait(pipeline_mma_compute_dkdv_consumer_state); + + // load tDVtDV + cute::copy(tiled_t2r_dv, tTR_tDV, tTR_rDV); + + // store tDVgDV + store(tTR_gDV, tTR_rDV, tTR_cDV, select<1,3>(problem_shape)); + + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_compute_dkdv.consumer_release(pipeline_mma_compute_dkdv_consumer_state); + ++pipeline_mma_compute_dkdv_consumer_state; + + pipeline_mma_compute_dkdv.consumer_wait(pipeline_mma_compute_dkdv_consumer_state); + + // load tDKtDK + cute::copy(tiled_t2r_dk, tTR_tDK, tTR_rDK); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rDK); i++) { + tTR_rDK(i) = mainloop_args.softmax_scale * tTR_rDK(i); + } + + // store tDKgDK + store(tTR_gDK, tTR_rDK, tTR_cDK, select<1,2>(problem_shape)); + + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_compute_dkdv.consumer_release(pipeline_mma_compute_dkdv_consumer_state); + ++pipeline_mma_compute_dkdv_consumer_state; + + } + + + template + CUTLASS_DEVICE void compute( + BlkCoord const& blk_coord, + BlkOffset const& blk_offset, + ProblemShape_ const& problem_shape, + int iter_index, + int iter_count, + MainloopArguments const& mainloop_args, + EpilogueArguments const& epilogue_args, + TensorStorage& shared_tensors, + PipelineLoadComputeLSE& pipeline_load_compute_lse, + typename PipelineLoadComputeLSE::PipelineState& pipeline_load_compute_lse_consumer_state, + PipelineLoadComputeSumOdO& pipeline_load_compute_sum_odo, + typename PipelineLoadComputeSumOdO::PipelineState& pipeline_load_compute_sum_odo_consumer_state, + PipelineMmaComputeS& pipeline_mma_compute_s, + typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_consumer_state, + PipelineMmaComputeDP& pipeline_mma_compute_dp, + typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_consumer_state, + PipelineComputeMmaP& pipeline_compute_mma_p, + typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_producer_state, + PipelineComputeMmaDS& pipeline_compute_mma_ds, + typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_producer_state, + PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, + typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) { + + + auto [Q, K, D, D_VO, HB] = problem_shape; + + // in tmem, S & P overlap + // and dP and dQ overlap + + // there are two compute wg's that cooperatively compute softmax + // they are striped by this tmem atom, i.e. wg0 has 16 elems, then wg1 etc + + auto load_op = SM100_TMEM_LOAD_16dp32b32x{}; + + Tensor tSTtST = partition_fragment_C(TiledMmaQK{}, select<0,1>(TileShapeQK{}))(make_coord(_,_),_0{},_0{}); + tSTtST.data() = TmemAllocation::kS; + + Tensor tDPTtDPT = partition_fragment_C(TiledMmaDOV{}, select<0,1>(TileShapeDOV{}))(make_coord(_,_),_0{},_0{}); + tDPTtDPT.data() = TmemAllocation::kDP; + + Tensor cST = make_identity_tensor(take<0,2>(TileShapeQK{})); + Tensor cDPT = make_identity_tensor(take<0,2>(TileShapeDOV{})); + Tensor cPT = make_identity_tensor(take<0,2>(TileShapeQK{})); + + constexpr int kNumWarpgroups = kNumComputeWarps / 4; + int dp_idx = threadIdx.x % 128; + int wg_idx = (threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp)) / 128; + auto tiled_t2r = make_tmem_copy(load_op, tSTtST); + auto thread_t2r = tiled_t2r.get_slice(dp_idx); + + auto split_wg = [&](auto const& t) { + if constexpr (decltype(size<1>(t))::value > 1) { + if constexpr (decltype(rank(t))::value == 3) { + auto p = t.compose(make_layout(make_shape(size<0>(t), make_shape(Int{}, size<1>(t) / Int{}), size<2>(t)))); + return p(_, make_coord(wg_idx, _), _); + } + else { + auto p = t.compose(make_layout(make_shape(size<0>(t), make_shape(Int{}, size<1>(t) / Int{}), size<2>(t), size<3>(t)))); + return p(_, make_coord(wg_idx, _), _, _); + } + } + else { + if constexpr (decltype(rank(t))::value == 3) { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), make_shape(Int{}, size<2>(t) / Int{})))); + return p(_, _, make_coord(wg_idx, _)); + } + else { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), size<2>(t), make_shape(Int{}, size<3>(t) / Int{})))); + return p(_, _, _, make_coord(wg_idx, _)); + } + } + }; + + Tensor tTR_cST_p = thread_t2r.partition_D(cST); + Tensor tTR_cST = split_wg(tTR_cST_p); + Tensor tTR_rST = make_tensor(shape(tTR_cST)); + Tensor tTR_tST = split_wg(thread_t2r.partition_S(tSTtST)); + + Tensor tTR_cDPT_p = thread_t2r.partition_D(cDPT); + Tensor tTR_cPT_p = thread_t2r.partition_D(cPT); + Tensor tTR_cDPT = split_wg(tTR_cDPT_p); + Tensor tTR_rDPT = make_tensor(shape(tTR_cDPT)); + Tensor tTR_tDPT = split_wg(thread_t2r.partition_S(tDPTtDPT)); + + Tensor sLSE = make_tensor(make_smem_ptr(shared_tensors.smem_lse.begin()), SmemLayoutLSE{}); + Tensor sSumOdO = make_tensor(make_smem_ptr(shared_tensors.smem_sum_odo.begin()), SmemLayoutSumOdO{}); + + bool is_residual_k = get<1>(blk_coord) * TileShapeK{} + TileShapeK{} >= get<1>(problem_shape); + int last_iter = iter_count - 1 + iter_index; + + CUTLASS_PRAGMA_NO_UNROLL + while (iter_count > 0) { + // wait for S and P + pipeline_mma_compute_s.consumer_wait(pipeline_mma_compute_s_consumer_state); + pipeline_compute_mma_p.producer_acquire(pipeline_compute_mma_p_producer_state); + // wait for LSE + pipeline_load_compute_lse.consumer_wait(pipeline_load_compute_lse_consumer_state); + + auto dispatch_bool = [](bool b, auto fn) { + if (b) { + fn(cute::true_type{}); + } + else { + fn(cute::false_type{}); + } + }; + + bool leading_causal_masking = false; + if constexpr (std::is_base_of_v, Mask>) { + leading_causal_masking = warp_uniform(iter_index == get<1>(blk_coord)); + } else if constexpr (std::is_base_of_v, Mask>) { + int offset = get<1>(problem_shape) - get<0>(problem_shape); + int kv_left = get<1>(blk_coord) * TileShapeK{}; + int kv_right = kv_left + TileShapeK{} - 1; + int q_left = iter_index * TileShapeQ{} + offset; + int q_right = q_left + TileShapeQ{} - 1; + + leading_causal_masking = warp_uniform(!((q_left > kv_right) || (q_right < kv_left))); + } + bool trailing_residual_masking = false; + if constexpr (std::is_base_of_v) { + trailing_residual_masking = warp_uniform((iter_index == last_iter) || is_residual_k); + } + + dispatch_bool(leading_causal_masking || trailing_residual_masking, [&](auto is_masked_tile) { + + // compute P = softmax(S, LSE) + cute::copy(tiled_t2r, tTR_tST, tTR_rST); + + if constexpr (decltype(is_masked_tile)::value) { + Mask{}.apply_mask(tTR_rST, [&](int i) { + auto c_transpose = tTR_cST(i); + return make_coord(get<0>(c_transpose) + iter_index * TileShapeQ{}, get<1>(c_transpose) + get<1>(blk_coord) * TileShapeK{}); + }, problem_shape); + } + + ElementAcc log2_e = static_cast(M_LOG2E); + float2 softmax_scale_log2_e; + softmax_scale_log2_e.x = mainloop_args.softmax_scale * log2_e; + softmax_scale_log2_e.y = mainloop_args.softmax_scale * log2_e; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rST); i += 2) { + float2 acc; + float2 lse; + float2 out; + acc.x = tTR_rST(i); + acc.y = tTR_rST(i + 1); + lse.x = sLSE(get<0>(tTR_cST(i)), pipeline_load_compute_lse_consumer_state.index()); + lse.y = sLSE(get<0>(tTR_cST(i+1)), pipeline_load_compute_lse_consumer_state.index()); + cute::fma(out, softmax_scale_log2_e, acc, lse); + tTR_rST(i) = ::exp2f(out.x); + tTR_rST(i+1) = ::exp2f(out.y); + } + + auto tRT_rST = quantize(tTR_rST); + + Tensor sP = make_tensor(make_smem_ptr((Element*) shared_tensors.smem_p.begin()), SmemLayoutP{}) + (_, _, _, pipeline_compute_mma_p_producer_state.index()); + + cutlass::arch::fence_view_async_tmem_load(); + cutlass::arch::NamedBarrier( + kNumComputeWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::TransformBarrier + ).arrive_and_wait(); + + auto sP_pi = as_position_independent_swizzle_tensor(sP); + + auto thread_layout = make_ordered_layout( + make_shape(_64{}, _32{}, _2{}, _2{}), + make_stride(_3{}, _0{}, _1{}, _2{}) + ); + auto sP_pi_slice_p = sP_pi.compose(thread_layout)(((dp_idx/32) * 16) + (dp_idx % 16) , _, (dp_idx % 32 / 16), _).compose(make_layout(shape(tTR_cPT_p))); + auto sP_pi_slice = split_wg(sP_pi_slice_p); + copy_aligned(tRT_rST, sP_pi_slice); + }); + + // notify for P + cutlass::arch::fence_view_async_shared(); + pipeline_compute_mma_p.producer_commit(pipeline_compute_mma_p_producer_state); + ++pipeline_compute_mma_p_producer_state; + // release S + pipeline_mma_compute_s.consumer_release(pipeline_mma_compute_s_consumer_state); + ++pipeline_mma_compute_s_consumer_state; + // release LSE + pipeline_load_compute_lse.consumer_release(pipeline_load_compute_lse_consumer_state); + ++pipeline_load_compute_lse_consumer_state; + + // wait for OdO + pipeline_load_compute_sum_odo.consumer_wait(pipeline_load_compute_sum_odo_consumer_state); + // wait for dP + pipeline_mma_compute_dp.consumer_wait(pipeline_mma_compute_dp_consumer_state); + + // wait for dS + // in principle, we could defer waiting for dS, and move in the freeing of dP + // however, that would force us to keep dS in registers longer + pipeline_compute_mma_ds.producer_acquire(pipeline_compute_mma_ds_producer_state); + + // compute dS = dsoftmax(P, dP, sum_OdO) + cute::copy(tiled_t2r, tTR_tDPT, tTR_rDPT); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rDPT); i += 2) { + float2 st; + st.x = tTR_rST(i); + st.y = tTR_rST(i+1); + float2 dpt; + dpt.x = tTR_rDPT(i); + dpt.y = tTR_rDPT(i+1); + float2 odo; + odo.x = sSumOdO(get<0>(tTR_cDPT(i)), pipeline_load_compute_sum_odo_consumer_state.index()); + odo.y = sSumOdO(get<0>(tTR_cDPT(i+1)), pipeline_load_compute_sum_odo_consumer_state.index()); + float2 dif; + // sum odo is negated during preprocess + cute::add(dif, dpt, odo); + float2 out; + cute::mul(out, dif, st); + tTR_rDPT(i) = out.x; + tTR_rDPT(i+1) = out.y; + } + + auto tTR_rDST = quantize(tTR_rDPT); + + // release dP + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_compute_dp.consumer_release(pipeline_mma_compute_dp_consumer_state); + ++pipeline_mma_compute_dp_consumer_state; + + Tensor sDS = make_tensor(make_smem_ptr((Element*) shared_tensors.smem_ds_t.begin()), SmemLayoutDST{}) + (_, _, _, pipeline_compute_mma_ds_producer_state.index()); + + auto thread_layout = make_ordered_layout( + make_shape(_64{}, _32{}, _2{}, _2{}), + make_stride(_3{}, _0{}, _1{}, _2{}) + ); + auto sDS_pi = as_position_independent_swizzle_tensor(sDS); + auto sDS_pi_slice_p = sDS_pi.compose(thread_layout)(((dp_idx/32) * 16) + (dp_idx % 16) , _, (dp_idx % 32 / 16), _).compose(make_layout(shape (tTR_cDPT_p))); + auto sDS_pi_slice = split_wg(sDS_pi_slice_p); + + copy_aligned(tTR_rDST, sDS_pi_slice); + + // notify for dS + cutlass::arch::fence_view_async_shared(); + pipeline_compute_mma_ds.producer_commit(pipeline_compute_mma_ds_producer_state); + ++pipeline_compute_mma_ds_producer_state; + // release OdO + pipeline_load_compute_sum_odo.consumer_release(pipeline_load_compute_sum_odo_consumer_state); + ++pipeline_load_compute_sum_odo_consumer_state; + + iter_count -= 1; + iter_index += 1; + } + + epilogue( + blk_coord, blk_offset, problem_shape, mainloop_args, epilogue_args, + pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_consumer_state + ); + } + + template + CUTLASS_DEVICE void reduce( + BlkCoord const& blk_coord, + ProblemShape_ const& problem_shape, + int iter_index, + int iter_count, + MainloopArguments const& mainloop_args, + MainloopParams const& mainloop_params, + TensorStorage& shared_tensors, + PipelineMmaReduceDQ& pipeline_mma_reduce_dq, + typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_consumer_state, + PipelineReduceTmaStore& pipeline_reduce_tma_store, + typename PipelineReduceTmaStore::PipelineState& pipeline_reduce_tma_store_producer_state) { + + using X = Underscore; + + auto [Q, K, D, D_VO, HB] = problem_shape; + + auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; + + // must match TileShapeDQ + auto load_op = SM100_TMEM_LOAD_16dp32b16x{}; + + auto tDQtDQ = partition_fragment_C(TiledMmaDSK{}, select<0,1>(TileShapeDSK{}))(make_coord(_,_),_0{},_0{}); + tDQtDQ.data() = TmemAllocation::kDQ; + + Tensor mDQ = mainloop_params.tma_red_dq.get_tma_tensor(make_shape(Q, D, HB)); + auto gDQ = local_tile(mDQ, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{}) + (_, _, _, _0{}, blk_coord_batch); + + Tensor cDQ = make_identity_tensor(take<0,2>(TileShapeDSK{})); + + Tensor sDQ = make_tensor(make_smem_ptr(shared_tensors.smem_dq.begin()), SmemLayoutDQ{}); + + int thread_idx = threadIdx.x % (kNumReduceWarps * NumThreadsPerWarp); + auto tiled_t2r = make_tmem_copy(load_op, tDQtDQ); + auto thread_t2r = tiled_t2r.get_slice(thread_idx); + + Tensor tTR_cDQ = thread_t2r.partition_D(cDQ); + Tensor tTR_gDQ = thread_t2r.partition_D(gDQ); + Tensor tTR_sDQ = thread_t2r.partition_D(sDQ); + Tensor tTR_tDQ = thread_t2r.partition_S(tDQtDQ); + + auto block_tma = mainloop_params.tma_red_dq.get_slice(_0{}); + + Tensor tDQsDQ = block_tma.partition_S(sDQ); + Tensor tDQcDQ = block_tma.partition_S(cDQ); + Tensor tDQgDQ = block_tma.partition_D(gDQ); + + int lane_predicate = (threadIdx.x % (kNumReduceWarps * NumThreadsPerWarp)) == 0; + + while (iter_count > 0) { + pipeline_mma_reduce_dq.consumer_wait(pipeline_mma_reduce_dq_consumer_state); + + Tensor tTR_rDQ = make_tensor(shape(tTR_cDQ)); + + // load dQ from tmem to rmem + cute::copy(tiled_t2r, tTR_tDQ, tTR_rDQ); + + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_reduce_dq.consumer_release(pipeline_mma_reduce_dq_consumer_state); + ++pipeline_mma_reduce_dq_consumer_state; + + // we don't have enough smem to dump it all to smem, so we do it in stages + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<2>(tTR_cDQ); i++) { + if (lane_predicate) { + pipeline_reduce_tma_store.producer_acquire(pipeline_reduce_tma_store_producer_state); + } + // wait in all threads for the acquire to complete + cutlass::arch::NamedBarrier( + kNumReduceWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::TransposeBarrier + ).arrive_and_wait(); + + cute::copy(tTR_rDQ(_, _, i), tTR_sDQ(_, _, _0{}, pipeline_reduce_tma_store_producer_state.index())); + + // wait for the stores to all be visible to the TMA + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier( + kNumReduceWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::TransposeBarrier + ).arrive_and_wait(); + if (lane_predicate) { + // launch tma store + copy(mainloop_params.tma_red_dq, tDQsDQ(_,_,_0{}, pipeline_reduce_tma_store_producer_state.index()), tDQgDQ(_,_,i,iter_index)); + pipeline_reduce_tma_store.producer_commit(pipeline_reduce_tma_store_producer_state); + } + + ++pipeline_reduce_tma_store_producer_state; + } + + iter_count -= 1; + iter_index += 1; + } + } + + + CUTLASS_DEVICE void operator()(Params const& params, char* smem) { + int warp_idx = cutlass::canonical_warp_idx_sync(); + auto role = warp_idx_to_role(warp_idx); + uint32_t lane_predicate = cute::elect_one_sync(); + + if (role == WarpRole::Load && lane_predicate) { + prefetch_tma_descriptor(params.mainloop_params.tma_load_q.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_k.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_v.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_do.get_tma_descriptor()); + } + + SharedStorage& shared_storage = *reinterpret_cast(smem); + + int initializing_warp = 0; + typename PipelineLoadMmaQ::Params pipeline_load_mma_q_params; + if (role == WarpRole::Load) { + pipeline_load_mma_q_params.role = PipelineLoadMmaQ::ThreadCategory::Producer; + } + if (role == WarpRole::Mma) { + pipeline_load_mma_q_params.role = PipelineLoadMmaQ::ThreadCategory::Consumer; + } + pipeline_load_mma_q_params.is_leader = lane_predicate && (role == WarpRole::Load); + // Also loads K in the first iteration + pipeline_load_mma_q_params.transaction_bytes = kTransactionsBytesLoadQ; + pipeline_load_mma_q_params.initializing_warp = initializing_warp++; + PipelineLoadMmaQ pipeline_load_mma_q(shared_storage.pipelines.load_mma_q, pipeline_load_mma_q_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineLoadMmaDO::Params pipeline_load_mma_do_params; + if (role == WarpRole::Load) { + pipeline_load_mma_do_params.role = PipelineLoadMmaDO::ThreadCategory::Producer; + } + if (role == WarpRole::Mma) { + pipeline_load_mma_do_params.role = PipelineLoadMmaDO::ThreadCategory::Consumer; + } + pipeline_load_mma_do_params.is_leader = lane_predicate && (role == WarpRole::Load); + // Also loads V in the first iteration + pipeline_load_mma_do_params.transaction_bytes = kTransactionsBytesLoadDO; + pipeline_load_mma_do_params.initializing_warp = initializing_warp++; + PipelineLoadMmaDO pipeline_load_mma_do(shared_storage.pipelines.load_mma_do, pipeline_load_mma_do_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineLoadComputeLSE::Params pipeline_load_compute_lse_params; + if (role == WarpRole::Load) { + pipeline_load_compute_lse_params.role = PipelineLoadComputeLSE::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_load_compute_lse_params.role = PipelineLoadComputeLSE::ThreadCategory::Consumer; + } + pipeline_load_compute_lse_params.producer_arv_count = NumThreadsPerWarp; + pipeline_load_compute_lse_params.consumer_arv_count = kNumComputeWarps * NumThreadsPerWarp; + pipeline_load_compute_lse_params.initializing_warp = initializing_warp++; + PipelineLoadComputeLSE pipeline_load_compute_lse( + shared_storage.pipelines.load_compute_lse, + pipeline_load_compute_lse_params, + /*barrier init*/ cute::true_type{}); + + typename PipelineLoadComputeSumOdO::Params pipeline_load_compute_sum_odo_params; + if (role == WarpRole::Load) { + pipeline_load_compute_sum_odo_params.role = PipelineLoadComputeSumOdO::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_load_compute_sum_odo_params.role = PipelineLoadComputeSumOdO::ThreadCategory::Consumer; + } + pipeline_load_compute_sum_odo_params.producer_arv_count = NumThreadsPerWarp; + pipeline_load_compute_sum_odo_params.consumer_arv_count = kNumComputeWarps * NumThreadsPerWarp; + pipeline_load_compute_sum_odo_params.initializing_warp = initializing_warp++; + PipelineLoadComputeSumOdO pipeline_load_compute_sum_odo( + shared_storage.pipelines.load_compute_sum_odo, + pipeline_load_compute_sum_odo_params, + /*barrier init*/ cute::true_type{}); + + typename PipelineMmaComputeS::Params pipeline_mma_compute_s_params; + if (role == WarpRole::Mma) { + pipeline_mma_compute_s_params.role = PipelineMmaComputeS::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_mma_compute_s_params.role = PipelineMmaComputeS::ThreadCategory::Consumer; + } + pipeline_mma_compute_s_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_mma_compute_s_params.initializing_warp = initializing_warp++; + PipelineMmaComputeS pipeline_mma_compute_s( + shared_storage.pipelines.mma_compute_s, + pipeline_mma_compute_s_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineMmaComputeDP::Params pipeline_mma_compute_dp_params; + if (role == WarpRole::Mma) { + pipeline_mma_compute_dp_params.role = PipelineMmaComputeDP::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_mma_compute_dp_params.role = PipelineMmaComputeDP::ThreadCategory::Consumer; + } + pipeline_mma_compute_dp_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_mma_compute_dp_params.initializing_warp = initializing_warp++; + PipelineMmaComputeDP pipeline_mma_compute_dp( + shared_storage.pipelines.mma_compute_dp, + pipeline_mma_compute_dp_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineMmaReduceDQ::Params pipeline_mma_reduce_dq_params; + if (role == WarpRole::Mma) { + pipeline_mma_reduce_dq_params.role = PipelineMmaReduceDQ::ThreadCategory::Producer; + } + if (role == WarpRole::Reduce) { + pipeline_mma_reduce_dq_params.role = PipelineMmaReduceDQ::ThreadCategory::Consumer; + } + pipeline_mma_reduce_dq_params.consumer_arv_count = kNumReduceWarps * cutlass::NumThreadsPerWarp; + pipeline_mma_reduce_dq_params.initializing_warp = initializing_warp++; + PipelineMmaReduceDQ pipeline_mma_reduce_dq( + shared_storage.pipelines.mma_reduce_dq, + pipeline_mma_reduce_dq_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineComputeMmaP::Params pipeline_compute_mma_p_params; + if (role == WarpRole::Mma) { + pipeline_compute_mma_p_params.role = PipelineComputeMmaP::ThreadCategory::Consumer; + } + if (role == WarpRole::Compute) { + pipeline_compute_mma_p_params.role = PipelineComputeMmaP::ThreadCategory::Producer; + } + pipeline_compute_mma_p_params.producer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_compute_mma_p_params.consumer_arv_count = 1; + pipeline_compute_mma_p_params.initializing_warp = initializing_warp++; + PipelineComputeMmaP pipeline_compute_mma_p( + shared_storage.pipelines.compute_mma_p, + pipeline_compute_mma_p_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineComputeMmaDS::Params pipeline_compute_mma_ds_params; + if (role == WarpRole::Mma) { + pipeline_compute_mma_ds_params.role = PipelineComputeMmaDS::ThreadCategory::Consumer; + } + if (role == WarpRole::Compute) { + pipeline_compute_mma_ds_params.role = PipelineComputeMmaDS::ThreadCategory::Producer; + } + pipeline_compute_mma_ds_params.producer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_compute_mma_ds_params.consumer_arv_count = 1; + pipeline_compute_mma_ds_params.initializing_warp = initializing_warp++; + PipelineComputeMmaDS pipeline_compute_mma_ds( + shared_storage.pipelines.compute_mma_ds, + pipeline_compute_mma_ds_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineMmaComputeDKDV::Params pipeline_mma_compute_dkdv_params; + if (role == WarpRole::Mma) { + pipeline_mma_compute_dkdv_params.role = PipelineMmaComputeDKDV::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_mma_compute_dkdv_params.role = PipelineMmaComputeDKDV::ThreadCategory::Consumer; + } + pipeline_mma_compute_dkdv_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_mma_compute_dkdv_params.initializing_warp = initializing_warp++; + PipelineMmaComputeDKDV pipeline_mma_compute_dkdv( + shared_storage.pipelines.mma_compute_dkdv, + pipeline_mma_compute_dkdv_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + PipelineReduceTmaStore pipeline_reduce_tma_store; + + TmemAllocator tmem_allocator; + + pipeline_init_arrive_relaxed(size(ClusterShape{})); + + pipeline_load_mma_q.init_masks(ClusterShape{}); + pipeline_load_mma_do.init_masks(ClusterShape{}); + pipeline_mma_compute_s.init_masks(ClusterShape{}); + pipeline_mma_compute_dp.init_masks(ClusterShape{}); + pipeline_mma_reduce_dq.init_masks(ClusterShape{}); + pipeline_compute_mma_p.init_masks(ClusterShape{}); + pipeline_compute_mma_ds.init_masks(ClusterShape{}); + pipeline_mma_compute_dkdv.init_masks(ClusterShape{}); + + typename decltype(pipeline_load_mma_q)::PipelineState pipeline_load_mma_q_consumer_state; + typename decltype(pipeline_load_mma_do)::PipelineState pipeline_load_mma_do_consumer_state; + typename decltype(pipeline_load_compute_lse)::PipelineState pipeline_load_compute_lse_consumer_state; + typename decltype(pipeline_load_compute_sum_odo)::PipelineState pipeline_load_compute_sum_odo_consumer_state; + typename decltype(pipeline_mma_compute_s)::PipelineState pipeline_mma_compute_s_consumer_state; + typename decltype(pipeline_mma_compute_dp)::PipelineState pipeline_mma_compute_dp_consumer_state; + typename decltype(pipeline_mma_reduce_dq)::PipelineState pipeline_mma_reduce_dq_consumer_state; + typename decltype(pipeline_compute_mma_p)::PipelineState pipeline_compute_mma_p_consumer_state; + typename decltype(pipeline_compute_mma_ds)::PipelineState pipeline_compute_mma_ds_consumer_state; + typename decltype(pipeline_mma_compute_dkdv)::PipelineState pipeline_mma_compute_dkdv_consumer_state; + + auto pipeline_load_mma_q_producer_state = make_producer_start_state(); + auto pipeline_load_mma_do_producer_state = make_producer_start_state(); + auto pipeline_load_compute_lse_producer_state = make_producer_start_state(); + auto pipeline_load_compute_sum_odo_producer_state = make_producer_start_state(); + auto pipeline_mma_compute_s_producer_state = make_producer_start_state(); + auto pipeline_mma_compute_dp_producer_state = make_producer_start_state(); + auto pipeline_mma_reduce_dq_producer_state = make_producer_start_state(); + auto pipeline_compute_mma_p_producer_state = make_producer_start_state(); + auto pipeline_compute_mma_ds_producer_state = make_producer_start_state(); + auto pipeline_mma_compute_dkdv_producer_state = make_producer_start_state(); + auto pipeline_reduce_tma_store_producer_state = make_producer_start_state(); + + pipeline_init_wait(size(ClusterShape{})); + + auto blk_coord = make_coord(_0{}, blockIdx.x, _0{}, _0{}, make_coord(blockIdx.y, blockIdx.z)); + auto [problem_shape, blk_offset] = apply_variable_length_offset( + params.problem_shape, + blk_coord + ); + int iter_count = ceil_div(get<0>(problem_shape), TileShapeQ{}); + int iter_start = 0; + if constexpr (std::is_base_of_v, Mask>) { + iter_start = (get<1>(blk_coord) * TileShapeK{}) / TileShapeQ{}; + } else if constexpr (std::is_base_of_v, Mask>) { + int offset = get<1>(problem_shape) - get<0>(problem_shape); + iter_start = max(0, (int(get<1>(blk_coord) * TileShapeK{}) - offset) / (int)TileShapeQ{}); + } + if (get<1>(blk_coord) * TileShapeK{} >= get<1>(problem_shape)) { + return; + } + iter_count -= iter_start; + + if (iter_count <= 0) { + epilogue_clear( + blk_coord, + blk_offset, + problem_shape, + params.mainloop, + params.epilogue + ); + return; + } + + if (role == WarpRole::Load) { + warpgroup_reg_set(); + + load( + blk_coord, + blk_offset, + problem_shape, + iter_start, + iter_count, + params.mainloop, + params.mainloop_params, + shared_storage.tensors, + pipeline_load_mma_q, pipeline_load_mma_q_producer_state, + pipeline_load_mma_do, pipeline_load_mma_do_producer_state, + pipeline_load_compute_lse, pipeline_load_compute_lse_producer_state, + pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_producer_state + ); + + } + else if (role == WarpRole::Mma) { + warpgroup_reg_set(); + + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + + mma( + blk_coord, + problem_shape, + iter_start, + iter_count, + params.mainloop, + shared_storage.tensors, + pipeline_load_mma_q, pipeline_load_mma_q_consumer_state, + pipeline_load_mma_do, pipeline_load_mma_do_consumer_state, + pipeline_mma_compute_s, pipeline_mma_compute_s_producer_state, + pipeline_mma_compute_dp, pipeline_mma_compute_dp_producer_state, + pipeline_mma_reduce_dq, pipeline_mma_reduce_dq_producer_state, + pipeline_compute_mma_p, pipeline_compute_mma_p_consumer_state, + pipeline_compute_mma_ds, pipeline_compute_mma_ds_consumer_state, + pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_producer_state + ); + + } + else if (role == WarpRole::Compute) { + warpgroup_reg_set(); + + compute( + blk_coord, + blk_offset, + problem_shape, + iter_start, + iter_count, + params.mainloop, + params.epilogue, + shared_storage.tensors, + pipeline_load_compute_lse, pipeline_load_compute_lse_consumer_state, + pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_consumer_state, + pipeline_mma_compute_s, pipeline_mma_compute_s_consumer_state, + pipeline_mma_compute_dp, pipeline_mma_compute_dp_consumer_state, + pipeline_compute_mma_p, pipeline_compute_mma_p_producer_state, + pipeline_compute_mma_ds, pipeline_compute_mma_ds_producer_state, + pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_consumer_state + ); + + cutlass::arch::NamedBarrier( + kNumComputeWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier + ).arrive_and_wait(); + + if (warp_idx % kNumComputeWarps == 0) { + uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + + } + else if (role == WarpRole::Reduce) { + warpgroup_reg_set(); + + reduce( + blk_coord, + problem_shape, + iter_start, + iter_count, + params.mainloop, + params.mainloop_params, + shared_storage.tensors, + pipeline_mma_reduce_dq, pipeline_mma_reduce_dq_consumer_state, + pipeline_reduce_tma_store, pipeline_reduce_tma_store_producer_state + ); + + pipeline_reduce_tma_store.producer_tail(pipeline_reduce_tma_store_producer_state); + } + else { + warpgroup_reg_set(); + + /* no-op */ + + } + } + + static dim3 get_block_shape() { + dim3 block(MaxThreadsPerBlock, 1, 1); + return block; + } + + static dim3 get_grid_shape(Params const& params) { + auto [Q, K, D, D_VO, HB] = params.problem_shape; + auto [H, B] = HB; + dim3 grid(ceil_div(K, TileShapeK{}), H, B); + return grid; + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/csrc/sm100/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp b/csrc/sm100/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp new file mode 100644 index 0000000..8fe503b --- /dev/null +++ b/csrc/sm100/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp @@ -0,0 +1,619 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/layout.hpp" +#include "cutlass/arch/arch.h" +#include "cutlass/kernel_hardware_info.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cute/arch/tmem_allocator_sm100.hpp" + +#include "kernel/fmha_options.hpp" +#include "kernel/fmha_tile_scheduler.hpp" +#include "kernel/fmha_causal_tile_scheduler.hpp" +#include "collective/fmha_fusion.hpp" +#include "collective/fmha_common.hpp" + +namespace cutlass::fmha::kernel { + +using namespace cute; +using namespace cutlass::fmha::collective; + +struct Sm100FmhaCtxKernelWarpspecializedSchedule { + + enum class WarpRole { + Softmax0, + Softmax1, + Correction, + MMA, + Load, + Epilogue, + Empty + }; + + static constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) { + int wg_idx = warp_idx / 4; // warp_idx + if (wg_idx == 0) return WarpRole::Softmax0; // 0 - 3 + if (wg_idx == 1) return WarpRole::Softmax1; // 4 - 7 + if (wg_idx == 2) return WarpRole::Correction; // 8 - 11 + if (warp_idx == 12) return WarpRole::MMA; // 12 + if (warp_idx == 13) return WarpRole::Load; // 13 + if (warp_idx == 14) return WarpRole::Epilogue; // 14 + return WarpRole::Empty; // 15 + } + + static const int NumWarpsSoftmax = 4; + static const int NumWarpsCorrection = 4; + static const int NumWarpsEpilogue = 1; + static const int NumWarpsLoad = 1; + + static const bool kDebugUsingPrintf = false; + static const int NumRegsSoftmax = 192; + static const int NumRegsCorrection = 96 - (kDebugUsingPrintf ? 16 : 0); + static const int NumRegsOther = 32 + (kDebugUsingPrintf ? 16 : 0); + static const int NumRegsEmpty = 24; + + static const int NumWarps = 16; + +}; + + +struct Sm100MlaFwdCtxKernelWarpspecializedSchedule { + + enum class WarpRole { + Softmax0, + Softmax1, + Correction, + MMA, + Load, + Epilogue, + Empty + }; + + static constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) { + int wg_idx = warp_idx / 4; // warp_idx + if (wg_idx == 0) return WarpRole::Softmax0; // 0 - 3 + if (wg_idx == 1) return WarpRole::Softmax1; // 4 - 7 + if (wg_idx == 2) return WarpRole::Correction; // 8 - 11 + if (warp_idx == 12) return WarpRole::MMA; // 12 + if (warp_idx == 13) return WarpRole::Load; // 13 + if (warp_idx == 14) return WarpRole::Epilogue; // 14 + return WarpRole::Empty; // 15 + } + + static const int NumWarpsSoftmax = 4; + static const int NumWarpsCorrection = 4; + static const int NumWarpsEpilogue = 1; + static const int NumWarpsLoad = 1; + + static const bool kDebugUsingPrintf = false; + static const int NumRegsSoftmax = 184; + static const int NumRegsCorrection = 96 - (kDebugUsingPrintf ? 16 : 0); + static const int NumRegsOther = 48 + (kDebugUsingPrintf ? 16 : 0); + static const int NumRegsEmpty = 24; + + static const int NumWarps = 16; + +}; + +template< + class ProblemShapeIn, + class CollectiveMainloop, + class CollectiveEpilogue, + class TileScheduler, + class KernelSchedule = Sm100FmhaCtxKernelWarpspecializedSchedule +> +struct Sm100FmhaFwdKernelTmaWarpspecialized { + + using TileShape = typename CollectiveMainloop::TileShape; + using ProblemShape = ProblemShapeIn; + + using WarpRole = typename KernelSchedule::WarpRole; + + constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) { + return KernelSchedule::warp_idx_to_WarpRole(warp_idx); + } + + static const int NumWarpsSoftmax = KernelSchedule::NumWarpsSoftmax; + static const int NumWarpsCorrection = KernelSchedule::NumWarpsCorrection; + static const int NumWarpsEpilogue = KernelSchedule::NumWarpsEpilogue; + static const int NumWarpsLoad = KernelSchedule::NumWarpsLoad; + + static_assert(NumWarpsEpilogue == CollectiveEpilogue::NumWarpsEpilogue); + static_assert(NumWarpsLoad == CollectiveEpilogue::NumWarpsLoad); + + static const int NumRegsSoftmax = KernelSchedule::NumRegsSoftmax; + static const int NumRegsCorrection = KernelSchedule::NumRegsCorrection; + static const int NumRegsOther = KernelSchedule::NumRegsOther; + static const int NumRegsEmpty = 24; + + static const int NumWarps = KernelSchedule::NumWarps; + + static constexpr bool IsMla = std::is_same_v; + + using ClusterShape = typename CollectiveMainloop::ClusterShape; + + using TmemAllocator = cute::TMEM::Allocator1Sm; + + struct SharedStorage { + using UnionType = union { + typename CollectiveMainloop::TensorStorage mainloop; + typename CollectiveEpilogue::TensorStorage epilogue; + }; + + using StructType = struct { + typename CollectiveMainloop::TensorStorage mainloop; + typename CollectiveEpilogue::TensorStorage epilogue; + }; + + static constexpr bool IsPersistent = std::is_same_v || std::is_same_v; + using MainloopEpilogueStorage = std::conditional_t, + StructType>, + UnionType>; + + MainloopEpilogueStorage mainloop_epilogue; + + struct PipelineStorage { + alignas(16) typename CollectiveMainloop::PipelineQ::SharedStorage load_q; + alignas(16) typename CollectiveMainloop::PipelineKV::SharedStorage load_kv; + alignas(16) typename CollectiveMainloop::PipelineS::SharedStorage mma_s0; + alignas(16) typename CollectiveMainloop::PipelineS::SharedStorage mma_s1; + alignas(16) typename CollectiveMainloop::PipelineC::SharedStorage s0_corr; + alignas(16) typename CollectiveMainloop::PipelineC::SharedStorage s1_corr; + alignas(16) typename CollectiveMainloop::PipelineO::SharedStorage mma_corr; + alignas(16) typename CollectiveMainloop::PipelineE::SharedStorage corr_epi; + alignas(16) typename CollectiveMainloop::OrderBarrierSoftmax::SharedStorage order_s01; + } pipelines; + + uint32_t tmem_base_ptr; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + struct Arguments { + ProblemShape problem_shape; + typename CollectiveMainloop::Arguments mainloop; + typename CollectiveEpilogue::Arguments epilogue; + cutlass::KernelHardwareInfo hw_info; + }; + + struct Params { + ProblemShape problem_shape; + typename CollectiveMainloop::Params mainloop; + typename CollectiveEpilogue::Params epilogue; + typename TileScheduler::Params tile_scheduler; + }; + + static const int MinBlocksPerMultiprocessor = 1; + static const int MaxThreadsPerBlock = NumWarps * cutlass::NumThreadsPerWarp; + using ArchTag = cutlass::arch::Sm100; + + static size_t get_workspace_size(Arguments const& args) { return 0; } + static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) { + return cutlass::Status::kSuccess; + } + + static bool can_implement(Arguments const& args) { + return CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + } + + static dim3 get_grid_shape(Params const& params) { + return TileScheduler::get_grid_shape(params.tile_scheduler); + } + + static dim3 get_block_shape() { + dim3 block(MaxThreadsPerBlock, 1, 1); + return block; + } + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + return Params{ + args.problem_shape, + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace), + TileScheduler::to_underlying_arguments(args.problem_shape, args.hw_info, ClusterShape{}, TileShape{}) + }; + } + + CUTLASS_DEVICE auto apply_batch(const Params ¶ms, ProblemShape const& problem_shape, int batch_idx) { + return apply_variable_length(params.problem_shape, batch_idx); + } + + CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { + + TileScheduler tile_scheduler{params.tile_scheduler}; + + int warp_idx = cutlass::canonical_warp_idx_sync(); + auto role = warp_idx_to_WarpRole(warp_idx); + uint32_t lane_predicate = cute::elect_one_sync(); + + if (role == WarpRole::Load && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + } + + if (role == WarpRole::Epilogue && lane_predicate) { + CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); + } + + SharedStorage& shared_storage = *reinterpret_cast(smem); + + auto get_epilogue_storage = [&]() { + if constexpr (IsMla && CollectiveMainloop::IsOrderLoadEpilogue) { + return reinterpret_cast(shared_storage.mainloop_epilogue.mainloop.smem_o.data()); + } else { + return &shared_storage.mainloop_epilogue.epilogue; + } + }; + typename CollectiveEpilogue::TensorStorage & epilogue_storage = *get_epilogue_storage(); + + + typename CollectiveMainloop::PipelineQ::Params pipeline_load_q_params; + if (role == WarpRole::Load) { + pipeline_load_q_params.role = CollectiveMainloop::PipelineQ::ThreadCategory::Producer; + } + if (role == WarpRole::MMA) { + pipeline_load_q_params.role = CollectiveMainloop::PipelineQ::ThreadCategory::Consumer; + } + pipeline_load_q_params.is_leader = lane_predicate && (role == WarpRole::Load); + pipeline_load_q_params.transaction_bytes = CollectiveMainloop::TransactionBytesLoadQ; + typename CollectiveMainloop::PipelineQ pipeline_load_q( + shared_storage.pipelines.load_q, + pipeline_load_q_params, + ClusterShape{}, cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename CollectiveMainloop::PipelineKV::Params pipeline_load_kv_params; + if (role == WarpRole::Load) { + pipeline_load_kv_params.role = CollectiveMainloop::PipelineKV::ThreadCategory::Producer; + } + if (role == WarpRole::MMA) { + pipeline_load_kv_params.role = CollectiveMainloop::PipelineKV::ThreadCategory::Consumer; + } + pipeline_load_kv_params.is_leader = lane_predicate && (role == WarpRole::Load); + pipeline_load_kv_params.transaction_bytes = CollectiveMainloop::TransactionBytesLoadK; + typename CollectiveMainloop::PipelineKV pipeline_load_kv( + shared_storage.pipelines.load_kv, + pipeline_load_kv_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename CollectiveMainloop::PipelineS::Params pipeline_mma_s0_params; + if (role == WarpRole::MMA) { + pipeline_mma_s0_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Producer; + } + if (role == WarpRole::Softmax0) { + pipeline_mma_s0_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Consumer; + } + pipeline_mma_s0_params.consumer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineS pipeline_mma_s0( + shared_storage.pipelines.mma_s0, + pipeline_mma_s0_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename CollectiveMainloop::PipelineS::Params pipeline_mma_s1_params; + if (role == WarpRole::MMA) { + pipeline_mma_s1_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Producer; + } + if (role == WarpRole::Softmax1) { + pipeline_mma_s1_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Consumer; + } + pipeline_mma_s1_params.consumer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineS pipeline_mma_s1( + shared_storage.pipelines.mma_s1, + pipeline_mma_s1_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename CollectiveMainloop::PipelineC::Params pipeline_s0_corr_params; + if (role == WarpRole::Softmax0) { + pipeline_s0_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Producer; + } + if (role == WarpRole::Correction) { + pipeline_s0_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Consumer; + } + pipeline_s0_corr_params.producer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + pipeline_s0_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineC pipeline_s0_corr( + shared_storage.pipelines.s0_corr, + pipeline_s0_corr_params, + /*barrier init*/ cute::true_type{}); + + typename CollectiveMainloop::PipelineC::Params pipeline_s1_corr_params; + if (role == WarpRole::Softmax1) { + pipeline_s1_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Producer; + } + if (role == WarpRole::Correction) { + pipeline_s1_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Consumer; + } + pipeline_s1_corr_params.producer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + pipeline_s1_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineC pipeline_s1_corr( + shared_storage.pipelines.s1_corr, + pipeline_s1_corr_params, + /*barrier init*/ cute::true_type{}); + + typename CollectiveMainloop::PipelineO::Params pipeline_mma_corr_params; + if (role == WarpRole::MMA) { + pipeline_mma_corr_params.role = CollectiveMainloop::PipelineO::ThreadCategory::Producer; + } + if (role == WarpRole::Correction) { + pipeline_mma_corr_params.role = CollectiveMainloop::PipelineO::ThreadCategory::Consumer; + } + pipeline_mma_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineO pipeline_mma_corr( + shared_storage.pipelines.mma_corr, + pipeline_mma_corr_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename CollectiveMainloop::PipelineE::Params pipeline_corr_epi_params; + if (role == WarpRole::Correction) { + pipeline_corr_epi_params.role = CollectiveMainloop::PipelineE::ThreadCategory::Producer; + } + if (role == WarpRole::Epilogue) { + pipeline_corr_epi_params.role = CollectiveMainloop::PipelineE::ThreadCategory::Consumer; + } + pipeline_corr_epi_params.producer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; + pipeline_corr_epi_params.consumer_arv_count = NumWarpsEpilogue * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineE pipeline_corr_epi( + shared_storage.pipelines.corr_epi, + pipeline_corr_epi_params, + /*barrier init*/ cute::true_type{}); + + typename CollectiveMainloop::OrderBarrierSoftmax::Params params_order_s01; + params_order_s01.group_id = role == WarpRole::Softmax1 ? 1 : 0; + params_order_s01.group_size = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::OrderBarrierSoftmax order_s01( + shared_storage.pipelines.order_s01, params_order_s01); + + TmemAllocator tmem_allocator; + + __syncthreads(); + + pipeline_load_q.init_masks(ClusterShape{}); + pipeline_load_kv.init_masks(ClusterShape{}); + pipeline_mma_s0.init_masks(ClusterShape{}); + pipeline_mma_s1.init_masks(ClusterShape{}); + pipeline_mma_corr.init_masks(ClusterShape{}); + + typename CollectiveMainloop::PipelineQ::PipelineState pipeline_load_q_consumer_state; + typename CollectiveMainloop::PipelineQ::PipelineState pipeline_load_q_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineKV::PipelineState pipeline_load_kv_consumer_state; + typename CollectiveMainloop::PipelineKV::PipelineState pipeline_load_kv_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s0_consumer_state; + typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s0_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s1_consumer_state; + typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s1_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineC::PipelineState pipeline_s0_corr_consumer_state; + typename CollectiveMainloop::PipelineC::PipelineState pipeline_s0_corr_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineC::PipelineState pipeline_s1_corr_consumer_state; + typename CollectiveMainloop::PipelineC::PipelineState pipeline_s1_corr_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineE::PipelineState pipeline_corr_epi_consumer_state; + typename CollectiveMainloop::PipelineE::PipelineState pipeline_corr_epi_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_consumer_state; + typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_producer_state = cutlass::make_producer_start_state(); + + CollectiveMainloop mainloop; + CollectiveEpilogue epilogue{params.epilogue}; + + if (role == WarpRole::Softmax0 || role == WarpRole::Softmax1) { + warpgroup_reg_set(); + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = apply_batch(params, + params.problem_shape, get<2,1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + if (get<1>(logical_problem_shape) == 0) { + continue; + } + + bool is_softmax_0 = role == WarpRole::Softmax0; + + mainloop.softmax( + is_softmax_0 ? 0 : 1, blk_coord, + params.mainloop, logical_problem_shape, + is_softmax_0 ? pipeline_mma_s0 : pipeline_mma_s1, + is_softmax_0 ? pipeline_mma_s0_consumer_state : pipeline_mma_s1_consumer_state, + is_softmax_0 ? pipeline_s0_corr : pipeline_s1_corr, + is_softmax_0 ? pipeline_s0_corr_producer_state : pipeline_s1_corr_producer_state, + order_s01 + ); + + } + } + else if (role == WarpRole::Correction) { + cutlass::arch::warpgroup_reg_dealloc(); + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = apply_batch(params, + params.problem_shape, get<2,1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + if (get<1>(logical_problem_shape) == 0) { + mainloop.correction_empty( + blk_coord, + params.mainloop, logical_problem_shape, + params.problem_shape, + epilogue_storage, + pipeline_corr_epi, pipeline_corr_epi_producer_state, + epilogue + ); + continue; + } + + mainloop.correction( + blk_coord, + params.mainloop, logical_problem_shape, + params.problem_shape, + epilogue_storage, + pipeline_s0_corr, pipeline_s0_corr_consumer_state, + pipeline_s1_corr, pipeline_s1_corr_consumer_state, + pipeline_mma_corr, pipeline_mma_corr_consumer_state, + pipeline_corr_epi, pipeline_corr_epi_producer_state, + epilogue + ); + + } + + if constexpr (NumWarpsEpilogue == 0) { + static_assert(NumWarpsCorrection == 1); + + uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + + } + else if (role == WarpRole::MMA) { + warpgroup_reg_set(); + + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = apply_batch(params, + params.problem_shape, get<2,1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + if (get<1>(logical_problem_shape) == 0) { + continue; + } + + mainloop.mma( + blk_coord, + params.mainloop, logical_problem_shape, + shared_storage.mainloop_epilogue.mainloop, + pipeline_load_q, pipeline_load_q_consumer_state, + pipeline_load_kv, pipeline_load_kv_consumer_state, + pipeline_mma_s0, pipeline_mma_s0_producer_state, + pipeline_mma_s1, pipeline_mma_s1_producer_state, + pipeline_mma_corr, pipeline_mma_corr_producer_state + ); + + } + } + else if (role == WarpRole::Load) { + warpgroup_reg_set(); + + if constexpr (IsMla && CollectiveMainloop::IsOrderLoadEpilogue) { + cutlass::arch::NamedBarrier::arrive((NumWarpsLoad + NumWarpsEpilogue) * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + } + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = apply_batch(params, + params.problem_shape, get<2,1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + if (get<1>(logical_problem_shape) == 0) { + continue; + } + + mainloop.load( + blk_coord, logical_problem_shape, + params.mainloop, params.problem_shape, + shared_storage.mainloop_epilogue.mainloop, + pipeline_load_q, pipeline_load_q_producer_state, + pipeline_load_kv, pipeline_load_kv_producer_state + ); + + } + } + else if (role == WarpRole::Epilogue) { + warpgroup_reg_set(); + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = apply_batch(params, + params.problem_shape, get<2,1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + epilogue.store( + blk_coord, logical_problem_shape, + params.epilogue, params.problem_shape, + epilogue_storage, + pipeline_corr_epi, pipeline_corr_epi_consumer_state + ); + + } + + static_assert(NumWarpsEpilogue <= 1); + if constexpr (NumWarpsEpilogue == 1) { + uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + + } + else if (role == WarpRole::Empty) { + warpgroup_reg_set(); + + /* no-op, donate regs and exit */ + } + } + +}; + +} // namespace cutlass::fmha::kernel diff --git a/csrc/sm100/pybind.cu b/csrc/sm100/pybind.cu new file mode 100644 index 0000000..7d4744d --- /dev/null +++ b/csrc/sm100/pybind.cu @@ -0,0 +1,17 @@ +#include + +void FMHACutlassSM100FwdRun(at::Tensor workspace_buffer, at::Tensor q, at::Tensor k, at::Tensor v, + at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv, + at::Tensor o, at::Tensor lse, + int mask_mode_code, float softmax_scale, int max_seqlen_q, int max_seqlen_kv, bool is_varlen); + +void FMHACutlassSM100BwdRun(at::Tensor workspace_buffer, at::Tensor d_o, at::Tensor q, at::Tensor k, + at::Tensor v, at::Tensor o, at::Tensor lse, + at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv, + at::Tensor dq, at::Tensor dk, at::Tensor dv, + int mask_mode_code, float softmax_scale, int max_seqlen_q, int max_seqlen_kv, bool is_varlen); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("fwd", &FMHACutlassSM100FwdRun); + m.def("bwd", &FMHACutlassSM100BwdRun); +} diff --git a/csrc/flash_api.cpp b/csrc/sm90/flash_api.cpp similarity index 100% rename from csrc/flash_api.cpp rename to csrc/sm90/flash_api.cpp diff --git a/csrc/kernels/config.h b/csrc/sm90/kernels/config.h similarity index 100% rename from csrc/kernels/config.h rename to csrc/sm90/kernels/config.h diff --git a/csrc/kernels/get_mla_metadata.cu b/csrc/sm90/kernels/get_mla_metadata.cu similarity index 100% rename from csrc/kernels/get_mla_metadata.cu rename to csrc/sm90/kernels/get_mla_metadata.cu diff --git a/csrc/kernels/get_mla_metadata.h b/csrc/sm90/kernels/get_mla_metadata.h similarity index 100% rename from csrc/kernels/get_mla_metadata.h rename to csrc/sm90/kernels/get_mla_metadata.h diff --git a/csrc/kernels/mla_combine.cu b/csrc/sm90/kernels/mla_combine.cu similarity index 100% rename from csrc/kernels/mla_combine.cu rename to csrc/sm90/kernels/mla_combine.cu diff --git a/csrc/kernels/mla_combine.h b/csrc/sm90/kernels/mla_combine.h similarity index 100% rename from csrc/kernels/mla_combine.h rename to csrc/sm90/kernels/mla_combine.h diff --git a/csrc/kernels/params.h b/csrc/sm90/kernels/params.h similarity index 100% rename from csrc/kernels/params.h rename to csrc/sm90/kernels/params.h diff --git a/csrc/kernels/splitkv_mla.cu b/csrc/sm90/kernels/splitkv_mla.cu similarity index 100% rename from csrc/kernels/splitkv_mla.cu rename to csrc/sm90/kernels/splitkv_mla.cu diff --git a/csrc/kernels/splitkv_mla.h b/csrc/sm90/kernels/splitkv_mla.h similarity index 100% rename from csrc/kernels/splitkv_mla.h rename to csrc/sm90/kernels/splitkv_mla.h diff --git a/csrc/kernels/traits.h b/csrc/sm90/kernels/traits.h similarity index 100% rename from csrc/kernels/traits.h rename to csrc/sm90/kernels/traits.h diff --git a/csrc/kernels/utils.h b/csrc/sm90/kernels/utils.h similarity index 100% rename from csrc/kernels/utils.h rename to csrc/sm90/kernels/utils.h diff --git a/flash_mla/__init__.py b/flash_mla/__init__.py index 51b8600..d0e6faf 100644 --- a/flash_mla/__init__.py +++ b/flash_mla/__init__.py @@ -3,4 +3,7 @@ from flash_mla.flash_mla_interface import ( get_mla_metadata, flash_mla_with_kvcache, + flash_attn_varlen_func, + flash_attn_varlen_qkvpacked_func, + flash_attn_varlen_kvpacked_func, ) diff --git a/flash_mla/flash_mla_interface.py b/flash_mla/flash_mla_interface.py index 47637f8..9c669ba 100644 --- a/flash_mla/flash_mla_interface.py +++ b/flash_mla/flash_mla_interface.py @@ -2,7 +2,9 @@ import torch -import flash_mla_cuda +import flash_mla_sm90 +import flash_mla_sm100 + def get_mla_metadata( @@ -20,10 +22,10 @@ def get_mla_metadata( tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. num_splits: (batch_size + 1), dtype torch.int32. """ - return flash_mla_cuda.get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k) + return flash_mla_sm90.get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k) -def flash_mla_with_kvcache( +def flash_mla_with_kvcache_sm90( q: torch.Tensor, k_cache: torch.Tensor, block_table: torch.Tensor, @@ -52,7 +54,7 @@ def flash_mla_with_kvcache( """ if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) - out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla( + out, softmax_lse = flash_mla_sm90.fwd_kvcache_mla( q, k_cache, head_dim_v, @@ -64,3 +66,264 @@ def flash_mla_with_kvcache( num_splits, ) return out, softmax_lse + + +def _flash_attn_varlen_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + out: Optional[torch.Tensor] = None, + lse: Optional[torch.Tensor] = None, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + qo_total_len, num_qo_heads, head_dim_qk = q.shape + kv_total_len, num_kv_heads, head_dim_vo = v.shape + + mask_mode_code = 1 if causal else 0 + if softmax_scale is None: + softmax_scale = head_dim_qk ** (-0.5) + + if out is None: + out = torch.empty(qo_total_len, num_qo_heads, head_dim_vo, device=q.device, dtype=q.dtype) + if lse is None: + # Make lse contiguous on seqlen dim + lse = torch.empty(num_qo_heads, qo_total_len, device=q.device, dtype=torch.float32).T + + workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.uint8, device=q.device) + flash_mla_sm100.fwd( + workspace_buffer, + q, + k, + v, + cu_seqlens_qo, + cu_seqlens_kv, + out, + lse, + mask_mode_code, + softmax_scale, + max_seqlen_qo, + max_seqlen_kv, + is_varlen, + ) + + return out, lse + + +def _flash_attn_varlen_backward( + do: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + lse: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + dq: Optional[torch.Tensor] = None, + dk: Optional[torch.Tensor] = None, + dv: Optional[torch.Tensor] = None, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + qo_total_len, num_qo_heads, head_dim_qk = q.shape + kv_total_len, num_kv_heads, head_dim_vo = v.shape + + # TODO: fix bwd GQA + if num_qo_heads != num_kv_heads: + raise ValueError(f"SM100 bwd doesn't support GQA now. num_qo_heads: {num_qo_heads}, num_kv_heads: {num_kv_heads}.") + + mask_mode_code = 1 if causal else 0 + if softmax_scale is None: + softmax_scale = head_dim_qk ** (-0.5) + + if dq is None: + dq = torch.empty(qo_total_len, num_qo_heads, head_dim_qk, device=q.device, dtype=q.dtype) + if dk is None: + dk = torch.empty(kv_total_len, num_kv_heads, head_dim_qk, device=q.device, dtype=q.dtype) + if dv is None: + dv = torch.empty(kv_total_len, num_kv_heads, head_dim_vo, device=q.device, dtype=q.dtype) + + max_seqlen_qo_aligned = (max_seqlen_qo + 7) // 8 * 8 + bs = cu_seqlens_qo.shape[0] - 1 + workspace_bytes = 0 + workspace_bytes += 4 * qo_total_len * num_qo_heads * head_dim_qk # dQ_acc + workspace_bytes += 4 * max_seqlen_qo_aligned * bs * num_qo_heads * 2 # sum_OdO and scaled_lse + if num_qo_heads != num_kv_heads: + workspace_bytes += 2 * kv_total_len * num_qo_heads * (head_dim_qk + head_dim_vo) # dKV_acc + workspace_buffer = torch.empty(workspace_bytes, dtype=torch.uint8, device=q.device) + flash_mla_sm100.bwd( + workspace_buffer, + do, + q, + k, + v, + out, + lse, + cu_seqlens_qo, + cu_seqlens_kv, + dq, + dk, + dv, + mask_mode_code, + softmax_scale, + max_seqlen_qo, + max_seqlen_kv, + is_varlen, + ) + + return dq, dk, dv + + +class FlashAttnVarlenFunc(torch.autograd.Function): + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, + ): + out, lse = _flash_attn_varlen_forward( + q, k, v, + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal=causal, softmax_scale=softmax_scale, + is_varlen=is_varlen, + ) + ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv) + ctx.max_seqlen_qo = max_seqlen_qo + ctx.max_seqlen_kv = max_seqlen_kv + ctx.causal = causal + ctx.softmax_scale = softmax_scale + ctx.is_varlen = is_varlen + return out, lse + + def backward( + ctx, + do: torch.Tensor, + dlse: torch.Tensor, + ): + del dlse # LSE doesn't support backward currently + q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv = ctx.saved_tensors + dq, dk, dv = _flash_attn_varlen_backward( + do, q, k, v, out, lse, + cu_seqlens_qo, cu_seqlens_kv, ctx.max_seqlen_qo, ctx.max_seqlen_kv, + causal=ctx.causal, softmax_scale=ctx.softmax_scale, + is_varlen=ctx.is_varlen, + ) + return dq, dk, dv, None, None, None, None, None, None, None + + +def flash_attn_varlen_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + q, k, v, + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal, softmax_scale, is_varlen, + ) + + +def flash_attn_varlen_qkvpacked_func( + qkv: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + qkv[:, :, :head_dim_qk], qkv[:, :, head_dim_qk:head_dim_qk * 2], qkv[:, :, head_dim_qk * 2:], + cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, + causal, softmax_scale, is_varlen, + ) + + +def flash_attn_varlen_kvpacked_func( + q: torch.Tensor, + kv: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + q, kv[:, :, :head_dim_qk], kv[:, :, head_dim_qk:], + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal, softmax_scale, is_varlen, + ) + + +def flash_mla_with_kvcache_sm100( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: torch.Tensor, + cache_seqlens: torch.Tensor, + head_dim_v: int, + softmax_scale: Optional[float] = None, + causal: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + # TODO + pass + + +def flash_mla_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: torch.Tensor, + cache_seqlens: torch.Tensor, + head_dim_v: int, + tile_scheduler_metadata: Optional[torch.Tensor] = None, + num_splits: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + causal: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + capability = torch.cuda.get_device_capability(q.device.index) + if capability == (9, 0): + return flash_mla_with_kvcache_sm90( + q, k_cache, block_table, cache_seqlens, head_dim_v, + tile_scheduler_metadata, num_splits, + softmax_scale, causal, + ) + elif capability == (10, 0): + raise ValueError(f"Unsupported device capability: {capability}") + else: + raise ValueError(f"Unsupported device capability: {capability}") diff --git a/setup.py b/setup.py index 217f540..58cf7b2 100644 --- a/setup.py +++ b/setup.py @@ -27,9 +27,13 @@ def get_features_args(): subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"]) -cc_flag = [] -cc_flag.append("-gencode") -cc_flag.append("arch=compute_90a,code=sm_90a") +cc_flag_sm90 = [] +cc_flag_sm90.append("-gencode") +cc_flag_sm90.append("arch=compute_90a,code=sm_90a") + +cc_flag_sm100 = [] +cc_flag_sm100.append("-gencode") +cc_flag_sm100.append("arch=compute_100a,code=sm_100a") this_dir = os.path.dirname(os.path.abspath(__file__)) @@ -41,12 +45,12 @@ def get_features_args(): ext_modules = [] ext_modules.append( CUDAExtension( - name="flash_mla_cuda", + name="flash_mla_sm90", sources=[ - "csrc/flash_api.cpp", - "csrc/kernels/get_mla_metadata.cu", - "csrc/kernels/mla_combine.cu", - "csrc/kernels/splitkv_mla.cu", + "csrc/sm90/flash_api.cpp", + "csrc/sm90/kernels/get_mla_metadata.cu", + "csrc/sm90/kernels/mla_combine.cu", + "csrc/sm90/kernels/splitkv_mla.cu", ], extra_compile_args={ "cxx": cxx_args + get_features_args(), @@ -66,12 +70,49 @@ def get_features_args(): "--use_fast_math", "--ptxas-options=-v,--register-usage-level=10" ] - + cc_flag + + cc_flag_sm90 ) + get_features_args(), }, include_dirs=[ - Path(this_dir) / "csrc", + Path(this_dir) / "csrc" / "sm90", + Path(this_dir) / "csrc" / "cutlass" / "include", + ], + ) +) + +ext_modules.append( + CUDAExtension( + name="flash_mla_sm100", + sources=[ + "csrc/sm100/pybind.cu", + "csrc/sm100/fmha_cutlass_fwd_sm100.cu", + "csrc/sm100/fmha_cutlass_bwd_sm100.cu", + ], + extra_compile_args={ + "cxx": ["-O3", "-std=c++17", "-DNDEBUG", "-Wno-deprecated-declarations"], + "nvcc": append_nvcc_threads( + [ + "-O3", + "-std=c++17", + "-DNDEBUG", + "-Wno-deprecated-declarations", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + "-lineinfo", + "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage", + ] + + cc_flag_sm100 + ), + }, + include_dirs=[ + Path(this_dir) / "csrc" / "sm100", Path(this_dir) / "csrc" / "cutlass" / "include", + Path(this_dir) / "csrc" / "cutlass" / "tools" / "util" / "include", ], ) ) diff --git a/tests/test_flash_mla.py b/tests/test_flash_mla_sm90.py similarity index 100% rename from tests/test_flash_mla.py rename to tests/test_flash_mla_sm90.py diff --git a/tests/test_fmha_sm100.py b/tests/test_fmha_sm100.py new file mode 100644 index 0000000..832c9fb --- /dev/null +++ b/tests/test_fmha_sm100.py @@ -0,0 +1,199 @@ +import random + +import torch +from torch.utils.checkpoint import checkpoint +import triton + +from flash_mla import flash_attn_varlen_func + + +def get_window_size(causal, window): + if window > 0: + window_size = (window - 1, 0) if causal else (window - 1, window - 1) + else: + window_size = (-1, -1) + return window_size + + +def get_attn_bias(s_q, s_k, causal, window): + attn_bias = torch.zeros(s_q, s_k, dtype=torch.float32) + if causal: + temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + if window > 0: + temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q - window) + attn_bias.masked_fill_(temp_mask, float("-inf")) + temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q + window - 1) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + return attn_bias + + +def assert_close(x: torch.Tensor, y: torch.Tensor, name: str) -> None: + x, y = x.double(), y.double() + RMSE = ((x - y) * (x - y)).mean().sqrt().item() + cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12) + amax_diff = (x - y).abs().max().item() + # print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}") + assert cos_diff < 1e-5, f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}" + + +def sdpa(query, key, value, attn_bias, softmax_scale=None): + key = key.repeat_interleave(h // h_k, dim=-3) + value = value.repeat_interleave(h // h_k, dim=-3) + if softmax_scale is None: + softmax_scale = query.shape[-1] ** (-0.5) + attn_weight = query @ key.transpose(-2, -1) * softmax_scale + attn_weight += attn_bias + lse = attn_weight.logsumexp(dim=-1) + attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) + return attn_weight.to(query.dtype) @ value, lse + + +def sdpa_checkpoint(*args, **kwargs): + return checkpoint(sdpa, *args, use_reentrant=False, **kwargs) + + +def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, window, has_bwd): + print(f"{b=}, {mean_sq=}, {mean_sk=}, {varlen=}, {h=}, {h_k=}, {d=}, {dv=}, {causal=}") + torch.manual_seed(0) + random.seed(0) + + seqlens_q = torch.full((b,), mean_sq, dtype=torch.int32) + seqlens_k = torch.full((b,), mean_sk, dtype=torch.int32) + + if varlen: + for i in range(b): + seqlens_q[i] = max(random.normalvariate(mean_sq, mean_sq / 2), 1) + for i in range(b): + seqlens_k[i] = max(random.normalvariate(mean_sk, mean_sk / 2), seqlens_q[i].item()) + cu_seqlens_q = torch.cumsum(torch.nn.functional.pad(seqlens_q, (1, 0)), 0, dtype=torch.int32) + cu_seqlens_k = torch.cumsum(torch.nn.functional.pad(seqlens_k, (1, 0)), 0, dtype=torch.int32) + total_q = seqlens_q.sum().item() + total_k = seqlens_k.sum().item() + max_seqlen_q = seqlens_q.max().item() + max_seqlen_k = seqlens_k.max().item() + total_attn_compute = sum([(get_attn_bias(seqlens_q[i].item(), seqlens_k[i].item(), + causal, window) == 0).sum().item() for i in range(b)]) + # print(f"{total_q=}, {max_seqlen_q=}, {total_k=}, {max_seqlen_k=}, {total_attn_compute=}, {cu_seqlens_q.tolist()}, {cu_seqlens_k.tolist()}") + + q = torch.randn(total_q, h, d) + k = torch.randn(total_k, h_k, d) + v = torch.randn(total_k, h_k, dv) + grad_out = torch.randn(total_q, h, dv) + softmax_scale = (d + 100) ** (-0.5) + + offst_q = total_q + offst_kv = total_k + + q1_with_buffer = torch.empty(total_q + total_q, h, d, device=device, dtype=dtype) + k1_with_buffer = torch.empty(offst_kv + total_k, h_k, d, device=device, dtype=dtype) + v1_with_buffer = torch.empty(offst_kv + total_k, h_k, dv, device=device, dtype=dtype) + q1_with_buffer[total_q:] = q + k1_with_buffer[offst_kv:] = k + v1_with_buffer[offst_kv:] = v + q1 = q1_with_buffer[offst_q:].requires_grad_() + k1 = k1_with_buffer[offst_kv:].requires_grad_() + v1 = v1_with_buffer[offst_kv:].requires_grad_() + + q2 = q.clone().requires_grad_() + k2 = k.clone().requires_grad_() + v2 = v.clone().requires_grad_() + + def flash_attn(): + q1.grad = k1.grad = v1.grad = None + kwargs = {} + if causal: + kwargs["causal"] = causal + if window != 0: + kwargs["window_size"] = get_window_size(causal, window) + return flash_attn_varlen_func(q1, k1, v1, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, + max_seqlen_k, softmax_scale=softmax_scale, is_varlen=varlen, **kwargs) + + def torch_attn(): + q2.grad = k2.grad = v2.grad = None + out = [] + lse = [] + for i in range(b): + OUT, LSE = sdpa_checkpoint( + q2[cu_seqlens_q[i].item(): cu_seqlens_q[i + 1].item()].float().transpose(-3, -2), + k2[cu_seqlens_k[i].item(): cu_seqlens_k[i + 1].item()].float().transpose(-3, -2), + v2[cu_seqlens_k[i].item(): cu_seqlens_k[i + 1].item()].float().transpose(-3, -2), + attn_bias=get_attn_bias(seqlens_q[i].item(), seqlens_k[i].item(), causal, window), + softmax_scale=softmax_scale, + ) + out.append(OUT.transpose(-3, -2)) + lse.append(LSE.transpose(-2, -1)) + out = torch.cat(out) + lse = torch.cat(lse) + return out, lse + + out_flash, lse_flash = flash_attn() + out_torch, lse_torch = torch_attn() + assert_close(out_flash, out_torch, "out") + assert_close(lse_flash, lse_torch, "lse") + + if has_bwd: + out_flash.backward(grad_out, retain_graph=True) + out_torch.backward(grad_out, retain_graph=True) + assert_close(q1.grad, q2.grad, "dq") + assert_close(k1.grad, k2.grad, "dk") + assert_close(v1.grad, v2.grad, "dv") + dq1 = q1.grad.clone() + dk1 = k1.grad.clone() + dv1 = v1.grad.clone() + + def forward(): + return flash_attn() + + def backward(): + q1.grad = k1.grad = v1.grad = None + out_flash.backward(grad_out, retain_graph=True) + + for _ in range(5): + out, lse = forward() + assert torch.equal(out, out_flash), "out deterministic check failed!" + assert torch.equal(lse, lse_flash), "lse deterministic check failed!" + if has_bwd: + backward() + # assert torch.equal(q1.grad, dq1), "dq deterministic check failed!" + assert torch.equal(k1.grad, dk1), "dk deterministic check failed!" + assert torch.equal(v1.grad, dv1), "dv deterministic check failed!" + + # with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: + # forward() + # if has_bwd: + # backward() + # print(prof.key_averages().table(sort_by="cuda_time_total", max_name_column_width=120)) + + def timer(func, name): + t = triton.testing.do_bench(func, warmup=2, rep=3) + FLOPS = total_attn_compute * h * 2 * ((d + dv) if name == "fwd" else ((d * 3 + dv * 2))) + print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOP/s, name: {name}") + return t + + timer(forward, "fwd") + if has_bwd: + timer(backward, "bwd") + + +if __name__ == "__main__": + dtype = torch.bfloat16 + torch.set_default_dtype(dtype) + device = torch.device("cuda:0") + torch.set_default_device(device) + torch.cuda.set_device(device) + + b = 4 + window = 0 + has_bwd = False + + for (mean_sq, mean_sk) in [(4096, 4096), (8192, 8192)]: + for varlen in [False, True]: + for (h, h_k) in [(32, 32), (32, 4)]: + if h != h_k: + has_bwd = False + else: + has_bwd = True + for (d, dv) in [(128, 128), (192, 128)]: + for causal in [False, True]: + test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, window, has_bwd)