Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions hopper/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "FALSE") == "TRUE"
DISABLE_PAGEDKV = os.getenv("FLASH_ATTENTION_DISABLE_PAGEDKV", "FALSE") == "TRUE"
DISABLE_APPENDKV = os.getenv("FLASH_ATTENTION_DISABLE_APPENDKV", "FALSE") == "TRUE"
DISABLE_FP8_TWO_LEVEL_ACCUMULATION = os.getenv("FLASH_ATTENTION_DISABLE_FP8_TWO_LEVEL_ACCUMULATION", "FALSE") == "TRUE"
DISABLE_LOCAL = os.getenv("FLASH_ATTENTION_DISABLE_LOCAL", "FALSE") == "TRUE"
DISABLE_SOFTCAP = os.getenv("FLASH_ATTENTION_DISABLE_SOFTCAP", "FALSE") == "TRUE"
DISABLE_PACKGQA = os.getenv("FLASH_ATTENTION_DISABLE_PACKGQA", "FALSE") == "TRUE"
Expand Down Expand Up @@ -476,6 +477,7 @@ def nvcc_threads_args():
+ (["-DFLASHATTENTION_DISABLE_PAGEDKV"] if DISABLE_PAGEDKV else [])
+ (["-DFLASHATTENTION_DISABLE_SPLIT"] if DISABLE_SPLIT else [])
+ (["-DFLASHATTENTION_DISABLE_APPENDKV"] if DISABLE_APPENDKV else [])
+ (["-DFLASHATTENTION_DISABLE_FP8_TWO_LEVEL_ACCUMULATION"] if DISABLE_FP8_TWO_LEVEL_ACCUMULATION else [])
+ (["-DFLASHATTENTION_DISABLE_LOCAL"] if DISABLE_LOCAL else [])
+ (["-DFLASHATTENTION_DISABLE_SOFTCAP"] if DISABLE_SOFTCAP else [])
+ (["-DFLASHATTENTION_DISABLE_PACKGQA"] if DISABLE_PACKGQA else [])
Expand Down
24 changes: 24 additions & 0 deletions hopper/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,21 @@ auto mma_partition_fragment_AB(Mma const& mma, Tensor0 const& tensor0) {
template <bool zero_init=false, int wg_wait=0, bool SwapAB=false, int M_slice=-1,
typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>
CUTLASS_DEVICE void gemm(TiledMma& tiled_mma, Tensor0 const& tCrA, Tensor1 const& tCrB, Tensor2& tCrC) {
#ifndef FLASHATTENTION_DISABLE_FP8_TWO_LEVEL_ACCUMULATION
static constexpr bool Is_FP8 = cute::is_same_v<typename Tensor0::value_type, cutlass::float_e4m3_t>
|| cute::is_same_v<typename Tensor0::value_type, cutlass::float_e5m2_t>;
static constexpr bool Use_Two_Level = Is_FP8 && !zero_init;

auto tCrC_original = cute::make_fragment_like(tCrC);
if constexpr (Use_Two_Level) {
// Copy original values to backup
#pragma unroll
for (int i = 0; i < cute::size(tCrC); ++i) {
tCrC_original(i) = tCrC(i);
}
cute::clear(tCrC);
}
#endif
if constexpr (M_slice >= 0) {
static constexpr int MMA_M = decltype(size<1>(tCrC))::value;
static_assert(M_slice < MMA_M);
Expand Down Expand Up @@ -301,6 +316,15 @@ CUTLASS_DEVICE void gemm(TiledMma& tiled_mma, Tensor0 const& tCrA, Tensor1 const
}
}
}

#ifndef FLASHATTENTION_DISABLE_FP8_TWO_LEVEL_ACCUMULATION
if constexpr (Use_Two_Level) {
#pragma unroll
for (int i = 0; i < cute::size(tCrC); ++i) {
tCrC(i) = tCrC_original(i) + tCrC(i); // Add temp results to original
}
}
#endif
}

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down