diff --git a/hopper/setup.py b/hopper/setup.py index 611eac2853..7d1776f565 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -48,6 +48,7 @@ DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "FALSE") == "TRUE" DISABLE_PAGEDKV = os.getenv("FLASH_ATTENTION_DISABLE_PAGEDKV", "FALSE") == "TRUE" DISABLE_APPENDKV = os.getenv("FLASH_ATTENTION_DISABLE_APPENDKV", "FALSE") == "TRUE" +DISABLE_FP8_TWO_LEVEL_ACCUMULATION = os.getenv("FLASH_ATTENTION_DISABLE_FP8_TWO_LEVEL_ACCUMULATION", "FALSE") == "TRUE" DISABLE_LOCAL = os.getenv("FLASH_ATTENTION_DISABLE_LOCAL", "FALSE") == "TRUE" DISABLE_SOFTCAP = os.getenv("FLASH_ATTENTION_DISABLE_SOFTCAP", "FALSE") == "TRUE" DISABLE_PACKGQA = os.getenv("FLASH_ATTENTION_DISABLE_PACKGQA", "FALSE") == "TRUE" @@ -476,6 +477,7 @@ def nvcc_threads_args(): + (["-DFLASHATTENTION_DISABLE_PAGEDKV"] if DISABLE_PAGEDKV else []) + (["-DFLASHATTENTION_DISABLE_SPLIT"] if DISABLE_SPLIT else []) + (["-DFLASHATTENTION_DISABLE_APPENDKV"] if DISABLE_APPENDKV else []) + + (["-DFLASHATTENTION_DISABLE_FP8_TWO_LEVEL_ACCUMULATION"] if DISABLE_FP8_TWO_LEVEL_ACCUMULATION else []) + (["-DFLASHATTENTION_DISABLE_LOCAL"] if DISABLE_LOCAL else []) + (["-DFLASHATTENTION_DISABLE_SOFTCAP"] if DISABLE_SOFTCAP else []) + (["-DFLASHATTENTION_DISABLE_PACKGQA"] if DISABLE_PACKGQA else []) diff --git a/hopper/utils.h b/hopper/utils.h index 3719eab920..efc755f7d7 100644 --- a/hopper/utils.h +++ b/hopper/utils.h @@ -235,6 +235,21 @@ auto mma_partition_fragment_AB(Mma const& mma, Tensor0 const& tensor0) { template CUTLASS_DEVICE void gemm(TiledMma& tiled_mma, Tensor0 const& tCrA, Tensor1 const& tCrB, Tensor2& tCrC) { +#ifndef FLASHATTENTION_DISABLE_FP8_TWO_LEVEL_ACCUMULATION + static constexpr bool Is_FP8 = cute::is_same_v + || cute::is_same_v; + static constexpr bool Use_Two_Level = Is_FP8 && !zero_init; + + auto tCrC_original = cute::make_fragment_like(tCrC); + if constexpr (Use_Two_Level) { + // Copy original values to backup + #pragma unroll + for (int i = 0; i < cute::size(tCrC); ++i) { + tCrC_original(i) = tCrC(i); + } + cute::clear(tCrC); + } +#endif if constexpr (M_slice >= 0) { static constexpr int MMA_M = decltype(size<1>(tCrC))::value; static_assert(M_slice < MMA_M); @@ -301,6 +316,15 @@ CUTLASS_DEVICE void gemm(TiledMma& tiled_mma, Tensor0 const& tCrA, Tensor1 const } } } + +#ifndef FLASHATTENTION_DISABLE_FP8_TWO_LEVEL_ACCUMULATION + if constexpr (Use_Two_Level) { + #pragma unroll + for (int i = 0; i < cute::size(tCrC); ++i) { + tCrC(i) = tCrC_original(i) + tCrC(i); // Add temp results to original + } + } +#endif } ////////////////////////////////////////////////////////////////////////////////////////////////////