From 133938411394f1192ac7d03ce80492c614be54ce Mon Sep 17 00:00:00 2001 From: Jonas Kuebler Date: Wed, 29 Oct 2025 12:49:54 +0000 Subject: [PATCH 1/5] fix sign-off Signed-off-by: Jonas Kuebler --- hopper/fp8_promotion_helper.hpp | 56 +++++++++++++++++ hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp | 77 +++++++++++++++++++++--- 2 files changed, 125 insertions(+), 8 deletions(-) create mode 100644 hopper/fp8_promotion_helper.hpp diff --git a/hopper/fp8_promotion_helper.hpp b/hopper/fp8_promotion_helper.hpp new file mode 100644 index 0000000000..4ea88a65e4 --- /dev/null +++ b/hopper/fp8_promotion_helper.hpp @@ -0,0 +1,56 @@ +#pragma once + +#include + +namespace flash { + +/** + * GmmaPromotionHelper: Fixes FP8 tensor core accumulation precision loss + * + * Problem: FP8 tensor cores do not use full FP32 accumulators, + * causing precision loss particularly if the contraction dimension is large. + * + * Solution: Accumulate a few tiles in temporary fragment, then promote + * to main accumulator using true FP32 arithmetic every promotion_interval tiles. + * + * This significantly reduces the accuracy drop on very long context datasets. + */ +template +struct GmmaPromotionHelper { + AccumFragment& accum_main; + AccumFragment accum_temp; + int mma_count = 0; + static constexpr int promotion_interval = 2; // Tunable: balance precision vs performance + + __device__ GmmaPromotionHelper(AccumFragment& main_accum) + : accum_main(main_accum), accum_temp(cute::make_fragment_like(main_accum)) { + cute::clear(accum_temp); + } + + template + __device__ void step(MmaFunc&& mma_func) { + mma_func(accum_temp); + mma_count++; + if (mma_count >= promotion_interval) { + promote(); + } + } + + __device__ void promote() { + // Add temp accumulator to main accumulator (FP32 precision) + #pragma unroll + for (int i = 0; i < cute::size(accum_main); ++i) { + accum_main(i) += accum_temp(i); + } + cute::clear(accum_temp); + mma_count = 0; + } + + __device__ void flush() { + if (mma_count > 0) { + promote(); + } + } +}; + +} // namespace flash diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index 6e0d8b768b..6d3c3e5e79 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -8,6 +8,7 @@ #include #include #include +#include "fp8_promotion_helper.hpp" #include "cutlass/pipeline/pipeline.hpp" #include "cute/tensor.hpp" @@ -1253,7 +1254,17 @@ struct CollectiveMainloopFwdSm90 { if constexpr(!HasQv) { if (!UseSchedulerBarrier || warp_group_idx == 0) { consumer_wait(pipeline_v, smem_pipe_read_v); } } - flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); + + // FP8 promotion helper for P×V GEMM + if constexpr (Is_FP8) { + flash::GmmaPromotionHelper promoter(tOrO); + promoter.step([&](auto& accum_frag) { + flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read_v.index()), accum_frag); + }); + promoter.flush(); + } else { + flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); + } warp_scheduler_barrier_arrive(); warpgroup_wait<1>(); pipeline_k.consumer_release(smem_pipe_read); // release K @@ -1325,7 +1336,17 @@ struct CollectiveMainloopFwdSm90 { cutlass::arch::NamedBarrier::arrive(NumMmaThreadsQK + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); if constexpr (RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); } if constexpr (!HasQv) { consumer_wait(pipeline_v, smem_pipe_read); } - flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO); + + // FP8 promotion helper for P×V GEMM + if constexpr (Is_FP8) { + flash::GmmaPromotionHelper promoter(tOrO); + promoter.step([&](auto& accum_frag) { + flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), accum_frag); + }); + promoter.flush(); + } else { + flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO); + } float const v_descale = !Is_FP8 || params.ptr_v_descale == nullptr ? 1.0f : params.ptr_v_descale[bidb * get<0>(params.stride_v_descale) + bidh_kv * get<1>(params.stride_v_descale)]; // cute::copy(softmax.finalize(v_descale), scores_scale); finalize_dispatch(scores_scale, v_descale); @@ -1382,11 +1403,31 @@ struct CollectiveMainloopFwdSm90 { if constexpr (!MmaPV_is_RS && !MmaPV_use_RS_WG1) { arrive_on_P_write_barrier(); } if constexpr (!HasQv) { consumer_wait(pipeline_v, smem_pipe_read); } warp_scheduler_barrier_sync(); - if constexpr (!MmaPV_use_RS_WG1) { - flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO); + + // FP8 promotion helper to prevent FP22 accumulation precision loss + if constexpr (Is_FP8) { + flash::GmmaPromotionHelper promoter(tOrO); + + if constexpr (!MmaPV_use_RS_WG1) { + promoter.step([&](auto& accum_frag) { + flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), accum_frag); + }); + } else { + TiledMmaPV_RS tiled_mma_pv_rs; + promoter.step([&](auto& accum_frag) { + flash::gemm(tiled_mma_pv_rs, tOrP, tOrV(_, _, _, smem_pipe_read.index()), accum_frag); + }); + } + + promoter.flush(); } else { - TiledMmaPV_RS tiled_mma_pv_rs; - flash::gemm(tiled_mma_pv_rs, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); + // Original BF16 path (no promotion needed) + if constexpr (!MmaPV_use_RS_WG1) { + flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO); + } else { + TiledMmaPV_RS tiled_mma_pv_rs; + flash::gemm(tiled_mma_pv_rs, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); + } } if constexpr (!MmaPV_is_RS && MmaPV_use_RS_WG1) { arrive_on_P_write_barrier(); } warpgroup_wait<0>(); @@ -1509,7 +1550,17 @@ struct CollectiveMainloopFwdSm90 { // If HasQv, then by the time P is ready, V must have been ready as well if constexpr (!HasQv) { pipeline_v.consumer_wait(smem_pipe_read); } cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); - flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); + + // FP8 promotion helper for P×V GEMM + if constexpr (Is_FP8) { + flash::GmmaPromotionHelper promoter(tOrO); + promoter.step([&](auto& accum_frag) { + flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), accum_frag); + }); + promoter.flush(); + } else { + flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); + } cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); pipeline_v.consumer_release(smem_pipe_read); // release V --n_block; @@ -1524,7 +1575,17 @@ struct CollectiveMainloopFwdSm90 { auto barrier_token = pipeline_v.consumer_try_wait(smem_pipe_read); pipeline_v.consumer_wait(smem_pipe_read, barrier_token); } - flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); + + // FP8 promotion helper for P×V GEMM + if constexpr (Is_FP8) { + flash::GmmaPromotionHelper promoter(tOrO); + promoter.step([&](auto& accum_frag) { + flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), accum_frag); + }); + promoter.flush(); + } else { + flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); + } cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); pipeline_v.consumer_release(smem_pipe_read); // release V }; From 2ba74fa35ae8b0b094d87f9323e757c7ef0318ea Mon Sep 17 00:00:00 2001 From: Jonas Kuebler Date: Thu, 30 Oct 2025 09:00:10 +0000 Subject: [PATCH 2/5] simplify to always add to fp32 accumulator Signed-off-by: Jonas Kuebler --- hopper/fp8_promotion_helper.hpp | 56 -------------- hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp | 97 +++++++++++++++--------- 2 files changed, 60 insertions(+), 93 deletions(-) delete mode 100644 hopper/fp8_promotion_helper.hpp diff --git a/hopper/fp8_promotion_helper.hpp b/hopper/fp8_promotion_helper.hpp deleted file mode 100644 index 4ea88a65e4..0000000000 --- a/hopper/fp8_promotion_helper.hpp +++ /dev/null @@ -1,56 +0,0 @@ -#pragma once - -#include - -namespace flash { - -/** - * GmmaPromotionHelper: Fixes FP8 tensor core accumulation precision loss - * - * Problem: FP8 tensor cores do not use full FP32 accumulators, - * causing precision loss particularly if the contraction dimension is large. - * - * Solution: Accumulate a few tiles in temporary fragment, then promote - * to main accumulator using true FP32 arithmetic every promotion_interval tiles. - * - * This significantly reduces the accuracy drop on very long context datasets. - */ -template -struct GmmaPromotionHelper { - AccumFragment& accum_main; - AccumFragment accum_temp; - int mma_count = 0; - static constexpr int promotion_interval = 2; // Tunable: balance precision vs performance - - __device__ GmmaPromotionHelper(AccumFragment& main_accum) - : accum_main(main_accum), accum_temp(cute::make_fragment_like(main_accum)) { - cute::clear(accum_temp); - } - - template - __device__ void step(MmaFunc&& mma_func) { - mma_func(accum_temp); - mma_count++; - if (mma_count >= promotion_interval) { - promote(); - } - } - - __device__ void promote() { - // Add temp accumulator to main accumulator (FP32 precision) - #pragma unroll - for (int i = 0; i < cute::size(accum_main); ++i) { - accum_main(i) += accum_temp(i); - } - cute::clear(accum_temp); - mma_count = 0; - } - - __device__ void flush() { - if (mma_count > 0) { - promote(); - } - } -}; - -} // namespace flash diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index 6d3c3e5e79..93087b13fa 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -8,7 +8,6 @@ #include #include #include -#include "fp8_promotion_helper.hpp" #include "cutlass/pipeline/pipeline.hpp" #include "cute/tensor.hpp" @@ -1255,13 +1254,17 @@ struct CollectiveMainloopFwdSm90 { if (!UseSchedulerBarrier || warp_group_idx == 0) { consumer_wait(pipeline_v, smem_pipe_read_v); } } - // FP8 promotion helper for P×V GEMM + // FP8: Use temp fragment to avoid FP22 accumulation in tOrO if constexpr (Is_FP8) { - flash::GmmaPromotionHelper promoter(tOrO); - promoter.step([&](auto& accum_frag) { - flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read_v.index()), accum_frag); - }); - promoter.flush(); + auto temp_frag = cute::make_fragment_like(tOrO); + cute::clear(temp_frag); + flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read_v.index()), temp_frag); + + // Add to main accumulator in true FP32 + #pragma unroll + for (int i = 0; i < cute::size(tOrO); ++i) { + tOrO(i) += temp_frag(i); + } } else { flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); } @@ -1337,13 +1340,17 @@ struct CollectiveMainloopFwdSm90 { if constexpr (RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); } if constexpr (!HasQv) { consumer_wait(pipeline_v, smem_pipe_read); } - // FP8 promotion helper for P×V GEMM + // FP8: Use temp fragment to avoid FP22 accumulation in tOrO if constexpr (Is_FP8) { - flash::GmmaPromotionHelper promoter(tOrO); - promoter.step([&](auto& accum_frag) { - flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), accum_frag); - }); - promoter.flush(); + auto temp_frag = cute::make_fragment_like(tOrO); + cute::clear(temp_frag); + flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), temp_frag); + + // Add to main accumulator in true FP32 + #pragma unroll + for (int i = 0; i < cute::size(tOrO); ++i) { + tOrO(i) += temp_frag(i); + } } else { flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO); } @@ -1404,24 +1411,32 @@ struct CollectiveMainloopFwdSm90 { if constexpr (!HasQv) { consumer_wait(pipeline_v, smem_pipe_read); } warp_scheduler_barrier_sync(); - // FP8 promotion helper to prevent FP22 accumulation precision loss + // FP8: Use temp fragment to avoid FP22 accumulation in tOrO if constexpr (Is_FP8) { - flash::GmmaPromotionHelper promoter(tOrO); - if constexpr (!MmaPV_use_RS_WG1) { - promoter.step([&](auto& accum_frag) { - flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), accum_frag); - }); + auto temp_frag = cute::make_fragment_like(tOrO); + cute::clear(temp_frag); + flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), temp_frag); + + // Add to main accumulator in true FP32 + #pragma unroll + for (int i = 0; i < cute::size(tOrO); ++i) { + tOrO(i) += temp_frag(i); + } } else { TiledMmaPV_RS tiled_mma_pv_rs; - promoter.step([&](auto& accum_frag) { - flash::gemm(tiled_mma_pv_rs, tOrP, tOrV(_, _, _, smem_pipe_read.index()), accum_frag); - }); + auto temp_frag = cute::make_fragment_like(tOrO); + cute::clear(temp_frag); + flash::gemm(tiled_mma_pv_rs, tOrP, tOrV(_, _, _, smem_pipe_read.index()), temp_frag); + + // Add to main accumulator in true FP32 + #pragma unroll + for (int i = 0; i < cute::size(tOrO); ++i) { + tOrO(i) += temp_frag(i); + } } - - promoter.flush(); } else { - // Original BF16 path (no promotion needed) + // Original BF16 path if constexpr (!MmaPV_use_RS_WG1) { flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO); } else { @@ -1551,13 +1566,17 @@ struct CollectiveMainloopFwdSm90 { if constexpr (!HasQv) { pipeline_v.consumer_wait(smem_pipe_read); } cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); - // FP8 promotion helper for P×V GEMM + // FP8: Use temp fragment to avoid FP22 accumulation in tOrO if constexpr (Is_FP8) { - flash::GmmaPromotionHelper promoter(tOrO); - promoter.step([&](auto& accum_frag) { - flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), accum_frag); - }); - promoter.flush(); + auto temp_frag = cute::make_fragment_like(tOrO); + cute::clear(temp_frag); + flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), temp_frag); + + // Add to main accumulator in true FP32 + #pragma unroll + for (int i = 0; i < cute::size(tOrO); ++i) { + tOrO(i) += temp_frag(i); + } } else { flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); } @@ -1576,13 +1595,17 @@ struct CollectiveMainloopFwdSm90 { pipeline_v.consumer_wait(smem_pipe_read, barrier_token); } - // FP8 promotion helper for P×V GEMM + // FP8: Use temp fragment to avoid FP22 accumulation in tOrO if constexpr (Is_FP8) { - flash::GmmaPromotionHelper promoter(tOrO); - promoter.step([&](auto& accum_frag) { - flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), accum_frag); - }); - promoter.flush(); + auto temp_frag = cute::make_fragment_like(tOrO); + cute::clear(temp_frag); + flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), temp_frag); + + // Add to main accumulator in true FP32 + #pragma unroll + for (int i = 0; i < cute::size(tOrO); ++i) { + tOrO(i) += temp_frag(i); + } } else { flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); } From b9e937f668711918cb5358133fb6f0b19651c4b3 Mon Sep 17 00:00:00 2001 From: Jonas Kuebler Date: Thu, 30 Oct 2025 19:39:05 +0000 Subject: [PATCH 3/5] add compile-time check to disable two-level accumulation Signed-off-by: Jonas Kuebler --- hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp | 30 ++++++++++++++++++++++++ hopper/setup.py | 2 ++ 2 files changed, 32 insertions(+) diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index 93087b13fa..718102749b 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -1254,6 +1254,7 @@ struct CollectiveMainloopFwdSm90 { if (!UseSchedulerBarrier || warp_group_idx == 0) { consumer_wait(pipeline_v, smem_pipe_read_v); } } +#ifndef FLASHATTENTION_DISABLE_FP8_TWO_LEVEL_ACCUMULATION // FP8: Use temp fragment to avoid FP22 accumulation in tOrO if constexpr (Is_FP8) { auto temp_frag = cute::make_fragment_like(tOrO); @@ -1268,6 +1269,10 @@ struct CollectiveMainloopFwdSm90 { } else { flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); } +#else + // Original path for all types + flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); +#endif warp_scheduler_barrier_arrive(); warpgroup_wait<1>(); pipeline_k.consumer_release(smem_pipe_read); // release K @@ -1340,6 +1345,7 @@ struct CollectiveMainloopFwdSm90 { if constexpr (RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); } if constexpr (!HasQv) { consumer_wait(pipeline_v, smem_pipe_read); } +#ifndef FLASHATTENTION_DISABLE_FP8_TWO_LEVEL_ACCUMULATION // FP8: Use temp fragment to avoid FP22 accumulation in tOrO if constexpr (Is_FP8) { auto temp_frag = cute::make_fragment_like(tOrO); @@ -1354,6 +1360,10 @@ struct CollectiveMainloopFwdSm90 { } else { flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO); } +#else + // Original path for all types + flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO); +#endif float const v_descale = !Is_FP8 || params.ptr_v_descale == nullptr ? 1.0f : params.ptr_v_descale[bidb * get<0>(params.stride_v_descale) + bidh_kv * get<1>(params.stride_v_descale)]; // cute::copy(softmax.finalize(v_descale), scores_scale); finalize_dispatch(scores_scale, v_descale); @@ -1411,6 +1421,7 @@ struct CollectiveMainloopFwdSm90 { if constexpr (!HasQv) { consumer_wait(pipeline_v, smem_pipe_read); } warp_scheduler_barrier_sync(); +#ifndef FLASHATTENTION_DISABLE_FP8_TWO_LEVEL_ACCUMULATION // FP8: Use temp fragment to avoid FP22 accumulation in tOrO if constexpr (Is_FP8) { if constexpr (!MmaPV_use_RS_WG1) { @@ -1444,6 +1455,15 @@ struct CollectiveMainloopFwdSm90 { flash::gemm(tiled_mma_pv_rs, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); } } +#else + // Original path for all types + if constexpr (!MmaPV_use_RS_WG1) { + flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO); + } else { + TiledMmaPV_RS tiled_mma_pv_rs; + flash::gemm(tiled_mma_pv_rs, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); + } +#endif if constexpr (!MmaPV_is_RS && MmaPV_use_RS_WG1) { arrive_on_P_write_barrier(); } warpgroup_wait<0>(); pipeline_v.consumer_release(smem_pipe_read); // release V @@ -1566,6 +1586,7 @@ struct CollectiveMainloopFwdSm90 { if constexpr (!HasQv) { pipeline_v.consumer_wait(smem_pipe_read); } cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); +#ifndef FLASHATTENTION_DISABLE_FP8_TWO_LEVEL_ACCUMULATION // FP8: Use temp fragment to avoid FP22 accumulation in tOrO if constexpr (Is_FP8) { auto temp_frag = cute::make_fragment_like(tOrO); @@ -1580,6 +1601,10 @@ struct CollectiveMainloopFwdSm90 { } else { flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); } +#else + // Original path for all types + flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); +#endif cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); pipeline_v.consumer_release(smem_pipe_read); // release V --n_block; @@ -1595,6 +1620,7 @@ struct CollectiveMainloopFwdSm90 { pipeline_v.consumer_wait(smem_pipe_read, barrier_token); } +#ifndef FLASHATTENTION_DISABLE_FP8_TWO_LEVEL_ACCUMULATION // FP8: Use temp fragment to avoid FP22 accumulation in tOrO if constexpr (Is_FP8) { auto temp_frag = cute::make_fragment_like(tOrO); @@ -1609,6 +1635,10 @@ struct CollectiveMainloopFwdSm90 { } else { flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); } +#else + // Original path for all types + flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); +#endif cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); pipeline_v.consumer_release(smem_pipe_read); // release V }; diff --git a/hopper/setup.py b/hopper/setup.py index 611eac2853..7d1776f565 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -48,6 +48,7 @@ DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "FALSE") == "TRUE" DISABLE_PAGEDKV = os.getenv("FLASH_ATTENTION_DISABLE_PAGEDKV", "FALSE") == "TRUE" DISABLE_APPENDKV = os.getenv("FLASH_ATTENTION_DISABLE_APPENDKV", "FALSE") == "TRUE" +DISABLE_FP8_TWO_LEVEL_ACCUMULATION = os.getenv("FLASH_ATTENTION_DISABLE_FP8_TWO_LEVEL_ACCUMULATION", "FALSE") == "TRUE" DISABLE_LOCAL = os.getenv("FLASH_ATTENTION_DISABLE_LOCAL", "FALSE") == "TRUE" DISABLE_SOFTCAP = os.getenv("FLASH_ATTENTION_DISABLE_SOFTCAP", "FALSE") == "TRUE" DISABLE_PACKGQA = os.getenv("FLASH_ATTENTION_DISABLE_PACKGQA", "FALSE") == "TRUE" @@ -476,6 +477,7 @@ def nvcc_threads_args(): + (["-DFLASHATTENTION_DISABLE_PAGEDKV"] if DISABLE_PAGEDKV else []) + (["-DFLASHATTENTION_DISABLE_SPLIT"] if DISABLE_SPLIT else []) + (["-DFLASHATTENTION_DISABLE_APPENDKV"] if DISABLE_APPENDKV else []) + + (["-DFLASHATTENTION_DISABLE_FP8_TWO_LEVEL_ACCUMULATION"] if DISABLE_FP8_TWO_LEVEL_ACCUMULATION else []) + (["-DFLASHATTENTION_DISABLE_LOCAL"] if DISABLE_LOCAL else []) + (["-DFLASHATTENTION_DISABLE_SOFTCAP"] if DISABLE_SOFTCAP else []) + (["-DFLASHATTENTION_DISABLE_PACKGQA"] if DISABLE_PACKGQA else []) From 49d4c831607a66f015cc56ed51bcc6f65297fe0e Mon Sep 17 00:00:00 2001 From: Jonas Kuebler Date: Fri, 31 Oct 2025 06:14:52 +0000 Subject: [PATCH 4/5] make BF16 path cleaner Signed-off-by: Jonas Kuebler --- hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp | 66 +++++++++++------------- 1 file changed, 30 insertions(+), 36 deletions(-) diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index 718102749b..3a53802141 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -1254,9 +1254,9 @@ struct CollectiveMainloopFwdSm90 { if (!UseSchedulerBarrier || warp_group_idx == 0) { consumer_wait(pipeline_v, smem_pipe_read_v); } } -#ifndef FLASHATTENTION_DISABLE_FP8_TWO_LEVEL_ACCUMULATION - // FP8: Use temp fragment to avoid FP22 accumulation in tOrO if constexpr (Is_FP8) { +#ifndef FLASHATTENTION_DISABLE_FP8_TWO_LEVEL_ACCUMULATION + // Use temp fragment to avoid FP22 accumulation in tOrO auto temp_frag = cute::make_fragment_like(tOrO); cute::clear(temp_frag); flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read_v.index()), temp_frag); @@ -1266,13 +1266,12 @@ struct CollectiveMainloopFwdSm90 { for (int i = 0; i < cute::size(tOrO); ++i) { tOrO(i) += temp_frag(i); } +#else + flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); +#endif } else { flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); } -#else - // Original path for all types - flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); -#endif warp_scheduler_barrier_arrive(); warpgroup_wait<1>(); pipeline_k.consumer_release(smem_pipe_read); // release K @@ -1345,9 +1344,9 @@ struct CollectiveMainloopFwdSm90 { if constexpr (RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); } if constexpr (!HasQv) { consumer_wait(pipeline_v, smem_pipe_read); } -#ifndef FLASHATTENTION_DISABLE_FP8_TWO_LEVEL_ACCUMULATION - // FP8: Use temp fragment to avoid FP22 accumulation in tOrO if constexpr (Is_FP8) { +#ifndef FLASHATTENTION_DISABLE_FP8_TWO_LEVEL_ACCUMULATION + // Use temp fragment to avoid FP22 accumulation in tOrO auto temp_frag = cute::make_fragment_like(tOrO); cute::clear(temp_frag); flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), temp_frag); @@ -1357,13 +1356,12 @@ struct CollectiveMainloopFwdSm90 { for (int i = 0; i < cute::size(tOrO); ++i) { tOrO(i) += temp_frag(i); } +#else + flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO); +#endif } else { flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO); } -#else - // Original path for all types - flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO); -#endif float const v_descale = !Is_FP8 || params.ptr_v_descale == nullptr ? 1.0f : params.ptr_v_descale[bidb * get<0>(params.stride_v_descale) + bidh_kv * get<1>(params.stride_v_descale)]; // cute::copy(softmax.finalize(v_descale), scores_scale); finalize_dispatch(scores_scale, v_descale); @@ -1421,9 +1419,9 @@ struct CollectiveMainloopFwdSm90 { if constexpr (!HasQv) { consumer_wait(pipeline_v, smem_pipe_read); } warp_scheduler_barrier_sync(); -#ifndef FLASHATTENTION_DISABLE_FP8_TWO_LEVEL_ACCUMULATION - // FP8: Use temp fragment to avoid FP22 accumulation in tOrO if constexpr (Is_FP8) { +#ifndef FLASHATTENTION_DISABLE_FP8_TWO_LEVEL_ACCUMULATION + // Use temp fragment to avoid FP22 accumulation in tOrO if constexpr (!MmaPV_use_RS_WG1) { auto temp_frag = cute::make_fragment_like(tOrO); cute::clear(temp_frag); @@ -1446,24 +1444,22 @@ struct CollectiveMainloopFwdSm90 { tOrO(i) += temp_frag(i); } } - } else { - // Original BF16 path +#else if constexpr (!MmaPV_use_RS_WG1) { flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO); } else { TiledMmaPV_RS tiled_mma_pv_rs; flash::gemm(tiled_mma_pv_rs, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); } - } -#else - // Original path for all types - if constexpr (!MmaPV_use_RS_WG1) { - flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO); +#endif } else { - TiledMmaPV_RS tiled_mma_pv_rs; - flash::gemm(tiled_mma_pv_rs, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); + if constexpr (!MmaPV_use_RS_WG1) { + flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO); + } else { + TiledMmaPV_RS tiled_mma_pv_rs; + flash::gemm(tiled_mma_pv_rs, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); + } } -#endif if constexpr (!MmaPV_is_RS && MmaPV_use_RS_WG1) { arrive_on_P_write_barrier(); } warpgroup_wait<0>(); pipeline_v.consumer_release(smem_pipe_read); // release V @@ -1586,9 +1582,9 @@ struct CollectiveMainloopFwdSm90 { if constexpr (!HasQv) { pipeline_v.consumer_wait(smem_pipe_read); } cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); -#ifndef FLASHATTENTION_DISABLE_FP8_TWO_LEVEL_ACCUMULATION - // FP8: Use temp fragment to avoid FP22 accumulation in tOrO if constexpr (Is_FP8) { +#ifndef FLASHATTENTION_DISABLE_FP8_TWO_LEVEL_ACCUMULATION + // Use temp fragment to avoid FP22 accumulation in tOrO auto temp_frag = cute::make_fragment_like(tOrO); cute::clear(temp_frag); flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), temp_frag); @@ -1598,13 +1594,12 @@ struct CollectiveMainloopFwdSm90 { for (int i = 0; i < cute::size(tOrO); ++i) { tOrO(i) += temp_frag(i); } +#else + flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); +#endif } else { flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); } -#else - // Original path for all types - flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); -#endif cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); pipeline_v.consumer_release(smem_pipe_read); // release V --n_block; @@ -1620,9 +1615,9 @@ struct CollectiveMainloopFwdSm90 { pipeline_v.consumer_wait(smem_pipe_read, barrier_token); } -#ifndef FLASHATTENTION_DISABLE_FP8_TWO_LEVEL_ACCUMULATION - // FP8: Use temp fragment to avoid FP22 accumulation in tOrO if constexpr (Is_FP8) { +#ifndef FLASHATTENTION_DISABLE_FP8_TWO_LEVEL_ACCUMULATION + // Use temp fragment to avoid FP22 accumulation in tOrO auto temp_frag = cute::make_fragment_like(tOrO); cute::clear(temp_frag); flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), temp_frag); @@ -1632,13 +1627,12 @@ struct CollectiveMainloopFwdSm90 { for (int i = 0; i < cute::size(tOrO); ++i) { tOrO(i) += temp_frag(i); } +#else + flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); +#endif } else { flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); } -#else - // Original path for all types - flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); -#endif cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); pipeline_v.consumer_release(smem_pipe_read); // release V }; From 96fe99234ee36efe11c6a4271640b34f55daa7dc Mon Sep 17 00:00:00 2001 From: Jonas Kuebler Date: Wed, 5 Nov 2025 14:16:36 +0000 Subject: [PATCH 5/5] move to within utils. functionally correct Signed-off-by: Jonas Kuebler --- hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp | 124 ++--------------------- hopper/utils.h | 24 +++++ 2 files changed, 32 insertions(+), 116 deletions(-) diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index 3a53802141..6e0d8b768b 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -1253,25 +1253,7 @@ struct CollectiveMainloopFwdSm90 { if constexpr(!HasQv) { if (!UseSchedulerBarrier || warp_group_idx == 0) { consumer_wait(pipeline_v, smem_pipe_read_v); } } - - if constexpr (Is_FP8) { -#ifndef FLASHATTENTION_DISABLE_FP8_TWO_LEVEL_ACCUMULATION - // Use temp fragment to avoid FP22 accumulation in tOrO - auto temp_frag = cute::make_fragment_like(tOrO); - cute::clear(temp_frag); - flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read_v.index()), temp_frag); - - // Add to main accumulator in true FP32 - #pragma unroll - for (int i = 0; i < cute::size(tOrO); ++i) { - tOrO(i) += temp_frag(i); - } -#else - flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); -#endif - } else { - flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); - } + flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); warp_scheduler_barrier_arrive(); warpgroup_wait<1>(); pipeline_k.consumer_release(smem_pipe_read); // release K @@ -1343,25 +1325,7 @@ struct CollectiveMainloopFwdSm90 { cutlass::arch::NamedBarrier::arrive(NumMmaThreadsQK + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); if constexpr (RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); } if constexpr (!HasQv) { consumer_wait(pipeline_v, smem_pipe_read); } - - if constexpr (Is_FP8) { -#ifndef FLASHATTENTION_DISABLE_FP8_TWO_LEVEL_ACCUMULATION - // Use temp fragment to avoid FP22 accumulation in tOrO - auto temp_frag = cute::make_fragment_like(tOrO); - cute::clear(temp_frag); - flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), temp_frag); - - // Add to main accumulator in true FP32 - #pragma unroll - for (int i = 0; i < cute::size(tOrO); ++i) { - tOrO(i) += temp_frag(i); - } -#else - flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO); -#endif - } else { - flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO); - } + flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO); float const v_descale = !Is_FP8 || params.ptr_v_descale == nullptr ? 1.0f : params.ptr_v_descale[bidb * get<0>(params.stride_v_descale) + bidh_kv * get<1>(params.stride_v_descale)]; // cute::copy(softmax.finalize(v_descale), scores_scale); finalize_dispatch(scores_scale, v_descale); @@ -1418,47 +1382,11 @@ struct CollectiveMainloopFwdSm90 { if constexpr (!MmaPV_is_RS && !MmaPV_use_RS_WG1) { arrive_on_P_write_barrier(); } if constexpr (!HasQv) { consumer_wait(pipeline_v, smem_pipe_read); } warp_scheduler_barrier_sync(); - - if constexpr (Is_FP8) { -#ifndef FLASHATTENTION_DISABLE_FP8_TWO_LEVEL_ACCUMULATION - // Use temp fragment to avoid FP22 accumulation in tOrO - if constexpr (!MmaPV_use_RS_WG1) { - auto temp_frag = cute::make_fragment_like(tOrO); - cute::clear(temp_frag); - flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), temp_frag); - - // Add to main accumulator in true FP32 - #pragma unroll - for (int i = 0; i < cute::size(tOrO); ++i) { - tOrO(i) += temp_frag(i); - } - } else { - TiledMmaPV_RS tiled_mma_pv_rs; - auto temp_frag = cute::make_fragment_like(tOrO); - cute::clear(temp_frag); - flash::gemm(tiled_mma_pv_rs, tOrP, tOrV(_, _, _, smem_pipe_read.index()), temp_frag); - - // Add to main accumulator in true FP32 - #pragma unroll - for (int i = 0; i < cute::size(tOrO); ++i) { - tOrO(i) += temp_frag(i); - } - } -#else - if constexpr (!MmaPV_use_RS_WG1) { - flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO); - } else { - TiledMmaPV_RS tiled_mma_pv_rs; - flash::gemm(tiled_mma_pv_rs, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); - } -#endif + if constexpr (!MmaPV_use_RS_WG1) { + flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO); } else { - if constexpr (!MmaPV_use_RS_WG1) { - flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO); - } else { - TiledMmaPV_RS tiled_mma_pv_rs; - flash::gemm(tiled_mma_pv_rs, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); - } + TiledMmaPV_RS tiled_mma_pv_rs; + flash::gemm(tiled_mma_pv_rs, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); } if constexpr (!MmaPV_is_RS && MmaPV_use_RS_WG1) { arrive_on_P_write_barrier(); } warpgroup_wait<0>(); @@ -1581,25 +1509,7 @@ struct CollectiveMainloopFwdSm90 { // If HasQv, then by the time P is ready, V must have been ready as well if constexpr (!HasQv) { pipeline_v.consumer_wait(smem_pipe_read); } cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); - - if constexpr (Is_FP8) { -#ifndef FLASHATTENTION_DISABLE_FP8_TWO_LEVEL_ACCUMULATION - // Use temp fragment to avoid FP22 accumulation in tOrO - auto temp_frag = cute::make_fragment_like(tOrO); - cute::clear(temp_frag); - flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), temp_frag); - - // Add to main accumulator in true FP32 - #pragma unroll - for (int i = 0; i < cute::size(tOrO); ++i) { - tOrO(i) += temp_frag(i); - } -#else - flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); -#endif - } else { - flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); - } + flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); pipeline_v.consumer_release(smem_pipe_read); // release V --n_block; @@ -1614,25 +1524,7 @@ struct CollectiveMainloopFwdSm90 { auto barrier_token = pipeline_v.consumer_try_wait(smem_pipe_read); pipeline_v.consumer_wait(smem_pipe_read, barrier_token); } - - if constexpr (Is_FP8) { -#ifndef FLASHATTENTION_DISABLE_FP8_TWO_LEVEL_ACCUMULATION - // Use temp fragment to avoid FP22 accumulation in tOrO - auto temp_frag = cute::make_fragment_like(tOrO); - cute::clear(temp_frag); - flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), temp_frag); - - // Add to main accumulator in true FP32 - #pragma unroll - for (int i = 0; i < cute::size(tOrO); ++i) { - tOrO(i) += temp_frag(i); - } -#else - flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); -#endif - } else { - flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); - } + flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); pipeline_v.consumer_release(smem_pipe_read); // release V }; diff --git a/hopper/utils.h b/hopper/utils.h index 3719eab920..efc755f7d7 100644 --- a/hopper/utils.h +++ b/hopper/utils.h @@ -235,6 +235,21 @@ auto mma_partition_fragment_AB(Mma const& mma, Tensor0 const& tensor0) { template CUTLASS_DEVICE void gemm(TiledMma& tiled_mma, Tensor0 const& tCrA, Tensor1 const& tCrB, Tensor2& tCrC) { +#ifndef FLASHATTENTION_DISABLE_FP8_TWO_LEVEL_ACCUMULATION + static constexpr bool Is_FP8 = cute::is_same_v + || cute::is_same_v; + static constexpr bool Use_Two_Level = Is_FP8 && !zero_init; + + auto tCrC_original = cute::make_fragment_like(tCrC); + if constexpr (Use_Two_Level) { + // Copy original values to backup + #pragma unroll + for (int i = 0; i < cute::size(tCrC); ++i) { + tCrC_original(i) = tCrC(i); + } + cute::clear(tCrC); + } +#endif if constexpr (M_slice >= 0) { static constexpr int MMA_M = decltype(size<1>(tCrC))::value; static_assert(M_slice < MMA_M); @@ -301,6 +316,15 @@ CUTLASS_DEVICE void gemm(TiledMma& tiled_mma, Tensor0 const& tCrA, Tensor1 const } } } + +#ifndef FLASHATTENTION_DISABLE_FP8_TWO_LEVEL_ACCUMULATION + if constexpr (Use_Two_Level) { + #pragma unroll + for (int i = 0; i < cute::size(tCrC); ++i) { + tCrC(i) = tCrC_original(i) + tCrC(i); // Add temp results to original + } + } +#endif } ////////////////////////////////////////////////////////////////////////////////////////////////////