From 01a27728e664377097d39b1ad9fc68cd6f570fc5 Mon Sep 17 00:00:00 2001 From: ljss <450993438@qq.com> Date: Mon, 28 Apr 2025 18:53:04 +0800 Subject: [PATCH 01/20] Fix synchronization issues --- .gitignore | 1 + csrc/kernels/splitkv_mla.cu | 10 +++++++--- csrc/kernels/traits.h | 3 ++- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index 982daef..9b500a0 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ dist/ *perf.csv *.png /.vscode +compile_commands.json diff --git a/csrc/kernels/splitkv_mla.cu b/csrc/kernels/splitkv_mla.cu index ff29305..5e1fded 100644 --- a/csrc/kernels/splitkv_mla.cu +++ b/csrc/kernels/splitkv_mla.cu @@ -1017,13 +1017,14 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params cudaGridDependencySynchronize(); int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize; - int4 tile_scheduler_metadata = __ldg(reinterpret_cast(tile_scheduler_metadata_ptr)); + // We don't use __ldg here, otherwise NVCC (ptxas, in particular) will do instruction reorder and place __ldg (LDG.E.128.CONSTANT in SASS) in front of cudaGridDependencySynchronize() (ACQBULK in SASS), leading to data race. + int4 tile_scheduler_metadata = *(reinterpret_cast(tile_scheduler_metadata_ptr)); int begin_idx = tile_scheduler_metadata.x; int begin_seqlen = tile_scheduler_metadata.y; int end_idx = tile_scheduler_metadata.z; int end_seqlen = tile_scheduler_metadata.w; if (begin_idx >= params.b) return; - int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4); + int begin_n_split_idx = *(tile_scheduler_metadata_ptr + 4); // Copy the first Q launch_q_copy(tma_params, begin_idx, m_block_idx, k_head_idx, sQ, barrier_Q); @@ -1123,6 +1124,8 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params // Issue P0 = Q @ K0^T, wait warpgroup_cooperative_qkt_gemm_no_pipeline(sQ, sK0, rP0, idx_in_warpgroup); + // We add a barrier here, making sure that previous writes to sM are visible to warpgroup 0 + NamedBarrier::arrive_and_wait(128, NamedBarriers::sMInitialized); cute::warpgroup_wait<0>(); #define LAUNCH_WG0_SUBROUTINE(IS_BLK0_LAST, IS_BLK1_LAST) \ @@ -1238,7 +1241,8 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params cute::tma_store_wait<0>(); } else { - int split_idx = __ldg(params.num_splits_ptr+batch_idx) + n_split_idx; + // Don't use __ldg because of PDL and instruction reordering + int split_idx = params.num_splits_ptr[batch_idx] + n_split_idx; float* oaccum_ptr = (float*)params.oaccum_ptr + ((split_idx*params.h_k + k_head_idx)*params.q_seq_per_hk + m_block_idx*T::BLOCK_SIZE_M)*T::HEAD_DIM_V; // (BLOCK_SIZE_M, HEAD_DIM_V) : (HEAD_DIM_V, 1) float* softmax_lseaccum_ptr = (float*)params.softmax_lseaccum_ptr + (split_idx*params.h_k + k_head_idx)*params.q_seq_per_hk + m_block_idx*T::BLOCK_SIZE_M; // (BLOCK_SIZE_M) : (1) Tensor gOAccum = make_tensor(make_gmem_ptr(oaccum_ptr), Layout< diff --git a/csrc/kernels/traits.h b/csrc/kernels/traits.h index 31c1388..5f915a6 100644 --- a/csrc/kernels/traits.h +++ b/csrc/kernels/traits.h @@ -102,5 +102,6 @@ enum NamedBarriers : int { sScale0Ready = 0, sScale1Ready = 1, sP0Ready = 2, - rO1sP0sV0RIssued = 3 + rO1sP0sV0RIssued = 3, + sMInitialized = 4, }; From 9c5dfab6d1746b4a27af14f440e7afd5c01ece68 Mon Sep 17 00:00:00 2001 From: ljss <450993438@qq.com> Date: Tue, 29 Apr 2025 12:02:57 +0800 Subject: [PATCH 02/20] update to cutlass 3.9 --- csrc/cutlass | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/cutlass b/csrc/cutlass index afa1772..e94e888 160000 --- a/csrc/cutlass +++ b/csrc/cutlass @@ -1 +1 @@ -Subproject commit afa1772203677c5118fcd82537a9c8fefbcc7008 +Subproject commit e94e888df3551224738bfa505787b515eae8352f From 9edee0c022cd0938148a18e334203b0aab43aa19 Mon Sep 17 00:00:00 2001 From: ljss <450993438@qq.com> Date: Tue, 29 Apr 2025 12:03:15 +0800 Subject: [PATCH 03/20] update .gitignore --- .gitignore | 1 + setup.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index 9b500a0..4535280 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ dist/ *.png /.vscode compile_commands.json +.cache diff --git a/setup.py b/setup.py index 131ceff..217f540 100644 --- a/setup.py +++ b/setup.py @@ -11,10 +11,12 @@ IS_WINDOWS, ) + def append_nvcc_threads(nvcc_extra_args): nvcc_threads = os.getenv("NVCC_THREADS") or "32" return nvcc_extra_args + ["--threads", nvcc_threads] + def get_features_args(): features_args = [] DISABLE_FP16 = os.getenv("FLASH_MLA_DISABLE_FP16", "FALSE") in ["TRUE", "1"] From 41b611f7d7561790a2f5040ff89212e08c7b0011 Mon Sep 17 00:00:00 2001 From: Zeyu WANG Date: Fri, 1 Aug 2025 17:21:27 +0800 Subject: [PATCH 04/20] Add more GPU architctures support (#76) * Add more GPU architctures support * Merge fmha and mla runner * add varlen & non varlen support, and add incontiguous tensor support * update readme * add varlen api --------- Co-authored-by: dianzhangc --- README.md | 12 +- csrc/sm100/collective/fmha_common.hpp | 127 ++ csrc/sm100/collective/fmha_fusion.hpp | 396 ++++ ..._fmha_fwd_epilogue_tma_warpspecialized.hpp | 234 +++ ..._fmha_fwd_mainloop_tma_warpspecialized.hpp | 1218 +++++++++++ .../sm100_fmha_load_tma_warpspecialized.hpp | 316 +++ ...a_mla_fwd_mainloop_tma_warpspecialized.hpp | 1225 +++++++++++ ...m100_fmha_mla_load_tma_warpspecialized.hpp | 340 +++ csrc/sm100/common/gather_tensor.hpp | 215 ++ csrc/sm100/common/helper.h | 72 + csrc/sm100/common/mask.cuh | 8 + csrc/sm100/common/pipeline_mla.hpp | 250 +++ csrc/sm100/common/pow_2.hpp | 92 + csrc/sm100/common/utils.hpp | 83 + csrc/sm100/device/fmha.hpp | 276 +++ csrc/sm100/device/fmha_device_bwd.hpp | 340 +++ csrc/sm100/fmha_cutlass_bwd_sm100.cu | 83 + csrc/sm100/fmha_cutlass_bwd_sm100.cuh | 200 ++ csrc/sm100/fmha_cutlass_fwd_sm100.cu | 81 + csrc/sm100/fmha_cutlass_fwd_sm100.cuh | 334 +++ .../kernel/fmha_causal_tile_scheduler.hpp | 197 ++ csrc/sm100/kernel/fmha_kernel_bwd_convert.hpp | 153 ++ csrc/sm100/kernel/fmha_kernel_bwd_sum_OdO.hpp | 161 ++ csrc/sm100/kernel/fmha_options.hpp | 85 + csrc/sm100/kernel/fmha_tile_scheduler.hpp | 162 ++ ...00_fmha_bwd_kernel_tma_warpspecialized.hpp | 1841 +++++++++++++++++ ...mha_bwd_mla_kernel_tma_warpspecialized.hpp | 1834 ++++++++++++++++ ...00_fmha_fwd_kernel_tma_warpspecialized.hpp | 619 ++++++ csrc/sm100/pybind.cu | 17 + csrc/{ => sm90}/flash_api.cpp | 0 csrc/{ => sm90}/kernels/config.h | 0 csrc/{ => sm90}/kernels/get_mla_metadata.cu | 0 csrc/{ => sm90}/kernels/get_mla_metadata.h | 0 csrc/{ => sm90}/kernels/mla_combine.cu | 0 csrc/{ => sm90}/kernels/mla_combine.h | 0 csrc/{ => sm90}/kernels/params.h | 0 csrc/{ => sm90}/kernels/splitkv_mla.cu | 0 csrc/{ => sm90}/kernels/splitkv_mla.h | 0 csrc/{ => sm90}/kernels/traits.h | 0 csrc/{ => sm90}/kernels/utils.h | 0 flash_mla/__init__.py | 3 + flash_mla/flash_mla_interface.py | 271 ++- setup.py | 61 +- ...st_flash_mla.py => test_flash_mla_sm90.py} | 0 tests/test_fmha_sm100.py | 199 ++ 45 files changed, 11489 insertions(+), 16 deletions(-) create mode 100644 csrc/sm100/collective/fmha_common.hpp create mode 100644 csrc/sm100/collective/fmha_fusion.hpp create mode 100644 csrc/sm100/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp create mode 100644 csrc/sm100/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp create mode 100644 csrc/sm100/collective/sm100_fmha_load_tma_warpspecialized.hpp create mode 100644 csrc/sm100/collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp create mode 100644 csrc/sm100/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp create mode 100644 csrc/sm100/common/gather_tensor.hpp create mode 100644 csrc/sm100/common/helper.h create mode 100644 csrc/sm100/common/mask.cuh create mode 100644 csrc/sm100/common/pipeline_mla.hpp create mode 100644 csrc/sm100/common/pow_2.hpp create mode 100644 csrc/sm100/common/utils.hpp create mode 100644 csrc/sm100/device/fmha.hpp create mode 100644 csrc/sm100/device/fmha_device_bwd.hpp create mode 100644 csrc/sm100/fmha_cutlass_bwd_sm100.cu create mode 100644 csrc/sm100/fmha_cutlass_bwd_sm100.cuh create mode 100644 csrc/sm100/fmha_cutlass_fwd_sm100.cu create mode 100644 csrc/sm100/fmha_cutlass_fwd_sm100.cuh create mode 100644 csrc/sm100/kernel/fmha_causal_tile_scheduler.hpp create mode 100644 csrc/sm100/kernel/fmha_kernel_bwd_convert.hpp create mode 100644 csrc/sm100/kernel/fmha_kernel_bwd_sum_OdO.hpp create mode 100644 csrc/sm100/kernel/fmha_options.hpp create mode 100644 csrc/sm100/kernel/fmha_tile_scheduler.hpp create mode 100644 csrc/sm100/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp create mode 100644 csrc/sm100/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp create mode 100644 csrc/sm100/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp create mode 100644 csrc/sm100/pybind.cu rename csrc/{ => sm90}/flash_api.cpp (100%) rename csrc/{ => sm90}/kernels/config.h (100%) rename csrc/{ => sm90}/kernels/get_mla_metadata.cu (100%) rename csrc/{ => sm90}/kernels/get_mla_metadata.h (100%) rename csrc/{ => sm90}/kernels/mla_combine.cu (100%) rename csrc/{ => sm90}/kernels/mla_combine.h (100%) rename csrc/{ => sm90}/kernels/params.h (100%) rename csrc/{ => sm90}/kernels/splitkv_mla.cu (100%) rename csrc/{ => sm90}/kernels/splitkv_mla.h (100%) rename csrc/{ => sm90}/kernels/traits.h (100%) rename csrc/{ => sm90}/kernels/utils.h (100%) rename tests/{test_flash_mla.py => test_flash_mla_sm90.py} (100%) create mode 100644 tests/test_fmha_sm100.py diff --git a/README.md b/README.md index 5d66f55..07e021a 100644 --- a/README.md +++ b/README.md @@ -28,13 +28,21 @@ Currently released: ### Install ```bash -python setup.py install +pip install -v . ``` ### Benchmark +#### Testing MLA Decoding + +```bash +python tests/test_flash_mla_sm90.py +``` + +#### Testing MLA Forward/Backward + ```bash -python tests/test_flash_mla.py +python tests/test_fmha_sm100.py ``` It is able up to 3000 GB/s in memory-bound configuration and 660 TFLOPS in computation-bound configuration on H800 SXM5, using CUDA 12.8. diff --git a/csrc/sm100/collective/fmha_common.hpp b/csrc/sm100/collective/fmha_common.hpp new file mode 100644 index 0000000..c60d9e9 --- /dev/null +++ b/csrc/sm100/collective/fmha_common.hpp @@ -0,0 +1,127 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/kernel_hardware_info.h" +#include "cutlass/arch/reg_reconfig.h" +#include "cute/tensor.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; + +template +CUTE_DEVICE void gemm_reset_zero_acc(Atom& atom, TA const& tA, TB const& tB, TC&& tC) { + constexpr int rA = decltype(rank(tA))::value; + constexpr int rB = decltype(rank(tB))::value; + constexpr int rC = decltype(rank(tC))::value; + static_assert(rA == 3 && rB == 3 && rC == 3); + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tA); k_block++) { + cute::gemm(atom, tA(_,_,k_block), tB(_,_,k_block), tC); + atom.accumulate_ = decltype(atom.accumulate_)::One; + } +} + +template +CUTE_DEVICE void gemm_zero_acc(Atom& atom, TA const& tA, TB const& tB, TC&& tC) { + atom.accumulate_ = decltype(atom.accumulate_)::Zero; + gemm_reset_zero_acc(atom, tA, tB, tC); +} + +template +CUTE_DEVICE constexpr auto unstageSmemLayout(Layout const& layout, Stages stages = {}) { + return composition(layout, prepend(make_layout(stages), _)); +} + +template +CUTE_DEVICE T warp_uniform(T a) { + return __shfl_sync(0xffffffff, a, 0); +} + +template +CUTE_HOST_DEVICE constexpr +auto +to_tiled_mma_sm100_ts( + TiledMMA, cute::C, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant>, + TAs...>, TMs...>) { + + return TiledMMA>, + TAs...>, TMs...>{}; +} + +template +CUTE_HOST_DEVICE constexpr +auto +to_tiled_mma_sm100_ts( + TiledMMA, + TAs...>, TMs...>) { + return TiledMMA, + TAs...>, TMs...>{}; +} + +template +CUTLASS_DEVICE +void warpgroup_reg_set() { + if constexpr (RegCount < 128) { + cutlass::arch::warpgroup_reg_dealloc(); + } + else { + cutlass::arch::warpgroup_reg_alloc(); + } +} + +} // namespace cutlass::fmha::collective diff --git a/csrc/sm100/collective/fmha_fusion.hpp b/csrc/sm100/collective/fmha_fusion.hpp new file mode 100644 index 0000000..1486767 --- /dev/null +++ b/csrc/sm100/collective/fmha_fusion.hpp @@ -0,0 +1,396 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; + +struct NoMask { + template + CUTLASS_DEVICE + int get_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + return ceil_div(get<1>(problem_size), get<1>(tile_shape)); + } + + template + CUTLASS_DEVICE + int get_masked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + return 0; + } + + template + CUTLASS_DEVICE + int get_unmasked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + return get_trip_count(blk_coord, tile_shape, problem_size); + } + + template + CUTLASS_DEVICE + void apply_mask( + AccQK& acc_qk, + IndexQK const& index_qk, + ProblemSize const& problem_size) { + + return; + } +}; + +struct ResidualMask : NoMask { + + using Base = NoMask; + + template + CUTLASS_DEVICE int get_masked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + if (get<1>(problem_size) % get<1>(tile_shape) != 0) { + return 1; + } + return 0; + } + + template + CUTLASS_DEVICE + int get_unmasked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + // if the sequence length does not divide the tile size evenly + if (get<1>(problem_size) % get<1>(tile_shape) != 0) { + return get_trip_count(blk_coord, tile_shape, problem_size) - 1; + } + return get_trip_count(blk_coord, tile_shape, problem_size); + } + + template + CUTLASS_DEVICE + void apply_mask( + AccQK& acc_qk, + IndexQK const& index_qk, + ProblemSize const& problem_size) { + + // This is useful is seqlen_k % kBlockN != 0 since it masks + // the remaining elements out from softmax. + // d % kHeadDim != 0 or seqlen_q % kBlockM do not suffer from similar + // issues as they are transparently taken care of by TMA and the + // epilogue, if it is instantiated with predication support. + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_qk); i++) { + auto pos = index_qk(i); + if (get<1>(pos) >= get<1>(problem_size)) { + acc_qk(i) = -INFINITY; + } + } + } +}; + +struct ResidualMaskForBackward : NoMask { + + using Base = NoMask; + + template + CUTLASS_DEVICE int get_masked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + if (get<1>(problem_size) % get<1>(tile_shape) != 0) { + return 1; + } + return 0; + } + + template + CUTLASS_DEVICE + int get_unmasked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + // if the sequence length does not divide the tile size evenly + if (get<1>(problem_size) % get<1>(tile_shape) != 0) { + return get_trip_count(blk_coord, tile_shape, problem_size) - 1; + } + return get_trip_count(blk_coord, tile_shape, problem_size); + } + + template + CUTLASS_DEVICE + void apply_mask( + AccQK& acc_qk, + IndexQK const& index_qk, + ProblemSize const& problem_size) { + + // This is useful is seqlen_k % kBlockN != 0 since it masks + // the remaining elements out from softmax. + // d % kHeadDim != 0 or seqlen_q % kBlockM do not suffer from similar + // issues as they are transparently taken care of by TMA and the + // epilogue, if it is instantiated with predication support. + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_qk); i++) { + auto pos = index_qk(i); + if (! elem_less(pos, select<0,1>(problem_size))) { + acc_qk(i) = -INFINITY; + } + } + } +}; + +// There are two ways to do causal if N_Q != N_K +// (1) The Q is at the beginning of the matrix +// (2) The Q is at the end of the matrix +template +struct CausalMask : NoMask { + + using Base = NoMask; + + static constexpr bool IsQBegin = kIsQBegin; + + template + CUTLASS_DEVICE + int get_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + // See note below on different ways to think about causal attention + // Again, we'd add the offset_q into the max_blocks_q calculation + int max_blocks_k = Base::get_trip_count(blk_coord, tile_shape, problem_size); + if constexpr (IsQBegin) { + int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape), get<1>(tile_shape)); + return std::min(max_blocks_k, max_blocks_q); + } else { + const int offset_q = get<1>(problem_size) - get<0>(problem_size); + int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape) + offset_q, get<1>(tile_shape)); + return std::min(max_blocks_k, max_blocks_q); + } + } + + template + CUTLASS_DEVICE + int get_masked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + int trip_count = get_trip_count(blk_coord, tile_shape, problem_size); + if constexpr (IsQBegin) { + return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape)))); + } else { + const int offset_tile_q = get<1>(problem_size) % get<1>(tile_shape); + return std::min(trip_count, int(ceil_div(get<0>(tile_shape) + offset_tile_q, get<1>(tile_shape)))); + } + } + + template + CUTLASS_DEVICE + int get_unmasked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + return get_trip_count(blk_coord, tile_shape, problem_size) - get_masked_trip_count(blk_coord, tile_shape, problem_size); + } + + template + CUTLASS_DEVICE + void apply_mask( + AccQK& acc_qk, + IndexQK const& index_qk, + ProblemSize const& problem_size) { + + // There are two ways to do causal if N_Q != N_K + // (1) is to assume that the Q is at the beginning of the matrix + // - this is the default setting. + // (2) is that it is at the end of the matrix + // - this is usually what we want for inference settings + // where we only compute the next row and use cache for the rest + // - if you'd like this, you only need to set kIsQBegin=false + + if constexpr (IsQBegin) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_qk); i++) { + auto pos = index_qk(i); + if ((get<0>(pos) < get<1>(pos)) || (get<1>(pos) >= get<1>(problem_size))) { + acc_qk(i) = -INFINITY; + } + } + } else { + const auto offset_q = get<1>(problem_size) - get<0>(problem_size); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_qk); i++) { + auto pos = index_qk(i); + if ((get<0>(pos) + offset_q < get<1>(pos)) || (get<1>(pos) >= get<1>(problem_size))) { + acc_qk(i) = -INFINITY; + } + } + } + } +}; + +template +struct CausalForBackwardMask : CausalMask, ResidualMaskForBackward { + + using Base = CausalMask; + + template + CUTLASS_DEVICE + void apply_mask( + AccQK& acc_qk, + IndexQK const& index_qk, + ProblemSize const& problem_size) { + + // There are two ways to do causal if N_Q != N_K + // (1) is to assume that the Q is at the beginning of the matrix + // - this is what we demonstrate here + // (2) is that it is at the end of the matrix + // - this is usually what we want for inference settings + // where we only compute the next row and use cache for the rest + // - if you'd like this, you only need to add an offset like so: + // get<0>(pos) + offset_q < get<1>(pos) + int offset_q = 0; + if constexpr (!kIsQBegin) { + offset_q = get<1>(problem_size) - get<0>(problem_size); + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_qk); i++) { + auto pos = index_qk(i); + bool masked = (get<0>(pos) + offset_q < get<1>(pos)) || !elem_less(pos, problem_size); + if (masked) { + acc_qk(i) = -INFINITY; + } + } + } + +}; + +struct VariableLength { + int max_length; + int* cumulative_length = nullptr; + int total_length = -1; + + CUTE_HOST_DEVICE operator int() const { + return max_length; + } +}; + +template struct is_variable_length_impl : std::false_type {}; +template<> struct is_variable_length_impl : std::true_type {}; +template constexpr bool is_variable_length_v = is_variable_length_impl>::value; + +template +CUTE_HOST_DEVICE +constexpr auto +apply_variable_length(Shape const& shape, Idx const& idx) { + return transform_leaf(shape, [&](auto const& s) { + if constexpr (is_variable_length_v) { + return s.cumulative_length[idx+1] - s.cumulative_length[idx]; + } + else { + return s; + } + }); +} + +template +CUTE_HOST_DEVICE +constexpr auto +apply_variable_length(Shape const& shape, Coord const& coord, Idx const& idx) { + auto new_shape = apply_variable_length(shape, idx); + auto new_coord = transform_leaf(shape, coord, [&](auto const& s, auto const& c) { + if constexpr (is_variable_length_v) { + return cute::make_tuple(c, s.cumulative_length[idx]); + } + else { + return c; + } + }); + return cute::make_tuple(new_shape, new_coord); +} + +template +CUTE_HOST_DEVICE +constexpr auto +apply_variable_length_offset(Shape const& shape, Coord const& coord) { + auto idx = back(back(coord)); + auto result_shape = transform_leaf(shape, [&](auto const& s) { + if constexpr (is_variable_length_v) { + return s.cumulative_length[idx+1] - s.cumulative_length[idx]; + } + else { + return s; + } + }); + auto result_offset = transform_leaf(coord, shape, [&](auto const& c, auto const& s) { + if constexpr (is_variable_length_v) { + return s.cumulative_length[idx]; + } + else { + return _0{}; + } + }); + return cute::make_tuple(result_shape, result_offset); +} + +} // namespace cutlass::fmha::collective + +namespace cute { + +template<> +struct is_integral : true_type {}; + +CUTE_HOST_DEVICE +void print(cutlass::fmha::collective::VariableLength a) { + printf("Varlen<%d, %p>", a.max_length, a.cumulative_length); +} + +} diff --git a/csrc/sm100/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp b/csrc/sm100/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp new file mode 100644 index 0000000..616357c --- /dev/null +++ b/csrc/sm100/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp @@ -0,0 +1,234 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/layout.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +namespace cutlass::fmha::collective { + +template< + class Element, + class ElementAcc, + class TileShape, // Q, D, _ + class StrideO, // Q, D, B + class StrideLSE_, // Q, B + class OrderLoadEpilogue = cute::false_type +> +struct Sm100FmhaFwdEpilogueTmaWarpspecialized { + + using Pipeline = cutlass::PipelineAsync<2>; + +// using SmemLayoutO = decltypa(make_layout(append<3>(select<0,1>(TileShape_WG{}), _2{}))); + using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + cute::UMMA::Major::K, Element, tuple_element_t<0, TileShape>, tuple_element_t<1, TileShape>>()); +// using SmemLayoutAtomO = decltype(make_ordered_layout(select<0,1>(TileShape{}), Step<_1, _0>{})); + using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, replace<2>(TileShape{}, _2{}), Step<_2, _1, _3>{})); + using SmemLayoutO_ = SmemLayoutO; + using StrideLSE = StrideLSE_; + using ElementOut = Element; + + static const int NumWarpsEpilogue = 1; + static const int NumWarpsLoad = 1; + + struct TensorStorage { + + using SmemLayoutO = SmemLayoutO_; + cute::array_aligned> smem_o; + + }; + + struct Arguments { + Element* ptr_O; + StrideO dO; + + ElementAcc* ptr_LSE; + StrideLSE dLSE; + }; + + using TMA_O = decltype(make_tma_copy( + SM90_TMA_STORE{}, + make_tensor((Element*) nullptr, repeat_like(StrideO{}, 0), StrideO{}), + SmemLayoutO{}(_,_,_0{}) + )); + + + struct Params { + TMA_O tma_store_o; + + ElementAcc* ptr_LSE; + StrideLSE dLSE; + }; + + // FMHA and MLA have different input ProblemShapes; + // get problem_shape_O according to the input ProblemShape. + template + CUTLASS_DEVICE static constexpr + auto get_problem_shape_O ( + ProblemShape const& problem_shape) { + if constexpr (rank_v(ProblemShape{}))> == 2) { + return replace<1>(select<0,2,3>(problem_shape), get<2, 0>(problem_shape)); + } else { + return select<0,2,3>(problem_shape); + } + } + + template + static Params to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + void* workspace = nullptr) { + + auto ptr_O = args.ptr_O; + StrideO dO = args.dO; + + auto problem_shape_O = get_problem_shape_O(problem_shape); + + if constexpr (is_variable_length_v>) { + auto cumulative_length_q = get<0>(problem_shape).cumulative_length; + if (cumulative_length_q != nullptr) { + int max_length_q = get<0>(problem_shape).max_length; + // for variable sequence lenght, the batch is in units of row_stride + get<2,1>(dO) = get<0>(dO); + get<2,1>(problem_shape_O) = max_length_q * (1 + get<2,1>(problem_shape_O)); + // offset ptr by the amount we add back in later + ptr_O -= max_length_q * get<0>(dO); + } + } + + auto tma_store_o = make_tma_copy( + SM90_TMA_STORE{}, + make_tensor(ptr_O, problem_shape_O, dO), + SmemLayoutO{}(_,_,_0{}) + ); + + return { + tma_store_o, + args.ptr_LSE, + args.dLSE + }; + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + cute::prefetch_tma_descriptor(params.tma_store_o.get_tma_descriptor()); + } + + const Params& params; + + CUTLASS_DEVICE Sm100FmhaFwdEpilogueTmaWarpspecialized(const Params& params) : params(params) {} + + template + CUTLASS_DEVICE auto + store( + BlkCoord const& blk_coord_in, ProblemShape const& problem_shape, + Params const& params, ParamsProblemShape const& params_problem_shape, + TensorStorage& shared_storage, + Pipeline& pipeline, typename Pipeline::PipelineState& pipeline_consumer_state) { + + BlkCoord blk_coord = blk_coord_in; + uint32_t lane_predicate = cute::elect_one_sync(); + + using X = Underscore; + + int o0_index = 2 * get<0>(blk_coord); + int o1_index = 2 * get<0>(blk_coord) + 1; + + Tensor mO_qdl_p = params.tma_store_o.get_tma_tensor(get_problem_shape_O(problem_shape)); + // offset mode 0 by (max_length - real_length) + // offset mode 3,1 by cumulative_length + real_length + // the ptr is already offset by - max_length + // so in total this achieves + int offs_0 = 0; + int offs_2_1 = 0; + + if constexpr (is_variable_length_v>) { + auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length; + if (cumulative_length_q != nullptr) { + int max_length_q = get<0>(params_problem_shape).max_length; + offs_0 = max_length_q - get<0>(problem_shape); + offs_2_1 = cumulative_length_q[get<2,1>(blk_coord)] + get<0>(problem_shape); + get<2,1>(blk_coord) = 0; + } + } + + Tensor mO_qdl = domain_offset(make_coord(offs_0, _0{}, make_coord(_0{}, offs_2_1)), mO_qdl_p); + + Tensor gO_qdl = local_tile(mO_qdl, TileShape{}, make_coord(_, _, _), Step<_1, _1, X>{}); + Tensor gO = gO_qdl(_, _, _, _0{}, get<2>(blk_coord)); + Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{}); + auto block_tma = params.tma_store_o.get_slice(0); + Tensor tOsO = block_tma.partition_S(sO); + Tensor tOgO = block_tma.partition_D(gO); + + auto pipeline_release_state = pipeline_consumer_state; + + // O1 O2 + // one pipeline: O + // wait from corr, issue tma store on smem + pipeline.consumer_wait(pipeline_consumer_state); + ++pipeline_consumer_state; + + if (lane_predicate) { + copy(params.tma_store_o, tOsO(_,_,_,_0{}), tOgO(_,_,_,o0_index)); + } + tma_store_arrive(); + + pipeline.consumer_wait(pipeline_consumer_state); + ++pipeline_consumer_state; + + if (lane_predicate) { + copy(params.tma_store_o, tOsO(_,_,_,_1{}), tOgO(_,_,_,o1_index)); + } + tma_store_arrive(); + + tma_store_wait<1>(); + + pipeline.consumer_release(pipeline_release_state); + ++pipeline_release_state; + + tma_store_wait<0>(); + + if constexpr (cute::is_same_v) { + cutlass::arch::NamedBarrier::arrive((NumWarpsLoad + NumWarpsEpilogue) * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + } + + pipeline.consumer_release(pipeline_release_state); + ++pipeline_release_state; + + } + +}; + +} // namespace cutlass::fmha::collective diff --git a/csrc/sm100/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp b/csrc/sm100/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp new file mode 100644 index 0000000..f39fd75 --- /dev/null +++ b/csrc/sm100/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp @@ -0,0 +1,1218 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cute/arch/simd_sm100.hpp" +#include "cute/tensor.hpp" +#include "cute/layout.hpp" + +#include "collective/fmha_common.hpp" +#include "collective/fmha_fusion.hpp" +#include "collective/sm100_fmha_load_tma_warpspecialized.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; + +template< + class Element_, + class ElementQK_, + class ElementPV_, + class TileShape_, + class StrideQ_, + class StrideK_, + class StrideV_, + class Mask_, + // shape here is QG K H + // and referes to the two softmax warps + // (2, 1, 1) means that they are stacked (best for large Q since it loads the least K/V) + // (1, 2, 1) means they sit side by side (best for small Q / large K) + class ThreadShape = Shape<_2, _1, _1>, + // Since shared memory is sufficient for FMHA, there is no need to reuse shared memory. + class OrderLoadEpilogue = cute::false_type +> +struct Sm100FmhaFwdMainloopTmaWarpspecialized { + + using Element = Element_; + using ElementQK = ElementQK_; + using ElementPV = ElementPV_; + using TileShape = TileShape_; + using StrideQ = StrideQ_; + using StrideK = StrideK_; + using StrideV = StrideV_; + using Mask = Mask_; + + static constexpr int StageCountQ = 2; + static constexpr int StageCountKV = sizeof(Element_) == 1 ? 4 : 3; + + using StagesQ = cutlass::gemm::collective::StageCount; + using StagesKV = cutlass::gemm::collective::StageCount; + + using ClusterShape = Shape<_1, _1, _1>; + + static const int Alignment = 128 / sizeof_bits_v; + + using TileShapeQK = decltype(shape_div(TileShape{}, ThreadShape{})); + + using TileShapePV = decltype(select<0,2,1>(TileShapeQK{})); + + using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, StrideQ, Alignment, + Element, StrideK, Alignment, + ElementQK, + TileShapeQK, ClusterShape, cutlass::gemm::collective::StageCount<3> /* we change it later anyways*/, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100>::CollectiveOp; + + using CollectiveMmaPV = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // the stride for A does not matter since we do not load from smem at all + Element, StrideK, Alignment, + Element, decltype(select<1,0,2>(StrideV{})), Alignment, + ElementPV, + TileShapePV, ClusterShape, cutlass::gemm::collective::StageCount<3> /* we change it later anyways*/, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100>::CollectiveOp; + + using SmemLayoutQ = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutA{}, Int{})); + using SmemLayoutK = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutB{}, Int{})); + using SmemLayoutV = decltype(unstageSmemLayout(typename CollectiveMmaPV::SmemLayoutB{}, Int{})); + + // Reuse shared memory for V and O. + static constexpr bool IsOrderLoadEpilogue = std::is_same_v; + struct TensorStorage { + cute::array_aligned> smem_q; + union { + cute::array_aligned> smem_k; + cute::array_aligned> smem_v; + }; + }; + + enum class TmemAllocation : uint32_t { + kSizeS = 128, + kSizeO = 128, + kSizeP = 32, + S0 = 0, + S1 = S0 + kSizeS, + V0 = S0, // stats storage from softmax to correction + V1 = S1, + P0 = S0 + kSizeP, + P1 = S1 + kSizeP, + O0 = S1 + kSizeS, + O1 = O0 + kSizeO, + kEnd = O1 + kSizeO + }; + + // indices for V0 / V1 + enum : int { + kIdxOldRowMax = 0, + kIdxNewRowMax = 1, + kIdxFinalRowSum = 0, + kIdxFinalRowMax = 1 + }; + + // from load to mma warp, protects q in smem + using PipelineQ = cutlass::PipelineTmaUmmaAsync< + StageCountQ, + typename CollectiveMmaQK::AtomThrShapeMNK + >; + + // from load to mma warp, protects k/v in smem + using PipelineKV = cutlass::PipelineTmaUmmaAsync< + StageCountKV, + typename CollectiveMmaQK::AtomThrShapeMNK + >; + + // from mma to softmax0/1 warp, protects S in tmem + // (not sure yet about the reverse direction) + // there is one pipe per softmax warp, and the mma warp alternates between them + using PipelineS = cutlass::PipelineUmmaAsync<1>; + + // from softmax0/1/ to correction wg + using PipelineC = cutlass::PipelineAsync<1>; + + // from mma to correction + using PipelineO = cutlass::PipelineUmmaAsync<2>; + + // from corr to epilogue + using PipelineE = cutlass::PipelineAsync<2>; + + using OrderBarrierSoftmax = cutlass::OrderedSequenceBarrier< + /*stages*/ 1, /*groups*/ 2>; + + static const int TransactionBytesLoadQ = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v); + + static const int TransactionBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v); + static const int TransactionBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v); + + static_assert(TransactionBytesLoadK == TransactionBytesLoadV, "K and V smem layouts must be of equal size"); + + using Load = Sm100FmhaLoadTmaWarpspecialized< + Element, StrideQ, StrideK, StrideV, + CollectiveMmaQK, CollectiveMmaPV, + SmemLayoutQ, SmemLayoutK, SmemLayoutV, + TensorStorage, PipelineQ, PipelineKV, Mask, TileShape + >; + + struct Arguments { + typename Load::Arguments load; + + // if zero, defaults to 1/sqrt(D) + float scale_softmax = 0.0f; + + // scaling factors to dequantize QKV + float scale_q = 1.0f; + float scale_k = 1.0f; + float scale_v = 1.0f; + + // scaling factor to quantize O + float inv_scale_o = 1.0f; + }; + + struct Params { + typename Load::Params load; + + float scale_softmax; + float scale_softmax_log2; + + float scale_output; + }; + + template + static bool can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static Params to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + void* workspace) { + + float scale_softmax = args.scale_softmax; + if (scale_softmax == 0.0f) { + scale_softmax = 1.0f / (float) std::sqrt(get<2>(problem_shape)); + } + float log2_e = static_cast(std::log2(std::exp(1.0))); + + return Params{ + Load::to_underlying_arguments(problem_shape, args.load, workspace), + args.scale_q * args.scale_k * scale_softmax, + args.scale_q * args.scale_k * log2_e * scale_softmax, + args.scale_v * args.inv_scale_o + }; + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + Load::prefetch_tma_descriptors(params.load); + } + + template + CUTLASS_DEVICE void + load( + BlkCoord const& blk_coord, ProblemShape const& problem_shape, + Params const& params, ParamsProblemShape const& params_problem_shape, + TensorStorage& storage, + PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state, + PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) { + + Load load; + load.load(blk_coord, problem_shape, params.load, params_problem_shape, + storage, + pipeline_q, pipeline_q_producer_state, + pipeline_kv, pipeline_kv_producer_state); + } + + template + CUTLASS_DEVICE auto + mma( + BlkCoord const& blk_coord, + Params const& params, ProblemShape const& problem_shape, + TensorStorage& storage, + PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_consumer_state, + PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_consumer_state, + PipelineS& pipeline_s0, typename PipelineS::PipelineState& pipeline_s0_producer_state, + PipelineS& pipeline_s1, typename PipelineS::PipelineState& pipeline_s1_producer_state, + PipelineO& pipeline_corr, typename PipelineO::PipelineState& pipeline_corr_producer_state) { + + auto pipeline_q_release_state = pipeline_q_consumer_state; + auto pipeline_kv_release_state = pipeline_kv_consumer_state; + + int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); + + typename CollectiveMmaQK::TiledMma mma_qk; + ThrMMA thr_mma_qk = mma_qk.get_slice(0); + + typename CollectiveMmaPV::TiledMma mma_pv; + TiledMMA mma_pv_ts = to_tiled_mma_sm100_ts(mma_pv); + ThrMMA thr_mma_pv = mma_pv_ts.get_slice(0); + + Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{}); + Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{}); + Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{}); + + Tensor tSrQ = thr_mma_qk.make_fragment_A(sQ); + Tensor tSrK = thr_mma_qk.make_fragment_B(sK); + Tensor tOrV = thr_mma_pv.make_fragment_B(sV); + + // tmem layout is + // S0 S1`O0 O1 + // sequential in memory, where S overlaps with P and V + + Tensor tStS = partition_fragment_C(mma_qk, select<0,1>(TileShapeQK{})); + Tensor tOtO = partition_fragment_C(mma_pv_ts, select<0,1>(TileShapePV{})); + + Tensor tStS0 = tStS; + tStS0.data() = tStS.data().get() + uint32_t(TmemAllocation::S0); + Tensor tStS1 = tStS; + tStS1.data() = tStS.data().get() + uint32_t(TmemAllocation::S1); + + Tensor tOtO0 = tOtO; + tOtO0.data() = tOtO.data().get() + uint32_t(TmemAllocation::O0); + Tensor tOtO1 = tOtO; + tOtO1.data() = tOtO.data().get() + uint32_t(TmemAllocation::O1); + + Tensor sP = make_tensor(make_smem_ptr((Element*)nullptr), typename CollectiveMmaPV::SmemLayoutA{}); + Tensor tOrP = thr_mma_pv.make_fragment_A(sP)(_, _, _, _0{}); // slice out staging + + Tensor tOrP0 = tOrP; + tOrP0.data() = tOrP0.data().get() + uint32_t(TmemAllocation::P0); + Tensor tOrP1 = tOrP; + tOrP1.data() = tOrP1.data().get() + uint32_t(TmemAllocation::P1); + + int k_index = 0; + int v_index = 0; + int q_index = 0; + + // wait for Q1 + q_index = pipeline_q_consumer_state.index(); + pipeline_q.consumer_wait(pipeline_q_consumer_state); + ++pipeline_q_consumer_state; + + Tensor tSrQ0 = tSrQ(_,_,_,q_index); + + + // wait for K1 + k_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // gemm Q1 * K1 -> S1 + pipeline_s0.producer_acquire(pipeline_s0_producer_state); + + gemm_zero_acc(mma_qk, tSrQ0, tSrK(_,_,_,k_index), tStS0); + + pipeline_s0.producer_commit(pipeline_s0_producer_state); + ++pipeline_s0_producer_state; + + // release K1 + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + + // wait for Q2 + if constexpr (get<0>(ThreadShape{}) > 1 || get<2>(ThreadShape{}) > 1) { + q_index = pipeline_q_consumer_state.index(); + pipeline_q.consumer_wait(pipeline_q_consumer_state); + ++pipeline_q_consumer_state; + } + + Tensor tSrQ1 = tSrQ(_,_,_,q_index); + + if constexpr (get<1>(ThreadShape{}) > 1) { + k_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + pipeline_s1.producer_acquire(pipeline_s1_producer_state); + + // gemm Q2 * K1 -> S2 + gemm_zero_acc(mma_qk, tSrQ1, tSrK(_,_,_,k_index), tStS1); + + pipeline_s1.producer_commit(pipeline_s1_producer_state); + ++pipeline_s1_producer_state; + + // release K1 + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + // wait for V1 + v_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // this acquire returns the ownership of all of S0 to the mma warp + // including the P0 part + // acquire corr first to take it out of the critical + // path since softmax takes longer + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + pipeline_s0.producer_acquire(pipeline_s0_producer_state); + + // gemm P1 * V1 -> O1 + gemm_zero_acc(mma_pv_ts, tOrP0, tOrV(_,_,_,v_index), tOtO0); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + + mma_pv_ts.accumulate_ = UMMA::ScaleOut::Zero; + + // loop: + mask_tile_count -= 1; + for (; mask_tile_count > 0; mask_tile_count -= 1) { + + // wait for Ki + k_index = (pipeline_kv_consumer_state.index()); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // gemm Q1 * Ki -> S1 + gemm_zero_acc(mma_qk, tSrQ0, tSrK(_,_,_,k_index), tStS0); + + pipeline_s0.producer_commit(pipeline_s0_producer_state); + ++pipeline_s0_producer_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + + // gemm P2 * V(i-1) -> O2 + if constexpr (get<1>(ThreadShape{}) > 1) { + v_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + pipeline_s1.producer_acquire(pipeline_s1_producer_state); + + gemm_reset_zero_acc(mma_pv_ts, tOrP1, tOrV(_,_,_,v_index), tOtO1); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + // release V(i-1) + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + k_index = (pipeline_kv_consumer_state.index()); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + // gemm Q2 * Ki -> S2 + gemm_zero_acc(mma_qk, tSrQ1, tSrK(_,_,_,k_index), tStS1); + + pipeline_s1.producer_commit(pipeline_s1_producer_state); + ++pipeline_s1_producer_state; + + // release Ki + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + // wait for Vi + v_index = (pipeline_kv_consumer_state.index()); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // gemm P1 * Vi -> O1 + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + + pipeline_s0.producer_acquire(pipeline_s0_producer_state); + + gemm_reset_zero_acc(mma_pv_ts, tOrP0, tOrV(_,_,_,v_index), tOtO0); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + } + + // release Q1 + pipeline_q.consumer_release(pipeline_q_release_state); + ++pipeline_q_release_state; + + // release Q2 + if constexpr (get<0>(ThreadShape{}) > 1) { + pipeline_q.consumer_release(pipeline_q_release_state); + ++pipeline_q_release_state; + } + + // wait for Vi + if constexpr (get<1>(ThreadShape{}) > 1) { + v_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + // gemm P2 * Vi -> O2 + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + pipeline_s1.producer_acquire(pipeline_s1_producer_state); + + gemm_reset_zero_acc(mma_pv_ts, tOrP1, tOrV(_,_,_,v_index), tOtO1); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + // release Vi + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + pipeline_s0.producer_commit(pipeline_s0_producer_state); + ++pipeline_s0_producer_state; + + pipeline_s1.producer_commit(pipeline_s1_producer_state); + ++pipeline_s1_producer_state; + + // T0 S00 B1, T0 S10 B1, T0 S00 B2, T0 S01 B1, T0 S10 B2, T0 S11 B1, T0 S01 B2, T1 S00 B1, T0 S11 B2, ... + // Q1 * K1 , Q2 * K1 , S11 * V1 , Q1 * K2 , S21 * V1 , Q2 * K2 , S12 * V2 , Q1 * K3 , S22 * K2 , ... + } + + template + CUTLASS_DEVICE auto + softmax_step( + float& row_max, float& row_sum, + Stage stage, bool final_call, + BlkCoord const& blk_coord, CoordTensor const& cS, + Params const& params, ProblemShape const& problem_shape, + PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state, + PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state, + OrderBarrierSoftmax& order_s) { + + Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); + + Tensor tStS = partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0,1>(TileShapeQK{})); + tStS.data() = uint32_t(stage == _0{} ? TmemAllocation::S0 : TmemAllocation::S1); + + Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{}))); + tStS_v.data() = uint32_t(stage == _0{} ? TmemAllocation::V0 : TmemAllocation::V1); + Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{}))); + + auto tilePlikeFP32 = size<1>(TileShapeQK{}) / Int{} * Int{}; + Tensor tStS_P = tStS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); + tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1)); + Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); + + // Each thread owns a single row + using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem + using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 128 cols of 8b elem + using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tStS); + auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + + Tensor tTMEM_LOADtS = thr_tmem_load.partition_S(tStS); + Tensor tTMEM_LOADcS = thr_tmem_load.partition_D(tScS); + + auto tiled_tmem_storev = make_tmem_copy(TMEM_STORE_V{}, tStS_v); + auto thr_tmem_storev = tiled_tmem_storev.get_slice(thread_idx); + + Tensor tTMEM_STOREVtS = thr_tmem_storev.partition_D(tStS_v); + Tensor tTMEM_STOREVcS = thr_tmem_storev.partition_S(tScS_v); + + auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tStS_P); + auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx); + + Tensor tTMEM_STOREtS_x4 = thr_tmem_store.partition_D(tStS_P); + tTMEM_STOREtS_x4.data() = warp_uniform(tTMEM_STOREtS_x4.data().get()); + Tensor tTMEM_STOREcS = thr_tmem_store.partition_S(tScS_P); + + // wait on tensor core pipe + pipeline_s.consumer_wait(pipeline_s_consumer_state); + + // read all of S from tmem into reg mem + Tensor tTMEM_LOADrS = make_tensor(shape(tTMEM_LOADcS)); + copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS); + + if constexpr (need_apply_mask) { + Mask{}.apply_mask(tTMEM_LOADrS, tTMEM_LOADcS, problem_shape); + } + + ElementQK old_row_max = row_max; + { + // compute rowmax + float row_max_0 = row_max; + float row_max_1 = row_max; + float row_max_2 = row_max; + float row_max_3 = row_max; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTMEM_LOADrS); i += 4) { + row_max_0 = ::fmax(row_max_0, tTMEM_LOADrS(i)); + row_max_1 = ::fmax(row_max_1, tTMEM_LOADrS(i+1)); + row_max_2 = ::fmax(row_max_2, tTMEM_LOADrS(i+2)); + row_max_3 = ::fmax(row_max_3, tTMEM_LOADrS(i+3)); + } + row_max = ::fmax(row_max_0, row_max_1); + row_max = ::fmax(row_max, row_max_2); + row_max = ::fmax(row_max, row_max_3); + } + + ElementQK row_max_safe = row_max == -INFINITY ? 0 : row_max; + + Tensor tTMEM_STOREVrS = make_tensor(shape(tTMEM_STOREVcS)); + tTMEM_STOREVrS(kIdxOldRowMax) = old_row_max; + tTMEM_STOREVrS(kIdxNewRowMax) = row_max_safe; + copy(tiled_tmem_storev, tTMEM_STOREVrS, tTMEM_STOREVtS); + + pipeline_c.producer_commit(pipeline_c_producer_state); + ++pipeline_c_producer_state; + + // notify correction wg that they are ready (might need addtl ordering between S0 and S1 WG's) + + ElementQK scale = params.scale_softmax_log2; + ElementQK row_max_scale = row_max_safe * scale; + + float2 scale_fp32x2 = make_float2(scale, scale); + float2 minus_row_max_scale_fp32x2 = make_float2(-row_max_scale, -row_max_scale); + + Tensor tTMEM_STORErS_x4 = make_tensor(shape(tTMEM_STOREcS)); + + constexpr int kConversionsPerStep = 2; + + Tensor tTMEM_STORErS_x4_e = recast>(tTMEM_STORErS_x4); + + NumericArrayConverter convert; + + const int kReleasePipeCount = 10; // must be multiple of 2 + + order_s.wait(); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTMEM_LOADrS); i += 2) { + float2 in = make_float2( + tTMEM_LOADrS(i + 0), + tTMEM_LOADrS(i + 1) + ); + float2 out; + cute::fma(out, scale_fp32x2, in, minus_row_max_scale_fp32x2); + tTMEM_LOADrS(i + 0) = out.x; + tTMEM_LOADrS(i + 1) = out.y; + + tTMEM_LOADrS(i+0) = ::exp2f(tTMEM_LOADrS(i+0)); + tTMEM_LOADrS(i+1) = ::exp2f(tTMEM_LOADrS(i+1)); + + Array in_conv; + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kConversionsPerStep; j++) { + in_conv[j] = tTMEM_LOADrS(i + j); + } + tTMEM_STORErS_x4_e[i / kConversionsPerStep] = convert(in_conv); + + + if (i == size(tTMEM_LOADrS) - kReleasePipeCount) { + order_s.arrive(); + } + + // this prevents register spills in fp16 + if constexpr (size<2>(tTMEM_STORErS_x4) == _2{}) { + if (i == size(tTMEM_LOADrS) - 6) { + copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, 0), tTMEM_STOREtS_x4(_, _, 0)); + } + } + } + + // tmem_store(reg_S8) -> op_P + CUTE_STATIC_ASSERT_V(size<2>(tTMEM_STORErS_x4) <= _2{}); + CUTE_STATIC_ASSERT_V(size<1>(tTMEM_STORErS_x4) == _1{}); + copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1), tTMEM_STOREtS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1)); + + cutlass::arch::fence_view_async_tmem_store(); + + // notify tensor core warp that P is ready + pipeline_s.consumer_release(pipeline_s_consumer_state); + ++pipeline_s_consumer_state; + + pipeline_c.producer_acquire(pipeline_c_producer_state); + + ElementQK acc_scale = 0.5f * ::exp2f(scale * (old_row_max - row_max_safe)); + row_sum *= acc_scale; + // row_sum = sum(reg_S) + float2 local_row_sum_f32x2 = make_float2(row_sum, row_sum); + float2 local_row_sum_1 = make_float2(0, 0); + float2 local_row_sum_2 = make_float2(0, 0); + float2 local_row_sum_3 = make_float2(0, 0); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTMEM_LOADrS); i += 8) { + // row_sum += tTMEM_LOADrS(i); + float2 in = make_float2(tTMEM_LOADrS(i), tTMEM_LOADrS(i+1)); + cute::add(local_row_sum_f32x2, local_row_sum_f32x2, in); + + in = make_float2(tTMEM_LOADrS(i+2), tTMEM_LOADrS(i+2+1)); + cute::add(local_row_sum_1, local_row_sum_1, in); + + in = make_float2(tTMEM_LOADrS(i+4), tTMEM_LOADrS(i+4+1)); + cute::add(local_row_sum_2, local_row_sum_2, in); + + in = make_float2(tTMEM_LOADrS(i+6), tTMEM_LOADrS(i+6+1)); + cute::add(local_row_sum_3, local_row_sum_3, in); + } + + cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_1); + cute::add(local_row_sum_2, local_row_sum_2, local_row_sum_3); + cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_2); + float local_row_sum = local_row_sum_f32x2.x + local_row_sum_f32x2.y; + + row_sum = local_row_sum; + + if (final_call) { + // re-acquire the S part in the final step + pipeline_s.consumer_wait(pipeline_s_consumer_state); + + Tensor tTMEM_STOREVrS = make_tensor(shape(tTMEM_STOREVcS)); + tTMEM_STOREVrS(kIdxFinalRowMax) = row_max; + tTMEM_STOREVrS(kIdxFinalRowSum) = row_sum; + copy(tiled_tmem_storev, tTMEM_STOREVrS, tTMEM_STOREVtS); + } + } + + template + CUTLASS_DEVICE auto + softmax( + Stage stage, + BlkCoord const& blk_coord, + Params const& params, ProblemShape const& problem_shape, + PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state, + PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state, + OrderBarrierSoftmax& order_s) { + + int mask_tile_count = Mask{}.get_unmasked_trip_count(blk_coord, TileShape{}, problem_shape); + + ElementQK row_max = -INFINITY; + ElementQK row_sum = 0; + + Tensor cS_base = make_identity_tensor(select<0,1>(TileShapeQK{})); + auto logical_offset = make_coord( + get<0>(blk_coord) * get<0>(TileShape{}) + (stage % get<0>(ThreadShape{})) * get<0>(TileShapeQK{}), + 0 + (stage % get<1>(ThreadShape{})) * get<1>(TileShapeQK{}) + ); + Tensor cS = domain_offset(logical_offset, cS_base); + + pipeline_c.producer_acquire(pipeline_c_producer_state); + + CUTLASS_PRAGMA_NO_UNROLL + for (; mask_tile_count > 0; mask_tile_count -= 1) { + softmax_step( + row_max, row_sum, stage, + (mask_tile_count == 1) && + (Mask{}.get_masked_trip_count(blk_coord, TileShape{}, problem_shape) == 0), + blk_coord, cS, params, problem_shape, + pipeline_s, pipeline_s_consumer_state, + pipeline_c, pipeline_c_producer_state, + order_s + ); + + cS.data() = cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{}); + } + + // Masked iterations + mask_tile_count = Mask{}.get_masked_trip_count(blk_coord, TileShape{}, problem_shape); + + CUTLASS_PRAGMA_NO_UNROLL + for (; mask_tile_count > 0; mask_tile_count -= 1) { + softmax_step( + row_max, row_sum, stage, mask_tile_count == 1, + blk_coord, cS, params, problem_shape, + pipeline_s, pipeline_s_consumer_state, + pipeline_c, pipeline_c_producer_state, + order_s + ); + + cS.data() = cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{}); + } + + pipeline_c.producer_commit(pipeline_c_producer_state); + ++pipeline_c_producer_state; + + pipeline_c.producer_acquire(pipeline_c_producer_state); + // empty step to sync against pipe s + pipeline_s.consumer_release(pipeline_s_consumer_state); + ++pipeline_s_consumer_state; + } + + template + CUTLASS_DEVICE auto + correction_epilogue( + float scale, + Stage stage, + TensorO const& sO_01) { + + using ElementOut = typename TensorO::value_type; + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + Tensor sO = sO_01(_,_,stage); + + // As opposed to the softmax, we do not have enough registers here + // to load all of the values (for tile kv = 128), so we loop + // good values would be either 32 or 64 + const int kCorrectionTileSize = 32 / sizeof(ElementOut); + + using TMEM_LOAD = std::conditional_t; // 4x32 threads with 64 cols of 32b elem + + typename CollectiveMmaPV::TiledMma mma; + Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{})); + Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{})); + Tensor tOcO = mma.get_slice(0).partition_C(cO); + Tensor tOsO = mma.get_slice(0).partition_C(sO); + + Tensor tOtO_i = logical_divide(tOtO, make_layout(make_shape(_128{}, Int{}))); + Tensor tOcO_i = logical_divide(tOcO, make_layout(make_shape(_128{}, Int{}))); + Tensor tOsO_i = logical_divide(tOsO, make_layout(make_shape(_128{}, Int{}))); + + if constexpr (decltype(stage == _0{})::value) { + tOtO_i.data() = tOtO_i.data().get() + uint32_t(TmemAllocation::O0); + } + else { + static_assert(decltype(stage == _1{})::value, "stage is either 0 or 1"); + tOtO_i.data() = tOtO_i.data().get() + uint32_t(TmemAllocation::O1); + } + + auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i(make_coord(_, _), _0{})); + auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + + Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i(make_coord(_, _), _)); + Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i(make_coord(_, _), _)); + Tensor tTMEM_LOADsO = thr_tmem_load.partition_D(tOsO_i(make_coord(_, _), _)); + + float2 scale_f32x2 = make_float2(scale, scale); + + // loop: + // TMEM_LOAD, FMUL2 scale, TMEM_STORE + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < get<2>(TileShape{}) / kCorrectionTileSize; i++) { + Tensor tTMEM_LOADtO_i = tTMEM_LOADtO(_, _0{}, _0{}, i); + Tensor tTMEM_LOADsO_i = tTMEM_LOADsO(_, _0{}, _0{}, i); + + Tensor tTMrO = make_tensor(shape(tTMEM_LOADcO(_, _0{}, _0{}, i))); + + copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO); + +#ifndef ONLY_SOFTMAX + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(tTMrO); j += 2) { + float2 in = make_float2(tTMrO(j), tTMrO(j+1)); + float2 out; + cute::mul(out, scale_f32x2, in); + tTMrO(j) = out.x; + tTMrO(j+1) = out.y; + } +#endif + + constexpr int N = 4 / sizeof(ElementOut); + NumericArrayConverter convert; + + Tensor tSMrO = make_tensor_like(tTMrO); + + Tensor tCs = recast(tTMrO); + Tensor tCd = recast(tSMrO); + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(tCs); j++) { + tCd(j) = convert.convert(tCs(j)); + } + + Tensor tSMsO_i = recast(tTMEM_LOADsO_i); + Tensor tSMrO_i = recast(tSMrO); + + copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tSMrO_i, tSMsO_i); + } + + cutlass::arch::fence_view_async_shared(); + } + + CUTLASS_DEVICE auto + correction_rescale( + float scale, + uint32_t tmem_O) { + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + // As opposed to the softmax, we do not have enough registers here + // to load all of the values (for tile kv = 128), so we loop + // good values would be either 32 or 64 + const int kCorrectionTileSize = 16; + + using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b16x; // 4x32 threads with 64 cols of 32b elem + using TMEM_STORE = SM100_TMEM_STORE_32dp32b16x; // 4x32 threads with 64 cols of 32b elem + + typename CollectiveMmaPV::TiledMma mma; + Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{})); + Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{})); + Tensor tOcO = mma.get_slice(0).partition_C(cO); + + Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int{}))); + Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int{}))); + + tOtO_i.data() = tOtO_i.data().get() + tmem_O; + + auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i); + auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tOtO_i); + auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx); + + Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i); + Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i); + Tensor tTMEM_STOREtO = thr_tmem_store.partition_D(tOtO_i); + Tensor tTMEM_STOREcO = thr_tmem_store.partition_S(tOcO_i); + static_assert(shape(tTMEM_STOREcO) == shape(tTMEM_LOADcO)); + + float2 scale_f32x2 = make_float2(scale, scale); + + Tensor tTMrO = make_tensor(make_shape(shape(tTMEM_LOADcO), Int<128 / kCorrectionTileSize>{})); + + auto copy_in = [&](int i) { + Tensor tTMEM_LOADtO_i = tTMEM_LOADtO; + tTMEM_LOADtO_i.data() = tTMEM_LOADtO_i.data().get() + uint32_t(i * kCorrectionTileSize); + Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); + copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO_i); + }; + + auto copy_out = [&](int i) { + Tensor tTMEM_STOREtO_i = tTMEM_STOREtO; + tTMEM_STOREtO_i.data() = tTMEM_STOREtO_i.data().get() + uint32_t(i * kCorrectionTileSize); + Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); + copy(tiled_tmem_store, tTMrO_i, tTMEM_STOREtO_i); + }; + + // sequence: LLMSLMSLMSS + + // loop: + // TMEM_LOAD, FMUL2 scale, TMEM_STORE + copy_in(0); + + int count = get<2>(TileShape{}) / kCorrectionTileSize; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < count; i++) { + if (i != count - 1) { + copy_in(i+1); + } + + Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(tTMrO_i); j += 2) { + float2 in = make_float2(tTMrO_i(j), tTMrO_i(j+1)); + float2 out; + cute::mul(out, scale_f32x2, in); + tTMrO_i(j) = out.x; + tTMrO_i(j+1) = out.y; + } + + copy_out(i); + } + } + + template< + class BlkCoord, class ProblemShape, class ParamsProblemShape, + class TensorStorageEpi, class CollectiveEpilogue + > + CUTLASS_DEVICE auto + correction( + BlkCoord const& blk_coord, + Params const& params, ProblemShape const& problem_shape, + ParamsProblemShape const& params_problem_shape, + TensorStorageEpi& shared_storage_epi, + PipelineC& pipeline_s0_c, typename PipelineC::PipelineState& pipeline_s0_c_consumer_state, + PipelineC& pipeline_s1_c, typename PipelineC::PipelineState& pipeline_s1_c_consumer_state, + PipelineO& pipeline_o, typename PipelineO::PipelineState& pipeline_o_consumer_state, + PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state, + CollectiveEpilogue& epilogue) { + + int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + Tensor tStS = partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0,1>(TileShapeQK{})); + + Tensor cS = make_identity_tensor(select<0,1>(TileShapeQK{})); + Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); + + Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{}))); + Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{}))); + + using TMEM_LOAD_V = SM100_TMEM_LOAD_32dp32b2x; // 4x32 threads with 2 cols of 32b elem + + auto tiled_tmem_loadv = make_tmem_copy(TMEM_LOAD_V{}, tStS_v); + auto thr_tmem_loadv = tiled_tmem_loadv.get_slice(thread_idx); + + Tensor tTMEM_LOADVtS = thr_tmem_loadv.partition_S(tStS_v); + Tensor tTMEM_LOADVcS = thr_tmem_loadv.partition_D(tScS_v); + + Tensor tTMEM_LOADVtS0 = tTMEM_LOADVtS; + tTMEM_LOADVtS0.data() = tTMEM_LOADVtS0.data().get() + uint32_t(TmemAllocation::V0); + Tensor tTMEM_LOADVtS1 = tTMEM_LOADVtS; + tTMEM_LOADVtS1.data() = tTMEM_LOADVtS1.data().get() + uint32_t(TmemAllocation::V1); + + // ignore first signal from softmax as no correction is required + pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); + pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); + ++pipeline_s0_c_consumer_state; + + pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); + + // handle the last iteration differently (i.e. tmem_load/stsm for epi) + mask_tile_count -= 1; + + CUTLASS_PRAGMA_NO_UNROLL + for (; mask_tile_count > 0; mask_tile_count -= 1) { + + pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); + + Tensor tTMEM_LOADVrS = make_tensor(shape(tTMEM_LOADVcS)); + + // read row_wise new global max + copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS); + + // e^(scale * (old_max - new_max) + float scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax))); + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + + correction_rescale(scale, uint32_t(TmemAllocation::O0)); + + pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); + ++pipeline_s1_c_consumer_state; + + cutlass::arch::fence_view_async_tmem_store(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + + pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); + + copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS); + + scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax))); + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + + correction_rescale(scale, uint32_t(TmemAllocation::O1)); + + pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); + ++pipeline_s0_c_consumer_state; + + cutlass::arch::fence_view_async_tmem_store(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + } + + pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); + ++pipeline_s1_c_consumer_state; + + // do the final correction to O1 + // better to somehow special-case it in the loop above + // doesn't matter for non-persistent code, but if it were + // persistent we do not want to release O too early + + pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); + + // read from V0 + // read row_sum and final row_max here + Tensor tTMEM_LOADVrS = make_tensor(shape(tTMEM_LOADVcS)); + copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS); + + pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); + ++pipeline_s0_c_consumer_state; + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + pipeline_epi.producer_acquire(pipeline_epi_producer_state); + // store to epi smem + + // loop: + // TMEM_LOAD + // FMUL2 scale = 1 / global_sum * out_quant_scale + // F2FP + // store to smem + Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{}); + Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), select<0,3>(problem_shape), epilogue.params.dLSE); + + correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _0{}, sO); + + if (epilogue.params.ptr_LSE != nullptr) { + int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord); + + int row_offset = 0; + if constexpr (is_variable_length_v>) { + row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)]; + } + + ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax); + + if (row_idx < get<0>(problem_shape)) { + gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse; + } + } + + cutlass::arch::fence_view_async_tmem_load(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + + pipeline_epi.producer_commit(pipeline_epi_producer_state); + ++pipeline_epi_producer_state; + + pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); + + // load from V1 + copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS); + + pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); + ++pipeline_s1_c_consumer_state; + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + pipeline_epi.producer_acquire(pipeline_epi_producer_state); + + correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _1{}, sO); + + if (epilogue.params.ptr_LSE != nullptr) { + int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{}); + + ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax); + + int row_offset = 0; + if constexpr (is_variable_length_v>) { + row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)]; + } + + if (row_idx < get<0>(problem_shape)) { + gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse; + } + } + + cutlass::arch::fence_view_async_tmem_load(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + + pipeline_epi.producer_commit(pipeline_epi_producer_state); + ++pipeline_epi_producer_state; + } + + + template< + class BlkCoord, class ProblemShape, class ParamsProblemShape, + class TensorStorageEpi, class CollectiveEpilogue + > + CUTLASS_DEVICE auto + correction_empty( + BlkCoord const& blk_coord, + Params const& params, ProblemShape const& problem_shape, + ParamsProblemShape const& params_problem_shape, + TensorStorageEpi& shared_storage_epi, + PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state, + CollectiveEpilogue& epilogue) { + + pipeline_epi.producer_acquire(pipeline_epi_producer_state); + + Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{}); + Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), select<0,3>(problem_shape), epilogue.params.dLSE); + float lse = -INFINITY; + int thread_idx = threadIdx.x % (4 * NumThreadsPerWarp); + +#define DSHOW(x) print(#x ": "); print(x); print("\n") + if (threadIdx.x % 128 == 0 && block0()) { + DSHOW(sO); + } +#if 1 + + using ElementOut = typename CollectiveEpilogue::ElementOut; + auto tiled_copy = make_cotiled_copy( + Copy_Atom, ElementOut>{}, + make_ordered_layout(make_shape(_128{}, Int{}), Step<_1, _0>{}), + sO.layout()); + + auto thr_copy = tiled_copy.get_slice(thread_idx); + auto tOgO = thr_copy.partition_D(sO); + auto tOrO = make_tensor(shape(tOgO(_,_,_,_0{}))); + clear(tOrO); + + copy(tiled_copy, tOrO, tOgO(_,_,_,_0{})); +#endif + + if (epilogue.params.ptr_LSE != nullptr) { + int row_idx = thread_idx + get<0>(TileShape{}) * get<0>(blk_coord); + + int row_offset = 0; + if constexpr (is_variable_length_v>) { + row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)]; + } + + if (row_idx < get<0>(problem_shape)) { + gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse; + } + } + + pipeline_epi.producer_commit(pipeline_epi_producer_state); + ++pipeline_epi_producer_state; + + copy(tiled_copy, tOrO, tOgO(_,_,_,_1{})); + cutlass::arch::fence_view_async_shared(); + pipeline_epi.producer_acquire(pipeline_epi_producer_state); + + if (epilogue.params.ptr_LSE != nullptr) { + int row_idx = thread_idx + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{}); + + int row_offset = 0; + if constexpr (is_variable_length_v>) { + row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)]; + } + + if (row_idx < get<0>(problem_shape)) { + gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse; + } + } + + cutlass::arch::fence_view_async_shared(); + pipeline_epi.producer_commit(pipeline_epi_producer_state); + ++pipeline_epi_producer_state; + } + +}; + +} // namespace cutlass::fmha::collective diff --git a/csrc/sm100/collective/sm100_fmha_load_tma_warpspecialized.hpp b/csrc/sm100/collective/sm100_fmha_load_tma_warpspecialized.hpp new file mode 100644 index 0000000..1951056 --- /dev/null +++ b/csrc/sm100/collective/sm100_fmha_load_tma_warpspecialized.hpp @@ -0,0 +1,316 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cute/tensor.hpp" +#include "cute/layout.hpp" + +#include "collective/fmha_common.hpp" +#include "collective/fmha_fusion.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; + +template< + class Element, + class StrideQ, + class StrideK, + class StrideV, + class CollectiveMmaQK, + class CollectiveMmaPV, + class SmemLayoutQ, + class SmemLayoutK, + class SmemLayoutV, + class TensorStorage, + class PipelineQ, + class PipelineKV, + class Mask, + class TileShape +> +struct Sm100FmhaLoadTmaWarpspecialized { + + using TileShapeQK = typename CollectiveMmaQK::TileShape; + using TileShapePV = typename CollectiveMmaPV::TileShape; + + struct Arguments { + const Element* ptr_Q; + StrideQ dQ; + const Element* ptr_K; + StrideK dK; + const Element* ptr_V; + StrideV dV; + }; + + using TMA_Q = typename CollectiveMmaQK::Params::TMA_A; + using TMA_K = typename CollectiveMmaQK::Params::TMA_B; + using TMA_V = typename CollectiveMmaPV::Params::TMA_B; + + struct Params { + TMA_Q tma_load_q; + TMA_K tma_load_k; + TMA_V tma_load_v; + }; + + template + static Params to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + void* workspace) { + + auto ptr_Q = args.ptr_Q; + auto ptr_K = args.ptr_K; + auto ptr_V = args.ptr_V; + auto dQ = args.dQ; + auto dK = args.dK; + auto dV = args.dV; + auto problem_shape_qk = problem_shape; + + if constexpr (is_variable_length_v>) { + auto cumulative_length_q = get<0>(problem_shape).cumulative_length; + if (cumulative_length_q != nullptr) { + int max_length_q = get<0>(problem_shape).max_length; + // for variable sequence lenght, the batch is in units of row_stride + get<2,1>(dQ) = get<0>(dQ); + get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_q * (1 + get<3,1>(problem_shape))); + // offset ptr by the amount we add back in later + ptr_Q -= max_length_q * get<0>(dQ); + } + } + + if constexpr (is_variable_length_v>) { + auto cumulative_length_kv = get<1>(problem_shape).cumulative_length; + if (cumulative_length_kv != nullptr) { + int max_length_kv = get<1>(problem_shape).max_length; + // for variable sequence lenght, the batch is in units of row_stride + get<2,1>(dK) = get<0>(dK); + get<2,1>(dV) = get<0>(dV); + get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_kv * (1 + get<3,1>(problem_shape))); + // offset ptr by the amount we add back in later + ptr_K -= max_length_kv * get<0>(dK); + ptr_V -= max_length_kv * get<0>(dV); + } + } + + auto params_qk = CollectiveMmaQK::to_underlying_arguments( + problem_shape_qk, + typename CollectiveMmaQK::Arguments { + ptr_Q, dQ, + ptr_K, dK, + }, /*workspace=*/ nullptr); + + auto problem_shape_pv = select<0,2,1,3>(problem_shape_qk); + auto params_pv = CollectiveMmaPV::to_underlying_arguments( + problem_shape_pv, + typename CollectiveMmaPV::Arguments { + ptr_K, dK, // never used, dummy + ptr_V, select<1,0,2>(dV), + }, /*workspace=*/ nullptr); + + return Params{ + params_qk.tma_load_a, + params_qk.tma_load_b, + params_pv.tma_load_b + }; + } + + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor()); + } + + template + CUTLASS_DEVICE void + load( + BlkCoord const& blk_coord_in, ProblemShape const& problem_shape, + Params const& params, ParamsProblemShape const& params_problem_shape, + TensorStorage& storage, + PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state, + PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) { + + BlkCoord blk_coord_q = blk_coord_in; + BlkCoord blk_coord_kv = blk_coord_in; + + int mask_tile_count = Mask{}.get_trip_count(blk_coord_in, TileShape{}, problem_shape); + + using X = Underscore; + + // this one is only executed by one thread, no need to elect_one + + // Q1, K1, Q2, V1, K2, V2, K3, V3, ... + // two pipes: Q and KV + // from Memory (prod) to TensorCore (cons) + + // compute gQ, sQ + // we load 2*get<0>(blk_coord), and 2*get<0>(blk_coord) + 1 + ThrMMA mma_qk = typename CollectiveMmaQK::TiledMma{}.get_slice(0); + Tensor mQ_qdl_p = params.tma_load_q.get_tma_tensor(select<0,2,3>(problem_shape)); + + int q_offs_0 = 0; + int q_offs_2_1 = 0; + + if constexpr (is_variable_length_v>) { + auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length; + if (cumulative_length_q != nullptr) { + int max_length_q = get<0>(params_problem_shape).max_length; + q_offs_0 = max_length_q - get<0>(problem_shape); + q_offs_2_1 = cumulative_length_q[get<2,1>(blk_coord_q)] + get<0>(problem_shape); + get<2,1>(blk_coord_q) = 0; + } + } + + Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, q_offs_2_1)), mQ_qdl_p); + + Tensor gQ_qdl = local_tile(mQ_qdl, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{}); + Tensor tSgQ_qdl = mma_qk.partition_A(gQ_qdl); + Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{}); + auto [tQgQ_qdl, tQsQ] = tma_partition( + params.tma_load_q, _0{}, make_layout(_1{}), + group_modes<0,3>(sQ), group_modes<0,3>(tSgQ_qdl) + ); + Tensor tQgQ = tQgQ_qdl(_, _, _0{}, get<2>(blk_coord_q)); + + // compute gK, sK + Tensor mK_kdl_p = params.tma_load_k.get_tma_tensor(select<1,2,3>(problem_shape)); + + int kv_offs_0 = 0; + int kv_offs_2_1 = 0; + + if constexpr (is_variable_length_v>) { + auto cumulative_length = get<1>(params_problem_shape).cumulative_length; + if (cumulative_length != nullptr) { + int max_length = get<1>(params_problem_shape).max_length; + kv_offs_0 = max_length - get<1>(problem_shape); + kv_offs_2_1 = cumulative_length[get<2,1>(blk_coord_kv)] + get<1>(problem_shape); + get<2,1>(blk_coord_kv) = 0; + } + } + + Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, kv_offs_2_1)), mK_kdl_p); + + Tensor gK_kdl = local_tile(mK_kdl, TileShapeQK{}, make_coord(_, _, _), Step{}); + Tensor tSgK_kdl = mma_qk.partition_B(gK_kdl); + Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{}); + auto [tKgK_kdl, tKsK] = tma_partition( + params.tma_load_k, _0{}, make_layout(_1{}), + group_modes<0,3>(sK), group_modes<0,3>(tSgK_kdl) + ); + Tensor tKgK = tKgK_kdl(_, _, _0{}, get<2>(blk_coord_kv)); + + // compute gV, sV + ThrMMA mma_pv = typename CollectiveMmaPV::TiledMma{}.get_slice(0); + Tensor mV_dkl_p = params.tma_load_v.get_tma_tensor(select<2,1,3>(problem_shape)); + + Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, kv_offs_2_1)), mV_dkl_p); + + Tensor gV_dkl = local_tile(mV_dkl, TileShapePV{}, make_coord(_, _, _), Step{}); + Tensor tOgV_dkl = mma_pv.partition_B(gV_dkl); + Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{}); + auto [tVgV_dkl, tVsV] = tma_partition( + params.tma_load_v, _0{}, make_layout(_1{}), + group_modes<0,3>(sV), group_modes<0,3>(tOgV_dkl) + ); + auto tVgV = tVgV_dkl(_, _0{}, _, get<2>(blk_coord_kv)); + + // blk_coord in decomposed in terms of TileShape, not TileShapeQK + // As such, it needs to be transformed as + // (a,b,c): a -> 2*a (Q0) 2*a+1 (Q1) + // b -> 2*a (Ki i even) 2*a+1 (Ki i odd) + + uint32_t lane_predicate = cute::elect_one_sync(); + + // Q1 + int q0_index = 2 * get<0>(blk_coord_q); + int q1_index = 2 * get<0>(blk_coord_q) + 1; + pipeline_q.producer_acquire(pipeline_q_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state); + copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q0_index), tQsQ(_, pipeline_q_producer_state.index())); + } + ++pipeline_q_producer_state; + + // K1 + int k_index = 0; + pipeline_kv.producer_acquire(pipeline_kv_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); + copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, k_index), tKsK(_, pipeline_kv_producer_state.index())); + } + ++pipeline_kv_producer_state; + + // Q2 + pipeline_q.producer_acquire(pipeline_q_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state); + copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q1_index), tQsQ(_, pipeline_q_producer_state.index())); + } + ++pipeline_q_producer_state; + + // V1 + pipeline_kv.producer_acquire(pipeline_kv_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); + copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index())); + } + ++pipeline_kv_producer_state; + k_index += 1; + + // loop: + mask_tile_count -= 1; + for (; mask_tile_count > 0; mask_tile_count -= 1) { + + // Ki + pipeline_kv.producer_acquire(pipeline_kv_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); + copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, k_index), tKsK(_, pipeline_kv_producer_state.index())); + } + ++pipeline_kv_producer_state; + + // Vi + pipeline_kv.producer_acquire(pipeline_kv_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); + copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index())); + } + ++pipeline_kv_producer_state; + k_index += 1; + } + } +}; + +} // namespace cutlass::fmha::collective diff --git a/csrc/sm100/collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp b/csrc/sm100/collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp new file mode 100644 index 0000000..bf41af9 --- /dev/null +++ b/csrc/sm100/collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp @@ -0,0 +1,1225 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cute/arch/simd_sm100.hpp" +#include "cute/tensor.hpp" +#include "cute/layout.hpp" + +#include "collective/fmha_common.hpp" +#include "collective/fmha_fusion.hpp" +#include "collective/sm100_fmha_mla_load_tma_warpspecialized.hpp" +#include "common/pipeline_mla.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; + +template< + class Element_, + class ElementQK_, + class ElementPV_, + class ComposedTileShape_, + class StrideQ_, + class StrideK_, + class StrideV_, + class Mask_, + // shape here is QG K H + // and referes to the two softmax warps + // (2, 1, 1) means that they are stacked (best for large Q since it loads the least K/V) + // (1, 2, 1) means they sit side by side (best for small Q / large K) + class ThreadShape = Shape<_2, _1, _1>, + class OrderLoadEpilogue = cute::false_type +> +struct Sm100MlaFwdMainloopTmaWarpspecialized { + + using Element = Element_; + using ElementQK = ElementQK_; + using ElementPV = ElementPV_; + using ComposedTileShape = ComposedTileShape_; + using StrideQ = StrideQ_; + using StrideK = StrideK_; + using StrideV = StrideV_; + using Mask = Mask_; + + static constexpr int StageCountQ = 2; + static constexpr int StageCountK = 1; + static constexpr int StageCountV = 1; + static constexpr int StageCountKV = StageCountK + StageCountV; + // Support StageCountKV > 2 in the future. + static_assert(StageCountK == 1 && StageCountV == 1, "Only support StageCountK = StageCountV = 1!"); + static_assert(std::is_same_v>, "Only support ThreadShape = Shape<_2, _1, _1>"); + + using ClusterShape = Shape<_1, _1, _1>; + + static const int Alignment = 128 / sizeof_bits_v; + + static constexpr auto HeadDimLatent = size<2, 0>(ComposedTileShape{}); + static constexpr auto HeadDimRope = size<2, 1>(ComposedTileShape{}); + static constexpr auto HeadDimQK = HeadDimLatent + HeadDimRope; + static constexpr auto HeadDimPV = HeadDimLatent; + + using TileShapeQK = decltype(shape_div(replace<2>(ComposedTileShape{}, HeadDimQK), ThreadShape{})); + using TileShapePV = decltype(select<0,2,1>(shape_div(replace<2>(ComposedTileShape{}, HeadDimPV), ThreadShape{}))); + using TileShape = decltype(replace<2>(ComposedTileShape{}, HeadDimLatent)); + + using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, StrideQ, Alignment, + Element, StrideK, Alignment, + ElementQK, + TileShapeQK, ClusterShape, cutlass::gemm::collective::StageCount<3> /* we change it later anyways*/, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100>::CollectiveOp; + + using CollectiveMmaPV = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // the stride for A does not matter since we do not load from smem at all + Element, StrideK, Alignment, + Element, decltype(select<1,0,2>(StrideV{})), Alignment, + ElementPV, + TileShapePV, ClusterShape, cutlass::gemm::collective::StageCount<3> /* we change it later anyways*/, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100>::CollectiveOp; + + using SmemLayoutQ = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutA{}, Int{})); + using SmemLayoutK = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutB{}, Int{})); + using SmemLayoutV = decltype(unstageSmemLayout(typename CollectiveMmaPV::SmemLayoutB{}, Int{})); + + using SmemStorageOneStageO = decltype(make_layout(replace<2>(TileShapePV{}, _1{}))); + + // Since the shared memory is not sufficient if we use separate Q, K, V, and O shared memory, + // we reuse shared memory for V and O to address this problem, + // and a barrier has been added to coordinate access to shared memory. + static constexpr bool IsOrderLoadEpilogue = std::is_same_v; + static const int NumWarpsEpilogue = 1; + static const int NumWarpsLoad = 1; + + struct TensorStorageQKVO { + cute::array_aligned> smem_q; + cute::array_aligned> smem_k; + cute::array_aligned> smem_o; // use as O0 + cute::array_aligned> smem_v; // use as V0 and O1 + }; + + struct TensorStorageQKV { + cute::array_aligned> smem_q; + cute::array_aligned> smem_k; + cute::array_aligned> smem_v; + }; + + using TensorStorage = std::conditional_t; + + enum class TmemAllocation : uint32_t { + kSizeS = 128, + kSizeO = 128, + kSizeP = 32, + S0 = 0, + S1 = S0 + kSizeS, + V0 = S0, // stats storage from softmax to correction + V1 = S1, + P0 = S0 + kSizeP, + P1 = S1 + kSizeP, + O0 = S1 + kSizeS, + O1 = O0 + kSizeO, + kEnd = O1 + kSizeO + }; + + // indices for V0 / V1 + enum : int { + kIdxOldRowMax = 0, + kIdxNewRowMax = 1, + kIdxFinalRowSum = 0, + kIdxFinalRowMax = 1 + }; + + // from load to mma warp, protects q in smem + using PipelineQ = cutlass::PipelineTmaUmmaAsync< + StageCountQ, + typename CollectiveMmaQK::AtomThrShapeMNK + >; + + // from load to mma warp, protects k/v in smem + using PipelineKV = cutlass::PipelineTmaAsyncMla< + StageCountKV, + typename CollectiveMmaQK::AtomThrShapeMNK + >; + + // from mma to softmax0/1 warp, protects S in tmem + // (not sure yet about the reverse direction) + // there is one pipe per softmax warp, and the mma warp alternates between them + using PipelineS = cutlass::PipelineUmmaAsync<1>; + + // from softmax0/1/ to correction wg + using PipelineC = cutlass::PipelineAsync<1>; + + // from mma to correction + using PipelineO = cutlass::PipelineUmmaAsync<2>; + + // from corr to epilogue + using PipelineE = cutlass::PipelineAsync<2>; + + using OrderBarrierSoftmax = cutlass::OrderedSequenceBarrier< + /*stages*/ 1, /*groups*/ 2>; + + static constexpr int TransactionBytesLoadQ = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v); + static constexpr int TransactionBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v); + static constexpr int TransactionBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v); + + using Load = Sm100MlaFwdLoadTmaWarpspecialized< + Element, StrideQ, StrideK, StrideV, + CollectiveMmaQK, CollectiveMmaPV, + SmemLayoutQ, SmemLayoutK, SmemLayoutV, + TensorStorage, PipelineQ, PipelineKV, Mask, TileShape, OrderLoadEpilogue + >; + + struct Arguments { + typename Load::Arguments load; + + // if zero, defaults to 1/sqrt(D) + float scale_softmax = 0.0f; + + // scaling factors to dequantize QKV + float scale_q = 1.0f; + float scale_k = 1.0f; + float scale_v = 1.0f; + + // scaling factor to quantize O + float inv_scale_o = 1.0f; + }; + + struct Params { + typename Load::Params load; + + float scale_softmax; + float scale_softmax_log2; + + float scale_output; + }; + + template + static bool can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static Params to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + void* workspace) { + + float scale_softmax = args.scale_softmax; + if (scale_softmax == 0.0f) { + scale_softmax = 1.0f / (float) std::sqrt(get<2, 0>(problem_shape) + get<2, 1>(problem_shape)); + } + float log2_e = static_cast(std::log2(std::exp(1.0))); + + return Params{ + Load::to_underlying_arguments(problem_shape, args.load, workspace), + args.scale_q * args.scale_k * scale_softmax, + args.scale_q * args.scale_k * log2_e * scale_softmax, + args.scale_v * args.inv_scale_o + }; + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + Load::prefetch_tma_descriptors(params.load); + } + + template + CUTLASS_DEVICE void + load( + BlkCoord const& blk_coord, ProblemShape const& problem_shape, + Params const& params, ParamsProblemShape const& params_problem_shape, + TensorStorage& storage, + PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state, + PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) { + + Load load; + load.load(blk_coord, problem_shape, params.load, params_problem_shape, + storage, + pipeline_q, pipeline_q_producer_state, + pipeline_kv, pipeline_kv_producer_state); + } + + template + CUTLASS_DEVICE auto + mma( + BlkCoord const& blk_coord, + Params const& params, ProblemShape const& problem_shape, + TensorStorage& storage, + PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_consumer_state, + PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_consumer_state, + PipelineS& pipeline_s0, typename PipelineS::PipelineState& pipeline_s0_producer_state, + PipelineS& pipeline_s1, typename PipelineS::PipelineState& pipeline_s1_producer_state, + PipelineO& pipeline_corr, typename PipelineO::PipelineState& pipeline_corr_producer_state) { + + auto pipeline_q_release_state = pipeline_q_consumer_state; + auto pipeline_kv_release_state = pipeline_kv_consumer_state; + + int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); + + typename CollectiveMmaQK::TiledMma mma_qk; + ThrMMA thr_mma_qk = mma_qk.get_slice(0); + + typename CollectiveMmaPV::TiledMma mma_pv; + TiledMMA mma_pv_ts = to_tiled_mma_sm100_ts(mma_pv); + ThrMMA thr_mma_pv = mma_pv_ts.get_slice(0); + + Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{}); + Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{}); + Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{}); + + Tensor tSrQ = thr_mma_qk.make_fragment_A(sQ); + Tensor tSrK = thr_mma_qk.make_fragment_B(sK); + Tensor tOrV = thr_mma_pv.make_fragment_B(sV); + + // tmem layout is + // S0 S1`O0 O1 + // sequential in memory, where S overlaps with P and V + + Tensor tStS = partition_fragment_C(mma_qk, select<0,1>(TileShapeQK{})); + Tensor tOtO = partition_fragment_C(mma_pv_ts, select<0,1>(TileShapePV{})); + + Tensor tStS0 = tStS; + tStS0.data() = tStS.data().get() + uint32_t(TmemAllocation::S0); + Tensor tStS1 = tStS; + tStS1.data() = tStS.data().get() + uint32_t(TmemAllocation::S1); + + Tensor tOtO0 = tOtO; + tOtO0.data() = tOtO.data().get() + uint32_t(TmemAllocation::O0); + Tensor tOtO1 = tOtO; + tOtO1.data() = tOtO.data().get() + uint32_t(TmemAllocation::O1); + + Tensor sP = make_tensor(make_smem_ptr((Element*)nullptr), typename CollectiveMmaPV::SmemLayoutA{}); + Tensor tOrP = thr_mma_pv.make_fragment_A(sP)(_, _, _, _0{}); // slice out staging + + Tensor tOrP0 = tOrP; + tOrP0.data() = tOrP0.data().get() + uint32_t(TmemAllocation::P0); + Tensor tOrP1 = tOrP; + tOrP1.data() = tOrP1.data().get() + uint32_t(TmemAllocation::P1); + + int k_index = 0; + int v_index = 0; + int q_index = 0; + + // wait for Q1 + q_index = pipeline_q_consumer_state.index(); + pipeline_q.consumer_wait(pipeline_q_consumer_state); + ++pipeline_q_consumer_state; + + Tensor tSrQ0 = tSrQ(_,_,_,q_index); + + + // wait for K1 + k_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // gemm Q1 * K1 -> S1 + pipeline_s0.producer_acquire(pipeline_s0_producer_state); + + gemm_zero_acc(mma_qk, tSrQ0, tSrK(_,_,_,k_index / 2), tStS0); + + pipeline_s0.producer_commit(pipeline_s0_producer_state); + ++pipeline_s0_producer_state; + + // release K1 + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + + // wait for Q2 + if constexpr (get<0>(ThreadShape{}) > 1 || get<2>(ThreadShape{}) > 1) { + q_index = pipeline_q_consumer_state.index(); + pipeline_q.consumer_wait(pipeline_q_consumer_state); + ++pipeline_q_consumer_state; + } + + Tensor tSrQ1 = tSrQ(_,_,_,q_index); + + if constexpr (get<1>(ThreadShape{}) > 1) { + k_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + pipeline_s1.producer_acquire(pipeline_s1_producer_state); + + // gemm Q2 * K1 -> S2 + gemm_zero_acc(mma_qk, tSrQ1, tSrK(_,_,_,k_index / 2), tStS1); + + pipeline_s1.producer_commit(pipeline_s1_producer_state); + ++pipeline_s1_producer_state; + + // release K1 + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + // wait for V1 + v_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // this acquire returns the ownership of all of S0 to the mma warp + // including the P0 part + // acquire corr first to take it out of the critical + // path since softmax takes longer + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + pipeline_s0.producer_acquire(pipeline_s0_producer_state); + + // gemm P1 * V1 -> O1 + gemm_zero_acc(mma_pv_ts, tOrP0, tOrV(_,_,_,v_index / 2), tOtO0); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + + mma_pv_ts.accumulate_ = UMMA::ScaleOut::Zero; + + // loop: + mask_tile_count -= 1; + for (; mask_tile_count > 0; mask_tile_count -= 1) { + + // wait for Ki + k_index = (pipeline_kv_consumer_state.index()); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // gemm Q1 * Ki -> S1 + gemm_zero_acc(mma_qk, tSrQ0, tSrK(_,_,_,k_index / 2), tStS0); + + pipeline_s0.producer_commit(pipeline_s0_producer_state); + ++pipeline_s0_producer_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + + // gemm P2 * V(i-1) -> O2 + if constexpr (get<1>(ThreadShape{}) > 1) { + v_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + pipeline_s1.producer_acquire(pipeline_s1_producer_state); + + gemm_reset_zero_acc(mma_pv_ts, tOrP1, tOrV(_,_,_,v_index / 2), tOtO1); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + // release V(i-1) + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + k_index = (pipeline_kv_consumer_state.index()); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + // gemm Q2 * Ki -> S2 + gemm_zero_acc(mma_qk, tSrQ1, tSrK(_,_,_,k_index / 2), tStS1); + + pipeline_s1.producer_commit(pipeline_s1_producer_state); + ++pipeline_s1_producer_state; + + // release Ki + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + // wait for Vi + v_index = (pipeline_kv_consumer_state.index()); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // gemm P1 * Vi -> O1 + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + + pipeline_s0.producer_acquire(pipeline_s0_producer_state); + + gemm_reset_zero_acc(mma_pv_ts, tOrP0, tOrV(_,_,_,v_index / 2), tOtO0); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + } + + // release Q1 + pipeline_q.consumer_release(pipeline_q_release_state); + ++pipeline_q_release_state; + + // release Q2 + if constexpr (get<0>(ThreadShape{}) > 1) { + pipeline_q.consumer_release(pipeline_q_release_state); + ++pipeline_q_release_state; + } + + // wait for Vi + if constexpr (get<1>(ThreadShape{}) > 1) { + v_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + // gemm P2 * Vi -> O2 + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + pipeline_s1.producer_acquire(pipeline_s1_producer_state); + + gemm_reset_zero_acc(mma_pv_ts, tOrP1, tOrV(_,_,_,v_index / 2), tOtO1); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + // release Vi + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + pipeline_s0.producer_commit(pipeline_s0_producer_state); + ++pipeline_s0_producer_state; + + pipeline_s1.producer_commit(pipeline_s1_producer_state); + ++pipeline_s1_producer_state; + + // T0 S00 B1, T0 S10 B1, T0 S00 B2, T0 S01 B1, T0 S10 B2, T0 S11 B1, T0 S01 B2, T1 S00 B1, T0 S11 B2, ... + // Q1 * K1 , Q2 * K1 , S11 * V1 , Q1 * K2 , S21 * V1 , Q2 * K2 , S12 * V2 , Q1 * K3 , S22 * K2 , ... + } + + template + CUTLASS_DEVICE auto + softmax_step( + bool need_apply_mask, + float& row_max, float& row_sum, + Stage stage, bool final_call, + BlkCoord const& blk_coord, CoordTensor const& cS, + Params const& params, ProblemShape const& problem_shape, + PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state, + PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state, + OrderBarrierSoftmax& order_s) { + + Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); + + Tensor tStS = partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0,1>(TileShapeQK{})); + tStS.data() = uint32_t(stage == _0{} ? TmemAllocation::S0 : TmemAllocation::S1); + + Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{}))); + tStS_v.data() = uint32_t(stage == _0{} ? TmemAllocation::V0 : TmemAllocation::V1); + Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{}))); + + auto tilePlikeFP32 = size<1>(TileShapeQK{}) / Int{} * Int{}; + Tensor tStS_P = tStS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); + tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1)); + Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); + + // Each thread owns a single row + using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem + using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 128 cols of 8b elem + using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tStS); + auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + + Tensor tTMEM_LOADtS = thr_tmem_load.partition_S(tStS); + Tensor tTMEM_LOADcS = thr_tmem_load.partition_D(tScS); + + auto tiled_tmem_storev = make_tmem_copy(TMEM_STORE_V{}, tStS_v); + auto thr_tmem_storev = tiled_tmem_storev.get_slice(thread_idx); + + Tensor tTMEM_STOREVtS = thr_tmem_storev.partition_D(tStS_v); + Tensor tTMEM_STOREVcS = thr_tmem_storev.partition_S(tScS_v); + + auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tStS_P); + auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx); + + Tensor tTMEM_STOREtS_x4 = thr_tmem_store.partition_D(tStS_P); + tTMEM_STOREtS_x4.data() = warp_uniform(tTMEM_STOREtS_x4.data().get()); + Tensor tTMEM_STOREcS = thr_tmem_store.partition_S(tScS_P); + + // wait on tensor core pipe + pipeline_s.consumer_wait(pipeline_s_consumer_state); + + // read all of S from tmem into reg mem + Tensor tTMEM_LOADrS = make_tensor(shape(tTMEM_LOADcS)); + copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS); + + if constexpr (need_mask) { + if(need_apply_mask) { + Mask{}.apply_mask(tTMEM_LOADrS, tTMEM_LOADcS, problem_shape); + } + } + + ElementQK old_row_max = row_max; + { + // compute rowmax + float row_max_0 = row_max; + float row_max_1 = row_max; + float row_max_2 = row_max; + float row_max_3 = row_max; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTMEM_LOADrS); i += 4) { + row_max_0 = ::fmax(row_max_0, tTMEM_LOADrS(i)); + row_max_1 = ::fmax(row_max_1, tTMEM_LOADrS(i+1)); + row_max_2 = ::fmax(row_max_2, tTMEM_LOADrS(i+2)); + row_max_3 = ::fmax(row_max_3, tTMEM_LOADrS(i+3)); + } + row_max = ::fmax(row_max_0, row_max_1); + row_max = ::fmax(row_max, row_max_2); + row_max = ::fmax(row_max, row_max_3); + } + + ElementQK row_max_safe = row_max == -INFINITY ? 0 : row_max; + + Tensor tTMEM_STOREVrS = make_tensor(shape(tTMEM_STOREVcS)); + tTMEM_STOREVrS(kIdxOldRowMax) = old_row_max; + tTMEM_STOREVrS(kIdxNewRowMax) = row_max_safe; + copy(tiled_tmem_storev, tTMEM_STOREVrS, tTMEM_STOREVtS); + + pipeline_c.producer_commit(pipeline_c_producer_state); + ++pipeline_c_producer_state; + + // notify correction wg that they are ready (might need addtl ordering between S0 and S1 WG's) + + ElementQK scale = params.scale_softmax_log2; + ElementQK row_max_scale = row_max_safe * scale; + + float2 scale_fp32x2 = make_float2(scale, scale); + float2 minus_row_max_scale_fp32x2 = make_float2(-row_max_scale, -row_max_scale); + + Tensor tTMEM_STORErS_x4 = make_tensor(shape(tTMEM_STOREcS)); + + constexpr int kConversionsPerStep = 2; + + Tensor tTMEM_STORErS_x4_e = recast>(tTMEM_STORErS_x4); + + NumericArrayConverter convert; + + constexpr int kReleasePipeCount = 10; // must be multiple of 2 + + order_s.wait(); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTMEM_LOADrS); i += 2) { + float2 in = make_float2( + tTMEM_LOADrS(i + 0), + tTMEM_LOADrS(i + 1) + ); + float2 out; + cute::fma(out, scale_fp32x2, in, minus_row_max_scale_fp32x2); + tTMEM_LOADrS(i + 0) = out.x; + tTMEM_LOADrS(i + 1) = out.y; + + tTMEM_LOADrS(i+0) = ::exp2f(tTMEM_LOADrS(i+0)); + tTMEM_LOADrS(i+1) = ::exp2f(tTMEM_LOADrS(i+1)); + + Array in_conv; + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kConversionsPerStep; j++) { + in_conv[j] = tTMEM_LOADrS(i + j); + } + tTMEM_STORErS_x4_e[i / kConversionsPerStep] = convert(in_conv); + + + if (i == size(tTMEM_LOADrS) - kReleasePipeCount) { + order_s.arrive(); + } + + // this prevents register spills in fp16 + if constexpr (size<2>(tTMEM_STORErS_x4) == _2{}) { + if (i == size(tTMEM_LOADrS) - 6) { + copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, 0), tTMEM_STOREtS_x4(_, _, 0)); + } + } + } + + // tmem_store(reg_S8) -> op_P + CUTE_STATIC_ASSERT_V(size<2>(tTMEM_STORErS_x4) <= _2{}); + CUTE_STATIC_ASSERT_V(size<1>(tTMEM_STORErS_x4) == _1{}); + copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1), tTMEM_STOREtS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1)); + + cutlass::arch::fence_view_async_tmem_store(); + + // notify tensor core warp that P is ready + pipeline_s.consumer_release(pipeline_s_consumer_state); + ++pipeline_s_consumer_state; + + pipeline_c.producer_acquire(pipeline_c_producer_state); + + ElementQK acc_scale = 0.5f * ::exp2f(scale * (old_row_max - row_max_safe)); + row_sum *= acc_scale; + // row_sum = sum(reg_S) + float2 local_row_sum_f32x2 = make_float2(row_sum, row_sum); + float2 local_row_sum_1 = make_float2(0, 0); + float2 local_row_sum_2 = make_float2(0, 0); + float2 local_row_sum_3 = make_float2(0, 0); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTMEM_LOADrS); i += 8) { + // row_sum += tTMEM_LOADrS(i); + float2 in = make_float2(tTMEM_LOADrS(i), tTMEM_LOADrS(i+1)); + cute::add(local_row_sum_f32x2, local_row_sum_f32x2, in); + + in = make_float2(tTMEM_LOADrS(i+2), tTMEM_LOADrS(i+2+1)); + cute::add(local_row_sum_1, local_row_sum_1, in); + + in = make_float2(tTMEM_LOADrS(i+4), tTMEM_LOADrS(i+4+1)); + cute::add(local_row_sum_2, local_row_sum_2, in); + + in = make_float2(tTMEM_LOADrS(i+6), tTMEM_LOADrS(i+6+1)); + cute::add(local_row_sum_3, local_row_sum_3, in); + } + + cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_1); + cute::add(local_row_sum_2, local_row_sum_2, local_row_sum_3); + cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_2); + float local_row_sum = local_row_sum_f32x2.x + local_row_sum_f32x2.y; + + row_sum = local_row_sum; + + if (final_call) { + // re-acquire the S part in the final step + pipeline_s.consumer_wait(pipeline_s_consumer_state); + + Tensor tTMEM_STOREVrS = make_tensor(shape(tTMEM_STOREVcS)); + tTMEM_STOREVrS(kIdxFinalRowMax) = row_max; + tTMEM_STOREVrS(kIdxFinalRowSum) = row_sum; + copy(tiled_tmem_storev, tTMEM_STOREVrS, tTMEM_STOREVtS); + } + } + + template + CUTLASS_DEVICE auto + softmax( + Stage stage, + BlkCoord const& blk_coord, + Params const& params, ProblemShape const& problem_shape, + PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state, + PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state, + OrderBarrierSoftmax& order_s) { + const int mask_trip_count = Mask{}.get_masked_trip_count(blk_coord, TileShape{}, problem_shape); + const int total_trip_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); + int trip_idx = total_trip_count; + + ElementQK row_max = -INFINITY; + ElementQK row_sum = 0; + + Tensor cS_base = make_identity_tensor(select<0,1>(TileShapeQK{})); + auto logical_offset = make_coord( + get<0>(blk_coord) * get<0>(TileShape{}) + (stage % get<0>(ThreadShape{})) * get<0>(TileShapeQK{}), + 0 + (stage % get<1>(ThreadShape{})) * get<1>(TileShapeQK{}) + ); + Tensor cS = domain_offset(logical_offset, cS_base); + + pipeline_c.producer_acquire(pipeline_c_producer_state); + + constexpr bool NeedMask = !std::is_same_v; + + CUTLASS_PRAGMA_NO_UNROLL + for (; trip_idx > 0; trip_idx -= 1) { + softmax_step( + trip_idx <= mask_trip_count, + row_max, row_sum, stage, + trip_idx == 1, + blk_coord, cS, params, problem_shape, + pipeline_s, pipeline_s_consumer_state, + pipeline_c, pipeline_c_producer_state, + order_s + ); + + cS.data() = cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{}); + } + + pipeline_c.producer_commit(pipeline_c_producer_state); + ++pipeline_c_producer_state; + + pipeline_c.producer_acquire(pipeline_c_producer_state); + // empty step to sync against pipe s + pipeline_s.consumer_release(pipeline_s_consumer_state); + ++pipeline_s_consumer_state; + } + + template + CUTLASS_DEVICE auto + correction_epilogue( + float scale, + Stage stage, + TensorO const& sO_01) { + + using ElementOut = typename TensorO::value_type; + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + Tensor sO = sO_01(_,_,stage); + + // As opposed to the softmax, we do not have enough registers here + // to load all of the values (for tile kv = 128), so we loop + // good values would be either 32 or 64 + constexpr int kCorrectionTileSize = 32 / sizeof(ElementOut); + + using TMEM_LOAD = std::conditional_t; // 4x32 threads with 64 cols of 32b elem + + typename CollectiveMmaPV::TiledMma mma; + Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{})); + Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{})); + Tensor tOcO = mma.get_slice(0).partition_C(cO); + Tensor tOsO = mma.get_slice(0).partition_C(sO); + + Tensor tOtO_i = logical_divide(tOtO, make_layout(make_shape(_128{}, Int{}))); + Tensor tOcO_i = logical_divide(tOcO, make_layout(make_shape(_128{}, Int{}))); + Tensor tOsO_i = logical_divide(tOsO, make_layout(make_shape(_128{}, Int{}))); + + if constexpr (decltype(stage == _0{})::value) { + tOtO_i.data() = tOtO_i.data().get() + uint32_t(TmemAllocation::O0); + } + else { + static_assert(decltype(stage == _1{})::value, "stage is either 0 or 1"); + tOtO_i.data() = tOtO_i.data().get() + uint32_t(TmemAllocation::O1); + } + + auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i(make_coord(_, _), _0{})); + auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + + Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i(make_coord(_, _), _)); + Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i(make_coord(_, _), _)); + Tensor tTMEM_LOADsO = thr_tmem_load.partition_D(tOsO_i(make_coord(_, _), _)); + + float2 scale_f32x2 = make_float2(scale, scale); + + // loop: + // TMEM_LOAD, FMUL2 scale, TMEM_STORE + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < get<2>(TileShape{}) / kCorrectionTileSize; i++) { + Tensor tTMEM_LOADtO_i = tTMEM_LOADtO(_, _0{}, _0{}, i); + Tensor tTMEM_LOADsO_i = tTMEM_LOADsO(_, _0{}, _0{}, i); + + Tensor tTMrO = make_tensor(shape(tTMEM_LOADcO(_, _0{}, _0{}, i))); + + copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO); + +#ifndef ONLY_SOFTMAX + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(tTMrO); j += 2) { + float2 in = make_float2(tTMrO(j), tTMrO(j+1)); + float2 out; + cute::mul(out, scale_f32x2, in); + tTMrO(j) = out.x; + tTMrO(j+1) = out.y; + } +#endif + + constexpr int N = 4 / sizeof(ElementOut); + NumericArrayConverter convert; + + Tensor tSMrO = make_tensor_like(tTMrO); + + Tensor tCs = recast(tTMrO); + Tensor tCd = recast(tSMrO); + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(tCs); j++) { + tCd(j) = convert.convert(tCs(j)); + } + + Tensor tSMsO_i = recast(tTMEM_LOADsO_i); + Tensor tSMrO_i = recast(tSMrO); + + copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tSMrO_i, tSMsO_i); + } + + cutlass::arch::fence_view_async_shared(); + } + + CUTLASS_DEVICE auto + correction_rescale( + float scale, + uint32_t tmem_O) { + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + // As opposed to the softmax, we do not have enough registers here + // to load all of the values (for tile kv = 128), so we loop + // good values would be either 32 or 64 + constexpr int kCorrectionTileSize = 16; + + using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b16x; // 4x32 threads with 64 cols of 32b elem + using TMEM_STORE = SM100_TMEM_STORE_32dp32b16x; // 4x32 threads with 64 cols of 32b elem + + typename CollectiveMmaPV::TiledMma mma; + Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{})); + Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{})); + Tensor tOcO = mma.get_slice(0).partition_C(cO); + + Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int{}))); + Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int{}))); + + tOtO_i.data() = tOtO_i.data().get() + tmem_O; + + auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i); + auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tOtO_i); + auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx); + + Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i); + Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i); + Tensor tTMEM_STOREtO = thr_tmem_store.partition_D(tOtO_i); + Tensor tTMEM_STOREcO = thr_tmem_store.partition_S(tOcO_i); + static_assert(shape(tTMEM_STOREcO) == shape(tTMEM_LOADcO)); + + float2 scale_f32x2 = make_float2(scale, scale); + + Tensor tTMrO = make_tensor(make_shape(shape(tTMEM_LOADcO), Int<128 / kCorrectionTileSize>{})); + + auto copy_in = [&](int i) { + Tensor tTMEM_LOADtO_i = tTMEM_LOADtO; + tTMEM_LOADtO_i.data() = tTMEM_LOADtO_i.data().get() + uint32_t(i * kCorrectionTileSize); + Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); + copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO_i); + }; + + auto copy_out = [&](int i) { + Tensor tTMEM_STOREtO_i = tTMEM_STOREtO; + tTMEM_STOREtO_i.data() = tTMEM_STOREtO_i.data().get() + uint32_t(i * kCorrectionTileSize); + Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); + copy(tiled_tmem_store, tTMrO_i, tTMEM_STOREtO_i); + }; + + // sequence: LLMSLMSLMSS + + // loop: + // TMEM_LOAD, FMUL2 scale, TMEM_STORE + copy_in(0); + + constexpr int count = get<2>(TileShape{}) / kCorrectionTileSize; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < count; i++) { + if (i != count - 1) { + copy_in(i+1); + } + + Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(tTMrO_i); j += 2) { + float2 in = make_float2(tTMrO_i(j), tTMrO_i(j+1)); + float2 out; + cute::mul(out, scale_f32x2, in); + tTMrO_i(j) = out.x; + tTMrO_i(j+1) = out.y; + } + + copy_out(i); + } + } + + template< + class BlkCoord, class ProblemShape, class ParamsProblemShape, + class TensorStorageEpi, class CollectiveEpilogue + > + CUTLASS_DEVICE auto + correction( + BlkCoord const& blk_coord, + Params const& params, ProblemShape const& problem_shape, + ParamsProblemShape const& params_problem_shape, + TensorStorageEpi& shared_storage_epi, + PipelineC& pipeline_s0_c, typename PipelineC::PipelineState& pipeline_s0_c_consumer_state, + PipelineC& pipeline_s1_c, typename PipelineC::PipelineState& pipeline_s1_c_consumer_state, + PipelineO& pipeline_o, typename PipelineO::PipelineState& pipeline_o_consumer_state, + PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state, + CollectiveEpilogue& epilogue) { + + int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + Tensor tStS = partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0,1>(TileShapeQK{})); + + Tensor cS = make_identity_tensor(select<0,1>(TileShapeQK{})); + Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); + + Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{}))); + Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{}))); + + using TMEM_LOAD_V = SM100_TMEM_LOAD_32dp32b2x; // 4x32 threads with 2 cols of 32b elem + + auto tiled_tmem_loadv = make_tmem_copy(TMEM_LOAD_V{}, tStS_v); + auto thr_tmem_loadv = tiled_tmem_loadv.get_slice(thread_idx); + + Tensor tTMEM_LOADVtS = thr_tmem_loadv.partition_S(tStS_v); + Tensor tTMEM_LOADVcS = thr_tmem_loadv.partition_D(tScS_v); + + Tensor tTMEM_LOADVtS0 = tTMEM_LOADVtS; + tTMEM_LOADVtS0.data() = tTMEM_LOADVtS0.data().get() + uint32_t(TmemAllocation::V0); + Tensor tTMEM_LOADVtS1 = tTMEM_LOADVtS; + tTMEM_LOADVtS1.data() = tTMEM_LOADVtS1.data().get() + uint32_t(TmemAllocation::V1); + + // ignore first signal from softmax as no correction is required + pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); + pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); + ++pipeline_s0_c_consumer_state; + + pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); + + // handle the last iteration differently (i.e. tmem_load/stsm for epi) + mask_tile_count -= 1; + + CUTLASS_PRAGMA_NO_UNROLL + for (; mask_tile_count > 0; mask_tile_count -= 1) { + + pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); + + Tensor tTMEM_LOADVrS = make_tensor(shape(tTMEM_LOADVcS)); + + // read row_wise new global max + copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS); + + // e^(scale * (old_max - new_max) + float scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax))); + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + + correction_rescale(scale, uint32_t(TmemAllocation::O0)); + + pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); + ++pipeline_s1_c_consumer_state; + + cutlass::arch::fence_view_async_tmem_store(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + + pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); + + copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS); + + scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax))); + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + + correction_rescale(scale, uint32_t(TmemAllocation::O1)); + + pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); + ++pipeline_s0_c_consumer_state; + + cutlass::arch::fence_view_async_tmem_store(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + } + + pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); + ++pipeline_s1_c_consumer_state; + + // do the final correction to O1 + // better to somehow special-case it in the loop above + // doesn't matter for non-persistent code, but if it were + // persistent we do not want to release O too early + + pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); + + // read from V0 + // read row_sum and final row_max here + Tensor tTMEM_LOADVrS = make_tensor(shape(tTMEM_LOADVcS)); + copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS); + + pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); + ++pipeline_s0_c_consumer_state; + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + pipeline_epi.producer_acquire(pipeline_epi_producer_state); + // store to epi smem + + // loop: + // TMEM_LOAD + // FMUL2 scale = 1 / global_sum * out_quant_scale + // F2FP + // store to smem + Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{}); + Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), select<0,3>(problem_shape), epilogue.params.dLSE); + correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _0{}, sO); + + if (epilogue.params.ptr_LSE != nullptr) { + int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord); + + int row_offset = 0; + if constexpr (is_variable_length_v>) { + row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)]; + } + + ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax); + + if (row_idx < get<0>(problem_shape)) { + gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse; + } + } + + cutlass::arch::fence_view_async_tmem_load(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + + pipeline_epi.producer_commit(pipeline_epi_producer_state); + ++pipeline_epi_producer_state; + + pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); + + // load from V1 + copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS); + + pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); + ++pipeline_s1_c_consumer_state; + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + pipeline_epi.producer_acquire(pipeline_epi_producer_state); + + correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _1{}, sO); + + if (epilogue.params.ptr_LSE != nullptr) { + int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{}); + + ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax); + + int row_offset = 0; + if constexpr (is_variable_length_v>) { + row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)]; + } + + if (row_idx < get<0>(problem_shape)) { + gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse; + } + } + + cutlass::arch::fence_view_async_tmem_load(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + + pipeline_epi.producer_commit(pipeline_epi_producer_state); + ++pipeline_epi_producer_state; + } + + + template< + class BlkCoord, class ProblemShape, class ParamsProblemShape, + class TensorStorageEpi, class CollectiveEpilogue + > + CUTLASS_DEVICE auto + correction_empty( + BlkCoord const& blk_coord, + Params const& params, ProblemShape const& problem_shape, + ParamsProblemShape const& params_problem_shape, + TensorStorageEpi& shared_storage_epi, + PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state, + CollectiveEpilogue& epilogue) { + + pipeline_epi.producer_acquire(pipeline_epi_producer_state); + + Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{}); + Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), select<0,3>(problem_shape), epilogue.params.dLSE); + float lse = -INFINITY; + int thread_idx = threadIdx.x % (4 * NumThreadsPerWarp); + +#define DSHOW(x) print(#x ": "); print(x); print("\n") + if (threadIdx.x % 128 == 0 && block0()) { + DSHOW(sO); + } +#if 1 + + using ElementOut = typename CollectiveEpilogue::ElementOut; + auto tiled_copy = make_cotiled_copy( + Copy_Atom, ElementOut>{}, + make_ordered_layout(make_shape(_128{}, Int{}), Step<_1, _0>{}), + sO.layout()); + + auto thr_copy = tiled_copy.get_slice(thread_idx); + auto tOgO = thr_copy.partition_D(sO); + auto tOrO = make_tensor(shape(tOgO(_,_,_,_0{}))); + clear(tOrO); + + copy(tiled_copy, tOrO, tOgO(_,_,_,_0{})); +#endif + + if (epilogue.params.ptr_LSE != nullptr) { + int row_idx = thread_idx + get<0>(TileShape{}) * get<0>(blk_coord); + + int row_offset = 0; + if constexpr (is_variable_length_v>) { + row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)]; + } + + if (row_idx < get<0>(problem_shape)) { + gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse; + } + } + + pipeline_epi.producer_commit(pipeline_epi_producer_state); + ++pipeline_epi_producer_state; + + copy(tiled_copy, tOrO, tOgO(_,_,_,_1{})); + cutlass::arch::fence_view_async_shared(); + pipeline_epi.producer_acquire(pipeline_epi_producer_state); + + if (epilogue.params.ptr_LSE != nullptr) { + int row_idx = thread_idx + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{}); + + int row_offset = 0; + if constexpr (is_variable_length_v>) { + row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)]; + } + + if (row_idx < get<0>(problem_shape)) { + gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse; + } + } + + cutlass::arch::fence_view_async_shared(); + pipeline_epi.producer_commit(pipeline_epi_producer_state); + ++pipeline_epi_producer_state; + } + +}; + +} // namespace cutlass::fmha::collective diff --git a/csrc/sm100/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp b/csrc/sm100/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp new file mode 100644 index 0000000..c2d3e2b --- /dev/null +++ b/csrc/sm100/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp @@ -0,0 +1,340 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cute/tensor.hpp" +#include "cute/layout.hpp" + +#include "collective/fmha_common.hpp" +#include "collective/fmha_fusion.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; + +template< + class Element, + class StrideQ, + class StrideK, + class StrideV, + class CollectiveMmaQK, + class CollectiveMmaPV, + class SmemLayoutQ, + class SmemLayoutK, + class SmemLayoutV, + class TensorStorage, + class PipelineQ, + class PipelineKV, + class Mask, + class TileShape, + class OrderLoadEpilogue = cute::false_type +> +struct Sm100MlaFwdLoadTmaWarpspecialized { + + using TileShapeQK = typename CollectiveMmaQK::TileShape; + using TileShapePV = typename CollectiveMmaPV::TileShape; + + static constexpr int TransactionBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v); + static constexpr int TransactionBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v); + + static const int NumWarpsEpilogue = 1; + static const int NumWarpsLoad = 1; + + struct Arguments { + const Element* ptr_Q; + StrideQ dQ; + const Element* ptr_K; + StrideK dK; + const Element* ptr_V; + StrideV dV; + }; + + using TMA_Q = typename CollectiveMmaQK::Params::TMA_A; + using TMA_K = typename CollectiveMmaQK::Params::TMA_B; + using TMA_V = typename CollectiveMmaPV::Params::TMA_B; + + struct Params { + TMA_Q tma_load_q; + TMA_K tma_load_k; + TMA_V tma_load_v; + }; + + template + static Params to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + void* workspace) { + + auto ptr_Q = args.ptr_Q; + auto ptr_K = args.ptr_K; + auto ptr_V = args.ptr_V; + auto dQ = args.dQ; + auto dK = args.dK; + auto dV = args.dV; + auto problem_shape_qk = replace<2>(problem_shape, get<2, 0>(problem_shape) + get<2, 1>(problem_shape)); + + if constexpr (is_variable_length_v>) { + auto cumulative_length_q = get<0>(problem_shape).cumulative_length; + if (cumulative_length_q != nullptr) { + int max_length_q = get<0>(problem_shape).max_length; + // for variable sequence lenght, the batch is in units of row_stride + get<2,1>(dQ) = get<0>(dQ); + get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_q * (1 + get<3,1>(problem_shape))); + // offset ptr by the amount we add back in later + ptr_Q -= max_length_q * get<0>(dQ); + } + } + + if constexpr (is_variable_length_v>) { + auto cumulative_length_kv = get<1>(problem_shape).cumulative_length; + if (cumulative_length_kv != nullptr) { + int max_length_kv = get<1>(problem_shape).max_length; + // for variable sequence lenght, the batch is in units of row_stride + get<2,1>(dK) = get<0>(dK); + get<2,1>(dV) = get<0>(dV); + get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_kv * (1 + get<3,1>(problem_shape))); + // offset ptr by the amount we add back in later + ptr_K -= max_length_kv * get<0>(dK); + ptr_V -= max_length_kv * get<0>(dV); + } + } + + auto problem_shape_pv = replace<1>(select<0,2,1,3>(problem_shape_qk), get<2, 0>(problem_shape)); + + auto params_qk = CollectiveMmaQK::to_underlying_arguments( + problem_shape_qk, + typename CollectiveMmaQK::Arguments { + ptr_Q, dQ, + ptr_K, dK, + }, /*workspace=*/ nullptr); + + auto params_pv = CollectiveMmaPV::to_underlying_arguments( + problem_shape_pv, + typename CollectiveMmaPV::Arguments { + ptr_K, dK, // never used, dummy + ptr_V, select<1,0,2>(dV), + }, /*workspace=*/ nullptr); + + return Params{ + params_qk.tma_load_a, + params_qk.tma_load_b, + params_pv.tma_load_b + }; + } + + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor()); + } + + template + CUTLASS_DEVICE void + load( + BlkCoord const& blk_coord_in, ProblemShape const& problem_shape, + Params const& params, ParamsProblemShape const& params_problem_shape, + TensorStorage& storage, + PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state, + PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) { + + BlkCoord blk_coord_q = blk_coord_in; + BlkCoord blk_coord_kv = blk_coord_in; + + auto problem_shape_qk = replace<2>(problem_shape, get<2, 0>(problem_shape) + get<2, 1>(problem_shape)); + auto problem_shape_v = replace<2>(problem_shape, get<2, 0>(problem_shape)); + + int mask_tile_count = Mask{}.get_trip_count(blk_coord_in, TileShape{}, problem_shape); + + using X = Underscore; + + // this one is only executed by one thread, no need to elect_one + + // Q1, K1, Q2, V1, K2, V2, K3, V3, ... + // two pipes: Q and KV + // from Memory (prod) to TensorCore (cons) + + // compute gQ, sQ + // we load 2*get<0>(blk_coord), and 2*get<0>(blk_coord) + 1 + ThrMMA mma_qk = typename CollectiveMmaQK::TiledMma{}.get_slice(0); + Tensor mQ_qdl_p = params.tma_load_q.get_tma_tensor(select<0,2,3>(problem_shape_qk)); + + int q_offs_0 = 0; + int q_offs_2_1 = 0; + + if constexpr (is_variable_length_v>) { + auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length; + if (cumulative_length_q != nullptr) { + int max_length_q = get<0>(params_problem_shape).max_length; + q_offs_0 = max_length_q - get<0>(problem_shape); + q_offs_2_1 = cumulative_length_q[get<2,1>(blk_coord_q)] + get<0>(problem_shape); + get<2,1>(blk_coord_q) = 0; + } + } + + Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, q_offs_2_1)), mQ_qdl_p); + + Tensor gQ_qdl = local_tile(mQ_qdl, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{}); + Tensor tSgQ_qdl = mma_qk.partition_A(gQ_qdl); + Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{}); + auto [tQgQ_qdl, tQsQ] = tma_partition( + params.tma_load_q, _0{}, make_layout(_1{}), + group_modes<0,3>(sQ), group_modes<0,3>(tSgQ_qdl) + ); + Tensor tQgQ = tQgQ_qdl(_, _, _0{}, get<2>(blk_coord_q)); + + // compute gK, sK + Tensor mK_kdl_p = params.tma_load_k.get_tma_tensor(select<1,2,3>(problem_shape_qk)); + + int kv_offs_0 = 0; + int kv_offs_2_1 = 0; + + if constexpr (is_variable_length_v>) { + auto cumulative_length = get<1>(params_problem_shape).cumulative_length; + if (cumulative_length != nullptr) { + int max_length = get<1>(params_problem_shape).max_length; + kv_offs_0 = max_length - get<1>(problem_shape); + kv_offs_2_1 = cumulative_length[get<2,1>(blk_coord_kv)] + get<1>(problem_shape); + get<2,1>(blk_coord_kv) = 0; + } + } + + Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, kv_offs_2_1)), mK_kdl_p); + + Tensor gK_kdl = local_tile(mK_kdl, TileShapeQK{}, make_coord(_, _, _), Step{}); + Tensor tSgK_kdl = mma_qk.partition_B(gK_kdl); + Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{}); + auto [tKgK_kdl, tKsK] = tma_partition( + params.tma_load_k, _0{}, make_layout(_1{}), + group_modes<0,3>(sK), group_modes<0,3>(tSgK_kdl) + ); + Tensor tKgK = tKgK_kdl(_, _, _0{}, get<2>(blk_coord_kv)); + + // compute gV, sV + ThrMMA mma_pv = typename CollectiveMmaPV::TiledMma{}.get_slice(0); + Tensor mV_dkl_p = params.tma_load_v.get_tma_tensor(select<2,1,3>(problem_shape_v)); + + Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, kv_offs_2_1)), mV_dkl_p); + + Tensor gV_dkl = local_tile(mV_dkl, TileShapePV{}, make_coord(_, _, _), Step{}); + Tensor tOgV_dkl = mma_pv.partition_B(gV_dkl); + Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{}); + auto [tVgV_dkl, tVsV] = tma_partition( + params.tma_load_v, _0{}, make_layout(_1{}), + group_modes<0,3>(sV), group_modes<0,3>(tOgV_dkl) + ); + auto tVgV = tVgV_dkl(_, _0{}, _, get<2>(blk_coord_kv)); + + // blk_coord in decomposed in terms of TileShape, not TileShapeQK + // As such, it needs to be transformed as + // (a,b,c): a -> 2*a (Q0) 2*a+1 (Q1) + // b -> 2*a (Ki i even) 2*a+1 (Ki i odd) + + uint32_t lane_predicate = cute::elect_one_sync(); + + // Q1 + int q0_index = 2 * get<0>(blk_coord_q); + int q1_index = 2 * get<0>(blk_coord_q) + 1; + pipeline_q.producer_acquire(pipeline_q_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state); + copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q0_index), tQsQ(_, pipeline_q_producer_state.index())); + } + ++pipeline_q_producer_state; + + // K1 + int k_index = 0; + pipeline_kv.producer_acquire(pipeline_kv_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); + copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, k_index), tKsK(_, pipeline_kv_producer_state.index() / 2)); + } + ++pipeline_kv_producer_state; + + // Q2 + pipeline_q.producer_acquire(pipeline_q_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state); + copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q1_index), tQsQ(_, pipeline_q_producer_state.index())); + } + ++pipeline_q_producer_state; + + if constexpr (cute::is_same_v) { + cutlass::arch::NamedBarrier::sync((NumWarpsLoad + NumWarpsEpilogue) * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + } + + // V1 + pipeline_kv.producer_acquire_bytes(pipeline_kv_producer_state, TransactionBytesLoadV); + if (lane_predicate) { + auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); + copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index() / 2)); + } + ++pipeline_kv_producer_state; + k_index += 1; + + // loop: + mask_tile_count -= 1; + for (; mask_tile_count > 0; mask_tile_count -= 1) { + + // Ki + pipeline_kv.producer_acquire(pipeline_kv_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); + copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, k_index), tKsK(_, pipeline_kv_producer_state.index() / 2)); + + // prefetch vi + cute::prefetch(params.tma_load_v, tVgV(_, k_index)); + } + ++pipeline_kv_producer_state; + + // Vi + pipeline_kv.producer_acquire_bytes(pipeline_kv_producer_state, TransactionBytesLoadV); + if (lane_predicate) { + auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); + copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index() / 2)); + + // prefetch ki+1 + if(mask_tile_count > 1) { + cute::prefetch(params.tma_load_k, tKgK(_, k_index + 1)); + } + } + ++pipeline_kv_producer_state; + k_index += 1; + } + } +}; + +} // namespace cutlass::fmha::collective diff --git a/csrc/sm100/common/gather_tensor.hpp b/csrc/sm100/common/gather_tensor.hpp new file mode 100644 index 0000000..46fb640 --- /dev/null +++ b/csrc/sm100/common/gather_tensor.hpp @@ -0,0 +1,215 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/layout.hpp" +#include "cute/tensor.hpp" +#include "cute/util/print.hpp" + +namespace example { + +using namespace cute; + +// Empty type used to disable gather/scatter for a GEMM argument +struct NoGather +{ + template + NoGather(Ts...) {}; +}; + +/// Function object that applies an index to its argument +template +struct IndexedGather +{ + CUTE_HOST_DEVICE constexpr + IndexedGather(Index const *indices = {}): indices_(indices) {} + + template + CUTE_HOST_DEVICE constexpr + Index + operator()(I i) const { return indices_[i]; } + + CUTE_HOST_DEVICE friend + void + print(IndexedGather const &s) { + cute::print("Indexed"); + } + + Index const *indices_; +}; + +/// Function object that applies a stride to its argument +/// Example: StridedFunc gathers every other row/column +template +struct StridedGather +{ + CUTE_HOST_DEVICE constexpr + StridedGather(Stride stride = {}): stride_(stride) {} + + template + CUTE_HOST_DEVICE constexpr + auto + operator()(I i) const { return i * stride_; } + + CUTE_HOST_DEVICE friend + void + print(StridedGather const &s) { + cute::print("Strided{"); + print(s.stride_); + cute::print("}"); + } + + Stride stride_; +}; + +/// Custom stride object that applies a function followed by a stride +template +struct CustomStride +{ + CUTE_HOST_DEVICE constexpr + CustomStride(Func const &func, Stride const &stride): func_(func), stride_(stride) {} + + template + CUTE_HOST_DEVICE constexpr friend + auto + operator*(I i, CustomStride const &s) { return s.func_(i) * s.stride_; } + + template + CUTE_HOST_DEVICE constexpr friend + auto + operator*(CustomStride const &s, I i) { return s.func_(i) * s.stride_; } + + CUTE_HOST_DEVICE friend + void + print(CustomStride const & s) { + cute::print("Custom{"); + print(s.func_); + cute::print(","); + print(s.stride_); + cute::print("}"); + } + + template + CUTE_HOST_DEVICE constexpr friend + auto + safe_div(CustomStride const &s, Div const &div) + { + return CustomStride(s.func_, safe_div(s.stride_, div)); + } + + // Circumvent the requirement on make_layout that shape and stride are integral + template + CUTE_HOST_DEVICE constexpr friend + auto + make_layout(Shape const &shape, CustomStride const &stride) + { + return Layout(shape, stride); + } + + Func func_; + Stride stride_; +}; + +template +CUTLASS_HOST_DEVICE +auto +make_custom_stride_layout(Stride const &stride, Func&& func) +{ + // Use a dummy shape and replace the first non-unit stride with a custom gather stride + auto idx = find_if(stride, [](auto x){ return not is_constant<1, decltype(x)>{}; }); + constexpr int I = decltype(idx)::value; + return make_layout(repeat_like(stride, _1{}), + replace(stride, CustomStride{static_cast(func), get(stride)})); +} + +/// Helper function to optionally create a gather tensor +template +CUTLASS_HOST_DEVICE +auto +make_gather_tensor(Iterator iter, Shape const &shape, Stride const &stride, Func &&func) +{ + if constexpr (not cutlass::platform::is_same, NoGather>::value) { + Layout matrix_layout = make_identity_layout(shape); + auto offset = as_arithmetic_tuple(repeat_like(shape, _0{})); + Layout gather_layout = make_custom_stride_layout(stride, static_cast(func)); + return make_tensor(iter, ComposedLayout{gather_layout, offset, matrix_layout}); + } else { + return make_tensor(iter, shape, stride); + } +} + +} // namespace example + +namespace cute +{ + +template +CUTE_HOST_DEVICE constexpr +auto +upcast(Shape const& shape, Stride const& stride) +{ + if constexpr (is_tuple::value) { + return transform_layout(shape, stride, [](auto const& s, auto const& d) { return upcast(s,d); }); + } else if constexpr (is_scaled_basis::value) { + if constexpr (Stride::mode() == I) { + return make_layout(ceil_div(shape, Int{}), ceil_div(stride, Int{})); + } else { + return make_layout(shape, stride); + } + } else { + return upcast(shape, stride); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +upcast(ComposedLayout,Offset,Layout> const& layout) +{ + // Find index of the stride-1 mode - that is the only one that requires updating inner shape and offset + auto idx = find_if(layout.layout_a().stride(), [](auto x){ return is_constant<1, decltype(x)>{}; }); + constexpr int I = decltype(idx)::value; + + // Upcast the outer layout (works as expected) + auto outer = upcast(layout.layout_a()); + + // Upcast the accumulated offset along stride-1 mode + auto offset = as_arithmetic_tuple(replace(layout.offset(), upcast(get(layout.offset())))); + + // Upcast the inner layout's shape along stride-1 mode + auto inner = upcast(layout.layout_b().shape(), layout.layout_b().stride()); + + return composition(outer, offset, inner); +} + +} // namespace example diff --git a/csrc/sm100/common/helper.h b/csrc/sm100/common/helper.h new file mode 100644 index 0000000..e957c4e --- /dev/null +++ b/csrc/sm100/common/helper.h @@ -0,0 +1,72 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + #pragma once + + #include "cuda_runtime.h" + #include + + /** + * Panic wrapper for unwinding CUTLASS errors + */ + #define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + if (error != cutlass::Status::kSuccess) { \ + std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at: " << __LINE__ \ + << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + + + /** + * Panic wrapper for unwinding CUDA runtime errors + */ + #define CUDA_CHECK(status) \ + { \ + cudaError_t error = status; \ + if (error != cudaSuccess) { \ + std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \ + << " at line: " << __LINE__ << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + + +#define FLASH_MLA_ASSERT(cond) \ +do { \ + if (!(cond)) { \ + std::cerr << "FLASH_MLA_ASSERT: " << #cond << " failed at " << __FILE__ << ":" << __LINE__ << std::endl; \ + std::abort(); \ + } \ +} while (0) + + \ No newline at end of file diff --git a/csrc/sm100/common/mask.cuh b/csrc/sm100/common/mask.cuh new file mode 100644 index 0000000..d118aab --- /dev/null +++ b/csrc/sm100/common/mask.cuh @@ -0,0 +1,8 @@ +#pragma once + +enum class MaskMode { + kNone = 0U, // No mask + kCausal = 1U, // Causal mask + kCustom = 2U, // Custom mask +}; + diff --git a/csrc/sm100/common/pipeline_mla.hpp b/csrc/sm100/common/pipeline_mla.hpp new file mode 100644 index 0000000..5bbeed9 --- /dev/null +++ b/csrc/sm100/common/pipeline_mla.hpp @@ -0,0 +1,250 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief Support the producer to acquire specific bytes of data. +*/ + +#pragma once + +#include "cutlass/pipeline/sm100_pipeline.hpp" + +namespace cutlass { + +using namespace cute; + +template < + int Stages_, + class ClusterShape = Shape, + class AtomThrShape_MNK_ = Shape<_1,_1,_1> +> +class PipelineTmaAsyncMla { + +public: + static constexpr uint32_t Stages = Stages_; + using AtomThrShape_MNK = AtomThrShape_MNK_; + +private: + using Impl = PipelineTmaUmmaAsync; + +public: + using FullBarrier = typename Impl::FullBarrier; + using EmptyBarrier = typename Impl::EmptyBarrier; + using ProducerBarrierType = typename Impl::ProducerBarrierType; + using ConsumerBarrierType = typename Impl::ConsumerBarrierType; + using PipelineState = typename Impl::PipelineState; + using SharedStorage = typename Impl::SharedStorage; + using ThreadCategory = typename Impl::ThreadCategory; + using Params = typename Impl::Params; + + + using McastDirection = McastDirection; + + // Helper function to initialize barriers + static + CUTLASS_DEVICE + void + init_barriers(SharedStorage& storage, Params params, ClusterShape cluster_shape) { + int warp_idx = canonical_warp_idx_sync(); + if (warp_idx == params.initializing_warp) { + // Barrier FULL and EMPTY init + constexpr int producer_arv_cnt = 1; + auto atom_thr_shape = AtomThrShape_MNK{}; + uint32_t const multicast_consumer_arrival_count = (cute::size<0>(cluster_shape) / cute::size<0>(atom_thr_shape)) + + (cute::size<1>(cluster_shape) / cute::size<1>(atom_thr_shape)) - 1; + + cutlass::arch::detail::initialize_barrier_array_pair_aligned( + storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, multicast_consumer_arrival_count); + } + cutlass::arch::fence_barrier_init(); + } + + static + CUTLASS_DEVICE + void + init_barriers(SharedStorage& storage, Params params, ClusterShape cluster_shape, McastDirection mcast_direction) { + auto atom_thr_shape = AtomThrShape_MNK{}; + + int warp_idx = canonical_warp_idx_sync(); + if (warp_idx == params.initializing_warp) { + // Barrier FULL and EMPTY init + constexpr int producer_arv_cnt = 1; + uint32_t const multicast_consumer_arrival_count = (mcast_direction == McastDirection::kRow) ? + cute::size<1>(cluster_shape) / cute::size<1>(atom_thr_shape) : // Mcast with row ctas + cute::size<0>(cluster_shape) / cute::size<0>(atom_thr_shape); // Mcast with col ctas + + cutlass::arch::detail::initialize_barrier_array_pair_aligned( + storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, multicast_consumer_arrival_count); + } + cutlass::arch::fence_barrier_init(); + } + + CUTLASS_DEVICE + void init_masks(ClusterShape cluster_shape, dim3 block_id_in_cluster = cute::block_id_in_cluster()) { + // Calculate consumer mask + if (params_.role == ThreadCategory::Consumer) { + auto cluster_layout = make_layout(cluster_shape); + block_id_mask_ = detail::calculate_multicast_mask(cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster); + } + } + + CUTLASS_DEVICE + void init_masks(ClusterShape cluster_shape, McastDirection mcast_direction) { + // Calculate consumer mask + dim3 block_id_in_cluster = cute::block_id_in_cluster(); + auto cluster_layout = make_layout(cluster_shape); + if (mcast_direction == McastDirection::kRow) { + block_id_mask_ = detail::calculate_multicast_mask(cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster); + } + else { + block_id_mask_ = detail::calculate_multicast_mask(cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster); + } + } + + +public: + template + CUTLASS_DEVICE + PipelineTmaAsyncMla(SharedStorage& storage, Params params, ClusterShape cluster_shape, InitBarriers = {}, InitMasks = {}) + : impl_(storage, params, cluster_shape, cute::false_type{}, cute::false_type{}) + , params_(params) + , empty_barrier_ptr_(&storage.empty_barrier_[0]) + , full_barrier_ptr_(&storage.full_barrier_[0]) { + static_assert(cute::is_same_v || cute::is_same_v); + if constexpr (cute::is_same_v) { + init_barriers(storage, params_, cluster_shape); + } + + static_assert(cute::is_same_v || cute::is_same_v); + if constexpr (cute::is_same_v) { + init_masks(cluster_shape); + } + } + + template + CUTLASS_DEVICE + PipelineTmaAsyncMla(SharedStorage& storage, Params params, ClusterShape cluster_shape, McastDirection mcast_direction, InitBarriers = {}, InitMasks = {}) + : impl_(storage, params, cluster_shape, cute::false_type{}, cute::false_type{}) + , params_(params) + , empty_barrier_ptr_(&storage.empty_barrier_[0]) + , full_barrier_ptr_(&storage.full_barrier_[0]) { + static_assert(cute::is_same_v || cute::is_same_v); + if constexpr (cute::is_same_v) { + init_barriers(storage, params_, cluster_shape, mcast_direction); + } + + static_assert(cute::is_same_v || cute::is_same_v); + if constexpr (cute::is_same_v) { + init_masks(cluster_shape, mcast_direction); + } + } + + + CUTLASS_DEVICE + void producer_acquire(PipelineState state, ProducerToken barrier_token = {BarrierStatus::WaitAgain}) { + impl_.producer_acquire(state, barrier_token); + } + + CUTLASS_DEVICE + void producer_acquire_bytes(uint32_t stage, uint32_t bytes, uint32_t phase, ProducerToken barrier_token) { + detail::pipeline_check_is_producer(params_.role); + if (barrier_token != BarrierStatus::WaitDone) { + empty_barrier_ptr_[stage].wait(phase); + } + + if (params_.is_leader) { + full_barrier_ptr_[stage].arrive_and_expect_tx(bytes); + } + #ifndef NDEBUG + if (params_.role == ThreadCategory::Consumer || params_.role == ThreadCategory::NonParticipant) { + asm volatile ("brkpt;\n" ::); + } + + // Most likely you have elected more than one leader + if (params_.is_leader && (threadIdx.x % 32 != 0)) { + asm volatile ("brkpt;\n" ::); + } + #endif + } + + CUTLASS_DEVICE + void producer_acquire_bytes(PipelineState state, uint32_t bytes, ProducerToken barrier_token = {BarrierStatus::WaitAgain}) { + producer_acquire_bytes(state.index(), bytes, state.phase(), barrier_token); + } + + CUTLASS_DEVICE + ProducerBarrierType* producer_get_barrier(PipelineState state) { + return impl_.producer_get_barrier(state); + } + + CUTLASS_DEVICE + void consumer_wait(PipelineState state, ConsumerToken barrier_token = {BarrierStatus::WaitAgain}) { + impl_.consumer_wait(state, barrier_token); + } + + CUTLASS_DEVICE + void consumer_release(PipelineState state) { + consumer_release(state.index(), false); + } + +private: + Impl impl_; + Params params_; + EmptyBarrier *empty_barrier_ptr_; + FullBarrier *full_barrier_ptr_; + uint16_t block_id_mask_ = 0; + static constexpr bool is_2sm_mma = size(AtomThrShape_MNK{}) > 1; + + // Consumer signalling Producer of completion + // Ensures all blocks in the Same Row and Column get notifed. + CUTLASS_DEVICE + void consumer_release(uint32_t stage, uint32_t skip) { + detail::pipeline_check_is_consumer(params_.role); + uint64_t* smem_ptr = reinterpret_cast(&empty_barrier_ptr_[stage]); + if constexpr (is_2sm_mma) { // Mma cluster shape is 2x1 + if (!skip) { + cutlass::arch::umma_arrive_multicast_2x1SM(smem_ptr, block_id_mask_); + } + } + else { + if (!skip) { + if constexpr (cute::is_static_v and size(ClusterShape{}) == 1) { + cutlass::arch::umma_arrive(smem_ptr); + } + else { + cutlass::arch::umma_arrive_multicast(smem_ptr, block_id_mask_); + } + } + } + } +}; + +} diff --git a/csrc/sm100/common/pow_2.hpp b/csrc/sm100/common/pow_2.hpp new file mode 100644 index 0000000..eca9325 --- /dev/null +++ b/csrc/sm100/common/pow_2.hpp @@ -0,0 +1,92 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include + +#include + +namespace cutlass::fmha { + +struct Pow2 { + int n; + int log2_n; + + explicit CUTE_DEVICE Pow2(int n) : n(n) { +#ifdef __CUDA_ARCH__ + log2_n = __ffs(n) - 1; +#endif + } + + template + CUTE_HOST_DEVICE T operator *(T const& b) const { + return n * b; + } + + template + CUTE_HOST_DEVICE auto operator *(Int const&) const { + if constexpr (N & (N - 1) == 0) { + return Pow2{n * N}; + } + return n * N; + } + +}; + +template +CUTE_HOST_DEVICE auto operator/(T const& a, Pow2 const& b) { + return a >> b.log2_n; +} + +template +CUTE_HOST_DEVICE auto operator%(T const& a, Pow2 const& b) { + return a & (b.n - 1); +} + +template +CUTE_HOST_DEVICE bool operator<(T const& a, Pow2 const& b) { + return a < b.n; +} + +CUTE_HOST_DEVICE void print(Pow2 const& a) { + printf("2^%d", a.log2_n); +} + +} // end namespace cutlass::fmha + +namespace cute { + +template <> +struct is_integral : true_type {}; + +} // end namespace cute diff --git a/csrc/sm100/common/utils.hpp b/csrc/sm100/common/utils.hpp new file mode 100644 index 0000000..f43770d --- /dev/null +++ b/csrc/sm100/common/utils.hpp @@ -0,0 +1,83 @@ +#pragma once + +#include "cutlass/numeric_types.h" +#include "helper.h" + +template +struct cutlass_dtype { + using type = T; +}; + +template <> +struct cutlass_dtype { + using type = cutlass::half_t; +}; + +template <> +struct cutlass_dtype { + using type = cutlass::bfloat16_t; +}; + +template <> +struct cutlass_dtype<__nv_fp8_e4m3> { + using type = cutlass::float_e4m3_t; +}; + +template <> +struct cutlass_dtype<__nv_fp8_e5m2> { + using type = cutlass::float_e5m2_t; +}; + +template +using cutlass_dtype_t = typename cutlass_dtype::type; + +template +struct DeviceAllocation { + T* ptr_ = nullptr; + size_t offset_ = 0; + size_t size_ = 0; + + DeviceAllocation(DeviceAllocation const&) = delete; + DeviceAllocation& operator=(DeviceAllocation const&) = delete; + + DeviceAllocation() = default; + DeviceAllocation(size_t size) { reset(size); } + ~DeviceAllocation() { reset(); } + + void reset(size_t size, size_t offset=0) { + reset(); + auto ret = cudaMalloc(&ptr_, sizeof(T) * (size + offset)); + assert(ret == cudaSuccess); + size_ = size; + offset_ = offset; + } + + T* get() { + return ptr_ + offset_; + } + + const T* get() const { + return ptr_ + offset_; + } + + void reset() { + if (ptr_ != nullptr) { + auto ret = cudaFree(ptr_); + assert(ret == cudaSuccess); + } + } + + size_t size() const { return size_; } + + size_t get_storage_size() const { return (size_ + offset_) * sizeof(T); } + + void copy_from_host(const T* ptr, size_t sz) { + auto ret = cudaMemcpy(ptr_, ptr, sz * sizeof(T), cudaMemcpyDefault); + assert(ret == cudaSuccess); + } + + void copy_from_device(const T* ptr, size_t sz) { + auto ret = cudaMemcpy(ptr_, ptr, sz * sizeof(T), cudaMemcpyDefault); + assert(ret == cudaSuccess); + } +}; \ No newline at end of file diff --git a/csrc/sm100/device/fmha.hpp b/csrc/sm100/device/fmha.hpp new file mode 100644 index 0000000..f8406d3 --- /dev/null +++ b/csrc/sm100/device/fmha.hpp @@ -0,0 +1,276 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief An universal device layer for cutlass 3.x-style kernels. +*/ + +#pragma once + +// common +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" + +#if !defined(__CUDACC_RTC__) +#include "cutlass/cluster_launch.hpp" +#include "cutlass/trace.h" +#endif // !defined(__CUDACC_RTC__) + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::fmha::device { + +//////////////////////////////////////////////////////////////////////////////// +////////////////////////////// CUTLASS 3.x API ///////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +template +class FMHA { +public: + using Kernel = Kernel_; + + static int const kThreadCount = Kernel::MaxThreadsPerBlock; + + /// Argument structure: User API + using Arguments = typename Kernel::Arguments; + /// Argument structure: Kernel API + using Params = typename Kernel::Params; + +private: + + /// Kernel API parameters object + Params params_; + + bool is_initialized(bool set = false) { + static bool initialized = false; + if (set) initialized = true; + return initialized; + } + +public: + + /// Access the Params structure + Params const& params() const { + return params_; + } + + /// Determines whether the GEMM can execute the given problem. + static Status + can_implement(Arguments const& args) { + if (Kernel::can_implement(args)) { + return Status::kSuccess; + } + else { + return Status::kInvalid; + } + } + + /// Gets the workspace size + static size_t + get_workspace_size(Arguments const& args) { + size_t workspace_bytes = 0; + workspace_bytes += Kernel::get_workspace_size(args); + return workspace_bytes; + } + + /// Computes the grid shape + static dim3 + get_grid_shape(Params const& params) { + return Kernel::get_grid_shape(params); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int /* smem_capacity */ = -1) { + CUTLASS_TRACE_HOST("FMHA::maximum_active_blocks()"); + int max_active_blocks = -1; + int smem_size = Kernel::SharedStorageSize; + + // first, account for dynamic smem capacity if needed + cudaError_t result; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaFuncSetAttribute() returned error: " + << cudaGetErrorString(result)); + return -1; + } + } + + // query occupancy after setting smem size + result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, + device_kernel, + Kernel::MaxThreadsPerBlock, + smem_size); + + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: " + << cudaGetErrorString(result)); + return -1; + } + + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + + /// Initializes GEMM state from arguments. + Status + initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("FMHA::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + // Initialize the workspace + Status status = Kernel::initialize_workspace(args, workspace, stream); + if (status != Status::kSuccess) { + return status; + } + + // Initialize the Params structure + params_ = Kernel::to_underlying_arguments(args, workspace); + + if (is_initialized()) return Status::kSuccess; + + // account for dynamic smem capacity if needed + int smem_size = Kernel::SharedStorageSize; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + cudaError_t result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + + is_initialized(true); + + return Status::kSuccess; + } + + /// Update API is preserved in 3.0, but does not guarantee a lightweight update of params. + Status + update(Arguments const& args, void* workspace = nullptr) { + CUTLASS_TRACE_HOST("FMHA()::update() - workspace: " << workspace); + + size_t workspace_bytes = get_workspace_size(args); + if (workspace_bytes > 0 && nullptr == workspace) { + return Status::kErrorWorkspaceNull; + } + + params_ = Kernel::to_underlying_arguments(args, workspace); + return Status::kSuccess; + } + + /// Primary run() entry point API that is static allowing users to create and manage their own params. + /// Supplied params struct must be construct by calling Kernel::to_underling_arguments() + static Status + run(Params& params, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("FMHA::run()"); + dim3 const block = Kernel::get_block_shape(); + dim3 const grid = get_grid_shape(params); + + // configure smem size and carveout + int smem_size = Kernel::SharedStorageSize; + + Status launch_result; + // Use extended launch API only for mainloops that use it + if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) { + dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}), + cute::size<1>(typename Kernel::ClusterShape{}), + cute::size<2>(typename Kernel::ClusterShape{})); + void const* kernel = (void const*) device_kernel; + void* kernel_params[] = {¶ms}; + launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params); + } + else { + launch_result = Status::kSuccess; + device_kernel<<>>(params); + } + + cudaError_t result = cudaGetLastError(); + if (cudaSuccess == result && Status::kSuccess == launch_result) { + return Status::kSuccess; + } + else { + CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); + return Status::kErrorInternal; + } + } + + // + // Non-static launch overloads that first create and set the internal params struct of this kernel handle. + // + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + Status status = initialize(args, workspace, stream); + if (Status::kSuccess == status) { + status = run(params_, stream); + } + return status; + } + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + return run(args, workspace, stream); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + run(cudaStream_t stream = nullptr) { + return run(params_, stream); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + operator()(cudaStream_t stream = nullptr) { + return run(params_, stream); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::device + +//////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/sm100/device/fmha_device_bwd.hpp b/csrc/sm100/device/fmha_device_bwd.hpp new file mode 100644 index 0000000..d2463ac --- /dev/null +++ b/csrc/sm100/device/fmha_device_bwd.hpp @@ -0,0 +1,340 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +// common +#include "cutlass/cutlass.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cute/tensor.hpp" + +#include "../device/fmha.hpp" +#include "../kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp" +#include "../kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp" +#include "../kernel/fmha_kernel_bwd_sum_OdO.hpp" +#include "../kernel/fmha_kernel_bwd_convert.hpp" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::fmha::device { + +//////////////////////////////////////////////////////////////////////////////// +////////////////////////////// CUTLASS 3.x API ///////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class Element, + class ElementAccumulator, + class TileShape, + bool IsMla, + class Mask +> +class Sm100FmhaBwd { +public: + /// Argument structure: User API + struct Arguments { + // Q K D D_VO HB + ProblemShape problem_shape; + + const Element* ptr_Q; + cute::tuple> stride_Q; + const Element* ptr_K; + cute::tuple> stride_K; + const Element* ptr_V; + cute::tuple> stride_V; + + const Element* ptr_O; + cute::tuple> stride_O; + const ElementAccumulator* ptr_LSE; + cute::tuple> stride_LSE; + + const Element* ptr_dO; + cute::tuple> stride_dO; + + Element* ptr_dQ; + cute::tuple> stride_dQ; + Element* ptr_dK; + cute::tuple> stride_dK; + Element* ptr_dV; + cute::tuple> stride_dV; + + ElementAccumulator softmax_scale; + + cutlass::KernelHardwareInfo hw_info; + }; + + using OperationSumOdO = cutlass::fmha::device::FMHA< + cutlass::fmha::kernel::FmhaKernelBwdSumOdO + >; + using OperationConvert = cutlass::fmha::device::FMHA< + cutlass::fmha::kernel::FmhaKernelBwdConvert + >; + + using OperationMha= cutlass::fmha::device::FMHA< + cutlass::fmha::kernel::Sm100FmhaBwdKernelTmaWarpSpecialized< + ProblemShape, Element, ElementAccumulator, TileShape, Mask + > + >; + + using OperationMla = cutlass::fmha::device::FMHA< + cutlass::fmha::kernel::Sm100FmhaBwdMlaKernelTmaWarpSpecialized< + ProblemShape, Element, ElementAccumulator, TileShape, Mask + > + >; + + using Operation = std::conditional_t; + + using Kernel = typename Operation::Kernel; + + struct Params { + OperationSumOdO op_sum_OdO; + Operation op; + OperationConvert op_convert; + ElementAccumulator* dQ_acc; + size_t dQ_acc_size; + }; + +private: + Params params_; + + static typename OperationSumOdO::Arguments to_sum_OdO_arguments( + Arguments const& args, + ElementAccumulator* sum_odo = nullptr, + ElementAccumulator* scaled_lse = nullptr) { + using namespace cute; + auto [Q_, K, D, D_VO, HB] = args.problem_shape; + auto [H, B] = HB; + D = cutlass::round_up(D, 8); // Alignment + int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment + auto stride_sum_OdO = make_stride(_1{}, make_stride(Q, Q*H)); + auto stride_scaled_lse = make_stride(_1{}, make_stride(Q, Q*H)); + auto log2_e = log2f(expf(1.0f)); + return typename OperationSumOdO::Arguments { + args.problem_shape, + args.ptr_O, args.stride_O, + args.ptr_dO, args.stride_dO, + sum_odo, stride_sum_OdO, + args.ptr_LSE, args.stride_LSE, + scaled_lse, stride_scaled_lse, + -1.0f, -log2_e + }; + } + + static typename OperationConvert::Arguments to_convert_arguments(Arguments const& args, ElementAccumulator* src = nullptr) { + using namespace cute; + auto [Q_, K, D, D_VO, HB] = args.problem_shape; + auto [H, B] = HB; + D = cutlass::round_up(D, 8); // Alignment + int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment + auto stride_src_dQ = make_stride(D, _1{}, make_stride(D*Q, D*Q*H)); + return typename OperationConvert::Arguments { + args.problem_shape, + src, stride_src_dQ, + nullptr, stride_src_dQ, + nullptr, stride_src_dQ, + args.ptr_dQ, args.stride_dQ, + nullptr, args.stride_dK, + nullptr, args.stride_dV, + args.softmax_scale + }; + } + + static typename Operation::Arguments to_bwd_arguments( + Arguments const& args, + ElementAccumulator* sum_OdO = nullptr, cute::tuple> const& stride_sum_OdO = {}, + ElementAccumulator* scaled_lse = nullptr, cute::tuple> const& stride_scaled_lse = {}, + ElementAccumulator* dQ_acc = nullptr, cute::tuple> const& stride_dQ = {}) { + + return typename Operation::Arguments{ + args.problem_shape, + { args.ptr_Q, args.stride_Q, + args.ptr_K, args.stride_K, + args.ptr_V, args.stride_V, + args.ptr_dO, args.stride_dO, + scaled_lse, stride_scaled_lse, + sum_OdO, stride_sum_OdO, + dQ_acc, stride_dQ, + args.softmax_scale }, + { args.ptr_dK, args.stride_dK, + args.ptr_dV, args.stride_dV }, + args.hw_info + }; + } + +public: + + /// Determines whether the GEMM can execute the given problem. + static Status + can_implement(Arguments const& args) { + Status status = Status::kSuccess; + + status = OperationSumOdO::can_implement(to_sum_OdO_arguments(args)); + if (status != Status::kSuccess) { + return status; + } + + status = OperationConvert::can_implement(to_convert_arguments(args)); + if (status != Status::kSuccess) { + return status; + } + + status = Operation::can_implement(to_bwd_arguments(args)); + if (status != Status::kSuccess) { + return status; + } + + return status; + } + + /// Gets the workspace size + static size_t + get_workspace_size(Arguments const& args) { + auto [Q_, K, D, D_VO, HB] = args.problem_shape; + auto [H, B] = HB; + D = cutlass::round_up(D, 8); // Alignment + int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment + size_t workspace_bytes = 0; + // OdO vector + workspace_bytes += B*H*Q * sizeof(ElementAccumulator); + // scaled LSE vector + workspace_bytes += B*H*Q * sizeof(ElementAccumulator); + // FP32 versions of outputs that are churned (start off with Q only) + workspace_bytes += B*H*Q*D * sizeof(ElementAccumulator); + return workspace_bytes; + } + + /// Initializes state from arguments. + Status + initialize_split(Arguments const& args, void* workspace_dQ, void* workspace_sum_OdO, void* workspace_scaled_lse, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("Universal::initialize_split() - workspace_dQ=" + << workspace_dQ << ", workspace_sum_OdO=" << workspace_sum_OdO << "stream: " << (stream ? "non-null" : "null")); + + auto [Q_, K, D, D_VO, HB] = args.problem_shape; + auto [H, B] = HB; + D = cutlass::round_up(D, 8); // Alignment + int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment + ElementAccumulator* sum_OdO = reinterpret_cast(workspace_sum_OdO); + ElementAccumulator* scaled_lse = reinterpret_cast(workspace_scaled_lse); + ElementAccumulator* dQ_acc = reinterpret_cast(workspace_dQ); + params_.dQ_acc = dQ_acc; + params_.dQ_acc_size = B*H*Q*D * sizeof(ElementAccumulator); + auto args_sum_OdO = to_sum_OdO_arguments(args, sum_OdO, scaled_lse); + auto args_convert = to_convert_arguments(args, dQ_acc); + params_.op_sum_OdO.initialize(args_sum_OdO, nullptr, stream); + params_.op_convert.initialize(args_convert, nullptr, stream); + auto args_bwd = to_bwd_arguments( + args, sum_OdO, args_sum_OdO.stride_sum_OdO, + scaled_lse, args_sum_OdO.stride_scaled_lse, + dQ_acc, args_convert.stride_src_dQ + ); + params_.op.initialize(args_bwd, nullptr, stream); + + return Status::kSuccess; + } + + /// Initializes state from arguments. + Status + initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("Universal::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + auto [Q_, K, D, D_VO, HB] = args.problem_shape; + auto [H, B] = HB; + D = cutlass::round_up(D, 8); // Alignment + int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment + char* workspace_chr = reinterpret_cast(workspace); + ElementAccumulator* sum_OdO = reinterpret_cast(workspace_chr); + workspace_chr += B*H*Q * sizeof(ElementAccumulator); + ElementAccumulator* scaled_lse = reinterpret_cast(workspace_chr); + workspace_chr += B*H*Q * sizeof(ElementAccumulator); + ElementAccumulator* dQ_acc = reinterpret_cast(workspace_chr); + return initialize_split(args, dQ_acc, sum_OdO, scaled_lse, stream); + } + + /// Primary run() entry point API that is static allowing users to create and manage their own params. + /// Supplied params struct must be construct by calling Kernel::to_underling_arguments() + static Status + run(Params& params, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("FmhaDeviceBwd::run()"); + + Status result = Status::kSuccess; + result = params.op_sum_OdO.run(stream); + if (result != Status::kSuccess) { + return result; + } + + auto cuda_result = cudaMemsetAsync(params.dQ_acc, 0, params.dQ_acc_size, stream); + if (cuda_result != cudaSuccess) { + return Status::kErrorInternal; + } + + result = params.op.run(stream); + if (result != Status::kSuccess) { + return result; + } + + result = params.op_convert.run(stream); + if (result != Status::kSuccess) { + return result; + } + + return Status::kSuccess; + } + + // + // Non-static launch overloads that first create and set the internal params struct of this kernel handle. + // + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + Status status = initialize(args, workspace, stream); + if (Status::kSuccess == status) { + status = run(params_, stream); + } + return status; + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + run(cudaStream_t stream = nullptr) { + return run(params_, stream); + } + +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::fmha::device + +//////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/sm100/fmha_cutlass_bwd_sm100.cu b/csrc/sm100/fmha_cutlass_bwd_sm100.cu new file mode 100644 index 0000000..4ff745d --- /dev/null +++ b/csrc/sm100/fmha_cutlass_bwd_sm100.cu @@ -0,0 +1,83 @@ +#include +#include +#include +#include +#include +#include "common/mask.cuh" +#include "common/utils.hpp" + +#include "fmha_cutlass_bwd_sm100.cuh" + +template +void call_run_fmha_bwd([[maybe_unused]] Mask mask, [[maybe_unused]] Varlen is_varlen, + [[maybe_unused]] Element in, [[maybe_unused]] ElementOut out, [[maybe_unused]] Mla mla, + at::Tensor workspace_buffer, at::Tensor d_o, at::Tensor q, at::Tensor k, + at::Tensor v, at::Tensor o, at::Tensor lse, + at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv, + at::Tensor dq, at::Tensor dk, at::Tensor dv, + float softmax_scale, int max_seqlen_q, int total_seqlen_kv) { + static constexpr bool IsVarlen = std::is_same_v; + static constexpr bool IsMla = std::is_same_v; + using TileShape = std::conditional_t, Shape<_128, _128, _128, _128>>; + run_fmha_bwd(workspace_buffer, d_o, q, k, v, o, lse, + cumulative_seqlen_q, cumulative_seqlen_kv, + dq, dk, dv, + softmax_scale, max_seqlen_q, total_seqlen_kv); +} + + +void FMHACutlassSM100BwdRun(at::Tensor workspace_buffer, at::Tensor d_o, at::Tensor q, at::Tensor k, + at::Tensor v, at::Tensor o, at::Tensor lse, + at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv, + at::Tensor dq, at::Tensor dk, at::Tensor dv, + int mask_mode_code, float softmax_scale, int max_seqlen_q, int max_seqlen_kv, bool is_varlen) { + + const c10::cuda::OptionalCUDAGuard device_guard(q.device()); + + int head_dim_qk = q.size(-1); + int head_dim_vo = v.size(-1); + MaskMode mask_mode = static_cast(mask_mode_code); + auto scalar_type_in = q.scalar_type(); + auto scalar_type_out = o.scalar_type(); + + if(scalar_type_in == at::ScalarType::BFloat16 && scalar_type_out == at::ScalarType::BFloat16) { + using Element = cutlass::bfloat16_t; + using ElementOut = cutlass::bfloat16_t; + + auto apply_config = [&](auto fn) { + if (mask_mode == MaskMode::kCausal) { + if(is_varlen) { + fn(CausalForBackwardMask{}, cute::true_type{}, Element{}, ElementOut{}); + } else { + fn(CausalForBackwardMask{}, cute::false_type{}, Element{}, ElementOut{}); + } + } + else { + if(is_varlen) { + fn(ResidualMaskForBackward{}, cute::true_type{}, Element{}, ElementOut{}); + } else { + fn(ResidualMaskForBackward{}, cute::false_type{}, Element{}, ElementOut{}); + } + } + }; + + apply_config([&](auto mask, auto varlen, auto in, auto out) { + if (head_dim_qk == 192 && head_dim_vo == 128) { + call_run_fmha_bwd(mask, varlen, in, out, true_type{}, workspace_buffer, d_o, q, k, v, o, lse, + cumulative_seqlen_q, cumulative_seqlen_kv, + dq, dk, dv, + softmax_scale, max_seqlen_q, max_seqlen_kv); + } else if (head_dim_qk == 128 && head_dim_vo == 128) { + call_run_fmha_bwd(mask, varlen, in, out, false_type{}, workspace_buffer, d_o, q, k, v, o, lse, + cumulative_seqlen_q, cumulative_seqlen_kv, + dq, dk, dv, + softmax_scale, max_seqlen_q, max_seqlen_kv); } + else { + std::cout << "No kernel instantiated for head_dim_qk=" << head_dim_qk << " head_dim_vo=" << head_dim_vo << std::endl; + } + }); + + } else { + FLASH_MLA_ASSERT(false); + } +} diff --git a/csrc/sm100/fmha_cutlass_bwd_sm100.cuh b/csrc/sm100/fmha_cutlass_bwd_sm100.cuh new file mode 100644 index 0000000..2b19be2 --- /dev/null +++ b/csrc/sm100/fmha_cutlass_bwd_sm100.cuh @@ -0,0 +1,200 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include +#include +#include + +#include + +#include +#include + +#include +#include +#include + +#include "common/utils.hpp" +#include "collective/fmha_fusion.hpp" +#include "device/fmha_device_bwd.hpp" + +using namespace cute; +using namespace cutlass::fmha::kernel; +using namespace cutlass::fmha::collective; +using namespace cutlass::fmha; +using namespace cutlass; + + +template< + class DType, + bool kIsVarlen, + bool kIsMla, + class TileShape, + class ActiveMask +> +struct BwdRunner { + + using Element = DType; + using ElementAccumulator = float; + + // Q K D D_VO (H B) + using ProblemShape = std::conditional_t< + kIsVarlen, + cute::tuple>, + cute::tuple> + >; + + using Operation = cutlass::fmha::device::Sm100FmhaBwd; + + using TensorStride = Stride>; + using StrideQ = TensorStride; // Seq DQK (H B) + using StrideK = TensorStride; // Seq DQK (H B) + using StrideV = TensorStride; // Seq DVO (H B) + using StrideO = TensorStride; // Seq DVO (H B) + using StrideLSE = Stride<_1, Stride>; // Seq (H B) + + // Backwards specific + using StrideDQ = TensorStride; + using StrideDK = TensorStride; // Seq DQK (H B) + using StrideDV = TensorStride; // Seq DVO (H B) + using StrideDO = TensorStride; + + static void run(at::Tensor workspace_buffer, at::Tensor d_o, at::Tensor q, at::Tensor k, + at::Tensor v, at::Tensor o, at::Tensor lse, + at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv, + at::Tensor dq, at::Tensor dk, at::Tensor dv, + float softmax_scale, int max_seqlen_q, int max_seqlen_kv) { + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + ProblemShape problem_shape; + cute::tuple> tensor_shape; + + + int d = q.size(-1); + int d_vo = v.size(-1); + int batch_size = cumulative_seqlen_q.size(0) - 1; + int num_qo_heads = q.size(1); + int total_seqlen_q = q.size(0); + int total_seqlen_kv = k.size(0); + + //varlen: q: [Q, H, D] + //fixedlen: q: [B, H, Q, D] + if constexpr (kIsVarlen) { + problem_shape = cute::make_tuple( + VariableLength{max_seqlen_q, static_cast(cumulative_seqlen_q.data_ptr()), total_seqlen_q}, + VariableLength{max_seqlen_kv, static_cast(cumulative_seqlen_kv.data_ptr()), total_seqlen_kv}, + d, d_vo, cute::make_tuple(num_qo_heads, batch_size)); + tensor_shape = make_shape(total_seqlen_q, total_seqlen_kv, d, d_vo, make_shape(num_qo_heads, 1)); + } else { + int q_len = total_seqlen_q / batch_size; + int kv_len = total_seqlen_kv / batch_size; + problem_shape = cute::make_tuple(q_len, kv_len, d, d_vo, cute::make_tuple(num_qo_heads, batch_size)); + tensor_shape = problem_shape; + } + + auto [Q, K, D, D_VO, HB] = tensor_shape; + auto [H, B] = HB; + + int q_stride0 = q.stride(0), q_stride1 = q.stride(1), q_stride2 = q.stride(2); + int k_stride0 = k.stride(0), k_stride1 = k.stride(1), k_stride2 = k.stride(2); + int v_stride0 = v.stride(0), v_stride1 = v.stride(1), v_stride2 = v.stride(2); + int o_stride0 = o.stride(0), o_stride1 = o.stride(1), o_stride2 = o.stride(2); + int lse_stride0 = lse.stride(0), lse_stride1 = lse.stride(1); + int dq_stride0 = dq.stride(0), dq_stride1 = dq.stride(1), dq_stride2 = dq.stride(2); + int dk_stride0 = dk.stride(0), dk_stride1 = dk.stride(1), dk_stride2 = dk.stride(2); + int dv_stride0 = dv.stride(0), dv_stride1 = dv.stride(1), dv_stride2 = dv.stride(2); + int do_stride0 = d_o.stride(0), do_stride1 = d_o.stride(1), do_stride2 = d_o.stride(2); + TORCH_CHECK(q_stride2 == 1); + TORCH_CHECK(k_stride2 == 1); + TORCH_CHECK(v_stride2 == 1); + TORCH_CHECK(o_stride2 == 1); + TORCH_CHECK(lse_stride0 == 1); + TORCH_CHECK(dq_stride2 == 1); + TORCH_CHECK(dk_stride2 == 1); + TORCH_CHECK(dv_stride2 == 1); + TORCH_CHECK(do_stride2 == 1); + + StrideQ stride_Q = make_stride(q_stride0, _1{}, make_stride(q_stride1, B == 1 ? 0 : q_stride0*Q)); + StrideK stride_K = make_stride(k_stride0, _1{}, make_stride(k_stride1, B == 1 ? 0 : k_stride0*K)); + StrideV stride_V = make_stride(v_stride0, _1{}, make_stride(v_stride1, B == 1 ? 0 : v_stride0*K)); + StrideO stride_O = make_stride(o_stride0, _1{}, make_stride(o_stride1, B == 1 ? 0 : o_stride0*Q)); + StrideLSE stride_LSE = make_stride(_1{}, make_stride(lse_stride1, B == 1 ? 0 : Q)); + + StrideDQ stride_dQ = make_stride(dq_stride0, _1{}, make_stride(dq_stride1, B == 1 ? 0 : dq_stride0*Q)); + StrideDK stride_dK = make_stride(dk_stride0, _1{}, make_stride(dk_stride1, B == 1 ? 0 : dk_stride0*K)); + StrideDV stride_dV = make_stride(dv_stride0, _1{}, make_stride(dv_stride1, B == 1 ? 0 : dv_stride0*K)); + StrideDO stride_dO = make_stride(do_stride0, _1{}, make_stride(do_stride1, B == 1 ? 0 : do_stride0*Q)); + + typename Operation::Arguments arguments{ + problem_shape, + (static_cast(q.data_ptr())), stride_Q, + (static_cast(k.data_ptr())), stride_K, + (static_cast(v.data_ptr())), stride_V, + (static_cast(o.data_ptr())), stride_O, + (static_cast(lse.data_ptr())), stride_LSE, + (static_cast(d_o.data_ptr())), stride_dO, + (static_cast(dq.data_ptr())), stride_dQ, + (static_cast(dk.data_ptr())), stride_dK, + (static_cast(dv.data_ptr())), stride_dV, + static_cast(softmax_scale), + hw_info + }; + + Operation op; + + size_t workspace_size = 0; + workspace_size = Operation::get_workspace_size(arguments); + DeviceAllocation workspace(workspace_size); + uint8_t* workspace_ptr = workspace.get(); + + CUTLASS_CHECK(op.can_implement(arguments)); + CUTLASS_CHECK(op.initialize(arguments, workspace.get())); + CUTLASS_CHECK(op.run(at::cuda::getCurrentCUDAStream())); + } + +}; + + +template +void run_fmha_bwd(at::Tensor workspace_buffer, at::Tensor d_o, at::Tensor q, at::Tensor k, + at::Tensor v, at::Tensor o, at::Tensor lse, + at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv, + at::Tensor dq, at::Tensor dk, at::Tensor dv, + float softmax_scale, int max_seqlen_q, int total_seqlen_kv) { + BwdRunner::run(workspace_buffer, d_o, q, k, v, o, lse, + cumulative_seqlen_q, cumulative_seqlen_kv, + dq, dk, dv, + softmax_scale, max_seqlen_q, total_seqlen_kv); +} diff --git a/csrc/sm100/fmha_cutlass_fwd_sm100.cu b/csrc/sm100/fmha_cutlass_fwd_sm100.cu new file mode 100644 index 0000000..e322709 --- /dev/null +++ b/csrc/sm100/fmha_cutlass_fwd_sm100.cu @@ -0,0 +1,81 @@ +#include "common/mask.cuh" +#include "common/utils.hpp" +#include "fmha_cutlass_fwd_sm100.cuh" + +#include +#include +#include +#include +#include + +template +void call_run_fmha_fwd([[maybe_unused]] Mask mask, [[maybe_unused]] Varlen is_varlen, + [[maybe_unused]] Element in, [[maybe_unused]] ElementOut out, + [[maybe_unused]] Mla mla, at::Tensor workspace_buffer, at::Tensor q, + at::Tensor k, at::Tensor v, at::Tensor cumulative_seqlen_q, + at::Tensor cumulative_seqlen_kv, at::Tensor o, at::Tensor lse, + float softmax_scale, int max_seqlen_q, int max_seqlen_kv) { + static constexpr bool IsVarlen = std::is_same_v; + static constexpr bool IsMla = std::is_same_v; + static constexpr bool IsCausalMask = std::is_same_v>; + using Option = std::conditional_t, + Option>; + + run_fmha_fwd( + workspace_buffer, q, k, v, cumulative_seqlen_q, cumulative_seqlen_kv, o, lse, + softmax_scale, max_seqlen_q, max_seqlen_kv); +} + +void FMHACutlassSM100FwdRun(at::Tensor workspace_buffer, at::Tensor q, at::Tensor k, + at::Tensor v, at::Tensor cumulative_seqlen_q, + at::Tensor cumulative_seqlen_kv, at::Tensor o, at::Tensor lse, + int mask_mode_code, float sm_scale, int max_seqlen_q, + int max_seqlen_kv, bool is_varlen) { + const c10::cuda::OptionalCUDAGuard device_guard(q.device()); + CHECK(q.scalar_type() == k.scalar_type()); + auto scalar_type_in = q.scalar_type(); + auto scalar_type_out = o.scalar_type(); + int head_dim_qk = q.size(-1); + int head_dim_vo = v.size(-1); + MaskMode mask_mode = static_cast(mask_mode_code); + + if (scalar_type_in == at::ScalarType::BFloat16 && + scalar_type_out == at::ScalarType::BFloat16) { + using Element = cutlass::bfloat16_t; + using ElementOut = cutlass::bfloat16_t; + + auto apply_config = [&](auto fn) { + if (mask_mode == MaskMode::kCausal) { + if (is_varlen) { + fn(CausalMask{}, cute::true_type{}, Element{}, ElementOut{}); + } else { + fn(CausalMask{}, cute::false_type{}, Element{}, ElementOut{}); + } + } else { + if (is_varlen) { + fn(ResidualMask{}, cute::true_type{}, Element{}, ElementOut{}); + } else { + fn(ResidualMask{}, cute::false_type{}, Element{}, ElementOut{}); + } + } + }; + + apply_config([&](auto mask, auto varlen, auto in, auto out) { + if (head_dim_qk == 192 && head_dim_vo == 128) { + call_run_fmha_fwd(mask, varlen, in, out, true_type{}, workspace_buffer, q, k, v, + cumulative_seqlen_q, cumulative_seqlen_kv, o, lse, sm_scale, + max_seqlen_q, max_seqlen_kv); + } else if (head_dim_qk == 128 && head_dim_vo == 128) { + call_run_fmha_fwd(mask, varlen, in, out, false_type{}, workspace_buffer, q, k, v, + cumulative_seqlen_q, cumulative_seqlen_kv, o, lse, sm_scale, + max_seqlen_q, max_seqlen_kv); + } else { + std::cout << "No kernel instantiated for head_dim_qk=" << head_dim_qk + << " head_dim_vo=" << head_dim_vo << std::endl; + } + }); + + } else { + FLASH_MLA_ASSERT(false); + } +} diff --git a/csrc/sm100/fmha_cutlass_fwd_sm100.cuh b/csrc/sm100/fmha_cutlass_fwd_sm100.cuh new file mode 100644 index 0000000..71831bb --- /dev/null +++ b/csrc/sm100/fmha_cutlass_fwd_sm100.cuh @@ -0,0 +1,334 @@ +#pragma once + +#include "collective/fmha_fusion.hpp" +#include "collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp" +#include "collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp" +#include "collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/kernel_hardware_info.h" +#include "device/fmha.hpp" +#include "kernel/fmha_causal_tile_scheduler.hpp" +#include "kernel/fmha_options.hpp" +#include "kernel/fmha_tile_scheduler.hpp" +#include "kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp" + +#include +#include + +using namespace cute; +using namespace cutlass::fmha::collective; +using namespace cutlass::fmha::kernel; +using namespace cutlass::fmha::device; + +struct FmhaOptions { + int b = 1; + int h = 1; + int h_k = 1; + int q = 256; + int k = 256; + int d = 128; +}; + +struct MlaOptions { + int b = 1; + int h = 1; + int h_k = 1; + int q = 256; + int k = 256; + int dl = 128; // headdim latent + int dr = 64; // headdim rope +}; + +template +struct FwdRunner { + + using Element = Element_; + using ElementAccumulatorQK = float; + using ElementAccumulatorPV = float; + using ElementOut = ElementOut_; + + using HeadDimLatent = _128; + using HeadDim = Shape; + using TileShapeMla = Shape<_256, _128, HeadDim>; + using TileShapeFmha = Shape<_256, _128, _128>; + using TileShape = std::conditional_t; + + using ProblemShapeRegular = std::conditional_t< + kIsMla, + cute::tuple, cute::tuple, int>>, + cute::tuple, int>>>; + + using ProblemShapeVarlen = + std::conditional_t, + cute::tuple, int>>, + cute::tuple, int>>>; + + using ProblemShapeType = + std::conditional_t; + + using StrideQ = cute::tuple, int>>; + using StrideK = cute::tuple, int>>; + using StrideV = StrideK; + using StrideO = StrideQ; + using StrideLSE = cute::tuple<_1, cute::tuple, int>>; + + static constexpr bool kIsPersistent = + find_option_t::value; + + using TileScheduler = std::conditional_t< + kIsPersistent, + std::conditional_t> || + std::is_same_v>, + cutlass::fmha::kernel::CausalPersistentTileScheduler, + cutlass::fmha::kernel::PersistentTileScheduler>, + std::conditional_t>; + + static constexpr bool IsOrderLoadEpilogue = + kIsPersistent && (sizeof(Element) == sizeof(ElementOut)); + using OrderLoadEpilogue = std::conditional_t; + + using MainloopMla = cutlass::fmha::collective::Sm100MlaFwdMainloopTmaWarpspecialized< + Element, ElementAccumulatorQK, ElementAccumulatorPV, TileShapeMla, StrideQ, StrideK, + StrideV, ActiveMask, Shape<_2, _1, _1>, OrderLoadEpilogue>; + + using OperationMla = + cutlass::fmha::device::FMHA, + TileScheduler, cutlass::fmha::kernel::Sm100MlaFwdCtxKernelWarpspecializedSchedule>>; + + using MainloopFmha = cutlass::fmha::collective::Sm100FmhaFwdMainloopTmaWarpspecialized< + Element, ElementAccumulatorQK, ElementAccumulatorPV, TileShapeFmha, StrideQ, StrideK, + StrideV, ActiveMask>; + + using OperationFmha = + cutlass::fmha::device::FMHA, + TileScheduler>>; + + using Mainloop = std::conditional_t; + using Operation = std::conditional_t; + + // + // Data members + // + + /// Initialization + StrideQ stride_Q; + StrideK stride_K; + StrideV stride_V; + StrideO stride_O; + StrideLSE stride_LSE; + + template + auto initialize_varlen(const ProblemShape &problem_size, int max_seqlen_q, int max_seqlen_kv, + int total_seqlen_q, int total_seqlen_kv) { + + int num_batches = get<3, 1>(problem_size); + + ProblemShape problem_size_for_init = problem_size; + get<3, 1>(problem_size_for_init) = 1; + get<0>(problem_size_for_init) = total_seqlen_q; + get<1>(problem_size_for_init) = total_seqlen_kv; + + ProblemShapeType problem_size_for_launch; + + get<0>(problem_size_for_launch) = VariableLength{max_seqlen_q}; + get<1>(problem_size_for_launch) = VariableLength{max_seqlen_kv}; + get<2>(problem_size_for_launch) = get<2>(problem_size); + get<3>(problem_size_for_launch) = get<3>(problem_size); + + return cute::make_tuple(problem_size_for_init, problem_size_for_launch); + } + + template + static constexpr auto get_problem_shape(const Options &options) { + int h_r = options.h / options.h_k; + if constexpr (std::is_same_v) { + return cute::make_tuple(options.q, options.k, cute::make_tuple(options.dl, options.dr), + cute::make_tuple(cute::make_tuple(h_r, options.h_k), options.b)); + } else { + return cute::make_tuple(options.q, options.k, options.d, + cute::make_tuple(cute::make_tuple(h_r, options.h_k), options.b)); + } + } + + template + ProblemShapeType initialize(const Options &options, int max_seqlen_q, int max_seqlen_kv, + int total_seqlen_q, int total_seqlen_kv, + void *cumulative_length_q, void *cumulative_length_kv) { + assert(options.h % options.h_k == 0); + auto problem_shape_in = get_problem_shape(options); + + ProblemShapeType problem_shape; + decltype(problem_shape_in) problem_size; + + if constexpr (kIsVarlen) { + auto [problem_shape_init, problem_shape_launch] = initialize_varlen( + problem_shape_in, max_seqlen_q, max_seqlen_kv, total_seqlen_q, total_seqlen_kv); + problem_shape = problem_shape_launch; + problem_size = problem_shape_init; + } else { + problem_size = problem_shape_in; + problem_shape = problem_shape_in; + } + + auto get_head_dimension = [&]() { + if constexpr (rank_v(problem_shape))> == 2) { + return cute::make_tuple(size<2, 0>(problem_shape) + size<2, 1>(problem_shape), + size<2, 0>(problem_shape)); + } else { + return cute::make_tuple(size<2>(problem_size), size<2>(problem_size)); + } + }; + + + if constexpr (kIsVarlen) { + get<0>(problem_shape).cumulative_length = static_cast(cumulative_length_q); + get<1>(problem_shape).cumulative_length = static_cast(cumulative_length_kv); + } + + return problem_shape; + } + + auto get_arguments(const ProblemShapeType &problem_shape, + const cutlass::KernelHardwareInfo &hw_info, float scale_softmax, + void *q_ptr, void *k_ptr, void *v_ptr, void *o_ptr, void *lse_ptr, + void *cumulative_length_q, void *cumulative_length_kv) { + auto problem_shape_ = problem_shape; + if constexpr (kIsVarlen) { + get<0>(problem_shape_).cumulative_length = static_cast(cumulative_length_q); + get<1>(problem_shape_).cumulative_length = static_cast(cumulative_length_kv); + } + + typename Operation::Arguments arguments{ + problem_shape_, + {static_cast(q_ptr), stride_Q, static_cast(k_ptr), stride_K, + static_cast(v_ptr), stride_V, scale_softmax}, + {static_cast(o_ptr), stride_O, + static_cast(lse_ptr), stride_LSE}, + hw_info}; + + return arguments; + } + + template + void run(const Options &options, const cutlass::KernelHardwareInfo &hw_info, at::Tensor q, + at::Tensor k, at::Tensor v, at::Tensor o, at::Tensor lse, float scale_softmax, + at::Tensor workspace, at::Tensor cumulative_seqlen_q, + at::Tensor cumulative_seqlen_kv, int max_seqlen_q, int max_seqlen_kv) { + + int total_seqlen_q = q.size(0); + int total_seqlen_kv = k.size(0); + ProblemShapeType problem_shape = + initialize(options, max_seqlen_q, max_seqlen_kv, total_seqlen_q, total_seqlen_kv, + cumulative_seqlen_q.data_ptr(), cumulative_seqlen_kv.data_ptr()); + + int SQ = size<0>(problem_shape); + int SK = size<1>(problem_shape); + int B = size<3, 1>(problem_shape); + int H = size<3, 0>(problem_shape); + int H_K = size<3, 0, 1>(problem_shape); + int H_Q = size<3, 0, 0>(problem_shape); + + int q_stride0 = q.stride(0), q_stride1 = q.stride(1), q_stride2 = q.stride(2); + int k_stride0 = k.stride(0), k_stride1 = k.stride(1), k_stride2 = k.stride(2); + int v_stride0 = v.stride(0), v_stride1 = v.stride(1), v_stride2 = v.stride(2); + int o_stride0 = o.stride(0), o_stride1 = o.stride(1), o_stride2 = o.stride(2); + int lse_stride0 = lse.stride(0), lse_stride1 = lse.stride(1); + TORCH_CHECK(q_stride2 == 1); + TORCH_CHECK(k_stride2 == 1); + TORCH_CHECK(v_stride2 == 1); + TORCH_CHECK(o_stride2 == 1); + TORCH_CHECK(lse_stride0 == 1); + + stride_Q = make_stride(q_stride0, _1{}, make_stride(make_stride(q_stride1, H_Q * q_stride1), SQ * q_stride0)); + stride_O = make_stride(o_stride0, _1{}, make_stride(make_stride(o_stride1, H_Q * o_stride1), SQ * o_stride0)); + stride_K = make_stride(k_stride0, _1{}, make_stride(make_stride(_0{}, k_stride1), SK * k_stride0)); + stride_V = make_stride(v_stride0, _1{}, make_stride(make_stride(_0{}, v_stride1), SK * v_stride0)); + stride_LSE = make_stride(_1{}, make_stride(make_stride(lse_stride1, lse_stride1 * H_Q), SQ)); + + if constexpr (kIsVarlen) { + get<2, 1>(stride_Q) = 0; + get<2, 1>(stride_K) = 0; + get<2, 1>(stride_V) = 0; + get<2, 1>(stride_O) = 0; + get<1, 1>(stride_LSE) = 0; + } + + typename Operation::Arguments arguments = + get_arguments(problem_shape, hw_info, scale_softmax, q.data_ptr(), k.data_ptr(), + v.data_ptr(), o.data_ptr(), lse.data_ptr(), + cumulative_seqlen_q.data_ptr(), cumulative_seqlen_kv.data_ptr()); + + Operation op; + + // size_t workspace_size = 0; + // workspace_size = Operation::get_workspace_size(arguments); + + // todo: if use workspace, need check workspace size first. + // we don't use workspace in current version. + + CUTLASS_CHECK(op.can_implement(arguments)); + CUTLASS_CHECK(op.initialize(arguments, nullptr)); + CUTLASS_CHECK(op.run(at::cuda::getCurrentCUDAStream())); + } +}; + +template +void run_fmha_fwd(at::Tensor workspace, at::Tensor q, at::Tensor k, at::Tensor v, + at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv, at::Tensor o, + at::Tensor lse, float scale_softmax, int max_seqlen_q, int max_seqlen_kv) { + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + auto get_options = [&]() { + if constexpr (kIsMla) { + MlaOptions options; + options.b = cumulative_seqlen_q.size(0) - 1; + options.h = q.size(1); + options.h_k = k.size(1); + options.q = q.size(0) / options.b; + options.k = k.size(0) / options.b; + options.dl = v.size(-1); + options.dr = q.size(-1) - v.size(-1); + return options; + } else { + FmhaOptions options; + options.b = cumulative_seqlen_q.size(0) - 1; + options.h = q.size(1); + options.h_k = k.size(1); + options.q = q.size(0) / options.b; + options.k = k.size(0) / options.b; + options.d = q.size(-1); + return options; + } + }; + + auto options = get_options(); + + if (options.h % cutlass::fmha::kernel::CausalIndividualTileScheduler::TileH == 0 && + (!std::is_same_v)) { + FwdRunner runner; + runner.run(options, hw_info, q, k, v, o, lse, scale_softmax, workspace, cumulative_seqlen_q, + cumulative_seqlen_kv, max_seqlen_q, max_seqlen_kv); + } else { + FwdRunner runner; + runner.run(options, hw_info, q, k, v, o, lse, scale_softmax, workspace, cumulative_seqlen_q, + cumulative_seqlen_kv, max_seqlen_q, max_seqlen_kv); + } +} diff --git a/csrc/sm100/kernel/fmha_causal_tile_scheduler.hpp b/csrc/sm100/kernel/fmha_causal_tile_scheduler.hpp new file mode 100644 index 0000000..572e67f --- /dev/null +++ b/csrc/sm100/kernel/fmha_causal_tile_scheduler.hpp @@ -0,0 +1,197 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" + +namespace cutlass::fmha::kernel { + +//////////////////////////////////////////////////////////////////////////////// + +// Swizzle Q tile and H tile to improve L2 cache hit rate, +// and launch the longest main loop first to keep most SMs busy. + +struct CausalIndividualTileScheduler { + + static constexpr int TileQ = 16; + static constexpr int TileH = 8; + static constexpr int TileSize = TileQ * TileH; + + struct Params { + dim3 grid; + int tile_max_q; + FastDivmod divmod_tile_col; + FastDivmod divmod_tile_size; + FastDivmod divmod_tile_head; + }; + + bool valid_ = true; + Params params; + + CUTLASS_DEVICE + CausalIndividualTileScheduler(Params const& params) : params(params) {} + + template + static Params to_underlying_arguments( + ProblemSize const& problem_size, KernelHardwareInfo hw_info, + ClusterShape const& cluster_shape, TileShape const& tile_shape) { + using namespace cute; + + dim3 grid(size<3,0>(problem_size), round_up(ceil_div(size<0>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape)), size<3,1>(problem_size)); + // gridDim.x must multiple of TileH + const int tile_col_count = grid.x / TileH; + const int tile_max_q = grid.y / TileQ * TileQ; + return Params{ grid , tile_max_q, tile_col_count, TileSize, TileH}; + } + + static dim3 get_grid_shape(Params const& params) { + return params.grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return valid_; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + const int block_idx = blockIdx.y * gridDim.x + blockIdx.x; + + int tile_idx, tile_tail; + params.divmod_tile_size(tile_idx, tile_tail, block_idx); + + int tile_row_idx, tile_col_idx; + params.divmod_tile_col(tile_row_idx,tile_col_idx, tile_idx); + + int row_offset_in_tail, col_offset_in_tail; + params.divmod_tile_head(row_offset_in_tail,col_offset_in_tail, tile_tail); + + const int row_idx = tile_row_idx * TileQ + row_offset_in_tail; + const int col_idx = tile_col_idx * TileH + col_offset_in_tail; + + // last q tile launch first + if(blockIdx.y >= params.tile_max_q) { + return make_coord(int(gridDim.y - 1 - blockIdx.y), _0{}, make_coord(int(blockIdx.x), int(blockIdx.z))); + } + + return make_coord(int(gridDim.y) - 1 - row_idx, _0{}, make_coord(col_idx, int(blockIdx.z))); + } + + CUTLASS_DEVICE + CausalIndividualTileScheduler& operator++() { + valid_ = false; + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + + +//////////////////////////////////////////////////////////////////////////////// + +// Launch order: H Q B +struct CausalPersistentTileScheduler { + + struct Params { + int num_blocks; + FastDivmod divmod_h; + FastDivmod divmod_m_block; + FastDivmod divmod_b; + + KernelHardwareInfo hw_info; + }; + + int block_idx = 0; + Params params; + + CUTLASS_DEVICE + CausalPersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {} + + template + static Params to_underlying_arguments( + ProblemSize const& problem_size, KernelHardwareInfo hw_info, + ClusterShape const& cluster_shape, TileShape const& tile_shape) { + using namespace cute; + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = hw_info.sm_count; + if (sm_count <= 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + hw_info.sm_count = sm_count; + + int num_m_blocks = cutlass::round_up(ceil_div(size<0>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape)); + int num_blocks = num_m_blocks * size<3,0>(problem_size) * size<3,1>(problem_size); + + return Params { + num_blocks, + { size<3,0>(problem_size) }, { num_m_blocks}, { size<3,1>(problem_size) }, + hw_info + }; + } + + static dim3 get_grid_shape(Params const& params) { + dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1); + return grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return block_idx < params.num_blocks; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + int block_decode = block_idx; + int m_block, bidb, bidh; + params.divmod_h(block_decode, bidh, block_decode); + params.divmod_m_block(block_decode, m_block, block_decode); + params.divmod_b(block_decode, bidb, block_decode); + return make_coord(m_block, _0{}, make_coord(bidh, bidb)); + } + + CUTLASS_DEVICE + CausalPersistentTileScheduler& operator++() { + block_idx += gridDim.x; + return *this; + } +}; +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::fmha::kernel diff --git a/csrc/sm100/kernel/fmha_kernel_bwd_convert.hpp b/csrc/sm100/kernel/fmha_kernel_bwd_convert.hpp new file mode 100644 index 0000000..32e007c --- /dev/null +++ b/csrc/sm100/kernel/fmha_kernel_bwd_convert.hpp @@ -0,0 +1,153 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/layout.hpp" + +namespace cutlass::fmha::kernel { + +using namespace cute; + +template +struct FmhaKernelBwdConvert { + + struct Arguments { + ProblemShape problem_shape; + + const ElementAcc* ptr_src_dQ; + tuple> stride_src_dQ; + const ElementAcc* ptr_src_dK; + tuple> stride_src_dK; + const ElementAcc* ptr_src_dV; + tuple> stride_src_dV; + + Element* ptr_dest_dQ; + tuple> stride_dest_dQ; + Element* ptr_dest_dK; + tuple> stride_dest_dK; + Element* ptr_dest_dV; + tuple> stride_dest_dV; + + ElementAcc scale = 1.0; + }; + + using Params = Arguments; + + using ClusterShape = Shape<_1, _1, _1>; + static constexpr int SharedStorageSize = 0; + + static const int MinBlocksPerMultiprocessor = 1; + static const int MaxThreadsPerBlock = 128; + using ArchTag = cutlass::arch::Sm90; + + static const int kBlockSeq = 8; + + static size_t get_workspace_size(Arguments const& args) { return 0; } + static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) { + return cutlass::Status::kSuccess; + } + + static const int kNumThreadsD = 16; + static const int kNumThreadsSeq = MaxThreadsPerBlock / kNumThreadsD; + static const int kElementsPerLoad = 4; + + static const int kIterationsSeq = kBlockSeq / kNumThreadsSeq; + + static bool can_implement(Arguments const& args) { + return get<2>(args.problem_shape) % kElementsPerLoad == 0 && get<3>(args.problem_shape) % kElementsPerLoad == 0; + } + + static dim3 get_grid_shape(Params const& params) { + dim3 grid(size<4,0>(params.problem_shape), size<4,1>(params.problem_shape), ceil_div(std::max(size<0>(params.problem_shape), size<1>(params.problem_shape)), kBlockSeq)); + return grid; + } + + static dim3 get_block_shape() { + dim3 block(kNumThreadsD, kNumThreadsSeq, 1); + return block; + } + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + return args; + } + + template + CUTLASS_DEVICE void copy(Params const& params, const ElementAcc* ptr_src, StrideSrc const& stride_src, Element* ptr_dest, StrideDest const& stride_dest, Count const& count, int d_dim) { + auto ptr_src_bh = ptr_src + get<2,0>(stride_src) * blockIdx.x + get<2,1>(stride_src) * blockIdx.y; + auto ptr_dest_bh = ptr_dest + get<2,0>(stride_dest) * blockIdx.x + get<2,1>(stride_dest) * blockIdx.y; + + int seqlen = count; + if constexpr (is_variable_length_v) { + int offset = count.cumulative_length[blockIdx.y]; + ptr_dest_bh += offset * get<0>(stride_dest); + seqlen = count.cumulative_length[blockIdx.y + 1] - offset; + } + + for (int idx_s_t = threadIdx.y; idx_s_t < kBlockSeq; idx_s_t += kNumThreadsSeq) { + int idx_s = idx_s_t + kBlockSeq * blockIdx.z; + if (idx_s >= seqlen) continue; + auto ptr_src_bhs = ptr_src_bh + idx_s * get<0>(stride_src); + auto ptr_dest_bhs = ptr_dest_bh + idx_s * get<0>(stride_dest); + + for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < d_dim; idx_d += kElementsPerLoad * kNumThreadsD) { + ElementAcc value_src[kElementsPerLoad]; + Element value_dest[kElementsPerLoad]; + + using VecSrc = uint_bit_t * kElementsPerLoad>; + using VecDest = uint_bit_t * kElementsPerLoad>; + *reinterpret_cast(value_src) = *reinterpret_cast(&ptr_src_bhs[idx_d]); + + for (int v = 0; v < kElementsPerLoad; v++) { + value_dest[v] = static_cast(params.scale * value_src[v]); + } + + *reinterpret_cast(&ptr_dest_bhs[idx_d]) = *reinterpret_cast(value_dest); + } + } + } + + CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { + if (params.ptr_src_dQ != nullptr) { + copy(params, params.ptr_src_dQ, params.stride_src_dQ, params.ptr_dest_dQ, params.stride_dest_dQ, get<0>(params.problem_shape), get<2>(params.problem_shape)); + } + if (params.ptr_src_dK != nullptr) { + copy(params, params.ptr_src_dK, params.stride_src_dK, params.ptr_dest_dK, params.stride_dest_dK, get<1>(params.problem_shape), get<2>(params.problem_shape)); + } + if (params.ptr_src_dV != nullptr) { + copy(params, params.ptr_src_dV, params.stride_src_dV, params.ptr_dest_dV, params.stride_dest_dV, get<1>(params.problem_shape), get<3>(params.problem_shape)); + } + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/csrc/sm100/kernel/fmha_kernel_bwd_sum_OdO.hpp b/csrc/sm100/kernel/fmha_kernel_bwd_sum_OdO.hpp new file mode 100644 index 0000000..bdcf1cb --- /dev/null +++ b/csrc/sm100/kernel/fmha_kernel_bwd_sum_OdO.hpp @@ -0,0 +1,161 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/layout.hpp" + +namespace cutlass::fmha::kernel { + +using namespace cute; + +template +struct FmhaKernelBwdSumOdO { + + struct Arguments { + ProblemShape problem_shape; + + const Element* ptr_O; + cute::tuple> stride_O; + const Element* ptr_dO; + cute::tuple> stride_dO; + + ElementAcc* ptr_sum_OdO; + cute::tuple> stride_sum_OdO; + + const ElementAcc* ptr_lse = nullptr; + cute::tuple> stride_lse; + + ElementAcc* ptr_scaled_lse = nullptr; + cute::tuple> stride_scaled_lse; + + ElementAcc sum_odo_scale = 1.0; + ElementAcc lse_scale = 1.0; + }; + + using Params = Arguments; + + using ClusterShape = Shape<_1, _1, _1>; + static constexpr int SharedStorageSize = 0; + + static const int MinBlocksPerMultiprocessor = 1; + static const int MaxThreadsPerBlock = 128; + using ArchTag = cutlass::arch::Sm100; + + static size_t get_workspace_size(Arguments const& args) { return 0; } + static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) { + return cutlass::Status::kSuccess; + } + + static const int kBlockQ = 16; + + static const int kNumThreadsD = 8; + static const int kNumThreadsQ = MaxThreadsPerBlock / kNumThreadsD; + static const int kElementsPerLoad = 2; + + static const int kIterationsQ = kBlockQ / kNumThreadsQ; + + static bool can_implement(Arguments const& args) { + return get<2>(args.problem_shape) % kElementsPerLoad == 0 && get<3>(args.problem_shape) % kElementsPerLoad == 0; + } + + static dim3 get_grid_shape(Params const& params) { + dim3 grid(ceil_div(size<0>(params.problem_shape), kBlockQ), size<4,0>(params.problem_shape), size<4,1>(params.problem_shape)); + return grid; + } + + static dim3 get_block_shape() { + dim3 block(kNumThreadsD, kNumThreadsQ, 1); + return block; + } + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + return args; + } + + CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { + auto ptr_O_bh = params.ptr_O + blockIdx.y * get<2,0>(params.stride_O) + blockIdx.z * get<2,1>(params.stride_O); + auto ptr_dO_bh = params.ptr_dO + blockIdx.y * get<2,0>(params.stride_dO) + blockIdx.z * get<2,1>(params.stride_dO); + auto ptr_sum_OdO_bh = params.ptr_sum_OdO + blockIdx.y * get<1,0>(params.stride_sum_OdO) + blockIdx.z * get<1,1>(params.stride_sum_OdO); + auto ptr_lse_bh = params.ptr_lse + blockIdx.y * get<1,0>(params.stride_lse) + blockIdx.z * get<1,1>(params.stride_lse); + auto ptr_scaled_lse_bh = params.ptr_scaled_lse + blockIdx.y * get<1,0>(params.stride_scaled_lse) + blockIdx.z * get<1,1>(params.stride_scaled_lse); + + auto problem_q = get<0>(params.problem_shape); + int seqlen_q = problem_q; + if constexpr (is_variable_length_v) { + int offset = problem_q.cumulative_length[blockIdx.z]; + ptr_O_bh += offset * get<0>(params.stride_O); + ptr_dO_bh += offset * get<0>(params.stride_dO); + ptr_lse_bh += offset * get<0>(params.stride_lse); + seqlen_q = problem_q.cumulative_length[blockIdx.z + 1] - offset; + } + + CUTLASS_PRAGMA_UNROLL + for (int idx_q_t = threadIdx.y; idx_q_t < kBlockQ; idx_q_t += kNumThreadsQ) { + int idx_q = idx_q_t + kBlockQ * blockIdx.x; + if (idx_q >= seqlen_q) continue; + ElementAcc acc = 0; + auto ptr_O_bhq = ptr_O_bh + idx_q * get<0>(params.stride_O); + auto ptr_dO_bhq = ptr_dO_bh + idx_q * get<0>(params.stride_dO); + auto ptr_sum_OdO_bhq = ptr_sum_OdO_bh + idx_q * get<0>(params.stride_sum_OdO); + auto ptr_lse_bhq = ptr_lse_bh + idx_q * get<0>(params.stride_lse); + auto ptr_scaled_lse_bhq = ptr_scaled_lse_bh + idx_q * get<0>(params.stride_scaled_lse); + + for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<3>(params.problem_shape); idx_d += kElementsPerLoad * kNumThreadsD) { + Element value_O[kElementsPerLoad]; + Element value_dO[kElementsPerLoad]; + + using Vec = uint_bit_t * kElementsPerLoad>; + *reinterpret_cast(value_O) = *reinterpret_cast(&ptr_O_bhq[idx_d]); + *reinterpret_cast(value_dO) = *reinterpret_cast(&ptr_dO_bhq[idx_d]); + + for (int v = 0; v < kElementsPerLoad; v++) { + acc += value_O[v] * value_dO[v]; + } + } + + for (int i = 1; i < kNumThreadsD; i *= 2) { + acc += __shfl_xor_sync((uint32_t)-1, acc, i, kNumThreadsD); + } + + if (threadIdx.x == 0) { + *ptr_sum_OdO_bhq = params.sum_odo_scale * acc; + if (params.ptr_scaled_lse) { + *ptr_scaled_lse_bhq = params.lse_scale * *ptr_lse_bhq; + } + } + } + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/csrc/sm100/kernel/fmha_options.hpp b/csrc/sm100/kernel/fmha_options.hpp new file mode 100644 index 0000000..d4faa8d --- /dev/null +++ b/csrc/sm100/kernel/fmha_options.hpp @@ -0,0 +1,85 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + + +#include "cutlass/cutlass.h" + +namespace cutlass::fmha::kernel { + +template +struct find_option; + +template +struct find_option { + using option_value = Default; +}; + +template +struct find_option : + std::conditional_t< + Option::tag == kTag, + Option, + find_option + > +{}; + +template +using find_option_t = typename find_option::option_value; + +enum class Tag { + kIsPersistent, + kNumMmaWarpGroups, + kLoadsQSeparately, + + kIsMainloopLocked, + kIsEpilogueLocked, + + kStagesQ, + kStagesKV, + + kEpilogueKind, + + kBlocksPerSM, + kClusterM, + + kAccQK +}; + +template +struct Option { + static constexpr auto tag = kTag; + using option_value = Value; +}; + +} // namespace cutlass::fmha::kernel diff --git a/csrc/sm100/kernel/fmha_tile_scheduler.hpp b/csrc/sm100/kernel/fmha_tile_scheduler.hpp new file mode 100644 index 0000000..119f069 --- /dev/null +++ b/csrc/sm100/kernel/fmha_tile_scheduler.hpp @@ -0,0 +1,162 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/kernel_hardware_info.h" + +namespace cutlass::fmha::kernel { + +//////////////////////////////////////////////////////////////////////////////// + +struct IndividualTileScheduler { + + struct Params { + dim3 grid; + }; + + bool valid_ = true; + + CUTLASS_DEVICE + IndividualTileScheduler(Params const&) {} + + template + static Params to_underlying_arguments( + ProblemSize const& problem_size, KernelHardwareInfo hw_info, + ClusterShape const& cluster_shape, TileShape const& tile_shape) { + using namespace cute; + dim3 grid(round_up(ceil_div(size<0>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape)), size<3,0>(problem_size), size<3,1>(problem_size)); + return Params{ grid }; + } + + static dim3 get_grid_shape(Params const& params) { + return params.grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return valid_; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + return make_coord(blockIdx.x, _0{}, make_coord(blockIdx.y, blockIdx.z)); + } + + CUTLASS_DEVICE + IndividualTileScheduler& operator++() { + valid_ = false; + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +struct PersistentTileScheduler { + + struct Params { + int num_blocks; + FastDivmod divmod_m_block; + FastDivmod divmod_h; + FastDivmod divmod_b; + + KernelHardwareInfo hw_info; + }; + + int block_idx = 0; + Params params; + + CUTLASS_DEVICE + PersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {} + + template + static Params to_underlying_arguments( + ProblemSize const& problem_size, KernelHardwareInfo hw_info, + ClusterShape const& cluster_shape, TileShape const& tile_shape) { + using namespace cute; + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = hw_info.sm_count; + if (sm_count <= 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + hw_info.sm_count = sm_count; + + int num_m_blocks = cutlass::round_up(ceil_div(size<0>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape)); + int num_blocks = num_m_blocks * size<3,0>(problem_size) * size<3,1>(problem_size); + + return Params { + num_blocks, + { num_m_blocks}, { size<3,0>(problem_size) }, { size<3,1>(problem_size) }, + hw_info + }; + } + + static dim3 get_grid_shape(Params const& params) { + dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1); + return grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return block_idx < params.num_blocks; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + int block_decode = block_idx; + int m_block, bidb, bidh; + params.divmod_m_block(block_decode, m_block, block_decode); + params.divmod_b(block_decode, bidb, block_decode); + params.divmod_h(block_decode, bidh, block_decode); + return make_coord(m_block, _0{}, make_coord(bidh, bidb)); + } + + CUTLASS_DEVICE + PersistentTileScheduler& operator++() { + block_idx += gridDim.x; + return *this; + } +}; + + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::fmha::kernel diff --git a/csrc/sm100/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp b/csrc/sm100/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp new file mode 100644 index 0000000..59b410b --- /dev/null +++ b/csrc/sm100/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp @@ -0,0 +1,1841 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cute/arch/simd_sm100.hpp" + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "collective/fmha_common.hpp" + +#include + +namespace cutlass::fmha::kernel { + +using namespace cutlass::fmha::collective; + +using namespace cute; + +template< + class ProblemShape, + class Element, + class ElementAcc, + class TileShape, + class Mask +> +struct Sm100FmhaBwdKernelTmaWarpSpecialized { + + using TileShapeQ = decltype(get<0>(TileShape{})); + static_assert(std::is_same_v, "tile shape K must be 128"); + using TileShapeK = decltype(get<1>(TileShape{})); + static_assert(std::is_same_v, "tile shape K must be 128"); + using TileShapeDQK = decltype(get<2>(TileShape{})); + using TileShapeDVO = decltype(get<2>(TileShape{})); + + using TmemAllocator = cute::TMEM::Allocator1Sm; + struct TmemAllocation { + static constexpr uint32_t kDK = 0; // TileShapeK x TileShapeDQK x acc + static constexpr uint32_t kDV = kDK + TileShapeDQK{}; // TileShapeK x TileShapeDVO x acc + static constexpr uint32_t kDQ = kDV + TileShapeDVO{}; // TileShapeQ x TileShapeDQK x acc + static constexpr uint32_t kDP = kDQ; // TileShapeK x TileShapeQ x inp + static constexpr uint32_t kS = kDQ + max(TileShapeQ{}, TileShapeDQK{}); + static constexpr uint32_t kP = kS; + static constexpr uint32_t kTotal = kS + TileShapeQ{}; + }; + + static_assert( + static_cast(TmemAllocation::kTotal) <= TmemAllocator::Sm100TmemCapacityColumns, + "using too much tmem" + ); + + enum class WarpRole { + Empty = 0x0, Load = 0x1, Mma = 0x2, Compute = 0x3, Reduce = 0x4 + }; + + static constexpr unsigned long long kWarpAssignment = 0x12'3333'3333'4444ull; + static constexpr int kNumComputeWarps = 8; + static constexpr int kNumReduceWarps = 4; + CUTLASS_DEVICE WarpRole warp_idx_to_role(int warp_idx) { + return static_cast((kWarpAssignment >> (4 * warp_idx)) & 0xF); + } + + struct RegisterAllocation { + static constexpr int kWarpgroup0 = 160-8; + static constexpr int kWarpgroup1 = 128; + static constexpr int kWarpgroup2 = 96; + static constexpr int kReduce = kWarpgroup0; + static constexpr int kCompute = kWarpgroup1; + static constexpr int kMma = kWarpgroup2; + static constexpr int kEmpty = kWarpgroup2; + static constexpr int kLoad = kWarpgroup2; + + static_assert(kWarpgroup0 + 2 * kWarpgroup1 + kWarpgroup2 <= 512); + }; + + using ArchTag = cutlass::arch::Sm100; + + using ClusterShape = Shape<_1, _1, _1>; + using Schedule = cutlass::gemm::KernelTmaWarpSpecialized1SmSm100; + + static constexpr int MinBlocksPerMultiprocessor = 1; + static constexpr int kNumWarps = kNumComputeWarps + kNumReduceWarps + 4; + static constexpr int MaxThreadsPerBlock = NumThreadsPerWarp * kNumWarps; + + static constexpr int Alignment = 128 / sizeof_bits_v; + static constexpr int kStages = 2; + + using TensorStrideContiguousK = Stride>; + using TensorStrideContiguousMN = Stride<_1, int, Stride>; + + // compute S + using CollectiveMmaKQ = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, TensorStrideContiguousK, Alignment, + Element, TensorStrideContiguousK, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapeKQ = typename CollectiveMmaKQ::TileShape; + using TiledMmaKQ = typename CollectiveMmaKQ::TiledMma; + + // compute dP + using CollectiveMmaVDO = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, TensorStrideContiguousK, Alignment, + Element, TensorStrideContiguousK, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapeVDO = typename CollectiveMmaVDO::TileShape; + using TiledMmaVDO = typename CollectiveMmaVDO::TiledMma; + + // compute dV + using CollectiveMmaPDO = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // needs to match ordering of S calculation + Element, TensorStrideContiguousK, Alignment, + Element, TensorStrideContiguousMN, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapePDO = typename CollectiveMmaPDO::TileShape; + using TiledMmaPDO = decltype(to_tiled_mma_sm100_ts(typename CollectiveMmaPDO::TiledMma{})); + + // compute dK + using CollectiveMmaDSQ = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // somewhat arbitrary since we dump to smem, need to agree with the next one + Element, TensorStrideContiguousK , Alignment, + Element, TensorStrideContiguousMN, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapeDSQ = typename CollectiveMmaDSQ::TileShape; + using TiledMmaDSQ = typename CollectiveMmaDSQ::TiledMma; + + // compute dQ + using CollectiveMmaDSK = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // somewhat arbitrary since we dump to smem, need to agree with the previous one + Element, TensorStrideContiguousMN, Alignment, + Element, TensorStrideContiguousMN, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapeDSK = typename CollectiveMmaDSK::TileShape; + using TiledMmaDSK = typename CollectiveMmaDSK::TiledMma; + + // pipelines are named Pipeline + static constexpr int kStagesComputeSmem = 1; + using PipelineLoadMmaQ = PipelineTmaUmmaAsync<2, ClusterShape>; + using PipelineLoadMmaDO = PipelineTmaUmmaAsync<1, ClusterShape>; + using PipelineLoadComputeLSE = PipelineAsync<1>; + using PipelineLoadComputeSumOdO = PipelineAsync<1>; + using PipelineMmaComputeS = PipelineUmmaAsync<1>; + using PipelineMmaComputeDP = PipelineUmmaAsync<1>; + using PipelineMmaReduceDQ = PipelineUmmaAsync<1>; + using PipelineComputeMmaP = PipelineUmmaConsumerAsync<1>; + using PipelineComputeMmaDS = PipelineUmmaConsumerAsync; + using PipelineMmaComputeDKDV = PipelineUmmaAsync<2>; + static constexpr int kStagesReduceTmaStore = 2; + using PipelineReduceTmaStore = PipelineTmaStore; + + struct PipelineStorage { + alignas(16) typename PipelineLoadMmaQ::SharedStorage load_mma_q; + alignas(16) typename PipelineLoadMmaDO::SharedStorage load_mma_do; + alignas(16) typename PipelineLoadComputeLSE::SharedStorage load_compute_lse; + alignas(16) typename PipelineLoadComputeSumOdO::SharedStorage load_compute_sum_odo; + alignas(16) typename PipelineMmaComputeS::SharedStorage mma_compute_s; + alignas(16) typename PipelineMmaComputeDP::SharedStorage mma_compute_dp; + alignas(16) typename PipelineMmaReduceDQ::SharedStorage mma_reduce_dq; + alignas(16) typename PipelineComputeMmaP::SharedStorage compute_mma_p; + alignas(16) typename PipelineComputeMmaDS::SharedStorage compute_mma_ds; + alignas(16) typename PipelineMmaComputeDKDV::SharedStorage mma_compute_dkdv; + }; + + template + static CUTE_DEVICE constexpr auto restage(Layout const& layout, Stages stages = {}) { + return composition(layout, make_tuple(_, _, _, make_layout(stages))); + } + + using SmemLayoutK = decltype(restage(typename CollectiveMmaKQ::SmemLayoutA{})); + using SmemLayoutV = decltype(restage(typename CollectiveMmaVDO::SmemLayoutA{})); + using SmemLayoutQ = decltype(restage(typename CollectiveMmaKQ::SmemLayoutB{}, _2{})); + using SmemLayoutDO = decltype(restage(typename CollectiveMmaVDO::SmemLayoutB{}, _1{})); + using SmemLayoutDS = decltype(restage(typename CollectiveMmaDSK::SmemLayoutA{}, Int{})); + using SmemLayoutLSE = Layout>; + using SmemLayoutSumOdO = Layout>; + + using SmemLayoutQT = decltype(restage(typename CollectiveMmaDSQ::SmemLayoutB{}, _2{})); + using SmemLayoutKT = decltype(restage(typename CollectiveMmaDSK::SmemLayoutB{})); + using SmemLayoutDST = decltype(restage(typename CollectiveMmaDSQ::SmemLayoutA{}, Int{})); + using SmemLayoutDOT = decltype(restage(typename CollectiveMmaPDO::SmemLayoutB{}, _1{})); + + using TileShapeDQ = _32; + using SmemAtomDQ = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + cute::UMMA::Major::K, ElementAcc, TileShapeQ, TileShapeDQ + >()); + using SmemShapeDQ = Shape>; + using SmemLayoutDQ = decltype(tile_to_shape(SmemAtomDQ{}, SmemShapeDQ{}, Step<_2, _1, _3>{})); + + struct TensorStorage { + union { + alignas(2048) cute::array> smem_k; + alignas(2048) cute::array> smem_k_t; + }; + alignas(2048) cute::array> smem_v; + union { + alignas(2048) cute::array> smem_q; + alignas(2048) cute::array> smem_q_t; + }; + union { + alignas(2048) cute::array> smem_do; + alignas(2048) cute::array> smem_do_t; + }; + union { + alignas(2048) cute::array> smem_ds; + alignas(2048) cute::array> smem_ds_t; + }; + alignas(1024) cute::array> smem_dq; + alignas(16) cute::array> smem_lse; + alignas(16) cute::array> smem_sum_odo; + }; + + static constexpr int kTransactionsBytesLoadQ = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v); + static constexpr int kTransactionsBytesLoadDO = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutDO{})) * cute::sizeof_bits_v); + + static constexpr int kTransactionsBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v); + static constexpr int kTransactionsBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v); + + struct SharedStorage { + TensorStorage tensors; + PipelineStorage pipelines; + uint32_t tmem_base_ptr; + }; + + // this is tight enough that it won't work with sizeof due to padding for alignment + static constexpr int SharedStorageSize = offsetof(SharedStorage, tmem_base_ptr) + sizeof(uint32_t); + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "using too much smem"); + + using TensorStride = TensorStrideContiguousK; // S D (H B) + using RowTensorStride = Stride<_1, Stride>; // S (H B) + + struct MainloopArguments { + const Element* ptr_q; + TensorStride stride_q; + const Element* ptr_k; + TensorStride stride_k; + const Element* ptr_v; + TensorStride stride_v; + const Element* ptr_do; + TensorStride stride_do; + + const ElementAcc* ptr_lse; + RowTensorStride stride_lse; + + const ElementAcc* ptr_sum_odo; + RowTensorStride stride_sum_odo; + + ElementAcc* ptr_dq_acc; + TensorStride stride_dq_acc; + + ElementAcc softmax_scale = 1.0f / sqrtf(TileShapeDQK{}); + }; + + using TMA_K = typename CollectiveMmaKQ::Params::TMA_A; + using TMA_V = typename CollectiveMmaVDO::Params::TMA_A; + using TMA_Q = typename CollectiveMmaKQ::Params::TMA_B; + using TMA_DO = typename CollectiveMmaVDO::Params::TMA_B; + + using TMA_DQ = decltype(make_tma_copy(SM90_TMA_REDUCE_ADD{}, + make_tensor((const ElementAcc*)nullptr, make_shape(1, 1, make_shape(1, 1)), TensorStride{}), + SmemLayoutDQ{}(_, _, _0{}) + )); + + struct MainloopParams { + TMA_K tma_load_k; + TMA_V tma_load_v; + TMA_Q tma_load_q; + TMA_DO tma_load_do; + TMA_DQ tma_red_dq; + }; + + struct EpilogueArguments { + Element* ptr_dk; + TensorStride stride_dk; + Element* ptr_dv; + TensorStride stride_dv; + }; + + struct Arguments { + ProblemShape problem_shape; + MainloopArguments mainloop; + EpilogueArguments epilogue; + KernelHardwareInfo hw_info; + }; + + struct Params { + ProblemShape problem_shape; + MainloopArguments mainloop; + MainloopParams mainloop_params; + EpilogueArguments epilogue; + KernelHardwareInfo hw_info; + }; + + + static bool can_implement(Arguments const& args) { + auto [Q, K, D, D_VO, HB] = args.problem_shape; + auto [H, B] = HB; + if (Q <= 0 || K <= 0 || D <= 0 || D_VO <= 0 || H <= 0 || B <= 0) { + return false; + } + if (D % Alignment != 0 || D_VO % Alignment != 0) { + return false; + } + return true; + } + + + static Status initialize_workspace(Arguments const&, void*, cudaStream_t) { + return Status::kSuccess; + } + + + static Params to_underlying_arguments(Arguments const& args, void*) { + auto [Q_, K_, D, D_VO, HB] = args.problem_shape; + int Q = Q_; + int K = K_; + + if constexpr (is_variable_length_v) { + Q = Q_.total_length; + } + if constexpr (is_variable_length_v) { + K = K_.total_length; + } + + auto params_kq = CollectiveMmaKQ::to_underlying_arguments( + make_shape(K, Q, D, HB), + typename CollectiveMmaKQ::Arguments { + args.mainloop.ptr_k, args.mainloop.stride_k, + args.mainloop.ptr_q, args.mainloop.stride_q, + }, /*workspace=*/nullptr); + + auto params_vdo = CollectiveMmaVDO::to_underlying_arguments( + make_shape(K, Q, D_VO, HB), + typename CollectiveMmaVDO::Arguments { + args.mainloop.ptr_v, args.mainloop.stride_v, + args.mainloop.ptr_do, args.mainloop.stride_do, + }, /*workspace=*/nullptr); + + TMA_DQ tma_red_dq = make_tma_copy( + SM90_TMA_REDUCE_ADD{}, + make_tensor(args.mainloop.ptr_dq_acc, make_shape(Q_, D, HB), args.mainloop.stride_dq_acc), + SmemLayoutDQ{}(_, _, _0{}) + ); + + return Params{ + args.problem_shape, + args.mainloop, + MainloopParams{ + params_kq.tma_load_a, + params_vdo.tma_load_a, + params_kq.tma_load_b, + params_vdo.tma_load_b, + tma_red_dq + }, + args.epilogue, + args.hw_info + }; + } + + + template + static CUTLASS_DEVICE auto quantize(T const& input) { + constexpr int AlignmentS = 4; + auto output = make_tensor(shape(input)); + auto input_vec = recast>(input); + auto output_vec = recast>(output); + + cutlass::NumericArrayConverter epilogue_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(input_vec); i++) { + output_vec(i) = epilogue_op(input_vec(i)); + } + + return output; + } + + + template + CUTLASS_DEVICE void load( + BlkCoord const& blk_coord, + BlkOffset const& blk_offset, + ProblemShape_ const& problem_shape, + int iter_index, + int iter_count, + MainloopArguments const& mainloop_args, + MainloopParams const& mainloop_params, + TensorStorage& shared_tensors, + PipelineLoadMmaQ& pipeline_load_mma_q, + typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_producer_state, + PipelineLoadMmaDO& pipeline_load_mma_do, + typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_producer_state, + PipelineLoadComputeLSE& pipeline_load_compute_lse, + typename PipelineLoadComputeLSE::PipelineState& pipeline_load_compute_lse_producer_state, + PipelineLoadComputeSumOdO& pipeline_load_compute_sum_odo, + typename PipelineLoadComputeSumOdO::PipelineState& pipeline_load_compute_sum_odo_producer_state) { + + auto [Q, K, D, D_VO, HB] = problem_shape; + + using X = Underscore; + + uint16_t mcast_mask = 0; + + auto mK_in = mainloop_params.tma_load_k.get_tma_tensor(make_shape(K, D, HB)); + auto mV_in = mainloop_params.tma_load_v.get_tma_tensor(make_shape(K, D_VO, HB)); + auto mQ_in = mainloop_params.tma_load_q.get_tma_tensor(make_shape(Q, D, HB)); + auto mDO_in = mainloop_params.tma_load_do.get_tma_tensor(make_shape(Q, D_VO, HB)); + + auto mK = domain_offset(select<1,2,4>(blk_offset), mK_in); + auto mV = domain_offset(select<1,3,4>(blk_offset), mV_in); + auto mQ = domain_offset(select<0,2,4>(blk_offset), mQ_in); + auto mDO = domain_offset(select<0,3,4>(blk_offset), mDO_in); + + auto gK = local_tile(mK, TileShapeKQ{}, make_coord(_,_,_), Step<_1, X, _1>{}); + auto gQ = local_tile(mQ, TileShapeKQ{}, make_coord(_,_,_), Step{}); + auto gV = local_tile(mV, TileShapeVDO{}, make_coord(_,_,_), Step<_1, X, _1>{}); + auto gDO = local_tile(mDO, TileShapeVDO{}, make_coord(_,_,_), Step{}); + + ThrMMA cta_mma_kq = TiledMmaKQ{}.get_slice(_0{}); + ThrMMA cta_mma_vdo = TiledMmaVDO{}.get_slice(_0{}); + + auto tSTgK = cta_mma_kq.partition_A(gK); + auto tSTgQ = cta_mma_kq.partition_B(gQ); + auto tDPTgV = cta_mma_vdo.partition_A(gV); + auto tDPTgDO = cta_mma_vdo.partition_B(gDO); + + auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + auto sK = make_tensor(make_smem_ptr(shared_tensors.smem_k.begin()), SmemLayoutK{}); + auto sV = make_tensor(make_smem_ptr(shared_tensors.smem_v.begin()), SmemLayoutV{}); + auto sDO = make_tensor(make_smem_ptr(shared_tensors.smem_do.begin()), SmemLayoutDO{}); + + auto [tKgK_mkl, tKsK] = tma_partition( + mainloop_params.tma_load_k, _0{}, make_layout(_1{}), + group_modes<0,3>(sK), group_modes<0,3>(tSTgK)); + auto [tQgQ_mkl, tQsQ] = tma_partition( + mainloop_params.tma_load_q, _0{}, make_layout(_1{}), + group_modes<0,3>(sQ), group_modes<0,3>(tSTgQ)); + auto [tVgV_mkl, tVsV] = tma_partition( + mainloop_params.tma_load_v, _0{}, make_layout(_1{}), + group_modes<0,3>(sV), group_modes<0,3>(tDPTgV)); + auto [tDOgDO_mkl, tDOsDO] = tma_partition( + mainloop_params.tma_load_do, _0{}, make_layout(_1{}), + group_modes<0,3>(sDO), group_modes<0,3>(tDPTgDO)); + + // set up lse and sum_odo + + auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; + + pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state); + auto tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state); + + pipeline_load_mma_q.producer_expect_transaction(pipeline_load_mma_q_producer_state, kTransactionsBytesLoadK); + + // load K + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_k.with(*tma_barrier, mcast_mask), + tKgK_mkl(_, blk_coord_k, _0{}, blk_coord_batch), + tKsK(_, _0{}) + ); + } + + // load Q + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask), + tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch), + tQsQ(_, pipeline_load_mma_q_producer_state.index()) + ); + } + + ++pipeline_load_mma_q_producer_state; + + pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state); + + // load LSE + // 32 threads loading 128 values of 32b each + // so 4*32b=128b + + int thread_idx = threadIdx.x % NumThreadsPerWarp; + int smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * 4; + int gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4; + auto mLSE = make_tensor(mainloop_args.ptr_lse, make_shape(Q, HB), mainloop_args.stride_lse); + for (int i = 0; i < 4; i++) { + cutlass::arch::cp_async_zfill<4>( + shared_tensors.smem_lse.begin() + smem_idx + i, + &mLSE(gmem_idx + i, blk_coord_batch), + gmem_idx + i < Q + ); + } + + pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_load_compute_lse_producer_state; + + + pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state); + tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state); + + pipeline_load_mma_do.producer_expect_transaction(pipeline_load_mma_do_producer_state, kTransactionsBytesLoadV); + + // load V + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_v.with(*tma_barrier, mcast_mask), + tVgV_mkl(_, blk_coord_k, _0{}, blk_coord_batch), + tVsV(_, _0{}) + ); + } + + // load dO + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask), + tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch), + tDOsDO(_, pipeline_load_mma_do_producer_state.index()) + ); + } + + ++pipeline_load_mma_do_producer_state; + + pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state); + + // load sum_OdO + smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * 4; + gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4; + auto mSumOdO = make_tensor(mainloop_args.ptr_sum_odo, make_shape(Q, HB), mainloop_args.stride_sum_odo); + for (int i = 0; i < 4; i++) { + cutlass::arch::cp_async_zfill<4>( + shared_tensors.smem_sum_odo.begin() + smem_idx + i, + &mSumOdO(gmem_idx + i, blk_coord_batch), + gmem_idx + i < Q + ); + } + + pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_load_compute_sum_odo_producer_state; + + iter_count -= 1; + iter_index += 1; + + while (iter_count > 0) { + pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state); + tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state); + + // load Q + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask), + tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch), + tQsQ(_, pipeline_load_mma_q_producer_state.index()) + ); + } + + ++pipeline_load_mma_q_producer_state; + + pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state); + + // load LSE + smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * 4; + gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4; + for (int i = 0; i < 4; i++) { + cutlass::arch::cp_async_zfill<4>( + shared_tensors.smem_lse.begin() + smem_idx + i, + &mLSE(gmem_idx + i, blk_coord_batch), + gmem_idx + i < Q + ); + } + + pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_load_compute_lse_producer_state; + + pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state); + tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state); + + // load dO + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask), + tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch), + tDOsDO(_, pipeline_load_mma_do_producer_state.index()) + ); + } + + ++pipeline_load_mma_do_producer_state; + + pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state); + + // load sum_OdO + smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * 4; + gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4; + for (int i = 0; i < 4; i++) { + cutlass::arch::cp_async_zfill<4>( + shared_tensors.smem_sum_odo.begin() + smem_idx + i, + &mSumOdO(gmem_idx + i, blk_coord_batch), + gmem_idx + i < Q + ); + } + + pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_load_compute_sum_odo_producer_state; + + iter_count -= 1; + iter_index += 1; + } + } + + + template + CUTLASS_DEVICE void mma( + BlkCoord const& blk_coord, + ProblemShape_ const& problem_shape, + int iter_index, + int iter_count, + MainloopArguments const& mainloop_args, + TensorStorage& shared_tensors, + PipelineLoadMmaQ& pipeline_load_mma_q, + typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_consumer_state, + PipelineLoadMmaDO& pipeline_load_mma_do, + typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_consumer_state, + PipelineMmaComputeS& pipeline_mma_compute_s, + typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_producer_state, + PipelineMmaComputeDP& pipeline_mma_compute_dp, + typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_producer_state, + PipelineMmaReduceDQ& pipeline_mma_reduce_dq, + typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_producer_state, + PipelineComputeMmaP& pipeline_compute_mma_p, + typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_consumer_state, + PipelineComputeMmaDS& pipeline_compute_mma_ds, + typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_consumer_state, + PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, + typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_producer_state) { + + auto [Q, K, D, D_VO, HB] = problem_shape; + + auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + auto sK = make_tensor(make_smem_ptr(shared_tensors.smem_k.begin()), SmemLayoutK{}); + auto sV = make_tensor(make_smem_ptr(shared_tensors.smem_v.begin()), SmemLayoutV{}); + auto sDO = make_tensor(make_smem_ptr(shared_tensors.smem_do.begin()), SmemLayoutDO{}); + + auto sQT = make_tensor(make_smem_ptr(shared_tensors.smem_q_t.begin()), SmemLayoutQT{}); + auto sKT = make_tensor(make_smem_ptr(shared_tensors.smem_k_t.begin()), SmemLayoutKT{}); + auto sDS = make_tensor(make_smem_ptr(shared_tensors.smem_ds.begin()), SmemLayoutDS{}); + auto sDST = make_tensor(make_smem_ptr(shared_tensors.smem_ds_t.begin()), SmemLayoutDST{}); + auto sP = make_tensor(make_smem_ptr((Element*) nullptr), typename CollectiveMmaPDO::SmemLayoutA{}); + auto sDOT = make_tensor(make_smem_ptr(shared_tensors.smem_do_t.begin()), SmemLayoutDOT{}); + + Tensor tSTrK = TiledMmaKQ::make_fragment_A(sK); + Tensor tSTrQ = TiledMmaKQ::make_fragment_B(sQ); + + Tensor tDPTrV = TiledMmaVDO::make_fragment_A(sV); + Tensor tDPTrDO = TiledMmaVDO::make_fragment_B(sDO); + + Tensor tDQrDS = TiledMmaDSK::make_fragment_A(sDS); + Tensor tDQrKT = TiledMmaDSK::make_fragment_B(sKT); + + Tensor tDKrDST = TiledMmaDSQ::make_fragment_A(sDST); + Tensor tDKrQT = TiledMmaDSQ::make_fragment_B(sQT); + + Tensor tDVrP = TiledMmaPDO::make_fragment_A(sP)(_, _, _, _0{}); + tDVrP.data() = TmemAllocation::kP; + Tensor tDVrDOT = TiledMmaPDO::make_fragment_B(sDOT); + + TiledMmaKQ tiled_mma_kq; + TiledMmaVDO tiled_mma_vdo; + TiledMmaDSK tiled_mma_dsk; + TiledMmaDSQ tiled_mma_dsq; + TiledMmaPDO tiled_mma_pdo; + + tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::Zero; + tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::Zero; + + Tensor tSTtST = partition_fragment_C(tiled_mma_kq, select<0,1>(TileShapeKQ{})); + tSTtST.data() = TmemAllocation::kS; + + Tensor tDPTtDPT = partition_fragment_C(tiled_mma_vdo, select<0,1>(TileShapeVDO{})); + tDPTtDPT.data() = TmemAllocation::kDP; + + Tensor tDQtDQ = partition_fragment_C(tiled_mma_dsk, select<0,1>(TileShapeDSK{})); + tDQtDQ.data() = TmemAllocation::kDQ; + + Tensor tDKtDK = partition_fragment_C(tiled_mma_dsq, select<0,1>(TileShapeDSQ{})); + tDKtDK.data() = TmemAllocation::kDK; + + Tensor tDVtDV = partition_fragment_C(tiled_mma_pdo, select<0,1>(TileShapePDO{})); + tDVtDV.data() = TmemAllocation::kDV; + + auto pipeline_load_mma_q_release_state = pipeline_load_mma_q_consumer_state; + + pipeline_load_mma_q.consumer_wait(pipeline_load_mma_q_consumer_state); + pipeline_mma_compute_s.producer_acquire(pipeline_mma_compute_s_producer_state); + + // S = Q*K + tiled_mma_kq.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tSTrQ); ++k_block) { + cute::gemm(tiled_mma_kq, + tSTrK(_,_,k_block,_0{}), + tSTrQ(_,_,k_block,pipeline_load_mma_q_consumer_state.index()), + tSTtST); + tiled_mma_kq.accumulate_ = UMMA::ScaleOut::One; + } + + ++pipeline_load_mma_q_consumer_state; + + pipeline_mma_compute_s.producer_commit(pipeline_mma_compute_s_producer_state); + ++pipeline_mma_compute_s_producer_state; + + pipeline_load_mma_do.consumer_wait(pipeline_load_mma_do_consumer_state); + + pipeline_mma_compute_dp.producer_acquire(pipeline_mma_compute_dp_producer_state); + pipeline_mma_reduce_dq.producer_acquire(pipeline_mma_reduce_dq_producer_state); + + // dP = dO*V + tiled_mma_vdo.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDPTrV); ++k_block) { + cute::gemm(tiled_mma_vdo, + tDPTrV(_,_,k_block,_0{}), + tDPTrDO(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), + tDPTtDPT); + tiled_mma_vdo.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_mma_compute_dp.producer_commit(pipeline_mma_compute_dp_producer_state); + ++pipeline_mma_compute_dp_producer_state; + + pipeline_compute_mma_p.consumer_wait(pipeline_compute_mma_p_consumer_state); + + // dV = P*dO + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDVrP); ++k_block) { + cute::gemm(tiled_mma_pdo, + tDVrP(_,_,k_block), + tDVrDOT(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), + tDVtDV); + tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_compute_mma_p.consumer_release(pipeline_compute_mma_p_consumer_state); + ++pipeline_compute_mma_p_consumer_state; + + pipeline_load_mma_do.consumer_release(pipeline_load_mma_do_consumer_state); + ++pipeline_load_mma_do_consumer_state; + + iter_count -= 1; + + // in tmem, S & P overlap + // and dP and dQ overlap + // so we need to acquire dQ and dP at the same time + while (iter_count > 0) { + pipeline_load_mma_q.consumer_wait(pipeline_load_mma_q_consumer_state); + pipeline_mma_compute_s.producer_acquire(pipeline_mma_compute_s_producer_state); + + // S = Q*K + tiled_mma_kq.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tSTrQ); ++k_block) { + cute::gemm(tiled_mma_kq, + tSTrK(_,_,k_block,_0{}), + tSTrQ(_,_,k_block,pipeline_load_mma_q_consumer_state.index()), + tSTtST); + tiled_mma_kq.accumulate_ = UMMA::ScaleOut::One; + } + + ++pipeline_load_mma_q_consumer_state; + + pipeline_mma_compute_s.producer_commit(pipeline_mma_compute_s_producer_state); + ++pipeline_mma_compute_s_producer_state; + + pipeline_compute_mma_ds.consumer_wait(pipeline_compute_mma_ds_consumer_state); + + // we need to acquire dP here, because tmem dQ == tmem dP + pipeline_mma_compute_dp.producer_acquire(pipeline_mma_compute_dp_producer_state); + + // dQ = dS*K + tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDQrDS); ++k_block) { + cute::gemm(tiled_mma_dsk, + tDQrDS(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), + tDQrKT(_,_,k_block,_0{}), + tDQtDQ); + tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_mma_reduce_dq.producer_commit(pipeline_mma_reduce_dq_producer_state); + ++pipeline_mma_reduce_dq_producer_state; + + // dK = dS*Q + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDKrDST); ++k_block) { + cute::gemm(tiled_mma_dsq, + tDKrDST(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), + tDKrQT(_,_,k_block,pipeline_load_mma_q_release_state.index()), + tDKtDK); + tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_load_mma_q.consumer_release(pipeline_load_mma_q_release_state); + ++pipeline_load_mma_q_release_state; + + pipeline_compute_mma_ds.consumer_release(pipeline_compute_mma_ds_consumer_state); + ++pipeline_compute_mma_ds_consumer_state; + + // we grab dq here, because in tmem dq == dp + pipeline_mma_reduce_dq.producer_acquire(pipeline_mma_reduce_dq_producer_state); + + pipeline_load_mma_do.consumer_wait(pipeline_load_mma_do_consumer_state); + + // dP = dO*V + tiled_mma_vdo.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDPTrV); ++k_block) { + cute::gemm(tiled_mma_vdo, + tDPTrV(_,_,k_block,_0{}), + tDPTrDO(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), + tDPTtDPT); + tiled_mma_vdo.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_mma_compute_dp.producer_commit(pipeline_mma_compute_dp_producer_state); + ++pipeline_mma_compute_dp_producer_state; + + pipeline_compute_mma_p.consumer_wait(pipeline_compute_mma_p_consumer_state); + + // dV = P*dO + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDVrP); ++k_block) { + cute::gemm(tiled_mma_pdo, + tDVrP(_,_,k_block), + tDVrDOT(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), + tDVtDV); + tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_compute_mma_p.consumer_release(pipeline_compute_mma_p_consumer_state); + ++pipeline_compute_mma_p_consumer_state; + + pipeline_load_mma_do.consumer_release(pipeline_load_mma_do_consumer_state); + ++pipeline_load_mma_do_consumer_state; + + iter_count -= 1; + } + + // signal to the epilogue that dV is ready + pipeline_mma_compute_dkdv.producer_acquire(pipeline_mma_compute_dkdv_producer_state); + pipeline_mma_compute_dkdv.producer_commit(pipeline_mma_compute_dkdv_producer_state); + ++pipeline_mma_compute_dkdv_producer_state; + + pipeline_mma_compute_dkdv.producer_acquire(pipeline_mma_compute_dkdv_producer_state); + + pipeline_compute_mma_ds.consumer_wait(pipeline_compute_mma_ds_consumer_state); + + // dK = dS*Q + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDKrDST); ++k_block) { + cute::gemm(tiled_mma_dsq, + tDKrDST(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), + tDKrQT(_,_,k_block,pipeline_load_mma_q_release_state.index()), + tDKtDK); + tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::One; + } + + // signal to epilgue that dK is ready + pipeline_mma_compute_dkdv.producer_commit(pipeline_mma_compute_dkdv_producer_state); + ++pipeline_mma_compute_dkdv_producer_state; + + // we've already acquired mma_reduce_dq in the loop + + // dQ = dS*K + tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDQrDS); ++k_block) { + cute::gemm(tiled_mma_dsk, + tDQrDS(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), + tDQrKT(_,_,k_block,_0{}), + tDQtDQ); + tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_mma_reduce_dq.producer_commit(pipeline_mma_reduce_dq_producer_state); + ++pipeline_mma_reduce_dq_producer_state; + + pipeline_load_mma_q.consumer_release(pipeline_load_mma_q_release_state); + ++pipeline_load_mma_q_release_state; + + pipeline_compute_mma_ds.consumer_release(pipeline_compute_mma_ds_consumer_state); + ++pipeline_compute_mma_ds_consumer_state; + } + + + + template + CUTLASS_DEVICE void store( + TensorG gmem, + TensorR const& regs, + TensorC const& coord, + TensorShape const& tensor_shape) { + + //TODO Performance of FlashMLA on hopper is dropped with latest cutlass, so here revert the to the old version. + // Tensor preds = cute::lazy::transform(coord, [&](auto const& c) { return elem_less(c, tensor_shape); }); + + auto copy_op = make_cotiled_copy( + Copy_Atom, Element>{}, + make_layout(make_shape(_1{}, Int{})), + regs.layout() + ); + auto thr_copy = copy_op.get_slice(_0{}); + + Tensor quantized_regs = quantize(regs); + auto tCg = thr_copy.partition_D(gmem); + auto tCr = thr_copy.partition_S(quantize(regs)); + auto tCc = thr_copy.partition_D(coord); + + + constexpr int R = decltype(tCr.layout())::rank; + auto tCg_v = group_modes<1, R>(tCg); + auto tCr_v = group_modes<1, R>(tCr); + auto tCc_v = group_modes<1, R>(tCc); + auto tCp_v = make_tensor(shape<1>(tCc_v)); + + for (int i = 0; i < size(tCp_v); ++i) { + tCp_v(i) = elem_less(tCc_v(_0{},i), tensor_shape); + } + + copy_if(copy_op, tCp_v, tCr_v, tCg_v); + + } + + + template + CUTLASS_DEVICE void epilogue_clear( + BlkCoord const& blk_coord, + BlkOffset const& blk_offset, + ProblemShape_ const& problem_shape, + MainloopArguments const& mainloop_args, + EpilogueArguments const& epilogue_args) { + + auto [Q, K, D, D_VO, HB] = problem_shape; + auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; + + auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk); + auto mDK = domain_offset(select<1,2,4>(blk_offset), mDK_in); + auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{}) + (_, _, blk_coord_k, _0{}, blk_coord_batch); + + Tensor cDK = domain_offset( + make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}), + make_identity_tensor(take<0,2>(TileShapeDSQ{})) + ); + + auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv); + auto mDV = domain_offset(select<1,3,4>(blk_offset), mDV_in); + auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{}) + (_, _, blk_coord_k, _0{}, blk_coord_batch); + + Tensor cDV = domain_offset( + make_coord(blk_coord_k * TileShapeK{}, _0{}), + make_identity_tensor(take<0,2>(TileShapePDO{})) + ); + + for (int i = threadIdx.x; i < size(gDK); i += blockDim.x) { + if (elem_less(cDK(i), select<1,2>(problem_shape))) { + gDK(i) = Element(0); + } + } + for (int i = threadIdx.x; i < size(gDV); i += blockDim.x) { + if (elem_less(cDV(i), select<1,3>(problem_shape))) { + gDV(i) = Element(0); + } + } + } + + + template + CUTLASS_DEVICE void epilogue( + BlkCoord const& blk_coord, + BlkOffset const& blk_offset, + ProblemShape_ const& problem_shape, + MainloopArguments const& mainloop_args, + EpilogueArguments const& epilogue_args, + PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, + typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) { + + auto [Q, K, D, D_VO, HB] = problem_shape; + auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; + + auto load_op = SM100_TMEM_LOAD_32dp32b16x{}; + + auto tDKtDK = partition_fragment_C(TiledMmaDSQ{}, select<0,1>(TileShapeDSQ{}))(make_coord(_,_),_0{},_0{}); + tDKtDK.data() = TmemAllocation::kDK; + + auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk); + auto mDK = domain_offset(select<1,2,4>(blk_offset), mDK_in); + auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{}) + (_, _, blk_coord_k, _0{}, blk_coord_batch); + + Tensor cDK = domain_offset( + make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}), + make_identity_tensor(take<0,2>(TileShapeDSQ{})) + ); + + constexpr int kNumWarpgroups = kNumComputeWarps / 4; + int dp_idx = threadIdx.x % 128; + int wg_idx = (threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp)) / 128; + + auto split_wg = [&](auto const& t) { + if constexpr (decltype(rank(t))::value == 3) { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), make_shape(Int{}, size<2>(t) / Int{})))); + return p(_, _, make_coord(wg_idx, _)); + } + else { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), size<2>(t), make_shape(Int{}, size<3>(t) / Int{})))); + return p(_, _, _, make_coord(wg_idx, _)); + } + }; + + auto tiled_t2r_dk = make_tmem_copy(load_op, tDKtDK); + auto thread_t2r_dk = tiled_t2r_dk.get_slice(dp_idx); + + Tensor tTR_cDK = split_wg(thread_t2r_dk.partition_D(cDK)); + Tensor tTR_gDK = split_wg(thread_t2r_dk.partition_D(gDK)); + Tensor tTR_rDK = make_tensor(shape(tTR_cDK)); + Tensor tTR_tDK = split_wg(thread_t2r_dk.partition_S(tDKtDK)); + + auto tDVtDV = partition_fragment_C(TiledMmaDSQ{}, select<0,1>(TileShapeDSQ{}))(make_coord(_,_),_0{},_0{}); + tDVtDV.data() = TmemAllocation::kDV; + + auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv); + auto mDV = domain_offset(select<1,3,4>(blk_offset), mDV_in); + auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{}) + (_, _, blk_coord_k, _0{}, blk_coord_batch); + + Tensor cDV = domain_offset( + make_coord(blk_coord_k * TileShapeK{}, _0{}), + make_identity_tensor(take<0,2>(TileShapePDO{})) + ); + + auto tiled_t2r_dv = make_tmem_copy(load_op, tDVtDV); + auto thread_t2r_dv = tiled_t2r_dv.get_slice(dp_idx); + + Tensor tTR_cDV = split_wg(thread_t2r_dv.partition_D(cDV)); + Tensor tTR_gDV = split_wg(thread_t2r_dv.partition_D(gDV)); + Tensor tTR_rDV = make_tensor(shape(tTR_cDV)); + Tensor tTR_tDV = split_wg(thread_t2r_dv.partition_S(tDVtDV)); + + pipeline_mma_compute_dkdv.consumer_wait(pipeline_mma_compute_dkdv_consumer_state); + + // load tDVtDV + cute::copy(tiled_t2r_dv, tTR_tDV, tTR_rDV); + + // store tDVgDV + store(tTR_gDV, tTR_rDV, tTR_cDV, select<1,3>(problem_shape)); + + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_compute_dkdv.consumer_release(pipeline_mma_compute_dkdv_consumer_state); + ++pipeline_mma_compute_dkdv_consumer_state; + + pipeline_mma_compute_dkdv.consumer_wait(pipeline_mma_compute_dkdv_consumer_state); + + // load tDKtDK + cute::copy(tiled_t2r_dk, tTR_tDK, tTR_rDK); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rDK); i++) { + tTR_rDK(i) = mainloop_args.softmax_scale * tTR_rDK(i); + } + + // store tDKgDK + store(tTR_gDK, tTR_rDK, tTR_cDK, select<1,2>(problem_shape)); + + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_compute_dkdv.consumer_release(pipeline_mma_compute_dkdv_consumer_state); + ++pipeline_mma_compute_dkdv_consumer_state; + + } + + + template + CUTLASS_DEVICE void compute( + BlkCoord const& blk_coord, + BlkOffset const& blk_offset, + ProblemShape_ const& problem_shape, + int iter_index, + int iter_count, + MainloopArguments const& mainloop_args, + EpilogueArguments const& epilogue_args, + TensorStorage& shared_tensors, + PipelineLoadComputeLSE& pipeline_load_compute_lse, + typename PipelineLoadComputeLSE::PipelineState& pipeline_load_compute_lse_consumer_state, + PipelineLoadComputeSumOdO& pipeline_load_compute_sum_odo, + typename PipelineLoadComputeSumOdO::PipelineState& pipeline_load_compute_sum_odo_consumer_state, + PipelineMmaComputeS& pipeline_mma_compute_s, + typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_consumer_state, + PipelineMmaComputeDP& pipeline_mma_compute_dp, + typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_consumer_state, + PipelineComputeMmaP& pipeline_compute_mma_p, + typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_producer_state, + PipelineComputeMmaDS& pipeline_compute_mma_ds, + typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_producer_state, + PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, + typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) { + + + auto [Q, K, D, D_VO, HB] = problem_shape; + + // in tmem, S & P overlap + // and dP and dQ overlap + + // there are two compute wg's that cooperatively compute softmax + // they are striped by this tmem atom, i.e. wg0 has 16 elems, then wg1 etc + + auto load_op = SM100_TMEM_LOAD_32dp32b16x{}; + auto store_op = []() { + if constexpr (sizeof(Element) == 1) { + return SM100_TMEM_STORE_32dp32b4x{}; + } + else { + return SM100_TMEM_STORE_32dp32b8x{}; + } + }(); + + Tensor tSTtST = partition_fragment_C(TiledMmaKQ{}, select<0,1>(TileShapeKQ{}))(make_coord(_,_),_0{},_0{}); + tSTtST.data() = TmemAllocation::kS; + + Tensor tDPTtDPT = partition_fragment_C(TiledMmaVDO{}, select<0,1>(TileShapeVDO{}))(make_coord(_,_),_0{},_0{}); + tDPTtDPT.data() = TmemAllocation::kDP; + + Tensor cST = make_identity_tensor(take<0,2>(TileShapeKQ{})); + Tensor cDPT = make_identity_tensor(take<0,2>(TileShapeVDO{})); + + constexpr int kNumWarpgroups = kNumComputeWarps / 4; + int dp_idx = threadIdx.x % 128; + int wg_idx = (threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp)) / 128; + auto tiled_t2r = make_tmem_copy(load_op, tSTtST); + auto thread_t2r = tiled_t2r.get_slice(dp_idx); + + auto split_wg = [&](auto const& t) { + if constexpr (decltype(size<1>(t))::value > 1) { + if constexpr (decltype(rank(t))::value == 3) { + auto p = t.compose(make_layout(make_shape(size<0>(t), make_shape(Int{}, size<1>(t) / Int{}), size<2>(t)))); + return p(_, make_coord(wg_idx, _), _); + } + else { + auto p = t.compose(make_layout(make_shape(size<0>(t), make_shape(Int{}, size<1>(t) / Int{}), size<2>(t), size<3>(t)))); + return p(_, make_coord(wg_idx, _), _, _); + } + } + else { + if constexpr (decltype(rank(t))::value == 3) { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), make_shape(Int{}, size<2>(t) / Int{})))); + return p(_, _, make_coord(wg_idx, _)); + } + else { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), size<2>(t), make_shape(Int{}, size<3>(t) / Int{})))); + return p(_, _, _, make_coord(wg_idx, _)); + } + + } + }; + + + Tensor tTR_cST_p = thread_t2r.partition_D(cST); + Tensor tTR_cST = split_wg(tTR_cST_p); + Tensor tTR_rST = make_tensor(shape(tTR_cST)); + Tensor tTR_tST = split_wg(thread_t2r.partition_S(tSTtST)); + + Tensor tTR_cDPT_p = thread_t2r.partition_D(cDPT); + Tensor tTR_cDPT = split_wg(tTR_cDPT_p); + Tensor tTR_rDPT = make_tensor(shape(tTR_cDPT)); + Tensor tTR_tDPT = split_wg(thread_t2r.partition_S(tDPTtDPT)); + + Tensor sLSE = make_tensor(make_smem_ptr(shared_tensors.smem_lse.begin()), SmemLayoutLSE{}); + Tensor sSumOdO = make_tensor(make_smem_ptr(shared_tensors.smem_sum_odo.begin()), SmemLayoutSumOdO{}); + + auto sP = make_tensor(make_smem_ptr((Element*) nullptr), typename CollectiveMmaPDO::SmemLayoutA{}); + + auto tDVrP = TiledMmaPDO::make_fragment_A(sP)(_, _, _, _0{}); + auto tDVcST = TiledMmaPDO{}.get_slice(_0{}).partition_A(cST); + tDVrP.data() = TmemAllocation::kP; + + auto tiled_r2t = make_tmem_copy(store_op, tDVrP); + auto thread_r2t = tiled_r2t.get_slice(dp_idx); + + auto tRT_tP = split_wg(thread_r2t.partition_D(tDVrP)); + auto tRT_cST_p = thread_r2t.partition_S(tDVcST); + auto tRT_cST = split_wg(tRT_cST_p); + + bool is_residual_k = get<1>(blk_coord) * TileShapeK{} + TileShapeK{} >= get<1>(problem_shape); + int last_iter = iter_count - 1 + iter_index; + + CUTLASS_PRAGMA_NO_UNROLL + while (iter_count > 0) { + // wait for S and P + pipeline_mma_compute_s.consumer_wait(pipeline_mma_compute_s_consumer_state); + pipeline_compute_mma_p.producer_acquire(pipeline_compute_mma_p_producer_state); + // wait for LSE + pipeline_load_compute_lse.consumer_wait(pipeline_load_compute_lse_consumer_state); + + auto dispatch_bool = [](bool b, auto fn) { + if (b) { + fn(cute::true_type{}); + } + else { + fn(cute::false_type{}); + } + }; + + bool leading_causal_masking = false; + if constexpr (std::is_base_of_v, Mask>) { + leading_causal_masking = warp_uniform(iter_index == get<1>(blk_coord)); + } else if constexpr (std::is_base_of_v, Mask>) { + int offset = get<1>(problem_shape) - get<0>(problem_shape); + int kv_left = get<1>(blk_coord) * TileShapeK{}; + int kv_right = kv_left + TileShapeK{} - 1; + int q_left = iter_index * TileShapeQ{} + offset; + int q_right = q_left + TileShapeQ{} - 1; + + leading_causal_masking = warp_uniform(!((q_left > kv_right) || (q_right < kv_left))); + } + bool trailing_residual_masking = false; + if constexpr (std::is_base_of_v) { + trailing_residual_masking = warp_uniform((iter_index == last_iter) || is_residual_k); + } + + dispatch_bool(leading_causal_masking || trailing_residual_masking, [&](auto is_masked_tile) { + + // compute P = softmax(S, LSE) + cute::copy(tiled_t2r, tTR_tST, tTR_rST); + + if constexpr (decltype(is_masked_tile)::value) { + Mask{}.apply_mask(tTR_rST, [&](int i) { + auto c_transpose = tTR_cST(i); + return make_coord(get<1>(c_transpose) + iter_index * TileShapeQ{}, get<0>(c_transpose) + get<1>(blk_coord) * TileShapeK{}); + }, problem_shape); + } + + ElementAcc log2_e = static_cast(M_LOG2E); + float2 softmax_scale_log2_e; + softmax_scale_log2_e.x = mainloop_args.softmax_scale * log2_e; + softmax_scale_log2_e.y = mainloop_args.softmax_scale * log2_e; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rST); i += 2) { + float2 acc; + float2 lse; + float2 out; + acc.x = tTR_rST(i); + acc.y = tTR_rST(i + 1); + lse.x = sLSE(get<1>(tTR_cST(i)), pipeline_load_compute_lse_consumer_state.index()); + lse.y = sLSE(get<1>(tTR_cST(i+1)), pipeline_load_compute_lse_consumer_state.index()); + cute::fma(out, softmax_scale_log2_e, acc, lse); + tTR_rST(i) = ::exp2f(out.x); + tTR_rST(i+1) = ::exp2f(out.y); + } + + auto tRT_rST = quantize(tTR_rST); + auto tRT_rST_reshaped = make_tensor(tRT_rST.data(), shape(tRT_cST)); + + cutlass::arch::fence_view_async_tmem_load(); + cutlass::arch::NamedBarrier( + kNumComputeWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::TransformBarrier + ).arrive_and_wait(); + + cute::copy(tiled_r2t, tRT_rST_reshaped, tRT_tP); + }); + + // notify for P + cutlass::arch::fence_view_async_tmem_store(); + pipeline_compute_mma_p.producer_commit(pipeline_compute_mma_p_producer_state); + ++pipeline_compute_mma_p_producer_state; + // release S + pipeline_mma_compute_s.consumer_release(pipeline_mma_compute_s_consumer_state); + ++pipeline_mma_compute_s_consumer_state; + // release LSE + pipeline_load_compute_lse.consumer_release(pipeline_load_compute_lse_consumer_state); + ++pipeline_load_compute_lse_consumer_state; + + // wait for OdO + pipeline_load_compute_sum_odo.consumer_wait(pipeline_load_compute_sum_odo_consumer_state); + // wait for dP + pipeline_mma_compute_dp.consumer_wait(pipeline_mma_compute_dp_consumer_state); + + // wait for dS + // in principle, we could defer waiting for dS, and move in the freeing of dP + // however, that would force us to keep dS in registers longer + pipeline_compute_mma_ds.producer_acquire(pipeline_compute_mma_ds_producer_state); + + // compute dS = dsoftmax(P, dP, sum_OdO) + cute::copy(tiled_t2r, tTR_tDPT, tTR_rDPT); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rDPT); i += 2) { + float2 st; + st.x = tTR_rST(i); + st.y = tTR_rST(i+1); + float2 dpt; + dpt.x = tTR_rDPT(i); + dpt.y = tTR_rDPT(i+1); + float2 odo; + odo.x = sSumOdO(get<1>(tTR_cDPT(i)), pipeline_load_compute_sum_odo_consumer_state.index()); + odo.y = sSumOdO(get<1>(tTR_cDPT(i+1)), pipeline_load_compute_sum_odo_consumer_state.index()); + float2 dif; + // sum odo is negated during preprocess + cute::add(dif, dpt, odo); + float2 out; + cute::mul(out, dif, st); + tTR_rDPT(i) = out.x; + tTR_rDPT(i+1) = out.y; + } + + auto tTR_rDST = quantize(tTR_rDPT); + + // release dP + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_compute_dp.consumer_release(pipeline_mma_compute_dp_consumer_state); + ++pipeline_mma_compute_dp_consumer_state; + + Tensor sDS = make_tensor(make_smem_ptr((Element*) shared_tensors.smem_ds.begin()), SmemLayoutDS{}) + (_, _, _, pipeline_compute_mma_ds_producer_state.index()); + + auto thread_layout = make_ordered_layout( + make_shape(_128{}, _128{}), + make_stride(_1{}, _0{}) + ); + + auto sDS_pi = as_position_independent_swizzle_tensor(sDS); + auto sDS_pi_slice_p = sDS_pi.compose(thread_layout)(dp_idx, _).compose(make_layout(shape(tTR_cDPT_p))); + auto sDS_pi_slice = split_wg(sDS_pi_slice_p); + + copy_aligned(tTR_rDST, sDS_pi_slice); + + // notify for dS + cutlass::arch::fence_view_async_shared(); + pipeline_compute_mma_ds.producer_commit(pipeline_compute_mma_ds_producer_state); + ++pipeline_compute_mma_ds_producer_state; + // release OdO + pipeline_load_compute_sum_odo.consumer_release(pipeline_load_compute_sum_odo_consumer_state); + ++pipeline_load_compute_sum_odo_consumer_state; + + iter_count -= 1; + iter_index += 1; + } + + epilogue( + blk_coord, blk_offset, problem_shape, mainloop_args, epilogue_args, + pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_consumer_state + ); + } + + template + CUTLASS_DEVICE void reduce( + BlkCoord const& blk_coord, + ProblemShape_ const& problem_shape, + int iter_index, + int iter_count, + MainloopArguments const& mainloop_args, + MainloopParams const& mainloop_params, + TensorStorage& shared_tensors, + PipelineMmaReduceDQ& pipeline_mma_reduce_dq, + typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_consumer_state, + PipelineReduceTmaStore& pipeline_reduce_tma_store, + typename PipelineReduceTmaStore::PipelineState& pipeline_reduce_tma_store_producer_state) { + + using X = Underscore; + + auto [Q, K, D, D_VO, HB] = problem_shape; + + auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; + + // must match TileShapeDQ + auto load_op = SM100_TMEM_LOAD_32dp32b32x{}; + + auto tDQtDQ = partition_fragment_C(TiledMmaDSK{}, select<0,1>(TileShapeDSK{}))(make_coord(_,_),_0{},_0{}); + tDQtDQ.data() = TmemAllocation::kDQ; + + Tensor mDQ = mainloop_params.tma_red_dq.get_tma_tensor(make_shape(Q, D, HB)); + auto gDQ = local_tile(mDQ, TileShapeKQ{}, make_coord(_,_,_), Step{}) + (_, _, _, _0{}, blk_coord_batch); + + Tensor cDQ = make_identity_tensor(take<0,2>(TileShapeDSK{})); + + Tensor sDQ = make_tensor(make_smem_ptr(shared_tensors.smem_dq.begin()), SmemLayoutDQ{}); + + int thread_idx = threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp); + auto tiled_t2r = make_tmem_copy(load_op, tDQtDQ); + auto thread_t2r = tiled_t2r.get_slice(thread_idx); + + Tensor tTR_cDQ = thread_t2r.partition_D(cDQ); + Tensor tTR_gDQ = thread_t2r.partition_D(gDQ); + Tensor tTR_sDQ = thread_t2r.partition_D(sDQ); + Tensor tTR_tDQ = thread_t2r.partition_S(tDQtDQ); + + auto block_tma = mainloop_params.tma_red_dq.get_slice(_0{}); + + Tensor tDQsDQ = block_tma.partition_S(sDQ); + Tensor tDQcDQ = block_tma.partition_S(cDQ); + Tensor tDQgDQ = block_tma.partition_D(gDQ); + + int lane_predicate = (threadIdx.x % (kNumReduceWarps * NumThreadsPerWarp)) == 0; + + while (iter_count > 0) { + pipeline_mma_reduce_dq.consumer_wait(pipeline_mma_reduce_dq_consumer_state); + + Tensor tTR_rDQ = make_tensor(shape(tTR_cDQ)); + + // load dQ from tmem to rmem + cute::copy(tiled_t2r, tTR_tDQ, tTR_rDQ); + + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_reduce_dq.consumer_release(pipeline_mma_reduce_dq_consumer_state); + ++pipeline_mma_reduce_dq_consumer_state; + + // we don't have enough smem to dump it all to smem, so we do it in stages + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<2>(tTR_cDQ); i++) { + if (lane_predicate) { + pipeline_reduce_tma_store.producer_acquire(pipeline_reduce_tma_store_producer_state); + } + // wait in all threads for the acquire to complete + cutlass::arch::NamedBarrier( + kNumReduceWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::TransposeBarrier + ).arrive_and_wait(); + + cute::copy(tTR_rDQ(_, _, i), tTR_sDQ(_, _, _0{}, pipeline_reduce_tma_store_producer_state.index())); + + // wait for the stores to all be visible to the TMA + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier( + kNumReduceWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::TransposeBarrier + ).arrive_and_wait(); + if (lane_predicate) { + // launch tma store + copy(mainloop_params.tma_red_dq, tDQsDQ(_,_,_0{}, pipeline_reduce_tma_store_producer_state.index()), tDQgDQ(_,_,i,iter_index)); + pipeline_reduce_tma_store.producer_commit(pipeline_reduce_tma_store_producer_state); + } + + ++pipeline_reduce_tma_store_producer_state; + } + + iter_count -= 1; + iter_index += 1; + } + } + + + CUTLASS_DEVICE void operator()(Params const& params, char* smem) { + int warp_idx = cutlass::canonical_warp_idx_sync(); + auto role = warp_idx_to_role(warp_idx); + uint32_t lane_predicate = cute::elect_one_sync(); + + if (role == WarpRole::Load && lane_predicate) { + prefetch_tma_descriptor(params.mainloop_params.tma_load_q.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_k.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_v.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_do.get_tma_descriptor()); + } + + SharedStorage& shared_storage = *reinterpret_cast(smem); + + int initializing_warp = 0; + typename PipelineLoadMmaQ::Params pipeline_load_mma_q_params; + if (role == WarpRole::Load) { + pipeline_load_mma_q_params.role = PipelineLoadMmaQ::ThreadCategory::Producer; + } + if (role == WarpRole::Mma) { + pipeline_load_mma_q_params.role = PipelineLoadMmaQ::ThreadCategory::Consumer; + } + pipeline_load_mma_q_params.is_leader = lane_predicate && (role == WarpRole::Load); + // Also loads K in the first iteration + pipeline_load_mma_q_params.transaction_bytes = kTransactionsBytesLoadQ; + pipeline_load_mma_q_params.initializing_warp = initializing_warp++; + PipelineLoadMmaQ pipeline_load_mma_q(shared_storage.pipelines.load_mma_q, pipeline_load_mma_q_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineLoadMmaDO::Params pipeline_load_mma_do_params; + if (role == WarpRole::Load) { + pipeline_load_mma_do_params.role = PipelineLoadMmaDO::ThreadCategory::Producer; + } + if (role == WarpRole::Mma) { + pipeline_load_mma_do_params.role = PipelineLoadMmaDO::ThreadCategory::Consumer; + } + pipeline_load_mma_do_params.is_leader = lane_predicate && (role == WarpRole::Load); + // Also loads V in the first iteration + pipeline_load_mma_do_params.transaction_bytes = kTransactionsBytesLoadDO; + pipeline_load_mma_do_params.initializing_warp = initializing_warp++; + PipelineLoadMmaDO pipeline_load_mma_do(shared_storage.pipelines.load_mma_do, pipeline_load_mma_do_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineLoadComputeLSE::Params pipeline_load_compute_lse_params; + if (role == WarpRole::Load) { + pipeline_load_compute_lse_params.role = PipelineLoadComputeLSE::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_load_compute_lse_params.role = PipelineLoadComputeLSE::ThreadCategory::Consumer; + } + pipeline_load_compute_lse_params.producer_arv_count = NumThreadsPerWarp; + pipeline_load_compute_lse_params.consumer_arv_count = kNumComputeWarps * NumThreadsPerWarp; + pipeline_load_compute_lse_params.initializing_warp = initializing_warp++; + PipelineLoadComputeLSE pipeline_load_compute_lse( + shared_storage.pipelines.load_compute_lse, + pipeline_load_compute_lse_params, + /*barrier init*/ cute::true_type{}); + + typename PipelineLoadComputeSumOdO::Params pipeline_load_compute_sum_odo_params; + if (role == WarpRole::Load) { + pipeline_load_compute_sum_odo_params.role = PipelineLoadComputeSumOdO::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_load_compute_sum_odo_params.role = PipelineLoadComputeSumOdO::ThreadCategory::Consumer; + } + pipeline_load_compute_sum_odo_params.producer_arv_count = NumThreadsPerWarp; + pipeline_load_compute_sum_odo_params.consumer_arv_count = kNumComputeWarps * NumThreadsPerWarp; + pipeline_load_compute_sum_odo_params.initializing_warp = initializing_warp++; + PipelineLoadComputeSumOdO pipeline_load_compute_sum_odo( + shared_storage.pipelines.load_compute_sum_odo, + pipeline_load_compute_sum_odo_params, + /*barrier init*/ cute::true_type{}); + + typename PipelineMmaComputeS::Params pipeline_mma_compute_s_params; + if (role == WarpRole::Mma) { + pipeline_mma_compute_s_params.role = PipelineMmaComputeS::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_mma_compute_s_params.role = PipelineMmaComputeS::ThreadCategory::Consumer; + } + pipeline_mma_compute_s_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_mma_compute_s_params.initializing_warp = initializing_warp++; + PipelineMmaComputeS pipeline_mma_compute_s( + shared_storage.pipelines.mma_compute_s, + pipeline_mma_compute_s_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineMmaComputeDP::Params pipeline_mma_compute_dp_params; + if (role == WarpRole::Mma) { + pipeline_mma_compute_dp_params.role = PipelineMmaComputeDP::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_mma_compute_dp_params.role = PipelineMmaComputeDP::ThreadCategory::Consumer; + } + pipeline_mma_compute_dp_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_mma_compute_dp_params.initializing_warp = initializing_warp++; + PipelineMmaComputeDP pipeline_mma_compute_dp( + shared_storage.pipelines.mma_compute_dp, + pipeline_mma_compute_dp_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineMmaReduceDQ::Params pipeline_mma_reduce_dq_params; + if (role == WarpRole::Mma) { + pipeline_mma_reduce_dq_params.role = PipelineMmaReduceDQ::ThreadCategory::Producer; + } + if (role == WarpRole::Reduce) { + pipeline_mma_reduce_dq_params.role = PipelineMmaReduceDQ::ThreadCategory::Consumer; + } + pipeline_mma_reduce_dq_params.consumer_arv_count = kNumReduceWarps * cutlass::NumThreadsPerWarp; + pipeline_mma_reduce_dq_params.initializing_warp = initializing_warp++; + PipelineMmaReduceDQ pipeline_mma_reduce_dq( + shared_storage.pipelines.mma_reduce_dq, + pipeline_mma_reduce_dq_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineComputeMmaP::Params pipeline_compute_mma_p_params; + if (role == WarpRole::Mma) { + pipeline_compute_mma_p_params.role = PipelineComputeMmaP::ThreadCategory::Consumer; + } + if (role == WarpRole::Compute) { + pipeline_compute_mma_p_params.role = PipelineComputeMmaP::ThreadCategory::Producer; + } + pipeline_compute_mma_p_params.producer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_compute_mma_p_params.consumer_arv_count = 1; + pipeline_compute_mma_p_params.initializing_warp = initializing_warp++; + PipelineComputeMmaP pipeline_compute_mma_p( + shared_storage.pipelines.compute_mma_p, + pipeline_compute_mma_p_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineComputeMmaDS::Params pipeline_compute_mma_ds_params; + if (role == WarpRole::Mma) { + pipeline_compute_mma_ds_params.role = PipelineComputeMmaDS::ThreadCategory::Consumer; + } + if (role == WarpRole::Compute) { + pipeline_compute_mma_ds_params.role = PipelineComputeMmaDS::ThreadCategory::Producer; + } + pipeline_compute_mma_ds_params.producer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_compute_mma_ds_params.consumer_arv_count = 1; + pipeline_compute_mma_ds_params.initializing_warp = initializing_warp++; + PipelineComputeMmaDS pipeline_compute_mma_ds( + shared_storage.pipelines.compute_mma_ds, + pipeline_compute_mma_ds_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineMmaComputeDKDV::Params pipeline_mma_compute_dkdv_params; + if (role == WarpRole::Mma) { + pipeline_mma_compute_dkdv_params.role = PipelineMmaComputeDKDV::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_mma_compute_dkdv_params.role = PipelineMmaComputeDKDV::ThreadCategory::Consumer; + } + pipeline_mma_compute_dkdv_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_mma_compute_dkdv_params.initializing_warp = initializing_warp++; + PipelineMmaComputeDKDV pipeline_mma_compute_dkdv( + shared_storage.pipelines.mma_compute_dkdv, + pipeline_mma_compute_dkdv_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + PipelineReduceTmaStore pipeline_reduce_tma_store; + + TmemAllocator tmem_allocator; + + pipeline_init_arrive_relaxed(size(ClusterShape{})); + + pipeline_load_mma_q.init_masks(ClusterShape{}); + pipeline_load_mma_do.init_masks(ClusterShape{}); + pipeline_mma_compute_s.init_masks(ClusterShape{}); + pipeline_mma_compute_dp.init_masks(ClusterShape{}); + pipeline_mma_reduce_dq.init_masks(ClusterShape{}); + pipeline_compute_mma_p.init_masks(ClusterShape{}); + pipeline_compute_mma_ds.init_masks(ClusterShape{}); + pipeline_mma_compute_dkdv.init_masks(ClusterShape{}); + + typename decltype(pipeline_load_mma_q)::PipelineState pipeline_load_mma_q_consumer_state; + typename decltype(pipeline_load_mma_do)::PipelineState pipeline_load_mma_do_consumer_state; + typename decltype(pipeline_load_compute_lse)::PipelineState pipeline_load_compute_lse_consumer_state; + typename decltype(pipeline_load_compute_sum_odo)::PipelineState pipeline_load_compute_sum_odo_consumer_state; + typename decltype(pipeline_mma_compute_s)::PipelineState pipeline_mma_compute_s_consumer_state; + typename decltype(pipeline_mma_compute_dp)::PipelineState pipeline_mma_compute_dp_consumer_state; + typename decltype(pipeline_mma_reduce_dq)::PipelineState pipeline_mma_reduce_dq_consumer_state; + typename decltype(pipeline_compute_mma_p)::PipelineState pipeline_compute_mma_p_consumer_state; + typename decltype(pipeline_compute_mma_ds)::PipelineState pipeline_compute_mma_ds_consumer_state; + typename decltype(pipeline_mma_compute_dkdv)::PipelineState pipeline_mma_compute_dkdv_consumer_state; + + auto pipeline_load_mma_q_producer_state = make_producer_start_state(); + auto pipeline_load_mma_do_producer_state = make_producer_start_state(); + auto pipeline_load_compute_lse_producer_state = make_producer_start_state(); + auto pipeline_load_compute_sum_odo_producer_state = make_producer_start_state(); + auto pipeline_mma_compute_s_producer_state = make_producer_start_state(); + auto pipeline_mma_compute_dp_producer_state = make_producer_start_state(); + auto pipeline_mma_reduce_dq_producer_state = make_producer_start_state(); + auto pipeline_compute_mma_p_producer_state = make_producer_start_state(); + auto pipeline_compute_mma_ds_producer_state = make_producer_start_state(); + auto pipeline_mma_compute_dkdv_producer_state = make_producer_start_state(); + auto pipeline_reduce_tma_store_producer_state = make_producer_start_state(); + + pipeline_init_wait(size(ClusterShape{})); + + auto blk_coord = make_coord(_0{}, blockIdx.x, _0{}, _0{}, make_coord(blockIdx.y, blockIdx.z)); + auto [problem_shape, blk_offset] = apply_variable_length_offset( + params.problem_shape, + blk_coord + ); + int iter_count = ceil_div(get<0>(problem_shape), TileShapeQ{}); + int iter_start = 0; + if constexpr (std::is_base_of_v, Mask>) { + iter_start = (get<1>(blk_coord) * TileShapeK{}) / TileShapeQ{}; + } else if constexpr (std::is_base_of_v, Mask>) { + int offset = get<1>(problem_shape) - get<0>(problem_shape); + iter_start = max(0, (int(get<1>(blk_coord) * TileShapeK{}) - offset) / (int)TileShapeQ{}); + } + if (get<1>(blk_coord) * TileShapeK{} >= get<1>(problem_shape)) { + return; + } + iter_count -= iter_start; + + if (iter_count <= 0) { + epilogue_clear( + blk_coord, + blk_offset, + problem_shape, + params.mainloop, + params.epilogue + ); + return; + } + + if (role == WarpRole::Load) { + warpgroup_reg_set(); + + load( + blk_coord, + blk_offset, + problem_shape, + iter_start, + iter_count, + params.mainloop, + params.mainloop_params, + shared_storage.tensors, + pipeline_load_mma_q, pipeline_load_mma_q_producer_state, + pipeline_load_mma_do, pipeline_load_mma_do_producer_state, + pipeline_load_compute_lse, pipeline_load_compute_lse_producer_state, + pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_producer_state + ); + + } + else if (role == WarpRole::Mma) { + warpgroup_reg_set(); + + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + + mma( + blk_coord, + problem_shape, + iter_start, + iter_count, + params.mainloop, + shared_storage.tensors, + pipeline_load_mma_q, pipeline_load_mma_q_consumer_state, + pipeline_load_mma_do, pipeline_load_mma_do_consumer_state, + pipeline_mma_compute_s, pipeline_mma_compute_s_producer_state, + pipeline_mma_compute_dp, pipeline_mma_compute_dp_producer_state, + pipeline_mma_reduce_dq, pipeline_mma_reduce_dq_producer_state, + pipeline_compute_mma_p, pipeline_compute_mma_p_consumer_state, + pipeline_compute_mma_ds, pipeline_compute_mma_ds_consumer_state, + pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_producer_state + ); + + } + else if (role == WarpRole::Compute) { + warpgroup_reg_set(); + + compute( + blk_coord, + blk_offset, + problem_shape, + iter_start, + iter_count, + params.mainloop, + params.epilogue, + shared_storage.tensors, + pipeline_load_compute_lse, pipeline_load_compute_lse_consumer_state, + pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_consumer_state, + pipeline_mma_compute_s, pipeline_mma_compute_s_consumer_state, + pipeline_mma_compute_dp, pipeline_mma_compute_dp_consumer_state, + pipeline_compute_mma_p, pipeline_compute_mma_p_producer_state, + pipeline_compute_mma_ds, pipeline_compute_mma_ds_producer_state, + pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_consumer_state + ); + + cutlass::arch::NamedBarrier( + kNumComputeWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier + ).arrive_and_wait(); + + if (warp_idx % kNumComputeWarps == 0) { + uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + + } + else if (role == WarpRole::Reduce) { + warpgroup_reg_set(); + + reduce( + blk_coord, + problem_shape, + iter_start, + iter_count, + params.mainloop, + params.mainloop_params, + shared_storage.tensors, + pipeline_mma_reduce_dq, pipeline_mma_reduce_dq_consumer_state, + pipeline_reduce_tma_store, pipeline_reduce_tma_store_producer_state + ); + + pipeline_reduce_tma_store.producer_tail(pipeline_reduce_tma_store_producer_state); + } + else { + warpgroup_reg_set(); + + /* no-op */ + + } + } + + static dim3 get_block_shape() { + dim3 block(MaxThreadsPerBlock, 1, 1); + return block; + } + + static dim3 get_grid_shape(Params const& params) { + auto [Q, K, D, D_VO, HB] = params.problem_shape; + auto [H, B] = HB; + dim3 grid(ceil_div(K, TileShapeK{}), H, B); + return grid; + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/csrc/sm100/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp b/csrc/sm100/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp new file mode 100644 index 0000000..5a58157 --- /dev/null +++ b/csrc/sm100/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp @@ -0,0 +1,1834 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cute/arch/simd_sm100.hpp" + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "collective/fmha_common.hpp" + +#include + +namespace cutlass::fmha::kernel { + +using namespace cutlass::fmha::collective; + +using namespace cute; + +template< + class ProblemShape, + class Element, + class ElementAcc, + class TileShape, + class Mask +> +struct Sm100FmhaBwdMlaKernelTmaWarpSpecialized { + + using TileShapeQ = decltype(get<0>(TileShape{})); + using TileShapeK = decltype(get<1>(TileShape{})); + using TileShapeDQK = decltype(get<2>(TileShape{})); + using TileShapeDVO = decltype(get<3>(TileShape{})); + + using TmemAllocator = cute::TMEM::Allocator1Sm; + struct TmemAllocation { + static constexpr uint32_t kDK = 0; // TileShapeK x TileShapeDQK x acc + static constexpr uint32_t kDV = kDK + TileShapeDQK{}; // TileShapeK x TileShapeDVO x acc + static constexpr uint32_t kDQ = kDV + TileShapeDVO{}; // TileShapeQ x TileShapeDQK x acc + static constexpr uint32_t kDP = kDQ; // TileShapeK x TileShapeQ x inp + static constexpr uint32_t kS = kDQ + 65536 * 16; + static constexpr uint32_t kP = kS; + static constexpr uint32_t kTotal = kDQ + TileShapeDQK{}; + }; + + static_assert( + static_cast(TmemAllocation::kTotal) <= TmemAllocator::Sm100TmemCapacityColumns, + "using too much tmem" + ); + + enum class WarpRole { + Empty = 0x0, Load = 0x1, Mma = 0x2, Compute = 0x3, Reduce = 0x4 + }; + + static constexpr unsigned long long kWarpAssignment = 0x12'3333'3333'4444ull; + static constexpr int kNumComputeWarps = 8; + static constexpr int kNumReduceWarps = 4; + + static constexpr int kLoadPerThread = TileShapeQ{} / NumThreadsPerWarp; + static_assert(TileShapeQ{} % NumThreadsPerWarp == 0, "TileShapeQ must be divisible by NumThreadsPerWarp"); + CUTLASS_DEVICE WarpRole warp_idx_to_role(int warp_idx) { + return static_cast((kWarpAssignment >> (4 * warp_idx)) & 0xF); + } + + struct RegisterAllocation { + static constexpr int kWarpgroup0 = 160-8; + static constexpr int kWarpgroup1 = 128; + static constexpr int kWarpgroup2 = 96; + static constexpr int kReduce = kWarpgroup0; + static constexpr int kCompute = kWarpgroup1; + static constexpr int kMma = kWarpgroup2; + static constexpr int kEmpty = kWarpgroup2; + static constexpr int kLoad = kWarpgroup2; + + static_assert(kWarpgroup0 + 2 * kWarpgroup1 + kWarpgroup2 <= 512); + }; + + using ArchTag = cutlass::arch::Sm100; + + using ClusterShape = Shape<_1, _1, _1>; + using Schedule = cutlass::gemm::KernelTmaWarpSpecialized1SmSm100; + + static constexpr int MinBlocksPerMultiprocessor = 1; + static constexpr int kNumWarps = kNumComputeWarps + kNumReduceWarps + 4; + static constexpr int MaxThreadsPerBlock = NumThreadsPerWarp * kNumWarps; + + static constexpr int Alignment = 128 / sizeof_bits_v; + static constexpr int kStages = 2; + + using TensorStrideContiguousK = Stride>; + using TensorStrideContiguousMN = Stride<_1, int, Stride>; + + // compute S + using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, TensorStrideContiguousK, Alignment, + Element, TensorStrideContiguousK, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapeQK = typename CollectiveMmaQK::TileShape; + using TiledMmaQK = typename CollectiveMmaQK::TiledMma; + + // compute dP + using CollectiveMmaDOV = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, TensorStrideContiguousK, Alignment, + Element, TensorStrideContiguousK, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapeDOV = typename CollectiveMmaDOV::TileShape; + using TiledMmaDOV = typename CollectiveMmaDOV::TiledMma; + + // compute dV + using CollectiveMmaPDO = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // needs to match ordering of S calculation + Element, TensorStrideContiguousK, Alignment, + Element, TensorStrideContiguousMN, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapePDO = typename CollectiveMmaPDO::TileShape; + using TiledMmaPDO = typename CollectiveMmaPDO::TiledMma; + + // compute dK + using CollectiveMmaDSQ = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // somewhat arbitrary since we dump to smem, need to agree with the next one + Element, TensorStrideContiguousK , Alignment, + Element, TensorStrideContiguousMN, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapeDSQ = typename CollectiveMmaDSQ::TileShape; + using TiledMmaDSQ = typename CollectiveMmaDSQ::TiledMma; + + // compute dQ + using CollectiveMmaDSK = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // somewhat arbitrary since we dump to smem, need to agree with the previous one + Element, TensorStrideContiguousMN, Alignment, + Element, TensorStrideContiguousMN, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapeDSK = typename CollectiveMmaDSK::TileShape; + using TiledMmaDSK = typename CollectiveMmaDSK::TiledMma; + + // pipelines are named Pipeline + static constexpr int kStagesComputeSmem = 1; + using PipelineLoadMmaQ = PipelineTmaUmmaAsync<2, ClusterShape>; + using PipelineLoadMmaDO = PipelineTmaUmmaAsync<1, ClusterShape>; + using PipelineLoadComputeLSE = PipelineAsync<1>; + using PipelineLoadComputeSumOdO = PipelineAsync<1>; + using PipelineMmaComputeS = PipelineUmmaAsync<1>; + using PipelineMmaComputeDP = PipelineUmmaAsync<1>; + using PipelineMmaReduceDQ = PipelineUmmaAsync<1>; + using PipelineComputeMmaP = PipelineUmmaConsumerAsync<1>; + using PipelineComputeMmaDS = PipelineUmmaConsumerAsync; + using PipelineMmaComputeDKDV = PipelineUmmaAsync<2>; + static constexpr int kStagesReduceTmaStore = 2; + using PipelineReduceTmaStore = PipelineTmaStore; + + struct PipelineStorage { + alignas(16) typename PipelineLoadMmaQ::SharedStorage load_mma_q; + alignas(16) typename PipelineLoadMmaDO::SharedStorage load_mma_do; + alignas(16) typename PipelineLoadComputeLSE::SharedStorage load_compute_lse; + alignas(16) typename PipelineLoadComputeSumOdO::SharedStorage load_compute_sum_odo; + alignas(16) typename PipelineMmaComputeS::SharedStorage mma_compute_s; + alignas(16) typename PipelineMmaComputeDP::SharedStorage mma_compute_dp; + alignas(16) typename PipelineMmaReduceDQ::SharedStorage mma_reduce_dq; + alignas(16) typename PipelineComputeMmaP::SharedStorage compute_mma_p; + alignas(16) typename PipelineComputeMmaDS::SharedStorage compute_mma_ds; + alignas(16) typename PipelineMmaComputeDKDV::SharedStorage mma_compute_dkdv; + }; + + template + static CUTE_DEVICE constexpr auto restage(Layout const& layout, Stages stages = {}) { + return composition(layout, make_tuple(_, _, _, make_layout(stages))); + } + + using SmemLayoutK = decltype(restage(typename CollectiveMmaQK::SmemLayoutB{})); + using SmemLayoutV = decltype(restage(typename CollectiveMmaDOV::SmemLayoutB{})); + using SmemLayoutQ = decltype(restage(typename CollectiveMmaQK::SmemLayoutA{}, _2{})); + using SmemLayoutDO = decltype(restage(typename CollectiveMmaDOV::SmemLayoutA{}, _1{})); + using SmemLayoutDS = decltype(restage(typename CollectiveMmaDSK::SmemLayoutA{}, Int{})); + using SmemLayoutLSE = Layout>; + using SmemLayoutSumOdO = Layout>; + + using SmemLayoutQT = decltype(restage(typename CollectiveMmaDSQ::SmemLayoutB{}, _2{})); + using SmemLayoutKT = decltype(restage(typename CollectiveMmaDSK::SmemLayoutB{})); + using SmemLayoutDST = decltype(restage(typename CollectiveMmaDSQ::SmemLayoutA{}, Int{})); + using SmemLayoutDOT = decltype(restage(typename CollectiveMmaPDO::SmemLayoutB{}, _1{})); + using SmemLayoutP = decltype(restage(typename CollectiveMmaPDO::SmemLayoutA{}, _1{})); + using SmemLayoutPT = decltype(restage(typename CollectiveMmaDSK::SmemLayoutA{}, _1{})); + + using TileShapeDQ = _32; + using SmemAtomDQ = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + cute::UMMA::Major::K, ElementAcc, TileShapeQ, TileShapeDQ + >()); + using SmemShapeDQ = Shape>; + using SmemLayoutDQ = decltype(tile_to_shape(SmemAtomDQ{}, SmemShapeDQ{}, Step<_2, _1, _3>{})); + + struct TensorStorage { + union { + alignas(2048) cute::array> smem_k; + alignas(2048) cute::array> smem_k_t; + }; + alignas(2048) cute::array> smem_v; + union { + alignas(2048) cute::array> smem_q; + alignas(2048) cute::array> smem_q_t; + }; + union { + alignas(2048) cute::array> smem_do; + alignas(2048) cute::array> smem_do_t; + }; + union { + alignas(2048) cute::array> smem_ds; + alignas(2048) cute::array> smem_ds_t; + }; + union{ + alignas(2048) cute::array> smem_p; + alignas(2048) cute::array> smem_p_t; + }; + alignas(1024) cute::array> smem_dq; + alignas(16) cute::array> smem_lse; + alignas(16) cute::array> smem_sum_odo; + }; + + static constexpr int kTransactionsBytesLoadQ = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v); + static constexpr int kTransactionsBytesLoadDO = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutDO{})) * cute::sizeof_bits_v); + + static constexpr int kTransactionsBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v); + static constexpr int kTransactionsBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v); + + struct SharedStorage { + TensorStorage tensors; + PipelineStorage pipelines; + uint32_t tmem_base_ptr; + }; + + // this is tight enough that it won't work with sizeof due to padding for alignment + static constexpr int SharedStorageSize = offsetof(SharedStorage, tmem_base_ptr) + sizeof(uint32_t); + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "using too much smem"); + + using TensorStride = TensorStrideContiguousK; // S D (H B) + using RowTensorStride = Stride<_1, Stride>; // S (H B) + + struct MainloopArguments { + const Element* ptr_q; + TensorStride stride_q; + const Element* ptr_k; + TensorStride stride_k; + const Element* ptr_v; + TensorStride stride_v; + const Element* ptr_do; + TensorStride stride_do; + + const ElementAcc* ptr_lse; + RowTensorStride stride_lse; + + const ElementAcc* ptr_sum_odo; + RowTensorStride stride_sum_odo; + + ElementAcc* ptr_dq_acc; + TensorStride stride_dq_acc; + + ElementAcc softmax_scale = 1.0f / sqrtf(TileShapeDQK{}); + }; + + using TMA_K = typename CollectiveMmaQK::Params::TMA_B; + using TMA_V = typename CollectiveMmaDOV::Params::TMA_B; + using TMA_Q = typename CollectiveMmaQK::Params::TMA_A; + using TMA_DO = typename CollectiveMmaDOV::Params::TMA_A; + + using TMA_DQ = decltype(make_tma_copy(SM90_TMA_REDUCE_ADD{}, + make_tensor((const ElementAcc*)nullptr, make_shape(1, 1, make_shape(1, 1)), TensorStride{}), + SmemLayoutDQ{}(_, _, _0{}) + )); + + struct MainloopParams { + TMA_K tma_load_k; + TMA_V tma_load_v; + TMA_Q tma_load_q; + TMA_DO tma_load_do; + TMA_DQ tma_red_dq; + }; + + struct EpilogueArguments { + Element* ptr_dk; + TensorStride stride_dk; + Element* ptr_dv; + TensorStride stride_dv; + }; + + struct Arguments { + ProblemShape problem_shape; + MainloopArguments mainloop; + EpilogueArguments epilogue; + KernelHardwareInfo hw_info; + }; + + struct Params { + ProblemShape problem_shape; + MainloopArguments mainloop; + MainloopParams mainloop_params; + EpilogueArguments epilogue; + KernelHardwareInfo hw_info; + }; + + + static bool can_implement(Arguments const& args) { + auto [Q, K, D, D_VO, HB] = args.problem_shape; + auto [H, B] = HB; + if (Q <= 0 || K <= 0 || D <= 0 || H <= 0 || B <= 0 || D_VO <= 0) { + return false; + } + if (D % Alignment != 0 || D_VO % Alignment != 0) { + return false; + } + return true; + } + + + static Status initialize_workspace(Arguments const&, void*, cudaStream_t) { + return Status::kSuccess; + } + + + static Params to_underlying_arguments(Arguments const& args, void*) { + auto [Q_, K_, D, D_VO, HB] = args.problem_shape; + int Q = Q_; + int K = K_; + + if constexpr (is_variable_length_v) { + Q = Q_.total_length; + } + if constexpr (is_variable_length_v) { + K = K_.total_length; + } + + auto params_kq = CollectiveMmaQK::to_underlying_arguments( + make_shape(Q, K, D, HB), + typename CollectiveMmaQK::Arguments { + args.mainloop.ptr_q, args.mainloop.stride_q, + args.mainloop.ptr_k, args.mainloop.stride_k, + }, /*workspace=*/nullptr); + + auto params_vdo = CollectiveMmaDOV::to_underlying_arguments( + make_shape(Q, K, D_VO, HB), + typename CollectiveMmaDOV::Arguments { + args.mainloop.ptr_do, args.mainloop.stride_do, + args.mainloop.ptr_v, args.mainloop.stride_v, + }, /*workspace=*/nullptr); + + TMA_DQ tma_red_dq = make_tma_copy( + SM90_TMA_REDUCE_ADD{}, + make_tensor(args.mainloop.ptr_dq_acc, make_shape(Q_, D, HB), args.mainloop.stride_dq_acc), + SmemLayoutDQ{}(_, _, _0{}) + ); + + return Params{ + args.problem_shape, + args.mainloop, + MainloopParams{ + params_kq.tma_load_b, + params_vdo.tma_load_b, + params_kq.tma_load_a, + params_vdo.tma_load_a, + tma_red_dq + }, + args.epilogue, + args.hw_info + }; + } + + + template + static CUTLASS_DEVICE auto quantize(T const& input) { + constexpr int AlignmentS = 4; + auto output = make_tensor(shape(input)); + auto input_vec = recast>(input); + auto output_vec = recast>(output); + + cutlass::NumericArrayConverter epilogue_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(input_vec); i++) { + output_vec(i) = epilogue_op(input_vec(i)); + } + + return output; + } + + + template + CUTLASS_DEVICE void load( + BlkCoord const& blk_coord, + BlkOffset const& blk_offset, + ProblemShape_ const& problem_shape, + int iter_index, + int iter_count, + MainloopArguments const& mainloop_args, + MainloopParams const& mainloop_params, + TensorStorage& shared_tensors, + PipelineLoadMmaQ& pipeline_load_mma_q, + typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_producer_state, + PipelineLoadMmaDO& pipeline_load_mma_do, + typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_producer_state, + PipelineLoadComputeLSE& pipeline_load_compute_lse, + typename PipelineLoadComputeLSE::PipelineState& pipeline_load_compute_lse_producer_state, + PipelineLoadComputeSumOdO& pipeline_load_compute_sum_odo, + typename PipelineLoadComputeSumOdO::PipelineState& pipeline_load_compute_sum_odo_producer_state) { + + auto [Q, K, D, D_VO, HB] = problem_shape; + + using X = Underscore; + + uint16_t mcast_mask = 0; + + auto mK_in = mainloop_params.tma_load_k.get_tma_tensor(make_shape(K, D, HB)); + auto mV_in = mainloop_params.tma_load_v.get_tma_tensor(make_shape(K, D_VO, HB)); + auto mQ_in = mainloop_params.tma_load_q.get_tma_tensor(make_shape(Q, D, HB)); + auto mDO_in = mainloop_params.tma_load_do.get_tma_tensor(make_shape(Q, D_VO, HB)); + + auto mK = domain_offset(select<1,2,4>(blk_offset), mK_in); + auto mV = domain_offset(select<1,3,4>(blk_offset), mV_in); + auto mQ = domain_offset(select<0,2,4>(blk_offset), mQ_in); + auto mDO = domain_offset(select<0,3,4>(blk_offset), mDO_in); + + auto gK = local_tile(mK, TileShapeQK{}, make_coord(_,_,_), Step{}); + auto gQ = local_tile(mQ, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{}); + auto gV = local_tile(mV, TileShapeDOV{}, make_coord(_,_,_), Step{}); + auto gDO = local_tile(mDO, TileShapeDOV{}, make_coord(_,_,_), Step<_1, X, _1>{}); + + ThrMMA cta_mma_kq = TiledMmaQK{}.get_slice(_0{}); + ThrMMA cta_mma_vdo = TiledMmaDOV{}.get_slice(_0{}); + + auto tSTgK = cta_mma_kq.partition_B(gK); + auto tSTgQ = cta_mma_kq.partition_A(gQ); + auto tDPTgV = cta_mma_vdo.partition_B(gV); + auto tDPTgDO = cta_mma_vdo.partition_A(gDO); + + auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + auto sK = make_tensor(make_smem_ptr(shared_tensors.smem_k.begin()), SmemLayoutK{}); + auto sV = make_tensor(make_smem_ptr(shared_tensors.smem_v.begin()), SmemLayoutV{}); + auto sDO = make_tensor(make_smem_ptr(shared_tensors.smem_do.begin()), SmemLayoutDO{}); + + auto [tKgK_mkl, tKsK] = tma_partition( + mainloop_params.tma_load_k, _0{}, make_layout(_1{}), + group_modes<0,3>(sK), group_modes<0,3>(tSTgK)); + auto [tQgQ_mkl, tQsQ] = tma_partition( + mainloop_params.tma_load_q, _0{}, make_layout(_1{}), + group_modes<0,3>(sQ), group_modes<0,3>(tSTgQ)); + auto [tVgV_mkl, tVsV] = tma_partition( + mainloop_params.tma_load_v, _0{}, make_layout(_1{}), + group_modes<0,3>(sV), group_modes<0,3>(tDPTgV)); + auto [tDOgDO_mkl, tDOsDO] = tma_partition( + mainloop_params.tma_load_do, _0{}, make_layout(_1{}), + group_modes<0,3>(sDO), group_modes<0,3>(tDPTgDO)); + + // set up lse and sum_odo + + auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; + + pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state); + auto tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state); + + pipeline_load_mma_q.producer_expect_transaction(pipeline_load_mma_q_producer_state, kTransactionsBytesLoadK); + + // load K + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_k.with(*tma_barrier, mcast_mask), + tKgK_mkl(_, blk_coord_k, _0{}, blk_coord_batch), + tKsK(_, _0{}) + ); + } + + // load Q + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask), + tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch), + tQsQ(_, pipeline_load_mma_q_producer_state.index()) + ); + } + + ++pipeline_load_mma_q_producer_state; + + pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state); + + // load LSE + // 32 threads loading kLoadPerThread * 32 values of 32b each + + int thread_idx = threadIdx.x % NumThreadsPerWarp; + int smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * kLoadPerThread; + int gmem_idx = TileShapeQ{} * iter_index + thread_idx * kLoadPerThread; + auto mLSE = make_tensor(mainloop_args.ptr_lse, make_shape(Q, HB), mainloop_args.stride_lse); + for (int i = 0; i < kLoadPerThread; i++) { + cutlass::arch::cp_async_zfill<4>( + shared_tensors.smem_lse.begin() + smem_idx + i, + &mLSE(gmem_idx + i, blk_coord_batch), + gmem_idx + i < Q + ); + } + + pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_load_compute_lse_producer_state; + + + pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state); + tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state); + + pipeline_load_mma_do.producer_expect_transaction(pipeline_load_mma_do_producer_state, kTransactionsBytesLoadV); + + // load V + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_v.with(*tma_barrier, mcast_mask), + tVgV_mkl(_, blk_coord_k, _0{}, blk_coord_batch), + tVsV(_, _0{}) + ); + } + + // load dO + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask), + tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch), + tDOsDO(_, pipeline_load_mma_do_producer_state.index()) + ); + } + + ++pipeline_load_mma_do_producer_state; + + pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state); + + // load sum_OdO + smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * kLoadPerThread; + gmem_idx = TileShapeQ{} * iter_index + thread_idx * kLoadPerThread; + auto mSumOdO = make_tensor(mainloop_args.ptr_sum_odo, make_shape(Q, HB), mainloop_args.stride_sum_odo); + for (int i = 0; i < kLoadPerThread; i++) { + cutlass::arch::cp_async_zfill<4>( + shared_tensors.smem_sum_odo.begin() + smem_idx + i, + &mSumOdO(gmem_idx + i, blk_coord_batch), + gmem_idx + i < Q + ); + } + + pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_load_compute_sum_odo_producer_state; + + iter_count -= 1; + iter_index += 1; + + while (iter_count > 0) { + pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state); + tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state); + + // load Q + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask), + tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch), + tQsQ(_, pipeline_load_mma_q_producer_state.index()) + ); + } + + ++pipeline_load_mma_q_producer_state; + + pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state); + + // load LSE + smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * kLoadPerThread; + gmem_idx = TileShapeQ{} * iter_index + thread_idx * kLoadPerThread; + for (int i = 0; i < kLoadPerThread; i++) { + cutlass::arch::cp_async_zfill<4>( + shared_tensors.smem_lse.begin() + smem_idx + i, + &mLSE(gmem_idx + i, blk_coord_batch), + gmem_idx + i < Q + ); + } + + pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_load_compute_lse_producer_state; + + pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state); + tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state); + + // load dO + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask), + tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch), + tDOsDO(_, pipeline_load_mma_do_producer_state.index()) + ); + } + + ++pipeline_load_mma_do_producer_state; + + pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state); + + // load sum_OdO + smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * kLoadPerThread; + gmem_idx = TileShapeQ{} * iter_index + thread_idx * kLoadPerThread; + for (int i = 0; i < kLoadPerThread; i++) { + cutlass::arch::cp_async_zfill<4>( + shared_tensors.smem_sum_odo.begin() + smem_idx + i, + &mSumOdO(gmem_idx + i, blk_coord_batch), + gmem_idx + i < Q + ); + } + + pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_load_compute_sum_odo_producer_state; + + iter_count -= 1; + iter_index += 1; + } + } + + + template + CUTLASS_DEVICE void mma( + BlkCoord const& blk_coord, + ProblemShape_ const& problem_shape, + int iter_index, + int iter_count, + MainloopArguments const& mainloop_args, + TensorStorage& shared_tensors, + PipelineLoadMmaQ& pipeline_load_mma_q, + typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_consumer_state, + PipelineLoadMmaDO& pipeline_load_mma_do, + typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_consumer_state, + PipelineMmaComputeS& pipeline_mma_compute_s, + typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_producer_state, + PipelineMmaComputeDP& pipeline_mma_compute_dp, + typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_producer_state, + PipelineMmaReduceDQ& pipeline_mma_reduce_dq, + typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_producer_state, + PipelineComputeMmaP& pipeline_compute_mma_p, + typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_consumer_state, + PipelineComputeMmaDS& pipeline_compute_mma_ds, + typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_consumer_state, + PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, + typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_producer_state) { + + auto [Q, K, D, D_VO, HB] = problem_shape; + + auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + auto sK = make_tensor(make_smem_ptr(shared_tensors.smem_k.begin()), SmemLayoutK{}); + auto sV = make_tensor(make_smem_ptr(shared_tensors.smem_v.begin()), SmemLayoutV{}); + auto sDO = make_tensor(make_smem_ptr(shared_tensors.smem_do.begin()), SmemLayoutDO{}); + + auto sQT = make_tensor(make_smem_ptr(shared_tensors.smem_q_t.begin()), SmemLayoutQT{}); + auto sKT = make_tensor(make_smem_ptr(shared_tensors.smem_k_t.begin()), SmemLayoutKT{}); + auto sDS = make_tensor(make_smem_ptr(shared_tensors.smem_ds.begin()), SmemLayoutDS{}); + auto sDST = make_tensor(make_smem_ptr(shared_tensors.smem_ds_t.begin()), SmemLayoutDST{}); + auto sP = make_tensor(make_smem_ptr(shared_tensors.smem_p.begin()), SmemLayoutP{}); + auto sDOT = make_tensor(make_smem_ptr(shared_tensors.smem_do_t.begin()), SmemLayoutDOT{}); + + Tensor tSTrK = TiledMmaQK::make_fragment_B(sK); + Tensor tSTrQ = TiledMmaQK::make_fragment_A(sQ); + + Tensor tDPTrV = TiledMmaDOV::make_fragment_B(sV); + Tensor tDPTrDO = TiledMmaDOV::make_fragment_A(sDO); + + Tensor tDQrDS = TiledMmaDSK::make_fragment_A(sDS); + Tensor tDQrKT = TiledMmaDSK::make_fragment_B(sKT); + + Tensor tDKrDST = TiledMmaDSQ::make_fragment_A(sDST); + Tensor tDKrQT = TiledMmaDSQ::make_fragment_B(sQT); + + Tensor tDVrP = TiledMmaPDO::make_fragment_A(sP); + Tensor tDVrDOT = TiledMmaPDO::make_fragment_B(sDOT); + + TiledMmaQK tiled_mma_qk; + TiledMmaDOV tiled_mma_dov; + TiledMmaDSK tiled_mma_dsk; + TiledMmaDSQ tiled_mma_dsq; + TiledMmaPDO tiled_mma_pdo; + + tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::Zero; + tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::Zero; + + Tensor tSTtST = partition_fragment_C(tiled_mma_qk, select<0,1>(TileShapeQK{})); + tSTtST.data() = TmemAllocation::kS; + + Tensor tDPTtDPT = partition_fragment_C(tiled_mma_dov, select<0,1>(TileShapeDOV{})); + tDPTtDPT.data() = TmemAllocation::kDP; + + Tensor tDQtDQ = partition_fragment_C(tiled_mma_dsk, select<0,1>(TileShapeDSK{})); + tDQtDQ.data() = TmemAllocation::kDQ; + + Tensor tDKtDK = partition_fragment_C(tiled_mma_dsq, select<0,1>(TileShapeDSQ{})); + tDKtDK.data() = TmemAllocation::kDK; + + Tensor tDVtDV = partition_fragment_C(tiled_mma_pdo, select<0,1>(TileShapePDO{})); + tDVtDV.data() = TmemAllocation::kDV; + + auto pipeline_load_mma_q_release_state = pipeline_load_mma_q_consumer_state; + + pipeline_load_mma_q.consumer_wait(pipeline_load_mma_q_consumer_state); + pipeline_mma_compute_s.producer_acquire(pipeline_mma_compute_s_producer_state); + + // S = Q*K + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tSTrQ); ++k_block) { + cute::gemm(tiled_mma_qk, + tSTrQ(_,_,k_block,pipeline_load_mma_q_consumer_state.index()), + tSTrK(_,_,k_block,_0{}), + tSTtST); + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::One; + } + + ++pipeline_load_mma_q_consumer_state; + + pipeline_mma_compute_s.producer_commit(pipeline_mma_compute_s_producer_state); + ++pipeline_mma_compute_s_producer_state; + + pipeline_load_mma_do.consumer_wait(pipeline_load_mma_do_consumer_state); + + pipeline_mma_compute_dp.producer_acquire(pipeline_mma_compute_dp_producer_state); + pipeline_mma_reduce_dq.producer_acquire(pipeline_mma_reduce_dq_producer_state); + + // dP = dO*V + tiled_mma_dov.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDPTrV); ++k_block) { + cute::gemm(tiled_mma_dov, + tDPTrDO(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), + tDPTrV(_,_,k_block,_0{}), + tDPTtDPT); + tiled_mma_dov.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_mma_compute_dp.producer_commit(pipeline_mma_compute_dp_producer_state); + ++pipeline_mma_compute_dp_producer_state; + + pipeline_compute_mma_p.consumer_wait(pipeline_compute_mma_p_consumer_state); + + // dV = P*dO + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDVrP); ++k_block) { + cute::gemm(tiled_mma_pdo, + tDVrP(_,_,k_block,_0{}), + tDVrDOT(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), + tDVtDV); + tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_compute_mma_p.consumer_release(pipeline_compute_mma_p_consumer_state); + ++pipeline_compute_mma_p_consumer_state; + + pipeline_load_mma_do.consumer_release(pipeline_load_mma_do_consumer_state); + ++pipeline_load_mma_do_consumer_state; + + iter_count -= 1; + + // in tmem, S & P overlap + // and dP and dQ overlap + // so we need to acquire dQ and dP at the same time + while (iter_count > 0) { + pipeline_load_mma_q.consumer_wait(pipeline_load_mma_q_consumer_state); + pipeline_mma_compute_s.producer_acquire(pipeline_mma_compute_s_producer_state); + + // S = Q*K + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tSTrQ); ++k_block) { + cute::gemm(tiled_mma_qk, + tSTrQ(_,_,k_block,pipeline_load_mma_q_consumer_state.index()), + tSTrK(_,_,k_block,_0{}), + tSTtST); + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::One; + } + + ++pipeline_load_mma_q_consumer_state; + + pipeline_mma_compute_s.producer_commit(pipeline_mma_compute_s_producer_state); + ++pipeline_mma_compute_s_producer_state; + + pipeline_compute_mma_ds.consumer_wait(pipeline_compute_mma_ds_consumer_state); + + // we need to acquire dP here, because tmem dQ == tmem dP + pipeline_mma_compute_dp.producer_acquire(pipeline_mma_compute_dp_producer_state); + + // dQ = dS*K + tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDQrDS); ++k_block) { + cute::gemm(tiled_mma_dsk, + tDQrDS(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), + tDQrKT(_,_,k_block,_0{}), + tDQtDQ); + tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_mma_reduce_dq.producer_commit(pipeline_mma_reduce_dq_producer_state); + ++pipeline_mma_reduce_dq_producer_state; + + // dK = dS*Q + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDKrDST); ++k_block) { + cute::gemm(tiled_mma_dsq, + tDKrDST(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), + tDKrQT(_,_,k_block,pipeline_load_mma_q_release_state.index()), + tDKtDK); + tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_load_mma_q.consumer_release(pipeline_load_mma_q_release_state); + ++pipeline_load_mma_q_release_state; + + pipeline_compute_mma_ds.consumer_release(pipeline_compute_mma_ds_consumer_state); + ++pipeline_compute_mma_ds_consumer_state; + + // we grab dq here, because in tmem dq == dp + pipeline_mma_reduce_dq.producer_acquire(pipeline_mma_reduce_dq_producer_state); + + pipeline_load_mma_do.consumer_wait(pipeline_load_mma_do_consumer_state); + + // dP = dO*V + tiled_mma_dov.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDPTrV); ++k_block) { + cute::gemm(tiled_mma_dov, + tDPTrDO(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), + tDPTrV(_,_,k_block,_0{}), + tDPTtDPT); + tiled_mma_dov.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_mma_compute_dp.producer_commit(pipeline_mma_compute_dp_producer_state); + ++pipeline_mma_compute_dp_producer_state; + + pipeline_compute_mma_p.consumer_wait(pipeline_compute_mma_p_consumer_state); + + // dV = P*dO + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDVrP); ++k_block) { + cute::gemm(tiled_mma_pdo, + tDVrP(_,_,k_block,_0{}), + tDVrDOT(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), + tDVtDV); + tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_compute_mma_p.consumer_release(pipeline_compute_mma_p_consumer_state); + ++pipeline_compute_mma_p_consumer_state; + + pipeline_load_mma_do.consumer_release(pipeline_load_mma_do_consumer_state); + ++pipeline_load_mma_do_consumer_state; + + iter_count -= 1; + } + + // signal to the epilogue that dV is ready + pipeline_mma_compute_dkdv.producer_acquire(pipeline_mma_compute_dkdv_producer_state); + pipeline_mma_compute_dkdv.producer_commit(pipeline_mma_compute_dkdv_producer_state); + ++pipeline_mma_compute_dkdv_producer_state; + + pipeline_mma_compute_dkdv.producer_acquire(pipeline_mma_compute_dkdv_producer_state); + + pipeline_compute_mma_ds.consumer_wait(pipeline_compute_mma_ds_consumer_state); + + // dK = dS*Q + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDKrDST); ++k_block) { + cute::gemm(tiled_mma_dsq, + tDKrDST(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), + tDKrQT(_,_,k_block,pipeline_load_mma_q_release_state.index()), + tDKtDK); + tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::One; + } + + // signal to epilgue that dK is ready + pipeline_mma_compute_dkdv.producer_commit(pipeline_mma_compute_dkdv_producer_state); + ++pipeline_mma_compute_dkdv_producer_state; + + // we've already acquired mma_reduce_dq in the loop + + // dQ = dS*K + tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDQrDS); ++k_block) { + cute::gemm(tiled_mma_dsk, + tDQrDS(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), + tDQrKT(_,_,k_block,_0{}), + tDQtDQ); + tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_mma_reduce_dq.producer_commit(pipeline_mma_reduce_dq_producer_state); + ++pipeline_mma_reduce_dq_producer_state; + + pipeline_load_mma_q.consumer_release(pipeline_load_mma_q_release_state); + ++pipeline_load_mma_q_release_state; + + pipeline_compute_mma_ds.consumer_release(pipeline_compute_mma_ds_consumer_state); + ++pipeline_compute_mma_ds_consumer_state; + } + + + + template + CUTLASS_DEVICE void store( + TensorG gmem, + TensorR const& regs, + TensorC const& coord, + TensorShape const& tensor_shape) { + //TODO Performance of FlashMLA on hopper is dropped with latest cutlass, so here revert the to the old version. + // Tensor preds = cute::lazy::transform(coord, [&](auto const& c) { return elem_less(c, tensor_shape); }); + + auto copy_op = make_cotiled_copy( + Copy_Atom, Element>{}, + make_layout(make_shape(_1{}, Int{})), + regs.layout() + ); + auto thr_copy = copy_op.get_slice(_0{}); + + Tensor quantized_regs = quantize(regs); + auto tCg = thr_copy.partition_D(gmem); + auto tCr = thr_copy.partition_S(quantize(regs)); + auto tCc = thr_copy.partition_D(coord); + + + constexpr int R = decltype(tCr.layout())::rank; + auto tCg_v = group_modes<1, R>(tCg); + auto tCr_v = group_modes<1, R>(tCr); + auto tCc_v = group_modes<1, R>(tCc); + auto tCp_v = make_tensor(shape<1>(tCc_v)); + + for (int i = 0; i < size(tCp_v); ++i) { + tCp_v(i) = elem_less(tCc_v(_0{},i), tensor_shape); + } + + copy_if(copy_op, tCp_v, tCr_v, tCg_v); + + } + + + template + CUTLASS_DEVICE void epilogue_clear( + BlkCoord const& blk_coord, + BlkOffset const& blk_offset, + ProblemShape_ const& problem_shape, + MainloopArguments const& mainloop_args, + EpilogueArguments const& epilogue_args) { + + auto [Q, K, D, D_VO, HB] = problem_shape; + auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; + + auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk); + auto mDK = domain_offset(select<1,2,4>(blk_offset), mDK_in); + auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{}) + (_, _, blk_coord_k, _0{}, blk_coord_batch); + + Tensor cDK = domain_offset( + make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}), + make_identity_tensor(take<0,2>(TileShapeDSQ{})) + ); + + auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv); + auto mDV = domain_offset(select<1,3,4>(blk_offset), mDV_in); + auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{}) + (_, _, blk_coord_k, _0{}, blk_coord_batch); + + Tensor cDV = domain_offset( + make_coord(blk_coord_k * TileShapeK{}, _0{}), + make_identity_tensor(take<0,2>(TileShapePDO{})) + ); + + for (int i = threadIdx.x; i < size(gDK); i += blockDim.x) { + if (elem_less(cDK(i), select<1,2>(problem_shape))) { + gDK(i) = Element(0); + } + } + for (int i = threadIdx.x; i < size(gDV); i += blockDim.x) { + if (elem_less(cDV(i), select<1,3>(problem_shape))) { + gDV(i) = Element(0); + } + } + + } + + + template + CUTLASS_DEVICE void epilogue( + BlkCoord const& blk_coord, + BlkOffset const& blk_offset, + ProblemShape_ const& problem_shape, + MainloopArguments const& mainloop_args, + EpilogueArguments const& epilogue_args, + PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, + typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) { + + auto [Q, K, D, D_VO, HB] = problem_shape; + auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; + + auto load_op = SM100_TMEM_LOAD_32dp32b16x{}; + + auto tDKtDK = partition_fragment_C(TiledMmaDSQ{}, select<0,1>(TileShapeDSQ{}))(make_coord(_,_),_0{},_0{}); + tDKtDK.data() = TmemAllocation::kDK; + + auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk); + auto mDK = domain_offset(select<1,2,4>(blk_offset), mDK_in); + auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{}) + (_, _, blk_coord_k, _0{}, blk_coord_batch); + + Tensor cDK = domain_offset( + make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}), + make_identity_tensor(take<0,2>(TileShapeDSQ{})) + ); + + constexpr int kNumWarpgroups = kNumComputeWarps / 4; + int dp_idx = threadIdx.x % 128; + int wg_idx = (threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp)) / 128; + + auto split_wg = [&](auto const& t) { + if constexpr (decltype(rank(t))::value == 3) { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), make_shape(Int{}, size<2>(t) / Int{})))); + return p(_, _, make_coord(wg_idx, _)); + } + else { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), size<2>(t), make_shape(Int{}, size<3>(t) / Int{})))); + return p(_, _, _, make_coord(wg_idx, _)); + } + }; + + auto tiled_t2r_dk = make_tmem_copy(load_op, tDKtDK); + auto thread_t2r_dk = tiled_t2r_dk.get_slice(dp_idx); + + Tensor tTR_cDK = split_wg(thread_t2r_dk.partition_D(cDK)); + Tensor tTR_gDK = split_wg(thread_t2r_dk.partition_D(gDK)); + Tensor tTR_rDK = make_tensor(shape(tTR_cDK)); + Tensor tTR_tDK = split_wg(thread_t2r_dk.partition_S(tDKtDK)); + + auto tDVtDV = partition_fragment_C(TiledMmaPDO{}, select<0,1>(TileShapePDO{}))(make_coord(_,_),_0{},_0{}); + tDVtDV.data() = TmemAllocation::kDV; + + auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv); + auto mDV = domain_offset(select<1,3,4>(blk_offset), mDV_in); + auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{}) + (_, _, blk_coord_k, _0{}, blk_coord_batch); + + Tensor cDV = domain_offset( + make_coord(blk_coord_k * TileShapeK{}, _0{}), + make_identity_tensor(take<0,2>(TileShapePDO{})) + ); + + auto tiled_t2r_dv = make_tmem_copy(load_op, tDVtDV); + auto thread_t2r_dv = tiled_t2r_dv.get_slice(dp_idx); + + Tensor tTR_cDV = split_wg(thread_t2r_dv.partition_D(cDV)); + Tensor tTR_gDV = split_wg(thread_t2r_dv.partition_D(gDV)); + Tensor tTR_rDV = make_tensor(shape(tTR_cDV)); + Tensor tTR_tDV = split_wg(thread_t2r_dv.partition_S(tDVtDV)); + + pipeline_mma_compute_dkdv.consumer_wait(pipeline_mma_compute_dkdv_consumer_state); + + // load tDVtDV + cute::copy(tiled_t2r_dv, tTR_tDV, tTR_rDV); + + // store tDVgDV + store(tTR_gDV, tTR_rDV, tTR_cDV, select<1,3>(problem_shape)); + + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_compute_dkdv.consumer_release(pipeline_mma_compute_dkdv_consumer_state); + ++pipeline_mma_compute_dkdv_consumer_state; + + pipeline_mma_compute_dkdv.consumer_wait(pipeline_mma_compute_dkdv_consumer_state); + + // load tDKtDK + cute::copy(tiled_t2r_dk, tTR_tDK, tTR_rDK); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rDK); i++) { + tTR_rDK(i) = mainloop_args.softmax_scale * tTR_rDK(i); + } + + // store tDKgDK + store(tTR_gDK, tTR_rDK, tTR_cDK, select<1,2>(problem_shape)); + + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_compute_dkdv.consumer_release(pipeline_mma_compute_dkdv_consumer_state); + ++pipeline_mma_compute_dkdv_consumer_state; + + } + + + template + CUTLASS_DEVICE void compute( + BlkCoord const& blk_coord, + BlkOffset const& blk_offset, + ProblemShape_ const& problem_shape, + int iter_index, + int iter_count, + MainloopArguments const& mainloop_args, + EpilogueArguments const& epilogue_args, + TensorStorage& shared_tensors, + PipelineLoadComputeLSE& pipeline_load_compute_lse, + typename PipelineLoadComputeLSE::PipelineState& pipeline_load_compute_lse_consumer_state, + PipelineLoadComputeSumOdO& pipeline_load_compute_sum_odo, + typename PipelineLoadComputeSumOdO::PipelineState& pipeline_load_compute_sum_odo_consumer_state, + PipelineMmaComputeS& pipeline_mma_compute_s, + typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_consumer_state, + PipelineMmaComputeDP& pipeline_mma_compute_dp, + typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_consumer_state, + PipelineComputeMmaP& pipeline_compute_mma_p, + typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_producer_state, + PipelineComputeMmaDS& pipeline_compute_mma_ds, + typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_producer_state, + PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, + typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) { + + + auto [Q, K, D, D_VO, HB] = problem_shape; + + // in tmem, S & P overlap + // and dP and dQ overlap + + // there are two compute wg's that cooperatively compute softmax + // they are striped by this tmem atom, i.e. wg0 has 16 elems, then wg1 etc + + auto load_op = SM100_TMEM_LOAD_16dp32b32x{}; + + Tensor tSTtST = partition_fragment_C(TiledMmaQK{}, select<0,1>(TileShapeQK{}))(make_coord(_,_),_0{},_0{}); + tSTtST.data() = TmemAllocation::kS; + + Tensor tDPTtDPT = partition_fragment_C(TiledMmaDOV{}, select<0,1>(TileShapeDOV{}))(make_coord(_,_),_0{},_0{}); + tDPTtDPT.data() = TmemAllocation::kDP; + + Tensor cST = make_identity_tensor(take<0,2>(TileShapeQK{})); + Tensor cDPT = make_identity_tensor(take<0,2>(TileShapeDOV{})); + Tensor cPT = make_identity_tensor(take<0,2>(TileShapeQK{})); + + constexpr int kNumWarpgroups = kNumComputeWarps / 4; + int dp_idx = threadIdx.x % 128; + int wg_idx = (threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp)) / 128; + auto tiled_t2r = make_tmem_copy(load_op, tSTtST); + auto thread_t2r = tiled_t2r.get_slice(dp_idx); + + auto split_wg = [&](auto const& t) { + if constexpr (decltype(size<1>(t))::value > 1) { + if constexpr (decltype(rank(t))::value == 3) { + auto p = t.compose(make_layout(make_shape(size<0>(t), make_shape(Int{}, size<1>(t) / Int{}), size<2>(t)))); + return p(_, make_coord(wg_idx, _), _); + } + else { + auto p = t.compose(make_layout(make_shape(size<0>(t), make_shape(Int{}, size<1>(t) / Int{}), size<2>(t), size<3>(t)))); + return p(_, make_coord(wg_idx, _), _, _); + } + } + else { + if constexpr (decltype(rank(t))::value == 3) { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), make_shape(Int{}, size<2>(t) / Int{})))); + return p(_, _, make_coord(wg_idx, _)); + } + else { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), size<2>(t), make_shape(Int{}, size<3>(t) / Int{})))); + return p(_, _, _, make_coord(wg_idx, _)); + } + } + }; + + Tensor tTR_cST_p = thread_t2r.partition_D(cST); + Tensor tTR_cST = split_wg(tTR_cST_p); + Tensor tTR_rST = make_tensor(shape(tTR_cST)); + Tensor tTR_tST = split_wg(thread_t2r.partition_S(tSTtST)); + + Tensor tTR_cDPT_p = thread_t2r.partition_D(cDPT); + Tensor tTR_cPT_p = thread_t2r.partition_D(cPT); + Tensor tTR_cDPT = split_wg(tTR_cDPT_p); + Tensor tTR_rDPT = make_tensor(shape(tTR_cDPT)); + Tensor tTR_tDPT = split_wg(thread_t2r.partition_S(tDPTtDPT)); + + Tensor sLSE = make_tensor(make_smem_ptr(shared_tensors.smem_lse.begin()), SmemLayoutLSE{}); + Tensor sSumOdO = make_tensor(make_smem_ptr(shared_tensors.smem_sum_odo.begin()), SmemLayoutSumOdO{}); + + bool is_residual_k = get<1>(blk_coord) * TileShapeK{} + TileShapeK{} >= get<1>(problem_shape); + int last_iter = iter_count - 1 + iter_index; + + CUTLASS_PRAGMA_NO_UNROLL + while (iter_count > 0) { + // wait for S and P + pipeline_mma_compute_s.consumer_wait(pipeline_mma_compute_s_consumer_state); + pipeline_compute_mma_p.producer_acquire(pipeline_compute_mma_p_producer_state); + // wait for LSE + pipeline_load_compute_lse.consumer_wait(pipeline_load_compute_lse_consumer_state); + + auto dispatch_bool = [](bool b, auto fn) { + if (b) { + fn(cute::true_type{}); + } + else { + fn(cute::false_type{}); + } + }; + + bool leading_causal_masking = false; + if constexpr (std::is_base_of_v, Mask>) { + leading_causal_masking = warp_uniform(iter_index == get<1>(blk_coord)); + } else if constexpr (std::is_base_of_v, Mask>) { + int offset = get<1>(problem_shape) - get<0>(problem_shape); + int kv_left = get<1>(blk_coord) * TileShapeK{}; + int kv_right = kv_left + TileShapeK{} - 1; + int q_left = iter_index * TileShapeQ{} + offset; + int q_right = q_left + TileShapeQ{} - 1; + + leading_causal_masking = warp_uniform(!((q_left > kv_right) || (q_right < kv_left))); + } + bool trailing_residual_masking = false; + if constexpr (std::is_base_of_v) { + trailing_residual_masking = warp_uniform((iter_index == last_iter) || is_residual_k); + } + + dispatch_bool(leading_causal_masking || trailing_residual_masking, [&](auto is_masked_tile) { + + // compute P = softmax(S, LSE) + cute::copy(tiled_t2r, tTR_tST, tTR_rST); + + if constexpr (decltype(is_masked_tile)::value) { + Mask{}.apply_mask(tTR_rST, [&](int i) { + auto c_transpose = tTR_cST(i); + return make_coord(get<0>(c_transpose) + iter_index * TileShapeQ{}, get<1>(c_transpose) + get<1>(blk_coord) * TileShapeK{}); + }, problem_shape); + } + + ElementAcc log2_e = static_cast(M_LOG2E); + float2 softmax_scale_log2_e; + softmax_scale_log2_e.x = mainloop_args.softmax_scale * log2_e; + softmax_scale_log2_e.y = mainloop_args.softmax_scale * log2_e; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rST); i += 2) { + float2 acc; + float2 lse; + float2 out; + acc.x = tTR_rST(i); + acc.y = tTR_rST(i + 1); + lse.x = sLSE(get<0>(tTR_cST(i)), pipeline_load_compute_lse_consumer_state.index()); + lse.y = sLSE(get<0>(tTR_cST(i+1)), pipeline_load_compute_lse_consumer_state.index()); + cute::fma(out, softmax_scale_log2_e, acc, lse); + tTR_rST(i) = ::exp2f(out.x); + tTR_rST(i+1) = ::exp2f(out.y); + } + + auto tRT_rST = quantize(tTR_rST); + + Tensor sP = make_tensor(make_smem_ptr((Element*) shared_tensors.smem_p.begin()), SmemLayoutP{}) + (_, _, _, pipeline_compute_mma_p_producer_state.index()); + + cutlass::arch::fence_view_async_tmem_load(); + cutlass::arch::NamedBarrier( + kNumComputeWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::TransformBarrier + ).arrive_and_wait(); + + auto sP_pi = as_position_independent_swizzle_tensor(sP); + + auto thread_layout = make_ordered_layout( + make_shape(_64{}, _32{}, _2{}, _2{}), + make_stride(_3{}, _0{}, _1{}, _2{}) + ); + auto sP_pi_slice_p = sP_pi.compose(thread_layout)(((dp_idx/32) * 16) + (dp_idx % 16) , _, (dp_idx % 32 / 16), _).compose(make_layout(shape(tTR_cPT_p))); + auto sP_pi_slice = split_wg(sP_pi_slice_p); + copy_aligned(tRT_rST, sP_pi_slice); + }); + + // notify for P + cutlass::arch::fence_view_async_shared(); + pipeline_compute_mma_p.producer_commit(pipeline_compute_mma_p_producer_state); + ++pipeline_compute_mma_p_producer_state; + // release S + pipeline_mma_compute_s.consumer_release(pipeline_mma_compute_s_consumer_state); + ++pipeline_mma_compute_s_consumer_state; + // release LSE + pipeline_load_compute_lse.consumer_release(pipeline_load_compute_lse_consumer_state); + ++pipeline_load_compute_lse_consumer_state; + + // wait for OdO + pipeline_load_compute_sum_odo.consumer_wait(pipeline_load_compute_sum_odo_consumer_state); + // wait for dP + pipeline_mma_compute_dp.consumer_wait(pipeline_mma_compute_dp_consumer_state); + + // wait for dS + // in principle, we could defer waiting for dS, and move in the freeing of dP + // however, that would force us to keep dS in registers longer + pipeline_compute_mma_ds.producer_acquire(pipeline_compute_mma_ds_producer_state); + + // compute dS = dsoftmax(P, dP, sum_OdO) + cute::copy(tiled_t2r, tTR_tDPT, tTR_rDPT); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rDPT); i += 2) { + float2 st; + st.x = tTR_rST(i); + st.y = tTR_rST(i+1); + float2 dpt; + dpt.x = tTR_rDPT(i); + dpt.y = tTR_rDPT(i+1); + float2 odo; + odo.x = sSumOdO(get<0>(tTR_cDPT(i)), pipeline_load_compute_sum_odo_consumer_state.index()); + odo.y = sSumOdO(get<0>(tTR_cDPT(i+1)), pipeline_load_compute_sum_odo_consumer_state.index()); + float2 dif; + // sum odo is negated during preprocess + cute::add(dif, dpt, odo); + float2 out; + cute::mul(out, dif, st); + tTR_rDPT(i) = out.x; + tTR_rDPT(i+1) = out.y; + } + + auto tTR_rDST = quantize(tTR_rDPT); + + // release dP + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_compute_dp.consumer_release(pipeline_mma_compute_dp_consumer_state); + ++pipeline_mma_compute_dp_consumer_state; + + Tensor sDS = make_tensor(make_smem_ptr((Element*) shared_tensors.smem_ds_t.begin()), SmemLayoutDST{}) + (_, _, _, pipeline_compute_mma_ds_producer_state.index()); + + auto thread_layout = make_ordered_layout( + make_shape(_64{}, _32{}, _2{}, _2{}), + make_stride(_3{}, _0{}, _1{}, _2{}) + ); + auto sDS_pi = as_position_independent_swizzle_tensor(sDS); + auto sDS_pi_slice_p = sDS_pi.compose(thread_layout)(((dp_idx/32) * 16) + (dp_idx % 16) , _, (dp_idx % 32 / 16), _).compose(make_layout(shape (tTR_cDPT_p))); + auto sDS_pi_slice = split_wg(sDS_pi_slice_p); + + copy_aligned(tTR_rDST, sDS_pi_slice); + + // notify for dS + cutlass::arch::fence_view_async_shared(); + pipeline_compute_mma_ds.producer_commit(pipeline_compute_mma_ds_producer_state); + ++pipeline_compute_mma_ds_producer_state; + // release OdO + pipeline_load_compute_sum_odo.consumer_release(pipeline_load_compute_sum_odo_consumer_state); + ++pipeline_load_compute_sum_odo_consumer_state; + + iter_count -= 1; + iter_index += 1; + } + + epilogue( + blk_coord, blk_offset, problem_shape, mainloop_args, epilogue_args, + pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_consumer_state + ); + } + + template + CUTLASS_DEVICE void reduce( + BlkCoord const& blk_coord, + ProblemShape_ const& problem_shape, + int iter_index, + int iter_count, + MainloopArguments const& mainloop_args, + MainloopParams const& mainloop_params, + TensorStorage& shared_tensors, + PipelineMmaReduceDQ& pipeline_mma_reduce_dq, + typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_consumer_state, + PipelineReduceTmaStore& pipeline_reduce_tma_store, + typename PipelineReduceTmaStore::PipelineState& pipeline_reduce_tma_store_producer_state) { + + using X = Underscore; + + auto [Q, K, D, D_VO, HB] = problem_shape; + + auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; + + // must match TileShapeDQ + auto load_op = SM100_TMEM_LOAD_16dp32b16x{}; + + auto tDQtDQ = partition_fragment_C(TiledMmaDSK{}, select<0,1>(TileShapeDSK{}))(make_coord(_,_),_0{},_0{}); + tDQtDQ.data() = TmemAllocation::kDQ; + + Tensor mDQ = mainloop_params.tma_red_dq.get_tma_tensor(make_shape(Q, D, HB)); + auto gDQ = local_tile(mDQ, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{}) + (_, _, _, _0{}, blk_coord_batch); + + Tensor cDQ = make_identity_tensor(take<0,2>(TileShapeDSK{})); + + Tensor sDQ = make_tensor(make_smem_ptr(shared_tensors.smem_dq.begin()), SmemLayoutDQ{}); + + int thread_idx = threadIdx.x % (kNumReduceWarps * NumThreadsPerWarp); + auto tiled_t2r = make_tmem_copy(load_op, tDQtDQ); + auto thread_t2r = tiled_t2r.get_slice(thread_idx); + + Tensor tTR_cDQ = thread_t2r.partition_D(cDQ); + Tensor tTR_gDQ = thread_t2r.partition_D(gDQ); + Tensor tTR_sDQ = thread_t2r.partition_D(sDQ); + Tensor tTR_tDQ = thread_t2r.partition_S(tDQtDQ); + + auto block_tma = mainloop_params.tma_red_dq.get_slice(_0{}); + + Tensor tDQsDQ = block_tma.partition_S(sDQ); + Tensor tDQcDQ = block_tma.partition_S(cDQ); + Tensor tDQgDQ = block_tma.partition_D(gDQ); + + int lane_predicate = (threadIdx.x % (kNumReduceWarps * NumThreadsPerWarp)) == 0; + + while (iter_count > 0) { + pipeline_mma_reduce_dq.consumer_wait(pipeline_mma_reduce_dq_consumer_state); + + Tensor tTR_rDQ = make_tensor(shape(tTR_cDQ)); + + // load dQ from tmem to rmem + cute::copy(tiled_t2r, tTR_tDQ, tTR_rDQ); + + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_reduce_dq.consumer_release(pipeline_mma_reduce_dq_consumer_state); + ++pipeline_mma_reduce_dq_consumer_state; + + // we don't have enough smem to dump it all to smem, so we do it in stages + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<2>(tTR_cDQ); i++) { + if (lane_predicate) { + pipeline_reduce_tma_store.producer_acquire(pipeline_reduce_tma_store_producer_state); + } + // wait in all threads for the acquire to complete + cutlass::arch::NamedBarrier( + kNumReduceWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::TransposeBarrier + ).arrive_and_wait(); + + cute::copy(tTR_rDQ(_, _, i), tTR_sDQ(_, _, _0{}, pipeline_reduce_tma_store_producer_state.index())); + + // wait for the stores to all be visible to the TMA + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier( + kNumReduceWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::TransposeBarrier + ).arrive_and_wait(); + if (lane_predicate) { + // launch tma store + copy(mainloop_params.tma_red_dq, tDQsDQ(_,_,_0{}, pipeline_reduce_tma_store_producer_state.index()), tDQgDQ(_,_,i,iter_index)); + pipeline_reduce_tma_store.producer_commit(pipeline_reduce_tma_store_producer_state); + } + + ++pipeline_reduce_tma_store_producer_state; + } + + iter_count -= 1; + iter_index += 1; + } + } + + + CUTLASS_DEVICE void operator()(Params const& params, char* smem) { + int warp_idx = cutlass::canonical_warp_idx_sync(); + auto role = warp_idx_to_role(warp_idx); + uint32_t lane_predicate = cute::elect_one_sync(); + + if (role == WarpRole::Load && lane_predicate) { + prefetch_tma_descriptor(params.mainloop_params.tma_load_q.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_k.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_v.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_do.get_tma_descriptor()); + } + + SharedStorage& shared_storage = *reinterpret_cast(smem); + + int initializing_warp = 0; + typename PipelineLoadMmaQ::Params pipeline_load_mma_q_params; + if (role == WarpRole::Load) { + pipeline_load_mma_q_params.role = PipelineLoadMmaQ::ThreadCategory::Producer; + } + if (role == WarpRole::Mma) { + pipeline_load_mma_q_params.role = PipelineLoadMmaQ::ThreadCategory::Consumer; + } + pipeline_load_mma_q_params.is_leader = lane_predicate && (role == WarpRole::Load); + // Also loads K in the first iteration + pipeline_load_mma_q_params.transaction_bytes = kTransactionsBytesLoadQ; + pipeline_load_mma_q_params.initializing_warp = initializing_warp++; + PipelineLoadMmaQ pipeline_load_mma_q(shared_storage.pipelines.load_mma_q, pipeline_load_mma_q_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineLoadMmaDO::Params pipeline_load_mma_do_params; + if (role == WarpRole::Load) { + pipeline_load_mma_do_params.role = PipelineLoadMmaDO::ThreadCategory::Producer; + } + if (role == WarpRole::Mma) { + pipeline_load_mma_do_params.role = PipelineLoadMmaDO::ThreadCategory::Consumer; + } + pipeline_load_mma_do_params.is_leader = lane_predicate && (role == WarpRole::Load); + // Also loads V in the first iteration + pipeline_load_mma_do_params.transaction_bytes = kTransactionsBytesLoadDO; + pipeline_load_mma_do_params.initializing_warp = initializing_warp++; + PipelineLoadMmaDO pipeline_load_mma_do(shared_storage.pipelines.load_mma_do, pipeline_load_mma_do_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineLoadComputeLSE::Params pipeline_load_compute_lse_params; + if (role == WarpRole::Load) { + pipeline_load_compute_lse_params.role = PipelineLoadComputeLSE::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_load_compute_lse_params.role = PipelineLoadComputeLSE::ThreadCategory::Consumer; + } + pipeline_load_compute_lse_params.producer_arv_count = NumThreadsPerWarp; + pipeline_load_compute_lse_params.consumer_arv_count = kNumComputeWarps * NumThreadsPerWarp; + pipeline_load_compute_lse_params.initializing_warp = initializing_warp++; + PipelineLoadComputeLSE pipeline_load_compute_lse( + shared_storage.pipelines.load_compute_lse, + pipeline_load_compute_lse_params, + /*barrier init*/ cute::true_type{}); + + typename PipelineLoadComputeSumOdO::Params pipeline_load_compute_sum_odo_params; + if (role == WarpRole::Load) { + pipeline_load_compute_sum_odo_params.role = PipelineLoadComputeSumOdO::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_load_compute_sum_odo_params.role = PipelineLoadComputeSumOdO::ThreadCategory::Consumer; + } + pipeline_load_compute_sum_odo_params.producer_arv_count = NumThreadsPerWarp; + pipeline_load_compute_sum_odo_params.consumer_arv_count = kNumComputeWarps * NumThreadsPerWarp; + pipeline_load_compute_sum_odo_params.initializing_warp = initializing_warp++; + PipelineLoadComputeSumOdO pipeline_load_compute_sum_odo( + shared_storage.pipelines.load_compute_sum_odo, + pipeline_load_compute_sum_odo_params, + /*barrier init*/ cute::true_type{}); + + typename PipelineMmaComputeS::Params pipeline_mma_compute_s_params; + if (role == WarpRole::Mma) { + pipeline_mma_compute_s_params.role = PipelineMmaComputeS::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_mma_compute_s_params.role = PipelineMmaComputeS::ThreadCategory::Consumer; + } + pipeline_mma_compute_s_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_mma_compute_s_params.initializing_warp = initializing_warp++; + PipelineMmaComputeS pipeline_mma_compute_s( + shared_storage.pipelines.mma_compute_s, + pipeline_mma_compute_s_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineMmaComputeDP::Params pipeline_mma_compute_dp_params; + if (role == WarpRole::Mma) { + pipeline_mma_compute_dp_params.role = PipelineMmaComputeDP::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_mma_compute_dp_params.role = PipelineMmaComputeDP::ThreadCategory::Consumer; + } + pipeline_mma_compute_dp_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_mma_compute_dp_params.initializing_warp = initializing_warp++; + PipelineMmaComputeDP pipeline_mma_compute_dp( + shared_storage.pipelines.mma_compute_dp, + pipeline_mma_compute_dp_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineMmaReduceDQ::Params pipeline_mma_reduce_dq_params; + if (role == WarpRole::Mma) { + pipeline_mma_reduce_dq_params.role = PipelineMmaReduceDQ::ThreadCategory::Producer; + } + if (role == WarpRole::Reduce) { + pipeline_mma_reduce_dq_params.role = PipelineMmaReduceDQ::ThreadCategory::Consumer; + } + pipeline_mma_reduce_dq_params.consumer_arv_count = kNumReduceWarps * cutlass::NumThreadsPerWarp; + pipeline_mma_reduce_dq_params.initializing_warp = initializing_warp++; + PipelineMmaReduceDQ pipeline_mma_reduce_dq( + shared_storage.pipelines.mma_reduce_dq, + pipeline_mma_reduce_dq_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineComputeMmaP::Params pipeline_compute_mma_p_params; + if (role == WarpRole::Mma) { + pipeline_compute_mma_p_params.role = PipelineComputeMmaP::ThreadCategory::Consumer; + } + if (role == WarpRole::Compute) { + pipeline_compute_mma_p_params.role = PipelineComputeMmaP::ThreadCategory::Producer; + } + pipeline_compute_mma_p_params.producer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_compute_mma_p_params.consumer_arv_count = 1; + pipeline_compute_mma_p_params.initializing_warp = initializing_warp++; + PipelineComputeMmaP pipeline_compute_mma_p( + shared_storage.pipelines.compute_mma_p, + pipeline_compute_mma_p_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineComputeMmaDS::Params pipeline_compute_mma_ds_params; + if (role == WarpRole::Mma) { + pipeline_compute_mma_ds_params.role = PipelineComputeMmaDS::ThreadCategory::Consumer; + } + if (role == WarpRole::Compute) { + pipeline_compute_mma_ds_params.role = PipelineComputeMmaDS::ThreadCategory::Producer; + } + pipeline_compute_mma_ds_params.producer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_compute_mma_ds_params.consumer_arv_count = 1; + pipeline_compute_mma_ds_params.initializing_warp = initializing_warp++; + PipelineComputeMmaDS pipeline_compute_mma_ds( + shared_storage.pipelines.compute_mma_ds, + pipeline_compute_mma_ds_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineMmaComputeDKDV::Params pipeline_mma_compute_dkdv_params; + if (role == WarpRole::Mma) { + pipeline_mma_compute_dkdv_params.role = PipelineMmaComputeDKDV::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_mma_compute_dkdv_params.role = PipelineMmaComputeDKDV::ThreadCategory::Consumer; + } + pipeline_mma_compute_dkdv_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_mma_compute_dkdv_params.initializing_warp = initializing_warp++; + PipelineMmaComputeDKDV pipeline_mma_compute_dkdv( + shared_storage.pipelines.mma_compute_dkdv, + pipeline_mma_compute_dkdv_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + PipelineReduceTmaStore pipeline_reduce_tma_store; + + TmemAllocator tmem_allocator; + + pipeline_init_arrive_relaxed(size(ClusterShape{})); + + pipeline_load_mma_q.init_masks(ClusterShape{}); + pipeline_load_mma_do.init_masks(ClusterShape{}); + pipeline_mma_compute_s.init_masks(ClusterShape{}); + pipeline_mma_compute_dp.init_masks(ClusterShape{}); + pipeline_mma_reduce_dq.init_masks(ClusterShape{}); + pipeline_compute_mma_p.init_masks(ClusterShape{}); + pipeline_compute_mma_ds.init_masks(ClusterShape{}); + pipeline_mma_compute_dkdv.init_masks(ClusterShape{}); + + typename decltype(pipeline_load_mma_q)::PipelineState pipeline_load_mma_q_consumer_state; + typename decltype(pipeline_load_mma_do)::PipelineState pipeline_load_mma_do_consumer_state; + typename decltype(pipeline_load_compute_lse)::PipelineState pipeline_load_compute_lse_consumer_state; + typename decltype(pipeline_load_compute_sum_odo)::PipelineState pipeline_load_compute_sum_odo_consumer_state; + typename decltype(pipeline_mma_compute_s)::PipelineState pipeline_mma_compute_s_consumer_state; + typename decltype(pipeline_mma_compute_dp)::PipelineState pipeline_mma_compute_dp_consumer_state; + typename decltype(pipeline_mma_reduce_dq)::PipelineState pipeline_mma_reduce_dq_consumer_state; + typename decltype(pipeline_compute_mma_p)::PipelineState pipeline_compute_mma_p_consumer_state; + typename decltype(pipeline_compute_mma_ds)::PipelineState pipeline_compute_mma_ds_consumer_state; + typename decltype(pipeline_mma_compute_dkdv)::PipelineState pipeline_mma_compute_dkdv_consumer_state; + + auto pipeline_load_mma_q_producer_state = make_producer_start_state(); + auto pipeline_load_mma_do_producer_state = make_producer_start_state(); + auto pipeline_load_compute_lse_producer_state = make_producer_start_state(); + auto pipeline_load_compute_sum_odo_producer_state = make_producer_start_state(); + auto pipeline_mma_compute_s_producer_state = make_producer_start_state(); + auto pipeline_mma_compute_dp_producer_state = make_producer_start_state(); + auto pipeline_mma_reduce_dq_producer_state = make_producer_start_state(); + auto pipeline_compute_mma_p_producer_state = make_producer_start_state(); + auto pipeline_compute_mma_ds_producer_state = make_producer_start_state(); + auto pipeline_mma_compute_dkdv_producer_state = make_producer_start_state(); + auto pipeline_reduce_tma_store_producer_state = make_producer_start_state(); + + pipeline_init_wait(size(ClusterShape{})); + + auto blk_coord = make_coord(_0{}, blockIdx.x, _0{}, _0{}, make_coord(blockIdx.y, blockIdx.z)); + auto [problem_shape, blk_offset] = apply_variable_length_offset( + params.problem_shape, + blk_coord + ); + int iter_count = ceil_div(get<0>(problem_shape), TileShapeQ{}); + int iter_start = 0; + if constexpr (std::is_base_of_v, Mask>) { + iter_start = (get<1>(blk_coord) * TileShapeK{}) / TileShapeQ{}; + } else if constexpr (std::is_base_of_v, Mask>) { + int offset = get<1>(problem_shape) - get<0>(problem_shape); + iter_start = max(0, (int(get<1>(blk_coord) * TileShapeK{}) - offset) / (int)TileShapeQ{}); + } + if (get<1>(blk_coord) * TileShapeK{} >= get<1>(problem_shape)) { + return; + } + iter_count -= iter_start; + + if (iter_count <= 0) { + epilogue_clear( + blk_coord, + blk_offset, + problem_shape, + params.mainloop, + params.epilogue + ); + return; + } + + if (role == WarpRole::Load) { + warpgroup_reg_set(); + + load( + blk_coord, + blk_offset, + problem_shape, + iter_start, + iter_count, + params.mainloop, + params.mainloop_params, + shared_storage.tensors, + pipeline_load_mma_q, pipeline_load_mma_q_producer_state, + pipeline_load_mma_do, pipeline_load_mma_do_producer_state, + pipeline_load_compute_lse, pipeline_load_compute_lse_producer_state, + pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_producer_state + ); + + } + else if (role == WarpRole::Mma) { + warpgroup_reg_set(); + + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + + mma( + blk_coord, + problem_shape, + iter_start, + iter_count, + params.mainloop, + shared_storage.tensors, + pipeline_load_mma_q, pipeline_load_mma_q_consumer_state, + pipeline_load_mma_do, pipeline_load_mma_do_consumer_state, + pipeline_mma_compute_s, pipeline_mma_compute_s_producer_state, + pipeline_mma_compute_dp, pipeline_mma_compute_dp_producer_state, + pipeline_mma_reduce_dq, pipeline_mma_reduce_dq_producer_state, + pipeline_compute_mma_p, pipeline_compute_mma_p_consumer_state, + pipeline_compute_mma_ds, pipeline_compute_mma_ds_consumer_state, + pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_producer_state + ); + + } + else if (role == WarpRole::Compute) { + warpgroup_reg_set(); + + compute( + blk_coord, + blk_offset, + problem_shape, + iter_start, + iter_count, + params.mainloop, + params.epilogue, + shared_storage.tensors, + pipeline_load_compute_lse, pipeline_load_compute_lse_consumer_state, + pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_consumer_state, + pipeline_mma_compute_s, pipeline_mma_compute_s_consumer_state, + pipeline_mma_compute_dp, pipeline_mma_compute_dp_consumer_state, + pipeline_compute_mma_p, pipeline_compute_mma_p_producer_state, + pipeline_compute_mma_ds, pipeline_compute_mma_ds_producer_state, + pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_consumer_state + ); + + cutlass::arch::NamedBarrier( + kNumComputeWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier + ).arrive_and_wait(); + + if (warp_idx % kNumComputeWarps == 0) { + uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + + } + else if (role == WarpRole::Reduce) { + warpgroup_reg_set(); + + reduce( + blk_coord, + problem_shape, + iter_start, + iter_count, + params.mainloop, + params.mainloop_params, + shared_storage.tensors, + pipeline_mma_reduce_dq, pipeline_mma_reduce_dq_consumer_state, + pipeline_reduce_tma_store, pipeline_reduce_tma_store_producer_state + ); + + pipeline_reduce_tma_store.producer_tail(pipeline_reduce_tma_store_producer_state); + } + else { + warpgroup_reg_set(); + + /* no-op */ + + } + } + + static dim3 get_block_shape() { + dim3 block(MaxThreadsPerBlock, 1, 1); + return block; + } + + static dim3 get_grid_shape(Params const& params) { + auto [Q, K, D, D_VO, HB] = params.problem_shape; + auto [H, B] = HB; + dim3 grid(ceil_div(K, TileShapeK{}), H, B); + return grid; + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/csrc/sm100/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp b/csrc/sm100/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp new file mode 100644 index 0000000..8fe503b --- /dev/null +++ b/csrc/sm100/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp @@ -0,0 +1,619 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/layout.hpp" +#include "cutlass/arch/arch.h" +#include "cutlass/kernel_hardware_info.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cute/arch/tmem_allocator_sm100.hpp" + +#include "kernel/fmha_options.hpp" +#include "kernel/fmha_tile_scheduler.hpp" +#include "kernel/fmha_causal_tile_scheduler.hpp" +#include "collective/fmha_fusion.hpp" +#include "collective/fmha_common.hpp" + +namespace cutlass::fmha::kernel { + +using namespace cute; +using namespace cutlass::fmha::collective; + +struct Sm100FmhaCtxKernelWarpspecializedSchedule { + + enum class WarpRole { + Softmax0, + Softmax1, + Correction, + MMA, + Load, + Epilogue, + Empty + }; + + static constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) { + int wg_idx = warp_idx / 4; // warp_idx + if (wg_idx == 0) return WarpRole::Softmax0; // 0 - 3 + if (wg_idx == 1) return WarpRole::Softmax1; // 4 - 7 + if (wg_idx == 2) return WarpRole::Correction; // 8 - 11 + if (warp_idx == 12) return WarpRole::MMA; // 12 + if (warp_idx == 13) return WarpRole::Load; // 13 + if (warp_idx == 14) return WarpRole::Epilogue; // 14 + return WarpRole::Empty; // 15 + } + + static const int NumWarpsSoftmax = 4; + static const int NumWarpsCorrection = 4; + static const int NumWarpsEpilogue = 1; + static const int NumWarpsLoad = 1; + + static const bool kDebugUsingPrintf = false; + static const int NumRegsSoftmax = 192; + static const int NumRegsCorrection = 96 - (kDebugUsingPrintf ? 16 : 0); + static const int NumRegsOther = 32 + (kDebugUsingPrintf ? 16 : 0); + static const int NumRegsEmpty = 24; + + static const int NumWarps = 16; + +}; + + +struct Sm100MlaFwdCtxKernelWarpspecializedSchedule { + + enum class WarpRole { + Softmax0, + Softmax1, + Correction, + MMA, + Load, + Epilogue, + Empty + }; + + static constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) { + int wg_idx = warp_idx / 4; // warp_idx + if (wg_idx == 0) return WarpRole::Softmax0; // 0 - 3 + if (wg_idx == 1) return WarpRole::Softmax1; // 4 - 7 + if (wg_idx == 2) return WarpRole::Correction; // 8 - 11 + if (warp_idx == 12) return WarpRole::MMA; // 12 + if (warp_idx == 13) return WarpRole::Load; // 13 + if (warp_idx == 14) return WarpRole::Epilogue; // 14 + return WarpRole::Empty; // 15 + } + + static const int NumWarpsSoftmax = 4; + static const int NumWarpsCorrection = 4; + static const int NumWarpsEpilogue = 1; + static const int NumWarpsLoad = 1; + + static const bool kDebugUsingPrintf = false; + static const int NumRegsSoftmax = 184; + static const int NumRegsCorrection = 96 - (kDebugUsingPrintf ? 16 : 0); + static const int NumRegsOther = 48 + (kDebugUsingPrintf ? 16 : 0); + static const int NumRegsEmpty = 24; + + static const int NumWarps = 16; + +}; + +template< + class ProblemShapeIn, + class CollectiveMainloop, + class CollectiveEpilogue, + class TileScheduler, + class KernelSchedule = Sm100FmhaCtxKernelWarpspecializedSchedule +> +struct Sm100FmhaFwdKernelTmaWarpspecialized { + + using TileShape = typename CollectiveMainloop::TileShape; + using ProblemShape = ProblemShapeIn; + + using WarpRole = typename KernelSchedule::WarpRole; + + constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) { + return KernelSchedule::warp_idx_to_WarpRole(warp_idx); + } + + static const int NumWarpsSoftmax = KernelSchedule::NumWarpsSoftmax; + static const int NumWarpsCorrection = KernelSchedule::NumWarpsCorrection; + static const int NumWarpsEpilogue = KernelSchedule::NumWarpsEpilogue; + static const int NumWarpsLoad = KernelSchedule::NumWarpsLoad; + + static_assert(NumWarpsEpilogue == CollectiveEpilogue::NumWarpsEpilogue); + static_assert(NumWarpsLoad == CollectiveEpilogue::NumWarpsLoad); + + static const int NumRegsSoftmax = KernelSchedule::NumRegsSoftmax; + static const int NumRegsCorrection = KernelSchedule::NumRegsCorrection; + static const int NumRegsOther = KernelSchedule::NumRegsOther; + static const int NumRegsEmpty = 24; + + static const int NumWarps = KernelSchedule::NumWarps; + + static constexpr bool IsMla = std::is_same_v; + + using ClusterShape = typename CollectiveMainloop::ClusterShape; + + using TmemAllocator = cute::TMEM::Allocator1Sm; + + struct SharedStorage { + using UnionType = union { + typename CollectiveMainloop::TensorStorage mainloop; + typename CollectiveEpilogue::TensorStorage epilogue; + }; + + using StructType = struct { + typename CollectiveMainloop::TensorStorage mainloop; + typename CollectiveEpilogue::TensorStorage epilogue; + }; + + static constexpr bool IsPersistent = std::is_same_v || std::is_same_v; + using MainloopEpilogueStorage = std::conditional_t, + StructType>, + UnionType>; + + MainloopEpilogueStorage mainloop_epilogue; + + struct PipelineStorage { + alignas(16) typename CollectiveMainloop::PipelineQ::SharedStorage load_q; + alignas(16) typename CollectiveMainloop::PipelineKV::SharedStorage load_kv; + alignas(16) typename CollectiveMainloop::PipelineS::SharedStorage mma_s0; + alignas(16) typename CollectiveMainloop::PipelineS::SharedStorage mma_s1; + alignas(16) typename CollectiveMainloop::PipelineC::SharedStorage s0_corr; + alignas(16) typename CollectiveMainloop::PipelineC::SharedStorage s1_corr; + alignas(16) typename CollectiveMainloop::PipelineO::SharedStorage mma_corr; + alignas(16) typename CollectiveMainloop::PipelineE::SharedStorage corr_epi; + alignas(16) typename CollectiveMainloop::OrderBarrierSoftmax::SharedStorage order_s01; + } pipelines; + + uint32_t tmem_base_ptr; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + struct Arguments { + ProblemShape problem_shape; + typename CollectiveMainloop::Arguments mainloop; + typename CollectiveEpilogue::Arguments epilogue; + cutlass::KernelHardwareInfo hw_info; + }; + + struct Params { + ProblemShape problem_shape; + typename CollectiveMainloop::Params mainloop; + typename CollectiveEpilogue::Params epilogue; + typename TileScheduler::Params tile_scheduler; + }; + + static const int MinBlocksPerMultiprocessor = 1; + static const int MaxThreadsPerBlock = NumWarps * cutlass::NumThreadsPerWarp; + using ArchTag = cutlass::arch::Sm100; + + static size_t get_workspace_size(Arguments const& args) { return 0; } + static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) { + return cutlass::Status::kSuccess; + } + + static bool can_implement(Arguments const& args) { + return CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + } + + static dim3 get_grid_shape(Params const& params) { + return TileScheduler::get_grid_shape(params.tile_scheduler); + } + + static dim3 get_block_shape() { + dim3 block(MaxThreadsPerBlock, 1, 1); + return block; + } + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + return Params{ + args.problem_shape, + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace), + TileScheduler::to_underlying_arguments(args.problem_shape, args.hw_info, ClusterShape{}, TileShape{}) + }; + } + + CUTLASS_DEVICE auto apply_batch(const Params ¶ms, ProblemShape const& problem_shape, int batch_idx) { + return apply_variable_length(params.problem_shape, batch_idx); + } + + CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { + + TileScheduler tile_scheduler{params.tile_scheduler}; + + int warp_idx = cutlass::canonical_warp_idx_sync(); + auto role = warp_idx_to_WarpRole(warp_idx); + uint32_t lane_predicate = cute::elect_one_sync(); + + if (role == WarpRole::Load && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + } + + if (role == WarpRole::Epilogue && lane_predicate) { + CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); + } + + SharedStorage& shared_storage = *reinterpret_cast(smem); + + auto get_epilogue_storage = [&]() { + if constexpr (IsMla && CollectiveMainloop::IsOrderLoadEpilogue) { + return reinterpret_cast(shared_storage.mainloop_epilogue.mainloop.smem_o.data()); + } else { + return &shared_storage.mainloop_epilogue.epilogue; + } + }; + typename CollectiveEpilogue::TensorStorage & epilogue_storage = *get_epilogue_storage(); + + + typename CollectiveMainloop::PipelineQ::Params pipeline_load_q_params; + if (role == WarpRole::Load) { + pipeline_load_q_params.role = CollectiveMainloop::PipelineQ::ThreadCategory::Producer; + } + if (role == WarpRole::MMA) { + pipeline_load_q_params.role = CollectiveMainloop::PipelineQ::ThreadCategory::Consumer; + } + pipeline_load_q_params.is_leader = lane_predicate && (role == WarpRole::Load); + pipeline_load_q_params.transaction_bytes = CollectiveMainloop::TransactionBytesLoadQ; + typename CollectiveMainloop::PipelineQ pipeline_load_q( + shared_storage.pipelines.load_q, + pipeline_load_q_params, + ClusterShape{}, cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename CollectiveMainloop::PipelineKV::Params pipeline_load_kv_params; + if (role == WarpRole::Load) { + pipeline_load_kv_params.role = CollectiveMainloop::PipelineKV::ThreadCategory::Producer; + } + if (role == WarpRole::MMA) { + pipeline_load_kv_params.role = CollectiveMainloop::PipelineKV::ThreadCategory::Consumer; + } + pipeline_load_kv_params.is_leader = lane_predicate && (role == WarpRole::Load); + pipeline_load_kv_params.transaction_bytes = CollectiveMainloop::TransactionBytesLoadK; + typename CollectiveMainloop::PipelineKV pipeline_load_kv( + shared_storage.pipelines.load_kv, + pipeline_load_kv_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename CollectiveMainloop::PipelineS::Params pipeline_mma_s0_params; + if (role == WarpRole::MMA) { + pipeline_mma_s0_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Producer; + } + if (role == WarpRole::Softmax0) { + pipeline_mma_s0_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Consumer; + } + pipeline_mma_s0_params.consumer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineS pipeline_mma_s0( + shared_storage.pipelines.mma_s0, + pipeline_mma_s0_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename CollectiveMainloop::PipelineS::Params pipeline_mma_s1_params; + if (role == WarpRole::MMA) { + pipeline_mma_s1_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Producer; + } + if (role == WarpRole::Softmax1) { + pipeline_mma_s1_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Consumer; + } + pipeline_mma_s1_params.consumer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineS pipeline_mma_s1( + shared_storage.pipelines.mma_s1, + pipeline_mma_s1_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename CollectiveMainloop::PipelineC::Params pipeline_s0_corr_params; + if (role == WarpRole::Softmax0) { + pipeline_s0_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Producer; + } + if (role == WarpRole::Correction) { + pipeline_s0_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Consumer; + } + pipeline_s0_corr_params.producer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + pipeline_s0_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineC pipeline_s0_corr( + shared_storage.pipelines.s0_corr, + pipeline_s0_corr_params, + /*barrier init*/ cute::true_type{}); + + typename CollectiveMainloop::PipelineC::Params pipeline_s1_corr_params; + if (role == WarpRole::Softmax1) { + pipeline_s1_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Producer; + } + if (role == WarpRole::Correction) { + pipeline_s1_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Consumer; + } + pipeline_s1_corr_params.producer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + pipeline_s1_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineC pipeline_s1_corr( + shared_storage.pipelines.s1_corr, + pipeline_s1_corr_params, + /*barrier init*/ cute::true_type{}); + + typename CollectiveMainloop::PipelineO::Params pipeline_mma_corr_params; + if (role == WarpRole::MMA) { + pipeline_mma_corr_params.role = CollectiveMainloop::PipelineO::ThreadCategory::Producer; + } + if (role == WarpRole::Correction) { + pipeline_mma_corr_params.role = CollectiveMainloop::PipelineO::ThreadCategory::Consumer; + } + pipeline_mma_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineO pipeline_mma_corr( + shared_storage.pipelines.mma_corr, + pipeline_mma_corr_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename CollectiveMainloop::PipelineE::Params pipeline_corr_epi_params; + if (role == WarpRole::Correction) { + pipeline_corr_epi_params.role = CollectiveMainloop::PipelineE::ThreadCategory::Producer; + } + if (role == WarpRole::Epilogue) { + pipeline_corr_epi_params.role = CollectiveMainloop::PipelineE::ThreadCategory::Consumer; + } + pipeline_corr_epi_params.producer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; + pipeline_corr_epi_params.consumer_arv_count = NumWarpsEpilogue * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineE pipeline_corr_epi( + shared_storage.pipelines.corr_epi, + pipeline_corr_epi_params, + /*barrier init*/ cute::true_type{}); + + typename CollectiveMainloop::OrderBarrierSoftmax::Params params_order_s01; + params_order_s01.group_id = role == WarpRole::Softmax1 ? 1 : 0; + params_order_s01.group_size = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::OrderBarrierSoftmax order_s01( + shared_storage.pipelines.order_s01, params_order_s01); + + TmemAllocator tmem_allocator; + + __syncthreads(); + + pipeline_load_q.init_masks(ClusterShape{}); + pipeline_load_kv.init_masks(ClusterShape{}); + pipeline_mma_s0.init_masks(ClusterShape{}); + pipeline_mma_s1.init_masks(ClusterShape{}); + pipeline_mma_corr.init_masks(ClusterShape{}); + + typename CollectiveMainloop::PipelineQ::PipelineState pipeline_load_q_consumer_state; + typename CollectiveMainloop::PipelineQ::PipelineState pipeline_load_q_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineKV::PipelineState pipeline_load_kv_consumer_state; + typename CollectiveMainloop::PipelineKV::PipelineState pipeline_load_kv_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s0_consumer_state; + typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s0_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s1_consumer_state; + typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s1_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineC::PipelineState pipeline_s0_corr_consumer_state; + typename CollectiveMainloop::PipelineC::PipelineState pipeline_s0_corr_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineC::PipelineState pipeline_s1_corr_consumer_state; + typename CollectiveMainloop::PipelineC::PipelineState pipeline_s1_corr_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineE::PipelineState pipeline_corr_epi_consumer_state; + typename CollectiveMainloop::PipelineE::PipelineState pipeline_corr_epi_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_consumer_state; + typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_producer_state = cutlass::make_producer_start_state(); + + CollectiveMainloop mainloop; + CollectiveEpilogue epilogue{params.epilogue}; + + if (role == WarpRole::Softmax0 || role == WarpRole::Softmax1) { + warpgroup_reg_set(); + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = apply_batch(params, + params.problem_shape, get<2,1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + if (get<1>(logical_problem_shape) == 0) { + continue; + } + + bool is_softmax_0 = role == WarpRole::Softmax0; + + mainloop.softmax( + is_softmax_0 ? 0 : 1, blk_coord, + params.mainloop, logical_problem_shape, + is_softmax_0 ? pipeline_mma_s0 : pipeline_mma_s1, + is_softmax_0 ? pipeline_mma_s0_consumer_state : pipeline_mma_s1_consumer_state, + is_softmax_0 ? pipeline_s0_corr : pipeline_s1_corr, + is_softmax_0 ? pipeline_s0_corr_producer_state : pipeline_s1_corr_producer_state, + order_s01 + ); + + } + } + else if (role == WarpRole::Correction) { + cutlass::arch::warpgroup_reg_dealloc(); + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = apply_batch(params, + params.problem_shape, get<2,1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + if (get<1>(logical_problem_shape) == 0) { + mainloop.correction_empty( + blk_coord, + params.mainloop, logical_problem_shape, + params.problem_shape, + epilogue_storage, + pipeline_corr_epi, pipeline_corr_epi_producer_state, + epilogue + ); + continue; + } + + mainloop.correction( + blk_coord, + params.mainloop, logical_problem_shape, + params.problem_shape, + epilogue_storage, + pipeline_s0_corr, pipeline_s0_corr_consumer_state, + pipeline_s1_corr, pipeline_s1_corr_consumer_state, + pipeline_mma_corr, pipeline_mma_corr_consumer_state, + pipeline_corr_epi, pipeline_corr_epi_producer_state, + epilogue + ); + + } + + if constexpr (NumWarpsEpilogue == 0) { + static_assert(NumWarpsCorrection == 1); + + uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + + } + else if (role == WarpRole::MMA) { + warpgroup_reg_set(); + + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = apply_batch(params, + params.problem_shape, get<2,1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + if (get<1>(logical_problem_shape) == 0) { + continue; + } + + mainloop.mma( + blk_coord, + params.mainloop, logical_problem_shape, + shared_storage.mainloop_epilogue.mainloop, + pipeline_load_q, pipeline_load_q_consumer_state, + pipeline_load_kv, pipeline_load_kv_consumer_state, + pipeline_mma_s0, pipeline_mma_s0_producer_state, + pipeline_mma_s1, pipeline_mma_s1_producer_state, + pipeline_mma_corr, pipeline_mma_corr_producer_state + ); + + } + } + else if (role == WarpRole::Load) { + warpgroup_reg_set(); + + if constexpr (IsMla && CollectiveMainloop::IsOrderLoadEpilogue) { + cutlass::arch::NamedBarrier::arrive((NumWarpsLoad + NumWarpsEpilogue) * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + } + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = apply_batch(params, + params.problem_shape, get<2,1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + if (get<1>(logical_problem_shape) == 0) { + continue; + } + + mainloop.load( + blk_coord, logical_problem_shape, + params.mainloop, params.problem_shape, + shared_storage.mainloop_epilogue.mainloop, + pipeline_load_q, pipeline_load_q_producer_state, + pipeline_load_kv, pipeline_load_kv_producer_state + ); + + } + } + else if (role == WarpRole::Epilogue) { + warpgroup_reg_set(); + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = apply_batch(params, + params.problem_shape, get<2,1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + epilogue.store( + blk_coord, logical_problem_shape, + params.epilogue, params.problem_shape, + epilogue_storage, + pipeline_corr_epi, pipeline_corr_epi_consumer_state + ); + + } + + static_assert(NumWarpsEpilogue <= 1); + if constexpr (NumWarpsEpilogue == 1) { + uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + + } + else if (role == WarpRole::Empty) { + warpgroup_reg_set(); + + /* no-op, donate regs and exit */ + } + } + +}; + +} // namespace cutlass::fmha::kernel diff --git a/csrc/sm100/pybind.cu b/csrc/sm100/pybind.cu new file mode 100644 index 0000000..7d4744d --- /dev/null +++ b/csrc/sm100/pybind.cu @@ -0,0 +1,17 @@ +#include + +void FMHACutlassSM100FwdRun(at::Tensor workspace_buffer, at::Tensor q, at::Tensor k, at::Tensor v, + at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv, + at::Tensor o, at::Tensor lse, + int mask_mode_code, float softmax_scale, int max_seqlen_q, int max_seqlen_kv, bool is_varlen); + +void FMHACutlassSM100BwdRun(at::Tensor workspace_buffer, at::Tensor d_o, at::Tensor q, at::Tensor k, + at::Tensor v, at::Tensor o, at::Tensor lse, + at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv, + at::Tensor dq, at::Tensor dk, at::Tensor dv, + int mask_mode_code, float softmax_scale, int max_seqlen_q, int max_seqlen_kv, bool is_varlen); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("fwd", &FMHACutlassSM100FwdRun); + m.def("bwd", &FMHACutlassSM100BwdRun); +} diff --git a/csrc/flash_api.cpp b/csrc/sm90/flash_api.cpp similarity index 100% rename from csrc/flash_api.cpp rename to csrc/sm90/flash_api.cpp diff --git a/csrc/kernels/config.h b/csrc/sm90/kernels/config.h similarity index 100% rename from csrc/kernels/config.h rename to csrc/sm90/kernels/config.h diff --git a/csrc/kernels/get_mla_metadata.cu b/csrc/sm90/kernels/get_mla_metadata.cu similarity index 100% rename from csrc/kernels/get_mla_metadata.cu rename to csrc/sm90/kernels/get_mla_metadata.cu diff --git a/csrc/kernels/get_mla_metadata.h b/csrc/sm90/kernels/get_mla_metadata.h similarity index 100% rename from csrc/kernels/get_mla_metadata.h rename to csrc/sm90/kernels/get_mla_metadata.h diff --git a/csrc/kernels/mla_combine.cu b/csrc/sm90/kernels/mla_combine.cu similarity index 100% rename from csrc/kernels/mla_combine.cu rename to csrc/sm90/kernels/mla_combine.cu diff --git a/csrc/kernels/mla_combine.h b/csrc/sm90/kernels/mla_combine.h similarity index 100% rename from csrc/kernels/mla_combine.h rename to csrc/sm90/kernels/mla_combine.h diff --git a/csrc/kernels/params.h b/csrc/sm90/kernels/params.h similarity index 100% rename from csrc/kernels/params.h rename to csrc/sm90/kernels/params.h diff --git a/csrc/kernels/splitkv_mla.cu b/csrc/sm90/kernels/splitkv_mla.cu similarity index 100% rename from csrc/kernels/splitkv_mla.cu rename to csrc/sm90/kernels/splitkv_mla.cu diff --git a/csrc/kernels/splitkv_mla.h b/csrc/sm90/kernels/splitkv_mla.h similarity index 100% rename from csrc/kernels/splitkv_mla.h rename to csrc/sm90/kernels/splitkv_mla.h diff --git a/csrc/kernels/traits.h b/csrc/sm90/kernels/traits.h similarity index 100% rename from csrc/kernels/traits.h rename to csrc/sm90/kernels/traits.h diff --git a/csrc/kernels/utils.h b/csrc/sm90/kernels/utils.h similarity index 100% rename from csrc/kernels/utils.h rename to csrc/sm90/kernels/utils.h diff --git a/flash_mla/__init__.py b/flash_mla/__init__.py index 51b8600..d0e6faf 100644 --- a/flash_mla/__init__.py +++ b/flash_mla/__init__.py @@ -3,4 +3,7 @@ from flash_mla.flash_mla_interface import ( get_mla_metadata, flash_mla_with_kvcache, + flash_attn_varlen_func, + flash_attn_varlen_qkvpacked_func, + flash_attn_varlen_kvpacked_func, ) diff --git a/flash_mla/flash_mla_interface.py b/flash_mla/flash_mla_interface.py index 47637f8..9c669ba 100644 --- a/flash_mla/flash_mla_interface.py +++ b/flash_mla/flash_mla_interface.py @@ -2,7 +2,9 @@ import torch -import flash_mla_cuda +import flash_mla_sm90 +import flash_mla_sm100 + def get_mla_metadata( @@ -20,10 +22,10 @@ def get_mla_metadata( tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. num_splits: (batch_size + 1), dtype torch.int32. """ - return flash_mla_cuda.get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k) + return flash_mla_sm90.get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k) -def flash_mla_with_kvcache( +def flash_mla_with_kvcache_sm90( q: torch.Tensor, k_cache: torch.Tensor, block_table: torch.Tensor, @@ -52,7 +54,7 @@ def flash_mla_with_kvcache( """ if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) - out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla( + out, softmax_lse = flash_mla_sm90.fwd_kvcache_mla( q, k_cache, head_dim_v, @@ -64,3 +66,264 @@ def flash_mla_with_kvcache( num_splits, ) return out, softmax_lse + + +def _flash_attn_varlen_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + out: Optional[torch.Tensor] = None, + lse: Optional[torch.Tensor] = None, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + qo_total_len, num_qo_heads, head_dim_qk = q.shape + kv_total_len, num_kv_heads, head_dim_vo = v.shape + + mask_mode_code = 1 if causal else 0 + if softmax_scale is None: + softmax_scale = head_dim_qk ** (-0.5) + + if out is None: + out = torch.empty(qo_total_len, num_qo_heads, head_dim_vo, device=q.device, dtype=q.dtype) + if lse is None: + # Make lse contiguous on seqlen dim + lse = torch.empty(num_qo_heads, qo_total_len, device=q.device, dtype=torch.float32).T + + workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.uint8, device=q.device) + flash_mla_sm100.fwd( + workspace_buffer, + q, + k, + v, + cu_seqlens_qo, + cu_seqlens_kv, + out, + lse, + mask_mode_code, + softmax_scale, + max_seqlen_qo, + max_seqlen_kv, + is_varlen, + ) + + return out, lse + + +def _flash_attn_varlen_backward( + do: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + lse: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + dq: Optional[torch.Tensor] = None, + dk: Optional[torch.Tensor] = None, + dv: Optional[torch.Tensor] = None, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + qo_total_len, num_qo_heads, head_dim_qk = q.shape + kv_total_len, num_kv_heads, head_dim_vo = v.shape + + # TODO: fix bwd GQA + if num_qo_heads != num_kv_heads: + raise ValueError(f"SM100 bwd doesn't support GQA now. num_qo_heads: {num_qo_heads}, num_kv_heads: {num_kv_heads}.") + + mask_mode_code = 1 if causal else 0 + if softmax_scale is None: + softmax_scale = head_dim_qk ** (-0.5) + + if dq is None: + dq = torch.empty(qo_total_len, num_qo_heads, head_dim_qk, device=q.device, dtype=q.dtype) + if dk is None: + dk = torch.empty(kv_total_len, num_kv_heads, head_dim_qk, device=q.device, dtype=q.dtype) + if dv is None: + dv = torch.empty(kv_total_len, num_kv_heads, head_dim_vo, device=q.device, dtype=q.dtype) + + max_seqlen_qo_aligned = (max_seqlen_qo + 7) // 8 * 8 + bs = cu_seqlens_qo.shape[0] - 1 + workspace_bytes = 0 + workspace_bytes += 4 * qo_total_len * num_qo_heads * head_dim_qk # dQ_acc + workspace_bytes += 4 * max_seqlen_qo_aligned * bs * num_qo_heads * 2 # sum_OdO and scaled_lse + if num_qo_heads != num_kv_heads: + workspace_bytes += 2 * kv_total_len * num_qo_heads * (head_dim_qk + head_dim_vo) # dKV_acc + workspace_buffer = torch.empty(workspace_bytes, dtype=torch.uint8, device=q.device) + flash_mla_sm100.bwd( + workspace_buffer, + do, + q, + k, + v, + out, + lse, + cu_seqlens_qo, + cu_seqlens_kv, + dq, + dk, + dv, + mask_mode_code, + softmax_scale, + max_seqlen_qo, + max_seqlen_kv, + is_varlen, + ) + + return dq, dk, dv + + +class FlashAttnVarlenFunc(torch.autograd.Function): + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, + ): + out, lse = _flash_attn_varlen_forward( + q, k, v, + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal=causal, softmax_scale=softmax_scale, + is_varlen=is_varlen, + ) + ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv) + ctx.max_seqlen_qo = max_seqlen_qo + ctx.max_seqlen_kv = max_seqlen_kv + ctx.causal = causal + ctx.softmax_scale = softmax_scale + ctx.is_varlen = is_varlen + return out, lse + + def backward( + ctx, + do: torch.Tensor, + dlse: torch.Tensor, + ): + del dlse # LSE doesn't support backward currently + q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv = ctx.saved_tensors + dq, dk, dv = _flash_attn_varlen_backward( + do, q, k, v, out, lse, + cu_seqlens_qo, cu_seqlens_kv, ctx.max_seqlen_qo, ctx.max_seqlen_kv, + causal=ctx.causal, softmax_scale=ctx.softmax_scale, + is_varlen=ctx.is_varlen, + ) + return dq, dk, dv, None, None, None, None, None, None, None + + +def flash_attn_varlen_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + q, k, v, + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal, softmax_scale, is_varlen, + ) + + +def flash_attn_varlen_qkvpacked_func( + qkv: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + qkv[:, :, :head_dim_qk], qkv[:, :, head_dim_qk:head_dim_qk * 2], qkv[:, :, head_dim_qk * 2:], + cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, + causal, softmax_scale, is_varlen, + ) + + +def flash_attn_varlen_kvpacked_func( + q: torch.Tensor, + kv: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + q, kv[:, :, :head_dim_qk], kv[:, :, head_dim_qk:], + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal, softmax_scale, is_varlen, + ) + + +def flash_mla_with_kvcache_sm100( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: torch.Tensor, + cache_seqlens: torch.Tensor, + head_dim_v: int, + softmax_scale: Optional[float] = None, + causal: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + # TODO + pass + + +def flash_mla_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: torch.Tensor, + cache_seqlens: torch.Tensor, + head_dim_v: int, + tile_scheduler_metadata: Optional[torch.Tensor] = None, + num_splits: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + causal: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + capability = torch.cuda.get_device_capability(q.device.index) + if capability == (9, 0): + return flash_mla_with_kvcache_sm90( + q, k_cache, block_table, cache_seqlens, head_dim_v, + tile_scheduler_metadata, num_splits, + softmax_scale, causal, + ) + elif capability == (10, 0): + raise ValueError(f"Unsupported device capability: {capability}") + else: + raise ValueError(f"Unsupported device capability: {capability}") diff --git a/setup.py b/setup.py index 217f540..58cf7b2 100644 --- a/setup.py +++ b/setup.py @@ -27,9 +27,13 @@ def get_features_args(): subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"]) -cc_flag = [] -cc_flag.append("-gencode") -cc_flag.append("arch=compute_90a,code=sm_90a") +cc_flag_sm90 = [] +cc_flag_sm90.append("-gencode") +cc_flag_sm90.append("arch=compute_90a,code=sm_90a") + +cc_flag_sm100 = [] +cc_flag_sm100.append("-gencode") +cc_flag_sm100.append("arch=compute_100a,code=sm_100a") this_dir = os.path.dirname(os.path.abspath(__file__)) @@ -41,12 +45,12 @@ def get_features_args(): ext_modules = [] ext_modules.append( CUDAExtension( - name="flash_mla_cuda", + name="flash_mla_sm90", sources=[ - "csrc/flash_api.cpp", - "csrc/kernels/get_mla_metadata.cu", - "csrc/kernels/mla_combine.cu", - "csrc/kernels/splitkv_mla.cu", + "csrc/sm90/flash_api.cpp", + "csrc/sm90/kernels/get_mla_metadata.cu", + "csrc/sm90/kernels/mla_combine.cu", + "csrc/sm90/kernels/splitkv_mla.cu", ], extra_compile_args={ "cxx": cxx_args + get_features_args(), @@ -66,12 +70,49 @@ def get_features_args(): "--use_fast_math", "--ptxas-options=-v,--register-usage-level=10" ] - + cc_flag + + cc_flag_sm90 ) + get_features_args(), }, include_dirs=[ - Path(this_dir) / "csrc", + Path(this_dir) / "csrc" / "sm90", + Path(this_dir) / "csrc" / "cutlass" / "include", + ], + ) +) + +ext_modules.append( + CUDAExtension( + name="flash_mla_sm100", + sources=[ + "csrc/sm100/pybind.cu", + "csrc/sm100/fmha_cutlass_fwd_sm100.cu", + "csrc/sm100/fmha_cutlass_bwd_sm100.cu", + ], + extra_compile_args={ + "cxx": ["-O3", "-std=c++17", "-DNDEBUG", "-Wno-deprecated-declarations"], + "nvcc": append_nvcc_threads( + [ + "-O3", + "-std=c++17", + "-DNDEBUG", + "-Wno-deprecated-declarations", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + "-lineinfo", + "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage", + ] + + cc_flag_sm100 + ), + }, + include_dirs=[ + Path(this_dir) / "csrc" / "sm100", Path(this_dir) / "csrc" / "cutlass" / "include", + Path(this_dir) / "csrc" / "cutlass" / "tools" / "util" / "include", ], ) ) diff --git a/tests/test_flash_mla.py b/tests/test_flash_mla_sm90.py similarity index 100% rename from tests/test_flash_mla.py rename to tests/test_flash_mla_sm90.py diff --git a/tests/test_fmha_sm100.py b/tests/test_fmha_sm100.py new file mode 100644 index 0000000..832c9fb --- /dev/null +++ b/tests/test_fmha_sm100.py @@ -0,0 +1,199 @@ +import random + +import torch +from torch.utils.checkpoint import checkpoint +import triton + +from flash_mla import flash_attn_varlen_func + + +def get_window_size(causal, window): + if window > 0: + window_size = (window - 1, 0) if causal else (window - 1, window - 1) + else: + window_size = (-1, -1) + return window_size + + +def get_attn_bias(s_q, s_k, causal, window): + attn_bias = torch.zeros(s_q, s_k, dtype=torch.float32) + if causal: + temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + if window > 0: + temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q - window) + attn_bias.masked_fill_(temp_mask, float("-inf")) + temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q + window - 1) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + return attn_bias + + +def assert_close(x: torch.Tensor, y: torch.Tensor, name: str) -> None: + x, y = x.double(), y.double() + RMSE = ((x - y) * (x - y)).mean().sqrt().item() + cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12) + amax_diff = (x - y).abs().max().item() + # print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}") + assert cos_diff < 1e-5, f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}" + + +def sdpa(query, key, value, attn_bias, softmax_scale=None): + key = key.repeat_interleave(h // h_k, dim=-3) + value = value.repeat_interleave(h // h_k, dim=-3) + if softmax_scale is None: + softmax_scale = query.shape[-1] ** (-0.5) + attn_weight = query @ key.transpose(-2, -1) * softmax_scale + attn_weight += attn_bias + lse = attn_weight.logsumexp(dim=-1) + attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) + return attn_weight.to(query.dtype) @ value, lse + + +def sdpa_checkpoint(*args, **kwargs): + return checkpoint(sdpa, *args, use_reentrant=False, **kwargs) + + +def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, window, has_bwd): + print(f"{b=}, {mean_sq=}, {mean_sk=}, {varlen=}, {h=}, {h_k=}, {d=}, {dv=}, {causal=}") + torch.manual_seed(0) + random.seed(0) + + seqlens_q = torch.full((b,), mean_sq, dtype=torch.int32) + seqlens_k = torch.full((b,), mean_sk, dtype=torch.int32) + + if varlen: + for i in range(b): + seqlens_q[i] = max(random.normalvariate(mean_sq, mean_sq / 2), 1) + for i in range(b): + seqlens_k[i] = max(random.normalvariate(mean_sk, mean_sk / 2), seqlens_q[i].item()) + cu_seqlens_q = torch.cumsum(torch.nn.functional.pad(seqlens_q, (1, 0)), 0, dtype=torch.int32) + cu_seqlens_k = torch.cumsum(torch.nn.functional.pad(seqlens_k, (1, 0)), 0, dtype=torch.int32) + total_q = seqlens_q.sum().item() + total_k = seqlens_k.sum().item() + max_seqlen_q = seqlens_q.max().item() + max_seqlen_k = seqlens_k.max().item() + total_attn_compute = sum([(get_attn_bias(seqlens_q[i].item(), seqlens_k[i].item(), + causal, window) == 0).sum().item() for i in range(b)]) + # print(f"{total_q=}, {max_seqlen_q=}, {total_k=}, {max_seqlen_k=}, {total_attn_compute=}, {cu_seqlens_q.tolist()}, {cu_seqlens_k.tolist()}") + + q = torch.randn(total_q, h, d) + k = torch.randn(total_k, h_k, d) + v = torch.randn(total_k, h_k, dv) + grad_out = torch.randn(total_q, h, dv) + softmax_scale = (d + 100) ** (-0.5) + + offst_q = total_q + offst_kv = total_k + + q1_with_buffer = torch.empty(total_q + total_q, h, d, device=device, dtype=dtype) + k1_with_buffer = torch.empty(offst_kv + total_k, h_k, d, device=device, dtype=dtype) + v1_with_buffer = torch.empty(offst_kv + total_k, h_k, dv, device=device, dtype=dtype) + q1_with_buffer[total_q:] = q + k1_with_buffer[offst_kv:] = k + v1_with_buffer[offst_kv:] = v + q1 = q1_with_buffer[offst_q:].requires_grad_() + k1 = k1_with_buffer[offst_kv:].requires_grad_() + v1 = v1_with_buffer[offst_kv:].requires_grad_() + + q2 = q.clone().requires_grad_() + k2 = k.clone().requires_grad_() + v2 = v.clone().requires_grad_() + + def flash_attn(): + q1.grad = k1.grad = v1.grad = None + kwargs = {} + if causal: + kwargs["causal"] = causal + if window != 0: + kwargs["window_size"] = get_window_size(causal, window) + return flash_attn_varlen_func(q1, k1, v1, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, + max_seqlen_k, softmax_scale=softmax_scale, is_varlen=varlen, **kwargs) + + def torch_attn(): + q2.grad = k2.grad = v2.grad = None + out = [] + lse = [] + for i in range(b): + OUT, LSE = sdpa_checkpoint( + q2[cu_seqlens_q[i].item(): cu_seqlens_q[i + 1].item()].float().transpose(-3, -2), + k2[cu_seqlens_k[i].item(): cu_seqlens_k[i + 1].item()].float().transpose(-3, -2), + v2[cu_seqlens_k[i].item(): cu_seqlens_k[i + 1].item()].float().transpose(-3, -2), + attn_bias=get_attn_bias(seqlens_q[i].item(), seqlens_k[i].item(), causal, window), + softmax_scale=softmax_scale, + ) + out.append(OUT.transpose(-3, -2)) + lse.append(LSE.transpose(-2, -1)) + out = torch.cat(out) + lse = torch.cat(lse) + return out, lse + + out_flash, lse_flash = flash_attn() + out_torch, lse_torch = torch_attn() + assert_close(out_flash, out_torch, "out") + assert_close(lse_flash, lse_torch, "lse") + + if has_bwd: + out_flash.backward(grad_out, retain_graph=True) + out_torch.backward(grad_out, retain_graph=True) + assert_close(q1.grad, q2.grad, "dq") + assert_close(k1.grad, k2.grad, "dk") + assert_close(v1.grad, v2.grad, "dv") + dq1 = q1.grad.clone() + dk1 = k1.grad.clone() + dv1 = v1.grad.clone() + + def forward(): + return flash_attn() + + def backward(): + q1.grad = k1.grad = v1.grad = None + out_flash.backward(grad_out, retain_graph=True) + + for _ in range(5): + out, lse = forward() + assert torch.equal(out, out_flash), "out deterministic check failed!" + assert torch.equal(lse, lse_flash), "lse deterministic check failed!" + if has_bwd: + backward() + # assert torch.equal(q1.grad, dq1), "dq deterministic check failed!" + assert torch.equal(k1.grad, dk1), "dk deterministic check failed!" + assert torch.equal(v1.grad, dv1), "dv deterministic check failed!" + + # with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: + # forward() + # if has_bwd: + # backward() + # print(prof.key_averages().table(sort_by="cuda_time_total", max_name_column_width=120)) + + def timer(func, name): + t = triton.testing.do_bench(func, warmup=2, rep=3) + FLOPS = total_attn_compute * h * 2 * ((d + dv) if name == "fwd" else ((d * 3 + dv * 2))) + print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOP/s, name: {name}") + return t + + timer(forward, "fwd") + if has_bwd: + timer(backward, "bwd") + + +if __name__ == "__main__": + dtype = torch.bfloat16 + torch.set_default_dtype(dtype) + device = torch.device("cuda:0") + torch.set_default_device(device) + torch.cuda.set_device(device) + + b = 4 + window = 0 + has_bwd = False + + for (mean_sq, mean_sk) in [(4096, 4096), (8192, 8192)]: + for varlen in [False, True]: + for (h, h_k) in [(32, 32), (32, 4)]: + if h != h_k: + has_bwd = False + else: + has_bwd = True + for (d, dv) in [(128, 128), (192, 128)]: + for causal in [False, True]: + test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, window, has_bwd) From ef5b1a69fc56fc8ea9405509cf67d2984acd205d Mon Sep 17 00:00:00 2001 From: Jiashi Li Date: Thu, 14 Aug 2025 09:34:17 +0800 Subject: [PATCH 05/20] Drop support for CUDA <12.8 --- README.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/README.md b/README.md index 07e021a..75bbb16 100644 --- a/README.md +++ b/README.md @@ -19,8 +19,7 @@ Currently released: ## Requirements - Hopper GPUs -- CUDA 12.3 and above - - **But we highly recommend 12.8 or above for the best performance** +- CUDA 12.8 and above - PyTorch 2.0 and above ## Quick start From c7590278ce1bfb86440e514c697bc4190aecd19c Mon Sep 17 00:00:00 2001 From: Jiashi Li Date: Thu, 14 Aug 2025 09:37:44 +0800 Subject: [PATCH 06/20] Fix accuracy issue in sum_OdO kernel --- csrc/sm100/kernel/fmha_kernel_bwd_sum_OdO.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/sm100/kernel/fmha_kernel_bwd_sum_OdO.hpp b/csrc/sm100/kernel/fmha_kernel_bwd_sum_OdO.hpp index bdcf1cb..db6a9b4 100644 --- a/csrc/sm100/kernel/fmha_kernel_bwd_sum_OdO.hpp +++ b/csrc/sm100/kernel/fmha_kernel_bwd_sum_OdO.hpp @@ -140,7 +140,7 @@ struct FmhaKernelBwdSumOdO { *reinterpret_cast(value_dO) = *reinterpret_cast(&ptr_dO_bhq[idx_d]); for (int v = 0; v < kElementsPerLoad; v++) { - acc += value_O[v] * value_dO[v]; + acc += ElementAcc(value_O[v]) * ElementAcc(value_dO[v]); } } From 2d291b0c31050ba259e87a4ae6fbf75a47824716 Mon Sep 17 00:00:00 2001 From: zhang Date: Mon, 25 Aug 2025 11:41:50 +0800 Subject: [PATCH 07/20] Remove tma padding for fwd inputs (#85) --- csrc/sm100/collective/fmha_fusion.hpp | 6 +-- .../sm100_fmha_load_tma_warpspecialized.hpp | 49 ++++++------------- ...m100_fmha_mla_load_tma_warpspecialized.hpp | 49 ++++++------------- csrc/sm100/fmha_cutlass_fwd_sm100.cu | 5 +- csrc/sm100/fmha_cutlass_fwd_sm100.cuh | 11 ++--- ...00_fmha_fwd_kernel_tma_warpspecialized.hpp | 29 ++++++++--- tests/test_fmha_sm100.py | 15 ++---- 7 files changed, 68 insertions(+), 96 deletions(-) diff --git a/csrc/sm100/collective/fmha_fusion.hpp b/csrc/sm100/collective/fmha_fusion.hpp index 1486767..8c09eaf 100644 --- a/csrc/sm100/collective/fmha_fusion.hpp +++ b/csrc/sm100/collective/fmha_fusion.hpp @@ -220,13 +220,13 @@ struct CausalMask : NoMask { BlkCoord const& blk_coord, TileShape const& tile_shape, ProblemSize const& problem_size) { - + int trip_count = get_trip_count(blk_coord, tile_shape, problem_size); if constexpr (IsQBegin) { return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape)))); } else { - const int offset_tile_q = get<1>(problem_size) % get<1>(tile_shape); - return std::min(trip_count, int(ceil_div(get<0>(tile_shape) + offset_tile_q, get<1>(tile_shape)))); + const int corner_count = int((get<1>(problem_size) % get<1>(tile_shape) || get<0>(problem_size) % get<0>(tile_shape))) ; + return std::min(trip_count, int(ceil_div(get<0>(tile_shape), get<1>(tile_shape))) + corner_count); } } diff --git a/csrc/sm100/collective/sm100_fmha_load_tma_warpspecialized.hpp b/csrc/sm100/collective/sm100_fmha_load_tma_warpspecialized.hpp index 1951056..3606dcc 100644 --- a/csrc/sm100/collective/sm100_fmha_load_tma_warpspecialized.hpp +++ b/csrc/sm100/collective/sm100_fmha_load_tma_warpspecialized.hpp @@ -95,32 +95,21 @@ struct Sm100FmhaLoadTmaWarpspecialized { auto dQ = args.dQ; auto dK = args.dK; auto dV = args.dV; - auto problem_shape_qk = problem_shape; + using IntProblemShape = cute::tuple, int>>; + + IntProblemShape problem_shape_qk; if constexpr (is_variable_length_v>) { auto cumulative_length_q = get<0>(problem_shape).cumulative_length; - if (cumulative_length_q != nullptr) { - int max_length_q = get<0>(problem_shape).max_length; - // for variable sequence lenght, the batch is in units of row_stride - get<2,1>(dQ) = get<0>(dQ); - get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_q * (1 + get<3,1>(problem_shape))); - // offset ptr by the amount we add back in later - ptr_Q -= max_length_q * get<0>(dQ); - } - } - - if constexpr (is_variable_length_v>) { - auto cumulative_length_kv = get<1>(problem_shape).cumulative_length; - if (cumulative_length_kv != nullptr) { - int max_length_kv = get<1>(problem_shape).max_length; - // for variable sequence lenght, the batch is in units of row_stride - get<2,1>(dK) = get<0>(dK); - get<2,1>(dV) = get<0>(dV); - get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_kv * (1 + get<3,1>(problem_shape))); - // offset ptr by the amount we add back in later - ptr_K -= max_length_kv * get<0>(dK); - ptr_V -= max_length_kv * get<0>(dV); + auto cumulative_length_k = get<1>(problem_shape).cumulative_length; + if (cumulative_length_q != nullptr && cumulative_length_k != nullptr ) { + get<0>(problem_shape_qk) = get<0>(problem_shape).total_length; + get<1>(problem_shape_qk) = get<1>(problem_shape).total_length; + get<2>(problem_shape_qk) = get<2>(problem_shape); + get<3>(problem_shape_qk) = get<3>(problem_shape); } + } else { + problem_shape_qk = problem_shape; } auto params_qk = CollectiveMmaQK::to_underlying_arguments( @@ -181,19 +170,16 @@ struct Sm100FmhaLoadTmaWarpspecialized { Tensor mQ_qdl_p = params.tma_load_q.get_tma_tensor(select<0,2,3>(problem_shape)); int q_offs_0 = 0; - int q_offs_2_1 = 0; if constexpr (is_variable_length_v>) { auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length; if (cumulative_length_q != nullptr) { - int max_length_q = get<0>(params_problem_shape).max_length; - q_offs_0 = max_length_q - get<0>(problem_shape); - q_offs_2_1 = cumulative_length_q[get<2,1>(blk_coord_q)] + get<0>(problem_shape); + q_offs_0 = cumulative_length_q[get<2,1>(blk_coord_q)]; get<2,1>(blk_coord_q) = 0; } } - Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, q_offs_2_1)), mQ_qdl_p); + Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, _0{})), mQ_qdl_p); Tensor gQ_qdl = local_tile(mQ_qdl, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{}); Tensor tSgQ_qdl = mma_qk.partition_A(gQ_qdl); @@ -208,19 +194,16 @@ struct Sm100FmhaLoadTmaWarpspecialized { Tensor mK_kdl_p = params.tma_load_k.get_tma_tensor(select<1,2,3>(problem_shape)); int kv_offs_0 = 0; - int kv_offs_2_1 = 0; if constexpr (is_variable_length_v>) { auto cumulative_length = get<1>(params_problem_shape).cumulative_length; if (cumulative_length != nullptr) { - int max_length = get<1>(params_problem_shape).max_length; - kv_offs_0 = max_length - get<1>(problem_shape); - kv_offs_2_1 = cumulative_length[get<2,1>(blk_coord_kv)] + get<1>(problem_shape); + kv_offs_0 = cumulative_length[get<2,1>(blk_coord_kv)]; get<2,1>(blk_coord_kv) = 0; } } - Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, kv_offs_2_1)), mK_kdl_p); + Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, _0{})), mK_kdl_p); Tensor gK_kdl = local_tile(mK_kdl, TileShapeQK{}, make_coord(_, _, _), Step{}); Tensor tSgK_kdl = mma_qk.partition_B(gK_kdl); @@ -235,7 +218,7 @@ struct Sm100FmhaLoadTmaWarpspecialized { ThrMMA mma_pv = typename CollectiveMmaPV::TiledMma{}.get_slice(0); Tensor mV_dkl_p = params.tma_load_v.get_tma_tensor(select<2,1,3>(problem_shape)); - Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, kv_offs_2_1)), mV_dkl_p); + Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, _0{})), mV_dkl_p); Tensor gV_dkl = local_tile(mV_dkl, TileShapePV{}, make_coord(_, _, _), Step{}); Tensor tOgV_dkl = mma_pv.partition_B(gV_dkl); diff --git a/csrc/sm100/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp b/csrc/sm100/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp index c2d3e2b..c8fc13b 100644 --- a/csrc/sm100/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp +++ b/csrc/sm100/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp @@ -102,32 +102,21 @@ struct Sm100MlaFwdLoadTmaWarpspecialized { auto dQ = args.dQ; auto dK = args.dK; auto dV = args.dV; - auto problem_shape_qk = replace<2>(problem_shape, get<2, 0>(problem_shape) + get<2, 1>(problem_shape)); + using IntProblemShape = cute::tuple, int>>; + + IntProblemShape problem_shape_qk; if constexpr (is_variable_length_v>) { auto cumulative_length_q = get<0>(problem_shape).cumulative_length; - if (cumulative_length_q != nullptr) { - int max_length_q = get<0>(problem_shape).max_length; - // for variable sequence lenght, the batch is in units of row_stride - get<2,1>(dQ) = get<0>(dQ); - get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_q * (1 + get<3,1>(problem_shape))); - // offset ptr by the amount we add back in later - ptr_Q -= max_length_q * get<0>(dQ); - } - } - - if constexpr (is_variable_length_v>) { - auto cumulative_length_kv = get<1>(problem_shape).cumulative_length; - if (cumulative_length_kv != nullptr) { - int max_length_kv = get<1>(problem_shape).max_length; - // for variable sequence lenght, the batch is in units of row_stride - get<2,1>(dK) = get<0>(dK); - get<2,1>(dV) = get<0>(dV); - get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_kv * (1 + get<3,1>(problem_shape))); - // offset ptr by the amount we add back in later - ptr_K -= max_length_kv * get<0>(dK); - ptr_V -= max_length_kv * get<0>(dV); + auto cumulative_length_k = get<1>(problem_shape).cumulative_length; + if (cumulative_length_q != nullptr && cumulative_length_k != nullptr ) { + get<0>(problem_shape_qk) = get<0>(problem_shape).total_length; + get<1>(problem_shape_qk) = get<1>(problem_shape).total_length; + get<2>(problem_shape_qk) = get<2, 0>(problem_shape) + get<2, 1>(problem_shape); + get<3>(problem_shape_qk) = get<3>(problem_shape); } + } else { + problem_shape_qk = replace<2>(problem_shape, get<2, 0>(problem_shape) + get<2, 1>(problem_shape));; } auto problem_shape_pv = replace<1>(select<0,2,1,3>(problem_shape_qk), get<2, 0>(problem_shape)); @@ -192,19 +181,16 @@ struct Sm100MlaFwdLoadTmaWarpspecialized { Tensor mQ_qdl_p = params.tma_load_q.get_tma_tensor(select<0,2,3>(problem_shape_qk)); int q_offs_0 = 0; - int q_offs_2_1 = 0; if constexpr (is_variable_length_v>) { auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length; if (cumulative_length_q != nullptr) { - int max_length_q = get<0>(params_problem_shape).max_length; - q_offs_0 = max_length_q - get<0>(problem_shape); - q_offs_2_1 = cumulative_length_q[get<2,1>(blk_coord_q)] + get<0>(problem_shape); + q_offs_0 = cumulative_length_q[get<2,1>(blk_coord_q)]; get<2,1>(blk_coord_q) = 0; } } - Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, q_offs_2_1)), mQ_qdl_p); + Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, _0{})), mQ_qdl_p); Tensor gQ_qdl = local_tile(mQ_qdl, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{}); Tensor tSgQ_qdl = mma_qk.partition_A(gQ_qdl); @@ -219,19 +205,16 @@ struct Sm100MlaFwdLoadTmaWarpspecialized { Tensor mK_kdl_p = params.tma_load_k.get_tma_tensor(select<1,2,3>(problem_shape_qk)); int kv_offs_0 = 0; - int kv_offs_2_1 = 0; if constexpr (is_variable_length_v>) { auto cumulative_length = get<1>(params_problem_shape).cumulative_length; if (cumulative_length != nullptr) { - int max_length = get<1>(params_problem_shape).max_length; - kv_offs_0 = max_length - get<1>(problem_shape); - kv_offs_2_1 = cumulative_length[get<2,1>(blk_coord_kv)] + get<1>(problem_shape); + kv_offs_0 = cumulative_length[get<2,1>(blk_coord_kv)]; get<2,1>(blk_coord_kv) = 0; } } - Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, kv_offs_2_1)), mK_kdl_p); + Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, _0{})), mK_kdl_p); Tensor gK_kdl = local_tile(mK_kdl, TileShapeQK{}, make_coord(_, _, _), Step{}); Tensor tSgK_kdl = mma_qk.partition_B(gK_kdl); @@ -246,7 +229,7 @@ struct Sm100MlaFwdLoadTmaWarpspecialized { ThrMMA mma_pv = typename CollectiveMmaPV::TiledMma{}.get_slice(0); Tensor mV_dkl_p = params.tma_load_v.get_tma_tensor(select<2,1,3>(problem_shape_v)); - Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, kv_offs_2_1)), mV_dkl_p); + Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, _0{})), mV_dkl_p); Tensor gV_dkl = local_tile(mV_dkl, TileShapePV{}, make_coord(_, _, _), Step{}); Tensor tOgV_dkl = mma_pv.partition_B(gV_dkl); diff --git a/csrc/sm100/fmha_cutlass_fwd_sm100.cu b/csrc/sm100/fmha_cutlass_fwd_sm100.cu index e322709..997886e 100644 --- a/csrc/sm100/fmha_cutlass_fwd_sm100.cu +++ b/csrc/sm100/fmha_cutlass_fwd_sm100.cu @@ -18,8 +18,9 @@ void call_run_fmha_fwd([[maybe_unused]] Mask mask, [[maybe_unused]] Varlen is_va static constexpr bool IsVarlen = std::is_same_v; static constexpr bool IsMla = std::is_same_v; static constexpr bool IsCausalMask = std::is_same_v>; - using Option = std::conditional_t, - Option>; + using Option = + std::conditional_t, + Option>; run_fmha_fwd( workspace_buffer, q, k, v, cumulative_seqlen_q, cumulative_seqlen_kv, o, lse, diff --git a/csrc/sm100/fmha_cutlass_fwd_sm100.cuh b/csrc/sm100/fmha_cutlass_fwd_sm100.cuh index 71831bb..987a5f7 100644 --- a/csrc/sm100/fmha_cutlass_fwd_sm100.cuh +++ b/csrc/sm100/fmha_cutlass_fwd_sm100.cuh @@ -143,8 +143,8 @@ struct FwdRunner { ProblemShapeType problem_size_for_launch; - get<0>(problem_size_for_launch) = VariableLength{max_seqlen_q}; - get<1>(problem_size_for_launch) = VariableLength{max_seqlen_kv}; + get<0>(problem_size_for_launch) = VariableLength{max_seqlen_q, nullptr, total_seqlen_q}; + get<1>(problem_size_for_launch) = VariableLength{max_seqlen_kv, nullptr, total_seqlen_kv}; get<2>(problem_size_for_launch) = get<2>(problem_size); get<3>(problem_size_for_launch) = get<3>(problem_size); @@ -206,10 +206,6 @@ struct FwdRunner { void *q_ptr, void *k_ptr, void *v_ptr, void *o_ptr, void *lse_ptr, void *cumulative_length_q, void *cumulative_length_kv) { auto problem_shape_ = problem_shape; - if constexpr (kIsVarlen) { - get<0>(problem_shape_).cumulative_length = static_cast(cumulative_length_q); - get<1>(problem_shape_).cumulative_length = static_cast(cumulative_length_kv); - } typename Operation::Arguments arguments{ problem_shape_, @@ -230,6 +226,7 @@ struct FwdRunner { int total_seqlen_q = q.size(0); int total_seqlen_kv = k.size(0); + ProblemShapeType problem_shape = initialize(options, max_seqlen_q, max_seqlen_kv, total_seqlen_q, total_seqlen_kv, cumulative_seqlen_q.data_ptr(), cumulative_seqlen_kv.data_ptr()); @@ -322,7 +319,7 @@ void run_fmha_fwd(at::Tensor workspace, at::Tensor q, at::Tensor k, at::Tensor v auto options = get_options(); if (options.h % cutlass::fmha::kernel::CausalIndividualTileScheduler::TileH == 0 && - (!std::is_same_v)) { + (std::is_same_v> || std::is_same_v>)) { FwdRunner runner; runner.run(options, hw_info, q, k, v, o, lse, scale_softmax, workspace, cumulative_seqlen_q, cumulative_seqlen_kv, max_seqlen_q, max_seqlen_kv); diff --git a/csrc/sm100/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp b/csrc/sm100/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp index 8fe503b..43bb035 100644 --- a/csrc/sm100/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp +++ b/csrc/sm100/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp @@ -465,6 +465,8 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized { else if (role == WarpRole::Correction) { cutlass::arch::warpgroup_reg_dealloc(); + bool has_valid = false; + CUTLASS_PRAGMA_NO_UNROLL for (; tile_scheduler.is_valid(); ++tile_scheduler) { auto blk_coord = tile_scheduler.get_block_coord(); @@ -476,6 +478,8 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized { continue; } + has_valid = true; + if (get<1>(logical_problem_shape) == 0) { mainloop.correction_empty( blk_coord, @@ -505,16 +509,17 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized { if constexpr (NumWarpsEpilogue == 0) { static_assert(NumWarpsCorrection == 1); - uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; - tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + if (has_valid) { + uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } } } else if (role == WarpRole::MMA) { warpgroup_reg_set(); - tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); - __syncwarp(); + bool allocated = false; CUTLASS_PRAGMA_NO_UNROLL for (; tile_scheduler.is_valid(); ++tile_scheduler) { @@ -527,6 +532,12 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized { continue; } + if (!allocated) { + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + allocated = true; + } + if (get<1>(logical_problem_shape) == 0) { continue; } @@ -580,6 +591,8 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized { else if (role == WarpRole::Epilogue) { warpgroup_reg_set(); + bool has_valid = false; + CUTLASS_PRAGMA_NO_UNROLL for (; tile_scheduler.is_valid(); ++tile_scheduler) { auto blk_coord = tile_scheduler.get_block_coord(); @@ -591,6 +604,8 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized { continue; } + has_valid = true; + epilogue.store( blk_coord, logical_problem_shape, params.epilogue, params.problem_shape, @@ -602,8 +617,10 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized { static_assert(NumWarpsEpilogue <= 1); if constexpr (NumWarpsEpilogue == 1) { - uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; - tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + if(has_valid) { + uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } } } diff --git a/tests/test_fmha_sm100.py b/tests/test_fmha_sm100.py index 832c9fb..7cb19a2 100644 --- a/tests/test_fmha_sm100.py +++ b/tests/test_fmha_sm100.py @@ -82,18 +82,9 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win grad_out = torch.randn(total_q, h, dv) softmax_scale = (d + 100) ** (-0.5) - offst_q = total_q - offst_kv = total_k - - q1_with_buffer = torch.empty(total_q + total_q, h, d, device=device, dtype=dtype) - k1_with_buffer = torch.empty(offst_kv + total_k, h_k, d, device=device, dtype=dtype) - v1_with_buffer = torch.empty(offst_kv + total_k, h_k, dv, device=device, dtype=dtype) - q1_with_buffer[total_q:] = q - k1_with_buffer[offst_kv:] = k - v1_with_buffer[offst_kv:] = v - q1 = q1_with_buffer[offst_q:].requires_grad_() - k1 = k1_with_buffer[offst_kv:].requires_grad_() - v1 = v1_with_buffer[offst_kv:].requires_grad_() + q1 = q.clone().requires_grad_() + k1 = k.clone().requires_grad_() + v1 = v.clone().requires_grad_() q2 = q.clone().requires_grad_() k2 = k.clone().requires_grad_() From eb7583357f0a2ca44a00d528639e0fb374c4254a Mon Sep 17 00:00:00 2001 From: Li Xiang Date: Mon, 25 Aug 2025 13:44:30 +0800 Subject: [PATCH 08/20] Remove cudaMalloc and cudaFree in backward (#87) * get rid of cudaMalloc and cudaFree * minor fix --------- Co-authored-by: Jiashi Li --- csrc/sm100/common/utils.hpp | 31 +++++++++---------------------- 1 file changed, 9 insertions(+), 22 deletions(-) diff --git a/csrc/sm100/common/utils.hpp b/csrc/sm100/common/utils.hpp index f43770d..6815839 100644 --- a/csrc/sm100/common/utils.hpp +++ b/csrc/sm100/common/utils.hpp @@ -1,5 +1,6 @@ #pragma once +#include #include "cutlass/numeric_types.h" #include "helper.h" @@ -36,18 +37,21 @@ struct DeviceAllocation { T* ptr_ = nullptr; size_t offset_ = 0; size_t size_ = 0; + torch::Tensor tensor; DeviceAllocation(DeviceAllocation const&) = delete; DeviceAllocation& operator=(DeviceAllocation const&) = delete; DeviceAllocation() = default; DeviceAllocation(size_t size) { reset(size); } - ~DeviceAllocation() { reset(); } + ~DeviceAllocation() {} void reset(size_t size, size_t offset=0) { - reset(); - auto ret = cudaMalloc(&ptr_, sizeof(T) * (size + offset)); - assert(ret == cudaSuccess); + size_t num_element = sizeof(T) * (size + offset); + auto options = torch::TensorOptions().dtype(torch::kByte).device(torch::kCUDA); + + tensor = torch::empty(num_element, options); + ptr_ = tensor.data_ptr(); size_ = size; offset_ = offset; } @@ -60,24 +64,7 @@ struct DeviceAllocation { return ptr_ + offset_; } - void reset() { - if (ptr_ != nullptr) { - auto ret = cudaFree(ptr_); - assert(ret == cudaSuccess); - } - } - size_t size() const { return size_; } size_t get_storage_size() const { return (size_ + offset_) * sizeof(T); } - - void copy_from_host(const T* ptr, size_t sz) { - auto ret = cudaMemcpy(ptr_, ptr, sz * sizeof(T), cudaMemcpyDefault); - assert(ret == cudaSuccess); - } - - void copy_from_device(const T* ptr, size_t sz) { - auto ret = cudaMemcpy(ptr_, ptr, sz * sizeof(T), cudaMemcpyDefault); - assert(ret == cudaSuccess); - } -}; \ No newline at end of file +}; From 261330bb6dfacdff8ff4b67e126417863b31aa72 Mon Sep 17 00:00:00 2001 From: Zeyu WANG Date: Wed, 27 Aug 2025 19:59:57 +0800 Subject: [PATCH 09/20] fix calc space bug (#91) * fix calc space bug * use python code to allocate the buffer for backward kernel --- csrc/sm100/common/utils.hpp | 39 +-------------------------- csrc/sm100/device/fmha_device_bwd.hpp | 12 ++++----- csrc/sm100/fmha_cutlass_bwd_sm100.cuh | 7 ++--- flash_mla/flash_mla_interface.py | 2 +- 4 files changed, 10 insertions(+), 50 deletions(-) diff --git a/csrc/sm100/common/utils.hpp b/csrc/sm100/common/utils.hpp index 6815839..fdaeff0 100644 --- a/csrc/sm100/common/utils.hpp +++ b/csrc/sm100/common/utils.hpp @@ -30,41 +30,4 @@ struct cutlass_dtype<__nv_fp8_e5m2> { }; template -using cutlass_dtype_t = typename cutlass_dtype::type; - -template -struct DeviceAllocation { - T* ptr_ = nullptr; - size_t offset_ = 0; - size_t size_ = 0; - torch::Tensor tensor; - - DeviceAllocation(DeviceAllocation const&) = delete; - DeviceAllocation& operator=(DeviceAllocation const&) = delete; - - DeviceAllocation() = default; - DeviceAllocation(size_t size) { reset(size); } - ~DeviceAllocation() {} - - void reset(size_t size, size_t offset=0) { - size_t num_element = sizeof(T) * (size + offset); - auto options = torch::TensorOptions().dtype(torch::kByte).device(torch::kCUDA); - - tensor = torch::empty(num_element, options); - ptr_ = tensor.data_ptr(); - size_ = size; - offset_ = offset; - } - - T* get() { - return ptr_ + offset_; - } - - const T* get() const { - return ptr_ + offset_; - } - - size_t size() const { return size_; } - - size_t get_storage_size() const { return (size_ + offset_) * sizeof(T); } -}; +using cutlass_dtype_t = typename cutlass_dtype::type; \ No newline at end of file diff --git a/csrc/sm100/device/fmha_device_bwd.hpp b/csrc/sm100/device/fmha_device_bwd.hpp index d2463ac..76b7ed5 100644 --- a/csrc/sm100/device/fmha_device_bwd.hpp +++ b/csrc/sm100/device/fmha_device_bwd.hpp @@ -225,11 +225,11 @@ class Sm100FmhaBwd { int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment size_t workspace_bytes = 0; // OdO vector - workspace_bytes += B*H*Q * sizeof(ElementAccumulator); + workspace_bytes += sizeof(ElementAccumulator) * B*H*Q; // scaled LSE vector - workspace_bytes += B*H*Q * sizeof(ElementAccumulator); + workspace_bytes += sizeof(ElementAccumulator) * B*H*Q; // FP32 versions of outputs that are churned (start off with Q only) - workspace_bytes += B*H*Q*D * sizeof(ElementAccumulator); + workspace_bytes += sizeof(ElementAccumulator) * B*H*Q*D; return workspace_bytes; } @@ -247,7 +247,7 @@ class Sm100FmhaBwd { ElementAccumulator* scaled_lse = reinterpret_cast(workspace_scaled_lse); ElementAccumulator* dQ_acc = reinterpret_cast(workspace_dQ); params_.dQ_acc = dQ_acc; - params_.dQ_acc_size = B*H*Q*D * sizeof(ElementAccumulator); + params_.dQ_acc_size = sizeof(ElementAccumulator) * B*H*Q*D; auto args_sum_OdO = to_sum_OdO_arguments(args, sum_OdO, scaled_lse); auto args_convert = to_convert_arguments(args, dQ_acc); params_.op_sum_OdO.initialize(args_sum_OdO, nullptr, stream); @@ -274,9 +274,9 @@ class Sm100FmhaBwd { int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment char* workspace_chr = reinterpret_cast(workspace); ElementAccumulator* sum_OdO = reinterpret_cast(workspace_chr); - workspace_chr += B*H*Q * sizeof(ElementAccumulator); + workspace_chr += sizeof(ElementAccumulator) * B*H*Q; ElementAccumulator* scaled_lse = reinterpret_cast(workspace_chr); - workspace_chr += B*H*Q * sizeof(ElementAccumulator); + workspace_chr += sizeof(ElementAccumulator) * B*H*Q; ElementAccumulator* dQ_acc = reinterpret_cast(workspace_chr); return initialize_split(args, dQ_acc, sum_OdO, scaled_lse, stream); } diff --git a/csrc/sm100/fmha_cutlass_bwd_sm100.cuh b/csrc/sm100/fmha_cutlass_bwd_sm100.cuh index 2b19be2..f4a1ce8 100644 --- a/csrc/sm100/fmha_cutlass_bwd_sm100.cuh +++ b/csrc/sm100/fmha_cutlass_bwd_sm100.cuh @@ -174,13 +174,10 @@ struct BwdRunner { Operation op; - size_t workspace_size = 0; - workspace_size = Operation::get_workspace_size(arguments); - DeviceAllocation workspace(workspace_size); - uint8_t* workspace_ptr = workspace.get(); + uint8_t* workspace_ptr = static_cast(workspace_buffer.data_ptr()); CUTLASS_CHECK(op.can_implement(arguments)); - CUTLASS_CHECK(op.initialize(arguments, workspace.get())); + CUTLASS_CHECK(op.initialize(arguments, workspace_ptr)); CUTLASS_CHECK(op.run(at::cuda::getCurrentCUDAStream())); } diff --git a/flash_mla/flash_mla_interface.py b/flash_mla/flash_mla_interface.py index 9c669ba..084117e 100644 --- a/flash_mla/flash_mla_interface.py +++ b/flash_mla/flash_mla_interface.py @@ -154,7 +154,7 @@ def _flash_attn_varlen_backward( max_seqlen_qo_aligned = (max_seqlen_qo + 7) // 8 * 8 bs = cu_seqlens_qo.shape[0] - 1 workspace_bytes = 0 - workspace_bytes += 4 * qo_total_len * num_qo_heads * head_dim_qk # dQ_acc + workspace_bytes += 4 * bs * max_seqlen_qo_aligned * num_qo_heads * head_dim_qk # dQ_acc workspace_bytes += 4 * max_seqlen_qo_aligned * bs * num_qo_heads * 2 # sum_OdO and scaled_lse if num_qo_heads != num_kv_heads: workspace_bytes += 2 * kv_total_len * num_qo_heads * (head_dim_qk + head_dim_vo) # dKV_acc From ebf30641e27b777c22c38968b8c0aa38da1bac19 Mon Sep 17 00:00:00 2001 From: zhang Date: Mon, 22 Sep 2025 17:08:22 +0800 Subject: [PATCH 10/20] Refine handling for q/v sequence length equals zero. (#92) --- .../sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp | 5 ++++- .../sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp | 4 ---- .../sm100/collective/sm100_fmha_load_tma_warpspecialized.hpp | 3 +++ .../sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp | 4 ---- .../collective/sm100_fmha_mla_load_tma_warpspecialized.hpp | 3 +++ csrc/sm100/device/fmha.hpp | 5 +++++ csrc/sm100/kernel/fmha_causal_tile_scheduler.hpp | 2 +- csrc/sm100/kernel/fmha_tile_scheduler.hpp | 2 +- tests/test_fmha_sm100.py | 3 +++ 9 files changed, 20 insertions(+), 11 deletions(-) diff --git a/csrc/sm100/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp b/csrc/sm100/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp index 616357c..6f9bba3 100644 --- a/csrc/sm100/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp +++ b/csrc/sm100/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp @@ -118,12 +118,15 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized { auto cumulative_length_q = get<0>(problem_shape).cumulative_length; if (cumulative_length_q != nullptr) { int max_length_q = get<0>(problem_shape).max_length; + get<0>(problem_shape_O).max_length = max(1, max_length_q); // for variable sequence lenght, the batch is in units of row_stride get<2,1>(dO) = get<0>(dO); - get<2,1>(problem_shape_O) = max_length_q * (1 + get<2,1>(problem_shape_O)); + get<2,1>(problem_shape_O) = max(1, max_length_q * (1 + get<2,1>(problem_shape_O))); // offset ptr by the amount we add back in later ptr_O -= max_length_q * get<0>(dO); } + } else { + get<0>(problem_shape_O) = max(1, get<0>(problem_shape_O)); } auto tma_store_o = make_tma_copy( diff --git a/csrc/sm100/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp b/csrc/sm100/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp index f39fd75..56f571a 100644 --- a/csrc/sm100/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp +++ b/csrc/sm100/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp @@ -1155,10 +1155,6 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { float lse = -INFINITY; int thread_idx = threadIdx.x % (4 * NumThreadsPerWarp); -#define DSHOW(x) print(#x ": "); print(x); print("\n") - if (threadIdx.x % 128 == 0 && block0()) { - DSHOW(sO); - } #if 1 using ElementOut = typename CollectiveEpilogue::ElementOut; diff --git a/csrc/sm100/collective/sm100_fmha_load_tma_warpspecialized.hpp b/csrc/sm100/collective/sm100_fmha_load_tma_warpspecialized.hpp index 3606dcc..86e3149 100644 --- a/csrc/sm100/collective/sm100_fmha_load_tma_warpspecialized.hpp +++ b/csrc/sm100/collective/sm100_fmha_load_tma_warpspecialized.hpp @@ -112,6 +112,9 @@ struct Sm100FmhaLoadTmaWarpspecialized { problem_shape_qk = problem_shape; } + get<0>(problem_shape_qk) = max(1, get<0>(problem_shape_qk)); + get<1>(problem_shape_qk) = max(1, get<1>(problem_shape_qk)); + auto params_qk = CollectiveMmaQK::to_underlying_arguments( problem_shape_qk, typename CollectiveMmaQK::Arguments { diff --git a/csrc/sm100/collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp b/csrc/sm100/collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp index bf41af9..994bd4e 100644 --- a/csrc/sm100/collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp +++ b/csrc/sm100/collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp @@ -1162,10 +1162,6 @@ struct Sm100MlaFwdMainloopTmaWarpspecialized { float lse = -INFINITY; int thread_idx = threadIdx.x % (4 * NumThreadsPerWarp); -#define DSHOW(x) print(#x ": "); print(x); print("\n") - if (threadIdx.x % 128 == 0 && block0()) { - DSHOW(sO); - } #if 1 using ElementOut = typename CollectiveEpilogue::ElementOut; diff --git a/csrc/sm100/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp b/csrc/sm100/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp index c8fc13b..0b7d76f 100644 --- a/csrc/sm100/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp +++ b/csrc/sm100/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp @@ -119,6 +119,9 @@ struct Sm100MlaFwdLoadTmaWarpspecialized { problem_shape_qk = replace<2>(problem_shape, get<2, 0>(problem_shape) + get<2, 1>(problem_shape));; } + get<0>(problem_shape_qk) = max(1, get<0>(problem_shape_qk)); + get<1>(problem_shape_qk) = max(1, get<1>(problem_shape_qk)); + auto problem_shape_pv = replace<1>(select<0,2,1,3>(problem_shape_qk), get<2, 0>(problem_shape)); auto params_qk = CollectiveMmaQK::to_underlying_arguments( diff --git a/csrc/sm100/device/fmha.hpp b/csrc/sm100/device/fmha.hpp index f8406d3..5dcb069 100644 --- a/csrc/sm100/device/fmha.hpp +++ b/csrc/sm100/device/fmha.hpp @@ -208,6 +208,11 @@ class FMHA { dim3 const block = Kernel::get_block_shape(); dim3 const grid = get_grid_shape(params); + // No need to launch the kernel + if(grid.x == 0 || grid.y == 0 || grid.z == 0) { + return Status::kSuccess; + } + // configure smem size and carveout int smem_size = Kernel::SharedStorageSize; diff --git a/csrc/sm100/kernel/fmha_causal_tile_scheduler.hpp b/csrc/sm100/kernel/fmha_causal_tile_scheduler.hpp index 572e67f..c879fe6 100644 --- a/csrc/sm100/kernel/fmha_causal_tile_scheduler.hpp +++ b/csrc/sm100/kernel/fmha_causal_tile_scheduler.hpp @@ -160,7 +160,7 @@ struct CausalPersistentTileScheduler { return Params { num_blocks, - { size<3,0>(problem_size) }, { num_m_blocks}, { size<3,1>(problem_size) }, + { size<3,0>(problem_size) }, { max(1, num_m_blocks) }, { size<3,1>(problem_size) }, hw_info }; } diff --git a/csrc/sm100/kernel/fmha_tile_scheduler.hpp b/csrc/sm100/kernel/fmha_tile_scheduler.hpp index 119f069..97e7962 100644 --- a/csrc/sm100/kernel/fmha_tile_scheduler.hpp +++ b/csrc/sm100/kernel/fmha_tile_scheduler.hpp @@ -123,7 +123,7 @@ struct PersistentTileScheduler { return Params { num_blocks, - { num_m_blocks}, { size<3,0>(problem_size) }, { size<3,1>(problem_size) }, + { max(1, num_m_blocks)}, { size<3,0>(problem_size) }, { size<3,1>(problem_size) }, hw_info }; } diff --git a/tests/test_fmha_sm100.py b/tests/test_fmha_sm100.py index 7cb19a2..2ba8b46 100644 --- a/tests/test_fmha_sm100.py +++ b/tests/test_fmha_sm100.py @@ -29,6 +29,9 @@ def get_attn_bias(s_q, s_k, causal, window): def assert_close(x: torch.Tensor, y: torch.Tensor, name: str) -> None: + close_tensor = torch.isclose(x.to(torch.float32), y.to(torch.float32), rtol=1e-5, atol=1e-5) + if close_tensor.all(): + return x, y = x.double(), y.double() RMSE = ((x - y) * (x - y)).mean().sqrt().item() cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12) From c28eca99dbc664dd2716415ed03492afe5fefade Mon Sep 17 00:00:00 2001 From: Shengyu Liu Date: Wed, 24 Sep 2025 14:22:05 +0800 Subject: [PATCH 11/20] Reorganize files and add sparse prefill/decoding kernels on hopper --- .gitignore | 1 + README.md | 169 ++++- csrc/{sm90/kernels => }/params.h | 34 +- csrc/pybind.cpp | 442 +++++++++++ .../dense}/collective/fmha_common.hpp | 0 .../dense}/collective/fmha_fusion.hpp | 0 ..._fmha_fwd_epilogue_tma_warpspecialized.hpp | 0 ..._fmha_fwd_mainloop_tma_warpspecialized.hpp | 6 +- .../sm100_fmha_load_tma_warpspecialized.hpp | 4 +- ...a_mla_fwd_mainloop_tma_warpspecialized.hpp | 8 +- ...m100_fmha_mla_load_tma_warpspecialized.hpp | 4 +- .../dense}/common/gather_tensor.hpp | 0 .../sm100/{ => prefill/dense}/common/helper.h | 0 .../sm100/{ => prefill/dense}/common/mask.cuh | 0 .../dense}/common/pipeline_mla.hpp | 0 .../{ => prefill/dense}/common/pow_2.hpp | 0 .../{ => prefill/dense}/common/utils.hpp | 0 .../sm100/{ => prefill/dense}/device/fmha.hpp | 0 .../dense}/device/fmha_device_bwd.hpp | 0 .../dense}/fmha_cutlass_bwd_sm100.cu | 4 +- .../dense}/fmha_cutlass_bwd_sm100.cuh | 0 .../dense}/fmha_cutlass_fwd_sm100.cu | 11 +- .../dense}/fmha_cutlass_fwd_sm100.cuh | 0 .../{pybind.cu => prefill/dense/interface.h} | 9 +- .../kernel/fmha_causal_tile_scheduler.hpp | 0 .../dense}/kernel/fmha_kernel_bwd_convert.hpp | 7 + .../dense}/kernel/fmha_kernel_bwd_sum_OdO.hpp | 7 + .../dense}/kernel/fmha_options.hpp | 0 .../dense}/kernel/fmha_tile_scheduler.hpp | 0 ...00_fmha_bwd_kernel_tma_warpspecialized.hpp | 9 +- ...mha_bwd_mla_kernel_tma_warpspecialized.hpp | 9 +- ...00_fmha_fwd_kernel_tma_warpspecialized.hpp | 17 +- csrc/sm90/{kernels => decode/dense}/config.h | 2 - .../{kernels => decode/dense}/splitkv_mla.cu | 43 +- csrc/sm90/decode/dense/splitkv_mla.h | 10 + csrc/sm90/{kernels => decode/dense}/traits.h | 0 .../decode/sparse_fp8/components/config.h | 121 +++ .../decode/sparse_fp8/components/dequant.h | 88 +++ .../decode/sparse_fp8/components/epilogue.h | 87 +++ .../decode/sparse_fp8/components/helpers.h | 86 +++ .../sparse_fp8/components/named_barriers.h | 10 + csrc/sm90/decode/sparse_fp8/splitkv_mla.cu | 614 +++++++++++++++ csrc/sm90/decode/sparse_fp8/splitkv_mla.h | 9 + csrc/sm90/flash_api.cpp | 216 ------ csrc/sm90/kernels/get_mla_metadata.h | 5 - csrc/sm90/kernels/mla_combine.h | 6 - csrc/sm90/kernels/splitkv_mla.h | 6 - csrc/sm90/prefill/sparse/fwd.cu | 709 ++++++++++++++++++ csrc/sm90/prefill/sparse/fwd.h | 9 + csrc/sm90/prefill/sparse/helpers.h | 177 +++++ .../kernels => smxx}/get_mla_metadata.cu | 28 +- csrc/smxx/get_mla_metadata.h | 5 + csrc/{sm90/kernels => smxx}/mla_combine.cu | 13 +- csrc/smxx/mla_combine.h | 6 + csrc/{sm90/kernels => }/utils.h | 34 + flash_mla/__init__.py | 1 + flash_mla/flash_mla_interface.py | 104 +-- setup.py | 115 +-- tests/lib.py | 73 ++ tests/quant.py | 68 ++ tests/test_flash_mla_decoding.py | 343 +++++++++ tests/test_flash_mla_prefill.py | 197 +++++ tests/test_flash_mla_sm90.py | 153 ---- tests/test_fmha_sm100.py | 73 +- 64 files changed, 3510 insertions(+), 642 deletions(-) rename csrc/{sm90/kernels => }/params.h (59%) create mode 100644 csrc/pybind.cpp rename csrc/sm100/{ => prefill/dense}/collective/fmha_common.hpp (100%) rename csrc/sm100/{ => prefill/dense}/collective/fmha_fusion.hpp (100%) rename csrc/sm100/{ => prefill/dense}/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp (100%) rename csrc/sm100/{ => prefill/dense}/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp (99%) rename csrc/sm100/{ => prefill/dense}/collective/sm100_fmha_load_tma_warpspecialized.hpp (99%) rename csrc/sm100/{ => prefill/dense}/collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp (99%) rename csrc/sm100/{ => prefill/dense}/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp (99%) rename csrc/sm100/{ => prefill/dense}/common/gather_tensor.hpp (100%) rename csrc/sm100/{ => prefill/dense}/common/helper.h (100%) rename csrc/sm100/{ => prefill/dense}/common/mask.cuh (100%) rename csrc/sm100/{ => prefill/dense}/common/pipeline_mla.hpp (100%) rename csrc/sm100/{ => prefill/dense}/common/pow_2.hpp (100%) rename csrc/sm100/{ => prefill/dense}/common/utils.hpp (100%) rename csrc/sm100/{ => prefill/dense}/device/fmha.hpp (100%) rename csrc/sm100/{ => prefill/dense}/device/fmha_device_bwd.hpp (100%) rename csrc/sm100/{ => prefill/dense}/fmha_cutlass_bwd_sm100.cu (98%) rename csrc/sm100/{ => prefill/dense}/fmha_cutlass_bwd_sm100.cuh (100%) rename csrc/sm100/{ => prefill/dense}/fmha_cutlass_fwd_sm100.cu (98%) rename csrc/sm100/{ => prefill/dense}/fmha_cutlass_fwd_sm100.cuh (100%) rename csrc/sm100/{pybind.cu => prefill/dense/interface.h} (84%) rename csrc/sm100/{ => prefill/dense}/kernel/fmha_causal_tile_scheduler.hpp (100%) rename csrc/sm100/{ => prefill/dense}/kernel/fmha_kernel_bwd_convert.hpp (97%) rename csrc/sm100/{ => prefill/dense}/kernel/fmha_kernel_bwd_sum_OdO.hpp (97%) rename csrc/sm100/{ => prefill/dense}/kernel/fmha_options.hpp (100%) rename csrc/sm100/{ => prefill/dense}/kernel/fmha_tile_scheduler.hpp (100%) rename csrc/sm100/{ => prefill/dense}/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp (99%) rename csrc/sm100/{ => prefill/dense}/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp (99%) rename csrc/sm100/{ => prefill/dense}/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp (98%) rename csrc/sm90/{kernels => decode/dense}/config.h (78%) rename csrc/sm90/{kernels => decode/dense}/splitkv_mla.cu (97%) create mode 100644 csrc/sm90/decode/dense/splitkv_mla.h rename csrc/sm90/{kernels => decode/dense}/traits.h (100%) create mode 100644 csrc/sm90/decode/sparse_fp8/components/config.h create mode 100644 csrc/sm90/decode/sparse_fp8/components/dequant.h create mode 100644 csrc/sm90/decode/sparse_fp8/components/epilogue.h create mode 100644 csrc/sm90/decode/sparse_fp8/components/helpers.h create mode 100644 csrc/sm90/decode/sparse_fp8/components/named_barriers.h create mode 100644 csrc/sm90/decode/sparse_fp8/splitkv_mla.cu create mode 100644 csrc/sm90/decode/sparse_fp8/splitkv_mla.h delete mode 100644 csrc/sm90/flash_api.cpp delete mode 100644 csrc/sm90/kernels/get_mla_metadata.h delete mode 100644 csrc/sm90/kernels/mla_combine.h delete mode 100644 csrc/sm90/kernels/splitkv_mla.h create mode 100644 csrc/sm90/prefill/sparse/fwd.cu create mode 100644 csrc/sm90/prefill/sparse/fwd.h create mode 100644 csrc/sm90/prefill/sparse/helpers.h rename csrc/{sm90/kernels => smxx}/get_mla_metadata.cu (64%) create mode 100644 csrc/smxx/get_mla_metadata.h rename csrc/{sm90/kernels => smxx}/mla_combine.cu (94%) create mode 100644 csrc/smxx/mla_combine.h rename csrc/{sm90/kernels => }/utils.h (71%) create mode 100644 tests/lib.py create mode 100644 tests/quant.py create mode 100644 tests/test_flash_mla_decoding.py create mode 100644 tests/test_flash_mla_prefill.py delete mode 100644 tests/test_flash_mla_sm90.py diff --git a/.gitignore b/.gitignore index 4535280..6b00da7 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ dist/ /.vscode compile_commands.json .cache +/dev diff --git a/README.md b/README.md index 75bbb16..8cf01a3 100644 --- a/README.md +++ b/README.md @@ -1,69 +1,184 @@ # FlashMLA -## Performance Update (2025.04.22) +## Introduction -We're excited to announce the new release of Flash MLA, which delivers 5% ~ 15% performance improvement on compute-bound workloads, achieving up to 660 TFlops on NVIDIA H800 SXM5 GPUs. The interface of the new version is fully compatible with the old one. Just switch to the new version and enjoy the instant speedup! 🚀🚀🚀 +FlashMLA is DeepSeek's library of optimized attention kernels, powering the [DeepSeek-V3](https://github.com/deepseek-ai/DeepSeek-V3) and [DeepSeek-V3.2](TODO) models. This repository contains the following implementations: -Besides, we'd love to share the technical details behind the new kernel! Check out our deep-dive write-up [here](docs/20250422-new-kernel-deep-dive.md). +**Sparse Attention Kernels** -The new kernel primarily targets compute-intensive settings (where the number of q heads $\times$ the number of q tokens per request (if MTP is disabled then it's 1) $\ge 64$). For memory-bound cases, we recommend using version [b31bfe7](https://github.com/deepseek-ai/FlashMLA/tree/b31bfe72a83ea205467b3271a5845440a03ed7cb) for optimal performance. +*These kernels power DeepSeek Sparse Attention (DSA), as introduced in [this paper](TODO).* -## Introduction +- Token-level sparse attention for the prefill stage +- Token-level sparse attention for the decoding stage, with FP8 KV cache -FlashMLA is an efficient MLA decoding kernel for Hopper GPUs, optimized for variable-length sequences serving. +**Dense Attention Kernels** -Currently released: -- BF16, FP16 -- Paged kvcache with block size of 64 +- Dense attention for the prefill stage +- Dense attention for the decoding stage -## Requirements +## News -- Hopper GPUs -- CUDA 12.8 and above -- PyTorch 2.0 and above +- **2025.09.26(TODO) Release of Sparse Attention Kernels**: With the launch of [DeepSeek-V3.2](TODO), we are releasing the corresponding token-level sparse attention kernels. These kernels power the model's DeepSeek Sparse Attention (DSA) and achieve up to 640 TFlops during prefilling and 410 TFlops during decoding. +- **2025.08.01 Kernels for MHA on Blackwell**: Thanks to [NVIDIA's PR](https://github.com/deepseek-ai/FlashMLA/pull/76) for MHA forward / backward kernels on Blackwell! +- **2025.04.22 Deep-Dive Blog**: We'd love to share the technical details behind the new FlashMLA kernel! Check out our deep-dive write-up [here](docs/20250422-new-kernel-deep-dive.md). +- **2025.04.22 Performance Update**: We're excited to announce the new release of Flash MLA, which delivers 5% ~ 15% performance improvement for compute-bound workloads, achieving up to 660 TFlops on NVIDIA H800 SXM5 GPUs. The interface of the new version is fully compatible with the old one. Simply upgrade to the new version for an immediate performance boost! 🚀🚀🚀 -## Quick start +## Performance -### Install +#### Test & benchmark MLA decoding (Sparse & Dense): ```bash -pip install -v . +python tests/test_flash_mla_decoding.py ``` -### Benchmark +The dense MLA decoding kernel can achieve up to 3000 GB/s in memory-bound configuration and 660 TFLOPS in computation-bound configuration on H800 SXM5, using CUDA 12.8. For token-level sparse MLA decoding kernel (which uses an FP8 KV cache while performing the matrix multiplication in bfloat16), it can achieve 410 TFLOPS in compute-bound configuration on H800 SXM5, CUDA 12.8. -#### Testing MLA Decoding +#### Test & benchmark MHA prefill (Dense): ```bash -python tests/test_flash_mla_sm90.py +python tests/test_fmha_sm100.py ``` -#### Testing MLA Forward/Backward +It achieves up to 1460 TFlops in forward and 1000 TFlops in backward computation on B200, as reported by NVIDIA. + +#### Test & benchmark MLA prefill (Sparse): ```bash -python tests/test_fmha_sm100.py +python tests/test_flash_mla_prefill.py ``` -It is able up to 3000 GB/s in memory-bound configuration and 660 TFLOPS in computation-bound configuration on H800 SXM5, using CUDA 12.8. +It achieves up to 640 TFlops in forward computation on H800 SXM5, CUDA 12.8. + +## Requirements + +- Hopper / Blackwell GPUs (See the support matrix below) +- CUDA 12.8 and above (CUDA 12.9+ is required for Blackwell kernels) +- PyTorch 2.0 and above -Note. For memory-bound cases, we recommend using version [b31bfe7](https://github.com/deepseek-ai/FlashMLA/tree/b31bfe72a83ea205467b3271a5845440a03ed7cb) for optimal performance. +Support matrix: -### Usage +| Kernel | GPU Architecture | MLA Mode [2] | KVCache Format | +| :---: | :---: | :---: | :---: | +| Dense Decoding | Hopper | MQA | BF16 | +| Sparse Decoding | Hopper | MQA | FP8 [1] | +| Dense Prefill | Blackwell | MHA | | +| Sparse Prefill | Hopper | MQA | | + +[1]: For more details on using FP8 KV cache, see documents below. + +[2]: Here "MLA Mode" refers to the mode used for MLA calculation. MQA stands for Multi-Query Attention mode (i.e. `head_dim_k` = 576 with `head_dim_v` = 512), while MHA stands for Multi-Head Attention mode (i.e. `head_dim_k` = 192 / 128 with `head_dim_v` = 128). For a detailed explanation of these modes, please refer to the appendix of [DeepSeek V3.2's Paper](TODO). + +## Installation + +```bash +git clone https://github.com/deepseek-ai/FlashMLA.git flash-mla +cd flash-mla +git submodule update --init --recursive +pip install -v . +``` + +## Usage + +### MLA Decoding + +To use the MLA decoding kernels, call get_mla_metadata once before the decoding loop to get the tile scheduler metadata. Then, call flash_mla_with_kvcache in each decoding step. For example: ```python from flash_mla import get_mla_metadata, flash_mla_with_kvcache -tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv) +tile_scheduler_metadata, num_splits = get_mla_metadata( + cache_seqlens, + s_q * h_q // h_kv, + h_kv, + h_q, + is_fp8, + topk, +) for i in range(num_layers): ... o_i, lse_i = flash_mla_with_kvcache( q_i, kvcache_i, block_table, cache_seqlens, dv, - tile_scheduler_metadata, num_splits, causal=True, + tile_scheduler_metadata, num_splits, + is_causal, is_fp8_kvcache, indices, ) ... ``` +Where + +- `s_q` is the number of q tokens per q sequence. If MTP (speculative decoding) is disabled, it should be 1. +- `h_kv` is the number of key-value heads. +- `h_q` is the number of query heads. + +**FP8 KV Cache:** +If `is_fp8_kvcache` is set to `True`, the kernel reads the KV cache in the "FP8 with scale" format (described below). It dequantizes the cache to bfloat16 and performs attention computation in bfloat16. The output is also in bfloat16. + +In the "FP8 with scale" format, each token's KV cache is 656 Bytes, structured as: +- **First 512 bytes:** The "quantized NoPE" part, containing 512 `float8_e4m3` values. +- **Next 16 bytes:** Scale factors, containing 4 `float32` values. The first `float32` is the scale for the first 128 `float8_e4m3` values, the second for the next 128, and so on. +- **Last 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This part is not quantized for accuracy. + +See `tests/quant.py` for quantization and dequantization details. + +**Sparse Attention (`indices` tensor):** +The `indices` tensor (if provided) enables token-level sparse attention by instructing the kernel to compute attention only for specified tokens. + +- **Shape:** `indices` should be a 3D tensor of shape `(batch_size, seq_len_q, topk)`. +- **Format:** `indices_in_kvcache[i][j][k] = (the index of the page block where token t resides) * page_block_size + (the offset of token t within the page block)`, where `t` is the k-th token for the j-th query sequence in the i-th batch. Since the index of the page block has already been encoded into `indices_in_kvcache`, the kernel does not require the `block_table` parameter. +- **Invalid entries:** Set invalid indices to `-1`. + +**Return Values:** +The kernel returns `(out, lse)`, where: +- `out` is the attention result. +- `lse` is the log-sum-exp value of the attention scores for each query head. + +See `tests/test_flash_mla_decoding.py` for a complete example. + +### Sparse MLA Prefill + +For the sparse MLA prefill kernel, call `flash_mla_sparse_fwd` directly with the following parameters: +- `q`: Query tensor of shape `[s_q, h_q, d_qk]` +- `kv`: Key-Value tensor of shape `[s_kv, h_kv, d_qk]` +- `indices`: Indices tensor of shape `[s_q, h_kv, topk]` +- `sm_scale`: A scalar value + +**Note on batching:** This kernel does not support a batch dimension. For multi-batch inference, reshape the input tensors and adjust the `indices` parameter to simulate batch processing. + +**Invalid indices:** Set invalid entries in `indices` to `-1` or any number `>= s_kv`. + +**Return Values and Equivalent PyTorch Code:** +The kernel returns `(out, max_logits, lse)`. This is equivalent to the following PyTorch operations: + +```python +Q: [s_q, h_q, d_qk], bfloat16 +kv: [s_kv, h_kv, d_qk], bfloat16 +indices: [s_q, h_kv, topk], int32 + +kv = kv.squeeze(1) # [s_kv, d_qk], h_kv must be 1 +indices = indices.squeeze(1) # [s_q, topk] +focused_kv = kv[indices] # For the i-th sequence (s_q), the corresponding KV tokens are selected from the KV cache based on indices[i, :]. This operation results in a tensor of shape [s_q, topk, d_qk]. + +P = (Q @ focused_kv.transpose(-1, -2)) * sm_scale * math.log2(math.e) # [s_q, h_q, topk] +max_logits = P.max(dim=-1) # [s_q, h_q] +lse = log2sumexp2(P, dim=-1, base=2) # [s_q, h_q],"log2sumexp2" means that the exponentiation and logarithm are base-2 +S = exp2(P - lse) # [s_q, h_q, topk] +out = S @ focused_kv # [s_q, h_q, d_qk] + +return (out, max_logits, lse) +``` + +See `tests/test_flash_mla_prefill.py` for a complete example. + +### Dense MHA Prefill + +This kernel implements the standard dense Multi-Head Attention (MHA) forward and backward operations. It can be called using: +- `flash_attn_varlen_func` +- `flash_attn_varlen_qkvpacked_func` +- `flash_attn_varlen_kvpacked_func` + +The usage is similar to the `flash_attn` package. See `tests/test_fmha_sm100.py` for a complete example. + ## Acknowledgement FlashMLA is inspired by [FlashAttention 2&3](https://github.com/dao-AILab/flash-attention/) and [cutlass](https://github.com/nvidia/cutlass) projects. @@ -109,7 +224,7 @@ The corresponding FlashMLA version can be found at: [AITER/MLA](https://github.c ```bibtex @misc{flashmla2025, - title={FlashMLA: Efficient MLA decoding kernels}, + title={FlashMLA: Efficient Multi-head Latent Attention Kernels}, author={Jiashi Li, Shengyu Liu}, year={2025}, publisher = {GitHub}, diff --git a/csrc/sm90/kernels/params.h b/csrc/params.h similarity index 59% rename from csrc/sm90/kernels/params.h rename to csrc/params.h index 3b4e254..baa2f7f 100644 --- a/csrc/sm90/kernels/params.h +++ b/csrc/params.h @@ -1,8 +1,8 @@ #pragma once -//////////////////////////////////////////////////////////////////////////////////////////////////// +#include "cutlass/bfloat16.h" -struct Flash_fwd_mla_params { +struct DecodingParams { using index_t = int64_t; int b; // batch size @@ -14,11 +14,13 @@ struct Flash_fwd_mla_params { int q_head_per_hk; // The number of q_head(s) per KV head, = h_q / h_k bool is_causal; float scale_softmax, scale_softmax_log2; + int topk; void *__restrict__ q_ptr; void *__restrict__ k_ptr; void *__restrict__ o_ptr; void *__restrict__ softmax_lse_ptr; + int *__restrict__ indices_ptr; index_t q_batch_stride; index_t k_batch_stride; @@ -29,6 +31,8 @@ struct Flash_fwd_mla_params { index_t q_head_stride; index_t k_head_stride; index_t o_head_stride; + index_t indices_batch_stride; + index_t indices_row_stride; int *__restrict__ block_table; index_t block_table_batch_stride; @@ -45,9 +49,9 @@ struct Flash_fwd_mla_params { }; static constexpr int TileSchedulerMetaDataSize = 8; -// [begin_idx, begin_seqlen, end_idx, end_seqlen, begin_n_split_idx, _, _, _] +// [begin_idx (inclusive), begin_block_idx (inclusive), end_idx (inclusive), end_block_idx (exclusive), begin_n_split_idx, _, _, _] -struct Mla_metadata_params { +struct GetDecodingMetadataParams { int *__restrict__ seqlens_k_ptr; int *__restrict__ tile_scheduler_metadata_ptr; int *__restrict__ num_splits_ptr; @@ -55,4 +59,26 @@ struct Mla_metadata_params { int block_size_n; int fixed_overhead_num_blocks; int num_sm_parts; + int topk; +}; + +struct SparsePrefillParams { + int s_q, s_kv, h_q, h_kv, d_qk, d_v, topk; + float sm_scale, sm_scale_div_log2; + + // Input tensors + cutlass::bfloat16_t* __restrict__ q; // [s_q, h_q, d_qk] + cutlass::bfloat16_t* __restrict__ kv; // [s_kv, h_kv, d_qk] + int* __restrict__ indices; // [s_q, h_kv, topk] + + int stride_q_s_q; int stride_q_h_q; + int stride_kv_s_kv; int stride_kv_h_kv; + int stride_indices_s_q; int stride_indices_h_kv; + + // Output tensors + cutlass::bfloat16_t* __restrict__ out; // [s_q, h_q, d_v] + float* __restrict__ max_logits; // [s_q, h_q] + float* __restrict__ lse; // [s_q, h_q] + + cudaStream_t stream; }; diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp new file mode 100644 index 0000000..b360c24 --- /dev/null +++ b/csrc/pybind.cpp @@ -0,0 +1,442 @@ +// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/flash_api.cpp +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#include +#include +#include +#include + +#include + +#include "params.h" +#include "smxx/get_mla_metadata.h" +#include "smxx/mla_combine.h" +#include "sm90/decode/dense/splitkv_mla.h" +#include "sm90/decode/sparse_fp8/splitkv_mla.h" +#include "sm90/prefill/sparse/fwd.h" +#include "sm100/prefill/dense/interface.h" + +#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") + +struct Arch { + int major; + int minor; + + bool is_sm90() const { + return major == 9 && minor == 0; + } + + bool is_sm100() const { + return major == 10 && minor == 0; + } + + void assert_is_supported() const { + TORCH_CHECK(is_sm90() || is_sm100(), "Only SM90 and SM100 are supported"); + } +}; + +// DecodingAttnImplMeta - A struct to hold metadata for Decoding Attention Implementation (i.e. Hopper Dense BF16, Hopper Sparse FP8, etc.) +struct DecodingAttnImplMeta { + int num_sm_parts; + int fixed_overhead_num_blocks; + int k_block_size; +}; + +DecodingAttnImplMeta get_attn_impl_meta( + Arch arch, + int sm_count, + int num_q_tokens_per_head_k, + int h_k, + std::optional h_q_, + bool is_fp8_kvcache, + bool is_sparse_attn +) { + if (arch.is_sm90()) { + if (is_sparse_attn) { + if (is_fp8_kvcache) { + TORCH_CHECK(h_q_.has_value()); + int h_q = h_q_.value(); + TORCH_CHECK(h_q % h_k == 0); + int s_q = num_q_tokens_per_head_k * h_k / h_q; + // FP8 + Sparse MLA + return { + std::max((sm_count/2) / h_k / (cutlass::ceil_div(h_q/h_k, 2*64) * s_q), 1), + 5, + 64 + }; + } else { + // Sparse BF16 MLA + TORCH_CHECK(false, "Sparse BF16 MLA is not supported on SM90"); + } + } else { + if (is_fp8_kvcache) { + // Dense FP8 MLA + TORCH_CHECK(false, "Dense FP8 MLA is not supported on SM90"); + } else { + // Dense BF16 MLA + return { + std::max(sm_count / h_k / cutlass::ceil_div(num_q_tokens_per_head_k, 64), 1), + 5, + 64 + }; + } + } + } else if (arch.is_sm100()) { + TORCH_CHECK(false, "Unsupported GPU architecture"); + } else { + TORCH_CHECK(false, "Unsupported GPU architecture"); + } +} + + +std::vector +get_mla_decoding_metadata( + at::Tensor &seqlens_k, + const int num_q_tokens_per_head_k, + const int h_k, + const std::optional h_q, + const bool is_fp8_kvcache, + const std::optional topk +) { + bool is_sparse_attn = topk.has_value(); + CHECK_DEVICE(seqlens_k); + TORCH_CHECK(seqlens_k.is_contiguous()); + TORCH_CHECK(seqlens_k.dtype() == torch::kInt32); + if (is_sparse_attn) + TORCH_CHECK(h_q.has_value(), "num_heads_q must be provided when topk is provided"); + + int batch_size = seqlens_k.size(0); + int *seqlens_k_ptr = seqlens_k.data_ptr(); + auto options = seqlens_k.options(); + + auto dprops = at::cuda::getCurrentDeviceProperties(); + int sm_count = dprops->multiProcessorCount; + Arch arch = {dprops->major, dprops->minor}; + arch.assert_is_supported(); + DecodingAttnImplMeta attn_impl_meta = get_attn_impl_meta(arch, sm_count, num_q_tokens_per_head_k, h_k, h_q, is_fp8_kvcache, is_sparse_attn); + + auto tile_scheduler_metadata = torch::empty({attn_impl_meta.num_sm_parts, TileSchedulerMetaDataSize}, options); + auto num_splits = torch::empty({batch_size + 1}, options); + int *tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr(); + int *num_splits_ptr = num_splits.data_ptr(); + + at::cuda::CUDAGuard device_guard{(char)seqlens_k.get_device()}; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + GetDecodingMetadataParams params = {}; + params.seqlens_k_ptr = seqlens_k_ptr; + params.tile_scheduler_metadata_ptr = tile_scheduler_metadata_ptr; + params.num_splits_ptr = num_splits_ptr; + params.batch_size = batch_size; + params.block_size_n = attn_impl_meta.k_block_size; + params.fixed_overhead_num_blocks = attn_impl_meta.fixed_overhead_num_blocks; + params.num_sm_parts = attn_impl_meta.num_sm_parts; + params.topk = is_sparse_attn ? topk.value() : -1; + run_get_mla_metadata_kernel(params, stream); + + return {tile_scheduler_metadata, num_splits}; +} + +std::vector +fwd_kvcache_mla( + at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size (when is_fp8 is False) or num_blocks x num_heads_k x (page_block_size*656) (when is_fp8 is True) + const int head_size_v, + const at::Tensor &seqlens_k, // batch_size + const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq + const float softmax_scale, + bool is_causal, + const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize + const at::Tensor &num_splits, // batch_size + 1 + const bool &is_fp8, + const std::optional &indices // None, or batch_size x seqlen_q x topk +) { + bool is_sparse_attn = indices.has_value(); + int topk = is_sparse_attn ? indices->size(-1) : -1; + + // Check the architecture + auto dprops = at::cuda::getCurrentDeviceProperties(); + Arch arch = {dprops->major, dprops->minor}; + arch.assert_is_supported(); + + // Check data types + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kHalf); + + if (!is_fp8) { + TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype"); + } else { + TORCH_CHECK(kcache.dtype() == torch::kFloat8_e4m3fn || kcache.dtype() == torch::kInt8 || kcache.dtype() == torch::kUInt8, "key must have dtype fp8_e4m3fn or int8 or uint8"); + } + TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32"); + TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); + TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32"); + TORCH_CHECK(num_splits.dtype() == torch::kInt32, "num_splits must have dtype int32"); + TORCH_CHECK(!is_sparse_attn || indices->dtype() == torch::kInt32, "indices must have dtype int32"); + + // Check device + CHECK_DEVICE(q); + CHECK_DEVICE(kcache); + CHECK_DEVICE(seqlens_k); + CHECK_DEVICE(block_table); + CHECK_DEVICE(tile_scheduler_metadata); + CHECK_DEVICE(num_splits); + if (is_sparse_attn) CHECK_DEVICE(indices.value()); + + // Check layout + TORCH_CHECK(q.stride(-1) == 1, "q must have contiguous last dimension"); + TORCH_CHECK(kcache.stride(-1) == 1, "kcache must have contiguous last dimension"); + CHECK_CONTIGUOUS(seqlens_k); + TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); + CHECK_CONTIGUOUS(tile_scheduler_metadata); + CHECK_CONTIGUOUS(num_splits); + TORCH_CHECK(!is_sparse_attn || indices->stride(-1) == 1, "indices must have contiguous last dimension"); + + const auto sizes = q.sizes(); + const int batch_size = sizes[0]; + const int seqlen_q_ori = sizes[1]; + const int num_heads_q = sizes[2]; + const int head_size_k = sizes[3]; + TORCH_CHECK(head_size_k == 576, "Only head_size_k == 576 is supported"); + TORCH_CHECK(head_size_v == 512, "Only head_size_v == 576 is supported"); + + const int max_num_blocks_per_seq = block_table.size(1); + const int num_blocks = kcache.size(0); + const int page_block_size = kcache.size(1); + const int num_heads_k = kcache.size(2); + TORCH_CHECK(page_block_size == 64, "Currently page_block_size must be 64"); + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(num_heads_q % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + if (seqlen_q_ori == 1) { is_causal = false; } + + const int num_q_heads_per_hk = num_heads_q / num_heads_k; + const int q_seq_per_hk = seqlen_q_ori * num_q_heads_per_hk; + const int num_heads = num_heads_k; + q = q.view({batch_size, seqlen_q_ori, num_heads_k, num_q_heads_per_hk, head_size_k}).transpose(2, 3) + .reshape({batch_size, q_seq_per_hk, num_heads, head_size_k}); + + CHECK_SHAPE(q, batch_size, q_seq_per_hk, num_heads, head_size_k); + if (!is_fp8) { + CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k); + } else { + int bytes_per_token = 512 + 64*2 + (512/128)*4; + CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, bytes_per_token); + TORCH_CHECK(num_heads_k == 1, "Currently the number of k heads must be 1 when is_fp8_kvcache is True"); + TORCH_CHECK(kcache.stride(1) == bytes_per_token, "The whole block must be contiguous when is_fp8_cache is True"); + } + CHECK_SHAPE(seqlens_k, batch_size); + CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); + TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize); + CHECK_SHAPE(num_splits, batch_size+1); + if (is_sparse_attn) CHECK_SHAPE(indices.value(), batch_size, seqlen_q_ori, topk); + + at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + + auto opts = q.options(); + at::Tensor out = torch::empty({batch_size, q_seq_per_hk, num_heads, head_size_v}, opts); + at::Tensor softmax_lse = torch::empty({batch_size, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat)); + CHECK_CONTIGUOUS(softmax_lse); + + DecodingParams params = {}; + // Set the sizes. + params.b = batch_size; + params.s_q = seqlen_q_ori; + params.q_seq_per_hk = q_seq_per_hk; + params.seqlens_k_ptr = seqlens_k.data_ptr(); + params.h_q = num_heads_q; + params.h_k = num_heads_k; + params.num_blocks = num_blocks; + params.q_head_per_hk = num_q_heads_per_hk; + params.is_causal = is_causal; + params.d = head_size_k; + params.d_v = head_size_v; + params.scale_softmax = softmax_scale; + params.scale_softmax_log2 = float(softmax_scale * M_LOG2E); + params.topk = topk; + // Set the pointers and strides. + params.q_ptr = q.data_ptr(); + params.k_ptr = kcache.data_ptr(); + params.o_ptr = out.data_ptr(); + params.indices_ptr = is_sparse_attn ? indices->data_ptr() : nullptr; + params.softmax_lse_ptr = softmax_lse.data_ptr(); + // All stride are in elements, not bytes. + params.q_batch_stride = q.stride(0); + params.k_batch_stride = kcache.stride(0); + params.o_batch_stride = out.stride(0); + params.q_row_stride = q.stride(-3); + params.k_row_stride = kcache.stride(1); + params.o_row_stride = out.stride(-3); + params.q_head_stride = q.stride(-2); + params.k_head_stride = kcache.stride(2); + params.o_head_stride = out.stride(-2); + params.indices_batch_stride = is_sparse_attn ? indices->stride(0) : 0; + params.indices_row_stride = is_sparse_attn ? indices->stride(1) : 0; + + params.block_table = block_table.data_ptr(); + params.block_table_batch_stride = block_table.stride(0); + params.page_block_size = page_block_size; + + params.tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr(); + params.num_sm_parts = tile_scheduler_metadata.size(0); + params.num_splits_ptr = num_splits.data_ptr(); + + const int total_num_splits = batch_size + params.num_sm_parts; + at::Tensor softmax_lse_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat)); + at::Tensor out_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk, head_size_v}, opts.dtype(at::kFloat)); + CHECK_CONTIGUOUS(softmax_lse_accum); + CHECK_CONTIGUOUS(out_accum); + params.total_num_splits = total_num_splits; + params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr(); + params.oaccum_ptr = out_accum.data_ptr(); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + TORCH_CHECK(head_size_k == 576); + + if (q_dtype == torch::kHalf) { +#ifdef FLASH_MLA_DISABLE_FP16 + TORCH_CHECK(false, "FlashMLA is compiled with -DFLASH_MLA_DISABLE_FP16. Please remove this flag from your environment and re-compile FlashMLA."); +#endif + } + + if (arch.is_sm90()) { + if (is_sparse_attn) { + if (is_fp8) { + TORCH_CHECK(q_dtype == torch::kBFloat16, "Sparse FP8 MLA only supports BFloat16 on SM90"); + sm90::run_flash_splitkv_mla_fp8_sparse_kernel(params, stream); + } else { + TORCH_CHECK(false, "Dense FP8 MLA is not supported on SM90"); + } + } else { + if (is_fp8) { + TORCH_CHECK(false, "Dense FP8 MLA is not supported on SM90"); + } else { + if (q_dtype == torch::kBFloat16) { + sm90::run_flash_splitkv_mla_kernel(params, stream); + } else if (q_dtype == torch::kHalf) { +#ifndef FLASH_MLA_DISABLE_FP16 + sm90::run_flash_splitkv_mla_kernel(params, stream); +#endif + } else { + TORCH_CHECK(false, "Unsupported tensor dtype for query"); + } + } + } + } else if (arch.is_sm100()) { + TORCH_CHECK(false, "Unsupported GPU architecture"); + } else { + TORCH_CHECK(false, "Unsupported GPU architecture"); + } + + if (q_dtype == torch::kBFloat16) { + run_flash_mla_combine_kernel(params, stream); + } else if (q_dtype == torch::kHalf) { +#ifndef FLASH_MLA_DISABLE_FP16 + run_flash_mla_combine_kernel(params, stream); +#endif + } else { + TORCH_CHECK(false, "Unsupported tensor dtype for query"); + } + + out = out.view({batch_size, seqlen_q_ori, num_q_heads_per_hk, num_heads_k, head_size_v}).transpose(2, 3) + .reshape({batch_size, seqlen_q_ori, num_heads_q, head_size_v}); + softmax_lse = softmax_lse.view({batch_size, num_heads_k, seqlen_q_ori, num_q_heads_per_hk}).transpose(2, 3) + .reshape({batch_size, num_heads_q, seqlen_q_ori}); + + return {out, softmax_lse}; +} + + +inline int int64_stride_to_int(int64_t orig_stride) { + if (orig_stride > std::numeric_limits::max()) { + TORCH_CHECK(false, "[Sparse TopK Attention] Stride exceeds int32 limit: ", orig_stride); + } + return static_cast(orig_stride); +} + +std::vector sparse_prefill_fwd( + const at::Tensor &q, + const at::Tensor &kv, + const at::Tensor &indices, + float sm_scale, + int d_v +) { + auto dprops = at::cuda::getCurrentDeviceProperties(); + bool is_sm90 = dprops->major == 9; + TORCH_CHECK(is_sm90, "Sparse Attention Forward Kernel (sparse_prefill_fwd) is only supported on SM90 architectures"); + + CHECK_DEVICE(q); + CHECK_DEVICE(kv); + CHECK_DEVICE(indices); + + TORCH_CHECK(q.dtype() == torch::kBFloat16); + TORCH_CHECK(kv.dtype() == torch::kBFloat16); + TORCH_CHECK(indices.dtype() == torch::kInt32); + + int s_q = q.size(0); + int s_kv = kv.size(0); + int h_q = q.size(1); + int h_kv = kv.size(1); + int d_qk = q.size(2); + int topk = indices.size(2); + + CHECK_SHAPE(q, s_q, h_q, d_qk); + CHECK_SHAPE(kv, s_kv, h_kv, d_qk); + CHECK_SHAPE(indices, s_q, h_kv, topk); + + TORCH_CHECK(q.stride(-1) == 1); + TORCH_CHECK(kv.stride(-1) == 1); + TORCH_CHECK(indices.stride(-1) == 1); + + at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + auto opts = q.options(); + at::Tensor out = torch::empty({s_q, h_q, d_v}, opts); + CHECK_CONTIGUOUS(out); + + at::Tensor buf_attn_score, max_logits, lse, p_sum; + max_logits = torch::empty({s_q, h_q}, opts.dtype(torch::kFloat)); + lse = torch::empty({s_q, h_q}, opts.dtype(torch::kFloat)); + CHECK_CONTIGUOUS(max_logits); + CHECK_CONTIGUOUS(lse); + + SparsePrefillParams params = { + s_q, s_kv, h_q, h_kv, d_qk, d_v, topk, + sm_scale, sm_scale * 1.44269504f, + + (cutlass::bfloat16_t*)q.data_ptr(), + (cutlass::bfloat16_t*)kv.data_ptr(), + (int*)indices.data_ptr(), + + int64_stride_to_int(q.stride(0)), int64_stride_to_int(q.stride(1)), + int64_stride_to_int(kv.stride(0)), int64_stride_to_int(kv.stride(1)), + int64_stride_to_int(indices.stride(0)), int64_stride_to_int(indices.stride(1)), + + (cutlass::bfloat16_t*)out.data_ptr(), + (float*)max_logits.data_ptr(), + (float*)lse.data_ptr(), + + at::cuda::getCurrentCUDAStream().stream() + }; + + if (is_sm90) { + sm90::run_fwd_kernel(params); + } else { + TORCH_CHECK(false, "Unknown architecture"); + } + + return {out, max_logits, lse}; +} + + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.doc() = "FlashMLA"; + m.def("get_mla_decoding_metadata", &get_mla_decoding_metadata); + m.def("fwd_kvcache_mla", &fwd_kvcache_mla); + m.def("dense_prefill_fwd", &FMHACutlassSM100FwdRun); + m.def("dense_prefill_bwd", &FMHACutlassSM100BwdRun); + m.def("sparse_prefill_fwd", &sparse_prefill_fwd); +} diff --git a/csrc/sm100/collective/fmha_common.hpp b/csrc/sm100/prefill/dense/collective/fmha_common.hpp similarity index 100% rename from csrc/sm100/collective/fmha_common.hpp rename to csrc/sm100/prefill/dense/collective/fmha_common.hpp diff --git a/csrc/sm100/collective/fmha_fusion.hpp b/csrc/sm100/prefill/dense/collective/fmha_fusion.hpp similarity index 100% rename from csrc/sm100/collective/fmha_fusion.hpp rename to csrc/sm100/prefill/dense/collective/fmha_fusion.hpp diff --git a/csrc/sm100/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp b/csrc/sm100/prefill/dense/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp similarity index 100% rename from csrc/sm100/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp rename to csrc/sm100/prefill/dense/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp diff --git a/csrc/sm100/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp b/csrc/sm100/prefill/dense/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp similarity index 99% rename from csrc/sm100/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp rename to csrc/sm100/prefill/dense/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp index f39fd75..4783a13 100644 --- a/csrc/sm100/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp +++ b/csrc/sm100/prefill/dense/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp @@ -37,9 +37,9 @@ #include "cute/tensor.hpp" #include "cute/layout.hpp" -#include "collective/fmha_common.hpp" -#include "collective/fmha_fusion.hpp" -#include "collective/sm100_fmha_load_tma_warpspecialized.hpp" +#include "../collective/fmha_common.hpp" +#include "../collective/fmha_fusion.hpp" +#include "../collective/sm100_fmha_load_tma_warpspecialized.hpp" namespace cutlass::fmha::collective { diff --git a/csrc/sm100/collective/sm100_fmha_load_tma_warpspecialized.hpp b/csrc/sm100/prefill/dense/collective/sm100_fmha_load_tma_warpspecialized.hpp similarity index 99% rename from csrc/sm100/collective/sm100_fmha_load_tma_warpspecialized.hpp rename to csrc/sm100/prefill/dense/collective/sm100_fmha_load_tma_warpspecialized.hpp index 3606dcc..987ac22 100644 --- a/csrc/sm100/collective/sm100_fmha_load_tma_warpspecialized.hpp +++ b/csrc/sm100/prefill/dense/collective/sm100_fmha_load_tma_warpspecialized.hpp @@ -36,8 +36,8 @@ #include "cute/tensor.hpp" #include "cute/layout.hpp" -#include "collective/fmha_common.hpp" -#include "collective/fmha_fusion.hpp" +#include "../collective/fmha_common.hpp" +#include "../collective/fmha_fusion.hpp" namespace cutlass::fmha::collective { diff --git a/csrc/sm100/collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp b/csrc/sm100/prefill/dense/collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp similarity index 99% rename from csrc/sm100/collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp rename to csrc/sm100/prefill/dense/collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp index bf41af9..1e66d1a 100644 --- a/csrc/sm100/collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp +++ b/csrc/sm100/prefill/dense/collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp @@ -37,10 +37,10 @@ #include "cute/tensor.hpp" #include "cute/layout.hpp" -#include "collective/fmha_common.hpp" -#include "collective/fmha_fusion.hpp" -#include "collective/sm100_fmha_mla_load_tma_warpspecialized.hpp" -#include "common/pipeline_mla.hpp" +#include "../collective/fmha_common.hpp" +#include "../collective/fmha_fusion.hpp" +#include "../collective/sm100_fmha_mla_load_tma_warpspecialized.hpp" +#include "../common/pipeline_mla.hpp" namespace cutlass::fmha::collective { diff --git a/csrc/sm100/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp b/csrc/sm100/prefill/dense/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp similarity index 99% rename from csrc/sm100/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp rename to csrc/sm100/prefill/dense/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp index c8fc13b..d161a99 100644 --- a/csrc/sm100/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp +++ b/csrc/sm100/prefill/dense/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp @@ -36,8 +36,8 @@ #include "cute/tensor.hpp" #include "cute/layout.hpp" -#include "collective/fmha_common.hpp" -#include "collective/fmha_fusion.hpp" +#include "../collective/fmha_common.hpp" +#include "../collective/fmha_fusion.hpp" namespace cutlass::fmha::collective { diff --git a/csrc/sm100/common/gather_tensor.hpp b/csrc/sm100/prefill/dense/common/gather_tensor.hpp similarity index 100% rename from csrc/sm100/common/gather_tensor.hpp rename to csrc/sm100/prefill/dense/common/gather_tensor.hpp diff --git a/csrc/sm100/common/helper.h b/csrc/sm100/prefill/dense/common/helper.h similarity index 100% rename from csrc/sm100/common/helper.h rename to csrc/sm100/prefill/dense/common/helper.h diff --git a/csrc/sm100/common/mask.cuh b/csrc/sm100/prefill/dense/common/mask.cuh similarity index 100% rename from csrc/sm100/common/mask.cuh rename to csrc/sm100/prefill/dense/common/mask.cuh diff --git a/csrc/sm100/common/pipeline_mla.hpp b/csrc/sm100/prefill/dense/common/pipeline_mla.hpp similarity index 100% rename from csrc/sm100/common/pipeline_mla.hpp rename to csrc/sm100/prefill/dense/common/pipeline_mla.hpp diff --git a/csrc/sm100/common/pow_2.hpp b/csrc/sm100/prefill/dense/common/pow_2.hpp similarity index 100% rename from csrc/sm100/common/pow_2.hpp rename to csrc/sm100/prefill/dense/common/pow_2.hpp diff --git a/csrc/sm100/common/utils.hpp b/csrc/sm100/prefill/dense/common/utils.hpp similarity index 100% rename from csrc/sm100/common/utils.hpp rename to csrc/sm100/prefill/dense/common/utils.hpp diff --git a/csrc/sm100/device/fmha.hpp b/csrc/sm100/prefill/dense/device/fmha.hpp similarity index 100% rename from csrc/sm100/device/fmha.hpp rename to csrc/sm100/prefill/dense/device/fmha.hpp diff --git a/csrc/sm100/device/fmha_device_bwd.hpp b/csrc/sm100/prefill/dense/device/fmha_device_bwd.hpp similarity index 100% rename from csrc/sm100/device/fmha_device_bwd.hpp rename to csrc/sm100/prefill/dense/device/fmha_device_bwd.hpp diff --git a/csrc/sm100/fmha_cutlass_bwd_sm100.cu b/csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu similarity index 98% rename from csrc/sm100/fmha_cutlass_bwd_sm100.cu rename to csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu index 4ff745d..54d85db 100644 --- a/csrc/sm100/fmha_cutlass_bwd_sm100.cu +++ b/csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu @@ -1,7 +1,7 @@ -#include +#include "interface.h" + #include #include -#include #include #include "common/mask.cuh" #include "common/utils.hpp" diff --git a/csrc/sm100/fmha_cutlass_bwd_sm100.cuh b/csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cuh similarity index 100% rename from csrc/sm100/fmha_cutlass_bwd_sm100.cuh rename to csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cuh diff --git a/csrc/sm100/fmha_cutlass_fwd_sm100.cu b/csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu similarity index 98% rename from csrc/sm100/fmha_cutlass_fwd_sm100.cu rename to csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu index 997886e..ab66f0f 100644 --- a/csrc/sm100/fmha_cutlass_fwd_sm100.cu +++ b/csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu @@ -1,12 +1,13 @@ -#include "common/mask.cuh" -#include "common/utils.hpp" -#include "fmha_cutlass_fwd_sm100.cuh" +#include "interface.h" -#include #include #include #include -#include + +#include "common/mask.cuh" +#include "common/utils.hpp" + +#include "fmha_cutlass_fwd_sm100.cuh" template void call_run_fmha_fwd([[maybe_unused]] Mask mask, [[maybe_unused]] Varlen is_varlen, diff --git a/csrc/sm100/fmha_cutlass_fwd_sm100.cuh b/csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cuh similarity index 100% rename from csrc/sm100/fmha_cutlass_fwd_sm100.cuh rename to csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cuh diff --git a/csrc/sm100/pybind.cu b/csrc/sm100/prefill/dense/interface.h similarity index 84% rename from csrc/sm100/pybind.cu rename to csrc/sm100/prefill/dense/interface.h index 7d4744d..80ef2bc 100644 --- a/csrc/sm100/pybind.cu +++ b/csrc/sm100/prefill/dense/interface.h @@ -1,4 +1,6 @@ -#include +#pragma once + +#include void FMHACutlassSM100FwdRun(at::Tensor workspace_buffer, at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv, @@ -10,8 +12,3 @@ void FMHACutlassSM100BwdRun(at::Tensor workspace_buffer, at::Tensor d_o, at::Ten at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv, at::Tensor dq, at::Tensor dk, at::Tensor dv, int mask_mode_code, float softmax_scale, int max_seqlen_q, int max_seqlen_kv, bool is_varlen); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("fwd", &FMHACutlassSM100FwdRun); - m.def("bwd", &FMHACutlassSM100BwdRun); -} diff --git a/csrc/sm100/kernel/fmha_causal_tile_scheduler.hpp b/csrc/sm100/prefill/dense/kernel/fmha_causal_tile_scheduler.hpp similarity index 100% rename from csrc/sm100/kernel/fmha_causal_tile_scheduler.hpp rename to csrc/sm100/prefill/dense/kernel/fmha_causal_tile_scheduler.hpp diff --git a/csrc/sm100/kernel/fmha_kernel_bwd_convert.hpp b/csrc/sm100/prefill/dense/kernel/fmha_kernel_bwd_convert.hpp similarity index 97% rename from csrc/sm100/kernel/fmha_kernel_bwd_convert.hpp rename to csrc/sm100/prefill/dense/kernel/fmha_kernel_bwd_convert.hpp index 32e007c..9a25ff3 100644 --- a/csrc/sm100/kernel/fmha_kernel_bwd_convert.hpp +++ b/csrc/sm100/prefill/dense/kernel/fmha_kernel_bwd_convert.hpp @@ -34,6 +34,7 @@ #include "cutlass/cutlass.h" #include "cute/layout.hpp" +#include "utils.h" // for IS_SM100 namespace cutlass::fmha::kernel { @@ -138,6 +139,7 @@ struct FmhaKernelBwdConvert { } CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { +#if IS_SM100 if (params.ptr_src_dQ != nullptr) { copy(params, params.ptr_src_dQ, params.stride_src_dQ, params.ptr_dest_dQ, params.stride_dest_dQ, get<0>(params.problem_shape), get<2>(params.problem_shape)); } @@ -147,6 +149,11 @@ struct FmhaKernelBwdConvert { if (params.ptr_src_dV != nullptr) { copy(params, params.ptr_src_dV, params.stride_src_dV, params.ptr_dest_dV, params.stride_dest_dV, get<1>(params.problem_shape), get<3>(params.problem_shape)); } +#else + if (cute::thread0()) { + CUTE_INVALID_CONTROL_PATH("This kernel only supports sm100\n"); + } +#endif } }; diff --git a/csrc/sm100/kernel/fmha_kernel_bwd_sum_OdO.hpp b/csrc/sm100/prefill/dense/kernel/fmha_kernel_bwd_sum_OdO.hpp similarity index 97% rename from csrc/sm100/kernel/fmha_kernel_bwd_sum_OdO.hpp rename to csrc/sm100/prefill/dense/kernel/fmha_kernel_bwd_sum_OdO.hpp index db6a9b4..07ae4f2 100644 --- a/csrc/sm100/kernel/fmha_kernel_bwd_sum_OdO.hpp +++ b/csrc/sm100/prefill/dense/kernel/fmha_kernel_bwd_sum_OdO.hpp @@ -34,6 +34,7 @@ #include "cutlass/cutlass.h" #include "cute/layout.hpp" +#include "utils.h" // for IS_SM100 namespace cutlass::fmha::kernel { @@ -104,6 +105,7 @@ struct FmhaKernelBwdSumOdO { } CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { +#if IS_SM100 auto ptr_O_bh = params.ptr_O + blockIdx.y * get<2,0>(params.stride_O) + blockIdx.z * get<2,1>(params.stride_O); auto ptr_dO_bh = params.ptr_dO + blockIdx.y * get<2,0>(params.stride_dO) + blockIdx.z * get<2,1>(params.stride_dO); auto ptr_sum_OdO_bh = params.ptr_sum_OdO + blockIdx.y * get<1,0>(params.stride_sum_OdO) + blockIdx.z * get<1,1>(params.stride_sum_OdO); @@ -155,6 +157,11 @@ struct FmhaKernelBwdSumOdO { } } } +#else + if (cute::thread0()) { + CUTE_INVALID_CONTROL_PATH("This kernel only supports sm100\n"); + } +#endif } }; diff --git a/csrc/sm100/kernel/fmha_options.hpp b/csrc/sm100/prefill/dense/kernel/fmha_options.hpp similarity index 100% rename from csrc/sm100/kernel/fmha_options.hpp rename to csrc/sm100/prefill/dense/kernel/fmha_options.hpp diff --git a/csrc/sm100/kernel/fmha_tile_scheduler.hpp b/csrc/sm100/prefill/dense/kernel/fmha_tile_scheduler.hpp similarity index 100% rename from csrc/sm100/kernel/fmha_tile_scheduler.hpp rename to csrc/sm100/prefill/dense/kernel/fmha_tile_scheduler.hpp diff --git a/csrc/sm100/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp b/csrc/sm100/prefill/dense/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp similarity index 99% rename from csrc/sm100/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp rename to csrc/sm100/prefill/dense/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp index 59b410b..057b45e 100644 --- a/csrc/sm100/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp +++ b/csrc/sm100/prefill/dense/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp @@ -41,7 +41,8 @@ #include "cutlass/arch/memory_sm80.h" #include "cutlass/gemm/collective/collective_builder.hpp" -#include "collective/fmha_common.hpp" +#include "utils.h" // for IS_SM100 +#include "../collective/fmha_common.hpp" #include @@ -1499,6 +1500,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { CUTLASS_DEVICE void operator()(Params const& params, char* smem) { +#if IS_SM100 int warp_idx = cutlass::canonical_warp_idx_sync(); auto role = warp_idx_to_role(warp_idx); uint32_t lane_predicate = cute::elect_one_sync(); @@ -1823,6 +1825,11 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { /* no-op */ } +#else + if (cute::thread0()) { + CUTE_INVALID_CONTROL_PATH("This kernel only supports sm100\n"); + } +#endif } static dim3 get_block_shape() { diff --git a/csrc/sm100/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp b/csrc/sm100/prefill/dense/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp similarity index 99% rename from csrc/sm100/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp rename to csrc/sm100/prefill/dense/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp index 5a58157..0d4af85 100644 --- a/csrc/sm100/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp +++ b/csrc/sm100/prefill/dense/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp @@ -41,7 +41,8 @@ #include "cutlass/arch/memory_sm80.h" #include "cutlass/gemm/collective/collective_builder.hpp" -#include "collective/fmha_common.hpp" +#include "utils.h" // for IS_SM100 +#include "../collective/fmha_common.hpp" #include @@ -1492,6 +1493,7 @@ struct Sm100FmhaBwdMlaKernelTmaWarpSpecialized { CUTLASS_DEVICE void operator()(Params const& params, char* smem) { +#if IS_SM100 int warp_idx = cutlass::canonical_warp_idx_sync(); auto role = warp_idx_to_role(warp_idx); uint32_t lane_predicate = cute::elect_one_sync(); @@ -1816,6 +1818,11 @@ struct Sm100FmhaBwdMlaKernelTmaWarpSpecialized { /* no-op */ } +#else + if (cute::thread0()) { + CUTE_INVALID_CONTROL_PATH("This kernel only supports sm100\n"); + } +#endif } static dim3 get_block_shape() { diff --git a/csrc/sm100/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp b/csrc/sm100/prefill/dense/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp similarity index 98% rename from csrc/sm100/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp rename to csrc/sm100/prefill/dense/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp index 43bb035..ef75280 100644 --- a/csrc/sm100/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp +++ b/csrc/sm100/prefill/dense/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp @@ -37,11 +37,12 @@ #include "cutlass/pipeline/pipeline.hpp" #include "cute/arch/tmem_allocator_sm100.hpp" -#include "kernel/fmha_options.hpp" -#include "kernel/fmha_tile_scheduler.hpp" -#include "kernel/fmha_causal_tile_scheduler.hpp" -#include "collective/fmha_fusion.hpp" -#include "collective/fmha_common.hpp" +#include "utils.h" // for IS_SM100 +#include "../kernel/fmha_options.hpp" +#include "../kernel/fmha_tile_scheduler.hpp" +#include "../kernel/fmha_causal_tile_scheduler.hpp" +#include "../collective/fmha_fusion.hpp" +#include "../collective/fmha_common.hpp" namespace cutlass::fmha::kernel { @@ -251,6 +252,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized { } CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { +#if IS_SM100 TileScheduler tile_scheduler{params.tile_scheduler}; @@ -629,6 +631,11 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized { /* no-op, donate regs and exit */ } +#else + if (cute::thread0()) { + CUTE_INVALID_CONTROL_PATH("This kernel only supports sm100\n"); + } +#endif } }; diff --git a/csrc/sm90/kernels/config.h b/csrc/sm90/decode/dense/config.h similarity index 78% rename from csrc/sm90/kernels/config.h rename to csrc/sm90/decode/dense/config.h index c9ce159..e97e0bc 100644 --- a/csrc/sm90/kernels/config.h +++ b/csrc/sm90/decode/dense/config.h @@ -8,6 +8,4 @@ static constexpr int PAGE_BLOCK_SIZE = 64; static constexpr int HEAD_DIM_K = 576; static constexpr int HEAD_DIM_V = 512; -static constexpr int FIXED_OVERHEAD_NUM_BLOCKS = 5; - } diff --git a/csrc/sm90/kernels/splitkv_mla.cu b/csrc/sm90/decode/dense/splitkv_mla.cu similarity index 97% rename from csrc/sm90/kernels/splitkv_mla.cu rename to csrc/sm90/decode/dense/splitkv_mla.cu index 5e1fded..cb2e476 100644 --- a/csrc/sm90/kernels/splitkv_mla.cu +++ b/csrc/sm90/decode/dense/splitkv_mla.cu @@ -1,20 +1,22 @@ #include -#include "params.h" #include "utils.h" + +#include "params.h" #include "config.h" #include "traits.h" using namespace cute; using cutlass::arch::NamedBarrier; +namespace sm90 { + // Here we use MAX_INIT_VAL_SM to initialize sM, and MAX_INIT_VAL for masking // The reason is that, we need to calculate new_max = max(sM(row_idx), cur_max*scale_softmax_log2) // so we must guarantee that MAX_INIT_VAL*scale_softmax_log2 < MAX_INIT_VAL_SM static constexpr float MAX_INIT_VAL_SM = -1e30f; static constexpr float MAX_INIT_VAL = -1e33f; - __forceinline__ __device__ int get_AorC_row_idx(int local_row_idx, int idx_in_warpgroup) { // In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in two particular rows. This function converts the local_row_idx (0~2) to the actual row_idx // You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a @@ -756,7 +758,7 @@ __forceinline__ __device__ void wg0_subroutine( TMABarrier barriers_K1[9], bool &cur_phase_K0, const TMAParams &tma_params, - const Flash_fwd_mla_params ¶ms, + const DecodingParams ¶ms, int* block_table_ptr, int seqlen_k, int block_idx, @@ -868,7 +870,7 @@ __forceinline__ __device__ void wg1_subroutine( TMABarrier barriers_K1[9], bool &cur_phase_K1, const TMAParams &tma_params, - const Flash_fwd_mla_params ¶ms, + const DecodingParams ¶ms, int* block_table_ptr, int seqlen_k, int block_idx, @@ -943,7 +945,7 @@ __forceinline__ __device__ void wg1_subroutine( } // A helper function for determining the length of the causal mask for one q token -__forceinline__ __device__ int get_mask_len(const Flash_fwd_mla_params ¶ms, int m_block_idx, int local_seq_q_idx) { +__forceinline__ __device__ int get_mask_len(const DecodingParams ¶ms, int m_block_idx, int local_seq_q_idx) { int global_seq_q_idx = m_block_idx*Config::BLOCK_SIZE_M + local_seq_q_idx; if (global_seq_q_idx < params.q_seq_per_hk) { int s_q_idx = global_seq_q_idx / params.q_head_per_hk; @@ -956,7 +958,7 @@ __forceinline__ __device__ int get_mask_len(const Flash_fwd_mla_params ¶ms, template __global__ void __launch_bounds__(T::NUM_THREADS, 1, 1) -flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params, __grid_constant__ const TmaParams tma_params) { +flash_fwd_splitkv_mla_kernel(__grid_constant__ const DecodingParams params, __grid_constant__ const TmaParams tma_params) { // grid shape: [ // num_m_blocks (=ceil_div(seqlen_q_ori*(num_q_heads//num_kv_heads))), // num_kv_heads, @@ -966,6 +968,7 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params // If is_no_split is True, then this request is exclusively assigned to this sm_part, so we shall write the result directly into params.o_ptr and params.softmax_lse_ptr. Otherwise, write to oaccum_ptr and softmax_lseaccum_ptr, with the corresponding split idx being (n_split_idx + num_splits_ptr[batch_idx]) // For the complete schedule of the kernel, please read our deep-dive write-up (link can be found in the README.md file). +#if IS_SM90 const int m_block_idx = blockIdx.x; const int k_head_idx = blockIdx.y; const int partition_idx = blockIdx.z; @@ -1018,11 +1021,11 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize; // We don't use __ldg here, otherwise NVCC (ptxas, in particular) will do instruction reorder and place __ldg (LDG.E.128.CONSTANT in SASS) in front of cudaGridDependencySynchronize() (ACQBULK in SASS), leading to data race. - int4 tile_scheduler_metadata = *(reinterpret_cast(tile_scheduler_metadata_ptr)); + int4 tile_scheduler_metadata = *(reinterpret_cast(tile_scheduler_metadata_ptr)); int begin_idx = tile_scheduler_metadata.x; - int begin_seqlen = tile_scheduler_metadata.y; + int sched_begin_block_idx = tile_scheduler_metadata.y; int end_idx = tile_scheduler_metadata.z; - int end_seqlen = tile_scheduler_metadata.w; + int sched_end_block_idx = tile_scheduler_metadata.w; if (begin_idx >= params.b) return; int begin_n_split_idx = *(tile_scheduler_metadata_ptr + 4); @@ -1034,9 +1037,9 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params constexpr int kBlockN = T::PAGE_BLOCK_SIZE; const int n_split_idx = batch_idx == begin_idx ? begin_n_split_idx : 0; int seqlen_k = __ldg(params.seqlens_k_ptr + batch_idx); - const int start_block_idx = batch_idx == begin_idx ? begin_seqlen / kBlockN : 0; - int end_block_idx = batch_idx == end_idx ? cute::ceil_div(end_seqlen, kBlockN) : cute::ceil_div(seqlen_k, kBlockN); - const bool is_no_split = start_block_idx == 0 && end_block_idx == cute::ceil_div(seqlen_k, kBlockN); + const int start_block_idx = batch_idx == begin_idx ? sched_begin_block_idx : 0; + int end_block_idx = batch_idx == end_idx ? sched_end_block_idx : cute::ceil_div(seqlen_k, kBlockN); + const bool is_no_split = __ldg(params.num_splits_ptr + batch_idx + 1) - __ldg(params.num_splits_ptr + batch_idx) == 1; int rRightBorderForQSeq[2]; if (params.is_causal) { @@ -1057,7 +1060,8 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params // Besides, a token may have some extra masks other than the common mask. We use rRightBorderForQSeq to denote it, which means the right border of the k-sequence for the particular q token. In this way, (seqlen_k-common_mask_len) - rRightBorderForQSeq < 64 holds, which means that we only need to apply the causal mask to the last two KV blocks // NOTE This may lead to start_block_idx >= end_block_idx which needs some special handling int common_mask_len = get_mask_len(params, m_block_idx, T::BLOCK_SIZE_M-1); - end_block_idx = batch_idx == end_idx ? cute::ceil_div(min(end_seqlen, seqlen_k-common_mask_len), kBlockN) : cute::ceil_div(seqlen_k-common_mask_len, kBlockN); + int last_block_in_seq = cute::ceil_div(seqlen_k-common_mask_len, kBlockN); + end_block_idx = batch_idx == end_idx ? min(sched_end_block_idx, last_block_in_seq) : last_block_in_seq; CUTLASS_PRAGMA_UNROLL for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) { @@ -1267,11 +1271,16 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params if (batch_idx != end_idx) __syncthreads(); } +#else + if (cute::thread0()) { + CUTE_INVALID_CONTROL_PATH("This kernel only supports sm90"); + } +#endif } template -void run_flash_splitkv_mla_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream) { +void run_flash_splitkv_mla_kernel(DecodingParams ¶ms, cudaStream_t stream) { using T = Traits; auto shape_Q = make_shape(params.q_seq_per_hk, params.d, params.h_k, params.b); auto tma_Q = cute::make_tma_copy( @@ -1347,8 +1356,10 @@ void run_flash_splitkv_mla_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t str CHECK_CUDA_KERNEL_LAUNCH(); } -template void run_flash_splitkv_mla_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream); +template void run_flash_splitkv_mla_kernel(DecodingParams ¶ms, cudaStream_t stream); #ifndef FLASH_MLA_DISABLE_FP16 -template void run_flash_splitkv_mla_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream); +template void run_flash_splitkv_mla_kernel(DecodingParams ¶ms, cudaStream_t stream); #endif + +} diff --git a/csrc/sm90/decode/dense/splitkv_mla.h b/csrc/sm90/decode/dense/splitkv_mla.h new file mode 100644 index 0000000..6d45cfa --- /dev/null +++ b/csrc/sm90/decode/dense/splitkv_mla.h @@ -0,0 +1,10 @@ +#pragma once + +#include "params.h" + +namespace sm90 { + +template +void run_flash_splitkv_mla_kernel(DecodingParams ¶ms, cudaStream_t stream); + +} diff --git a/csrc/sm90/kernels/traits.h b/csrc/sm90/decode/dense/traits.h similarity index 100% rename from csrc/sm90/kernels/traits.h rename to csrc/sm90/decode/dense/traits.h diff --git a/csrc/sm90/decode/sparse_fp8/components/config.h b/csrc/sm90/decode/sparse_fp8/components/config.h new file mode 100644 index 0000000..bdba0b8 --- /dev/null +++ b/csrc/sm90/decode/sparse_fp8/components/config.h @@ -0,0 +1,121 @@ +#pragma once + +#include +#include +#include + +using bf16 = cutlass::bfloat16_t; +using fp8 = cutlass::float_e4m3_t; +using transac_bar_t = cutlass::arch::ClusterTransactionBarrier; + +using namespace cute; + +static constexpr int NUM_THREADS = 128*3; +static constexpr int BLOCK_M = 64; +static constexpr int TOPK_BLOCK_SIZE = 64; +static constexpr int PAGE_BLOCK_SIZE = 64; +static constexpr int QUANT_TILE_SIZE = 128; + +static constexpr int HEAD_DIM_K = 576; +static constexpr int HEAD_DIM_V = 512; +static constexpr int HEAD_DIM_NOPE = HEAD_DIM_V; +static constexpr int HEAD_DIM_ROPE = HEAD_DIM_K - HEAD_DIM_V; +static constexpr int NUM_SCALES = HEAD_DIM_NOPE / QUANT_TILE_SIZE; +static constexpr int NUM_BYTES_PER_TOKEN = HEAD_DIM_NOPE + NUM_SCALES*sizeof(float) + HEAD_DIM_ROPE*sizeof(bf16); + +static constexpr int NUM_K_BUFS = 2; + +using SmemLayoutQTile = decltype(tile_to_shape( + GMMA::Layout_SW128_Atom{}, + Shape, Int<64>>{} +)); + +template +using SmemLayoutQTiles = decltype(tile_to_shape( + SmemLayoutQTile{}, + Shape, Int<64*NUM_TILES>>{}, + Step<_1, _2>{} +)); + +using SmemLayoutQ = SmemLayoutQTiles<9>; + +using SmemLayoutKTile = decltype(tile_to_shape( + GMMA::Layout_INTER_Atom{}, + Shape, _64>{}, + Step<_1, _2>{} +)); + +template +using SmemLayoutKTiles = decltype(tile_to_shape( + SmemLayoutKTile{}, + Shape, Int<64*NUM_TILES>>{}, + Step<_1, _2>{} +)); + +template +using SmemLayoutKTilesTransposed = decltype(composition( + SmemLayoutKTiles{}, + Layout, Int>, Stride, _1>>{} +)); + +using SmemLayoutOBuf = decltype(tile_to_shape( + GMMA::Layout_K_SW128_Atom{}, + Shape, Int>{} +)); + +using SmemLayoutOAccumBuf = Layout< + Shape, Int>, + Stride, _1> // We use stride = 520 here to avoid bank conflict +>; + +using SmemLayoutK = SmemLayoutKTiles<9>; +using SmemLayoutV = SmemLayoutKTilesTransposed<8>; +using SmemLayoutHalfV = SmemLayoutKTilesTransposed<4>; + +using SmemLayoutS = decltype(tile_to_shape( + GMMA::Layout_K_SW128_Atom{}, + Shape, Int>{} +)); + +struct SharedMemoryPlan { + array_aligned> q; + union { + array_aligned> k[NUM_K_BUFS]; + array_aligned> oBuf; + array_aligned> oAccumBuf; + } u; + array_aligned> s; + bool is_kv_valid[NUM_K_BUFS][TOPK_BLOCK_SIZE]; + + float sM[BLOCK_M], sL[BLOCK_M], sScale[BLOCK_M]; + transac_bar_t bar_q, bar_k_local_ready[NUM_K_BUFS], bar_k_remote_ready[NUM_K_BUFS], bar_k_avail[NUM_K_BUFS]; +}; + +template< + typename Shape_Q, typename TMA_Q, + typename Shape_O, typename TMA_O +> +struct TmaParams { + Shape_Q shape_Q; TMA_Q tma_Q; + Shape_O shape_O; TMA_O tma_O; +}; + +using TiledMMA_QK = decltype(make_tiled_mma( + GMMA::MMA_64x64x16_F32BF16BF16_SS{}, + Layout>{} +)); + +using TiledMMA_QK_rQ = decltype(make_tiled_mma( + GMMA::MMA_64x64x16_F32BF16BF16_RS{}, + Layout>{} +)); + +using TiledMMA_PV_LocalP = decltype(make_tiled_mma( + GMMA::MMA_64x256x16_F32BF16BF16_RS{}, + Layout>{} +)); + +using TiledMMA_PV_RemoteP = decltype(make_tiled_mma( + GMMA::MMA_64x256x16_F32BF16BF16_SS{}, + Layout>{} +)); diff --git a/csrc/sm90/decode/sparse_fp8/components/dequant.h b/csrc/sm90/decode/sparse_fp8/components/dequant.h new file mode 100644 index 0000000..c3efc05 --- /dev/null +++ b/csrc/sm90/decode/sparse_fp8/components/dequant.h @@ -0,0 +1,88 @@ +#pragma once + +#include +#include + +struct fp8x8 { + __nv_fp8x4_e4m3 lo; + __nv_fp8x4_e4m3 hi; +}; + +struct fp8x16 { + fp8x8 lo; + fp8x8 hi; +}; + +struct bf16x8 { + __nv_bfloat162 a, b, c, d; +}; + +__device__ __forceinline__ +bf16x8 cvt_fp8x8_bf16x8(const fp8x8 &inputs, const float &scale) { + __nv_bfloat162 scale_bf162 = __float2bfloat162_rn(scale); + + #define DEQUANT_FP8x4(OUTPUT_BF16_LO, OUTPUT_BF16_HI, FP8x4) \ + { \ + float4 fp32x4 = (float4)(FP8x4); \ + OUTPUT_BF16_LO = __float22bfloat162_rn({fp32x4.x, fp32x4.y})*scale_bf162; \ + OUTPUT_BF16_HI = __float22bfloat162_rn({fp32x4.z, fp32x4.w})*scale_bf162; \ + } + + bf16x8 result; + DEQUANT_FP8x4(result.a, result.b, inputs.lo); + DEQUANT_FP8x4(result.c, result.d, inputs.hi); + + return result; +} + +enum class L1CacheHint { + NO_ALLOCATE, + EVICT_FIRST, + EVICT_NORMAL, + EVICT_LAST +}; + +enum class L2PrefetchHint { + B64, + B128, + B256 +}; + +template< + typename T, + L1CacheHint l1_cache_hint, + L2PrefetchHint l2_prefetch_hint +> +__device__ __forceinline__ +T load_128b_from_gmem(const void* addr) { + static_assert(sizeof(T) == 128/8); + int4 ret; + + #define EXEC(L1_HINT_STR, L2_HINT_STR) { \ + asm volatile("ld.global.nc.L1::" L1_HINT_STR ".L2::" L2_HINT_STR ".v4.s32 {%0, %1, %2, %3}, [%4];" \ + : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) \ + : "l"(addr)); \ + } + + #define DISPATCH_L2(L1_HINT_STR) { \ + if constexpr(l2_prefetch_hint == L2PrefetchHint::B64) \ + EXEC(L1_HINT_STR, "64B") \ + else if constexpr(l2_prefetch_hint == L2PrefetchHint::B128) \ + EXEC(L1_HINT_STR, "128B") \ + else if constexpr(l2_prefetch_hint == L2PrefetchHint::B256) \ + EXEC(L1_HINT_STR, "256B") \ + } + + if constexpr(l1_cache_hint == L1CacheHint::NO_ALLOCATE) + DISPATCH_L2("no_allocate") + else if constexpr(l1_cache_hint == L1CacheHint::EVICT_FIRST) + DISPATCH_L2("evict_first") + else if constexpr(l1_cache_hint == L1CacheHint::EVICT_NORMAL) + DISPATCH_L2("evict_normal") + else if constexpr(l1_cache_hint == L1CacheHint::EVICT_LAST) + DISPATCH_L2("evict_last") + + #undef EXEC + #undef DISPATCH_L2 + return *reinterpret_cast(&ret); +} diff --git a/csrc/sm90/decode/sparse_fp8/components/epilogue.h b/csrc/sm90/decode/sparse_fp8/components/epilogue.h new file mode 100644 index 0000000..038cbfd --- /dev/null +++ b/csrc/sm90/decode/sparse_fp8/components/epilogue.h @@ -0,0 +1,87 @@ +#pragma once + +#include "named_barriers.h" + +// Store O / OAccum +template< + bool IS_NO_SPLIT, + typename TMAParams, + typename Tensor0, + typename Tensor1, + typename Tensor2, + typename Tensor3 +> +__forceinline__ __device__ void store_o( + Tensor0 &rO, // ((2, 2, 32), 1, 1) + Tensor1 &gOorAccum, // (BLOCK_SIZE_M, HEAD_DIM_V) + Tensor2 &sOutputBuf, + Tensor3 &sOutputAccumBuf, + float rL[2], + TMAParams &tma_params, + int batch_idx, + int s_q_idx, + int head_block_idx, + int num_valid_seq_q, + int warpgroup_idx, + int idx_in_warpgroup +) { + using cutlass::arch::NamedBarrier; + if constexpr (IS_NO_SPLIT) { + // Should convert the output to bfloat16 / float16, and save it to O + Tensor rOb = make_tensor_like(rO); + CUTLASS_PRAGMA_UNROLL + for (int idx = 0; idx < size(rO); ++idx) { + rOb(idx) = (bf16)(rO(idx) / rL[idx%4 >= 2]); + } + + Tensor sMyOutputBuf = local_tile(sOutputBuf, Shape<_64, _256>{}, make_coord(_0{}, warpgroup_idx)); + TiledCopy r2s_tiled_copy = make_tiled_copy_C( + Copy_Atom{}, + TiledMMA_PV_LocalP{} + ); + ThrCopy r2s_thr_copy = r2s_tiled_copy.get_slice(idx_in_warpgroup); + Tensor r2s_thr_copy_rOb = r2s_thr_copy.retile_S(rOb); + Tensor r2s_thr_copy_sMyOutputBuf = r2s_thr_copy.partition_D(sMyOutputBuf); + cute::copy(r2s_tiled_copy, r2s_thr_copy_rOb, r2s_thr_copy_sMyOutputBuf); + cutlass::arch::fence_view_async_shared(); + + NamedBarrier::arrive_and_wait(256, NamedBarriers::epilogue_r2s_ready); + + if (threadIdx.x == 0) { + Tensor tma_gO = tma_params.tma_O.get_tma_tensor(tma_params.shape_O)(_, _, s_q_idx, batch_idx); + auto thr_tma = tma_params.tma_O.get_slice(_0{}); + Tensor my_tma_gO = flat_divide(tma_gO, Shape, Int>{})(_, _, head_block_idx, _0{}); + cute::copy( + tma_params.tma_O, + thr_tma.partition_S(sOutputBuf), + thr_tma.partition_D(my_tma_gO) + ); + cute::tma_store_arrive(); + } + } else { + // Should save the result to OAccum + CUTLASS_PRAGMA_UNROLL + for (int idx = 0; idx < size(rO); idx += 2) { + int row = (idx_in_warpgroup/32)*16 + (idx_in_warpgroup%32/4) + (idx%4 >= 2 ? 8 : 0); + int col = warpgroup_idx*256 + (idx_in_warpgroup%4)*2 + idx/4*8; + *(float2*)(&(sOutputAccumBuf(row, col))) = float2 { + rO(idx) / rL[idx%4 >= 2], + rO(idx+1) / rL[idx%4 >= 2], + }; + } + cutlass::arch::fence_view_async_shared(); + + NamedBarrier::arrive_and_wait(256, NamedBarriers::epilogue_r2s_ready); + + if (elect_one_sync()) { + CUTLASS_PRAGMA_UNROLL + for (int local_row = 0; local_row < BLOCK_M / (256/32); ++local_row) { + int row = local_row * (256/32) + (threadIdx.x / 32); + if (row < num_valid_seq_q) { + SM90_BULK_COPY_S2G::copy(&sOutputAccumBuf(row, _0{}), &gOorAccum(row, _0{}), HEAD_DIM_V*sizeof(float)); + } + } + cute::tma_store_arrive(); + } + } +} diff --git a/csrc/sm90/decode/sparse_fp8/components/helpers.h b/csrc/sm90/decode/sparse_fp8/components/helpers.h new file mode 100644 index 0000000..8a336ea --- /dev/null +++ b/csrc/sm90/decode/sparse_fp8/components/helpers.h @@ -0,0 +1,86 @@ +#pragma once + +// In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in two particular rows. This function converts the local_row_idx (0~1) to the actual row_idx +// You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a +__forceinline__ __device__ int get_AorC_row_idx(int local_row_idx, int idx_in_warpgroup) { + int row_idx = (idx_in_warpgroup/32)*16 + local_row_idx*8 + (idx_in_warpgroup%32/4); + return row_idx; +} + + +// Adapted from https://github.com/Dao-AILab/flash-attention/blob/cdaf2de6e95cb05400959b5ab984f66e4c7df317/hopper/utils.h +template +__forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) { + constexpr bool Is_RS = !cute::is_base_of::value; + // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const + if constexpr (Is_RS) { cute::warpgroup_fence_operand(const_cast(tCrA)); } + warpgroup_fence_operand(tCrC); + if constexpr (arrive) { + warpgroup_arrive(); + } + if constexpr (zero_init) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + } else { + // cute::gemm(tiled_mma, tCrA, tCrB, tCrC); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + } + if constexpr (commit) { + warpgroup_commit_batch(); + } + if constexpr (wg_wait >= 0) { warpgroup_wait(); } + warpgroup_fence_operand(tCrC); + if constexpr (Is_RS) { warpgroup_fence_operand(const_cast(tCrA)); } +} + +template< + typename TMA, + typename Tensor0, + typename Tensor1 +> +CUTE_DEVICE +void launch_tma_copy( + const TMA &tma_copy, + const Tensor0 &src, + Tensor1 &dst, + transac_bar_t &bar, + const cute::TMA::CacheHintSm90 &cache_hint = cute::TMA::CacheHintSm90::EVICT_NORMAL, + const uint16_t &multicast_mask = 0 +) { + auto thr_tma = tma_copy.get_slice(_0{}); + cute::copy( + tma_copy.with(reinterpret_cast(bar), multicast_mask, cache_hint), + thr_tma.partition_S(src), + thr_tma.partition_D(dst) + ); +} + +template +CUTE_DEVICE +static void st_async_128b(void* dst_ptr, const T& data, const transac_bar_t* mbar_ptr) { + long2 data_long2 = *reinterpret_cast(&data); + uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst_ptr); + uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(mbar_ptr); + asm volatile ( + "st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.v2.s64 [%0], {%1, %2}, [%3]; \n" + : + : "r"(dst_addr), "l"(data_long2.x), "l"(data_long2.y), "r"(mbar_addr) + ); +} + +static constexpr int PEER_ADDR_MASK = 16777216; // peer_addr = my_addr ^ PEER_ADDR_MASK. 不确定是不是在所有显卡上都是这个数字 +template +CUTE_DEVICE +T* get_peer_addr(const T* p) { + return (T*)((int64_t)(p) ^ PEER_ADDR_MASK); +} diff --git a/csrc/sm90/decode/sparse_fp8/components/named_barriers.h b/csrc/sm90/decode/sparse_fp8/components/named_barriers.h new file mode 100644 index 0000000..b91cb22 --- /dev/null +++ b/csrc/sm90/decode/sparse_fp8/components/named_barriers.h @@ -0,0 +1,10 @@ +#pragma once + +enum NamedBarriers : uint32_t { + sScale_and_sS_ready = 0, + sScale_and_sS_free = 1, + oBuf_free_and_sL_ready = 2, + epilogue_r2s_ready = 3, + batch_loop_sync = 4, + warpgroup0_sync = 5 +}; diff --git a/csrc/sm90/decode/sparse_fp8/splitkv_mla.cu b/csrc/sm90/decode/sparse_fp8/splitkv_mla.cu new file mode 100644 index 0000000..3283413 --- /dev/null +++ b/csrc/sm90/decode/sparse_fp8/splitkv_mla.cu @@ -0,0 +1,614 @@ +#include "splitkv_mla.h" + +#include +#include +#include +#include + +#include "utils.h" +#include "components/config.h" +#include "components/epilogue.h" +#include "components/helpers.h" +#include "components/named_barriers.h" +#include "components/dequant.h" +using namespace cute; + +namespace sm90 { + +static constexpr float MAX_INIT_VAL = -1e30; // Prevent (-inf) - (-inf) = nan +using cutlass::arch::fence_view_async_shared; +using cutlass::arch::NamedBarrier; + +// Save rPb (64x64, bfloat16) to sP using the stmatrix instruction +template< + typename Tensor0, + typename Tensor1 +> +__forceinline__ __device__ void save_rPb_to_sP( + Tensor0 const &rPb, + Tensor1 const &sP, + int idx_in_warpgroup +) { + auto r2s_copy = make_tiled_copy_C( + Copy_Atom{}, + TiledMMA_QK{} + ); + ThrCopy thr_copy = r2s_copy.get_slice(idx_in_warpgroup); + Tensor thr_copy_rPb = thr_copy.retile_S(rPb); + Tensor thr_copy_sP = thr_copy.partition_D(sP); + cute::copy(r2s_copy, thr_copy_rPb, thr_copy_sP); +} + + +// Retrieve rPb (64x64, bfloat16) from sP using the ldmatrix instruction +template< + typename Tensor0, + typename Tensor1 +> +__forceinline__ __device__ void retrieve_rP_from_sP( + Tensor0 &rPb, + Tensor1 const &sP, + int idx_in_warpgroup +) { + TiledCopy s2r_copy = make_tiled_copy_A( + Copy_Atom{}, + TiledMMA_PV_LocalP{} + ); + ThrCopy thr_copy = s2r_copy.get_slice(idx_in_warpgroup); + Tensor thr_copy_sP = thr_copy.partition_S(sP); + Tensor thr_copy_rPb = thr_copy.retile_D(rPb); + cute::copy(s2r_copy, thr_copy_sP, thr_copy_rPb); +} + + +template< + typename Tensor0, + typename Tensor1, + typename Tensor2 +> +__forceinline__ __device__ void scale_softmax( + Tensor0 &rP, + Tensor1 &rS, + Tensor2 &rO, + float scale_softmax_log2, + float sScale[], + float rM[2], + float rL[2], + bool is_kv_valid[], + int block_idx, + int idx_in_warpgroup +) { + float scale_for_olds[2]; + CUTE_UNROLL + for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) { + Tensor cur_rP = flatten(rP(make_coord(_, local_row_idx, _), _, _)); + Tensor cur_rS = flatten(rS(make_coord(_, local_row_idx, _), _, _)); + Tensor cur_rO = flatten(rO(make_coord(_, local_row_idx, _), _, _)); + + float cur_max = -INFINITY; + CUTE_UNROLL + for (int i = 0; i < size(cur_rP); ++i) { + if (!is_kv_valid[(i&1)+(i/2)*8+(idx_in_warpgroup%4)*2]) + cur_rP(i) = -INFINITY; + cur_max = max(cur_max, cur_rP(i)); + } + cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 1)); + cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 2)); + + cur_max *= scale_softmax_log2; + float old_max = rM[local_row_idx]; + rM[local_row_idx] = max(cur_max, old_max); + float scale_for_old = exp2f(old_max - rM[local_row_idx]); + scale_for_olds[local_row_idx] = scale_for_old; + + CUTE_UNROLL + for (int i = 0; i < size(cur_rO); ++i) { + cur_rO(i) *= scale_for_old; + } + + float cur_sum = 0; + CUTE_UNROLL + for (int i = 0; i < size(cur_rP); ++i) { + cur_rP(i) = exp2f(cur_rP(i)*scale_softmax_log2 - rM[local_row_idx]); + cur_rS(i) = (bf16)cur_rP(i); + cur_sum += cur_rP(i); + } + rL[local_row_idx] = rL[local_row_idx]*scale_for_old + cur_sum; + } + if (idx_in_warpgroup%4 == 0) + *(float2*)(sScale + 2*(idx_in_warpgroup/4)) = *(float2*)(scale_for_olds); +} + +template +__global__ void __launch_bounds__(NUM_THREADS, 1, 2) +flash_fwd_splitkv_mla_fp8_sparse_kernel(__grid_constant__ const DecodingParams params, __grid_constant__ const TmaParams tma_params) { +#if IS_SM90 + const int head_block_idx = blockIdx.x; + const int s_q_idx = blockIdx.y; + const int partition_idx = blockIdx.z; + const int idx_in_cluster = head_block_idx % 2; + const int warpgroup_idx = cutlass::canonical_warp_group_idx(); + const int idx_in_warpgroup = threadIdx.x % 128; + const int warp_idx = cutlass::canonical_warp_idx_sync(); + + // Define shared tensors + extern __shared__ char wksp_buf[]; + SharedMemoryPlan &plan = *reinterpret_cast(wksp_buf); + Tensor sQ = make_tensor(make_smem_ptr(plan.q.data()), SmemLayoutQ{}); + Tensor sOBuf = make_tensor(make_smem_ptr(plan.u.oBuf.data()), SmemLayoutOBuf{}); + Tensor sOAccumBuf = make_tensor(make_smem_ptr(plan.u.oAccumBuf.data()), SmemLayoutOAccumBuf{}); + Tensor sS = make_tensor(make_smem_ptr(plan.s.data()), SmemLayoutS{}); + float* sM = plan.sM; + float* sL = plan.sL; + float* sScale = plan.sScale; + + // Prefetch TMA descriptors + if (warp_idx == 0 && elect_one_sync()) { + cute::prefetch_tma_descriptor(tma_params.tma_Q.get_tma_descriptor()); + cute::prefetch_tma_descriptor(tma_params.tma_O.get_tma_descriptor()); + } + + // Initialize TMA barriers + if (warp_idx == 0 && elect_one_sync()) { + plan.bar_q.init(1); + CUTE_UNROLL + for (int i = 0; i < NUM_K_BUFS; ++i) { + plan.bar_k_local_ready[i].init(128); + plan.bar_k_remote_ready[i].init(1); + plan.bar_k_avail[i].init(4); + } + fence_view_async_shared(); + } + cute::cluster_arrive(); + + bool bar_phase_q = 0; + int bar_phase_k = 0; // Don't use array here to prevent using local memory + + // Programmatic Dependent Launch: Wait for the previous kernel to finish + // Don't use PDL because of compiler bugs! + // cudaGridDependencySynchronize(); + + int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize; + int4 tile_scheduler_metadata = __ldg(reinterpret_cast(tile_scheduler_metadata_ptr)); + int begin_idx = tile_scheduler_metadata.x; + int sched_begin_block_idx = tile_scheduler_metadata.y; + int end_idx = tile_scheduler_metadata.z; + int sched_end_block_idx = tile_scheduler_metadata.w; + if (begin_idx >= params.b) return; + int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4); + + if (warp_idx == 0 && elect_one_sync()) { + Tensor gQ = flat_divide( + tma_params.tma_Q.get_tma_tensor(tma_params.shape_Q)(_, _, s_q_idx, begin_idx), + Tile, Int>{} + )(_, _, head_block_idx, _0{}); + launch_tma_copy(tma_params.tma_Q, gQ, sQ, plan.bar_q, TMA::CacheHintSm90::EVICT_FIRST); + plan.bar_q.arrive_and_expect_tx(BLOCK_M*HEAD_DIM_K*sizeof(bf16)); + } + + cute::cluster_wait(); // Wait for barriers from the other CTA to be ready + + auto get_cur_req_info = [&](int batch_idx) -> std::tuple { + constexpr int kBlockN = TOPK_BLOCK_SIZE; + const int start_block_idx = batch_idx == begin_idx ? sched_begin_block_idx : 0; + // NOTE TopK attention has nothing to do with causal mask and sliding window + int end_block_idx = batch_idx == end_idx ? sched_end_block_idx : cute::ceil_div(params.topk, kBlockN); + const bool is_no_split = start_block_idx == 0 && end_block_idx == cute::ceil_div(params.topk, kBlockN); + return {start_block_idx, end_block_idx, is_no_split}; + }; + + if (warpgroup_idx == 0) { + cutlass::arch::warpgroup_reg_alloc<192>(); + + TiledMMA tiled_mma_QK = TiledMMA_QK{}; + ThrMMA thr_mma_QK = tiled_mma_QK.get_slice(idx_in_warpgroup); + TiledMMA tiled_mma_PV = TiledMMA_PV_LocalP{}; + ThrMMA thr_mma_PV = tiled_mma_PV.get_slice(idx_in_warpgroup); + + float rL[2], rM[2]; + Tensor rO = partition_fragment_C(TiledMMA_PV_LocalP{}, Shape, Int>{}); + Tensor rP = partition_fragment_C(TiledMMA_QK{}, Shape, Int>{}); + Tensor rS = make_tensor(partition_shape_A(TiledMMA_PV_LocalP{}, Shape, Int>{})); + + #pragma unroll 1 + for (int batch_idx = begin_idx; batch_idx <= end_idx; ++batch_idx) { + auto [start_block_idx, end_block_idx, is_no_split] = get_cur_req_info(batch_idx); + + rL[0] = rL[1] = 0.0f; + rM[0] = rM[1] = MAX_INIT_VAL; + cute::fill(rO, 0.); + + // Wait for Q + plan.bar_q.wait(bar_phase_q); + bar_phase_q ^= 1; + + CUTE_NO_UNROLL + for (int block_idx = start_block_idx; block_idx < end_block_idx; block_idx++) { + int buf_idx = (block_idx-start_block_idx) % NUM_K_BUFS; + Tensor sK = make_tensor(make_smem_ptr(plan.u.k[buf_idx].data()), SmemLayoutK{}); + Tensor sV = make_tensor(make_smem_ptr(plan.u.k[buf_idx].data()), SmemLayoutHalfV{}); + + // Wait, issue WGMMA + plan.bar_k_local_ready[buf_idx].wait(bar_phase_k>>buf_idx&1); + plan.bar_k_remote_ready[buf_idx].wait(bar_phase_k>>buf_idx&1); + + gemm( + tiled_mma_QK, + thr_mma_QK.partition_fragment_A(sQ), + thr_mma_QK.partition_fragment_B(sK), + rP + ); + + bar_phase_k ^= 1<(); + + // Calculate S = softmax(mask(scale(P))) + if (block_idx != start_block_idx) + NamedBarrier::arrive_and_wait(256, NamedBarriers::sScale_and_sS_free); // Make sure that sScale and sS is free + + // Since in our case TOPK_BLOCK_SIZE == BLOCK_M, so we only need to do OOB checking for the last 2 blocks + scale_softmax(rP, rS, rO, params.scale_softmax_log2, sScale, rM, rL, plan.is_kv_valid[buf_idx], block_idx, idx_in_warpgroup); + + // Store S into shared, inform warpgroup 1 + save_rPb_to_sP(rS, sS, idx_in_warpgroup); + fence_view_async_shared(); + + // Issue O += S @ V + gemm( + tiled_mma_PV, + rS, + thr_mma_PV.partition_fragment_B(sV), + rO + ); + + NamedBarrier::arrive(256, NamedBarriers::sScale_and_sS_ready); + + cute::warpgroup_wait<0>(); + + plan.bar_k_avail[buf_idx].arrive(0, idx_in_warpgroup == 32); + plan.bar_k_avail[buf_idx].arrive(1, idx_in_warpgroup == 64); + } + + // Copy the next q + if (warp_idx == 0 && elect_one_sync()) { + if (batch_idx != end_idx) { + Tensor gQ = flat_divide( + tma_params.tma_Q.get_tma_tensor(tma_params.shape_Q)(_, _, s_q_idx, batch_idx+1), + Tile, Int>{} + )(_, _, head_block_idx, _0{}); + launch_tma_copy(tma_params.tma_Q, gQ, sQ, plan.bar_q, TMA::CacheHintSm90::EVICT_FIRST); + plan.bar_q.arrive_and_expect_tx(BLOCK_M*HEAD_DIM_K*sizeof(bf16)); + } else { + cudaTriggerProgrammaticLaunchCompletion(); + } + } + + // Synchronize L and M across warpgroups + rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 1); + rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 2); + rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 1); + rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 2); + if (idx_in_warpgroup%4 == 0) { + CUTE_UNROLL + for (int i = 0; i < 2; ++i) { + int row = get_AorC_row_idx(i, idx_in_warpgroup); + sL[row] = rL[i]; + sM[row] = rM[i]; + } + } + + // This is a synchronization point for warpgroup 0/1. + // Warpgroup 0 should wait wg 1 for oBuf/oAccumBuf (overlapped with k) to be free + // Warpgroup 1 should wait wg 0 for sL to be ready + NamedBarrier::arrive_and_wait(256, NamedBarriers::oBuf_free_and_sL_ready); + + CUTE_UNROLL + for (int i = 0; i < 2; ++i) + rL[i] = rL[i] == 0.0f ? 1.0f : rL[i]; + + int num_valid_seq_q = min(params.q_head_per_hk - head_block_idx*BLOCK_M, BLOCK_M); + int start_seq_idx = s_q_idx*params.q_head_per_hk + head_block_idx*BLOCK_M; + if (is_no_split) { + bf16* o_ptr = (bf16*)params.o_ptr + batch_idx*params.o_batch_stride + start_seq_idx*params.o_row_stride; // (BLOCK_M, HEAD_DIM_V) : (params.o_row_stride, 1) + Tensor gO = make_tensor(make_gmem_ptr(o_ptr), make_layout( + Shape, Int>{}, + make_stride(params.o_row_stride, _1{}) + )); + float* gSoftmaxLse = (float*)params.softmax_lse_ptr + batch_idx*params.q_seq_per_hk + start_seq_idx; // (BLOCK_M) : (1) + + store_o(rO, gO, sOBuf, sOAccumBuf, rL, tma_params, batch_idx, s_q_idx, head_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup); + + int i = threadIdx.x; + if (i < num_valid_seq_q) { + float cur_L = sL[i]; + gSoftmaxLse[i] = cur_L == 0.0f ? INFINITY : logf(cur_L) + sM[i] / (float)M_LOG2E; + } + + cute::tma_store_wait<0>(); + } else { + int n_split_idx = batch_idx == begin_idx ? begin_n_split_idx : 0; + int split_idx = __ldg(params.num_splits_ptr+batch_idx) + n_split_idx; + float* oaccum_ptr = (float*)params.oaccum_ptr + (split_idx*params.q_seq_per_hk + start_seq_idx)*HEAD_DIM_V; // (BLOCK_M, HEAD_DIM_V) : (HEAD_DIM_V, 1) + float* gSoftmaxLseAccum = (float*)params.softmax_lseaccum_ptr + split_idx*params.q_seq_per_hk + start_seq_idx; // (BLOCK_M) : (1) + Tensor gOAccum = make_tensor(make_gmem_ptr(oaccum_ptr), Layout< + Shape, Int>, + Stride, _1> + >{}); + store_o(rO, gOAccum, sOBuf, sOAccumBuf, rL, tma_params, batch_idx, s_q_idx, head_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup); + + int i = threadIdx.x; + if (i < num_valid_seq_q) { + float cur_L = sL[i]; + gSoftmaxLseAccum[i] = cur_L == 0.0f ? -INFINITY : log2f(cur_L) + sM[i]; + } + + cute::tma_store_wait<0>(); + } + + cute::cluster_sync(); // Must use arrive_and_wait here to prevent overwritting sL while WG1 is writing back its result + } + } else if (warpgroup_idx == 1) { + cutlass::arch::warpgroup_reg_dealloc<160>(); + + TiledMMA tiled_mma_PV = TiledMMA_PV_RemoteP{}; + ThrMMA thr_mma_PV = tiled_mma_PV.get_slice(idx_in_warpgroup); + Tensor rO = partition_fragment_C(tiled_mma_PV, Shape, Int>{}); + float rL[2]; + + #pragma unroll 1 + for (int batch_idx = begin_idx; batch_idx <= end_idx; ++batch_idx) { + auto [start_block_idx, end_block_idx, is_no_split] = get_cur_req_info(batch_idx); + cute::fill(rO, 0.); + + CUTE_NO_UNROLL + for (int block_idx = start_block_idx; block_idx < end_block_idx; block_idx++) { + int buf_idx = (block_idx-start_block_idx) % NUM_K_BUFS; + Tensor sV = make_tensor(make_smem_ptr(plan.u.k[buf_idx].data() + (SmemLayoutV{})(_256{}, _0{})), SmemLayoutHalfV{}); + + // Wait for S and sScale + NamedBarrier::arrive_and_wait(256, NamedBarriers::sScale_and_sS_ready); + + // Scale O + float cur_scales[2]; + *(float2*)cur_scales = *(float2*)(sScale + (idx_in_warpgroup/4)*2); + CUTE_UNROLL + for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) { + Tensor cur_rO = flatten(rO(make_coord(_, local_row_idx, _), _, _)); + CUTE_UNROLL + for (int i = 0; i < size(cur_rO); ++i) { + cur_rO(i) *= cur_scales[local_row_idx]; + } + } + + // Issue O += S @ V, and wait + gemm( + tiled_mma_PV, + thr_mma_PV.partition_fragment_A(sS), + thr_mma_PV.partition_fragment_B(sV), + rO + ); + cute::warpgroup_wait<0>(); + + plan.bar_k_avail[buf_idx].arrive(0, idx_in_warpgroup == 32); + plan.bar_k_avail[buf_idx].arrive(1, idx_in_warpgroup == 64); + + if (block_idx != end_block_idx-1) + NamedBarrier::arrive(256, NamedBarriers::sScale_and_sS_free); // Tell WG0 that sScale and sS are available + } + + NamedBarrier::arrive_and_wait(256, NamedBarriers::oBuf_free_and_sL_ready); + CUTE_UNROLL + for (int i = 0; i < 2; ++i) { + int row = get_AorC_row_idx(i, idx_in_warpgroup); + rL[i] = sL[row]; + } + + CUTE_UNROLL + for (int i = 0; i < 2; ++i) + rL[i] = rL[i] == 0.0f ? 1.0f : rL[i]; + + int num_valid_seq_q = min(params.q_head_per_hk - head_block_idx*BLOCK_M, BLOCK_M); + int start_seq_idx = s_q_idx*params.q_head_per_hk+head_block_idx*BLOCK_M; + if (is_no_split) { + bf16* o_ptr = (bf16*)params.o_ptr + batch_idx*params.o_batch_stride + start_seq_idx*params.o_row_stride; // (BLOCK_M, HEAD_DIM_V) : (params.o_row_stride, 1) + Tensor gO = make_tensor(make_gmem_ptr(o_ptr), make_layout( + Shape, Int>{}, + make_stride(params.o_row_stride, _1{}) + )); + + store_o(rO, gO, sOBuf, sOAccumBuf, rL, tma_params, batch_idx, s_q_idx, head_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup); + + cute::tma_store_wait<0>(); + } else { + int n_split_idx = batch_idx == begin_idx ? begin_n_split_idx : 0; + int split_idx = __ldg(params.num_splits_ptr+batch_idx) + n_split_idx; + float* oaccum_ptr = (float*)params.oaccum_ptr + (split_idx*params.q_seq_per_hk + start_seq_idx)*HEAD_DIM_V; // (BLOCK_M, HEAD_DIM_V) : (HEAD_DIM_V, 1) + Tensor gOAccum = make_tensor(make_gmem_ptr(oaccum_ptr), Layout< + Shape, Int>, + Stride, _1> + >{}); + store_o(rO, gOAccum, sOBuf, sOAccumBuf, rL, tma_params, batch_idx, s_q_idx, head_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup); + + cute::tma_store_wait<0>(); + } + + cute::cluster_sync(); // We must use arrive_and_wait instead of arrive here to create an order between "forall warp in WG1, warp has done written back O" and "warp 2 signals `bar_k_avail`" + } + } else { + // Producer warpgroup + cutlass::arch::warpgroup_reg_dealloc<152>(); + + int warp_idx = __shfl_sync(0xffffffff, idx_in_warpgroup / 32, 0); // NOTE TPBNO + int lane_idx = idx_in_warpgroup % 32; + int my_token_idx = warp_idx*8 + lane_idx%8; + + CUTE_NO_UNROLL + for (int batch_idx = begin_idx; batch_idx <= end_idx; ++batch_idx) { + auto [start_block_idx, end_block_idx, is_no_split] = get_cur_req_info(batch_idx); + int* gIndices = params.indices_ptr + batch_idx*params.indices_batch_stride + s_q_idx*params.indices_row_stride; // (topk) : (1) + + #define GET_TOKEN_INDEX(block_idx) __ldg(gIndices + (block_idx)*TOPK_BLOCK_SIZE + idx_in_cluster*(TOPK_BLOCK_SIZE/2) + my_token_idx) + int nxt_token_index = GET_TOKEN_INDEX(start_block_idx); + + CUTE_NO_UNROLL + for (int block_idx = start_block_idx; block_idx < end_block_idx; block_idx++) { + int buf_idx = (block_idx-start_block_idx) % NUM_K_BUFS; + + // Define shared and global tensors + bf16* sK_nope_base = plan.u.k[buf_idx].data() + (idx_in_cluster*(TOPK_BLOCK_SIZE/2) + my_token_idx)*8 + ((lane_idx/8)*16)*TOPK_BLOCK_SIZE; + bf16* sK_nope_peer_base = get_peer_addr(sK_nope_base); + + transac_bar_t* peer_bar_k_remote_ready = get_peer_addr(&(plan.bar_k_remote_ready[buf_idx])); + int token_index = nxt_token_index; + if (block_idx+1 != end_block_idx) + nxt_token_index = GET_TOKEN_INDEX(block_idx+1); + int block_index = token_index/PAGE_BLOCK_SIZE; + int rel_idx_in_block = (token_index+PAGE_BLOCK_SIZE) % PAGE_BLOCK_SIZE; // NOTE When token_index is -1, -1/PAGE_BLOCK_SIZE = 0 and (-1+PAGE_BLOCK_SIZE)%PAGE_BLOCK_SIZE = 63, so there will be no illegal-memory-access error + fp8* gK_base = (fp8*)params.k_ptr + block_index*params.k_batch_stride + rel_idx_in_block*params.k_row_stride; + float4 scales = load_128b_from_gmem((float*)(gK_base+HEAD_DIM_NOPE)); + + // Wait for the nope buffer to be available + plan.bar_k_avail[buf_idx].wait((bar_phase_k>>buf_idx&1)^1); + bar_phase_k ^= 1 << buf_idx; + + // Copy block #block_index + if (idx_in_warpgroup == 0) { + plan.bar_k_remote_ready[buf_idx].arrive_and_expect_tx((TOPK_BLOCK_SIZE/2)*(HEAD_DIM_NOPE+HEAD_DIM_ROPE)*sizeof(bf16)); + } + + // Collectively copy from global memory and dequant + // For more detail about the layout of K/V, please refer to comments in flash_mla_interface.py + + fp8* gK_nope = gK_base + (lane_idx/8)*16; + if (token_index == -1) { + scales = {0.0f, 0.0f, 0.0f, 0.0f}; + } + CUTE_UNROLL + for (int dim_idx = 0; dim_idx < HEAD_DIM_NOPE/64; dim_idx += 1) { + fp8x16 cur_fp8x16 = load_128b_from_gmem(gK_nope + dim_idx*64); // We use EVICT_LAST here since gK_base may not be aligned to 32B + float scale = dim_idx < 4 ? (dim_idx < 2 ? scales.x : scales.y) : (dim_idx < 6 ? scales.z : scales.w); + auto dequant_and_save_bf16x8 = [&](const fp8x8 &data, int offset) { + int smem_offset = (dim_idx*64 + offset) * TOPK_BLOCK_SIZE; + bf16x8 cur_bf16x8 = cvt_fp8x8_bf16x8(data, scale); + *(__int128_t*)(sK_nope_base + smem_offset) = *(__int128_t*)&cur_bf16x8; + st_async_128b(sK_nope_peer_base + smem_offset, cur_bf16x8, peer_bar_k_remote_ready); + }; + if (token_index == -1) + *(uint128_t*)(&cur_fp8x16) = uint128_t(); + dequant_and_save_bf16x8(cur_fp8x16.lo, 0); + dequant_and_save_bf16x8(cur_fp8x16.hi, 8); + } + + bf16* gK_rope = (bf16*)(gK_base+HEAD_DIM_NOPE+NUM_SCALES*sizeof(float)) + (lane_idx/8)*8; + bf16* sK_rope_base = plan.u.k[buf_idx].data() + (idx_in_cluster*(TOPK_BLOCK_SIZE/2) + my_token_idx)*8 + ((lane_idx/8)*8)*TOPK_BLOCK_SIZE; + bf16* sK_rope_peer_base = get_peer_addr(sK_rope_base); + + CUTE_UNROLL + for (int dim_idx = 0; dim_idx < HEAD_DIM_ROPE/32; dim_idx += 1) { + bf16x8 cur_bf16x8 = load_128b_from_gmem(gK_rope + dim_idx*32); + if (token_index == -1) + *(uint128_t*)(&cur_bf16x8) = uint128_t(); + int smem_offset = (HEAD_DIM_NOPE + dim_idx*32) * TOPK_BLOCK_SIZE; + *(__int128_t*)(sK_rope_base + smem_offset) = *(__int128_t*)&cur_bf16x8; + st_async_128b(sK_rope_peer_base + smem_offset, cur_bf16x8, peer_bar_k_remote_ready); + } + + fence_view_async_shared(); + + if (idx_in_warpgroup < 32) { + // We put this after fence_view_async_shared() since this won't be read by async proxy + int2 indices = __ldg((int2*)(gIndices + block_idx*TOPK_BLOCK_SIZE + lane_idx*2)); + *(char2*)(&plan.is_kv_valid[buf_idx][lane_idx*2]) = {indices.x != -1, indices.y != -1}; + } + + // Signal the barrier + plan.bar_k_local_ready[buf_idx].arrive(); + } + + cute::cluster_sync(); + } + } + + if (begin_idx > end_idx) { + cute::cluster_sync(); // Don't need a cluster_sync() when begin_idx <= end_idx, since the loop will execute at least once and the final statement is cluster_sync() + } +#else + if (cute::thread0()) { + CUTE_INVALID_CONTROL_PATH("This kernel only supports sm90"); + } +#endif + +} + + +void run_flash_splitkv_mla_fp8_sparse_kernel(DecodingParams ¶ms, cudaStream_t stream) { + FLASH_ASSERT(params.h_k == 1); + FLASH_ASSERT(params.topk % TOPK_BLOCK_SIZE == 0); + + auto shape_Q = make_shape(params.q_head_per_hk, params.d, params.s_q, params.b); + auto tma_Q = cute::make_tma_copy( + SM90_TMA_LOAD{}, + make_tensor( + make_gmem_ptr((bf16*)params.q_ptr), + make_layout( + shape_Q, + make_stride(params.q_row_stride, _1{}, params.q_head_per_hk*params.q_row_stride, params.q_batch_stride) + ) + ), + SmemLayoutQ{} + ); + + auto shape_O = make_shape(params.q_head_per_hk, params.d_v, params.s_q, params.b); + auto tma_O = cute::make_tma_copy( + SM90_TMA_STORE{}, + make_tensor( + make_gmem_ptr((bf16*)params.o_ptr), + make_layout( + shape_O, + make_stride(params.o_row_stride, _1{}, params.q_head_per_hk*params.o_row_stride, params.o_batch_stride) + ) + ), + SmemLayoutOBuf{} + ); + + TmaParams< + decltype(shape_Q), decltype(tma_Q), + decltype(shape_O), decltype(tma_O) + > tma_params = { + shape_Q, tma_Q, + shape_O, tma_O + }; + auto mla_kernel = &flash_fwd_splitkv_mla_fp8_sparse_kernel; + + constexpr size_t smem_size = sizeof(SharedMemoryPlan); + CHECK_CUDA(cudaFuncSetAttribute(mla_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + const int num_m_block = cute::ceil_div(params.q_head_per_hk, 2*BLOCK_M) * 2; + // NOTE Don't use PDL because of potential compiler bugs! + // cudaLaunchAttribute mla_kernel_attributes[1]; + // mla_kernel_attributes[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + // mla_kernel_attributes[0].val.programmaticStreamSerializationAllowed = 1; + // cudaLaunchConfig_t mla_kernel_config = { + // dim3(num_m_block, params.h_k, params.num_sm_parts), + // dim3(NUM_THREADS, 1, 1), + // smem_size, + // stream, + // mla_kernel_attributes, + // 1 + // }; + // cudaLaunchKernelEx(&mla_kernel_config, mla_kernel, params, tma_params); + cutlass::ClusterLaunchParams launch_params = { + dim3(num_m_block, params.s_q, params.num_sm_parts), + dim3(NUM_THREADS, 1, 1), + dim3(2, 1, 1), + smem_size, + stream + }; + cutlass::launch_kernel_on_cluster( + launch_params, (void*)mla_kernel, params, tma_params + ); + CHECK_CUDA_KERNEL_LAUNCH(); +} + +} diff --git a/csrc/sm90/decode/sparse_fp8/splitkv_mla.h b/csrc/sm90/decode/sparse_fp8/splitkv_mla.h new file mode 100644 index 0000000..daa21a3 --- /dev/null +++ b/csrc/sm90/decode/sparse_fp8/splitkv_mla.h @@ -0,0 +1,9 @@ +#pragma once + +#include "params.h" + +namespace sm90 { + +void run_flash_splitkv_mla_fp8_sparse_kernel(DecodingParams ¶ms, cudaStream_t stream); + +} diff --git a/csrc/sm90/flash_api.cpp b/csrc/sm90/flash_api.cpp deleted file mode 100644 index a87e1ab..0000000 --- a/csrc/sm90/flash_api.cpp +++ /dev/null @@ -1,216 +0,0 @@ -// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/flash_api.cpp -/****************************************************************************** - * Copyright (c) 2024, Tri Dao. - ******************************************************************************/ - -#include -#include -#include -#include - -#include - -#include "kernels/config.h" -#include "kernels/get_mla_metadata.h" -#include "kernels/mla_combine.h" -#include "kernels/params.h" -#include "kernels/splitkv_mla.h" - -#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") -#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") -#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") - -std::vector -get_mla_metadata( - at::Tensor &seqlens_k, - const int num_heads_per_head_k, - const int num_heads_k -) { - CHECK_DEVICE(seqlens_k); - TORCH_CHECK(seqlens_k.is_contiguous()); - TORCH_CHECK(seqlens_k.dtype() == torch::kInt32); - - int batch_size = seqlens_k.size(0); - int *seqlens_k_ptr = seqlens_k.data_ptr(); - auto options = seqlens_k.options(); - - auto dprops = at::cuda::getCurrentDeviceProperties(); - int sm_count = dprops->multiProcessorCount; - int num_sm_parts = sm_count / num_heads_k / cutlass::ceil_div(num_heads_per_head_k, Config::BLOCK_SIZE_M); - - auto tile_scheduler_metadata = torch::empty({num_sm_parts, TileSchedulerMetaDataSize}, options); - auto num_splits = torch::empty({batch_size + 1}, options); - int *tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr(); - int *num_splits_ptr = num_splits.data_ptr(); - - at::cuda::CUDAGuard device_guard{(char)seqlens_k.get_device()}; - auto stream = at::cuda::getCurrentCUDAStream().stream(); - Mla_metadata_params params = {}; - params.seqlens_k_ptr = seqlens_k_ptr; - params.tile_scheduler_metadata_ptr = tile_scheduler_metadata_ptr; - params.num_splits_ptr = num_splits_ptr; - params.batch_size = batch_size; - params.block_size_n = Config::PAGE_BLOCK_SIZE; - params.fixed_overhead_num_blocks = Config::FIXED_OVERHEAD_NUM_BLOCKS; - params.num_sm_parts = num_sm_parts; - run_get_mla_metadata_kernel(params, stream); - - return {tile_scheduler_metadata, num_splits}; -} - -std::vector -mha_fwd_kvcache_mla( - at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size - const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size - const int head_size_v, - const at::Tensor &seqlens_k, // batch_size - const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq - const float softmax_scale, - bool is_causal, - const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize - const at::Tensor &num_splits // batch_size + 1 -) { - // Check the architecture - auto dprops = at::cuda::getCurrentDeviceProperties(); - bool is_sm90 = dprops->major == 9 && dprops->minor == 0; - TORCH_CHECK(is_sm90); - - // Check data types - auto q_dtype = q.dtype(); - TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kHalf); - TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype"); - TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32"); - TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); - TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32"); - TORCH_CHECK(num_splits.dtype() == torch::kInt32, "num_splits must have dtype int32"); - - // Check device - CHECK_DEVICE(q); - CHECK_DEVICE(kcache); - CHECK_DEVICE(seqlens_k); - CHECK_DEVICE(block_table); - CHECK_DEVICE(tile_scheduler_metadata); - CHECK_DEVICE(num_splits); - - // Check layout - TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - CHECK_CONTIGUOUS(seqlens_k); - TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); - CHECK_CONTIGUOUS(tile_scheduler_metadata); - CHECK_CONTIGUOUS(num_splits); - - const auto sizes = q.sizes(); - const int batch_size = sizes[0]; - const int seqlen_q_ori = sizes[1]; - const int num_heads_q = sizes[2]; - const int head_size_k = sizes[3]; - TORCH_CHECK(head_size_k == 576, "Only head_size_k == 576 is supported"); - TORCH_CHECK(head_size_v == 512, "Only head_size_v == 576 is supported"); - - const int max_num_blocks_per_seq = block_table.size(1); - const int num_blocks = kcache.size(0); - const int page_block_size = kcache.size(1); - const int num_heads_k = kcache.size(2); - TORCH_CHECK(batch_size > 0, "batch size must be postive"); - TORCH_CHECK(num_heads_q % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); - - if (seqlen_q_ori == 1) { is_causal = false; } - - const int num_q_heads_per_hk = num_heads_q / num_heads_k; - const int q_seq_per_hk = seqlen_q_ori * num_q_heads_per_hk; - const int num_heads = num_heads_k; - q = q.view({batch_size, seqlen_q_ori, num_heads_k, num_q_heads_per_hk, head_size_k}).transpose(2, 3) - .reshape({batch_size, q_seq_per_hk, num_heads, head_size_k}); - - CHECK_SHAPE(q, batch_size, q_seq_per_hk, num_heads, head_size_k); - CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k); - CHECK_SHAPE(seqlens_k, batch_size); - CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); - TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize); - CHECK_SHAPE(num_splits, batch_size+1); - - at::cuda::CUDAGuard device_guard{(char)q.get_device()}; - - auto opts = q.options(); - at::Tensor out = torch::empty({batch_size, q_seq_per_hk, num_heads, head_size_v}, opts); - at::Tensor softmax_lse = torch::empty({batch_size, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat)); - CHECK_CONTIGUOUS(softmax_lse); - - Flash_fwd_mla_params params = {}; - // Set the sizes. - params.b = batch_size; - params.s_q = seqlen_q_ori; - params.q_seq_per_hk = q_seq_per_hk; - params.seqlens_k_ptr = seqlens_k.data_ptr(); - params.h_q = num_heads_q; - params.h_k = num_heads_k; - params.num_blocks = num_blocks; - params.q_head_per_hk = num_q_heads_per_hk; - params.is_causal = is_causal; - params.d = head_size_k; - params.d_v = head_size_v; - params.scale_softmax = softmax_scale; - params.scale_softmax_log2 = float(softmax_scale * M_LOG2E); - // Set the pointers and strides. - params.q_ptr = q.data_ptr(); - params.k_ptr = kcache.data_ptr(); - params.o_ptr = out.data_ptr(); - params.softmax_lse_ptr = softmax_lse.data_ptr(); - // All stride are in elements, not bytes. - params.q_batch_stride = q.stride(0); - params.k_batch_stride = kcache.stride(0); - params.o_batch_stride = out.stride(0); - params.q_row_stride = q.stride(-3); - params.k_row_stride = kcache.stride(-3); - params.o_row_stride = out.stride(-3); - params.q_head_stride = q.stride(-2); - params.k_head_stride = kcache.stride(-2); - params.o_head_stride = out.stride(-2); - - params.block_table = block_table.data_ptr(); - params.block_table_batch_stride = block_table.stride(0); - params.page_block_size = page_block_size; - - params.tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr(); - params.num_sm_parts = tile_scheduler_metadata.size(0); - params.num_splits_ptr = num_splits.data_ptr(); - - const int total_num_splits = batch_size + params.num_sm_parts; - at::Tensor softmax_lse_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat)); - at::Tensor out_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk, head_size_v}, opts.dtype(at::kFloat)); - CHECK_CONTIGUOUS(softmax_lse_accum); - CHECK_CONTIGUOUS(out_accum); - params.total_num_splits = total_num_splits; - params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr(); - params.oaccum_ptr = out_accum.data_ptr(); - - auto stream = at::cuda::getCurrentCUDAStream().stream(); - TORCH_CHECK(head_size_k == 576); - if (q_dtype == torch::kBFloat16) { - run_flash_splitkv_mla_kernel(params, stream); - run_flash_mla_combine_kernel(params, stream); - } else if (q_dtype == torch::kHalf) { -#ifdef FLASH_MLA_DISABLE_FP16 - TORCH_CHECK(false, "FlashMLA is compiled with -DFLASH_MLA_DISABLE_FP16. Please remove this flag from your environment and re-compile FlashMLA."); -#else - run_flash_splitkv_mla_kernel(params, stream); - run_flash_mla_combine_kernel(params, stream); -#endif - } else { - TORCH_CHECK(false, "Unsupported tensor dtype for query"); - } - - out = out.view({batch_size, seqlen_q_ori, num_q_heads_per_hk, num_heads_k, head_size_v}).transpose(2, 3) - .reshape({batch_size, seqlen_q_ori, num_heads_q, head_size_v}); - softmax_lse = softmax_lse.view({batch_size, num_heads_k, seqlen_q_ori, num_q_heads_per_hk}).transpose(2, 3) - .reshape({batch_size, num_heads_q, seqlen_q_ori}); - - return {out, softmax_lse}; -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.doc() = "FlashMLA"; - m.def("get_mla_metadata", &get_mla_metadata); - m.def("fwd_kvcache_mla", &mha_fwd_kvcache_mla); -} diff --git a/csrc/sm90/kernels/get_mla_metadata.h b/csrc/sm90/kernels/get_mla_metadata.h deleted file mode 100644 index 5130581..0000000 --- a/csrc/sm90/kernels/get_mla_metadata.h +++ /dev/null @@ -1,5 +0,0 @@ -#pragma once - -#include "params.h" - -void run_get_mla_metadata_kernel(Mla_metadata_params ¶ms, cudaStream_t stream); diff --git a/csrc/sm90/kernels/mla_combine.h b/csrc/sm90/kernels/mla_combine.h deleted file mode 100644 index 69035e9..0000000 --- a/csrc/sm90/kernels/mla_combine.h +++ /dev/null @@ -1,6 +0,0 @@ -#pragma once - -#include "params.h" - -template -void run_flash_mla_combine_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream); diff --git a/csrc/sm90/kernels/splitkv_mla.h b/csrc/sm90/kernels/splitkv_mla.h deleted file mode 100644 index 479fb50..0000000 --- a/csrc/sm90/kernels/splitkv_mla.h +++ /dev/null @@ -1,6 +0,0 @@ -#pragma once - -#include "params.h" - -template -void run_flash_splitkv_mla_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream); diff --git a/csrc/sm90/prefill/sparse/fwd.cu b/csrc/sm90/prefill/sparse/fwd.cu new file mode 100644 index 0000000..084e0e2 --- /dev/null +++ b/csrc/sm90/prefill/sparse/fwd.cu @@ -0,0 +1,709 @@ +#include "fwd.h" + +#include +#include +#include +#include +#include +#include + +#include "utils.h" +#include "helpers.h" + +namespace sm90 { + +using namespace cute; + +constexpr int D_Q = 576; +constexpr int D_K = 576; +constexpr int D_V = 512; + +constexpr int B_H = 64; +constexpr int B_TOPK = 64; // TopK block size +constexpr int NUM_THREADS = 128*3; +static constexpr float MAX_INIT_VAL = -1e30; // We use this number as the initial value for mi (max logits) + +template +using SmemLayoutQTiles = decltype(coalesce(tile_to_shape( + GMMA::Layout_K_SW128_Atom{}, + Shape, Int<64*NUM_TILES>>{}, + Step<_1, _2>{} +), Shape<_1, _1>{})); + +template +using SmemLayoutOTiles = decltype(coalesce(tile_to_shape( + GMMA::Layout_K_SW128_Atom{}, + Shape, Int<64*NUM_TILES>>{}, + Step<_1, _2>{} +), Shape<_1, _1>{})); + +template +using SmemLayoutKTiles = decltype(coalesce(tile_to_shape( + GMMA::Layout_SW128_Atom{}, + Shape, Int<64*NUM_TILES>>{}, + Step<_1, _2>{} +), Shape<_1, _1>{})); + +template +using SmemLayoutKTilesTransposed = decltype(composition( + SmemLayoutKTiles{}, + Layout, Int>, Stride, _1>>{} +)); + +using SmemLayoutQ = SmemLayoutQTiles<9>; +using SmemLayoutO = SmemLayoutOTiles<8>; +using SmemLayoutK = SmemLayoutKTiles<9>; +using SmemLayoutV = SmemLayoutKTilesTransposed<8>; +using SmemLayoutHalfV = SmemLayoutKTilesTransposed<4>; + +using SmemLayoutS = decltype(coalesce(tile_to_shape( + GMMA::Layout_K_SW128_Atom{}, + Shape, Int>{} +), Shape<_1, _1>{})); + +struct SharedMemoryPlan { + union { + array_aligned> q; + array_aligned> o; + } q_o; + array_aligned> k[2]; + array_aligned> s; + + bool is_kv_valid[2][B_TOPK]; + float2 sM[32]; + float2 sL[64]; // For reduction across WG0/1 in epilogue + float final_max_logits[64], final_lse[64]; + transac_bar_t bar_q, bar_k0_free[2], bar_k0_ready[2], bar_k1_free[2], bar_k1_ready[2], bar_is_kv_valid_ready; +}; + +using TiledMMA_QK = decltype(make_tiled_mma( + GMMA::MMA_64x64x16_F32BF16BF16_SS{}, + Layout>{} +)); + +using TiledMMA_PV_LocalP = decltype(make_tiled_mma( + GMMA::MMA_64x256x16_F32BF16BF16_RS{}, + Layout>{} +)); + +using TiledMMA_PV_RemoteP = decltype(make_tiled_mma( + GMMA::MMA_64x256x16_F32BF16BF16_SS{}, + Layout>{} +)); + +template< + typename Shape_Q, typename TMA_Q +> +struct TmaParams { + Shape_Q shape_Q; TMA_Q tma_Q; + CUtensorMap tensor_map_O; +}; + +enum NamedBarriers : uint32_t { + wg0_bunch_0_ready = 0, + wg1_bunch_0_ready = 1, + wg0_s0_ready = 2, + wg1_s1_ready = 3, + sL_ready = 4, + warpgroup0_sync = 5, + warpgroup1_sync = 6 +}; + +// Save rPb (64x64, bfloat16) to sP using the stmatrix instruction +template< + typename Tensor0, + typename Tensor1 +> +__forceinline__ __device__ void save_rS_to_sS( + Tensor0 const &rPb, + Tensor1 const &sP, + int idx_in_warpgroup +) { + auto r2s_copy = make_tiled_copy_C( + Copy_Atom{}, + TiledMMA_QK{} + ); + ThrCopy thr_copy = r2s_copy.get_slice(idx_in_warpgroup); + Tensor thr_copy_rPb = thr_copy.retile_S(rPb); + Tensor thr_copy_sP = thr_copy.partition_D(sP); + cute::copy(r2s_copy, thr_copy_rPb, thr_copy_sP); +} + + +template +__global__ void __launch_bounds__(NUM_THREADS, 1, 1) +sparse_attn_fwd_kernel(__grid_constant__ const SparsePrefillParams params, __grid_constant__ const TmaParams tma_params) { + // NOTE This kernel uses a similar schedule to Flash MLA - 0422. For a detailed explanation, please refer to https://github.com/deepseek-ai/FlashMLA/blob/main/docs/20250422-new-kernel-deep-dive.md +#if IS_SM90 + const int q_h_idx = blockIdx.x % (params.h_q/B_H); + const int s_q_idx = blockIdx.x / (params.h_q/B_H); + const int warpgroup_idx = cutlass::canonical_warp_group_idx(); + const int warp_idx = cutlass::canonical_warp_idx_sync(); + const int idx_in_warpgroup = threadIdx.x % 128; + + // Define shared tensors + extern __shared__ char wksp_buf[]; + SharedMemoryPlan &plan = *reinterpret_cast(wksp_buf); + Tensor sQ = make_tensor(make_smem_ptr(plan.q_o.q.data()), SmemLayoutQ{}); + Tensor sO = make_tensor(make_smem_ptr(plan.q_o.o.data()), SmemLayoutO{}); + Tensor sS0 = make_tensor(make_smem_ptr(plan.k[0].data()+64*512), SmemLayoutS{}); // Overlap with sK0's RoPE part + Tensor sS1 = make_tensor(make_smem_ptr(plan.s.data()), SmemLayoutS{}); + + if (warp_idx == 0 && elect_one_sync()) { + // Prefetch TMA descriptors + cute::prefetch_tma_descriptor(tma_params.tma_Q.get_tma_descriptor()); + cute::prefetch_tma_descriptor(&tma_params.tensor_map_O); + + // Initialize barriers + plan.bar_q.init(1); + CUTE_UNROLL + for (int i = 0; i < 2; ++i) { + plan.bar_k0_free[i].init(128); + plan.bar_k0_ready[i].init(128); + plan.bar_k1_free[i].init(128); + plan.bar_k1_ready[i].init(128); + } + plan.bar_is_kv_valid_ready.init(16); + fence_barrier_init(); + } + + __syncthreads(); + + const int num_topk_blocks = params.topk / B_TOPK; + if (warpgroup_idx == 0 || warpgroup_idx == 1) { + cutlass::arch::warpgroup_reg_alloc<216>(); + + if (warp_idx == 0 && elect_one_sync()) { + // Load Q + Tensor gQ = flat_divide( + tma_params.tma_Q.get_tma_tensor(tma_params.shape_Q)(_, _, s_q_idx), + Tile, Int>{} + )(_, _, q_h_idx, _0{}); + launch_tma_copy(tma_params.tma_Q, gQ, sQ, plan.bar_q, TMA::CacheHintSm90::EVICT_FIRST); + plan.bar_q.arrive_and_expect_tx(B_H*D_Q*sizeof(bf16)); + } + + float rM[2] = {MAX_INIT_VAL, MAX_INIT_VAL}; // Meaning: the `max_logits` used for O / rL calculation + float rL[2] = {0.0f, 0.0f}; + Tensor rO = partition_fragment_C(TiledMMA_PV_LocalP{}, Shape, Int>{}); + Tensor rP = partition_fragment_C(TiledMMA_QK{}, Shape, Int>{}); + Tensor rS = make_tensor(partition_shape_A(TiledMMA_PV_LocalP{}, Shape, Int>{})); + cute::fill(rO, 0.0f); + + // Wait for Q + plan.bar_q.wait(0); + + bool cur_bar_wait_phase = 0; + + struct Warpgroup0 {}; + struct Warpgroup1 {}; + + auto qkt_gemm_one_tile = [&](auto warpgroup_idx, int tile_idx, bool clear_accum) { + constexpr bool IS_WG1 = std::is_same_v; + TiledMMA tiled_mma_QK = TiledMMA_QK{}; + Tensor sQ_tile = flat_divide(sQ, Tile, Int<64>>{})(_, _, _0{}, tile_idx); + Tensor sK_tile = make_tensor(make_smem_ptr(plan.k[(int)IS_WG1].data() + tile_idx*B_TOPK*64), SmemLayoutKTiles<1>{}); + gemm_ss(clear_accum, tiled_mma_QK, sQ_tile, sK_tile, rP, idx_in_warpgroup); + }; + + auto mask_rP = [&](auto warpgroup_idx) { + constexpr bool IS_WG1 = std::is_same_v; + plan.bar_is_kv_valid_ready.wait(cur_bar_wait_phase); + CUTE_UNROLL + for (int row_idx = 0; row_idx < 2; ++row_idx) { + CUTE_UNROLL + for (int i = row_idx*2; i < size(rP); i += 4) { + int col = 8*(i/4) + (idx_in_warpgroup%4)*2; + if (!plan.is_kv_valid[IS_WG1][col]) rP(i) = -INFINITY; + if (!plan.is_kv_valid[IS_WG1][col+1]) rP(i+1) = -INFINITY; + } + } + }; + + auto online_softmax_and_rescale_o = [&](auto warpgroup_idx) { + plan.bar_is_kv_valid_ready.wait(cur_bar_wait_phase); + constexpr bool IS_WG1 = std::is_same_v; + const float scale = params.sm_scale_div_log2; + float r_sM[2]; + if constexpr (IS_WG1) { + *(float2*)r_sM = plan.sM[idx_in_warpgroup/4]; + } + float new_maxs[2]; + CUTE_UNROLL + for (int row_idx = 0; row_idx < 2; ++row_idx) { + // Get rowwise max + float cur_max = -INFINITY; + CUTE_UNROLL + for (int i = row_idx*2; i < size(rP); i += 4) { + cur_max = max(cur_max, max(rP(i), rP(i+1))); + } + cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 1)); + cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 2)); + cur_max *= scale; + + // Get new max and scale + // For WG1, old_max comes from sM (written by WG0); for WG0, old_max comes from rM (read by WG0 from sM in the last round) + new_maxs[row_idx] = max(IS_WG1 ? r_sM[row_idx] : rM[row_idx], cur_max); + + // Scale O + float scale_for_o = exp2f(rM[row_idx]-new_maxs[row_idx]); + CUTE_UNROLL + for (int i = row_idx*2; i < size(rO); i += 4) { + rO(i) *= scale_for_o; + rO(i+1) *= scale_for_o; + } + + // Get rS + float cur_sum = 0; + CUTE_UNROLL + for (int i = row_idx*2; i < size(rP); i += 4) { + rP(i) = exp2f(rP(i)*scale - new_maxs[row_idx]); + rP(i+1) = exp2f(rP(i+1)*scale - new_maxs[row_idx]); + rS(i) = (bf16)rP(i); + rS(i+1) = (bf16)rP(i+1); + cur_sum += rP(i) + rP(i+1); + } + rL[row_idx] = rL[row_idx]*scale_for_o + cur_sum; + } + __syncwarp(); + if (idx_in_warpgroup%4 == 0) { + plan.sM[idx_in_warpgroup/4] = *(float2*)new_maxs; + } + rM[0] = new_maxs[0]; + rM[1] = new_maxs[1]; + }; + + auto reduce_L = [&]() { + // Reduce L + // For example, thread 0 reduces with thread 1, 2, and 3, as well as thread 128, 129, 130, and 131 + rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 1); + rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 2); + rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 1); + rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 2); + if (idx_in_warpgroup%4 == 0) + plan.sL[threadIdx.x/4] = *(float2*)(rL); + NamedBarrier::arrive_and_wait(256, NamedBarriers::sL_ready); + float2 peer_L = plan.sL[(threadIdx.x/4)^32]; + rL[0] += peer_L.x; + rL[1] += peer_L.y; + }; + + auto store_O = [&]() { + float scale_factors[2]; + CUTE_UNROLL + for (int i = 0; i < 2; ++i) + scale_factors[i] = rL[i] == 0.0f ? 1.0f : 1.0f / rL[i]; + + Tensor sO = make_tensor(make_smem_ptr(plan.q_o.o.data() + warpgroup_idx*B_H*(D_V/2)), SmemLayoutOTiles<4>{}); + bf16* stsm_addrs[4]; + int stsm_row = (idx_in_warpgroup/32)*16 + (idx_in_warpgroup%16); + CUTE_UNROLL + for (int i = 0; i < 64/16; ++i) { + stsm_addrs[i] = &sO(stsm_row, (idx_in_warpgroup%32/16*8) + 16*i); + } + bool s2g_pred = warp_idx%4 == 0 && elect_one_sync(); + + warpgroup_wait<0>(); + CUTE_UNROLL + for (int tile_idx = 0; tile_idx < (D_V/2)/64; tile_idx += 1) { + // Convert + constexpr int NUM_ELEMS_EACH_TILE = B_H*64 / 128; // 64: tile size, 128: warpgroup size + bf16 cur_rOb[NUM_ELEMS_EACH_TILE]; + CUTE_UNROLL + for (int i = 0; i < NUM_ELEMS_EACH_TILE; ++i) { + cur_rOb[i] = (bf16)(rO(tile_idx*NUM_ELEMS_EACH_TILE + i) * scale_factors[i%4>=2]); + } + // R -> S + CUTE_UNROLL + for (int i = 0; i < 64/16; ++i) { + SM90_U32x4_STSM_N::copy( + *reinterpret_cast(cur_rOb + i*8 + 0), + *reinterpret_cast(cur_rOb + i*8 + 2), + *reinterpret_cast(cur_rOb + i*8 + 4), + *reinterpret_cast(cur_rOb + i*8 + 6), + *reinterpret_cast(stsm_addrs[i] + tile_idx*(B_H*64)) + ); + } + fence_view_async_shared(); + NamedBarrier::arrive_and_wait(128, warpgroup_idx ? NamedBarriers::warpgroup1_sync : NamedBarriers::warpgroup0_sync); + // S -> G + if (s2g_pred) { + int g_tile_idx = warpgroup_idx*4 + tile_idx; + SM90_TMA_STORE_3D::copy( + &tma_params.tensor_map_O, + plan.q_o.o.data() + g_tile_idx*(B_H*64), + g_tile_idx*64, + q_h_idx*B_H, + s_q_idx + ); + } + } + cute::tma_store_arrive(); + }; + + + if (warpgroup_idx == 0) { + // Warpgroup 0 + + auto pipelined_wait_and_qkt_gemm_l = [&]() __attribute__((always_inline)) { + plan.bar_k0_ready[0].wait(cur_bar_wait_phase); + qkt_gemm_one_tile(Warpgroup0{}, 0, true); + qkt_gemm_one_tile(Warpgroup0{}, 1, false); + qkt_gemm_one_tile(Warpgroup0{}, 2, false); + qkt_gemm_one_tile(Warpgroup0{}, 3, false); + warpgroup_commit_batch(); + }; + + auto pipelined_wait_and_qkt_gemm_r = [&]() __attribute__((always_inline)) { + plan.bar_k0_ready[1].wait(cur_bar_wait_phase); + qkt_gemm_one_tile(Warpgroup0{}, 4, false); + qkt_gemm_one_tile(Warpgroup0{}, 5, false); + qkt_gemm_one_tile(Warpgroup0{}, 6, false); + qkt_gemm_one_tile(Warpgroup0{}, 7, false); + qkt_gemm_one_tile(Warpgroup0{}, 8, false); + warpgroup_commit_batch(); + }; + + auto scale_rS = [&](float scales[2]) { + CUTE_UNROLL + for (int row = 0; row < 2; ++row) { + CUTE_UNROLL + for (int i = row*2; i < size(rP); i += 4) { + rS(i) = (bf16)(rP(i) * scales[row]); + rS(i+1) = (bf16)(rP(i+1) * scales[row]); + } + } + }; + + auto rescale_rO = [&](float scales[2]) { + CUTE_UNROLL + for (int row = 0; row < 2; ++row) { + CUTE_UNROLL + for (int i = row*2; i < size(rO); i += 4) { + rO(i) *= scales[row]; + rO(i+1) *= scales[row]; + } + rL[row] *= scales[row]; + } + }; + + CUTE_NO_UNROLL + for (int block_idx = 0; block_idx < num_topk_blocks; block_idx += 2) { + Tensor sV0l = make_tensor(make_smem_ptr(plan.k[0].data()), SmemLayoutKTilesTransposed<4>{}); + Tensor sV1l = make_tensor(make_smem_ptr(plan.k[1].data()), SmemLayoutKTilesTransposed<4>{}); + + if (block_idx == 0) { + // NOTE We put these code here to avoid register spilling + pipelined_wait_and_qkt_gemm_l(); + pipelined_wait_and_qkt_gemm_r(); + warpgroup_wait<0>(); + } + + // Online softmax, inform WG1 + mask_rP(Warpgroup0{}); + + online_softmax_and_rescale_o(Warpgroup0{}); + NamedBarrier::arrive(256, NamedBarriers::wg0_bunch_0_ready); + + // Issue rO0 += rS0 @ sV0l + gemm_rs(false, TiledMMA_PV_LocalP{}, rS, sV0l, rO, idx_in_warpgroup); + warpgroup_commit_batch(); + + // Mark V0L as free + warpgroup_wait<0>(); + plan.bar_k0_free[0].arrive(); + + // Wait for new sM, scale rS, save, inform WG1 + NamedBarrier::arrive_and_wait(256, NamedBarriers::wg1_bunch_0_ready); + float new_rM[2], scale_factors[2]; + *(float2*)new_rM = plan.sM[idx_in_warpgroup/4]; + CUTE_UNROLL + for (int i = 0; i < 2; ++i) { + scale_factors[i] = exp2f(rM[i] - new_rM[i]); + rM[i] = new_rM[i]; + } + scale_rS(scale_factors); + save_rS_to_sS(rS, sS0, idx_in_warpgroup); + fence_view_async_shared(); + NamedBarrier::arrive(256, NamedBarriers::wg0_s0_ready); + + // Wait for sS1 + NamedBarrier::arrive_and_wait(256, NamedBarriers::wg1_s1_ready); + + // Rescale rO0, Issue rO0 += sS1 @ sV1L + rescale_rO(scale_factors); + gemm_ss(false, TiledMMA_PV_RemoteP{}, sS1, sV1l, rO, idx_in_warpgroup); + warpgroup_commit_batch(); + + cur_bar_wait_phase ^= 1; + + if (block_idx+2 < num_topk_blocks) { + // Launch the next QK^T GEMM + pipelined_wait_and_qkt_gemm_l(); + + // Mark V1L as free + warpgroup_wait<1>(); + plan.bar_k1_free[0].arrive(); + pipelined_wait_and_qkt_gemm_r(); + + // Wait for rP0 = sQ @ sK0 + warpgroup_wait<0>(); + } else { + // Mark V1L as free + warpgroup_wait<0>(); + plan.bar_k1_free[0].arrive(); + } + } + + reduce_L(); + store_O(); + } else { + // Warpgroup 1 + + auto pipelined_wait_and_qkt_gemm = [&]() __attribute__((always_inline)) { + plan.bar_k1_ready[1].wait(cur_bar_wait_phase); + qkt_gemm_one_tile(Warpgroup1{}, 4, true); + qkt_gemm_one_tile(Warpgroup1{}, 5, false); + qkt_gemm_one_tile(Warpgroup1{}, 6, false); + qkt_gemm_one_tile(Warpgroup1{}, 7, false); + qkt_gemm_one_tile(Warpgroup1{}, 8, false); + plan.bar_k1_ready[0].wait(cur_bar_wait_phase); + qkt_gemm_one_tile(Warpgroup1{}, 0, false); + qkt_gemm_one_tile(Warpgroup1{}, 1, false); + qkt_gemm_one_tile(Warpgroup1{}, 2, false); + qkt_gemm_one_tile(Warpgroup1{}, 3, false); + warpgroup_commit_batch(); + }; + + CUTE_NO_UNROLL + for (int block_idx = 0; block_idx < num_topk_blocks; block_idx += 2) { + Tensor sV0r = make_tensor(make_smem_ptr(plan.k[0].data()+64*256), SmemLayoutKTilesTransposed<4>{}); + Tensor sV1r = make_tensor(make_smem_ptr(plan.k[1].data()+64*256), SmemLayoutKTilesTransposed<4>{}); + + // Issue rP1 = sQ @ sK1, and wait + pipelined_wait_and_qkt_gemm(); + warpgroup_wait<0>(); + + mask_rP(Warpgroup1{}); + + // Wait for WG0 (for sM), online softmax, Notify WG0 (sM ready) + NamedBarrier::arrive_and_wait(256, NamedBarriers::wg0_bunch_0_ready); + online_softmax_and_rescale_o(Warpgroup1{}); + NamedBarrier::arrive(256, NamedBarriers::wg1_bunch_0_ready); + + + // Issue rO1 += rS1 @ sV1R + gemm_rs(false, TiledMMA_PV_LocalP{}, rS, sV1r, rO, idx_in_warpgroup); + warpgroup_commit_batch(); + + // Wait for WG0 (for sS0), Issue rO1 += rS0 @ sV0R + save_rS_to_sS(rS, sS1, idx_in_warpgroup); // Put it here is faster + NamedBarrier::arrive_and_wait(256, NamedBarriers::wg0_s0_ready); + gemm_ss(false, TiledMMA_PV_RemoteP{}, sS0, sV0r, rO, idx_in_warpgroup); + warpgroup_commit_batch(); + + // Save rS1, inform WG0 + fence_view_async_shared(); + NamedBarrier::arrive(256, NamedBarriers::wg1_s1_ready); + + // Wait for GEMM, and inform that sV1R is free + warpgroup_wait<1>(); + plan.bar_k1_free[1].arrive(); + + // Wait for GEMM, and inform that sV0R is free + warpgroup_wait<0>(); + plan.bar_k0_free[1].arrive(); + + cur_bar_wait_phase ^= 1; + } + + reduce_L(); + store_O(); + + // Save lse + if (idx_in_warpgroup%4 == 0) { + for (int row = 0; row < 2; ++row) { + int real_row = get_AorC_row_idx(row, idx_in_warpgroup); + bool is_no_valid_tokens = rL[row] == 0.0f; + plan.final_max_logits[real_row] = is_no_valid_tokens ? -INFINITY : rM[row]; + plan.final_lse[real_row] = is_no_valid_tokens ? -INFINITY : log2f(rL[row]) + rM[row]; + } + fence_view_async_shared(); + } + + NamedBarrier::arrive_and_wait(128, NamedBarriers::warpgroup1_sync); + if (idx_in_warpgroup == 0) { + int g_offset = s_q_idx*params.h_q + q_h_idx*B_H; + SM90_BULK_COPY_S2G::copy(plan.final_max_logits, params.max_logits + g_offset, B_H*sizeof(float)); + SM90_BULK_COPY_S2G::copy(plan.final_lse, params.lse + g_offset, B_H*sizeof(float)); + cute::tma_store_arrive(); + } + } + } else { + // Producer warpgroup + cutlass::arch::warpgroup_reg_dealloc<72>(); + + constexpr int GROUP_SIZE = 8, NUM_GROUPS = 128/GROUP_SIZE; + constexpr int NUM_ROWS_PER_GROUP = B_TOPK / NUM_GROUPS; + int idx_in_group = idx_in_warpgroup % GROUP_SIZE; + int group_idx = idx_in_warpgroup / GROUP_SIZE; + int* gIndices = params.indices + s_q_idx*params.topk; // [topk] + + bf16* my_sKV_base = &(make_tensor(make_smem_ptr(plan.k[0].data()), SmemLayoutKTiles<1>{})(group_idx, idx_in_group*8)); + bf16* my_gKV_base = params.kv + idx_in_group*8; + + int64_t token_indices[2][NUM_ROWS_PER_GROUP]; + bool is_token_valid[2][NUM_ROWS_PER_GROUP]; + auto load_token_indices = [&](int block_idx) { + CUTE_UNROLL + for (int buf_idx = 0; buf_idx < 2; ++buf_idx) { + CUTE_UNROLL + for (int local_row = 0; local_row < NUM_ROWS_PER_GROUP; ++local_row) { + int offs = (block_idx+buf_idx)*B_TOPK + local_row*NUM_GROUPS + group_idx; + int t = __ldg(gIndices + offs); + token_indices[buf_idx][local_row] = t*(int64_t)params.stride_kv_s_kv; // We mult it with params.stride_kv_s_kv here since it's faster + is_token_valid[buf_idx][local_row] = t >= 0 && t < params.s_kv; + } + } + }; + + int64_t cache_policy = createpolicy_evict_last(); + auto copy_tiles = [&](int block_idx, int buf_idx, int tile_start, int tile_end) { + // Copy some K/V tiles from global memory to shared memory + // A tile has a shape of 64 (B_TOPK) x 64 + // `buf_idx` is the index of the shared memory buffer, 0 or 1 + // `tile_idx` is the index of the tile to load, from 0 to D_K/64-1 = 8 + CUTE_UNROLL + for (int local_row = 0; local_row < NUM_ROWS_PER_GROUP; ++local_row) { + int64_t token_index = token_indices[buf_idx][local_row]; + CUTE_UNROLL + for (int tile_idx = tile_start; tile_idx < tile_end; ++tile_idx) { + cp_async_cacheglobal_l2_prefetch_256B( + my_gKV_base + token_index + tile_idx*64, + my_sKV_base + (buf_idx*B_TOPK*D_K + tile_idx*(B_TOPK*64) + local_row*NUM_GROUPS*64), + is_token_valid[buf_idx][local_row], + cache_policy + ); + } + } + }; + + auto commit_to_mbar = [&](transac_bar_t &bar) { + cutlass::arch::cpasync_barrier_arrive_noinc((uint64_t*)(&bar)); + }; + + int cur_bar_wait_phase = 1; + + CUTE_NO_UNROLL + for (int block_idx = 0; block_idx < num_topk_blocks; block_idx += 2) { + load_token_indices(block_idx); + + // V0L + plan.bar_k0_free[0].wait(cur_bar_wait_phase); + copy_tiles(block_idx+0, 0, 0, 4); + commit_to_mbar(plan.bar_k0_ready[0]); + + // V1R + plan.bar_k1_free[1].wait(cur_bar_wait_phase); + copy_tiles(block_idx+1, 1, 4, 9); + commit_to_mbar(plan.bar_k1_ready[1]); + + // V0R + plan.bar_k0_free[1].wait(cur_bar_wait_phase); + copy_tiles(block_idx+0, 0, 4, 9); + commit_to_mbar(plan.bar_k0_ready[1]); + + // V1L + plan.bar_k1_free[0].wait(cur_bar_wait_phase); + copy_tiles(block_idx+1, 1, 0, 4); + commit_to_mbar(plan.bar_k1_ready[0]); + + // Valid mask + // NOTE V1R's finish implies maskings of the last round have finished + if (idx_in_group == 0) { + CUTE_UNROLL + for (int buf_idx = 0; buf_idx < 2; ++buf_idx) + CUTE_UNROLL + for (int local_row = 0; local_row < NUM_ROWS_PER_GROUP; ++local_row) + plan.is_kv_valid[buf_idx][local_row*NUM_GROUPS+group_idx] = is_token_valid[buf_idx][local_row]; + plan.bar_is_kv_valid_ready.arrive(); + } + + cur_bar_wait_phase ^= 1; + } + } +#else + if (cute::thread0()) { + CUTE_INVALID_CONTROL_PATH("This kernel only supports sm90"); + } +#endif +} + + +void run_fwd_kernel(const SparsePrefillParams& params) { + FLASH_ASSERT(params.h_kv == 1); + FLASH_ASSERT(params.topk % (2*B_TOPK) == 0); // To save some boundry checkings + FLASH_ASSERT(params.topk > 0); + FLASH_ASSERT(params.h_q % B_H == 0); + + auto shape_Q = make_shape(params.h_q, params.d_qk, params.s_q); + auto tma_Q = cute::make_tma_copy( + SM90_TMA_LOAD{}, + make_tensor( + make_gmem_ptr((bf16*)params.q), + make_layout( + shape_Q, + make_stride(params.stride_q_h_q, _1{}, params.stride_q_s_q) + ) + ), + SmemLayoutQ{} + ); + + CUtensorMap tensor_map_O; + { + uint64_t size[3] = {D_V, (unsigned long)params.h_q, (unsigned long)params.s_q}; + uint64_t stride[2] = {D_V*sizeof(bf16), D_V*params.h_q*sizeof(bf16)}; + uint32_t box_size[3] = {64, B_H, 1}; + uint32_t elem_stride[3] = {1, 1, 1}; + CUresult res = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)( + &tensor_map_O, + CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, + 3, + params.out, + size, + stride, + box_size, + elem_stride, + CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, + CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B, + CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE, + CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE + ); + FLASH_ASSERT(res == CUresult::CUDA_SUCCESS); + } + + TmaParams< + decltype(shape_Q), decltype(tma_Q) + > tma_params = { + shape_Q, tma_Q, + tensor_map_O + }; + auto kernel = &sparse_attn_fwd_kernel; + + constexpr size_t smem_size = sizeof(SharedMemoryPlan); + CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + cutlass::ClusterLaunchParams launch_params = { + dim3((params.h_q/B_H)*params.s_q, 1, 1), // NOTE We put s_q on the first dim since it can be larger than 65536 (the maximum size of griddim.y and griddim.z) + dim3(NUM_THREADS, 1, 1), + dim3(1, 1, 1), + smem_size, + params.stream + }; + cutlass::launch_kernel_on_cluster( + launch_params, (void*)kernel, params, tma_params + ); + CHECK_CUDA_KERNEL_LAUNCH(); +} + +} diff --git a/csrc/sm90/prefill/sparse/fwd.h b/csrc/sm90/prefill/sparse/fwd.h new file mode 100644 index 0000000..60cb624 --- /dev/null +++ b/csrc/sm90/prefill/sparse/fwd.h @@ -0,0 +1,9 @@ +#pragma once + +#include "params.h" + +namespace sm90 { + +void run_fwd_kernel(const SparsePrefillParams& params); + +} diff --git a/csrc/sm90/prefill/sparse/helpers.h b/csrc/sm90/prefill/sparse/helpers.h new file mode 100644 index 0000000..fd68c36 --- /dev/null +++ b/csrc/sm90/prefill/sparse/helpers.h @@ -0,0 +1,177 @@ +#pragma once + +#include +#include +#include + +namespace sm90 { + +using bf16 = cutlass::bfloat16_t; +using transac_bar_t = cutlass::arch::ClusterTransactionBarrier; +using cutlass::arch::fence_view_async_shared; +using cutlass::arch::fence_barrier_init; +using cutlass::arch::NamedBarrier; + +__forceinline__ __device__ void cp_async_cacheglobal_l2_prefetch_256B(const void* src, void* dst) { + uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst); + asm volatile("cp.async.cg.shared.global.L2::256B [%0], [%1], %2;\n" + :: "r"(dst_addr), + "l"(src), + "n"(16)); +} + +__forceinline__ __device__ void cp_async_cacheglobal_l2_prefetch_256B(const void* src, void* dst, bool pred, int64_t cache_policy) { + uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst); + asm volatile("cp.async.cg.shared.global.L2::cache_hint.L2::256B [%0], [%1], 16, %2, %3;\n" + :: "r"(dst_addr), + "l"(src), + "r"(pred?16:0), + "l"(cache_policy)); +} + +__forceinline__ __device__ int64_t createpolicy_evict_last() { + int64_t res; + asm volatile( + "createpolicy.fractional.L2::evict_last.b64 %0, 1.0; \n\t" + : "=l"(res) + : + ); + return res; +} + +__forceinline__ __device__ int64_t createpolicy_evict_first() { + int64_t res; + asm volatile( + "createpolicy.fractional.L2::evict_first.b64 %0, 1.0; \n\t" + : "=l"(res) + : + ); + return res; +} + + +__forceinline__ __device__ int get_AorC_row_idx(int local_row_idx, int idx_in_warpgroup) { + // In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in two particular rows. This function converts the local_row_idx (0~2) to the actual row_idx + // You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a + int row_idx = (idx_in_warpgroup/32)*16 + local_row_idx*8 + (idx_in_warpgroup%32/4); + return row_idx; +} + +__forceinline__ __device__ int get_AorC_col_idx(int local_elem_idx, int idx_in_warpgroup) { + int col_idx = 8*(local_elem_idx/4) + (idx_in_warpgroup%4)*2 + (local_elem_idx&1); + return col_idx; +} + +// Adapted from https://github.com/Dao-AILab/flash-attention/blob/cdaf2de6e95cb05400959b5ab984f66e4c7df317/hopper/utils.h +// * Copyright (c) 2024, Tri Dao. +template +__forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) { + using namespace cute; + constexpr bool Is_RS = !cute::is_base_of::value; + // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const + if constexpr (Is_RS) { cute::warpgroup_fence_operand(const_cast(tCrA)); } + warpgroup_fence_operand(tCrC); + if constexpr (arrive) { + warpgroup_arrive(); + } + if constexpr (zero_init) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + } else { + // cute::gemm(tiled_mma, tCrA, tCrB, tCrC); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + } + if constexpr (commit) { + warpgroup_commit_batch(); + } + if constexpr (wg_wait >= 0) { warpgroup_wait(); } + warpgroup_fence_operand(tCrC); + if constexpr (Is_RS) { warpgroup_fence_operand(const_cast(tCrA)); } +} + +// A simpiler version of gemm +template +__forceinline__ __device__ void gemm_ss(bool clear_accum, TiledMma tiled_mma, Tensor0 const &sA, Tensor1 const &sB, Tensor2 &rC_frag, int idx_in_warpgroup) { + using namespace cute; + ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup); + Tensor sA_frag = thr_mma.partition_fragment_A(sA); + Tensor sB_frag = thr_mma.partition_fragment_B(sB); + static_assert(size<2>(sA_frag) == size<2>(sB_frag)); + + warpgroup_fence_operand(rC_frag); + warpgroup_arrive(); + tiled_mma.accumulate_ = clear_accum ? GMMA::ScaleOut::Zero : GMMA::ScaleOut::One; + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(sA_frag); ++k) { + cute::gemm(tiled_mma, sA_frag(_, _, k), sB_frag(_, _, k), rC_frag); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_fence_operand(rC_frag); +} + +template +__forceinline__ __device__ void gemm_rs(bool clear_accum, TiledMma tiled_mma, Tensor0 rA_frag, Tensor1 const &sB, Tensor2 &rC_frag, int idx_in_warpgroup) { + using namespace cute; + ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup); + Tensor sB_frag = thr_mma.partition_fragment_B(sB); + static_assert(size<2>(rA_frag) == size<2>(sB_frag)); + + warpgroup_fence_operand(const_cast(rA_frag)); + warpgroup_fence_operand(rC_frag); + warpgroup_arrive(); + tiled_mma.accumulate_ = clear_accum ? GMMA::ScaleOut::Zero : GMMA::ScaleOut::One; + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(rA_frag); ++k) { + cute::gemm(tiled_mma, rA_frag(_, _, k), sB_frag(_, _, k), rC_frag); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_fence_operand(rC_frag); + warpgroup_fence_operand(const_cast(rA_frag)); +} + + +__forceinline__ __device__ uint32_t get_sm_id() { + uint32_t ret; + asm("mov.u32 %0, %smid;" : "=r"(ret)); + return ret; +} + +static constexpr int PEER_ADDR_MASK = 16777216; // peer_addr = my_addr ^ PEER_ADDR_MASK. 不确定是不是在所有显卡上都是这个数字 +template +CUTE_DEVICE +T* get_peer_addr(const T* p) { + return (T*)((int64_t)(p) ^ PEER_ADDR_MASK); +} + +template< + typename TMA, + typename Tensor0, + typename Tensor1 +> +CUTE_DEVICE +void launch_tma_copy( + const TMA &tma_copy, + Tensor0 src, + Tensor1 dst, + transac_bar_t &bar, + const cute::TMA::CacheHintSm90 &cache_hint = cute::TMA::CacheHintSm90::EVICT_NORMAL +) { + auto thr_tma = tma_copy.get_slice(cute::_0{}); + cute::copy( + tma_copy.with(reinterpret_cast(bar), 0, cache_hint), + thr_tma.partition_S(src), + thr_tma.partition_D(dst) + ); +} + +} diff --git a/csrc/sm90/kernels/get_mla_metadata.cu b/csrc/smxx/get_mla_metadata.cu similarity index 64% rename from csrc/sm90/kernels/get_mla_metadata.cu rename to csrc/smxx/get_mla_metadata.cu index 6b78f9b..9b5be62 100644 --- a/csrc/sm90/kernels/get_mla_metadata.cu +++ b/csrc/smxx/get_mla_metadata.cu @@ -6,7 +6,7 @@ #include "utils.h" __global__ void __launch_bounds__(32, 1, 1) -get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) { +get_mla_metadata_kernel(__grid_constant__ const GetDecodingMetadataParams params) { int *seqlens_k_ptr = params.seqlens_k_ptr; int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr; int *num_splits_ptr = params.num_splits_ptr; @@ -18,12 +18,26 @@ get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) { extern __shared__ int shared_mem[]; int* num_blocks_shared = shared_mem; // [batch_size] int* num_splits_shared = shared_mem + batch_size; // [batch_size+1] + int* seqlens_k_shared = shared_mem + batch_size*2+1; // [batch_size] + int* first_block_idx_shared = shared_mem + batch_size*3+1; // [batch_size] + int* last_block_idx_shared = shared_mem + batch_size*4+1; // [batch_size] int total_num_blocks = 0; for (int i = threadIdx.x; i < batch_size; i += 32) { - int num_blocks = cutlass::ceil_div(seqlens_k_ptr[i], block_size_n); + int cur_s_k = params.topk == -1 ? __ldg(seqlens_k_ptr + i) : params.topk; + seqlens_k_shared[i] = cur_s_k; + int first_token_idx = 0; + int last_token_idx = max(cur_s_k-1, 0); + int cur_first_block_idx = first_token_idx / block_size_n; + int cur_last_block_idx = last_token_idx / block_size_n; + // NOTE Should attend to tokens [first_token_idx, last_token_idx], i.e. blocks [cur_first_block_idx, cur_last_block_idx] + // NOTE Before clamping, first_token_idx <= last_token_idx always holds, so after clamping, first_token_idx <= last_token_idx still holds. + // NOTE if seqlens_k is 0, then first_token_idx == last_token_idx == cur_first_block_idx == cur_last_block_idx == 0. So the sequence will have 1 block. We will correct this later in this kernel. + int num_blocks = cur_last_block_idx - cur_first_block_idx + 1; total_num_blocks += num_blocks + fixed_overhead_num_blocks; num_blocks_shared[i] = num_blocks; + first_block_idx_shared[i] = cur_first_block_idx; + last_block_idx_shared[i] = cur_last_block_idx; } for (int offset = 16; offset >= 1; offset /= 2) { total_num_blocks += __shfl_xor_sync(uint32_t(-1), total_num_blocks, offset); @@ -31,14 +45,14 @@ get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) { __syncwarp(); if (threadIdx.x == 0) { - int payload = max(cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks, 2*fixed_overhead_num_blocks); + int payload = cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks; int now_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0; num_splits_shared[0] = 0; for (int i = 0; i < num_sm_parts; ++i) { int tile_scheduler_metadata0[4], tile_scheduler_metadata1; tile_scheduler_metadata0[0] = now_idx; - tile_scheduler_metadata0[1] = now_block * block_size_n; + tile_scheduler_metadata0[1] = now_block + first_block_idx_shared[now_idx]; tile_scheduler_metadata1 = now_n_split_idx; int remain_payload = payload; while (now_idx < batch_size) { @@ -61,7 +75,7 @@ get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) { } } tile_scheduler_metadata0[2] = now_block > 0 ? now_idx : now_idx - 1; - tile_scheduler_metadata0[3] = now_block > 0 ? now_block * block_size_n : seqlens_k_ptr[now_idx - 1]; + tile_scheduler_metadata0[3] = now_block > 0 ? now_block + first_block_idx_shared[now_idx] : (seqlens_k_shared[now_idx-1] == 0 ? 0 : last_block_idx_shared[now_idx-1] + 1); *reinterpret_cast(tile_scheduler_metadata_ptr + i * TileSchedulerMetaDataSize) = *reinterpret_cast(tile_scheduler_metadata0); tile_scheduler_metadata_ptr[i * TileSchedulerMetaDataSize + 4] = tile_scheduler_metadata1; } @@ -74,8 +88,8 @@ get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) { } } -void run_get_mla_metadata_kernel(Mla_metadata_params ¶ms, cudaStream_t stream) { - int smem_size = sizeof(int) * (params.batch_size*2+1); +void run_get_mla_metadata_kernel(GetDecodingMetadataParams ¶ms, cudaStream_t stream) { + int smem_size = sizeof(int) * (params.batch_size*5+1); CHECK_CUDA(cudaFuncSetAttribute(get_mla_metadata_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); get_mla_metadata_kernel<<<1, 32, smem_size, stream>>>(params); CHECK_CUDA_KERNEL_LAUNCH(); diff --git a/csrc/smxx/get_mla_metadata.h b/csrc/smxx/get_mla_metadata.h new file mode 100644 index 0000000..7a1d1c4 --- /dev/null +++ b/csrc/smxx/get_mla_metadata.h @@ -0,0 +1,5 @@ +#pragma once + +#include "params.h" + +void run_get_mla_metadata_kernel(GetDecodingMetadataParams ¶ms, cudaStream_t stream); diff --git a/csrc/sm90/kernels/mla_combine.cu b/csrc/smxx/mla_combine.cu similarity index 94% rename from csrc/sm90/kernels/mla_combine.cu rename to csrc/smxx/mla_combine.cu index b6ba8f8..ff609bf 100644 --- a/csrc/sm90/kernels/mla_combine.cu +++ b/csrc/smxx/mla_combine.cu @@ -7,13 +7,12 @@ #include "params.h" #include "utils.h" -#include "config.h" // for BLOCK_SIZE_M and HEAD_DIM_V using namespace cute; template __global__ void __launch_bounds__(NUM_THREADS) -flash_fwd_mla_combine_kernel(__grid_constant__ const Flash_fwd_mla_params params) { +flash_fwd_mla_combine_kernel(__grid_constant__ const DecodingParams params) { // grid_shape: [batch_size, num_q_heads*s_q / BLOCK_SIZE_M] // Each CTA gathers the activation of some heads from one batch, do scaling & accumulation, and save the result static_assert(NUM_THREADS/32 == BLOCK_SIZE_M); // The number of warps == block_size_m @@ -176,12 +175,14 @@ flash_fwd_mla_combine_kernel(__grid_constant__ const Flash_fwd_mla_params params template -void run_flash_mla_combine_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream) { +void run_flash_mla_combine_kernel(DecodingParams ¶ms, cudaStream_t stream) { + static constexpr int HEAD_DIM_V = 512; // Since only this head dimension is supported by Flash MLA + FLASH_ASSERT(params.d_v == HEAD_DIM_V); MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, NUM_SPLITS, [&] { constexpr int BLOCK_SIZE_M = 8; constexpr int NUM_THREADS = BLOCK_SIZE_M*32; constexpr size_t smem_size = BLOCK_SIZE_M*(NUM_SPLITS+1)*sizeof(float); - auto combine_kernel = &flash_fwd_mla_combine_kernel; + auto combine_kernel = &flash_fwd_mla_combine_kernel; CHECK_CUDA(cudaFuncSetAttribute(combine_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); // Use cudaLaunchKernelEx to enable PDL (Programmatic Dependent Launch) cudaLaunchAttribute attribute[1]; @@ -200,8 +201,8 @@ void run_flash_mla_combine_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t str CHECK_CUDA_KERNEL_LAUNCH(); } -template void run_flash_mla_combine_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream); +template void run_flash_mla_combine_kernel(DecodingParams ¶ms, cudaStream_t stream); #ifndef FLASH_MLA_DISABLE_FP16 -template void run_flash_mla_combine_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream); +template void run_flash_mla_combine_kernel(DecodingParams ¶ms, cudaStream_t stream); #endif \ No newline at end of file diff --git a/csrc/smxx/mla_combine.h b/csrc/smxx/mla_combine.h new file mode 100644 index 0000000..eca7501 --- /dev/null +++ b/csrc/smxx/mla_combine.h @@ -0,0 +1,6 @@ +#pragma once + +#include "params.h" + +template +void run_flash_mla_combine_kernel(DecodingParams ¶ms, cudaStream_t stream); diff --git a/csrc/sm90/kernels/utils.h b/csrc/utils.h similarity index 71% rename from csrc/sm90/kernels/utils.h rename to csrc/utils.h index ae9d0fc..571412f 100644 --- a/csrc/sm90/kernels/utils.h +++ b/csrc/utils.h @@ -30,3 +30,37 @@ } while(0) #define println(fmt, ...) { print(fmt, ##__VA_ARGS__); print("\n"); } + +template +__inline__ __host__ __device__ T ceil_div(const T &a, const T &b) { + return (a + b - 1) / b; +} + +#ifndef TRAP_ONLY_DEVICE_ASSERT +#define TRAP_ONLY_DEVICE_ASSERT(cond) \ +do { \ + if (not (cond)) \ + asm("trap;"); \ +} while (0) +#endif + +// For development, we define both IS_SM100 and IS_SM90 when using CLion or VSCode IDEs so code highlighting will be correct. +#if defined(__CLION_IDE__) || defined(__VSCODE_IDE__) +#define IS_SM100 1 +#define IS_SM90 1 +#else + +// We define the following macros to detect the CUDA architecture, so that we can enable/disable certains kernels that depends on specific architectures. +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 1000) +#define IS_SM100 1 +#else +#define IS_SM100 0 +#endif + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 900) +#define IS_SM90 1 +#else +#define IS_SM90 0 +#endif + +#endif // defined(__CLION_IDE__) || defined(__VSCODE_IDE__) \ No newline at end of file diff --git a/flash_mla/__init__.py b/flash_mla/__init__.py index d0e6faf..66f1986 100644 --- a/flash_mla/__init__.py +++ b/flash_mla/__init__.py @@ -6,4 +6,5 @@ flash_attn_varlen_func, flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func, + flash_mla_sparse_fwd ) diff --git a/flash_mla/flash_mla_interface.py b/flash_mla/flash_mla_interface.py index 084117e..4d27621 100644 --- a/flash_mla/flash_mla_interface.py +++ b/flash_mla/flash_mla_interface.py @@ -2,30 +2,33 @@ import torch -import flash_mla_sm90 -import flash_mla_sm100 - - +import flash_mla.cuda as flash_mla_cuda def get_mla_metadata( cache_seqlens: torch.Tensor, - num_heads_per_head_k: int, + num_q_tokens_per_head_k: int, num_heads_k: int, + num_heads_q: Optional[int] = None, + is_fp8_kvcache: bool = False, + topk: Optional[int] = None ) -> Tuple[torch.Tensor, torch.Tensor]: """ Arguments: cache_seqlens: (batch_size), dtype torch.int32. - num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k. - num_heads_k: num_heads_k. + num_q_tokens_per_head_k: Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k. + num_heads_k: The number of k heads. + num_heads_q: The number of q heads. This argument is optional when sparse attention is not enabled + is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format. + topk: If not None, sparse attention will be enabled, and only tokens in the `indices` array passed to `flash_mla_with_kvcache_sm90` will be attended to. Returns: tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. num_splits: (batch_size + 1), dtype torch.int32. """ - return flash_mla_sm90.get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k) + return flash_mla_cuda.get_mla_decoding_metadata(cache_seqlens, num_q_tokens_per_head_k, num_heads_k, num_heads_q, is_fp8_kvcache, topk) -def flash_mla_with_kvcache_sm90( +def flash_mla_with_kvcache( q: torch.Tensor, k_cache: torch.Tensor, block_table: torch.Tensor, @@ -35,6 +38,8 @@ def flash_mla_with_kvcache_sm90( num_splits: torch.Tensor, softmax_scale: Optional[float] = None, causal: bool = False, + is_fp8_kvcache: bool = False, + indices: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Arguments: @@ -47,6 +52,8 @@ def flash_mla_with_kvcache_sm90( num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata. softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim). causal: bool. Whether to apply causal attention mask. + is_fp8_kvcache: bool. Whether the k_cache and v_cache are in fp8 format. For the format of FP8 KV cache, please refer to README.md + indices: (batch_size, seq_len_q, topk), torch.int32. If not None, sparse attention will be enabled, and only tokens in the `indices` array will be attended to. Invalid indices should be set to -1 or numbers >= total_seq_len_kv. For details about how to set up `indices`, please refer to README.md. Returns: out: (batch_size, seq_len_q, num_heads_q, head_dim_v). @@ -54,7 +61,9 @@ def flash_mla_with_kvcache_sm90( """ if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) - out, softmax_lse = flash_mla_sm90.fwd_kvcache_mla( + if indices is not None: + assert causal == False, "causal must be `false` if sparse attention is enabled." + out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla( q, k_cache, head_dim_v, @@ -64,10 +73,42 @@ def flash_mla_with_kvcache_sm90( causal, tile_scheduler_metadata, num_splits, + is_fp8_kvcache, + indices ) return out, softmax_lse +def flash_mla_sparse_fwd( + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + sm_scale: float, + d_v: int = 512, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Sparse attention prefill kernel + + Args: + q: [s_q, h_q, d_qk], bfloat16 + kv: [s_kv, h_kv, d_qk], bfloat16 + indices: [s_q, h_kv, topk], int32. Invalid indices should be set to -1 or numbers >= s_kv + sm_scale: float + d_v: The dimension of value vectors. Can only be 512 + + Returns: + (output, max_logits, lse) + About the definition of output, max_logits and lse, please refer to README.md + - output: [s_q, h_q, d_v], bfloat16 + - max_logits: [s_q, h_q], float + - lse: [s_q, h_q], float, 2-based log-sum-exp + """ + results = flash_mla_cuda.sparse_prefill_fwd( + q, kv, indices, sm_scale, d_v + ) + return results + + def _flash_attn_varlen_forward( q: torch.Tensor, k: torch.Tensor, @@ -96,7 +137,7 @@ def _flash_attn_varlen_forward( lse = torch.empty(num_qo_heads, qo_total_len, device=q.device, dtype=torch.float32).T workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.uint8, device=q.device) - flash_mla_sm100.fwd( + flash_mla_cuda.dense_prefill_fwd( workspace_buffer, q, k, @@ -159,7 +200,7 @@ def _flash_attn_varlen_backward( if num_qo_heads != num_kv_heads: workspace_bytes += 2 * kv_total_len * num_qo_heads * (head_dim_qk + head_dim_vo) # dKV_acc workspace_buffer = torch.empty(workspace_bytes, dtype=torch.uint8, device=q.device) - flash_mla_sm100.bwd( + flash_mla_cuda.dense_prefill_bwd( workspace_buffer, do, q, @@ -195,7 +236,7 @@ def forward( causal: bool = False, softmax_scale: Optional[float] = None, is_varlen: bool = True, - ): + ) -> Tuple[torch.Tensor, torch.Tensor]: out, lse = _flash_attn_varlen_forward( q, k, v, cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, @@ -290,40 +331,3 @@ def flash_attn_varlen_kvpacked_func( cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, causal, softmax_scale, is_varlen, ) - - -def flash_mla_with_kvcache_sm100( - q: torch.Tensor, - k_cache: torch.Tensor, - block_table: torch.Tensor, - cache_seqlens: torch.Tensor, - head_dim_v: int, - softmax_scale: Optional[float] = None, - causal: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: - # TODO - pass - - -def flash_mla_with_kvcache( - q: torch.Tensor, - k_cache: torch.Tensor, - block_table: torch.Tensor, - cache_seqlens: torch.Tensor, - head_dim_v: int, - tile_scheduler_metadata: Optional[torch.Tensor] = None, - num_splits: Optional[torch.Tensor] = None, - softmax_scale: Optional[float] = None, - causal: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: - capability = torch.cuda.get_device_capability(q.device.index) - if capability == (9, 0): - return flash_mla_with_kvcache_sm90( - q, k_cache, block_table, cache_seqlens, head_dim_v, - tile_scheduler_metadata, num_splits, - softmax_scale, causal, - ) - elif capability == (10, 0): - raise ValueError(f"Unsupported device capability: {capability}") - else: - raise ValueError(f"Unsupported device capability: {capability}") diff --git a/setup.py b/setup.py index 58cf7b2..338117f 100644 --- a/setup.py +++ b/setup.py @@ -12,29 +12,31 @@ ) -def append_nvcc_threads(nvcc_extra_args): - nvcc_threads = os.getenv("NVCC_THREADS") or "32" - return nvcc_extra_args + ["--threads", nvcc_threads] - +def is_flag_set(flag: str) -> bool: + return os.getenv(flag, "FALSE").lower() in ["true", "1", "y", "yes"] def get_features_args(): features_args = [] - DISABLE_FP16 = os.getenv("FLASH_MLA_DISABLE_FP16", "FALSE") in ["TRUE", "1"] - if DISABLE_FP16: + if is_flag_set("FLASH_MLA_DISABLE_FP16"): features_args.append("-DFLASH_MLA_DISABLE_FP16") return features_args +def get_arch_flags(): + DISABLE_SM100 = is_flag_set("FLASH_MLA_DISABLE_SM100") + DISABLE_SM90 = is_flag_set("FLASH_MLA_DISABLE_SM90") + arch_flags = [] + if not DISABLE_SM100: + arch_flags.extend(["-gencode", "arch=compute_100a,code=sm_100a"]) + if not DISABLE_SM90: + arch_flags.extend(["-gencode", "arch=compute_90a,code=sm_90a"]) + return arch_flags + +def get_nvcc_thread_args(): + nvcc_threads = os.getenv("NVCC_THREADS") or "32" + return ["--threads", nvcc_threads] subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"]) -cc_flag_sm90 = [] -cc_flag_sm90.append("-gencode") -cc_flag_sm90.append("arch=compute_90a,code=sm_90a") - -cc_flag_sm100 = [] -cc_flag_sm100.append("-gencode") -cc_flag_sm100.append("arch=compute_100a,code=sm_100a") - this_dir = os.path.dirname(os.path.abspath(__file__)) if IS_WINDOWS: @@ -45,79 +47,44 @@ def get_features_args(): ext_modules = [] ext_modules.append( CUDAExtension( - name="flash_mla_sm90", + name="flash_mla.cuda", sources=[ - "csrc/sm90/flash_api.cpp", - "csrc/sm90/kernels/get_mla_metadata.cu", - "csrc/sm90/kernels/mla_combine.cu", - "csrc/sm90/kernels/splitkv_mla.cu", + "csrc/pybind.cpp", + "csrc/smxx/get_mla_metadata.cu", + "csrc/smxx/mla_combine.cu", + "csrc/sm90/decode/dense/splitkv_mla.cu", + "csrc/sm90/decode/sparse_fp8/splitkv_mla.cu", + "csrc/sm90/prefill/sparse/fwd.cu", + "csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu", + "csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu", ], extra_compile_args={ "cxx": cxx_args + get_features_args(), - "nvcc": append_nvcc_threads( - [ - "-O3", - "-std=c++17", - "-DNDEBUG", - "-D_USE_MATH_DEFINES", - "-Wno-deprecated-declarations", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_HALF2_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "--use_fast_math", - "--ptxas-options=-v,--register-usage-level=10" - ] - + cc_flag_sm90 - ) + get_features_args(), + "nvcc": [ + "-O3", + "-std=c++17", + "-DNDEBUG", + "-D_USE_MATH_DEFINES", + "-Wno-deprecated-declarations", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + "--ptxas-options=-v,--register-usage-level=10" + ] + get_features_args() + get_arch_flags() + get_nvcc_thread_args(), }, include_dirs=[ + Path(this_dir) / "csrc", Path(this_dir) / "csrc" / "sm90", Path(this_dir) / "csrc" / "cutlass" / "include", - ], - ) -) - -ext_modules.append( - CUDAExtension( - name="flash_mla_sm100", - sources=[ - "csrc/sm100/pybind.cu", - "csrc/sm100/fmha_cutlass_fwd_sm100.cu", - "csrc/sm100/fmha_cutlass_bwd_sm100.cu", - ], - extra_compile_args={ - "cxx": ["-O3", "-std=c++17", "-DNDEBUG", "-Wno-deprecated-declarations"], - "nvcc": append_nvcc_threads( - [ - "-O3", - "-std=c++17", - "-DNDEBUG", - "-Wno-deprecated-declarations", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_HALF2_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "--use_fast_math", - "-lineinfo", - "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage", - ] - + cc_flag_sm100 - ), - }, - include_dirs=[ - Path(this_dir) / "csrc" / "sm100", - Path(this_dir) / "csrc" / "cutlass" / "include", Path(this_dir) / "csrc" / "cutlass" / "tools" / "util" / "include", ], ) ) - try: cmd = ['git', 'rev-parse', '--short', 'HEAD'] rev = '+' + subprocess.check_output(cmd).decode('ascii').rstrip() diff --git a/tests/lib.py b/tests/lib.py new file mode 100644 index 0000000..f884721 --- /dev/null +++ b/tests/lib.py @@ -0,0 +1,73 @@ +from typing import List + +import torch + +def cdiv(x: int, y: int): + return (x+y-1) // y + +def check_is_allclose(name: str, ans: torch.Tensor, ref: torch.Tensor, abs_tol: float = 1e-5, rel_tol: float = 1e-2, cos_diff_tol: float = 1e-7): + """ + Check if two tensors are close enough + """ + def get_cos_diff(x: torch.Tensor, y: torch.Tensor) -> float: + """ + Calculate the cosine diff between two tensors + """ + x, y = x.double(), y.double() + denominator = (x * x + y * y).sum().item() + if denominator == 0: + return 0 + sim = 2 * (x * y).sum().item() / denominator + return 1 - sim + assert ans.shape == ref.shape, f"`{name}` Shape mismatch: {ans.shape} vs {ref.shape}" + + ans = ans.clone().to(torch.float) + ref = ref.clone().to(torch.float) + + # Deal with anomalies + def deal_with_anomalies(val: float): + ref_mask = (ref == val) if (val == val) else (ref != ref) + ans_mask = (ans == val) if (val == val) else (ans != ans) + ref[ref_mask] = 0.0 + ans[ans_mask] = 0.0 + if not torch.equal(ref_mask, ans_mask): + print(f"`{name}` Anomaly number `{val}` mismatch: {ans_mask.sum().item()} in ans but {ref_mask.sum().item()} in ref") + return False + return True + + anomalies_check_passed = True + anomalies_check_passed &= deal_with_anomalies(float("inf")) + anomalies_check_passed &= deal_with_anomalies(float("-inf")) + anomalies_check_passed &= deal_with_anomalies(float("nan")) + + if not anomalies_check_passed: + return False + + cos_diff = get_cos_diff(ans, ref) + raw_abs_err = torch.abs(ans-ref) + raw_rel_err = raw_abs_err / (torch.abs(ref)+(1e-6)) + rel_err = raw_rel_err.masked_fill(raw_abs_err List[int]: + result = [] + for size in t.shape[::-1]: + result.append(pos % size) + pos = pos // size + assert pos == 0 + return result[::-1] + print(f"max abs err: {torch.max(abs_err).item()}: pos {get_pos_in_tensor(ans, max_abs_err_pos)}, {ans.reshape(-1)[max_abs_err_pos].item()} vs {ref.reshape(-1)[max_abs_err_pos].item()}") + print(f"max rel err: {torch.max(rel_err).item()}: pos {get_pos_in_tensor(ans, max_rel_err_pos)}, {ans.reshape(-1)[max_rel_err_pos].item()} vs {ref.reshape(-1)[max_rel_err_pos].item()}") + print(f"{pass_mask.sum()} out of {pass_mask.numel()} passed ({pass_mask.sum()/pass_mask.numel()*100.0:.2f}%)") + print(f"Cosine diff: {cos_diff} (threshold: {cos_diff_tol})") + return False + else: + if abs(cos_diff) > cos_diff_tol: + print(f"`{name}` mismatch: Cosine diff too large: {cos_diff} vs {cos_diff_tol})") + return False + return True \ No newline at end of file diff --git a/tests/quant.py b/tests/quant.py new file mode 100644 index 0000000..afee4b2 --- /dev/null +++ b/tests/quant.py @@ -0,0 +1,68 @@ +import enum + +import torch + +def quantize_k_cache( + input_k_cache: torch.Tensor, # (num_blocks, block_size, h_k, d) + dv: int, + tile_size: int = 128, +) -> torch.Tensor: + """ + Quantize the k-cache + Return a tensor with shape (num_blocks, block_size, h_k, dv + 4(dv/tile_size) + t(d-dv)) of dtype uint8_t, where t = input_k_cache.element_size() + For more detail about the layout of K/V, please refer to comments in flash_mla_interface.py or README.md + """ + assert dv % tile_size == 0 + num_tiles = dv // tile_size + num_blocks, block_size, h_k, d = input_k_cache.shape + assert h_k == 1 + input_k_cache = input_k_cache.squeeze(2) # [num_blocks, block_size, d] + input_elem_size = input_k_cache.element_size() + + result = torch.empty((num_blocks, block_size, dv + num_tiles*4 + input_elem_size*(d-dv)), dtype=torch.float8_e4m3fn, device=input_k_cache.device) + result_k_nope_part = result[..., :dv] + result_k_scale_factor = result[..., dv:dv + num_tiles*4].view(torch.float32) + result_k_rope_part = result[..., dv + num_tiles*4:].view(input_k_cache.dtype) + result_k_rope_part[:] = input_k_cache[..., dv:] + + for tile_idx in range(0, num_tiles): + cur_scale_factors_inv = torch.abs(input_k_cache[..., tile_idx*tile_size:(tile_idx+1)*tile_size]).max(dim=-1).values / 448.0 # [num_blocks, block_size] + result_k_scale_factor[:, :, tile_idx] = cur_scale_factors_inv + + cur_scale_factors_inv.unsqueeze_(-1) # [num_blocks, block_size, 1] + cur_quantized_nope = (input_k_cache[..., tile_idx*tile_size:(tile_idx+1)*tile_size].float() / cur_scale_factors_inv.float()).to(torch.float8_e4m3fn) + result_k_nope_part[..., tile_idx*tile_size:(tile_idx+1)*tile_size] = cur_quantized_nope + + result = result.view(num_blocks, block_size, 1, -1) + return result + + +def dequantize_k_cache( + quant_k_cache: torch.Tensor, # (num_blocks, block_size, 1, bytes_per_token) + dv: int = 512, + tile_size: int = 128, + d: int = 576 +) -> torch.Tensor: + """ + De-quantize the k-cache + """ + assert dv % tile_size == 0 + num_tiles = dv // tile_size + num_blocks, block_size, h_k, _ = quant_k_cache.shape + assert h_k == 1 + result = torch.empty((num_blocks, block_size, d), dtype=torch.bfloat16, device=quant_k_cache.device) + + quant_k_cache = quant_k_cache.view(num_blocks, block_size, -1) + + input_nope = quant_k_cache[..., :dv] + input_scale = quant_k_cache[..., dv:dv + num_tiles*4].view(torch.float32) + input_rope = quant_k_cache[..., dv + num_tiles*4:].view(torch.bfloat16) + result[..., dv:] = input_rope + + for tile_idx in range(0, num_tiles): + cur_nope = input_nope[..., tile_idx*tile_size:(tile_idx+1)*tile_size].to(torch.float32) + cur_scales = input_scale[..., tile_idx].unsqueeze(-1) + result[..., tile_idx*tile_size:(tile_idx+1)*tile_size] = cur_nope * cur_scales + + result = result.view(num_blocks, block_size, 1, d) + return result diff --git a/tests/test_flash_mla_decoding.py b/tests/test_flash_mla_decoding.py new file mode 100644 index 0000000..64ddf72 --- /dev/null +++ b/tests/test_flash_mla_decoding.py @@ -0,0 +1,343 @@ +import argparse +import math +import random +import dataclasses +from typing import Optional, Tuple, List + +import torch +import triton + +import quant +import flash_mla +from lib import cdiv, check_is_allclose + +@dataclasses.dataclass +class TestParam: + b: int # Batch size + s_q: int # Number of queries for one request + s_k: int # Seq len, or mean seq len if varlen == True + is_varlen: bool + is_causal: bool + is_fp8: bool + topk: Optional[int] = None + test_performance: bool = True + is_all_indices_invalid: bool = False + have_zero_seqlen_k: bool = False + block_size: int = 64 + h_q: int = 128 # Number of q heads + h_kv: int = 1 # Number of kv heads + d: int = 576 # Q/K head dim (= dv + RoPE dim) + dv: int = 512 # V head dim + seed: int = 0 + + +def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Generate test data from a given configuration + Return: [cache_seqlens, q, block_table, blocked_k] + Pay attention: This function changes the random seed + """ + random.seed(t.seed) + torch.manual_seed(t.seed) + torch.cuda.manual_seed(t.seed) + torch.backends.cudnn.deterministic = True + + assert t.h_q % t.h_kv == 0 + + cache_seqlens_cpu = torch.full((t.b,), t.s_k, dtype=torch.int32, device='cpu') + if t.is_varlen: + for i in range(t.b): + cache_seqlens_cpu[i] = max(random.normalvariate(t.s_k, t.s_k / 2), t.s_q) + + if t.have_zero_seqlen_k: + zeros_mask = torch.randn(t.b, dtype=torch.float32, device='cpu') > 0 + cache_seqlens_cpu[zeros_mask] = 0 + + max_seqlen = cache_seqlens_cpu.max().item() + max_seqlen_pad = cdiv(max_seqlen, 256) * 256 + cache_seqlens = cache_seqlens_cpu.cuda() + + q = torch.randn(t.b, t.s_q, t.h_q, t.d) + q.clamp_(min=-1.0, max=1.0) + + block_table = torch.arange(t.b * max_seqlen_pad // t.block_size, dtype=torch.int32).view(t.b, max_seqlen_pad // t.block_size) + block_table = block_table.view(-1)[torch.randperm(block_table.numel())].view(t.b, -1) + blocked_k = torch.randn(block_table.numel(), t.block_size, t.h_kv, t.d) / 10 + blocked_k.clamp_(min=-1.0, max=1.0) + + if t.topk is None: + for i in range(t.b): + cur_len = cache_seqlens_cpu[i].item() + cur_num_blocks = cdiv(cur_len, t.block_size) + blocked_k[block_table[i][cur_num_blocks:]] = float("nan") + if cur_len % t.block_size != 0: + blocked_k[block_table[i][cur_num_blocks-1]][cur_len % t.block_size:] = float("nan") + block_table[i][cur_num_blocks:] = 2147480000 + return cache_seqlens, q, block_table, blocked_k, None, None + else: + block_table_cpu = block_table.cpu() + abs_indices = torch.empty(t.b, t.s_q, t.topk, dtype=torch.int32, device="cpu") + indices_in_kvcache = torch.empty(t.b, t.s_q, t.topk, dtype=torch.int32, device="cpu") + for i in range(t.b): + # Generate indices + for j in range(t.s_q): + cur_abs_indices = torch.randperm(int(cache_seqlens_cpu[i].item()), device="cpu")[:t.topk] + cur_blocked_indices = block_table_cpu[i, cur_abs_indices//t.block_size]*t.block_size + (cur_abs_indices%t.block_size) + if len(cur_abs_indices) < t.topk: + pad_len = t.topk - len(cur_abs_indices) + cur_abs_indices = torch.cat([cur_abs_indices, torch.full((pad_len,), -1, device='cpu')]) + cur_blocked_indices = torch.cat([cur_blocked_indices, torch.full((pad_len,), -1, device='cpu')]) + + # Mask KV + perm = torch.randperm(t.topk, device='cpu') + cur_abs_indices = cur_abs_indices[perm] + cur_blocked_indices = cur_blocked_indices[perm] + + # Fill it with invalid indices if needed + if t.is_all_indices_invalid: + cur_abs_indices.fill_(-1) + cur_blocked_indices.fill_(-1) + + abs_indices[i, j, :] = cur_abs_indices + indices_in_kvcache[i, j, :] = cur_blocked_indices + + # Mask nonused KV as NaN + all_indices = indices_in_kvcache.flatten().tolist() + all_indices = list(set(all_indices)) + if -1 in all_indices: + all_indices.remove(-1) + all_indices = torch.tensor(all_indices, dtype=torch.int32, device='cpu') + + blocked_k = blocked_k.view(-1, t.h_kv, t.d) + nonused_indices_mask = torch.ones(blocked_k.size(0)*blocked_k.size(1), dtype=torch.bool, device='cpu') + nonused_indices_mask[all_indices] = False + blocked_k[nonused_indices_mask, :, :] = float("nan") + blocked_k = blocked_k.view(-1, t.block_size, t.h_kv, t.d) + + abs_indices = abs_indices.to(q.device) + indices_in_kvcache = indices_in_kvcache.to(q.device) + + return cache_seqlens, q, block_table, blocked_k, abs_indices, indices_in_kvcache + + +def reference_torch( + cache_seqlens: torch.Tensor, # [batch_size] + block_table: torch.Tensor, # [batch_size, ?] + q: torch.Tensor, # [batch_size, s_q, h_q, d] + blocked_k: torch.Tensor, # [?, block_size, h_kv, d] + dv: int, + is_causal: bool, + indices: Optional[torch.Tensor] = None # [batch_size, s_q, topk] +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + A reference implementation in PyTorch + """ + def get_topk_attn_mask(s_q: int, s_k: int, indices: torch.Tensor): + mask = torch.zeros(s_q, s_k, dtype=torch.bool) + for i in range(s_q): + cur_indices = indices[i] + valid_indices = cur_indices[cur_indices != -1] + mask[i, valid_indices] = True + return mask + + def scaled_dot_product_attention( + batch_idx: int, + query: torch.Tensor, # [h_q, s_q, d] + kv: torch.Tensor, # [h_kv, s_k, d] + dv: int, + is_causal, + indices: Optional[torch.Tensor], # [s_q, topk] + ) -> Tuple[torch.Tensor, torch.Tensor]: + h_q = query.size(0) + h_kv = kv.size(0) + s_q = query.shape[-2] + s_k = kv.shape[-2] + query = query.float() + kv = kv.float() + if h_kv != 1: + kv = kv.repeat_interleave(h_q // h_kv, dim=0) + kv[kv != kv] = 0.0 + attn_weight = query @ kv.transpose(-2, -1) # [h_q, s_q, s_k] + if (is_causal and query.size(1) > 1) or indices is not None: + mask = torch.ones(s_q, s_k, dtype=torch.bool) + if is_causal: + assert indices is None + mask = mask.tril(diagonal=s_k - s_q) + if indices is not None: + mask &= get_topk_attn_mask(s_q, s_k, indices) + attn_bias = torch.zeros(s_q, s_k, dtype=torch.float) + attn_bias.masked_fill_(mask.logical_not(), float("-inf")) + attn_weight += attn_bias.to(q.dtype) + attn_weight /= math.sqrt(query.size(-1)) + lse = attn_weight.logsumexp(dim=-1) # [h_q, s_q] + attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) + output = attn_weight @ kv[..., :dv] # [h_q, s_q, dv] + # Correct for q tokens which has no attendable k + lonely_q_mask = (lse == float("-inf")) + output[lonely_q_mask.unsqueeze(-1).broadcast_to(h_q, s_q, dv)] = 0.0 + lse[lonely_q_mask] = float("+inf") + + return output, lse + + b, s_q, h_q, d = q.size() + block_size = blocked_k.size(1) + h_kv = blocked_k.size(2) + cache_seqlens_cpu = cache_seqlens.cpu() + out_ref = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) + lse_ref = torch.empty(b, h_q, s_q, dtype=torch.float32) + for i in range(b): + cur_len = cache_seqlens_cpu[i].item() + cur_num_blocks = cdiv(cur_len, block_size) + cur_block_indices = block_table[i][0: cur_num_blocks] + cur_kv = blocked_k[cur_block_indices].view(-1, h_kv, d)[:cur_len, ...] + cur_out, cur_lse = scaled_dot_product_attention( + i, + q[i].transpose(0, 1), + cur_kv.transpose(0, 1), + dv, + is_causal, + indices[i] if indices is not None else None + ) + out_ref[i] = cur_out.transpose(0, 1) + lse_ref[i] = cur_lse + out_ref = out_ref.to(torch.bfloat16) + return out_ref, lse_ref + + +@torch.inference_mode() +def test_flash_mla(t: TestParam): + print('-------------------------------') + print(f"Running on {t}...") + + # Generating test data + torch.cuda.synchronize() + cache_seqlens, q, block_table, blocked_k, abs_indices, indices_in_kvcache = generate_test_data(t) + + if t.is_fp8: + # The quantization error may be too large to be distinguished from wrong kernels + # So we quantize and de-quantize kv-cache here to mitigate quantization error + blocked_k_quantized = quant.quantize_k_cache(blocked_k, t.dv, 128) + blocked_k_dequantized = quant.dequantize_k_cache(blocked_k_quantized) + blocked_k = blocked_k_dequantized + + # Get schedule metadata + torch.cuda.synchronize() + tile_scheduler_metadata, num_splits = flash_mla.get_mla_metadata( + cache_seqlens, + t.s_q * t.h_q // t.h_kv, + t.h_kv, + t.h_q, + t.is_fp8, + t.topk + ) + torch.cuda.synchronize() + + def run_flash_mla(): + return flash_mla.flash_mla_with_kvcache( + q, + blocked_k if not t.is_fp8 else blocked_k_quantized, # type: ignore + block_table, + cache_seqlens, + t.dv, + tile_scheduler_metadata, + num_splits, + causal=t.is_causal, + is_fp8_kvcache=t.is_fp8, + indices=indices_in_kvcache + ) + + out_ans, lse_ans = run_flash_mla() + out_ref, lse_ref = reference_torch(cache_seqlens, block_table, q, blocked_k, t.dv, t.is_causal, abs_indices) + assert check_is_allclose("out", out_ans, out_ref, abs_tol=8e-4, rel_tol=2.01/128, cos_diff_tol=5e-6) + assert check_is_allclose("lse", lse_ans, lse_ref, abs_tol=1e-6, rel_tol=8.01/65536) + + if t.test_performance: + time_usage: float = triton.testing.do_bench(run_flash_mla)/1000 # type: ignore + mean_attended_seqlens = cache_seqlens.float().mean().item() if t.topk is None else t.topk + compute_volume_flop = t.b*t.h_q*t.s_q*sum([ + 2*t.d*mean_attended_seqlens, # Q * K^T + 2*mean_attended_seqlens*t.dv, # attention * V + ]) + q_elem_size = torch.bfloat16.itemsize + kv_token_size = 656 if t.is_fp8 else t.d*torch.bfloat16.itemsize + memory_volume_B = t.b*sum([ + t.s_q*t.h_q*(t.d*q_elem_size), # Q + (t.s_q if t.topk is not None else 1) * mean_attended_seqlens*t.h_kv*kv_token_size, # K/V + t.s_q*t.h_q*(t.dv*q_elem_size), # Output + ]) + achieved_tflops = compute_volume_flop / time_usage / 1e12 + achieved_gBps = memory_volume_B / time_usage / 1e9 + + print(f"{time_usage*1000:.3f} ms, {achieved_tflops:.0f} TFLOPS, {achieved_gBps:.0f} GB/s") + + +def main(torch_dtype): + device = torch.device("cuda:0") + torch.set_default_dtype(torch_dtype) + torch.set_default_device(device) + torch.cuda.set_device(device) + + correctness_cases = [ + TestParam(b, s_q, s_k, is_varlen, is_causal, is_fp8, topk, test_performance=False) + for b in [1, 2, 6, 64] + for s_q in [1, 2, 4] + for s_k in [20, 140, 4096] + for is_varlen in [False, True] + for is_causal in [False, True] + for (is_fp8, topk) in [ + (False, None), + (True, 128), + (True, 2048) + ] + if not (is_causal and topk is not None) + ] + + corner_cases = [ + # Cases where all topk indices are invalid + TestParam(128, 2, 4096, is_varlen=True, is_causal=False, is_fp8=True, topk=topk, test_performance=False, is_all_indices_invalid=True) + for topk in [128, 2048, 4096] + ] + [ + # Cases where some kv cache have zero length + TestParam(128, 2, 4096, is_varlen=True, is_causal=is_causal, is_fp8=is_fp8, topk=topk, test_performance=False, have_zero_seqlen_k=True) + for (is_causal, is_fp8, topk) in [ + (False, False, None), + (True, False, None), + (False, True, 128), + (False, True, 2048), + ] + ] + + performance_cases = [ + TestParam(128, s_q, s_k, is_varlen=True, is_causal=is_causal, is_fp8=is_fp8, topk=topk, test_performance=True) + for (is_causal, is_fp8, topk) in [ + (False, False, None), + (True, False, None), + (False, True, 2048), + ] + for s_q in [1, 2] + for s_k in [4096, 8192, 16384, 32768] + ] + + testcases = correctness_cases + corner_cases + performance_cases + + for testcase in testcases: + test_flash_mla(testcase) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--dtype", + type=str, + choices=["bf16", "fp16"], + default="bf16", + help="Data type to use for testing (bf16 or fp16)", + ) + + args = parser.parse_args() + + torch_dtype = torch.bfloat16 + if args.dtype == "fp16": + torch_dtype = torch.float16 + + main(torch_dtype) diff --git a/tests/test_flash_mla_prefill.py b/tests/test_flash_mla_prefill.py new file mode 100644 index 0000000..f85f2d6 --- /dev/null +++ b/tests/test_flash_mla_prefill.py @@ -0,0 +1,197 @@ +import math +import time +from typing import Tuple +import random +import dataclasses + +import torch +import triton + +from flash_mla import flash_mla_sparse_fwd +from lib import check_is_allclose + +@dataclasses.dataclass +class TestParam: + b: int + s_q: int + s_kv: int + topk: int + h_q: int = 128 + h_kv: int = 1 + d_qk: int = 576 + d_v: int = 512 + seed: int = 0 + check_correctness: bool = True + benchmark: bool = True + +@dataclasses.dataclass +class Testcase: + t: TestParam + q: torch.Tensor + kv: torch.Tensor + indices: torch.Tensor + +def generate_testcase(t: TestParam) -> Testcase: + torch.manual_seed(t.seed) + torch.cuda.manual_seed(t.seed) + random.seed(t.seed) + q = torch.randn((t.b, t.s_q, t.h_q, t.d_qk), dtype=torch.bfloat16)/10 + kv = torch.randn((t.b, t.s_kv, t.h_kv, t.d_qk), dtype=torch.bfloat16)/10 + + q.clamp_(-10, 10) + kv.clamp_(-10, 10) + + indices = torch.full((t.b, t.s_q, t.h_kv, t.topk), t.s_kv, dtype=torch.int32) + for b in range(t.b): + for s in range(t.s_q): + for h in range(t.h_kv): + # TODO Comment + near_mask = torch.randint(0, 32, (min(t.topk, t.s_kv),)) < 31 + cur_indices = torch.randperm(t.s_kv)[:t.topk] + cur_indices[near_mask] = torch.randint(max(0, t.s_kv-20000), t.s_kv-1, (near_mask.sum().item(),)) + if len(cur_indices) < t.topk: + cur_indices = torch.cat([cur_indices, torch.full((t.topk - len(cur_indices),), 2147480000)]) + cur_indices = cur_indices[torch.randperm(t.topk)] + indices[b, s, h] = cur_indices + indices = indices.to(q.device) + + return Testcase( + t=t, + q=q, + kv=kv, + indices=indices + ) + +def get_flop(p: TestParam) -> float: + flop = 2 * sum([ + p.h_q * p.d_qk * p.topk, + p.h_q * p.d_v * p.topk + ]) * p.b * p.s_q + return flop + +def reference_torch(p: TestParam, t: Testcase, sm_scale: float) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def log2sumexp2(a: torch.Tensor, dim: int) -> torch.Tensor: + return torch.logsumexp(a * math.log(2), dim=dim) * math.log2(math.e) + + assert p.b == 1 + indices = t.indices[0, :, 0, :] # [s_q, topk] + invalid_indices_mask = (indices < 0) | (indices >= p.s_kv) + qs = t.q[0, :, :, :].float() # [s_q, h_q, d_qk] + kvs = t.kv[0, :, 0, :].float() # [s_kv, d_qk] + + kvs = torch.index_select(kvs, 0, indices.masked_fill(invalid_indices_mask, 0).flatten()).view(p.s_q, p.topk, p.d_qk) # [s_q, topk, d_qk] + attn_score = qs @ kvs.transpose(1, 2) # [s_q, h_q, topk] + attn_score.masked_fill_(invalid_indices_mask.unsqueeze(1), float('-inf')) + attn_score *= sm_scale * math.log2(math.e) + max_logits = torch.max(attn_score, dim=-1)[0] # [s_q, h_q] + lse = log2sumexp2(attn_score, dim=-1) # [s_q, h_q] + attn_score = torch.exp2(attn_score - lse.unsqueeze(-1)) # [s_q, h_q, topk] + result = attn_score @ kvs[:, :, :p.d_v] + return (max_logits, lse, result) + +@torch.inference_mode() +def run_test(p: TestParam) -> bool: + print("================") + print(f"Running on {p}") + torch.cuda.empty_cache() + assert p.b == 1 + + t = generate_testcase(p) + sm_scale = 1 / math.sqrt(p.d_qk) + torch.cuda.synchronize() + + def run_ans(): + return flash_mla_sparse_fwd( + t.q.squeeze(0), t.kv.squeeze(0), t.indices.squeeze(0), sm_scale=sm_scale + ) + + ans_out, ans_max_logits, ans_lse = run_ans() + torch.cuda.synchronize() + + if p.benchmark: + flop = get_flop(p) + prefill_ans_time: float = triton.testing.do_bench(run_ans, warmup=10, rep=20)/1000 # type: ignore + prefill_flops = flop/prefill_ans_time/1e12 + print(f"Prefill: {prefill_ans_time*1e6:4.0f} us, {prefill_flops:.3f} TFlops") + + if p.check_correctness: + torch.cuda.synchronize() + ref_max_logits, ref_lse, ref_out = reference_torch(p, t, sm_scale) + torch.cuda.synchronize() + + is_correct = True + is_correct &= check_is_allclose("out", ans_out, ref_out, abs_tol=8e-4, rel_tol=2.01/128, cos_diff_tol=7e-6) + is_correct &= check_is_allclose("max_logits", ans_max_logits, ref_max_logits, abs_tol=1e-6, rel_tol=2.01/65536) + is_correct &= check_is_allclose("lse", ans_lse, ref_lse, abs_tol=1e-6, rel_tol=2.01/65536) + + return is_correct + else: + return True + + +if __name__ == '__main__': + device = torch.device("cuda:0") + torch.set_default_dtype(torch.bfloat16) + torch.set_default_device(device) + torch.cuda.set_device(device) + torch.set_float32_matmul_precision('high') + + correctness_cases = [ + # Regular shapes + TestParam(1, s_q, s_kv, topk, h_q=128, benchmark=False) + for s_kv, topk in [ + # Regular shapes + (128, 128), + (256, 256), + (512, 512), + + # Irregular shapes + (592, 128), + (1840, 256), + (1592, 384), + (1521, 512), + + # Irregular shapes with OOB TopK + (95, 128), + (153, 256), + (114, 384), + ] + for s_q in [ + 1, 62 + ] + ] + + corner_cases = [ + # In these cases, some blocks may not have any valid topk indices + TestParam(1, s_q, s_kv, topk, h_q=128, benchmark=False) + for s_kv, topk in [ + (32, 2048), + (64, 8192) + ] + for s_q in [1, 1024] + ] + + performance_cases = [ + TestParam(1, s_q, s_kv, topk, h_q=128) + for s_q in [4096] + for s_kv in [4096, 8192, 16384, 32768, 49152, 65536, 81920, 98304, 114688, 131072] + for topk in [2048] + ] + + testcases = correctness_cases + corner_cases + performance_cases + + failed_cases = [] + for test in testcases: + if test.benchmark: + time.sleep(0.2) + is_correct = run_test(test) + if not is_correct: + failed_cases.append(test) + + if len(failed_cases) > 0: + print(f"\033[31m\033[1m{len(failed_cases)} / {len(testcases)} cases failed:\033[0m") + for case in failed_cases: + print(f" {case}") + else: + print(f"\033[32m\033[1mAll {len(testcases)} cases passed!\033[0m") + diff --git a/tests/test_flash_mla_sm90.py b/tests/test_flash_mla_sm90.py deleted file mode 100644 index 67c9d93..0000000 --- a/tests/test_flash_mla_sm90.py +++ /dev/null @@ -1,153 +0,0 @@ -import argparse -import math -import random - -import torch -import triton - -from flash_mla import flash_mla_with_kvcache, get_mla_metadata - - -def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): - query = query.float() - key = key.float() - value = value.float() - key = key.repeat_interleave(h_q // h_kv, dim=0) - value = value.repeat_interleave(h_q // h_kv, dim=0) - attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1)) - if is_causal: - s_q = query.shape[-2] - s_k = key.shape[-2] - attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype) - temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q) - attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) - attn_bias.to(query.dtype) - attn_weight += attn_bias - lse = attn_weight.logsumexp(dim=-1) - attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) - return attn_weight @ value, lse - - -def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None: - x, y = x.double(), y.double() - RMSE = ((x - y) * (x - y)).mean().sqrt().item() - cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12) - amax_diff = (x - y).abs().max().item() - # print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}") - assert cos_diff < 1e-5 - - -@torch.inference_mode() -def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen): - print( - f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}" - ) - - cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32) - if varlen: - for i in range(b): - cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), s_q) - total_seqlens = cache_seqlens.sum().item() - mean_seqlens = cache_seqlens.float().mean().int().item() - max_seqlen = cache_seqlens.max().item() - max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 - # print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}") - - q = torch.randn(b, s_q, h_q, d) - block_size = 64 - block_table = torch.arange( - b * max_seqlen_pad // block_size, dtype=torch.int32 - ).view(b, max_seqlen_pad // block_size) - blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) - for i in range(b): - blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = ( - float("nan") - ) - blocked_v = blocked_k[..., :dv] - - tile_scheduler_metadata, num_splits = get_mla_metadata( - cache_seqlens, s_q * h_q // h_kv, h_kv - ) - - def flash_mla(): - return flash_mla_with_kvcache( - q, - blocked_k, - block_table, - cache_seqlens, - dv, - tile_scheduler_metadata, - num_splits, - causal=causal, - ) - - def ref_mla(): - out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) - lse = torch.empty(b, h_q, s_q, dtype=torch.float32) - for i in range(b): - begin = i * max_seqlen_pad - end = begin + cache_seqlens[i] - O, LSE = scaled_dot_product_attention( - q[i].transpose(0, 1), - blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1), - blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1), - h_q=h_q, - h_kv=h_kv, - is_causal=causal, - ) - out[i] = O.transpose(0, 1) - lse[i] = LSE - return out, lse - - out_flash, lse_flash = flash_mla() - out_torch, lse_torch = ref_mla() - cal_diff(out_flash, out_torch, "out") - cal_diff(lse_flash, lse_torch, "lse") - - t = triton.testing.do_bench(flash_mla) - FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 - bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * ( - torch.finfo(q.dtype).bits // 8 - ) - print( - f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s" - ) - - -def main(torch_dtype): - device = torch.device("cuda:0") - torch.set_default_dtype(torch_dtype) - torch.set_default_device(device) - torch.cuda.set_device(device) - torch.manual_seed(0) - random.seed(0) - - h_kv = 1 - d, dv = 576, 512 - causal = True - - for b in [128]: - for s in [4096, 8192, 16384]: - for h_q in [16, 32, 64, 128]: # TP = 8, 4, 2, 1 - for s_q in [1, 2]: # MTP = 1, 2 - for varlen in [False, True]: - test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--dtype", - type=str, - choices=["bf16", "fp16"], - default="bf16", - help="Data type to use for testing (bf16 or fp16)", - ) - - args = parser.parse_args() - - torch_dtype = torch.bfloat16 - if args.dtype == "fp16": - torch_dtype = torch.float16 - - main(torch_dtype) diff --git a/tests/test_fmha_sm100.py b/tests/test_fmha_sm100.py index 7cb19a2..6b2ba45 100644 --- a/tests/test_fmha_sm100.py +++ b/tests/test_fmha_sm100.py @@ -6,6 +6,7 @@ from flash_mla import flash_attn_varlen_func +from lib import check_is_allclose def get_window_size(causal, window): if window > 0: @@ -28,21 +29,15 @@ def get_attn_bias(s_q, s_k, causal, window): return attn_bias -def assert_close(x: torch.Tensor, y: torch.Tensor, name: str) -> None: - x, y = x.double(), y.double() - RMSE = ((x - y) * (x - y)).mean().sqrt().item() - cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12) - amax_diff = (x - y).abs().max().item() - # print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}") - assert cos_diff < 1e-5, f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}" - - def sdpa(query, key, value, attn_bias, softmax_scale=None): + query = query.float().transpose(-3, -2) + key = key.float().transpose(-3, -2) + value = value.float().transpose(-3, -2) key = key.repeat_interleave(h // h_k, dim=-3) value = value.repeat_interleave(h // h_k, dim=-3) if softmax_scale is None: softmax_scale = query.shape[-1] ** (-0.5) - attn_weight = query @ key.transpose(-2, -1) * softmax_scale + attn_weight = (query @ key.transpose(-2, -1)) * softmax_scale attn_weight += attn_bias lse = attn_weight.logsumexp(dim=-1) attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) @@ -53,8 +48,8 @@ def sdpa_checkpoint(*args, **kwargs): return checkpoint(sdpa, *args, use_reentrant=False, **kwargs) -def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, window, has_bwd): - print(f"{b=}, {mean_sq=}, {mean_sk=}, {varlen=}, {h=}, {h_k=}, {d=}, {dv=}, {causal=}") +def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, window, has_bwd, check_correctness: bool = True): + print(f"{b=}, {mean_sq=}, {mean_sk=}, {varlen=}, {h=}, {h_k=}, {d=}, {dv=}, {causal=}, {has_bwd=}, {check_correctness=}") torch.manual_seed(0) random.seed(0) @@ -76,19 +71,20 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win causal, window) == 0).sum().item() for i in range(b)]) # print(f"{total_q=}, {max_seqlen_q=}, {total_k=}, {max_seqlen_k=}, {total_attn_compute=}, {cu_seqlens_q.tolist()}, {cu_seqlens_k.tolist()}") - q = torch.randn(total_q, h, d) - k = torch.randn(total_k, h_k, d) - v = torch.randn(total_k, h_k, dv) - grad_out = torch.randn(total_q, h, dv) + q = torch.randn(total_q, h, d)/10 + k = torch.randn(total_k, h_k, d)/10 + v = torch.randn(total_k, h_k, dv)/10 + grad_out = torch.randn(total_q, h, dv)/10 softmax_scale = (d + 100) ** (-0.5) q1 = q.clone().requires_grad_() k1 = k.clone().requires_grad_() v1 = v.clone().requires_grad_() - q2 = q.clone().requires_grad_() - k2 = k.clone().requires_grad_() - v2 = v.clone().requires_grad_() + if check_correctness: + q2 = q.clone().requires_grad_() + k2 = k.clone().requires_grad_() + v2 = v.clone().requires_grad_() def flash_attn(): q1.grad = k1.grad = v1.grad = None @@ -106,9 +102,9 @@ def torch_attn(): lse = [] for i in range(b): OUT, LSE = sdpa_checkpoint( - q2[cu_seqlens_q[i].item(): cu_seqlens_q[i + 1].item()].float().transpose(-3, -2), - k2[cu_seqlens_k[i].item(): cu_seqlens_k[i + 1].item()].float().transpose(-3, -2), - v2[cu_seqlens_k[i].item(): cu_seqlens_k[i + 1].item()].float().transpose(-3, -2), + q2[cu_seqlens_q[i].item(): cu_seqlens_q[i + 1].item()], + k2[cu_seqlens_k[i].item(): cu_seqlens_k[i + 1].item()], + v2[cu_seqlens_k[i].item(): cu_seqlens_k[i + 1].item()], attn_bias=get_attn_bias(seqlens_q[i].item(), seqlens_k[i].item(), causal, window), softmax_scale=softmax_scale, ) @@ -119,20 +115,23 @@ def torch_attn(): return out, lse out_flash, lse_flash = flash_attn() - out_torch, lse_torch = torch_attn() - assert_close(out_flash, out_torch, "out") - assert_close(lse_flash, lse_torch, "lse") - if has_bwd: out_flash.backward(grad_out, retain_graph=True) - out_torch.backward(grad_out, retain_graph=True) - assert_close(q1.grad, q2.grad, "dq") - assert_close(k1.grad, k2.grad, "dk") - assert_close(v1.grad, v2.grad, "dv") dq1 = q1.grad.clone() dk1 = k1.grad.clone() dv1 = v1.grad.clone() + if check_correctness: + out_torch, lse_torch = torch_attn() + assert check_is_allclose("out", out_flash, out_torch, abs_tol=1e-3, rel_tol=8.01/128, cos_diff_tol=7e-6) + assert check_is_allclose("lse", lse_flash, lse_torch, abs_tol=1e-6, rel_tol=2.01/65536) + + if has_bwd: + out_torch.backward(grad_out, retain_graph=True) + assert check_is_allclose("dq", q1.grad, q2.grad, abs_tol=1e-3, rel_tol=8.01/128, cos_diff_tol=7e-6) + assert check_is_allclose("dk", k1.grad, k2.grad, abs_tol=1e-3, rel_tol=8.01/128, cos_diff_tol=7e-6) + assert check_is_allclose("dv", v1.grad, v2.grad, abs_tol=1e-3, rel_tol=8.01/128, cos_diff_tol=7e-6) + def forward(): return flash_attn() @@ -150,12 +149,6 @@ def backward(): assert torch.equal(k1.grad, dk1), "dk deterministic check failed!" assert torch.equal(v1.grad, dv1), "dv deterministic check failed!" - # with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: - # forward() - # if has_bwd: - # backward() - # print(prof.key_averages().table(sort_by="cuda_time_total", max_name_column_width=120)) - def timer(func, name): t = triton.testing.do_bench(func, warmup=2, rep=3) FLOPS = total_attn_compute * h * 2 * ((d + dv) if name == "fwd" else ((d * 3 + dv * 2))) @@ -173,18 +166,20 @@ def timer(func, name): device = torch.device("cuda:0") torch.set_default_device(device) torch.cuda.set_device(device) + torch.set_float32_matmul_precision("high") - b = 4 + b = 2 window = 0 has_bwd = False for (mean_sq, mean_sk) in [(4096, 4096), (8192, 8192)]: for varlen in [False, True]: - for (h, h_k) in [(32, 32), (32, 4)]: + for (h, h_k) in [(128, 128), (32, 4)]: if h != h_k: has_bwd = False else: has_bwd = True for (d, dv) in [(128, 128), (192, 128)]: for causal in [False, True]: - test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, window, has_bwd) + skip_correctness_check = mean_sq == 8192 and mean_sk == 8192 and h == 128 + test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, window, has_bwd, not skip_correctness_check) From 87709cf4cce80392c67befe132fd338dd3049bc2 Mon Sep 17 00:00:00 2001 From: Shengyu Liu Date: Wed, 24 Sep 2025 14:13:06 +0800 Subject: [PATCH 12/20] Add a comment --- tests/test_flash_mla_prefill.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_flash_mla_prefill.py b/tests/test_flash_mla_prefill.py index f85f2d6..19a6dbe 100644 --- a/tests/test_flash_mla_prefill.py +++ b/tests/test_flash_mla_prefill.py @@ -45,7 +45,7 @@ def generate_testcase(t: TestParam) -> Testcase: for b in range(t.b): for s in range(t.s_q): for h in range(t.h_kv): - # TODO Comment + # NOTE We use the following method to generate indices so that most indices lies within [s_kv-20000, s_kv), which is more realistic for sparse attention near_mask = torch.randint(0, 32, (min(t.topk, t.s_kv),)) < 31 cur_indices = torch.randperm(t.s_kv)[:t.topk] cur_indices[near_mask] = torch.randint(max(0, t.s_kv-20000), t.s_kv-1, (near_mask.sum().item(),)) From 7232d69d5e269db902d69d7718a5a55efaab4be8 Mon Sep 17 00:00:00 2001 From: Shengyu Liu Date: Mon, 29 Sep 2025 15:11:24 +0800 Subject: [PATCH 13/20] Fill in link to DSv3.2 paper --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 8cf01a3..2f6b8db 100644 --- a/README.md +++ b/README.md @@ -2,11 +2,11 @@ ## Introduction -FlashMLA is DeepSeek's library of optimized attention kernels, powering the [DeepSeek-V3](https://github.com/deepseek-ai/DeepSeek-V3) and [DeepSeek-V3.2](TODO) models. This repository contains the following implementations: +FlashMLA is DeepSeek's library of optimized attention kernels, powering the [DeepSeek-V3](https://github.com/deepseek-ai/DeepSeek-V3) and [DeepSeek-V3.2](https://github.com/deepseek-ai/DeepSeek-V3.2-Exp) models. This repository contains the following implementations: **Sparse Attention Kernels** -*These kernels power DeepSeek Sparse Attention (DSA), as introduced in [this paper](TODO).* +*These kernels power DeepSeek Sparse Attention (DSA), as introduced in [this paper](https://github.com/deepseek-ai/DeepSeek-V3.2-Exp).* - Token-level sparse attention for the prefill stage - Token-level sparse attention for the decoding stage, with FP8 KV cache @@ -18,7 +18,7 @@ FlashMLA is DeepSeek's library of optimized attention kernels, powering the [Dee ## News -- **2025.09.26(TODO) Release of Sparse Attention Kernels**: With the launch of [DeepSeek-V3.2](TODO), we are releasing the corresponding token-level sparse attention kernels. These kernels power the model's DeepSeek Sparse Attention (DSA) and achieve up to 640 TFlops during prefilling and 410 TFlops during decoding. +- **2025.09.29 Release of Sparse Attention Kernels**: With the launch of [DeepSeek-V3.2](https://github.com/deepseek-ai/DeepSeek-V3.2-Exp), we are releasing the corresponding token-level sparse attention kernels. These kernels power the model's DeepSeek Sparse Attention (DSA) and achieve up to 640 TFlops during prefilling and 410 TFlops during decoding. - **2025.08.01 Kernels for MHA on Blackwell**: Thanks to [NVIDIA's PR](https://github.com/deepseek-ai/FlashMLA/pull/76) for MHA forward / backward kernels on Blackwell! - **2025.04.22 Deep-Dive Blog**: We'd love to share the technical details behind the new FlashMLA kernel! Check out our deep-dive write-up [here](docs/20250422-new-kernel-deep-dive.md). - **2025.04.22 Performance Update**: We're excited to announce the new release of Flash MLA, which delivers 5% ~ 15% performance improvement for compute-bound workloads, achieving up to 660 TFlops on NVIDIA H800 SXM5 GPUs. The interface of the new version is fully compatible with the old one. Simply upgrade to the new version for an immediate performance boost! 🚀🚀🚀 @@ -66,7 +66,7 @@ Support matrix: [1]: For more details on using FP8 KV cache, see documents below. -[2]: Here "MLA Mode" refers to the mode used for MLA calculation. MQA stands for Multi-Query Attention mode (i.e. `head_dim_k` = 576 with `head_dim_v` = 512), while MHA stands for Multi-Head Attention mode (i.e. `head_dim_k` = 192 / 128 with `head_dim_v` = 128). For a detailed explanation of these modes, please refer to the appendix of [DeepSeek V3.2's Paper](TODO). +[2]: Here "MLA Mode" refers to the mode used for MLA calculation. MQA stands for Multi-Query Attention mode (i.e. `head_dim_k` = 576 with `head_dim_v` = 512), while MHA stands for Multi-Head Attention mode (i.e. `head_dim_k` = 192 / 128 with `head_dim_v` = 128). For a detailed explanation of these modes, please refer to the appendix of [DeepSeek V3.2's Paper](https://github.com/deepseek-ai/DeepSeek-V3.2-Exp). ## Installation From fd249aacce56327affecd16f89e035b12691974f Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Mon, 29 Sep 2025 02:21:37 -0700 Subject: [PATCH 14/20] Add Sparse Decoding Kernel and Sparse Prefill Kernel for Blackwell Signed-off-by: simon-mo --- README.md | 8 +- csrc/pybind.cpp | 38 +- csrc/sm100/decode/sparse_fp8/dequant.h | 61 ++ csrc/sm100/decode/sparse_fp8/splitkv_mla.cu | 592 +++++++++++++++ csrc/sm100/decode/sparse_fp8/splitkv_mla.h | 10 + csrc/sm100/defines.h | 30 + csrc/sm100/helpers.h | 97 +++ csrc/sm100/intrinsics.h | 461 ++++++++++++ csrc/sm100/prefill/sparse/fwd.cu | 785 ++++++++++++++++++++ csrc/sm100/prefill/sparse/fwd.h | 9 + csrc/sm100/prefill/sparse/helpers.h | 104 +++ csrc/sm100/prefill/sparse/intrinsics.h | 638 ++++++++++++++++ csrc/sm100/prefill/sparse/ws_gemm.h | 328 ++++++++ csrc/sm100/tma_cta_group2_nosplit.h | 281 +++++++ csrc/sm100/ws_gemm.h | 426 +++++++++++ setup.py | 16 + tests/test_flash_mla_decoding.py | 5 + 17 files changed, 3882 insertions(+), 7 deletions(-) create mode 100644 csrc/sm100/decode/sparse_fp8/dequant.h create mode 100644 csrc/sm100/decode/sparse_fp8/splitkv_mla.cu create mode 100644 csrc/sm100/decode/sparse_fp8/splitkv_mla.h create mode 100644 csrc/sm100/defines.h create mode 100644 csrc/sm100/helpers.h create mode 100644 csrc/sm100/intrinsics.h create mode 100644 csrc/sm100/prefill/sparse/fwd.cu create mode 100644 csrc/sm100/prefill/sparse/fwd.h create mode 100644 csrc/sm100/prefill/sparse/helpers.h create mode 100644 csrc/sm100/prefill/sparse/intrinsics.h create mode 100644 csrc/sm100/prefill/sparse/ws_gemm.h create mode 100644 csrc/sm100/tma_cta_group2_nosplit.h create mode 100644 csrc/sm100/ws_gemm.h diff --git a/README.md b/README.md index 2f6b8db..354cdde 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,8 @@ python tests/test_flash_mla_decoding.py The dense MLA decoding kernel can achieve up to 3000 GB/s in memory-bound configuration and 660 TFLOPS in computation-bound configuration on H800 SXM5, using CUDA 12.8. For token-level sparse MLA decoding kernel (which uses an FP8 KV cache while performing the matrix multiplication in bfloat16), it can achieve 410 TFLOPS in compute-bound configuration on H800 SXM5, CUDA 12.8. +For Blackwell GPUs, the token-level sparse MLA decoding kernel can achieve up to 350 TFlops (on B200) which is not really optimized yet. + #### Test & benchmark MHA prefill (Dense): ```bash @@ -47,7 +49,7 @@ It achieves up to 1460 TFlops in forward and 1000 TFlops in backward computation python tests/test_flash_mla_prefill.py ``` -It achieves up to 640 TFlops in forward computation on H800 SXM5, CUDA 12.8. +It achieves up to 640 TFlops in forward computation on H800 SXM5, CUDA 12.8, and achieves up to 1450 TFlops on B200, CUDA 12.9. ## Requirements @@ -60,9 +62,9 @@ Support matrix: | Kernel | GPU Architecture | MLA Mode [2] | KVCache Format | | :---: | :---: | :---: | :---: | | Dense Decoding | Hopper | MQA | BF16 | -| Sparse Decoding | Hopper | MQA | FP8 [1] | +| Sparse Decoding | Hopper & Blackwell | MQA | FP8 [1] | | Dense Prefill | Blackwell | MHA | | -| Sparse Prefill | Hopper | MQA | | +| Sparse Prefill | Hopper & Blackwell | MQA | | [1]: For more details on using FP8 KV cache, see documents below. diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index b360c24..6ec3f21 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -16,7 +16,9 @@ #include "sm90/decode/dense/splitkv_mla.h" #include "sm90/decode/sparse_fp8/splitkv_mla.h" #include "sm90/prefill/sparse/fwd.h" +#include "sm100/decode/sparse_fp8/splitkv_mla.h" #include "sm100/prefill/dense/interface.h" +#include "sm100/prefill/sparse/fwd.h" #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") @@ -31,7 +33,7 @@ struct Arch { } bool is_sm100() const { - return major == 10 && minor == 0; + return major == 10; } void assert_is_supported() const { @@ -86,7 +88,31 @@ DecodingAttnImplMeta get_attn_impl_meta( } } } else if (arch.is_sm100()) { - TORCH_CHECK(false, "Unsupported GPU architecture"); + if (is_sparse_attn) { + if (is_fp8_kvcache) { + TORCH_CHECK(h_q_.has_value()); + int h_q = h_q_.value(); + TORCH_CHECK(h_q % h_k == 0); + int s_q = num_q_tokens_per_head_k * h_k / h_q; + // FP8 + Sparse MLA + return { + std::max(sm_count / h_k / (cutlass::ceil_div(h_q/h_k, 64) * s_q), 1), + 5, + 64 + }; + } else { + // Sparse BF16 MLA + TORCH_CHECK(false, "Sparse BF16 MLA is not supported on SM100"); + } + } else { + if (is_fp8_kvcache) { + // FP8 MLA + TORCH_CHECK(false, "FP8 Dence MLA is not supported on SM100"); + } else { + // Normal BF16 MLA + TORCH_CHECK(false, "BF16 Dence MLA is not supported on SM100"); + } + } } else { TORCH_CHECK(false, "Unsupported GPU architecture"); } @@ -326,7 +352,8 @@ fwd_kvcache_mla( } } } else if (arch.is_sm100()) { - TORCH_CHECK(false, "Unsupported GPU architecture"); + TORCH_CHECK(is_fp8 && is_sparse_attn, "Only FP8 + Sparse attention is supported on SM100"); + sm100::run_flash_splitkv_mla_fp8_sparse_kernel(params, stream); } else { TORCH_CHECK(false, "Unsupported GPU architecture"); } @@ -366,7 +393,8 @@ std::vector sparse_prefill_fwd( ) { auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_sm90 = dprops->major == 9; - TORCH_CHECK(is_sm90, "Sparse Attention Forward Kernel (sparse_prefill_fwd) is only supported on SM90 architectures"); + bool is_sm100 = dprops->major == 10; + TORCH_CHECK(is_sm90 || is_sm100, "Sparse Attention Forward Kernel (sparse_prefill_fwd) is only supported on SM90 or SM100 architectures"); CHECK_DEVICE(q); CHECK_DEVICE(kv); @@ -423,6 +451,8 @@ std::vector sparse_prefill_fwd( if (is_sm90) { sm90::run_fwd_kernel(params); + } else if (is_sm100) { + sm100::run_fwd_kernel(params); } else { TORCH_CHECK(false, "Unknown architecture"); } diff --git a/csrc/sm100/decode/sparse_fp8/dequant.h b/csrc/sm100/decode/sparse_fp8/dequant.h new file mode 100644 index 0000000..3ed46e1 --- /dev/null +++ b/csrc/sm100/decode/sparse_fp8/dequant.h @@ -0,0 +1,61 @@ +#pragma once + +#include +#include + +#include "sm100/defines.h" + +namespace sm100 { + +struct fp8x8 { + __nv_fp8x4_e4m3 lo; + __nv_fp8x4_e4m3 hi; +}; + +struct fp8x32 { + fp8x8 a0, a1, a2, a3; +}; + +struct fp8x16 { + fp8x8 a0, a1; +}; + +__device__ __forceinline__ +bf16x8 cvt_fp8x8_bf16x8(const fp8x8 &inputs, const float &scale) { + __nv_bfloat162 scale_bf162 = __float2bfloat162_rn(scale); + + #define DEQUANT_FP8x4(OUTPUT_BF16_LO, OUTPUT_BF16_HI, FP8x4) \ + { \ + float4 fp32x4 = (float4)(FP8x4); \ + OUTPUT_BF16_LO = __float22bfloat162_rn({fp32x4.x, fp32x4.y})*scale_bf162; \ + OUTPUT_BF16_HI = __float22bfloat162_rn({fp32x4.z, fp32x4.w})*scale_bf162; \ + } + + bf16x8 result; + DEQUANT_FP8x4(result.a01, result.a23, inputs.lo); + DEQUANT_FP8x4(result.a45, result.a67, inputs.hi); + + return result; +} + +__device__ __forceinline__ +fp8x32 ldg_256_fp8x32(void* src_ptr) { + int32x8_t val; + asm volatile("ld.global.nc.L1::evict_normal.L2::evict_normal.L2::256B.v8.s32 {%0, %1, %2, %3, %4, %5, %6, %7}, [%8];" + : "=r"(val.a0), "=r"(val.a1), "=r"(val.a2), "=r"(val.a3), + "=r"(val.a4), "=r"(val.a5), "=r"(val.a6), "=r"(val.a7) + : "l"(src_ptr) + ); + return *reinterpret_cast(&val); +} + +__device__ __forceinline__ +fp8x16 ldg_128_fp8x16(void* src_ptr) { + int4 ret; + asm volatile("ld.global.nc.L1::evict_first.v4.s32 {%0, %1, %2, %3}, [%4];" + : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) + : "l"(src_ptr)); + return *reinterpret_cast(&ret); +} + +} diff --git a/csrc/sm100/decode/sparse_fp8/splitkv_mla.cu b/csrc/sm100/decode/sparse_fp8/splitkv_mla.cu new file mode 100644 index 0000000..068e9fd --- /dev/null +++ b/csrc/sm100/decode/sparse_fp8/splitkv_mla.cu @@ -0,0 +1,592 @@ +#include "splitkv_mla.h" + +#include +#include +#include +#include +#include + +#include "utils.h" +#include "dequant.h" +#include "sm100/defines.h" +#include "sm100/helpers.h" +#include "sm100/intrinsics.h" +#include "sm100/ws_gemm.h" + +namespace sm100 { + +using cutlass::arch::fence_view_async_shared; +using cutlass::arch::NamedBarrier; +using namespace cute; + +constexpr int B_H = 64; +constexpr int B_TOPK = 64; +constexpr int D_K = 576; +constexpr int D_V = 512; +constexpr int NUM_BUFS = 2; +constexpr int NUM_THREADS = 128*3; +constexpr int NUM_WORKING_THREADS = 128 + 128 + 32; +constexpr float MAX_INIT_VAL = -1e30f; // To avoid (-inf) - (-inf) = NaN + +template< + typename Shape_Q, typename TMA_Q, + typename Shape_O, typename TMA_O +> +struct TmaParams { + Shape_Q shape_Q; TMA_Q tma_Q; + Shape_O shape_O; TMA_O tma_O; +}; + +namespace tmem_addr { + constexpr int o = 0; // o: [0, 256] + constexpr int p = 256; // p: [256, 288] +}; + +using SmemLayoutQ = decltype(coalesce(tile_to_shape( + UMMA::Layout_K_SW128_Atom{}, + Shape, Int>{}, + Step<_1, _2>{} +), Shape<_1, _1>{})); + +using SmemLayoutOBuf = decltype(tile_to_shape( + UMMA::Layout_K_INTER_Atom{}, // TODO This may lead to TMA double traffic + Shape, Int>{} +)); + +using SmemLayoutOAccumBuf = Layout< + Shape, Int>, + Stride, _1> // We use stride = 520 here to avoid bank conflict +>; + +using SmemLayoutS = decltype(tile_to_shape( + UMMA::Layout_K_INTER_Atom{}, + Shape, Int>{}, + Step<_1, _2>{} +)); + +template +using SmemLayoutKTiles = decltype(coalesce(tile_to_shape( + UMMA::Layout_K_INTER_Atom{}, + Shape, Int<64*NUM_TILES>>{}, + Step<_1, _2>{} +), Shape<_1, _1>{})); + +template +using SmemLayoutKTilesTransposed = decltype(composition( + SmemLayoutKTiles{}, + Layout< + Shape, Int>, + Stride, _1> + >{} +)); + +using SmemLayoutK = SmemLayoutKTiles<9>; +using SmemLayoutV = SmemLayoutKTilesTransposed<8>; + +struct SharedMemoryPlan { + array_aligned> q; + union { + array_aligned> o_buf; + array_aligned> o_accum_buf; + array_aligned> k[NUM_BUFS]; + } u; + array_aligned> s; + transac_bar_t bar_q; + transac_bar_t bar_k_ready[NUM_BUFS], bar_k_free[NUM_BUFS]; + transac_bar_t bar_qk_done[NUM_BUFS], bar_so_ready[NUM_BUFS]; + float rowwise_max_buf[128], rowwise_li_buf[128]; + bool is_token_valid[NUM_BUFS][B_TOPK]; + array_aligned tmem_start_addr; +}; + +using TiledMMA_QK = decltype(make_tiled_mma( + SM100_MMA_F16BF16_WS_SS_NOELECT{}, + Layout>{} +)); // TODO Use TS? + +using TiledMMA_SV = decltype(make_tiled_mma( + SM100_MMA_F16BF16_WS_SS_NOELECT{}, + Layout>{}, + Tile, Int>{} +)); + +template +CUTE_DEVICE +void store_128b(void* smem_ptr, const T &data) { + static_assert(sizeof(T) == 16); + *(__int128*)smem_ptr = *(__int128*)&data; +} + +template +__global__ void __launch_bounds__(NUM_THREADS, 1, 1) +flash_fwd_splitkv_mla_fp8_sparse_kernel(__grid_constant__ const DecodingParams params, __grid_constant__ const TmaParams tma_params) { +#if IS_SM100 + const int head_block_idx = blockIdx.x; + const int s_q_idx = blockIdx.y; + const int partition_idx = blockIdx.z; + const int warpgroup_idx = cutlass::canonical_warp_group_idx(); + const int idx_in_warpgroup = threadIdx.x % 128; + const int warp_idx = cutlass::canonical_warp_idx_sync(); + + // Define shared tensors + extern __shared__ char wksp_buf[]; + SharedMemoryPlan &plan = *reinterpret_cast(wksp_buf); + Tensor sQ = make_tensor(make_smem_ptr(plan.q.data()), SmemLayoutQ{}); + + if (warp_idx == 0 && elect_one_sync()) { + cute::prefetch_tma_descriptor(tma_params.tma_Q.get_tma_descriptor()); + cute::prefetch_tma_descriptor(tma_params.tma_O.get_tma_descriptor()); + } + + if (warp_idx == 0) { + if (elect_one_sync()) { + plan.bar_q.init(1); + for (int i = 0; i < NUM_BUFS; ++i) { + plan.bar_k_ready[i].init(128); + plan.bar_k_free[i].init(1); + plan.bar_qk_done[i].init(1); + plan.bar_so_ready[i].init(128); + } + cutlass::arch::fence_barrier_init(); + } + cute::TMEM::Allocator1Sm().allocate(512, plan.tmem_start_addr.data()); + TRAP_ONLY_DEVICE_ASSERT(plan.tmem_start_addr.data()[0] == 0); + cute::TMEM::Allocator1Sm().release_allocation_lock(); + } + __syncthreads(); + + int bar_phase_k = 0; + + int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize; + int4 tile_scheduler_metadata = __ldg(reinterpret_cast(tile_scheduler_metadata_ptr)); + int begin_idx = tile_scheduler_metadata.x; + int sched_begin_block_idx = tile_scheduler_metadata.y; + int end_idx = tile_scheduler_metadata.z; + int sched_end_block_idx = tile_scheduler_metadata.w; + if (begin_idx >= params.b) { + if (warp_idx == 0) { + cute::TMEM::Allocator1Sm().free(0, 512); + } + return; + } + + auto get_cur_req_info = [&](int batch_idx) -> std::tuple { + int start_block_idx = batch_idx == begin_idx ? sched_begin_block_idx : 0; + int end_block_idx = batch_idx == end_idx ? sched_end_block_idx : params.topk / B_TOPK; + bool is_no_split = start_block_idx == 0 && end_block_idx == params.topk / B_TOPK; + return {start_block_idx, end_block_idx, is_no_split}; + }; + + if (warpgroup_idx == 0) { + // Producer warpgroup + + #pragma unroll 1 + for (int batch_idx = begin_idx; batch_idx <= end_idx; ++batch_idx) { + auto [start_block_idx, end_block_idx, is_no_split] = get_cur_req_info(batch_idx); + int* gIndices = params.indices_ptr + batch_idx*params.indices_batch_stride + s_q_idx*params.indices_row_stride; // (topk) : (1) + + constexpr int GROUP_SIZE = 4, NUM_GROUPS = 128 / GROUP_SIZE; + constexpr int ROWS_PER_GROUP = B_TOPK / NUM_GROUPS; + int group_idx = idx_in_warpgroup / GROUP_SIZE; + int idx_in_group = idx_in_warpgroup % GROUP_SIZE; + + NamedBarrier::arrive_and_wait(NUM_WORKING_THREADS, 1); + + CUTE_NO_UNROLL + for (int block_idx = start_block_idx; block_idx < end_block_idx; block_idx++) { + int buf_idx = block_idx % NUM_BUFS; + + // Wait for buffer to be available + plan.bar_k_free[buf_idx].wait(bar_phase_k>>buf_idx&1^1); + + // Load + Tensor sK = make_tensor(make_smem_ptr(plan.u.k[buf_idx].data()), SmemLayoutK{}); + + CUTE_UNROLL + for (int local_row = 0; local_row < ROWS_PER_GROUP; ++local_row) { + int smem_row = group_idx + local_row*NUM_GROUPS; + int token_index = __ldg(gIndices + block_idx*B_TOPK + smem_row); + bool is_token_invalid = token_index == -1; + if (idx_in_group == 0) + plan.is_token_valid[buf_idx][smem_row] = !is_token_invalid; + if (is_token_invalid) { + uint128_t zeros = uint128_t{}; + CUTE_UNROLL + for (int local_col = 0; local_col < D_V / (GROUP_SIZE*16); ++local_col) { + int col_base = local_col*(GROUP_SIZE*16) + idx_in_group*16; + store_128b(&sK(smem_row, col_base ), zeros); + store_128b(&sK(smem_row, col_base+8), zeros); + } + CUTE_UNROLL + for (int local_col = 0; local_col < (D_K-D_V) / (GROUP_SIZE*8); ++local_col) { + int col_base = local_col*(GROUP_SIZE*8) + idx_in_group*8; + store_128b(&sK(smem_row, D_V+col_base), zeros); + } + } else { + int block_index = token_index/B_TOPK; + int rel_idx_in_block = (token_index+B_TOPK) % B_TOPK; // NOTE When token_index is -1, -1/B_TOPK = 0 and (-1+B_TOPK)%B_TOPK = 63, so there will be no illegal-memory-access error. However, masking is necessary to prevent NaN (TODO Skip some rows instead?) TODO Masking + fp8* gK_base = (fp8*)params.k_ptr + block_index*params.k_batch_stride + rel_idx_in_block*params.k_row_stride; + float4 scales = __ldg((float4*)(gK_base + D_V)); + + CUTE_UNROLL + for (int local_col = 0; local_col < D_V / (GROUP_SIZE*16); ++local_col) { + int col_base = local_col*(GROUP_SIZE*16) + idx_in_group*16; + fp8x16 cur_fp8s = ldg_128_fp8x16(gK_base + col_base); + float cur_scale = local_col < (256/(GROUP_SIZE*16)) ? + (local_col < (128/(GROUP_SIZE*16)) ? scales.x : scales.y) : + (local_col < (384/(GROUP_SIZE*16)) ? scales.z : scales.w); + store_128b(&sK(smem_row, col_base ), cvt_fp8x8_bf16x8(cur_fp8s.a0, cur_scale)); + store_128b(&sK(smem_row, col_base+8), cvt_fp8x8_bf16x8(cur_fp8s.a1, cur_scale)); + } + + CUTE_UNROLL + for (int local_col = 0; local_col < (D_K-D_V) / (GROUP_SIZE*8); ++local_col) { + int col_base = local_col*(GROUP_SIZE*8) + idx_in_group*8; + fp8x16 cur_k_rope_fp8s = ldg_128_fp8x16(gK_base + D_V + 4*sizeof(float) + col_base*sizeof(bf16)); + bf16x8 cur_k_rope = *reinterpret_cast(&cur_k_rope_fp8s); + store_128b(&sK(smem_row, D_V+col_base), cur_k_rope); + } + } + } + + fence_view_async_shared(); + + // Signal + plan.bar_k_ready[buf_idx].arrive(); + + bar_phase_k ^= 1<(); + + int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4); + + #pragma unroll 1 + for (int batch_idx = begin_idx; batch_idx <= end_idx; ++batch_idx) { + auto [start_block_idx, end_block_idx, is_no_split] = get_cur_req_info(batch_idx); + + NamedBarrier::arrive_and_wait(NUM_WORKING_THREADS, 1); + + float li = 0.0f; + float mi = MAX_INIT_VAL; + + CUTE_NO_UNROLL + for (int block_idx = start_block_idx; block_idx < end_block_idx; block_idx++) { + int buf_idx = block_idx % NUM_BUFS; + + // Wait for P + plan.bar_qk_done[buf_idx].wait(bar_phase_k>>buf_idx&1); + tcgen05_after_thread_sync(); + + // Load P from TMEM + float p[B_TOPK/2]; + float2* p_float2 = reinterpret_cast(p); + tmem_ld_32dp32bNx(tmem_addr::p, p); + cutlass::arch::fence_view_async_tmem_load(); + + // Get rowwise max + float cur_max = -INFINITY; + CUTE_UNROLL + for (int i = 0; i < B_TOPK/2; ++i) { + if (!plan.is_token_valid[buf_idx][(idx_in_warpgroup/64)*(B_TOPK/2)+i]) p[i] = -INFINITY; + cur_max = max(cur_max, p[i]); + } + cur_max *= params.scale_softmax_log2; + + NamedBarrier::arrive_and_wait(128, 0); // TODO Name these barriers + plan.rowwise_max_buf[idx_in_warpgroup] = cur_max; + NamedBarrier::arrive_and_wait(128, 0); + cur_max = max(cur_max, plan.rowwise_max_buf[idx_in_warpgroup ^ 64]); + + float new_max = max(mi, cur_max); + float scale_for_old = exp2f(mi - new_max); + float2 scale_for_old_float2 = {scale_for_old, scale_for_old}; + + // Get S + float2 scale_softmax_log2_float2 = {params.scale_softmax_log2, params.scale_softmax_log2}; + float2 neg_new_max_float2 = {-new_max, -new_max}; + bf16 s[B_TOPK/2]; + float2 cur_sum = {0.0f, 0.0f}; + CUTE_UNROLL + for (int i = 0; i < (B_TOPK/2)/2; ++i) { + float2 t = float2_fma(p_float2[i], scale_softmax_log2_float2, neg_new_max_float2); + t.x = exp2(t.x); + t.y = exp2(t.y); + *(__nv_bfloat162*)&s[i*2] = __float22bfloat162_rn(t); + cur_sum = float2_add(cur_sum, t); + } + + // Save S + // NOTE We don't need a barrier here, since the current QK^T has finished implies that the previous SV has finished + bf16* sS_base = plan.s.data() + (idx_in_warpgroup/64)*(B_H*B_TOPK/2) + (idx_in_warpgroup%64) * 8; + CUTE_UNROLL + for (int i = 0; i < (B_TOPK/2)/8; i += 1) { + store_128b(sS_base + i*8*B_H, *((bf16x8*)s + i)); + } + fence_view_async_shared(); + + // Rescale O + if (block_idx != start_block_idx) { + constexpr int B_SCALE_O = 64; + float2 o[B_SCALE_O/2]; + CUTE_UNROLL + for (int b = 0; b < (D_V/2)/B_SCALE_O; ++b) { + tmem_ld_32dp32bNx(tmem_addr::o + b*B_SCALE_O, o); + cutlass::arch::fence_view_async_tmem_load(); + CUTE_UNROLL + for (int i = 0; i < B_SCALE_O/2; ++i) + o[i] = float2_mul(o[i], scale_for_old_float2); + tmem_st_32dp32bNx(tmem_addr::o + b*B_SCALE_O, o); + cutlass::arch::fence_view_async_tmem_store(); + } + } + plan.bar_so_ready[buf_idx].arrive(); + + // Update mi and li + mi = new_max; + li = li * scale_for_old + cur_sum.x + cur_sum.y; + + bar_phase_k ^= 1<>((end_block_idx-1)%NUM_BUFS)&1^1); + tcgen05_after_thread_sync(); + + // Save O + float o_scale = li == 0.0f ? 0.0f : 1.0f / li; + float2 o_scale_float2 = {o_scale, o_scale}; + if (is_no_split) { + constexpr int B_EPI = 32; + float2 o[B_EPI/2]; + __nv_bfloat162 o_bf16[B_EPI/2]; + Tensor sO = make_tensor(make_smem_ptr(plan.u.o_buf.data()), SmemLayoutOBuf{}); + bf16* sO_base = plan.u.o_buf.data() + ((idx_in_warpgroup/64)*128)*B_H + (idx_in_warpgroup%64)*8; + CUTE_UNROLL + for (int i = 0; i < (D_V/2) / B_EPI; ++i) { + // Load + tmem_ld_32dp32bNx(tmem_addr::o + i*B_EPI, o); + cutlass::arch::fence_view_async_tmem_load(); + // Scale & Convert + CUTE_UNROLL + for (int j = 0; j < B_EPI/2; ++j) { + o[j] = float2_mul(o[j], o_scale_float2); + o_bf16[j] = __float22bfloat162_rn(o[j]); + } + // Store + int col_base = (i*B_EPI>=D_V/4 ? D_V/2 : 0) + (i*B_EPI%(D_V/4)); + CUTE_UNROLL + for (int j = 0; j < B_EPI / 8; ++j) + store_128b(sO_base + (col_base+j*8)*B_H, *reinterpret_cast(&o_bf16[j*4])); + } + fence_view_async_shared(); + NamedBarrier::arrive_and_wait(128, 0); + if (warp_idx == 4 && elect_one_sync()) { + Tensor tma_gO = tma_params.tma_O.get_tma_tensor(tma_params.shape_O)(_, _, s_q_idx, batch_idx); + auto thr_tma = tma_params.tma_O.get_slice(_0{}); + Tensor my_tma_gO = flat_divide(tma_gO, Shape, Int>{})(_, _, head_block_idx, _0{}); + cute::copy( + tma_params.tma_O, + thr_tma.partition_S(sO), + thr_tma.partition_D(my_tma_gO) + ); + cute::tma_store_arrive(); + } + } else { + constexpr int B_EPI = 64; + float2 o[B_EPI/2]; + Tensor sO = make_tensor(make_smem_ptr(plan.u.o_accum_buf.data()), SmemLayoutOAccumBuf{}); + CUTE_UNROLL + for (int i = 0; i < (D_V/2) / B_EPI; ++i) { + // Load + tmem_ld_32dp32bNx(tmem_addr::o + i*B_EPI, o); + cutlass::arch::fence_view_async_tmem_load(); + // Scale & Convert + CUTE_UNROLL + for (int j = 0; j < B_EPI/2; ++j) + o[j] = float2_mul(o[j], o_scale_float2); + // Store + int col_base = (idx_in_warpgroup/64)*128 + (i*B_EPI >= D_V/4 ? D_V/2 : 0) + (i*B_EPI%(D_V/4)); + CUTE_UNROLL + for (int j = 0; j < B_EPI / 4; ++j) + store_128b(&sO(idx_in_warpgroup%64, col_base + j*4), *reinterpret_cast(&o[j*2])); + } + fence_view_async_shared(); + NamedBarrier::arrive_and_wait(128, 0); + if (elect_one_sync()) { + CUTE_UNROLL + for (int local_row = 0; local_row < B_H/4; ++local_row) { + int smem_row = local_row*4 + (warp_idx-4); + if (smem_row < num_valid_heads) { + SM90_BULK_COPY_S2G::copy( + &sO(smem_row, _0{}), + (float*)params.oaccum_ptr + (split_idx*params.q_seq_per_hk + start_seq_idx + smem_row)*D_V, + D_V*sizeof(float) + ); + } + } + cute::tma_store_arrive(); + } + } + + cute::tma_store_wait<0>(); + } + + if (warp_idx == 4) { + cute::TMEM::Allocator1Sm().free(0, 512); + } + } else { + cutlass::arch::warpgroup_reg_dealloc<96>(); + if (warp_idx == 8) { + // UTCMMA warp + + bool bar_phase_q = 0; + TiledMMA tiled_mma_qk = TiledMMA_QK{}; + TiledMMA tiled_mma_sv = TiledMMA_SV{}; + Tensor tP = partition_fragment_C(tiled_mma_qk, Shape, Int>{}); + Tensor tO = partition_fragment_C(tiled_mma_sv, Shape, Int>{}); + tO.data().get() = tmem_addr::o; + tP.data().get() = tmem_addr::p; + Tensor sS = make_tensor(make_smem_ptr(plan.s.data()), SmemLayoutS{}); + + #pragma unroll 1 + for (int batch_idx = begin_idx; batch_idx <= end_idx; ++batch_idx) { + auto [start_block_idx, end_block_idx, is_no_split] = get_cur_req_info(batch_idx); + + if (elect_one_sync()) { + // Copy Q + Tensor gQ = flat_divide( + tma_params.tma_Q.get_tma_tensor(tma_params.shape_Q)(_, _, s_q_idx, batch_idx), + Tile, Int>{} + )(_, _, head_block_idx, _0{}); + launch_tma_copy(tma_params.tma_Q, gQ, sQ, plan.bar_q, TMA::CacheHintSm90::EVICT_FIRST); + plan.bar_q.arrive_and_expect_tx(B_H*D_K*sizeof(bf16)); + } + + NamedBarrier::arrive_and_wait(NUM_WORKING_THREADS, 1); + + if (elect_one_sync()) { + // Wait for Q + plan.bar_q.wait(bar_phase_q); + bar_phase_q ^= 1; + tcgen05_after_thread_sync(); + + CUTE_NO_UNROLL + for (int block_idx = start_block_idx; block_idx < end_block_idx; block_idx++) { + int buf_idx = block_idx % NUM_BUFS; + + // Wait for K + plan.bar_k_ready[buf_idx].wait(bar_phase_k>>buf_idx&1); + tcgen05_after_thread_sync(); + Tensor sK = make_tensor(make_smem_ptr(plan.u.k[buf_idx].data()), SmemLayoutK{}); + + // Issue P = Q @ K^T + utcmma_ss(tiled_mma_qk, sQ, sK, tP, true); + umma_arrive_noelect(plan.bar_qk_done[buf_idx]); + + // Wait for S + plan.bar_so_ready[buf_idx].wait(bar_phase_k>>buf_idx&1); + tcgen05_after_thread_sync(); + Tensor sV = make_tensor(make_smem_ptr(plan.u.k[buf_idx].data()), SmemLayoutV{}); + + // Issue O += S @ V + utcmma_ss(tiled_mma_sv, sS, sV, tO, block_idx == start_block_idx); + umma_arrive_noelect(plan.bar_k_free[buf_idx]); + + bar_phase_k ^= 1< tma_params = { + shape_Q, tma_Q, + shape_O, tma_O + }; + auto mla_kernel = &flash_fwd_splitkv_mla_fp8_sparse_kernel; + + constexpr size_t smem_size = sizeof(SharedMemoryPlan); + CHECK_CUDA(cudaFuncSetAttribute(mla_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + const int num_m_blocks = cute::ceil_div(params.q_head_per_hk, B_H); + // NOTE Don't use PDL because of potential compiler bugs! + mla_kernel<<>>(params, tma_params); + CHECK_CUDA_KERNEL_LAUNCH(); +} + +} \ No newline at end of file diff --git a/csrc/sm100/decode/sparse_fp8/splitkv_mla.h b/csrc/sm100/decode/sparse_fp8/splitkv_mla.h new file mode 100644 index 0000000..cc8c6da --- /dev/null +++ b/csrc/sm100/decode/sparse_fp8/splitkv_mla.h @@ -0,0 +1,10 @@ +#pragma once + +#include "params.h" + +namespace sm100 { + +void run_flash_splitkv_mla_fp8_sparse_kernel(DecodingParams ¶ms, cudaStream_t stream); + +} + diff --git a/csrc/sm100/defines.h b/csrc/sm100/defines.h new file mode 100644 index 0000000..0e779a3 --- /dev/null +++ b/csrc/sm100/defines.h @@ -0,0 +1,30 @@ +#pragma once + +#include +#include + +namespace sm100 { + +using bf16 = cutlass::bfloat16_t; +using fp8 = cutlass::float_e4m3_t; +using transac_bar_t = cutlass::arch::ClusterTransactionBarrier; +using cutlass::arch::fence_view_async_shared; +using cutlass::arch::fence_barrier_init; +using cutlass::arch::NamedBarrier; + +struct int32x8_t { + int a0, a1, a2, a3, a4, a5, a6, a7; +}; + +struct float8 { + float2 a01, a23, a45, a67; +}; + +struct bf16x8 { + __nv_bfloat162 a01; + __nv_bfloat162 a23; + __nv_bfloat162 a45; + __nv_bfloat162 a67; +}; + +} diff --git a/csrc/sm100/helpers.h b/csrc/sm100/helpers.h new file mode 100644 index 0000000..9195b33 --- /dev/null +++ b/csrc/sm100/helpers.h @@ -0,0 +1,97 @@ +#pragma once + +#include +#include "defines.h" + +namespace sm100 { + +using namespace cute; + +using _72 = Int<72>; +using _576 = Int<576>; + +template< + typename TMA, + typename Tensor0, + typename Tensor1 +> +CUTE_DEVICE +void launch_tma_copy( + const TMA &tma_copy, + Tensor0 src, + Tensor1 dst, + transac_bar_t &bar, + const cute::TMA::CacheHintSm90 &cache_hint = cute::TMA::CacheHintSm90::EVICT_NORMAL +) { + auto thr_tma = tma_copy.get_slice(_0{}); + cute::copy( + tma_copy.with(reinterpret_cast(bar), 0, cache_hint), + thr_tma.partition_S(src), + thr_tma.partition_D(dst) + ); +} + +template< + typename TiledMMA, + typename TensorA, + typename TensorB, + typename TensorFragC +> +CUTE_DEVICE +void utcmma_ss( + TiledMMA &tiled_mma, + TensorA sA, + TensorB sB, + TensorFragC tC_frag, + bool clear_accum +) { + tiled_mma.accumulate_ = clear_accum ? UMMA::ScaleOut::Zero : UMMA::ScaleOut::One; + ThrMMA thr_mma = tiled_mma.get_slice(_0{}); // Since A/B/C are already CTA-local tiles, this number does not matter + auto sA_frag = thr_mma.partition_fragment_A(sA); + auto sB_frag = thr_mma.partition_fragment_B(sB); + static_assert(size<2>(sA_frag) == size<2>(sB_frag)); + static_assert(size<1>(sA_frag) == size<1>(tC_frag)); + static_assert(size<1>(sB_frag) == size<2>(tC_frag)); + CUTE_UNROLL + for (int k = 0; k < size<2>(sA_frag); ++k) { + cute::gemm( + tiled_mma, + sA_frag(_, _, k), + sB_frag(_, _, k), + tC_frag + ); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } +} + +template< + typename TiledMMA, + typename TensorA, + typename TensorB, + typename TensorFragC +> +CUTE_DEVICE +void utcmma_ts( + TiledMMA &tiled_mma, + TensorA tA_frag, + TensorB sB, + TensorFragC tC_frag, + bool clear_accum +) { + tiled_mma.accumulate_ = clear_accum ? UMMA::ScaleOut::Zero : UMMA::ScaleOut::One; + ThrMMA thr_mma = tiled_mma.get_slice(_0{}); // Since A/B/C are already CTA-local tiles, this number does not matter + auto sB_frag = thr_mma.partition_fragment_B(sB); + static_assert(size<2>(tA_frag) == size<2>(sB_frag)); + CUTE_UNROLL + for (int k = 0; k < size<2>(tA_frag); ++k) { + cute::gemm( + tiled_mma, + tA_frag(_, _, k), + sB_frag(_, _, k), + tC_frag + ); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } +} + +} diff --git a/csrc/sm100/intrinsics.h b/csrc/sm100/intrinsics.h new file mode 100644 index 0000000..c2402ee --- /dev/null +++ b/csrc/sm100/intrinsics.h @@ -0,0 +1,461 @@ +#pragma once + +#include +#include + +#include "defines.h" + +namespace sm100 { + +using namespace cute; + +__forceinline__ __device__ void cp_async_cacheglobal_l2_prefetch_256B(const void* src, void* dst) { + uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst); + asm volatile("cp.async.cg.shared.global.L2::256B [%0], [%1], %2;\n" + :: "r"(dst_addr), + "l"(src), + "n"(16)); +} + +CUTE_DEVICE +int64_t createpolicy_evict_last() { + int64_t res; + asm volatile( + "createpolicy.fractional.L2::evict_last.b64 %0, 1.0; \n\t" + : "=l"(res) + : + ); + return res; +} + +template +CUTE_DEVICE +static void st_async_128b(void* dst_ptr, const T& data, const transac_bar_t* mbar_ptr) { + static_assert(sizeof(T) == 16, "Data type must be 16 bytes (128 bits) for st_async_128b."); + long2 data_long2 = *reinterpret_cast(&data); + uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst_ptr); + uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(mbar_ptr); + asm volatile ( + "st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.v2.s64 [%0], {%1, %2}, [%3]; \n" + : + : "r"(dst_addr), "l"(data_long2.x), "l"(data_long2.y), "r"(mbar_addr) + ); +} + + +__device__ __forceinline__ void tcgen05_before_thread_sync() { + asm volatile("tcgen05.fence::before_thread_sync;"); +} + +__device__ __forceinline__ void tcgen05_after_thread_sync() { + asm volatile("tcgen05.fence::after_thread_sync;"); +} + +CUTE_DEVICE +void umma_arrive_multicast_noelect(transac_bar_t &smem_ptr, uint16_t cta_mask) { + uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(&smem_ptr); + asm volatile( + "{\n\t" + "tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.multicast::cluster.b64 [%0], %1; \n\t" + "}" + : + :"r"(bar_intptr), "h"(cta_mask)); +} + +CUTE_DEVICE +void umma_arrive_multicast_2x1SM_noelect(transac_bar_t &smem_ptr, uint16_t cta_mask) { + uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(&smem_ptr); + asm volatile( + "{\n\t" + "tcgen05.commit.cta_group::2.mbarrier::arrive::one.shared::cluster.multicast::cluster.b64 [%0], %1; \n\t" + "}" + : + :"r"(bar_intptr), "h"(cta_mask)); +} + +CUTE_DEVICE +void umma_arrive_noelect(transac_bar_t &smem_ptr) { + uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(&smem_ptr); + asm volatile("tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.b64 [%0];" + : + :"r"(bar_intptr)); +} + +CUTE_DEVICE +void umma_arrive_2x1SM_noelect(transac_bar_t &smem_ptr) { + uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(&smem_ptr); + asm volatile("tcgen05.commit.cta_group::2.mbarrier::arrive::one.shared::cluster.b64 [%0];" + : + :"r"(bar_intptr)); +} + +CUTE_DEVICE +float2 float2_add(const float2 &a, const float2 &b) { + float2 res; + cute::add(res, a, b); + return res; +} + +CUTE_DEVICE +float2 float2_mul(const float2 &a, const float2 &b) { + float2 res; + cute::mul(res, a, b); + return res; +} + +CUTE_DEVICE +float2 float2_fma(const float2 &a, const float2 &b, const float2 &c) { + // return a*b+c + float2 res; + cute::fma(res, a, b, c); + return res; +} + +CUTE_DEVICE +float2 float2_neg(const float2 &a) { + float2 t = {-1.0f, -1.0f}; + return float2_mul(a, t); +} + +template +CUTE_DEVICE void tma_gather4(const void* desc_ptr, transac_bar_t* mbar_ptr, void* smem_ptr, int col_idx, int4 row_idxs, TMA::CacheHintSm90 cache_hint) { + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(mbar_ptr); + if constexpr (USE_CTA0_MBAR) { + mbar_addr &= Sm100MmaPeerBitMask; + } + asm volatile( + "cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::2.L2::cache_hint [%0], [%1, {%2, %3, %4, %5, %6}], [%7], %8;\n" + : + : "r"(smem_addr), "l"(desc_ptr), "r"(col_idx), + "r"(row_idxs.x), "r"(row_idxs.y), "r"(row_idxs.z), "r"(row_idxs.w), + "r"(mbar_addr), "l"(uint64_t(cache_hint)) + : "memory" + ); +} + +// 32 data path lanes, 32-bit pattern, repeated N times +template +CUTE_DEVICE void tmem_ld_32dp32bNx(uint32_t const &src_addr, T* dst_ptr_) { + static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 128, "N must be a power of 2 and lies between 1 ~ 128"); + uint32_t* dst_ptr = reinterpret_cast(dst_ptr_); + + if constexpr (N == 1) { + asm volatile("tcgen05.ld.sync.aligned.32x32b.x1.b32" + "{%0}," + "[%1];\n" + : "=r"(dst_ptr[0]) + : "r"(src_addr)); + } else if constexpr (N == 2) { + asm volatile("tcgen05.ld.sync.aligned.32x32b.x2.b32" + "{%0, %1}," + "[%2];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]) + : "r"(src_addr)); + } else if constexpr (N == 4) { + asm volatile("tcgen05.ld.sync.aligned.32x32b.x4.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]) + : "r"(src_addr)); + } else if constexpr (N == 8) { + asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]) + : "r"(src_addr)); + } else if constexpr (N == 16) { + asm volatile("tcgen05.ld.sync.aligned.32x32b.x16.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15}," + "[%16];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]) + : "r"(src_addr)); + } else if constexpr (N == 32) { + asm volatile("tcgen05.ld.sync.aligned.32x32b.x32.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " + "%26, %27, %28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]) + : "r"(src_addr)); + } else if constexpr (N == 64) { + asm volatile( + "tcgen05.ld.sync.aligned.32x32b.x64.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, " + "%57, %58, %59, %60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]) + : "r"(src_addr)); + } else if constexpr (N == 128) { + asm volatile( + "tcgen05.ld.sync.aligned.32x32b.x128.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, " + "%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, " + "%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, " + "%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, " + "%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, " + "%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, %123, %124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]), "=r"(dst_ptr[64]), "=r"(dst_ptr[65]), + "=r"(dst_ptr[66]), "=r"(dst_ptr[67]), "=r"(dst_ptr[68]), + "=r"(dst_ptr[69]), "=r"(dst_ptr[70]), "=r"(dst_ptr[71]), + "=r"(dst_ptr[72]), "=r"(dst_ptr[73]), "=r"(dst_ptr[74]), + "=r"(dst_ptr[75]), "=r"(dst_ptr[76]), "=r"(dst_ptr[77]), + "=r"(dst_ptr[78]), "=r"(dst_ptr[79]), "=r"(dst_ptr[80]), + "=r"(dst_ptr[81]), "=r"(dst_ptr[82]), "=r"(dst_ptr[83]), + "=r"(dst_ptr[84]), "=r"(dst_ptr[85]), "=r"(dst_ptr[86]), + "=r"(dst_ptr[87]), "=r"(dst_ptr[88]), "=r"(dst_ptr[89]), + "=r"(dst_ptr[90]), "=r"(dst_ptr[91]), "=r"(dst_ptr[92]), + "=r"(dst_ptr[93]), "=r"(dst_ptr[94]), "=r"(dst_ptr[95]), + "=r"(dst_ptr[96]), "=r"(dst_ptr[97]), "=r"(dst_ptr[98]), + "=r"(dst_ptr[99]), "=r"(dst_ptr[100]), "=r"(dst_ptr[101]), + "=r"(dst_ptr[102]), "=r"(dst_ptr[103]), "=r"(dst_ptr[104]), + "=r"(dst_ptr[105]), "=r"(dst_ptr[106]), "=r"(dst_ptr[107]), + "=r"(dst_ptr[108]), "=r"(dst_ptr[109]), "=r"(dst_ptr[110]), + "=r"(dst_ptr[111]), "=r"(dst_ptr[112]), "=r"(dst_ptr[113]), + "=r"(dst_ptr[114]), "=r"(dst_ptr[115]), "=r"(dst_ptr[116]), + "=r"(dst_ptr[117]), "=r"(dst_ptr[118]), "=r"(dst_ptr[119]), + "=r"(dst_ptr[120]), "=r"(dst_ptr[121]), "=r"(dst_ptr[122]), + "=r"(dst_ptr[123]), "=r"(dst_ptr[124]), "=r"(dst_ptr[125]), + "=r"(dst_ptr[126]), "=r"(dst_ptr[127]) + : "r"(src_addr)); + } else { + asm volatile ("trap"); + } +} + +// 32 data path lanes, 32-bit pattern, repeated N times +template +CUTE_DEVICE void tmem_st_32dp32bNx(uint32_t const &dst_addr, T* src_ptr_) { + static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 128, "N must be a power of 2 and lies between 1 ~ 128"); + uint32_t* src_ptr = reinterpret_cast(src_ptr_); + + if constexpr (N == 1) { + asm volatile("tcgen05.st.sync.aligned.32x32b.x1.b32" + "[%1], {%0};\n" + : + : "r"(src_ptr[0]), + "r"(dst_addr)); + } else if constexpr (N == 2) { + asm volatile("tcgen05.st.sync.aligned.32x32b.x2.b32" + "[%2], {%0, %1};\n" + : + : "r"(src_ptr[0]), "r"(src_ptr[1]), + "r"(dst_addr)); + } else if constexpr (N == 4) { + asm volatile("tcgen05.st.sync.aligned.32x32b.x4.b32" + "[%4], {%0, %1, %2, %3};\n" + : + : "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]), + "r"(src_ptr[3]), + "r"(dst_addr)); + } else if constexpr (N == 8) { + asm volatile("tcgen05.st.sync.aligned.32x32b.x8.b32" + "[%8], {%0, %1, %2, %3, %4, %5, %6, %7};\n" + : + : "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]), + "r"(src_ptr[3]), "r"(src_ptr[4]), "r"(src_ptr[5]), + "r"(src_ptr[6]), "r"(src_ptr[7]), + "r"(dst_addr)); + } else if constexpr (N == 16) { + asm volatile("tcgen05.st.sync.aligned.32x32b.x16.b32" + "[%16], {%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15};\n" + : + : "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]), + "r"(src_ptr[3]), "r"(src_ptr[4]), "r"(src_ptr[5]), + "r"(src_ptr[6]), "r"(src_ptr[7]), "r"(src_ptr[8]), + "r"(src_ptr[9]), "r"(src_ptr[10]), "r"(src_ptr[11]), + "r"(src_ptr[12]), "r"(src_ptr[13]), "r"(src_ptr[14]), + "r"(src_ptr[15]), + "r"(dst_addr)); + } else if constexpr (N == 32) { + asm volatile("tcgen05.st.sync.aligned.32x32b.x32.b32" + "[%32], {%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " + "%26, %27, %28, %29, %30, %31};\n" + : + : "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]), + "r"(src_ptr[3]), "r"(src_ptr[4]), "r"(src_ptr[5]), + "r"(src_ptr[6]), "r"(src_ptr[7]), "r"(src_ptr[8]), + "r"(src_ptr[9]), "r"(src_ptr[10]), "r"(src_ptr[11]), + "r"(src_ptr[12]), "r"(src_ptr[13]), "r"(src_ptr[14]), + "r"(src_ptr[15]), "r"(src_ptr[16]), "r"(src_ptr[17]), + "r"(src_ptr[18]), "r"(src_ptr[19]), "r"(src_ptr[20]), + "r"(src_ptr[21]), "r"(src_ptr[22]), "r"(src_ptr[23]), + "r"(src_ptr[24]), "r"(src_ptr[25]), "r"(src_ptr[26]), + "r"(src_ptr[27]), "r"(src_ptr[28]), "r"(src_ptr[29]), + "r"(src_ptr[30]), "r"(src_ptr[31]), + "r"(dst_addr)); + } else if constexpr (N == 64) { + asm volatile( + "tcgen05.st.sync.aligned.32x32b.x64.b32" + "[%64], {%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, " + "%57, %58, %59, %60, %61, %62, %63};\n" + : + : "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]), + "r"(src_ptr[3]), "r"(src_ptr[4]), "r"(src_ptr[5]), + "r"(src_ptr[6]), "r"(src_ptr[7]), "r"(src_ptr[8]), + "r"(src_ptr[9]), "r"(src_ptr[10]), "r"(src_ptr[11]), + "r"(src_ptr[12]), "r"(src_ptr[13]), "r"(src_ptr[14]), + "r"(src_ptr[15]), "r"(src_ptr[16]), "r"(src_ptr[17]), + "r"(src_ptr[18]), "r"(src_ptr[19]), "r"(src_ptr[20]), + "r"(src_ptr[21]), "r"(src_ptr[22]), "r"(src_ptr[23]), + "r"(src_ptr[24]), "r"(src_ptr[25]), "r"(src_ptr[26]), + "r"(src_ptr[27]), "r"(src_ptr[28]), "r"(src_ptr[29]), + "r"(src_ptr[30]), "r"(src_ptr[31]), "r"(src_ptr[32]), + "r"(src_ptr[33]), "r"(src_ptr[34]), "r"(src_ptr[35]), + "r"(src_ptr[36]), "r"(src_ptr[37]), "r"(src_ptr[38]), + "r"(src_ptr[39]), "r"(src_ptr[40]), "r"(src_ptr[41]), + "r"(src_ptr[42]), "r"(src_ptr[43]), "r"(src_ptr[44]), + "r"(src_ptr[45]), "r"(src_ptr[46]), "r"(src_ptr[47]), + "r"(src_ptr[48]), "r"(src_ptr[49]), "r"(src_ptr[50]), + "r"(src_ptr[51]), "r"(src_ptr[52]), "r"(src_ptr[53]), + "r"(src_ptr[54]), "r"(src_ptr[55]), "r"(src_ptr[56]), + "r"(src_ptr[57]), "r"(src_ptr[58]), "r"(src_ptr[59]), + "r"(src_ptr[60]), "r"(src_ptr[61]), "r"(src_ptr[62]), + "r"(src_ptr[63]), + "r"(dst_addr)); + } else if constexpr (N == 128) { + asm volatile( + "tcgen05.st.sync.aligned.32x32b.x128.b32" + "[%128], {%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, " + "%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, " + "%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, " + "%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, " + "%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, " + "%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, %123, %124, %125, %126, %127};\n" + : + : "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]), + "r"(src_ptr[3]), "r"(src_ptr[4]), "r"(src_ptr[5]), + "r"(src_ptr[6]), "r"(src_ptr[7]), "r"(src_ptr[8]), + "r"(src_ptr[9]), "r"(src_ptr[10]), "r"(src_ptr[11]), + "r"(src_ptr[12]), "r"(src_ptr[13]), "r"(src_ptr[14]), + "r"(src_ptr[15]), "r"(src_ptr[16]), "r"(src_ptr[17]), + "r"(src_ptr[18]), "r"(src_ptr[19]), "r"(src_ptr[20]), + "r"(src_ptr[21]), "r"(src_ptr[22]), "r"(src_ptr[23]), + "r"(src_ptr[24]), "r"(src_ptr[25]), "r"(src_ptr[26]), + "r"(src_ptr[27]), "r"(src_ptr[28]), "r"(src_ptr[29]), + "r"(src_ptr[30]), "r"(src_ptr[31]), "r"(src_ptr[32]), + "r"(src_ptr[33]), "r"(src_ptr[34]), "r"(src_ptr[35]), + "r"(src_ptr[36]), "r"(src_ptr[37]), "r"(src_ptr[38]), + "r"(src_ptr[39]), "r"(src_ptr[40]), "r"(src_ptr[41]), + "r"(src_ptr[42]), "r"(src_ptr[43]), "r"(src_ptr[44]), + "r"(src_ptr[45]), "r"(src_ptr[46]), "r"(src_ptr[47]), + "r"(src_ptr[48]), "r"(src_ptr[49]), "r"(src_ptr[50]), + "r"(src_ptr[51]), "r"(src_ptr[52]), "r"(src_ptr[53]), + "r"(src_ptr[54]), "r"(src_ptr[55]), "r"(src_ptr[56]), + "r"(src_ptr[57]), "r"(src_ptr[58]), "r"(src_ptr[59]), + "r"(src_ptr[60]), "r"(src_ptr[61]), "r"(src_ptr[62]), + "r"(src_ptr[63]), "r"(src_ptr[64]), "r"(src_ptr[65]), + "r"(src_ptr[66]), "r"(src_ptr[67]), "r"(src_ptr[68]), + "r"(src_ptr[69]), "r"(src_ptr[70]), "r"(src_ptr[71]), + "r"(src_ptr[72]), "r"(src_ptr[73]), "r"(src_ptr[74]), + "r"(src_ptr[75]), "r"(src_ptr[76]), "r"(src_ptr[77]), + "r"(src_ptr[78]), "r"(src_ptr[79]), "r"(src_ptr[80]), + "r"(src_ptr[81]), "r"(src_ptr[82]), "r"(src_ptr[83]), + "r"(src_ptr[84]), "r"(src_ptr[85]), "r"(src_ptr[86]), + "r"(src_ptr[87]), "r"(src_ptr[88]), "r"(src_ptr[89]), + "r"(src_ptr[90]), "r"(src_ptr[91]), "r"(src_ptr[92]), + "r"(src_ptr[93]), "r"(src_ptr[94]), "r"(src_ptr[95]), + "r"(src_ptr[96]), "r"(src_ptr[97]), "r"(src_ptr[98]), + "r"(src_ptr[99]), "r"(src_ptr[100]), "r"(src_ptr[101]), + "r"(src_ptr[102]), "r"(src_ptr[103]), "r"(src_ptr[104]), + "r"(src_ptr[105]), "r"(src_ptr[106]), "r"(src_ptr[107]), + "r"(src_ptr[108]), "r"(src_ptr[109]), "r"(src_ptr[110]), + "r"(src_ptr[111]), "r"(src_ptr[112]), "r"(src_ptr[113]), + "r"(src_ptr[114]), "r"(src_ptr[115]), "r"(src_ptr[116]), + "r"(src_ptr[117]), "r"(src_ptr[118]), "r"(src_ptr[119]), + "r"(src_ptr[120]), "r"(src_ptr[121]), "r"(src_ptr[122]), + "r"(src_ptr[123]), "r"(src_ptr[124]), "r"(src_ptr[125]), + "r"(src_ptr[126]), "r"(src_ptr[127]), + "r"(dst_addr)); + } else { + asm volatile ("trap"); + } +} + +static constexpr int PEER_ADDR_MASK = 16777216; // peer_addr = my_addr ^ PEER_ADDR_MASK. 不确定是不是在所有显卡上都是这个数字 +template +CUTE_DEVICE +T* get_peer_addr(const T* p) { + return (T*)((int64_t)(p) ^ PEER_ADDR_MASK); +} + + +} diff --git a/csrc/sm100/prefill/sparse/fwd.cu b/csrc/sm100/prefill/sparse/fwd.cu new file mode 100644 index 0000000..963ac78 --- /dev/null +++ b/csrc/sm100/prefill/sparse/fwd.cu @@ -0,0 +1,785 @@ +#include "fwd.h" + +#include +#include +#include +#include +#include +#include + +#include "params.h" +#include "utils.h" +#include "sm100/ws_gemm.h" +#include "sm100/helpers.h" +#include "sm100/intrinsics.h" +#include "sm100/tma_cta_group2_nosplit.h" + +namespace sm100 { + +using namespace cute; + +CUTE_DEVICE int32x8_t ldg_256_indices(void* src_ptr) { + int32x8_t val; + asm volatile("ld.global.nc.L1::evict_normal.L2::evict_normal.L2::256B.v8.s32 {%0, %1, %2, %3, %4, %5, %6, %7}, [%8];" + : "=r"(val.a0), "=r"(val.a1), "=r"(val.a2), "=r"(val.a3), + "=r"(val.a4), "=r"(val.a5), "=r"(val.a6), "=r"(val.a7) + : "l"(src_ptr) + ); + return val; +} + +template< + typename Shape_Q, typename TMA_Q, + typename Shape_O, typename TMA_O +> +struct TmaParams { + Shape_Q shape_Q; TMA_Q tma_Q; + Shape_O shape_O; TMA_O tma_O; + CUtensorMap tensor_map_kv; +}; + +struct float2x2 { + float2 lo, hi; +}; + +constexpr int D_Q = 576; +constexpr int D_K = 576; +constexpr int D_V = 512; +constexpr float MAX_INIT_VAL = -1e30; // We use this number as the initial value for mi (max logits) to avoid -inf - (-inf) = nan + +constexpr int B_H = 128; // For 2 CTAs +constexpr int B_TOPK = 128; // For 2 CTAs +constexpr int NUM_BUFS = 2; +constexpr int NUM_THREADS = 256 + 128 + 128; // 128 TMA threads, 128 scale & exp threads, 32 UTCMMA threads + +constexpr int D_sQ = 256, NUM_sQ_TILES = D_sQ / 64; +constexpr int D_tQ = D_Q - D_sQ, NUM_tQ_TILES = D_tQ / 64; +static_assert(D_sQ%64 == 0 && D_tQ%64 == 0 && D_sQ + D_tQ == D_Q); + +// Tensor memory columns +namespace tmem_cols { + // 0 ~ 256: output + // 256 ~ 320: P + // 320 ~ 512: Q[192:576] + constexpr int o = 0; + constexpr int p = 256; + constexpr int q = 512 - D_tQ/2; + static_assert(p+64 <= q); +} + +template +using SmemLayoutQTiles = decltype(coalesce(tile_to_shape( + UMMA::Layout_K_SW128_Atom{}, + Shape, Int<64*NUM_TILES>>{}, + Step<_1, _2>{} +), Shape<_1, _1>{})); + +template +using SmemLayoutOTiles = decltype(coalesce(tile_to_shape( + UMMA::Layout_K_SW128_Atom{}, + Shape, Int<64*NUM_TILES>>{}, + Step<_1, _2>{} +), Shape<_1, _1>{})); + +using SmemLayoutO = SmemLayoutOTiles<8>; + +template +using SmemLayoutKTiles = decltype(coalesce(tile_to_shape( + UMMA::Layout_K_SW128_Atom{}, + Shape, Int<64*NUM_TILES>>{}, + Step<_1, _2>{} +), Shape<_1, _1>{})); + +using SmemLayoutV = decltype(coalesce(tile_to_shape( + UMMA::Layout_MN_SW128_Atom{}, + Shape, Int>{}, + Step<_2, _1>{} +), Shape<_1, _1>{})); + +template +using SmemLayoutSTiles = decltype(coalesce(tile_to_shape( + UMMA::Layout_K_INTER_Atom{}, + Shape, Int<64*NUM_TILES>>{}, + Step<_1, _2>{} +), Shape<_1, _1>{})); + +struct SharedMemoryPlan { + union { + array_aligned>> q_full; + struct { + array_aligned>> sq; + array_aligned> v; + // NOTE K is not overlapped with q_full, so we can do k copy-in while performing S->T copy for q + array_aligned>> k; + } s; + array_aligned> o; + } u; + array_aligned>> s; + char is_k_valid[NUM_BUFS][B_TOPK/8]; + transac_bar_t bar_prologue_q, bar_prologue_utccp; + transac_bar_t bar_qk_part_done[NUM_BUFS], bar_qk_done[NUM_BUFS]; // Pi = QKi^T done (i.e. Ki free) + transac_bar_t bar_sv_part_done[NUM_BUFS], bar_sv_done[NUM_BUFS]; // O += SiVi done (i.e. Vi free) + transac_bar_t bar_k_part0_ready[NUM_BUFS], bar_k_part1_ready[NUM_BUFS]; + transac_bar_t bar_v_part0_ready[NUM_BUFS], bar_v_part1_ready[NUM_BUFS]; // Vi is ready + transac_bar_t bar_p_free[NUM_BUFS]; + transac_bar_t bar_so_ready[NUM_BUFS]; // S and O are ready + transac_bar_t bar_k_valid_ready[NUM_BUFS], bar_k_valid_free[NUM_BUFS]; + array_aligned tmem_start_addr; + float rowwise_max_buf[128], rowwise_li_buf[128]; +}; + +using TiledMMA_P_tQ = decltype(make_tiled_mma( + SM100_MMA_F16BF16_2x1SM_TS_NOELECT{} +)); + +using TiledMMA_P_sQ = decltype(make_tiled_mma( + SM100_MMA_F16BF16_2x1SM_SS_NOELECT{} +)); + +using TiledMMA_O = decltype(make_tiled_mma( + SM100_MMA_F16BF16_2x1SM_SS_NOELECT{}, + Layout>{}, + Tile, Layout, Stride<_1, _256, _128>>, _16>{} // We use this permutation layout to let CTA0 takes V[:, 0:256] and CTA1 takes V[:, 256:512] +)); + +/* +Pipeline Overview: + +| Copy | MMA | Scale & Exp | + +K0 +V0 + P0 = QK0^T +K1 S0 = exp(P0) + scale(O) w.r.t P0 + P1 = QK1^T +K2 S1 = exp(P1) + O += S0V0 +V1 scale(O) w.r.t P1 + P2 = QK2^T +K3 S2 = exp(P2) + O += S1V1 +V2 scale(O) w.r.t P2 + P3 = QK3^T +K4 S3 = exp(P3) + O += S2V2 +V3 scale(O) w.r.t P3 + +... + + O += S(n-3)V(n-3) +V(n-2) scale(O) w.r.t P(n-2) + P(n-1) = QK(n-1)^T + S(n-1) = exp(P(n-1)) + O += S(n-2)V(n-2) +V(n-1) scale(O) w.r.t P(n-1) + O += S(n-1)V(n-1) +*/ + +template +__global__ void __launch_bounds__(NUM_THREADS, 1, 2) +sparse_attn_fwd_kernel(__grid_constant__ const SparsePrefillParams params, __grid_constant__ const TmaParams tma_params) { +#if IS_SM100 + const int cta_idx = blockIdx.x % 2; + const int s_q_idx = blockIdx.x / 2; + const int warp_idx = cutlass::canonical_warp_idx_sync(); + const int lane_idx = threadIdx.x % 32; + const int num_k_blocks = params.topk / B_TOPK; + const int warpgroup_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0); + const int idx_in_warpgroup = threadIdx.x % 128; + + // Prefetch TMA descriptors + if (threadIdx.x == 0) { + cute::prefetch_tma_descriptor(tma_params.tma_Q.get_tma_descriptor()); + cute::prefetch_tma_descriptor(tma_params.tma_O.get_tma_descriptor()); + cute::prefetch_tma_descriptor(&(tma_params.tensor_map_kv)); + } + + // Define shared tensors + extern __shared__ char wksp_buf[]; + SharedMemoryPlan &plan = *reinterpret_cast(wksp_buf); + Tensor sQ_full = make_tensor(make_smem_ptr(plan.u.q_full.data()), SmemLayoutQTiles<9>{}); + + int* gIndices = params.indices + s_q_idx*params.stride_indices_s_q; // [topk] + + // Allocate tmem tensors + TiledMMA tiled_mma_P_tQ = TiledMMA_P_tQ{}; + TiledMMA tiled_mma_P_sQ = TiledMMA_P_sQ{}; + TiledMMA tiled_mma_O = TiledMMA_O{}; + Tensor tP = partition_fragment_C(tiled_mma_P_tQ, Shape, Int>{}); + Tensor tQr = tiled_mma_P_tQ.get_slice(_0{}).make_fragment_A( + partition_shape_A(tiled_mma_P_tQ, Shape, Int>{}) + ); + Tensor tO = partition_fragment_C(tiled_mma_O, Shape, Int>{}); + tP.data().get() = tmem_cols::p; + tQr.data().get() = tmem_cols::q; + tO.data().get() = tmem_cols::o; + + if (warp_idx == 0) { + if (elect_one_sync()) { + // Initialize barriers + plan.bar_prologue_q.init(1); + plan.bar_prologue_utccp.init(1); + CUTE_UNROLL + for (int i = 0; i < NUM_BUFS; ++i) { + plan.bar_qk_part_done[i].init(1); + plan.bar_qk_done[i].init(1); + plan.bar_sv_part_done[i].init(1); + plan.bar_sv_done[i].init(1); + plan.bar_k_part0_ready[i].init(1); + plan.bar_k_part1_ready[i].init(1); + plan.bar_v_part0_ready[i].init(1); + plan.bar_v_part1_ready[i].init(1); + plan.bar_p_free[i].init(128*2); + plan.bar_so_ready[i].init(128*2); + plan.bar_k_valid_ready[i].init(16); + plan.bar_k_valid_free[i].init(128); + } + fence_barrier_init(); + } + } + + cute::cluster_sync(); // We must add a cluster_sync() here, or TMA from CTA1 may launch before barrier initialization in CTA0 + + if (warp_idx == 0) { + if (elect_one_sync()) { + // Copy Q + Tensor gQ = flat_divide( + tma_params.tma_Q.get_tma_tensor(tma_params.shape_Q)(_, _, s_q_idx), + Tile>{} + )(_, cta_idx, _); + launch_tma_copy(tma_params.tma_Q, gQ, sQ_full, plan.bar_prologue_q, TMA::CacheHintSm90::EVICT_FIRST); + } + + // Initialize TMEM + // We put this before cluster_arrive to make sure that the TMEM allocation is done before UTCCP + cute::TMEM::Allocator2Sm().allocate(512, plan.tmem_start_addr.data()); + TRAP_ONLY_DEVICE_ASSERT(plan.tmem_start_addr.data()[0] == 0); + cute::TMEM::Allocator2Sm().release_allocation_lock(); + __syncwarp(); + } + + if (warpgroup_idx == 0) { + cutlass::arch::warpgroup_reg_alloc<144>(); + // Scale & Exp warps + + // The following three numbers are + // - mi: max_logits used to scale Pi (i.e. O := exp2(Pi*scale - mi) @ V) + // - li: sumexp, i.e. li := sum(exp(Pi*scale - mi)) + // - real_mi: real max logits, i.e. real_mi := max(Pi*scale) + // where Pi is the i-th row of P, P := QK^T + // mi and real_mi are always consistent within the two threads that + // controls one row (i.e. thread 0+64, 1+65, 2+66, ...) after every update + float mi = MAX_INIT_VAL; + float li = 0.0f; + float real_mi = -CUDART_INF_F; + + const float2 scale = float2 {params.sm_scale_div_log2, params.sm_scale_div_log2}; + uint128_t* sS_base = (uint128_t*)plan.s.data() + idx_in_warpgroup%64 + 64*((idx_in_warpgroup/64)*8); + + CUTE_NO_UNROLL + for (int k = 0; k < num_k_blocks; ++k) { + // Wait for P + plan.bar_qk_done[k%NUM_BUFS].wait((k/NUM_BUFS)&1); + tcgen05_after_thread_sync(); + + // Load P + float2 p[(B_TOPK/2)/2]; + tmem_ld_32dp32bNx(tmem_cols::p, p); + cutlass::arch::fence_view_async_tmem_load(); + tcgen05_before_thread_sync(); + plan.bar_p_free[k%NUM_BUFS].arrive(0u); + + // Mask + plan.bar_k_valid_ready[k%NUM_BUFS].wait((k/NUM_BUFS)&1); + // The following code enables NVCC to use R2P instruction + // Although we perform 2x LDS.32 instructions here, don't worry, NVCC will + // convert them to one LDS.64 instruction. However, if we write LDS.64 + // here, NVCC won't use R2P. + uint32_t is_k_valid_lo = *(uint32_t*)(plan.is_k_valid[k%NUM_BUFS] + (idx_in_warpgroup>=64?B_TOPK/8/2:0)); + uint32_t is_k_valid_hi = *(uint32_t*)(plan.is_k_valid[k%NUM_BUFS] + (idx_in_warpgroup>=64?B_TOPK/8/2:0) + 4); + float* p_float = (float*)p; + CUTE_UNROLL + for (int i = 0; i < (B_TOPK/2)/2; i += 1) { + if (!(is_k_valid_lo >> i & 1)) + p_float[i] = -CUDART_INF_F; + } + CUTE_UNROLL + for (int i = 0; i < (B_TOPK/2)/2; i += 1) { + if (!(is_k_valid_hi >> i & 1)) + p_float[i+(B_TOPK/2)/2] = -CUDART_INF_F; + } + + // Get rowwise max of Pi + float cur_pi_max = -CUDART_INF_F; + CUTE_UNROLL + for (int i = 0; i < (B_TOPK/2); i += 1) { + cur_pi_max = max(cur_pi_max, p_float[i]); + } + cur_pi_max *= params.sm_scale_div_log2; + + plan.bar_k_valid_free[k%NUM_BUFS].arrive(); + + NamedBarrier::arrive_and_wait(128, 0); // Wait for rowwise_max_buf and sP to be ready + plan.rowwise_max_buf[idx_in_warpgroup] = cur_pi_max; + NamedBarrier::arrive_and_wait(128, 0); // TODO Name these barriers + cur_pi_max = max(cur_pi_max, plan.rowwise_max_buf[idx_in_warpgroup^64]); + real_mi = max(real_mi, cur_pi_max); + bool should_scale_o = __any_sync(0xffffffff, cur_pi_max - mi > 6.0f); + // By this point: + // - cur_pi_max, real_mi, and mi is identical within each row (i.e. thread 0+64, 1+65, ...) + // - should_scale_o is identical among threads 0~31+64~95; and is identical among threads 32~63+96~127 + + // Calc scale factor, and scale li + float new_max, scale_for_old; + if (!should_scale_o) { + // Don't scale O + scale_for_old = 1.0f; + new_max = mi; + } else { + new_max = max(cur_pi_max, mi); + scale_for_old = exp2f(mi - new_max); + } + mi = new_max; // mi is still identical within each row + li *= scale_for_old; + + // Calculate S + __nv_bfloat162 s[(B_TOPK/2)/2]; + float2 neg_new_max = float2 {-new_max, -new_max}; + CUTE_UNROLL + for (int i = 0; i < (B_TOPK/2)/2; i += 1) { + float2 d = float2_fma(p[i], scale, neg_new_max); + d.x = exp2f(d.x); + d.y = exp2f(d.y); + li += d.x + d.y; // NOTE Theorically we can have use FFMA2 here but actually this is faster... + s[i] = __float22bfloat162_rn(d); + } + + // Wait for last SV gemm, write S + if (k > 0) { + plan.bar_sv_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1); + } + CUTE_UNROLL + for (int i = 0; i < B_TOPK/2/8; i += 1) { + sS_base[64*i] = *(uint128_t*)(s + i*4); + } + + // Scale O + if (k > 0 && should_scale_o) { + float2 scale_for_old_float2 = float2 {scale_for_old, scale_for_old}; + // plan.bar_sv_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1); // NOTE We have waited for last SV gemm before + tcgen05_after_thread_sync(); + + static constexpr int CHUNK_SIZE = 32; + float2 o[CHUNK_SIZE/2]; + CUTE_UNROLL + for (int chunk_idx = 0; chunk_idx < (D_V/2)/CHUNK_SIZE; ++chunk_idx) { + // Load O + tmem_ld_32dp32bNx(tmem_cols::o + chunk_idx*CHUNK_SIZE, o); + cutlass::arch::fence_view_async_tmem_load(); + + // Mult + for (int i = 0; i < CHUNK_SIZE/2; ++i) { + o[i] = float2_mul(o[i], scale_for_old_float2); + } + + // Store O + tmem_st_32dp32bNx(tmem_cols::o + chunk_idx*CHUNK_SIZE, o); + cutlass::arch::fence_view_async_tmem_store(); + } + tcgen05_before_thread_sync(); + } + + fence_view_async_shared(); + plan.bar_so_ready[k%NUM_BUFS].arrive(0u); + } + + // Epilogue + + if (real_mi == -CUDART_INF_F) { + // real_mi == -CUDART_INF_F <=> No valid TopK indices + // We set li to 0 to fit the definition that li := exp(x[i] - mi) + li = 0.0f; + mi = -CUDART_INF_F; + } + + // Exchange li + plan.rowwise_li_buf[idx_in_warpgroup] = li; + NamedBarrier::arrive_and_wait(128, 0); + li += plan.rowwise_li_buf[idx_in_warpgroup^64]; + + // Store mi and li + if (idx_in_warpgroup < 64) { + int global_index = s_q_idx*params.h_q + cta_idx*(B_H/2) + idx_in_warpgroup; + float cur_lse = log2f(li) + mi; + params.max_logits[global_index] = real_mi; + params.lse[global_index] = cur_lse; + } + + // Wait for the last GEMM + plan.bar_sv_done[(num_k_blocks-1)%NUM_BUFS].wait(((num_k_blocks-1)/NUM_BUFS)&1); + tcgen05_after_thread_sync(); + + // Store O + float output_scale = __fdividef(1.0f, li); + Tensor sO = make_tensor(make_smem_ptr(plan.u.o.data()), SmemLayoutO{}); + constexpr int B_EPI = 64; + Tensor tma_gO = flat_divide( + tma_params.tma_O.get_tma_tensor(tma_params.shape_O)(_, _, s_q_idx), + Shape, Int>{} + )(_, _, cta_idx, _); + Tensor sO_divided = flat_divide( + sO, + Shape, Int>{} + )(_, _, _0{}, _); + auto thr_tma = tma_params.tma_O.get_slice(_0{}); + + float2 o[B_EPI/2]; + bool have_valid_indices = __any_sync(0xffffffff, li != 0); // Prevent some threads' li == 0 and some threads' li != 0 which lead to deadlock during tmem_ld + if (!have_valid_indices) { + // If there are no valid indices, we set o[i] to 0 and don't load from TMEM + CUTE_UNROLL + for (int i = 0; i < B_EPI/2; ++i) + o[i].x = o[i].y = 0.0f; + output_scale = 1.0f; + } + + float2 output_scale_float2 = make_float2(output_scale, output_scale); + + CUTE_UNROLL + for (int k = 0; k < (D_V/2)/B_EPI; ++k) { + // Load O from tO + if (have_valid_indices) { + tmem_ld_32dp32bNx(tmem_cols::o + k*B_EPI, o); + cutlass::arch::fence_view_async_tmem_load(); + } + + // Convert and store + CUTE_UNROLL + for (int i = 0; i < B_EPI/8; ++i) { + __nv_bfloat162 o_bf16[4]; + CUTE_UNROLL + for (int j = 0; j < 4; ++j) { + float2 d = float2_mul(o[i*4+j], output_scale_float2); + o_bf16[j] = __float22bfloat162_rn(d); + } + int smem_row = idx_in_warpgroup % 64; + int smem_col = (idx_in_warpgroup/64)*(D_V/2) + k*B_EPI + i*8; + *(uint128_t*)(&sO(smem_row, smem_col)) = *(uint128_t*)(o_bf16); + } + + // Sync + fence_view_async_shared(); + NamedBarrier::arrive_and_wait(128, 0); + + if (warp_idx == 0 && elect_one_sync()) { + cute::copy( + tma_params.tma_O, + thr_tma.partition_S(sO_divided(_, _, k)), + thr_tma.partition_D(tma_gO(_, _, k)) + ); + } + if (warp_idx == 1 && elect_one_sync()) { + int k2 = k + (D_V/B_EPI/2); + cute::copy( + tma_params.tma_O, + thr_tma.partition_S(sO_divided(_, _, k2)), + thr_tma.partition_D(tma_gO(_, _, k2)) + ); + } + } + + if (warp_idx == 0) { + cute::TMEM::Allocator2Sm().free(0, 512); + } + } else if (warpgroup_idx == 1) { + // Producer warp for K + cutlass::arch::warpgroup_reg_dealloc<96>(); + int warp_idx = cutlass::canonical_warp_idx_sync() - 4; + constexpr int NUM_WARPS = 4, NUM_LOCAL_ROWS_PER_WARP = (B_TOPK/2)/4/NUM_WARPS; + if (elect_one_sync()) { + bf16* sK_base = plan.u.s.k.data() + warp_idx*4*64; + + CUTE_NO_UNROLL + for (int k = 0; k < num_k_blocks; ++k) { + int4 indices[NUM_LOCAL_ROWS_PER_WARP]; + CUTE_UNROLL + for (int local_row = 0; local_row < NUM_LOCAL_ROWS_PER_WARP; ++local_row) + indices[local_row] = __ldg((int4*)(gIndices + k*B_TOPK + cta_idx*(B_TOPK/2)) + local_row*NUM_WARPS + warp_idx); + + auto load_part_ki = [&](transac_bar_t* bar, int local_col_start, int local_col_end) { + CUTE_UNROLL + for (int local_row = 0; local_row < NUM_LOCAL_ROWS_PER_WARP; ++local_row) { + CUTE_UNROLL + for (int local_col = local_col_start; local_col < local_col_end; ++local_col) + tma_gather4( + &(tma_params.tensor_map_kv), + bar, + sK_base + local_row*(4*NUM_WARPS)*64 + local_col*((B_TOPK/2)*64), + local_col*64, + indices[local_row], + TMA::CacheHintSm90::EVICT_LAST + ); + } + }; + + int cur_buf = k%NUM_BUFS; + if (k > 0) { + plan.bar_qk_part_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1); + } + load_part_ki(plan.bar_k_part0_ready+cur_buf, 0, D_sQ/64); + + if (k > 0) { + plan.bar_qk_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1); + } + load_part_ki(plan.bar_k_part1_ready+cur_buf, D_sQ/64, D_K/64); + } + } + } else if (warpgroup_idx == 2) { + // Producer warps for V + cutlass::arch::warpgroup_reg_dealloc<96>(); + int warp_idx = cutlass::canonical_warp_idx_sync() - 8; + constexpr int NUM_WARPS = 4; + + if (elect_one_sync()) { + // Wait for UTCCP + plan.bar_prologue_utccp.wait(0); + + bf16* sV_base = plan.u.s.v.data() + warp_idx*4*64; + + CUTE_NO_UNROLL + for (int k = 0; k < num_k_blocks; ++k) { + auto load_part_vi = [&](transac_bar_t* bar, int local_row_start, int local_row_end) { + CUTE_UNROLL + for (int local_row = local_row_start; local_row < local_row_end; ++local_row) { + int4 token_idxs = __ldg((int4*)(gIndices + k*B_TOPK) + local_row*NUM_WARPS + warp_idx); + CUTE_UNROLL + for (int local_col = 0; local_col < (D_V/2)/64; ++local_col) + tma_gather4( + &(tma_params.tensor_map_kv), + bar, + sV_base + local_row*(4*NUM_WARPS)*64 + local_col*(B_TOPK*64), + local_col*64 + (cta_idx?256:0), + token_idxs, + TMA::CacheHintSm90::EVICT_LAST + ); + } + }; + + int cur_buf = k%NUM_BUFS; + if (k > 0) { + plan.bar_sv_part_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1); + } + load_part_vi(plan.bar_v_part0_ready+cur_buf, 0, (B_TOPK/2)/4/NUM_WARPS); + + if (k > 0) { + plan.bar_sv_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1); + } + load_part_vi(plan.bar_v_part1_ready+cur_buf, (B_TOPK/2)/4/NUM_WARPS, B_TOPK/4/NUM_WARPS); + } + } + } else { + cutlass::arch::warpgroup_reg_alloc<168>(); + + // MMA warp + if (cta_idx == 0 && warp_idx == 12 && elect_one_sync()) { + // S -> T copy for Q + UMMA::SmemDescriptor sQ_desc = UMMA::make_umma_desc( + make_tensor( + make_smem_ptr(plan.u.q_full.data() + (B_H/2)*D_sQ), + tile_to_shape( + UMMA::Layout_K_SW128_Atom{}, + Shape, Int<64>>{} + ) + ) + ); + plan.bar_prologue_q.arrive_and_expect_tx(B_H*D_K*sizeof(bf16)); + plan.bar_prologue_q.wait(0); + tcgen05_after_thread_sync(); + CUTE_UNROLL + for (int tile_idx = 0; tile_idx < NUM_tQ_TILES; ++tile_idx) { + // A tile is 64 rows * 64 cols (128B) + CUTE_UNROLL + for (int subtile_idx = 0; subtile_idx < 8; ++subtile_idx) { + // A subtile is 64 rows * 8 cols (128b) + SM100_UTCCP_2x64dp128bitlw0213_2cta::copy( + sQ_desc + tile_idx*((B_H/2)*128/16) + subtile_idx*(16/16), // Remember that 4 LSBs are not included + tmem_cols::q + tile_idx*32 + subtile_idx*4 + ); + } + } + umma_arrive_multicast_2x1SM_noelect(plan.bar_prologue_utccp, 1|2); + + CUTE_NO_UNROLL + for (int k = 0; k < num_k_blocks+1; ++k) { + if (k < num_k_blocks) { + // Pi = QKi^T + int cur_buf = k%NUM_BUFS; + Tensor sQl = make_tensor(make_smem_ptr(plan.u.s.sq.data()), SmemLayoutQTiles{}); + Tensor sKl = make_tensor(make_smem_ptr(plan.u.s.k.data()), SmemLayoutKTiles{}); + Tensor sKr = make_tensor(make_smem_ptr(plan.u.s.k.data()+64*D_sQ), SmemLayoutKTiles{}); + + // Wait for K (part0) + plan.bar_k_part0_ready[cur_buf].arrive_and_expect_tx(B_TOPK*D_sQ*sizeof(bf16)); + plan.bar_k_part0_ready[cur_buf].wait((k/NUM_BUFS)&1); + if (k > 0) { + plan.bar_p_free[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1); + } + tcgen05_after_thread_sync(); + + utcmma_ss(tiled_mma_P_sQ, sQl, sKl, tP, true); + umma_arrive_multicast_2x1SM_noelect(plan.bar_qk_part_done[cur_buf], 1|2); + + // Wait for K (part1) + plan.bar_k_part1_ready[cur_buf].arrive_and_expect_tx(B_TOPK*(D_K-D_sQ)*sizeof(bf16)); + plan.bar_k_part1_ready[cur_buf].wait((k/NUM_BUFS)&1); + tcgen05_after_thread_sync(); + + utcmma_ts(tiled_mma_P_tQ, tQr, sKr, tP, false); + umma_arrive_multicast_2x1SM_noelect(plan.bar_qk_done[cur_buf], 1|2); + } + if (k > 0) { + // O += S(i-1)V(i-1) + int cur_buf = (k-1)%NUM_BUFS; + + Tensor sS = make_tensor(make_smem_ptr(plan.s.data()), SmemLayoutSTiles<2>{}); + Tensor sV = make_tensor(make_smem_ptr(plan.u.s.v.data()), SmemLayoutV{}); + Tensor sS_divided = flat_divide(sS, Tile, _64>{})(_, _, _0{}, _); // (B_H/2, 64, 2) + Tensor sV_divided = flat_divide(sV, Tile, _64>{})(_, _, _0{}, _); // (D_V/2, 64, 2) + + // Wait for S(i-1) and O to be scaled + plan.bar_so_ready[cur_buf].wait(((k-1)/NUM_BUFS)&1); + + // Wait for V (part0), and issue O += sS @ sV + plan.bar_v_part0_ready[cur_buf].arrive_and_expect_tx((B_TOPK/2)*D_V*sizeof(bf16)); + plan.bar_v_part0_ready[cur_buf].wait(((k-1)/NUM_BUFS)&1); + tcgen05_after_thread_sync(); + + utcmma_ss(tiled_mma_O, sS_divided(_, _, _0{}), sV_divided(_, _, _0{}), tO, k == 1); + umma_arrive_multicast_2x1SM_noelect(plan.bar_sv_part_done[cur_buf], 1|2); + + // Wait for V (part1), and issue O += sS @ sV + plan.bar_v_part1_ready[cur_buf].arrive_and_expect_tx((B_TOPK/2)*D_V*sizeof(bf16)); + plan.bar_v_part1_ready[cur_buf].wait(((k-1)/NUM_BUFS)&1); + tcgen05_after_thread_sync(); + utcmma_ss(tiled_mma_O, sS_divided(_, _, _1{}), sV_divided(_, _, _1{}), tO, false); + umma_arrive_multicast_2x1SM_noelect(plan.bar_sv_done[cur_buf], 1|2); + } + } + } else if (warp_idx == 13) { + // KV valid loading warp + static_assert(B_TOPK == 128); + if (lane_idx < 16) { + CUTE_NO_UNROLL + for (int k = 0; k < num_k_blocks; ++k) { + int cur_buf = k%NUM_BUFS; + int32x8_t indices = ldg_256_indices(gIndices + k*B_TOPK + lane_idx*8); + auto is_valid = [&](int index) -> char { + return index >= 0 && index < params.s_kv; + }; + char is_ks_valid_mask = \ + is_valid(indices.a7) << 7 | + is_valid(indices.a6) << 6 | + is_valid(indices.a5) << 5 | + is_valid(indices.a4) << 4 | + is_valid(indices.a3) << 3 | + is_valid(indices.a2) << 2 | + is_valid(indices.a1) << 1 | + is_valid(indices.a0) << 0; + + plan.bar_k_valid_free[cur_buf].wait((k/NUM_BUFS)&1^1); + plan.is_k_valid[cur_buf][lane_idx] = is_ks_valid_mask; + plan.bar_k_valid_ready[cur_buf].arrive(); + } + } + } + } + +#else + if (cute::thread0()) { + CUTE_INVALID_CONTROL_PATH("This kernel only supports sm100"); + } +#endif +} + +void run_fwd_kernel(const SparsePrefillParams& params) { + FLASH_ASSERT(params.h_kv == 1); + FLASH_ASSERT(params.topk % B_TOPK == 0); // To save some boundry checkings + FLASH_ASSERT(params.h_q == B_H); // To save some calculation + + auto shape_Q = make_shape(params.h_q, params.d_qk, params.s_q); + auto tma_Q = cute::make_tma_copy( + SM100_TMA_2SM_LOAD_NOSPLIT{}, + make_tensor( + make_gmem_ptr((bf16*)params.q), + make_layout( + shape_Q, + make_stride(params.stride_q_h_q, _1{}, params.stride_q_s_q) + ) + ), + SmemLayoutQTiles<9>{} + ); + + auto shape_O = make_shape(params.h_q, params.d_v, params.s_q); + auto tma_O = cute::make_tma_copy( + SM90_TMA_STORE{}, + make_tensor( + make_gmem_ptr((bf16*)params.out), + make_layout( + shape_O, + make_stride(params.d_v, _1{}, params.h_q*params.d_v) + ) + ), + SmemLayoutOTiles<1>{} + ); + + CUtensorMap tensor_map_kv; + { + uint64_t size[2] = {D_K, (unsigned long)params.s_kv}; + uint64_t stride[1] = {params.stride_kv_s_kv*sizeof(bf16)}; + uint32_t box_size[2] = {64, 1}; + uint32_t elem_stride[2] = {1, 1}; + CUresult res = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)( + &tensor_map_kv, + CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, + 2, + params.kv, + size, + stride, + box_size, + elem_stride, + CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, + CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B, + CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B, + CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE + ); + FLASH_ASSERT(res == CUresult::CUDA_SUCCESS); + } + + TmaParams< + decltype(shape_Q), decltype(tma_Q), + decltype(shape_O), decltype(tma_O) + > tma_params = { + shape_Q, tma_Q, + shape_O, tma_O, + tensor_map_kv + }; + auto kernel = &sparse_attn_fwd_kernel; + + constexpr size_t smem_size = sizeof(SharedMemoryPlan); + CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + cutlass::ClusterLaunchParams launch_params = { + dim3(2*params.s_q, 1, 1), + dim3(NUM_THREADS, 1, 1), + dim3(2, 1, 1), + smem_size, + params.stream + }; + cutlass::launch_kernel_on_cluster( + launch_params, (void*)kernel, params, tma_params + ); + CHECK_CUDA_KERNEL_LAUNCH(); +} + +} diff --git a/csrc/sm100/prefill/sparse/fwd.h b/csrc/sm100/prefill/sparse/fwd.h new file mode 100644 index 0000000..6558e80 --- /dev/null +++ b/csrc/sm100/prefill/sparse/fwd.h @@ -0,0 +1,9 @@ +#pragma once + +#include "params.h" + +namespace sm100 { + +void run_fwd_kernel(const SparsePrefillParams& params); + +} diff --git a/csrc/sm100/prefill/sparse/helpers.h b/csrc/sm100/prefill/sparse/helpers.h new file mode 100644 index 0000000..991b40d --- /dev/null +++ b/csrc/sm100/prefill/sparse/helpers.h @@ -0,0 +1,104 @@ +#pragma once + +#include +#include "sm100/defines.h" + +namespace sm100 { + +using namespace cute; + +using _72 = Int<72>; +using _576 = Int<576>; + +template< + typename TMA, + typename Tensor0, + typename Tensor1 +> +CUTE_DEVICE +void launch_tma_copy( + const TMA &tma_copy, + Tensor0 src, + Tensor1 dst, + transac_bar_t &bar, + const cute::TMA::CacheHintSm90 &cache_hint = cute::TMA::CacheHintSm90::EVICT_NORMAL +) { + auto thr_tma = tma_copy.get_slice(_0{}); + cute::copy( + tma_copy.with(reinterpret_cast(bar), 0, cache_hint), + thr_tma.partition_S(src), + thr_tma.partition_D(dst) + ); +} + +template< + typename TiledMMA, + typename TensorA, + typename TensorB, + typename TensorFragC +> +CUTE_DEVICE +void utcmma( + TiledMMA &tiled_mma, + TensorA sA, + TensorB sB, + TensorFragC tC_frag, + bool clear_accum +) { + tiled_mma.accumulate_ = clear_accum ? UMMA::ScaleOut::Zero : UMMA::ScaleOut::One; + ThrMMA thr_mma = tiled_mma.get_slice(_0{}); // Since A/B/C are already CTA-local tiles, this number does not matter + auto sA_frag = thr_mma.partition_fragment_A(sA); + auto sB_frag = thr_mma.partition_fragment_B(sB); + static_assert(size<2>(sA_frag) == size<2>(sB_frag)); + static_assert(size<1>(sA_frag) == size<1>(tC_frag)); + static_assert(size<1>(sB_frag) == size<2>(tC_frag)); + CUTE_UNROLL + for (int k = 0; k < size<2>(sA_frag); ++k) { + cute::gemm( + tiled_mma, + sA_frag(_, _, k), + sB_frag(_, _, k), + tC_frag + ); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } +} + +template< + typename TiledMMA, + typename TensorA, + typename TensorB, + typename TensorFragC +> +CUTE_DEVICE +void utcmma_ts( + TiledMMA &tiled_mma, + TensorA tA_frag, + TensorB sB, + TensorFragC tC_frag, + bool clear_accum +) { + tiled_mma.accumulate_ = clear_accum ? UMMA::ScaleOut::Zero : UMMA::ScaleOut::One; + ThrMMA thr_mma = tiled_mma.get_slice(_0{}); // Since A/B/C are already CTA-local tiles, this number does not matter + auto sB_frag = thr_mma.partition_fragment_B(sB); + static_assert(size<2>(tA_frag) == size<2>(sB_frag)); + CUTE_UNROLL + for (int k = 0; k < size<2>(tA_frag); ++k) { + cute::gemm( + tiled_mma, + tA_frag(_, _, k), + sB_frag(_, _, k), + tC_frag + ); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } +} + +struct bf16x8 { + __nv_bfloat162 a01; + __nv_bfloat162 a23; + __nv_bfloat162 a45; + __nv_bfloat162 a67; +}; + +} diff --git a/csrc/sm100/prefill/sparse/intrinsics.h b/csrc/sm100/prefill/sparse/intrinsics.h new file mode 100644 index 0000000..85a8203 --- /dev/null +++ b/csrc/sm100/prefill/sparse/intrinsics.h @@ -0,0 +1,638 @@ +#pragma once + +#include +#include "defines.h" + +namespace sm100 { + +using namespace cute; + +struct int32x8_t { + int a0, a1, a2, a3, a4, a5, a6, a7; +}; + +struct float8 { + float2 a01, a23, a45, a67; +}; + +__forceinline__ __device__ void cp_async_cacheglobal_l2_prefetch_256B(const void* src, void* dst) { + uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst); + asm volatile("cp.async.cg.shared.global.L2::256B [%0], [%1], %2;\n" + :: "r"(dst_addr), + "l"(src), + "n"(16)); +} + +template +CUTE_DEVICE +static void st_async_128b(void* dst_ptr, const T& data, const transac_bar_t* mbar_ptr) { + static_assert(sizeof(T) == 16, "Data type must be 16 bytes (128 bits) for st_async_128b."); + long2 data_long2 = *reinterpret_cast(&data); + uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst_ptr); + uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(mbar_ptr); + asm volatile ( + "st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.v2.s64 [%0], {%1, %2}, [%3]; \n" + : + : "r"(dst_addr), "l"(data_long2.x), "l"(data_long2.y), "r"(mbar_addr) + ); +} + +CUTE_DEVICE +void umma_arrive_multicast_noelect(uint64_t const* smem_ptr, uint16_t cta_mask) { + uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(smem_ptr); + asm volatile( + "{\n\t" + "tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.multicast::cluster.b64 [%0], %1; \n\t" + "}" + : + :"r"(bar_intptr), "h"(cta_mask)); +} + +CUTE_DEVICE +void umma_arrive_multicast_noelect(transac_bar_t const* smem_ptr, uint16_t cta_mask) { + umma_arrive_multicast_noelect((uint64_t*)smem_ptr, cta_mask); +} + +CUTE_DEVICE +void umma_arrive_multicast_2x1SM_noelect(uint64_t const* smem_ptr, uint16_t cta_mask) { + uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(smem_ptr); + asm volatile( + "{\n\t" + "tcgen05.commit.cta_group::2.mbarrier::arrive::one.shared::cluster.multicast::cluster.b64 [%0], %1; \n\t" + "}" + : + :"r"(bar_intptr), "h"(cta_mask)); +} + +CUTE_DEVICE +void umma_arrive_multicast_2x1SM_noelect(transac_bar_t const* smem_ptr, uint16_t cta_mask) { + umma_arrive_multicast_2x1SM_noelect((uint64_t*)smem_ptr, cta_mask); +} + +CUTE_DEVICE +int64_t createpolicy_evict_last() { + int64_t res; + asm volatile( + "createpolicy.fractional.L2::evict_last.b64 %0, 1.0; \n\t" + : "=l"(res) + : + ); + return res; +} + +CUTE_DEVICE +void atomicadd_f32x4_with_policy(void* global_addr, const float4 &data, int64_t cache_policy) { + asm volatile( + "red.relaxed.gpu.global.add.L2::cache_hint.v4.f32 [%4], {%0, %1, %2, %3}, %5; \n\t" + : + : "f"(data.x), "f"(data.y), "f"(data.z), "f"(data.w), + "l"((int64_t)global_addr), "l"(cache_policy) + ); +} + +CUTE_DEVICE +void tma_bulk_reduce_add(void const* src_ptr, void* dst_ptr, int32_t store_bytes) { + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(src_ptr); + asm volatile("cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [%0], [%1], %2;\n" + : + : "l"(dst_ptr), "r"(smem_int_ptr), "r"(store_bytes) + : "memory"); +} + +CUTE_DEVICE +float2 float2_add(const float2 &a, const float2 &b) { + float2 res; + cute::add(res, a, b); + return res; +} + +CUTE_DEVICE +float2 float2_mul(const float2 &a, const float2 &b) { + float2 res; + cute::mul(res, a, b); + return res; +} + +CUTE_DEVICE +float2 float2_fma(const float2 &a, const float2 &b, const float2 &c) { + // return a*b+c + float2 res; + cute::fma(res, a, b, c); + return res; +} + +CUTE_DEVICE +float2 float2_neg(const float2 &a) { + float2 t = {-1.0f, -1.0f}; + return float2_mul(a, t); +} + +__device__ __forceinline__ void tcgen05_before_thread_sync() { + asm volatile("tcgen05.fence::before_thread_sync;"); +} + +__device__ __forceinline__ void tcgen05_after_thread_sync() { + asm volatile("tcgen05.fence::after_thread_sync;"); +} + +template +CUTE_DEVICE void tma_gather4(const void* desc_ptr, transac_bar_t* mbar_ptr, void* smem_ptr, int col_idx, int4 row_idxs, TMA::CacheHintSm90 cache_hint) { + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(mbar_ptr); + if constexpr (USE_CTA0_MBAR) { + mbar_addr &= Sm100MmaPeerBitMask; + } + asm volatile( + "cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::2.L2::cache_hint [%0], [%1, {%2, %3, %4, %5, %6}], [%7], %8;\n" + : + : "r"(smem_addr), "l"(desc_ptr), "r"(col_idx), + "r"(row_idxs.x), "r"(row_idxs.y), "r"(row_idxs.z), "r"(row_idxs.w), + "r"(mbar_addr), "l"(uint64_t(cache_hint)) + : "memory" + ); +} + +// 32 data path lanes, 32-bit pattern, repeated N times +template +CUTE_DEVICE void tmem_ld_32dp32bNx(uint32_t const &src_addr, T* dst_ptr_) { + static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 128, "N must be a power of 2 and lies between 1 ~ 128"); + uint32_t* dst_ptr = reinterpret_cast(dst_ptr_); + + if constexpr (N == 1) { + asm volatile("tcgen05.ld.sync.aligned.32x32b.x1.b32" + "{%0}," + "[%1];\n" + : "=r"(dst_ptr[0]) + : "r"(src_addr)); + } else if constexpr (N == 2) { + asm volatile("tcgen05.ld.sync.aligned.32x32b.x2.b32" + "{%0, %1}," + "[%2];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]) + : "r"(src_addr)); + } else if constexpr (N == 4) { + asm volatile("tcgen05.ld.sync.aligned.32x32b.x4.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]) + : "r"(src_addr)); + } else if constexpr (N == 8) { + asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]) + : "r"(src_addr)); + } else if constexpr (N == 16) { + asm volatile("tcgen05.ld.sync.aligned.32x32b.x16.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15}," + "[%16];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]) + : "r"(src_addr)); + } else if constexpr (N == 32) { + asm volatile("tcgen05.ld.sync.aligned.32x32b.x32.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " + "%26, %27, %28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]) + : "r"(src_addr)); + } else if constexpr (N == 64) { + asm volatile( + "tcgen05.ld.sync.aligned.32x32b.x64.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, " + "%57, %58, %59, %60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]) + : "r"(src_addr)); + } else if constexpr (N == 128) { + asm volatile( + "tcgen05.ld.sync.aligned.32x32b.x128.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, " + "%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, " + "%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, " + "%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, " + "%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, " + "%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, %123, %124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]), "=r"(dst_ptr[64]), "=r"(dst_ptr[65]), + "=r"(dst_ptr[66]), "=r"(dst_ptr[67]), "=r"(dst_ptr[68]), + "=r"(dst_ptr[69]), "=r"(dst_ptr[70]), "=r"(dst_ptr[71]), + "=r"(dst_ptr[72]), "=r"(dst_ptr[73]), "=r"(dst_ptr[74]), + "=r"(dst_ptr[75]), "=r"(dst_ptr[76]), "=r"(dst_ptr[77]), + "=r"(dst_ptr[78]), "=r"(dst_ptr[79]), "=r"(dst_ptr[80]), + "=r"(dst_ptr[81]), "=r"(dst_ptr[82]), "=r"(dst_ptr[83]), + "=r"(dst_ptr[84]), "=r"(dst_ptr[85]), "=r"(dst_ptr[86]), + "=r"(dst_ptr[87]), "=r"(dst_ptr[88]), "=r"(dst_ptr[89]), + "=r"(dst_ptr[90]), "=r"(dst_ptr[91]), "=r"(dst_ptr[92]), + "=r"(dst_ptr[93]), "=r"(dst_ptr[94]), "=r"(dst_ptr[95]), + "=r"(dst_ptr[96]), "=r"(dst_ptr[97]), "=r"(dst_ptr[98]), + "=r"(dst_ptr[99]), "=r"(dst_ptr[100]), "=r"(dst_ptr[101]), + "=r"(dst_ptr[102]), "=r"(dst_ptr[103]), "=r"(dst_ptr[104]), + "=r"(dst_ptr[105]), "=r"(dst_ptr[106]), "=r"(dst_ptr[107]), + "=r"(dst_ptr[108]), "=r"(dst_ptr[109]), "=r"(dst_ptr[110]), + "=r"(dst_ptr[111]), "=r"(dst_ptr[112]), "=r"(dst_ptr[113]), + "=r"(dst_ptr[114]), "=r"(dst_ptr[115]), "=r"(dst_ptr[116]), + "=r"(dst_ptr[117]), "=r"(dst_ptr[118]), "=r"(dst_ptr[119]), + "=r"(dst_ptr[120]), "=r"(dst_ptr[121]), "=r"(dst_ptr[122]), + "=r"(dst_ptr[123]), "=r"(dst_ptr[124]), "=r"(dst_ptr[125]), + "=r"(dst_ptr[126]), "=r"(dst_ptr[127]) + : "r"(src_addr)); + } else { + asm volatile ("trap"); + } +} + +// 16 data path lanes, 256-bit pattern, repeated N times +template +CUTE_DEVICE void tmem_ld_16dp256bNx(uint32_t const &src_addr, T* dst_ptr_) { + static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 32, + "N must be a power of 2 and lies between 1 ~ 32"); + uint32_t* dst_ptr = reinterpret_cast(dst_ptr_); + + if constexpr (N == 1) { + asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]) + : "r"(src_addr)); + } else if constexpr (N == 2) { + asm volatile("tcgen05.ld.sync.aligned.16x256b.x2.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]) + : "r"(src_addr)); + } else if constexpr (N == 4) { + asm volatile( + "tcgen05.ld.sync.aligned.16x256b.x4.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15}," + "[%16];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]) + : "r"(src_addr)); + } else if constexpr (N == 8) { + asm volatile( + "tcgen05.ld.sync.aligned.16x256b.x8.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " + "%26, %27, %28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]) + : "r"(src_addr)); + } else if constexpr (N == 16) { + asm volatile( + "tcgen05.ld.sync.aligned.16x256b.x16.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]) + : "r"(src_addr)); + } else if constexpr (N == 32) { + asm volatile( + "tcgen05.ld.sync.aligned.16x256b.x32.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, " + "%70, " + "%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, " + "%84, " + "%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, " + "%98, " + "%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, " + "%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, %123, %124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]), "=r"(dst_ptr[64]), "=r"(dst_ptr[65]), + "=r"(dst_ptr[66]), "=r"(dst_ptr[67]), "=r"(dst_ptr[68]), + "=r"(dst_ptr[69]), "=r"(dst_ptr[70]), "=r"(dst_ptr[71]), + "=r"(dst_ptr[72]), "=r"(dst_ptr[73]), "=r"(dst_ptr[74]), + "=r"(dst_ptr[75]), "=r"(dst_ptr[76]), "=r"(dst_ptr[77]), + "=r"(dst_ptr[78]), "=r"(dst_ptr[79]), "=r"(dst_ptr[80]), + "=r"(dst_ptr[81]), "=r"(dst_ptr[82]), "=r"(dst_ptr[83]), + "=r"(dst_ptr[84]), "=r"(dst_ptr[85]), "=r"(dst_ptr[86]), + "=r"(dst_ptr[87]), "=r"(dst_ptr[88]), "=r"(dst_ptr[89]), + "=r"(dst_ptr[90]), "=r"(dst_ptr[91]), "=r"(dst_ptr[92]), + "=r"(dst_ptr[93]), "=r"(dst_ptr[94]), "=r"(dst_ptr[95]), + "=r"(dst_ptr[96]), "=r"(dst_ptr[97]), "=r"(dst_ptr[98]), + "=r"(dst_ptr[99]), "=r"(dst_ptr[100]), "=r"(dst_ptr[101]), + "=r"(dst_ptr[102]), "=r"(dst_ptr[103]), "=r"(dst_ptr[104]), + "=r"(dst_ptr[105]), "=r"(dst_ptr[106]), "=r"(dst_ptr[107]), + "=r"(dst_ptr[108]), "=r"(dst_ptr[109]), "=r"(dst_ptr[110]), + "=r"(dst_ptr[111]), "=r"(dst_ptr[112]), "=r"(dst_ptr[113]), + "=r"(dst_ptr[114]), "=r"(dst_ptr[115]), "=r"(dst_ptr[116]), + "=r"(dst_ptr[117]), "=r"(dst_ptr[118]), "=r"(dst_ptr[119]), + "=r"(dst_ptr[120]), "=r"(dst_ptr[121]), "=r"(dst_ptr[122]), + "=r"(dst_ptr[123]), "=r"(dst_ptr[124]), "=r"(dst_ptr[125]), + "=r"(dst_ptr[126]), "=r"(dst_ptr[127]) + : "r"(src_addr)); + } else { + asm volatile("trap"); + } +} + + +// 32 data path lanes, 32-bit pattern, repeated N times +template +CUTE_DEVICE void tmem_st_32dp32bNx(uint32_t const &dst_addr, T* src_ptr_) { + static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 128, "N must be a power of 2 and lies between 1 ~ 128"); + uint32_t* src_ptr = reinterpret_cast(src_ptr_); + + if constexpr (N == 1) { + asm volatile("tcgen05.st.sync.aligned.32x32b.x1.b32" + "[%1], {%0};\n" + : + : "r"(src_ptr[0]), + "r"(dst_addr)); + } else if constexpr (N == 2) { + asm volatile("tcgen05.st.sync.aligned.32x32b.x2.b32" + "[%2], {%0, %1};\n" + : + : "r"(src_ptr[0]), "r"(src_ptr[1]), + "r"(dst_addr)); + } else if constexpr (N == 4) { + asm volatile("tcgen05.st.sync.aligned.32x32b.x4.b32" + "[%4], {%0, %1, %2, %3};\n" + : + : "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]), + "r"(src_ptr[3]), + "r"(dst_addr)); + } else if constexpr (N == 8) { + asm volatile("tcgen05.st.sync.aligned.32x32b.x8.b32" + "[%8], {%0, %1, %2, %3, %4, %5, %6, %7};\n" + : + : "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]), + "r"(src_ptr[3]), "r"(src_ptr[4]), "r"(src_ptr[5]), + "r"(src_ptr[6]), "r"(src_ptr[7]), + "r"(dst_addr)); + } else if constexpr (N == 16) { + asm volatile("tcgen05.st.sync.aligned.32x32b.x16.b32" + "[%16], {%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15};\n" + : + : "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]), + "r"(src_ptr[3]), "r"(src_ptr[4]), "r"(src_ptr[5]), + "r"(src_ptr[6]), "r"(src_ptr[7]), "r"(src_ptr[8]), + "r"(src_ptr[9]), "r"(src_ptr[10]), "r"(src_ptr[11]), + "r"(src_ptr[12]), "r"(src_ptr[13]), "r"(src_ptr[14]), + "r"(src_ptr[15]), + "r"(dst_addr)); + } else if constexpr (N == 32) { + asm volatile("tcgen05.st.sync.aligned.32x32b.x32.b32" + "[%32], {%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " + "%26, %27, %28, %29, %30, %31};\n" + : + : "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]), + "r"(src_ptr[3]), "r"(src_ptr[4]), "r"(src_ptr[5]), + "r"(src_ptr[6]), "r"(src_ptr[7]), "r"(src_ptr[8]), + "r"(src_ptr[9]), "r"(src_ptr[10]), "r"(src_ptr[11]), + "r"(src_ptr[12]), "r"(src_ptr[13]), "r"(src_ptr[14]), + "r"(src_ptr[15]), "r"(src_ptr[16]), "r"(src_ptr[17]), + "r"(src_ptr[18]), "r"(src_ptr[19]), "r"(src_ptr[20]), + "r"(src_ptr[21]), "r"(src_ptr[22]), "r"(src_ptr[23]), + "r"(src_ptr[24]), "r"(src_ptr[25]), "r"(src_ptr[26]), + "r"(src_ptr[27]), "r"(src_ptr[28]), "r"(src_ptr[29]), + "r"(src_ptr[30]), "r"(src_ptr[31]), + "r"(dst_addr)); + } else if constexpr (N == 64) { + asm volatile( + "tcgen05.st.sync.aligned.32x32b.x64.b32" + "[%64], {%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, " + "%57, %58, %59, %60, %61, %62, %63};\n" + : + : "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]), + "r"(src_ptr[3]), "r"(src_ptr[4]), "r"(src_ptr[5]), + "r"(src_ptr[6]), "r"(src_ptr[7]), "r"(src_ptr[8]), + "r"(src_ptr[9]), "r"(src_ptr[10]), "r"(src_ptr[11]), + "r"(src_ptr[12]), "r"(src_ptr[13]), "r"(src_ptr[14]), + "r"(src_ptr[15]), "r"(src_ptr[16]), "r"(src_ptr[17]), + "r"(src_ptr[18]), "r"(src_ptr[19]), "r"(src_ptr[20]), + "r"(src_ptr[21]), "r"(src_ptr[22]), "r"(src_ptr[23]), + "r"(src_ptr[24]), "r"(src_ptr[25]), "r"(src_ptr[26]), + "r"(src_ptr[27]), "r"(src_ptr[28]), "r"(src_ptr[29]), + "r"(src_ptr[30]), "r"(src_ptr[31]), "r"(src_ptr[32]), + "r"(src_ptr[33]), "r"(src_ptr[34]), "r"(src_ptr[35]), + "r"(src_ptr[36]), "r"(src_ptr[37]), "r"(src_ptr[38]), + "r"(src_ptr[39]), "r"(src_ptr[40]), "r"(src_ptr[41]), + "r"(src_ptr[42]), "r"(src_ptr[43]), "r"(src_ptr[44]), + "r"(src_ptr[45]), "r"(src_ptr[46]), "r"(src_ptr[47]), + "r"(src_ptr[48]), "r"(src_ptr[49]), "r"(src_ptr[50]), + "r"(src_ptr[51]), "r"(src_ptr[52]), "r"(src_ptr[53]), + "r"(src_ptr[54]), "r"(src_ptr[55]), "r"(src_ptr[56]), + "r"(src_ptr[57]), "r"(src_ptr[58]), "r"(src_ptr[59]), + "r"(src_ptr[60]), "r"(src_ptr[61]), "r"(src_ptr[62]), + "r"(src_ptr[63]), + "r"(dst_addr)); + } else if constexpr (N == 128) { + asm volatile( + "tcgen05.st.sync.aligned.32x32b.x128.b32" + "[%128], {%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, " + "%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, " + "%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, " + "%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, " + "%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, " + "%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, %123, %124, %125, %126, %127};\n" + : + : "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]), + "r"(src_ptr[3]), "r"(src_ptr[4]), "r"(src_ptr[5]), + "r"(src_ptr[6]), "r"(src_ptr[7]), "r"(src_ptr[8]), + "r"(src_ptr[9]), "r"(src_ptr[10]), "r"(src_ptr[11]), + "r"(src_ptr[12]), "r"(src_ptr[13]), "r"(src_ptr[14]), + "r"(src_ptr[15]), "r"(src_ptr[16]), "r"(src_ptr[17]), + "r"(src_ptr[18]), "r"(src_ptr[19]), "r"(src_ptr[20]), + "r"(src_ptr[21]), "r"(src_ptr[22]), "r"(src_ptr[23]), + "r"(src_ptr[24]), "r"(src_ptr[25]), "r"(src_ptr[26]), + "r"(src_ptr[27]), "r"(src_ptr[28]), "r"(src_ptr[29]), + "r"(src_ptr[30]), "r"(src_ptr[31]), "r"(src_ptr[32]), + "r"(src_ptr[33]), "r"(src_ptr[34]), "r"(src_ptr[35]), + "r"(src_ptr[36]), "r"(src_ptr[37]), "r"(src_ptr[38]), + "r"(src_ptr[39]), "r"(src_ptr[40]), "r"(src_ptr[41]), + "r"(src_ptr[42]), "r"(src_ptr[43]), "r"(src_ptr[44]), + "r"(src_ptr[45]), "r"(src_ptr[46]), "r"(src_ptr[47]), + "r"(src_ptr[48]), "r"(src_ptr[49]), "r"(src_ptr[50]), + "r"(src_ptr[51]), "r"(src_ptr[52]), "r"(src_ptr[53]), + "r"(src_ptr[54]), "r"(src_ptr[55]), "r"(src_ptr[56]), + "r"(src_ptr[57]), "r"(src_ptr[58]), "r"(src_ptr[59]), + "r"(src_ptr[60]), "r"(src_ptr[61]), "r"(src_ptr[62]), + "r"(src_ptr[63]), "r"(src_ptr[64]), "r"(src_ptr[65]), + "r"(src_ptr[66]), "r"(src_ptr[67]), "r"(src_ptr[68]), + "r"(src_ptr[69]), "r"(src_ptr[70]), "r"(src_ptr[71]), + "r"(src_ptr[72]), "r"(src_ptr[73]), "r"(src_ptr[74]), + "r"(src_ptr[75]), "r"(src_ptr[76]), "r"(src_ptr[77]), + "r"(src_ptr[78]), "r"(src_ptr[79]), "r"(src_ptr[80]), + "r"(src_ptr[81]), "r"(src_ptr[82]), "r"(src_ptr[83]), + "r"(src_ptr[84]), "r"(src_ptr[85]), "r"(src_ptr[86]), + "r"(src_ptr[87]), "r"(src_ptr[88]), "r"(src_ptr[89]), + "r"(src_ptr[90]), "r"(src_ptr[91]), "r"(src_ptr[92]), + "r"(src_ptr[93]), "r"(src_ptr[94]), "r"(src_ptr[95]), + "r"(src_ptr[96]), "r"(src_ptr[97]), "r"(src_ptr[98]), + "r"(src_ptr[99]), "r"(src_ptr[100]), "r"(src_ptr[101]), + "r"(src_ptr[102]), "r"(src_ptr[103]), "r"(src_ptr[104]), + "r"(src_ptr[105]), "r"(src_ptr[106]), "r"(src_ptr[107]), + "r"(src_ptr[108]), "r"(src_ptr[109]), "r"(src_ptr[110]), + "r"(src_ptr[111]), "r"(src_ptr[112]), "r"(src_ptr[113]), + "r"(src_ptr[114]), "r"(src_ptr[115]), "r"(src_ptr[116]), + "r"(src_ptr[117]), "r"(src_ptr[118]), "r"(src_ptr[119]), + "r"(src_ptr[120]), "r"(src_ptr[121]), "r"(src_ptr[122]), + "r"(src_ptr[123]), "r"(src_ptr[124]), "r"(src_ptr[125]), + "r"(src_ptr[126]), "r"(src_ptr[127]), + "r"(dst_addr)); + } else { + asm volatile ("trap"); + } +} + + +static constexpr int PEER_ADDR_MASK = 16777216; // peer_addr = my_addr ^ PEER_ADDR_MASK. 不确定是不是在所有显卡上都是这个数字 +template +CUTE_DEVICE +T* get_peer_addr(const T* p) { + return (T*)((int64_t)(p) ^ PEER_ADDR_MASK); +} + +} diff --git a/csrc/sm100/prefill/sparse/ws_gemm.h b/csrc/sm100/prefill/sparse/ws_gemm.h new file mode 100644 index 0000000..78c9005 --- /dev/null +++ b/csrc/sm100/prefill/sparse/ws_gemm.h @@ -0,0 +1,328 @@ +#pragma once + +#include + +namespace cute { + +// Extensions to CuTe +// CuTe don't support UTCMMA with .ws, so we add it here + +template +struct SM100_MMA_F16BF16_WS_SS_NOELECT +{ + static_assert(M == 32 || M == 64 || M == 128, "SM100_MMA_F16BF16_WS_SS_NOELECT M-mode size should be 32, 64 or 128 for 1 CTA cluster MMA."); + static_assert(N == 64 || N == 128 || N == 256, + "SM100_MMA_F16BF16_WS_SS_NOELECT N-mode size should be 32, 64 or 128"); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.ws.cta_group::1.kind::f16 [%0], %1, %2, %3, p, 0; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC)); + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_WS_SS_NOELECT supports 16bit types"); + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_ws_1sm; + + // Logical shape-K is always 256bits, transform to units of elements + static constexpr int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>(); + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F16BF16_WS_SS_NOELECT::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + + +template +struct SM100_MMA_F16BF16_2x1SM_TS_NOELECT +{ + static_assert(M == 128 || M == 256, "SM100_MMA_F16BF16_2x1SM_TS_NOELECT M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_F16BF16_2x1SM_TS_NOELECT N-mode size should be a multiple of 32 between 32 and 256."); + static_assert(a_major == UMMA::Major::K, "SM100_MMA_F16BF16_2x1SM_TS_NOELECT A from TMEM can't be transposed"); + + using DRegisters = void; + using ARegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& tmem_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED) + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::f16 [%0], [%1], %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7])); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_2x1SM_TS_NOELECT without CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED"); +#endif + } +}; + + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_2x1SM_TS_NOELECT supports 16bit types"); + + using FrgTypeA = UMMA::tmem_frg_2sm; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // Size of instructions' K extent is always 256 bits; convert to units of element + constexpr static int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t tmem_a = raw_pointer_cast(A.data()); + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F16BF16_2x1SM_TS_NOELECT::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + + + +// SM100_MMA_F16BF16_2x1SM_SS without elect_one_sync() +template +struct SM100_MMA_F16BF16_2x1SM_SS_NOELECT +{ + static_assert(M == 128 || M == 256, "SM100_MMA_F16BF16_2x1SM_SS_NOELECT M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_F16BF16_2x1SM_SS_NOELECT N-mode size should be a multiple of 32 between 32 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED) + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::f16 [%0], %1, %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7])); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_2x1SM_SS_NOELECT without CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED"); +#endif + } +}; + +// template +// struct MMA_Traits> : MMA_Traits> {}; +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_2x1SM_SS_NOELECT supports 16bit types"); + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // Size of instructions's K extent is always 256bits, convert to units of element + constexpr static int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>(); + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F16BF16_2x1SM_SS_NOELECT::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +} \ No newline at end of file diff --git a/csrc/sm100/tma_cta_group2_nosplit.h b/csrc/sm100/tma_cta_group2_nosplit.h new file mode 100644 index 0000000..12e65b5 --- /dev/null +++ b/csrc/sm100/tma_cta_group2_nosplit.h @@ -0,0 +1,281 @@ +#pragma once + +#include + +namespace cute { + +// Extensions to CuTe +// CuTe 自带的 SM100_TMA_2SM_LOAD_1D 系列要求参与的 thread 数量为 2(using ThrID = Layout<_2>;),还会对数据进行切分,用起来太恶心了,所以我们自己改一版。另外,为了和其他使用 SM90 TMA 的部分统一,这里我们让它接受 TMA::CacheHintSm90 而不是 TMA::CacheHintSm100 + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// TMA_LOAD : Initiates a TMA copy from global memory to shared memory +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM100_TMA_2SM_LOAD_1D_NOSPLIT +{ + CUTE_HOST_DEVICE static void + copy([[maybe_unused]] void const* desc_ptr, [[maybe_unused]] uint64_t* mbar_ptr, [[maybe_unused]] uint64_t cache_hint, + [[maybe_unused]] void * smem_ptr, + [[maybe_unused]] int32_t const& crd0) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.1d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3}], [%2], %4;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_2D_NOSPLIT +{ + CUTE_HOST_DEVICE static void + copy([[maybe_unused]] void const* desc_ptr, [[maybe_unused]] uint64_t* mbar_ptr, [[maybe_unused]] uint64_t cache_hint, + [[maybe_unused]] void * smem_ptr, + [[maybe_unused]] int32_t const& crd0, int32_t const& crd1) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.2d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4}], [%2], %5;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_3D_NOSPLIT +{ + CUTE_HOST_DEVICE static void + copy([[maybe_unused]] void const* desc_ptr, [[maybe_unused]] uint64_t* mbar_ptr, [[maybe_unused]] uint64_t cache_hint, + [[maybe_unused]] void * smem_ptr, + [[maybe_unused]] int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.3d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5}], [%2], %6;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "r"(crd2), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_4D_NOSPLIT +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.4d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6}], [%2], %7;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_5D_NOSPLIT +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.5d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6, %7}], [%2], %8;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_NOSPLIT +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0) + { + return SM100_TMA_2SM_LOAD_1D_NOSPLIT::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { + return SM100_TMA_2SM_LOAD_2D_NOSPLIT::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { + return SM100_TMA_2SM_LOAD_3D_NOSPLIT::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1, crd2); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { + return SM100_TMA_2SM_LOAD_4D_NOSPLIT::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1, crd2, crd3); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { + return SM100_TMA_2SM_LOAD_5D_NOSPLIT::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1, crd2, crd3, crd4); + } + + using PREFETCH = typename SM90_TMA_LOAD::PREFETCH; +}; + + + +struct SM100_TMA_2SM_LOAD_NOSPLIT_OP : SM100_TMA_2SM_LOAD_NOSPLIT {}; + +// The non-executable SM100_TMA_2SM_LOAD_NOSPLIT with tma_desc and no tma_mbar +// Use .with(tma_mbar) to construct an executable version +template +struct Copy_Traits +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM100_TMA_2SM_LOAD_NOSPLIT arguments + TmaDescriptor tma_desc_; + using AuxParams = AuxParams_; + AuxParams aux_params_; + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + TmaDescriptor const* + get_tma_descriptor() const { + return &tma_desc_; + } + + // Construct an executable SM100_TMA_2SM_LOAD_NOSPLIT with tma_mbar + CUTE_HOST_DEVICE constexpr + Copy_Traits + with( + uint64_t& tma_mbar, + [[maybe_unused]] uint16_t const& multicast_mask = 0, + TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const { + // We accept multicast_mask here to keep the API for both atoms consistent + return {&tma_desc_, &tma_mbar, static_cast(cache_hint)}; + } + + // Construct an executable SM100_TMA_2SM_LOAD_NOSPLIT with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm) + CUTE_HOST_DEVICE constexpr + Copy_Traits + with( + TmaDescriptor const* new_tma_desc, + uint64_t& tma_mbar, + [[maybe_unused]] uint16_t const& multicast_mask = 0, + TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const { + // We accept multicast_mask here to keep the API for both atoms consistent + return {new_tma_desc, &tma_mbar, static_cast(cache_hint)}; + } + + // Generate the TMA coord tensor + template + CUTE_HOST_DEVICE constexpr + auto + get_tma_tensor(GShape const& g_shape) const { + static_assert(is_congruent::value); + return make_counting_tensor(make_layout(g_shape, aux_params_.g_stride_)); + } + + // Don't try to execute a copy with SM100_TMA_2SM_LOAD_NOSPLIT before calling .with() + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) = delete; +}; + +// The executable SM100_TMA_2SM_LOAD_NOSPLIT with tma_desc and tma_mbar +template +struct Copy_Traits + : TMA_LOAD_Unpack +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM100_TMA_2SM_LOAD_NOSPLIT arguments + tuple< + TmaDescriptor const*, + uint64_t*, // smem mbarrier + uint64_t // cache hint + > const opargs_; + + CUTE_HOST_DEVICE + Copy_Traits(TmaDescriptor const* desc, uint64_t* mbar, uint64_t cache) + : opargs_(desc, mbar, cache) {} +}; + + +} diff --git a/csrc/sm100/ws_gemm.h b/csrc/sm100/ws_gemm.h new file mode 100644 index 0000000..54edd3d --- /dev/null +++ b/csrc/sm100/ws_gemm.h @@ -0,0 +1,426 @@ +#pragma once + +#include + +namespace cute { + +// Extensions to CuTe +// CuTe don't support UTCMMA with .ws, so we add it here + +template +struct SM100_MMA_F16BF16_WS_SS_NOELECT +{ + static_assert(M == 32 || M == 64 || M == 128, "SM100_MMA_F16BF16_WS_SS_NOELECT M-mode size should be 32, 64 or 128 for 1 CTA cluster MMA."); + static_assert(N == 64 || N == 128 || N == 256, + "SM100_MMA_F16BF16_WS_SS_NOELECT N-mode size should be 32, 64 or 128"); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.ws.cta_group::1.kind::f16 [%0], %1, %2, %3, p, 0; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC)); + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_WS_SS_NOELECT supports 16bit types"); + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_ws_1sm; + + // Logical shape-K is always 256bits, transform to units of elements + static constexpr int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>(); + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F16BF16_WS_SS_NOELECT::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +using namespace cute; +template +struct SM100_MMA_F16BF16_WS_TS_NOELECT +{ + static_assert(M == 32 || M == 64 || M == 128, "SM100_MMA_F16BF16_WS_TS_NOELECT M-mode size should be 32, 64 or 128 for 1 CTA cluster MMA."); + static_assert(N == 64 || N == 128 || N == 256, + "SM100_MMA_F16BF16_WS_TS_NOELECT N-mode size should be 32, 64 or 128"); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& tmem_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.ws.cta_group::1.kind::f16 [%0], [%1], %2, %3, p, 0; \n\t" + "}\n" + : + : "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC)); + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_WS_TS_NOELECT supports 16bit types"); + + using FrgTypeA = UMMA::tmem_frg_1sm; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + // Logical shape-K is always 256 bits; transform to units of elements + static constexpr int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint32_t tmem_a = raw_pointer_cast(A.data()); + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F16BF16_WS_TS_NOELECT::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + + +template +struct SM100_MMA_F16BF16_2x1SM_TS_NOELECT +{ + static_assert(M == 128 || M == 256, "SM100_MMA_F16BF16_2x1SM_TS_NOELECT M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_F16BF16_2x1SM_TS_NOELECT N-mode size should be a multiple of 32 between 32 and 256."); + static_assert(a_major == UMMA::Major::K, "SM100_MMA_F16BF16_2x1SM_TS_NOELECT A from TMEM can't be transposed"); + + using DRegisters = void; + using ARegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& tmem_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED) + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::f16 [%0], [%1], %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7])); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_2x1SM_TS_NOELECT without CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED"); +#endif + } +}; + + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_2x1SM_TS_NOELECT supports 16bit types"); + + using FrgTypeA = UMMA::tmem_frg_2sm; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // Size of instructions' K extent is always 256 bits; convert to units of element + constexpr static int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t tmem_a = raw_pointer_cast(A.data()); + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F16BF16_2x1SM_TS_NOELECT::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + + + +// SM100_MMA_F16BF16_2x1SM_SS without elect_one_sync() +template +struct SM100_MMA_F16BF16_2x1SM_SS_NOELECT +{ + static_assert(M == 128 || M == 256, "SM100_MMA_F16BF16_2x1SM_SS_NOELECT M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_F16BF16_2x1SM_SS_NOELECT N-mode size should be a multiple of 32 between 32 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED) + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::f16 [%0], %1, %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7])); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_2x1SM_SS_NOELECT without CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED"); +#endif + } +}; + +// template +// struct MMA_Traits> : MMA_Traits> {}; +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_2x1SM_SS_NOELECT supports 16bit types"); + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // Size of instructions's K extent is always 256bits, convert to units of element + constexpr static int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>(); + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F16BF16_2x1SM_SS_NOELECT::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +} // namespace cute diff --git a/setup.py b/setup.py index 338117f..15fa671 100644 --- a/setup.py +++ b/setup.py @@ -9,6 +9,7 @@ BuildExtension, CUDAExtension, IS_WINDOWS, + CUDA_HOME ) @@ -22,8 +23,21 @@ def get_features_args(): return features_args def get_arch_flags(): + # Check NVCC Version + # NOTE The "CUDA_HOME" here is not necessarily from the `CUDA_HOME` environment variable. For more details, see `torch/utils/cpp_extension.py` + assert CUDA_HOME is not None, "PyTorch must be compiled with CUDA support" + nvcc_version = subprocess.check_output( + [os.path.join(CUDA_HOME, "bin", "nvcc"), '--version'], stderr=subprocess.STDOUT + ).decode('utf-8') + nvcc_version_number = nvcc_version.split('release ')[1].split(',')[0].strip() + major, minor = map(int, nvcc_version_number.split('.')) + print(f'Compiling using NVCC {major}.{minor}') + DISABLE_SM100 = is_flag_set("FLASH_MLA_DISABLE_SM100") DISABLE_SM90 = is_flag_set("FLASH_MLA_DISABLE_SM90") + if major < 12 or (major == 12 and minor <= 8): + assert DISABLE_SM100, "sm100 compilation for Flash MLA requires NVCC 12.9 or higher. Please set FLASH_MLA_DISABLE_SM100=1 to disable sm100 compilation, or update your environment." + arch_flags = [] if not DISABLE_SM100: arch_flags.extend(["-gencode", "arch=compute_100a,code=sm_100a"]) @@ -55,8 +69,10 @@ def get_nvcc_thread_args(): "csrc/sm90/decode/dense/splitkv_mla.cu", "csrc/sm90/decode/sparse_fp8/splitkv_mla.cu", "csrc/sm90/prefill/sparse/fwd.cu", + "csrc/sm100/decode/sparse_fp8/splitkv_mla.cu", "csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu", "csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu", + "csrc/sm100/prefill/sparse/fwd.cu", ], extra_compile_args={ "cxx": cxx_args + get_features_args(), diff --git a/tests/test_flash_mla_decoding.py b/tests/test_flash_mla_decoding.py index 64ddf72..d6c1f81 100644 --- a/tests/test_flash_mla_decoding.py +++ b/tests/test_flash_mla_decoding.py @@ -319,6 +319,11 @@ def main(torch_dtype): ] testcases = correctness_cases + corner_cases + performance_cases + + # Prune out unsupported cases + cc_major, cc_minor = torch.cuda.get_device_capability() + if cc_major == 10: + testcases = [t for t in testcases if (t.is_fp8 and t.topk is not None)] for testcase in testcases: test_flash_mla(testcase) From 472477e875b746c731eb63669bd81b7def9679db Mon Sep 17 00:00:00 2001 From: Shengyu Liu Date: Mon, 29 Sep 2025 18:23:20 +0800 Subject: [PATCH 15/20] Add Deep-Dive Blog for the New Sparse Decoding Kernel on Hopper (#100) --- csrc/sm100/tma_cta_group2_nosplit.h | 2 +- docs/250929-hopper-fp8-sparse-deep-dive.md | 52 ++++++++++++++++++++++ 2 files changed, 53 insertions(+), 1 deletion(-) create mode 100644 docs/250929-hopper-fp8-sparse-deep-dive.md diff --git a/csrc/sm100/tma_cta_group2_nosplit.h b/csrc/sm100/tma_cta_group2_nosplit.h index 12e65b5..045456d 100644 --- a/csrc/sm100/tma_cta_group2_nosplit.h +++ b/csrc/sm100/tma_cta_group2_nosplit.h @@ -5,7 +5,7 @@ namespace cute { // Extensions to CuTe -// CuTe 自带的 SM100_TMA_2SM_LOAD_1D 系列要求参与的 thread 数量为 2(using ThrID = Layout<_2>;),还会对数据进行切分,用起来太恶心了,所以我们自己改一版。另外,为了和其他使用 SM90 TMA 的部分统一,这里我们让它接受 TMA::CacheHintSm90 而不是 TMA::CacheHintSm100 +// CuTe's SM100_TMA_2SM_LOAD_1D requires two threads to perform this operation cooperatively (using ThrID = Layout<_2>;), which doesn't fit our use case. //////////////////////////////////////////////////////////////////////////////////////////////////// /// TMA_LOAD : Initiates a TMA copy from global memory to shared memory diff --git a/docs/250929-hopper-fp8-sparse-deep-dive.md b/docs/250929-hopper-fp8-sparse-deep-dive.md new file mode 100644 index 0000000..cd71346 --- /dev/null +++ b/docs/250929-hopper-fp8-sparse-deep-dive.md @@ -0,0 +1,52 @@ +# A Deep Dive Into The Flash MLA FP8 Decoding Kernel on Hopper + +With the release of DeepSeek-V3.2, we have doubled the context length of our models from 64K tokens to 128K tokens. This puts significant pressure on GPU memory (a single request with 128K tokens requires a KVCache of size $576 \times 2 \times 62 \times 128 \times 1024 = 8.72\ \mathrm{GiB}$), which can lead to out-of-memory (OOM) errors or under-utilized GPUs due to small batch sizes. To address this, we introduced FP8 KVCache for DeepSeek-V3.2. + +However, writing a high-performance decoding kernel is challenging due to the need for dequantization and its sparse memory access patterns. In this blog, we share the story behind our new FP8 sparse decoding kernel for Hopper GPUs. We will first explain our FP8 KVCache format, then provide a theoretical analysis of clock cycles, and finally detail the techniques used in our new kernel. + +## The FP8 KVCache Format + +Recall that the decoding phase of the Multi-head Latent Attention (MLA) algorithm operates similarly to Multi-Query Attention (MQA), with 128 query heads and 1 key head, where `head_dim_k = 576` and `head_dim_v = 512` respectively. To reduce the size of the KVCache while maintaining accuracy, we use a fine-grained quantization method. Specifically, we apply tile-level quantization (with a tile size of $1 \times 128$) to the first 512 elements in each token's KV Cache. This results in 512 `float8_e4m3` values and 4 `float32` scale factors. For the remaining 64 elements (the RoPE part), we do not apply quantization as they are sensitive to precision loss. Therefore, in GPU memory, each token's KVCache occupies 656 bytes, consisting of 512 `float8_e4m3`s, 4 `float32`s, and 64 `bfloat16`s. + +Inside the kernel, we first dequantize the 512 `float8_e4m3` values into 512 `bfloat16`s. We then concatenate them with the 64 original `bfloat16` values from the RoPE part. Finally, we perform the MQA calculation using matrix multiplication-add (MMA) operations in `bfloat16` precision (i.e., the inputs to the MMAs are in `bfloat16` and the outputs are in `float32`. This applies to both the QK gemm and the attention-score-V gemm). + +## Theoretical Analysis of Clock Cycles + +The main challenge is that Tensor Cores (which handle MMA calculations) are extremely fast, while the dequantization process, performed on CUDA Cores, struggles to keep up. + +The basic unit on an NVIDIA GPU is the Stream Multiprocessor (SM). You can think of each SM as an independent core on the GPU. For simplicity, let's focus on a single SM. Each SM can process 4096 MMA Flops per clock cycle (calculated as `989 TFlops / 1830 MHz / 132 SMs` on H800). If we assign each CTA (CUDA Thread Block) to process 64 query heads, it only requires $64 \times (576+512) \times 2 / 4096 \approx 34$ cycles for MMA operations per K/V token. + +However, because the H800 cannot directly cast `float8_e4m3` to `bfloat16`, dequantizing the KVCache for one token requires the following steps: +1. Convert `float8_e4m3` to `half` +2. Convert `half` to `float32` +3. Convert `float32` to `bfloat16` +4. Multiply the converted `bfloat16` value by the `float32` scale factor + +According to [NVIDIA's documentation](https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#throughput-of-native-arithmetic-instructions), we need at least $(\frac{1}{64} + \frac{1}{64} + \frac{1}{16} + \frac{1}{256}) \times 512 \approx 50$ cycles for dequantizing each token! This is significantly more than the 34 cycles required for the MMA operations, meaning the kernel is **dequantization-bound**. If left unaddressed, dequantization would become the performance bottleneck, leaving the powerful Tensor Cores underutilized. + +## Crossover + +Before we continue, it's important to note a key fact: every query head within the same query token attends to the same key heads, because this is Multi-Query Attention (MQA). + +Recall that each CTA processes 64 query heads, while DeepSeek-V3.2 has a total of 128 query heads. If we can find a way to "share" the dequantized K/V values between two CTAs that are processing different sets of query heads, then each CTA would only need to dequantize **half** of the KV cache – which is fantastic! We call this method "crossover", since the idea was actually inspired by [Chromosomal crossover](https://en.wikipedia.org/wiki/Chromosomal_crossover) during [Meiosis](https://en.wikipedia.org/wiki/Meiosis). + +The next question is, how do we implement this in CUDA? Before NVIDIA's Hopper architecture, the only options for data exchange between CTAs were global memory or the L2 cache, which are slow. However, the powerful Distributed Shared Memory gave us a new solution. + +## Distributed Shared Memory to the Rescue + +Distributed Shared Memory (DSM) is a new feature introduced with the Hopper architecture, alongside the CTA Cluster (thread block cluster). CTAs within the same cluster can directly access each other's shared memory. For more details, you can refer to [NVIDIA Hopper Architecture In-Depth](https://developer.nvidia.com/blog/nvidia-hopper-architecture-in-depth/). + +Here is how we use it: We launch CTAs in clusters of size 2. Each CTA within a cluster is responsible for 64 query heads from the same query token. Each CTA performs the following steps: +1. Loads *half* of the quantized K/V from global memory. We use a wide `__ldg` load with a width of 128 bits to improve performance. +2. Dequantizes its assigned half on the CUDA Cores. +3. Stores the dequantized K/V into its own shared memory. +4. Simultaneously uses `st.async` to write the dequantized K/V into the shared memory of the other CTA in the cluster. + +For synchronization between these operations, we rely on the [cluster transaction barrier](https://developer.nvidia.com/blog/nvidia-hopper-architecture-in-depth/), another powerful programming primitive available in CTA Clusters. After the data exchange is complete, each CTA has the *full* set of dequantized K and V values available in its own shared memory, which it can then use to perform the MMA operations. + +## Performance +Using these techniques, we achieved 410 TFLOPS in a compute-bound configuration (batch_size=128, num_heads=128, s_q=2, topk=2048) on H800 SXM5 GPUs. This is a significant improvement over the 250 TFLOPS achieved by our previous FP8 sparse decoding kernel without the crossover technique. + +Although this number is still below the 640 TFLOPS peak of our previous bfloat16 dense decoding kernel, one reason is that it's a **sparse** kernel, and its topk is only 2048. With a smaller topk, the relative overhead of the kernel's prologue and epilogue becomes larger compared with dense decoding with long context length. If we set topk to a larger value, such as 32768, this kernel can achieve up to 460 TFLOPS. + +From another perspective, the execution time of this kernel in the configuration mentioned above is comparable to that of the dense decoding kernel when the sequence length is around 3000. When the sequence length exceeds 3000, the performance advantage of our new kernel becomes even more significant. This also highlights the effectiveness of our DeepSeek Sparse Attention algorithm. From 42f3c5789db65b5ff1eadea0fe4ce3805483a8e8 Mon Sep 17 00:00:00 2001 From: Shengyu Liu Date: Mon, 29 Sep 2025 18:29:18 +0800 Subject: [PATCH 16/20] Rename deep dive blog --- ...parse-deep-dive.md => 20250929-hopper-fp8-sparse-deep-dive.md} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename docs/{250929-hopper-fp8-sparse-deep-dive.md => 20250929-hopper-fp8-sparse-deep-dive.md} (100%) diff --git a/docs/250929-hopper-fp8-sparse-deep-dive.md b/docs/20250929-hopper-fp8-sparse-deep-dive.md similarity index 100% rename from docs/250929-hopper-fp8-sparse-deep-dive.md rename to docs/20250929-hopper-fp8-sparse-deep-dive.md From e9b67321b17e53b3743cbe0e180973a943c4b217 Mon Sep 17 00:00:00 2001 From: Shengyu Liu Date: Tue, 30 Sep 2025 18:21:54 +0800 Subject: [PATCH 17/20] Update blog and README --- README.md | 2 +- docs/20250929-hopper-fp8-sparse-deep-dive.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 354cdde..df021de 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ FlashMLA is DeepSeek's library of optimized attention kernels, powering the [Dee ## News -- **2025.09.29 Release of Sparse Attention Kernels**: With the launch of [DeepSeek-V3.2](https://github.com/deepseek-ai/DeepSeek-V3.2-Exp), we are releasing the corresponding token-level sparse attention kernels. These kernels power the model's DeepSeek Sparse Attention (DSA) and achieve up to 640 TFlops during prefilling and 410 TFlops during decoding. +- **2025.09.29 Release of Sparse Attention Kernels**: With the launch of [DeepSeek-V3.2](https://github.com/deepseek-ai/DeepSeek-V3.2-Exp), we are releasing the corresponding token-level sparse attention kernels. These kernels power the model's DeepSeek Sparse Attention (DSA) and achieve up to 640 TFlops during prefilling and 410 TFlops during decoding. We also release a deep-dive blog for our new FP8 sparse decoding kernel. Check it out [here](docs/20250929-hopper-fp8-sparse-deep-dive.md). - **2025.08.01 Kernels for MHA on Blackwell**: Thanks to [NVIDIA's PR](https://github.com/deepseek-ai/FlashMLA/pull/76) for MHA forward / backward kernels on Blackwell! - **2025.04.22 Deep-Dive Blog**: We'd love to share the technical details behind the new FlashMLA kernel! Check out our deep-dive write-up [here](docs/20250422-new-kernel-deep-dive.md). - **2025.04.22 Performance Update**: We're excited to announce the new release of Flash MLA, which delivers 5% ~ 15% performance improvement for compute-bound workloads, achieving up to 660 TFlops on NVIDIA H800 SXM5 GPUs. The interface of the new version is fully compatible with the old one. Simply upgrade to the new version for an immediate performance boost! 🚀🚀🚀 diff --git a/docs/20250929-hopper-fp8-sparse-deep-dive.md b/docs/20250929-hopper-fp8-sparse-deep-dive.md index cd71346..cf3c166 100644 --- a/docs/20250929-hopper-fp8-sparse-deep-dive.md +++ b/docs/20250929-hopper-fp8-sparse-deep-dive.md @@ -14,7 +14,7 @@ Inside the kernel, we first dequantize the 512 `float8_e4m3` values into 512 `bf The main challenge is that Tensor Cores (which handle MMA calculations) are extremely fast, while the dequantization process, performed on CUDA Cores, struggles to keep up. -The basic unit on an NVIDIA GPU is the Stream Multiprocessor (SM). You can think of each SM as an independent core on the GPU. For simplicity, let's focus on a single SM. Each SM can process 4096 MMA Flops per clock cycle (calculated as `989 TFlops / 1830 MHz / 132 SMs` on H800). If we assign each CTA (CUDA Thread Block) to process 64 query heads, it only requires $64 \times (576+512) \times 2 / 4096 \approx 34$ cycles for MMA operations per K/V token. +The basic unit on an NVIDIA GPU is the Stream Multiprocessor (SM). You can think of each SM as an independent core on the GPU. For simplicity, let's focus on a single SM. Each SM can process 4096 MMA Flops per clock cycle (calculated as `989 TFlops / 1830 MHz / 132 SMs` on H800). In our kernel, each CTA runs on one SM, and each SM is only mapped to one CTA. If we assign each CTA (CUDA Thread Block) to process 64 query heads, it only requires $64 \times (576+512) \times 2 / 4096 \approx 34$ cycles for MMA operations per K/V token. However, because the H800 cannot directly cast `float8_e4m3` to `bfloat16`, dequantizing the KVCache for one token requires the following steps: 1. Convert `float8_e4m3` to `half` From 7f55c7151acfeaacfd610c022aaa26f836c9fac1 Mon Sep 17 00:00:00 2001 From: Jiashi Li Date: Tue, 30 Sep 2025 23:29:18 +0800 Subject: [PATCH 18/20] Fix error message --- csrc/pybind.cpp | 6 +++--- .../kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp | 2 +- .../sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp | 3 ++- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 6ec3f21..13541d4 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -41,7 +41,7 @@ struct Arch { } }; -// DecodingAttnImplMeta - A struct to hold metadata for Decoding Attention Implementation (i.e. Hopper Dense BF16, Hopper Sparse FP8, etc.) +// DecodingAttnImplMeta - A struct to hold metadata for Decoding Attention Implementation (i.e. SM90 Dense BF16, SM90 Sparse FP8, etc.) struct DecodingAttnImplMeta { int num_sm_parts; int fixed_overhead_num_blocks; @@ -334,7 +334,7 @@ fwd_kvcache_mla( TORCH_CHECK(q_dtype == torch::kBFloat16, "Sparse FP8 MLA only supports BFloat16 on SM90"); sm90::run_flash_splitkv_mla_fp8_sparse_kernel(params, stream); } else { - TORCH_CHECK(false, "Dense FP8 MLA is not supported on SM90"); + TORCH_CHECK(false, "Only FP8 kvcahe is supported for sparse MLA on SM90"); } } else { if (is_fp8) { @@ -347,7 +347,7 @@ fwd_kvcache_mla( sm90::run_flash_splitkv_mla_kernel(params, stream); #endif } else { - TORCH_CHECK(false, "Unsupported tensor dtype for query"); + TORCH_CHECK(false, "Unsupported dtype for dense MLA on SM90"); } } } diff --git a/csrc/sm100/prefill/dense/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp b/csrc/sm100/prefill/dense/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp index 057b45e..c34713b 100644 --- a/csrc/sm100/prefill/dense/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp +++ b/csrc/sm100/prefill/dense/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp @@ -949,7 +949,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { TensorC const& coord, TensorShape const& tensor_shape) { - //TODO Performance of FlashMLA on hopper is dropped with latest cutlass, so here revert the to the old version. + // TODO: Performance of FlashMLA on sm90 is dropped with latest cutlass, so here revert the to the old version. // Tensor preds = cute::lazy::transform(coord, [&](auto const& c) { return elem_less(c, tensor_shape); }); auto copy_op = make_cotiled_copy( diff --git a/csrc/sm100/prefill/dense/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp b/csrc/sm100/prefill/dense/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp index 0d4af85..c25d638 100644 --- a/csrc/sm100/prefill/dense/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp +++ b/csrc/sm100/prefill/dense/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp @@ -953,7 +953,8 @@ struct Sm100FmhaBwdMlaKernelTmaWarpSpecialized { TensorR const& regs, TensorC const& coord, TensorShape const& tensor_shape) { - //TODO Performance of FlashMLA on hopper is dropped with latest cutlass, so here revert the to the old version. + + // TODO: Performance of FlashMLA on sm90 is dropped with latest cutlass, so here revert the to the old version. // Tensor preds = cute::lazy::transform(coord, [&](auto const& c) { return elem_less(c, tensor_shape); }); auto copy_op = make_cotiled_copy( From 1858932afd3bd4cf2d3f91bfdaa9f8d96f2afe14 Mon Sep 17 00:00:00 2001 From: Jiashi Li Date: Tue, 30 Sep 2025 23:33:43 +0800 Subject: [PATCH 19/20] Code format --- tests/quant.py | 26 ++++++------- tests/test_flash_mla_decoding.py | 64 ++++++++++++++++---------------- tests/test_flash_mla_prefill.py | 27 +++++++------- tests/test_fmha_sm100.py | 19 +++++----- 4 files changed, 66 insertions(+), 70 deletions(-) diff --git a/tests/quant.py b/tests/quant.py index afee4b2..0624759 100644 --- a/tests/quant.py +++ b/tests/quant.py @@ -1,5 +1,3 @@ -import enum - import torch def quantize_k_cache( @@ -19,20 +17,20 @@ def quantize_k_cache( input_k_cache = input_k_cache.squeeze(2) # [num_blocks, block_size, d] input_elem_size = input_k_cache.element_size() - result = torch.empty((num_blocks, block_size, dv + num_tiles*4 + input_elem_size*(d-dv)), dtype=torch.float8_e4m3fn, device=input_k_cache.device) + result = torch.empty((num_blocks, block_size, dv + num_tiles * 4 + input_elem_size * (d - dv)), dtype=torch.float8_e4m3fn, device=input_k_cache.device) result_k_nope_part = result[..., :dv] - result_k_scale_factor = result[..., dv:dv + num_tiles*4].view(torch.float32) - result_k_rope_part = result[..., dv + num_tiles*4:].view(input_k_cache.dtype) + result_k_scale_factor = result[..., dv:dv + num_tiles * 4].view(torch.float32) + result_k_rope_part = result[..., dv + num_tiles * 4:].view(input_k_cache.dtype) result_k_rope_part[:] = input_k_cache[..., dv:] for tile_idx in range(0, num_tiles): - cur_scale_factors_inv = torch.abs(input_k_cache[..., tile_idx*tile_size:(tile_idx+1)*tile_size]).max(dim=-1).values / 448.0 # [num_blocks, block_size] + cur_scale_factors_inv = torch.abs(input_k_cache[..., tile_idx * tile_size:(tile_idx + 1) * tile_size]).max(dim=-1).values / 448.0 # [num_blocks, block_size] result_k_scale_factor[:, :, tile_idx] = cur_scale_factors_inv cur_scale_factors_inv.unsqueeze_(-1) # [num_blocks, block_size, 1] - cur_quantized_nope = (input_k_cache[..., tile_idx*tile_size:(tile_idx+1)*tile_size].float() / cur_scale_factors_inv.float()).to(torch.float8_e4m3fn) - result_k_nope_part[..., tile_idx*tile_size:(tile_idx+1)*tile_size] = cur_quantized_nope - + cur_quantized_nope = (input_k_cache[..., tile_idx * tile_size:(tile_idx + 1) * tile_size].float() / cur_scale_factors_inv.float()).to(torch.float8_e4m3fn) + result_k_nope_part[..., tile_idx * tile_size:(tile_idx + 1) * tile_size] = cur_quantized_nope + result = result.view(num_blocks, block_size, 1, -1) return result @@ -55,14 +53,14 @@ def dequantize_k_cache( quant_k_cache = quant_k_cache.view(num_blocks, block_size, -1) input_nope = quant_k_cache[..., :dv] - input_scale = quant_k_cache[..., dv:dv + num_tiles*4].view(torch.float32) - input_rope = quant_k_cache[..., dv + num_tiles*4:].view(torch.bfloat16) + input_scale = quant_k_cache[..., dv:dv + num_tiles * 4].view(torch.float32) + input_rope = quant_k_cache[..., dv + num_tiles * 4:].view(torch.bfloat16) result[..., dv:] = input_rope for tile_idx in range(0, num_tiles): - cur_nope = input_nope[..., tile_idx*tile_size:(tile_idx+1)*tile_size].to(torch.float32) + cur_nope = input_nope[..., tile_idx * tile_size:(tile_idx + 1) * tile_size].to(torch.float32) cur_scales = input_scale[..., tile_idx].unsqueeze(-1) - result[..., tile_idx*tile_size:(tile_idx+1)*tile_size] = cur_nope * cur_scales - + result[..., tile_idx * tile_size:(tile_idx + 1) * tile_size] = cur_nope * cur_scales + result = result.view(num_blocks, block_size, 1, d) return result diff --git a/tests/test_flash_mla_decoding.py b/tests/test_flash_mla_decoding.py index d6c1f81..dc140d7 100644 --- a/tests/test_flash_mla_decoding.py +++ b/tests/test_flash_mla_decoding.py @@ -2,20 +2,20 @@ import math import random import dataclasses -from typing import Optional, Tuple, List +from typing import Optional, Tuple import torch import triton -import quant import flash_mla +import quant from lib import cdiv, check_is_allclose @dataclasses.dataclass class TestParam: - b: int # Batch size - s_q: int # Number of queries for one request - s_k: int # Seq len, or mean seq len if varlen == True + b: int # Batch size + s_q: int # Number of queries for one request + s_k: int # Seq len, or mean seq len if varlen == True is_varlen: bool is_causal: bool is_fp8: bool @@ -24,8 +24,8 @@ class TestParam: is_all_indices_invalid: bool = False have_zero_seqlen_k: bool = False block_size: int = 64 - h_q: int = 128 # Number of q heads - h_kv: int = 1 # Number of kv heads + h_q: int = 128 # Number of q heads + h_kv: int = 1 # Number of kv heads d: int = 576 # Q/K head dim (= dv + RoPE dim) dv: int = 512 # V head dim seed: int = 0 @@ -71,7 +71,7 @@ def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch. cur_num_blocks = cdiv(cur_len, t.block_size) blocked_k[block_table[i][cur_num_blocks:]] = float("nan") if cur_len % t.block_size != 0: - blocked_k[block_table[i][cur_num_blocks-1]][cur_len % t.block_size:] = float("nan") + blocked_k[block_table[i][cur_num_blocks - 1]][cur_len % t.block_size:] = float("nan") block_table[i][cur_num_blocks:] = 2147480000 return cache_seqlens, q, block_table, blocked_k, None, None else: @@ -82,12 +82,12 @@ def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch. # Generate indices for j in range(t.s_q): cur_abs_indices = torch.randperm(int(cache_seqlens_cpu[i].item()), device="cpu")[:t.topk] - cur_blocked_indices = block_table_cpu[i, cur_abs_indices//t.block_size]*t.block_size + (cur_abs_indices%t.block_size) + cur_blocked_indices = block_table_cpu[i, cur_abs_indices // t.block_size] * t.block_size + (cur_abs_indices % t.block_size) if len(cur_abs_indices) < t.topk: pad_len = t.topk - len(cur_abs_indices) cur_abs_indices = torch.cat([cur_abs_indices, torch.full((pad_len,), -1, device='cpu')]) cur_blocked_indices = torch.cat([cur_blocked_indices, torch.full((pad_len,), -1, device='cpu')]) - + # Mask KV perm = torch.randperm(t.topk, device='cpu') cur_abs_indices = cur_abs_indices[perm] @@ -100,7 +100,7 @@ def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch. abs_indices[i, j, :] = cur_abs_indices indices_in_kvcache[i, j, :] = cur_blocked_indices - + # Mask nonused KV as NaN all_indices = indices_in_kvcache.flatten().tolist() all_indices = list(set(all_indices)) @@ -109,11 +109,11 @@ def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch. all_indices = torch.tensor(all_indices, dtype=torch.int32, device='cpu') blocked_k = blocked_k.view(-1, t.h_kv, t.d) - nonused_indices_mask = torch.ones(blocked_k.size(0)*blocked_k.size(1), dtype=torch.bool, device='cpu') + nonused_indices_mask = torch.ones(blocked_k.size(0) * blocked_k.size(1), dtype=torch.bool, device='cpu') nonused_indices_mask[all_indices] = False blocked_k[nonused_indices_mask, :, :] = float("nan") blocked_k = blocked_k.view(-1, t.block_size, t.h_kv, t.d) - + abs_indices = abs_indices.to(q.device) indices_in_kvcache = indices_in_kvcache.to(q.device) @@ -139,7 +139,7 @@ def get_topk_attn_mask(s_q: int, s_k: int, indices: torch.Tensor): valid_indices = cur_indices[cur_indices != -1] mask[i, valid_indices] = True return mask - + def scaled_dot_product_attention( batch_idx: int, query: torch.Tensor, # [h_q, s_q, d] @@ -157,7 +157,7 @@ def scaled_dot_product_attention( if h_kv != 1: kv = kv.repeat_interleave(h_q // h_kv, dim=0) kv[kv != kv] = 0.0 - attn_weight = query @ kv.transpose(-2, -1) # [h_q, s_q, s_k] + attn_weight = query @ kv.transpose(-2, -1) # [h_q, s_q, s_k] if (is_causal and query.size(1) > 1) or indices is not None: mask = torch.ones(s_q, s_k, dtype=torch.bool) if is_causal: @@ -169,14 +169,14 @@ def scaled_dot_product_attention( attn_bias.masked_fill_(mask.logical_not(), float("-inf")) attn_weight += attn_bias.to(q.dtype) attn_weight /= math.sqrt(query.size(-1)) - lse = attn_weight.logsumexp(dim=-1) # [h_q, s_q] + lse = attn_weight.logsumexp(dim=-1) # [h_q, s_q] attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) output = attn_weight @ kv[..., :dv] # [h_q, s_q, dv] # Correct for q tokens which has no attendable k lonely_q_mask = (lse == float("-inf")) output[lonely_q_mask.unsqueeze(-1).broadcast_to(h_q, s_q, dv)] = 0.0 lse[lonely_q_mask] = float("+inf") - + return output, lse b, s_q, h_q, d = q.size() @@ -202,7 +202,7 @@ def scaled_dot_product_attention( lse_ref[i] = cur_lse out_ref = out_ref.to(torch.bfloat16) return out_ref, lse_ref - + @torch.inference_mode() def test_flash_mla(t: TestParam): @@ -235,7 +235,7 @@ def test_flash_mla(t: TestParam): def run_flash_mla(): return flash_mla.flash_mla_with_kvcache( q, - blocked_k if not t.is_fp8 else blocked_k_quantized, # type: ignore + blocked_k if not t.is_fp8 else blocked_k_quantized, # type: ignore block_table, cache_seqlens, t.dv, @@ -248,27 +248,27 @@ def run_flash_mla(): out_ans, lse_ans = run_flash_mla() out_ref, lse_ref = reference_torch(cache_seqlens, block_table, q, blocked_k, t.dv, t.is_causal, abs_indices) - assert check_is_allclose("out", out_ans, out_ref, abs_tol=8e-4, rel_tol=2.01/128, cos_diff_tol=5e-6) - assert check_is_allclose("lse", lse_ans, lse_ref, abs_tol=1e-6, rel_tol=8.01/65536) + assert check_is_allclose("out", out_ans, out_ref, abs_tol=8e-4, rel_tol=2.01 / 128, cos_diff_tol=5e-6) + assert check_is_allclose("lse", lse_ans, lse_ref, abs_tol=1e-6, rel_tol=8.01 / 65536) if t.test_performance: - time_usage: float = triton.testing.do_bench(run_flash_mla)/1000 # type: ignore + time_usage: float = triton.testing.do_bench(run_flash_mla) / 1000 # type: ignore mean_attended_seqlens = cache_seqlens.float().mean().item() if t.topk is None else t.topk - compute_volume_flop = t.b*t.h_q*t.s_q*sum([ - 2*t.d*mean_attended_seqlens, # Q * K^T - 2*mean_attended_seqlens*t.dv, # attention * V + compute_volume_flop = t.b * t.h_q * t.s_q * sum([ + 2 * t.d * mean_attended_seqlens, # Q * K^T + 2 * mean_attended_seqlens * t.dv, # attention * V ]) q_elem_size = torch.bfloat16.itemsize - kv_token_size = 656 if t.is_fp8 else t.d*torch.bfloat16.itemsize - memory_volume_B = t.b*sum([ - t.s_q*t.h_q*(t.d*q_elem_size), # Q - (t.s_q if t.topk is not None else 1) * mean_attended_seqlens*t.h_kv*kv_token_size, # K/V - t.s_q*t.h_q*(t.dv*q_elem_size), # Output + kv_token_size = 656 if t.is_fp8 else t.d * torch.bfloat16.itemsize + memory_volume_B = t.b * sum([ + t.s_q * t.h_q * (t.d * q_elem_size), # Q + (t.s_q if t.topk is not None else 1) * mean_attended_seqlens * t.h_kv * kv_token_size, # K/V + t.s_q * t.h_q * (t.dv * q_elem_size), # Output ]) achieved_tflops = compute_volume_flop / time_usage / 1e12 achieved_gBps = memory_volume_B / time_usage / 1e9 - print(f"{time_usage*1000:.3f} ms, {achieved_tflops:.0f} TFLOPS, {achieved_gBps:.0f} GB/s") + print(f"{time_usage * 1000:.3f} ms, {achieved_tflops:.0f} TFLOPS, {achieved_gBps:.0f} GB/s") def main(torch_dtype): @@ -324,7 +324,7 @@ def main(torch_dtype): cc_major, cc_minor = torch.cuda.get_device_capability() if cc_major == 10: testcases = [t for t in testcases if (t.is_fp8 and t.topk is not None)] - + for testcase in testcases: test_flash_mla(testcase) diff --git a/tests/test_flash_mla_prefill.py b/tests/test_flash_mla_prefill.py index 19a6dbe..d2f5b7e 100644 --- a/tests/test_flash_mla_prefill.py +++ b/tests/test_flash_mla_prefill.py @@ -35,8 +35,8 @@ def generate_testcase(t: TestParam) -> Testcase: torch.manual_seed(t.seed) torch.cuda.manual_seed(t.seed) random.seed(t.seed) - q = torch.randn((t.b, t.s_q, t.h_q, t.d_qk), dtype=torch.bfloat16)/10 - kv = torch.randn((t.b, t.s_kv, t.h_kv, t.d_qk), dtype=torch.bfloat16)/10 + q = torch.randn((t.b, t.s_q, t.h_q, t.d_qk), dtype=torch.bfloat16) / 10 + kv = torch.randn((t.b, t.s_kv, t.h_kv, t.d_qk), dtype=torch.bfloat16) / 10 q.clamp_(-10, 10) kv.clamp_(-10, 10) @@ -48,7 +48,7 @@ def generate_testcase(t: TestParam) -> Testcase: # NOTE We use the following method to generate indices so that most indices lies within [s_kv-20000, s_kv), which is more realistic for sparse attention near_mask = torch.randint(0, 32, (min(t.topk, t.s_kv),)) < 31 cur_indices = torch.randperm(t.s_kv)[:t.topk] - cur_indices[near_mask] = torch.randint(max(0, t.s_kv-20000), t.s_kv-1, (near_mask.sum().item(),)) + cur_indices[near_mask] = torch.randint(max(0, t.s_kv - 20000), t.s_kv - 1, (near_mask.sum().item(),)) if len(cur_indices) < t.topk: cur_indices = torch.cat([cur_indices, torch.full((t.topk - len(cur_indices),), 2147480000)]) cur_indices = cur_indices[torch.randperm(t.topk)] @@ -72,9 +72,9 @@ def get_flop(p: TestParam) -> float: def reference_torch(p: TestParam, t: Testcase, sm_scale: float) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: def log2sumexp2(a: torch.Tensor, dim: int) -> torch.Tensor: return torch.logsumexp(a * math.log(2), dim=dim) * math.log2(math.e) - + assert p.b == 1 - indices = t.indices[0, :, 0, :] # [s_q, topk] + indices = t.indices[0, :, 0, :] # [s_q, topk] invalid_indices_mask = (indices < 0) | (indices >= p.s_kv) qs = t.q[0, :, :, :].float() # [s_q, h_q, d_qk] kvs = t.kv[0, :, 0, :].float() # [s_kv, d_qk] @@ -104,15 +104,15 @@ def run_ans(): return flash_mla_sparse_fwd( t.q.squeeze(0), t.kv.squeeze(0), t.indices.squeeze(0), sm_scale=sm_scale ) - + ans_out, ans_max_logits, ans_lse = run_ans() torch.cuda.synchronize() if p.benchmark: flop = get_flop(p) - prefill_ans_time: float = triton.testing.do_bench(run_ans, warmup=10, rep=20)/1000 # type: ignore - prefill_flops = flop/prefill_ans_time/1e12 - print(f"Prefill: {prefill_ans_time*1e6:4.0f} us, {prefill_flops:.3f} TFlops") + prefill_ans_time: float = triton.testing.do_bench(run_ans, warmup=10, rep=20) / 1000 # type: ignore + prefill_flops = flop / prefill_ans_time / 1e12 + print(f"Prefill: {prefill_ans_time * 1e6:4.0f} us, {prefill_flops:.3f} TFlops") if p.check_correctness: torch.cuda.synchronize() @@ -120,9 +120,9 @@ def run_ans(): torch.cuda.synchronize() is_correct = True - is_correct &= check_is_allclose("out", ans_out, ref_out, abs_tol=8e-4, rel_tol=2.01/128, cos_diff_tol=7e-6) - is_correct &= check_is_allclose("max_logits", ans_max_logits, ref_max_logits, abs_tol=1e-6, rel_tol=2.01/65536) - is_correct &= check_is_allclose("lse", ans_lse, ref_lse, abs_tol=1e-6, rel_tol=2.01/65536) + is_correct &= check_is_allclose("out", ans_out, ref_out, abs_tol=8e-4, rel_tol=2.01 / 128, cos_diff_tol=7e-6) + is_correct &= check_is_allclose("max_logits", ans_max_logits, ref_max_logits, abs_tol=1e-6, rel_tol=2.01 / 65536) + is_correct &= check_is_allclose("lse", ans_lse, ref_lse, abs_tol=1e-6, rel_tol=2.01 / 65536) return is_correct else: @@ -187,11 +187,10 @@ def run_ans(): is_correct = run_test(test) if not is_correct: failed_cases.append(test) - + if len(failed_cases) > 0: print(f"\033[31m\033[1m{len(failed_cases)} / {len(testcases)} cases failed:\033[0m") for case in failed_cases: print(f" {case}") else: print(f"\033[32m\033[1mAll {len(testcases)} cases passed!\033[0m") - diff --git a/tests/test_fmha_sm100.py b/tests/test_fmha_sm100.py index 6b2ba45..62e3344 100644 --- a/tests/test_fmha_sm100.py +++ b/tests/test_fmha_sm100.py @@ -5,7 +5,6 @@ import triton from flash_mla import flash_attn_varlen_func - from lib import check_is_allclose def get_window_size(causal, window): @@ -71,10 +70,10 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win causal, window) == 0).sum().item() for i in range(b)]) # print(f"{total_q=}, {max_seqlen_q=}, {total_k=}, {max_seqlen_k=}, {total_attn_compute=}, {cu_seqlens_q.tolist()}, {cu_seqlens_k.tolist()}") - q = torch.randn(total_q, h, d)/10 - k = torch.randn(total_k, h_k, d)/10 - v = torch.randn(total_k, h_k, dv)/10 - grad_out = torch.randn(total_q, h, dv)/10 + q = torch.randn(total_q, h, d) / 10 + k = torch.randn(total_k, h_k, d) / 10 + v = torch.randn(total_k, h_k, dv) / 10 + grad_out = torch.randn(total_q, h, dv) / 10 softmax_scale = (d + 100) ** (-0.5) q1 = q.clone().requires_grad_() @@ -123,14 +122,14 @@ def torch_attn(): if check_correctness: out_torch, lse_torch = torch_attn() - assert check_is_allclose("out", out_flash, out_torch, abs_tol=1e-3, rel_tol=8.01/128, cos_diff_tol=7e-6) - assert check_is_allclose("lse", lse_flash, lse_torch, abs_tol=1e-6, rel_tol=2.01/65536) + assert check_is_allclose("out", out_flash, out_torch, abs_tol=1e-3, rel_tol=8.01 / 128, cos_diff_tol=7e-6) + assert check_is_allclose("lse", lse_flash, lse_torch, abs_tol=1e-6, rel_tol=2.01 / 65536) if has_bwd: out_torch.backward(grad_out, retain_graph=True) - assert check_is_allclose("dq", q1.grad, q2.grad, abs_tol=1e-3, rel_tol=8.01/128, cos_diff_tol=7e-6) - assert check_is_allclose("dk", k1.grad, k2.grad, abs_tol=1e-3, rel_tol=8.01/128, cos_diff_tol=7e-6) - assert check_is_allclose("dv", v1.grad, v2.grad, abs_tol=1e-3, rel_tol=8.01/128, cos_diff_tol=7e-6) + assert check_is_allclose("dq", q1.grad, q2.grad, abs_tol=1e-3, rel_tol=8.01 / 128, cos_diff_tol=7e-6) + assert check_is_allclose("dk", k1.grad, k2.grad, abs_tol=1e-3, rel_tol=8.01 / 128, cos_diff_tol=7e-6) + assert check_is_allclose("dv", v1.grad, v2.grad, abs_tol=1e-3, rel_tol=8.01 / 128, cos_diff_tol=7e-6) def forward(): return flash_attn() From 1408756a88e52a25196b759eaf8db89d2b51b5a1 Mon Sep 17 00:00:00 2001 From: Jiashi Li Date: Wed, 1 Oct 2025 00:04:36 +0800 Subject: [PATCH 20/20] Update README --- README.md | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index df021de..f08d888 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ ## Introduction -FlashMLA is DeepSeek's library of optimized attention kernels, powering the [DeepSeek-V3](https://github.com/deepseek-ai/DeepSeek-V3) and [DeepSeek-V3.2](https://github.com/deepseek-ai/DeepSeek-V3.2-Exp) models. This repository contains the following implementations: +FlashMLA is DeepSeek's library of optimized attention kernels, powering the [DeepSeek-V3](https://github.com/deepseek-ai/DeepSeek-V3) and [DeepSeek-V3.2-Exp](https://github.com/deepseek-ai/DeepSeek-V3.2-Exp) models. This repository contains the following implementations: **Sparse Attention Kernels** @@ -19,7 +19,7 @@ FlashMLA is DeepSeek's library of optimized attention kernels, powering the [Dee ## News - **2025.09.29 Release of Sparse Attention Kernels**: With the launch of [DeepSeek-V3.2](https://github.com/deepseek-ai/DeepSeek-V3.2-Exp), we are releasing the corresponding token-level sparse attention kernels. These kernels power the model's DeepSeek Sparse Attention (DSA) and achieve up to 640 TFlops during prefilling and 410 TFlops during decoding. We also release a deep-dive blog for our new FP8 sparse decoding kernel. Check it out [here](docs/20250929-hopper-fp8-sparse-deep-dive.md). -- **2025.08.01 Kernels for MHA on Blackwell**: Thanks to [NVIDIA's PR](https://github.com/deepseek-ai/FlashMLA/pull/76) for MHA forward / backward kernels on Blackwell! +- **2025.08.01 Kernels for MHA on SM100**: Thanks to [NVIDIA's PR](https://github.com/deepseek-ai/FlashMLA/pull/76) for MHA forward / backward kernels on SM100! - **2025.04.22 Deep-Dive Blog**: We'd love to share the technical details behind the new FlashMLA kernel! Check out our deep-dive write-up [here](docs/20250422-new-kernel-deep-dive.md). - **2025.04.22 Performance Update**: We're excited to announce the new release of Flash MLA, which delivers 5% ~ 15% performance improvement for compute-bound workloads, achieving up to 660 TFlops on NVIDIA H800 SXM5 GPUs. The interface of the new version is fully compatible with the old one. Simply upgrade to the new version for an immediate performance boost! 🚀🚀🚀 @@ -31,9 +31,7 @@ FlashMLA is DeepSeek's library of optimized attention kernels, powering the [Dee python tests/test_flash_mla_decoding.py ``` -The dense MLA decoding kernel can achieve up to 3000 GB/s in memory-bound configuration and 660 TFLOPS in computation-bound configuration on H800 SXM5, using CUDA 12.8. For token-level sparse MLA decoding kernel (which uses an FP8 KV cache while performing the matrix multiplication in bfloat16), it can achieve 410 TFLOPS in compute-bound configuration on H800 SXM5, CUDA 12.8. - -For Blackwell GPUs, the token-level sparse MLA decoding kernel can achieve up to 350 TFlops (on B200) which is not really optimized yet. +The dense MLA decoding kernel achieves up to 3000 GB/s in memory-bound configuration and 660 TFLOPS in computation-bound configuration on H800 SXM5 with CUDA 12.8. The token-level sparse MLA decoding kernel (which uses an FP8 KV cache while performing the matrix multiplication in bfloat16) achieves 410 TFLOPS in compute-bound configuration on H800 SXM5 with CUDA 12.8, and achieves up to 350 TFlops on B200 (which is not really optimized yet). #### Test & benchmark MHA prefill (Dense): @@ -49,22 +47,22 @@ It achieves up to 1460 TFlops in forward and 1000 TFlops in backward computation python tests/test_flash_mla_prefill.py ``` -It achieves up to 640 TFlops in forward computation on H800 SXM5, CUDA 12.8, and achieves up to 1450 TFlops on B200, CUDA 12.9. +It achieves up to 640 TFlops in forward computation on H800 SXM5 with CUDA 12.8, and achieves up to 1450 TFlops on B200, CUDA 12.9. ## Requirements -- Hopper / Blackwell GPUs (See the support matrix below) -- CUDA 12.8 and above (CUDA 12.9+ is required for Blackwell kernels) +- SM90 / SM100 (See the support matrix below) +- CUDA 12.8 and above (CUDA 12.9+ is required for SM100 kernels) - PyTorch 2.0 and above Support matrix: | Kernel | GPU Architecture | MLA Mode [2] | KVCache Format | | :---: | :---: | :---: | :---: | -| Dense Decoding | Hopper | MQA | BF16 | -| Sparse Decoding | Hopper & Blackwell | MQA | FP8 [1] | -| Dense Prefill | Blackwell | MHA | | -| Sparse Prefill | Hopper & Blackwell | MQA | | +| Dense Decoding | SM90 | MQA | BF16 | +| Sparse Decoding | SM90 & SM100 | MQA | FP8 [1] | +| Dense Prefill | SM100 | MHA | | +| Sparse Prefill | SM90 & SM100 | MQA | | [1]: For more details on using FP8 KV cache, see documents below.