Skip to content

Conversation

@jmkuebler
Copy link

@jmkuebler jmkuebler commented Oct 29, 2025

Problem

The FP8 TensorCore accumulator does not correctly work in FP32 precision but in some reduced format. This materializes in drastic accuracy loss on long-context needle in a haystack tasks as discussed in
vllm-project/vllm#23813 (comment)

Solution:

Similar to Sage-Attention 2, we implement a two-level accumulator, where only a few 1 tile is combined with the FP8 TC accumulator and then flushed to an actual FP32 accumulator. This contains floating point errors to a few single tile instead of the full matmult.
The solution also follows the approach here https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8.hpp#L600

Testing:

We run LM Eval niah_multikey_2 where the largest accuracy drops where observed.

model=meta-llama/Llama-3.1-8B-Instruct

lm_eval \
    --batch_size auto \
    --model vllm \
    --model_args pretrained=$model,kv_cache_dtype=fp8,dtype=auto,add_bos_token=False,max_model_len=131072,tensor_parallel_size=8,gpu_memory_utilization=0.9,enable_chunked_prefill=True,trust_remote_code=True \
    --tasks niah_multikey_2 \
    --metadata='{"max_seq_lengths":[128000]}' \
    --seed 0
type accuracy wall clock time in minutes
BF16 91.20%
FP8 before 13.00% 22.20814
FP8 this PR v1 @ 1339384 89.20% 22.12524
FP8 this PR v2 @ 2ba74fa 89.20% 21.82062

Performance Impact

We measured the FP8 decoding perfomance which slightly degrades by around 5%, but still is much faster than BF16. Just the setting with 4Q:1KV at head dim 128 is not competitive (see details in comments).

We added an compile-time env flag (FLASH_ATTENTION_DISABLE_FP8_TWO_LEVEL_ACCUMULATION=TRUE) to restore the legacy behavior. But the default will be changed by this PR to favor accuracy over performance.

--
Thanks @MatthewBonanni for quick iterations and support!

Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
@jmkuebler jmkuebler force-pushed the two-level-accumulation_dev branch from fe016a7 to 1339384 Compare October 29, 2025 12:50
AccumFragment& accum_main;
AccumFragment accum_temp;
int mma_count = 0;
static constexpr int promotion_interval = 2; // Tunable: balance precision vs performance
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this needs some further experimentation. If 1 works well perf-wise that would be the best choice.

Copy link

@MatthewBonanni MatthewBonanni Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the promotion_interval working properly as written? If the promoter is being constructed immediately before the mma won't mma_count be zero every time? Seems like the intent would be to construct the promoter when the other fragments are allocated.

Copy link

@MatthewBonanni MatthewBonanni Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, if 1 works well perf-wise we might not need the promoter struct at all, but it is a nice way to encapsulate this feature

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are absolutely right, that didn't work as intendend.

Making this promoter work correctly would require a bit more work I think. I have now pushed a version two that is heavily simplified and just adds every tile. So the promotion interval is one.
I'll share some benchmarking numbers below... and then we can discuss how to do it correctly.

Copy link

@MatthewBonanni MatthewBonanni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this approach! Have you had a chance to profile it? One concern @LucasWilkinson mentioned to me is that the register usage is already pretty high as-is, so adding the extra accumulator could push it over the limit.

Also, do we want to make it possible to enable / disable this with an env var? I suppose if it's a universal win then it isn't necessary

AccumFragment& accum_main;
AccumFragment accum_temp;
int mma_count = 0;
static constexpr int promotion_interval = 2; // Tunable: balance precision vs performance
Copy link

@MatthewBonanni MatthewBonanni Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the promotion_interval working properly as written? If the promoter is being constructed immediately before the mma won't mma_count be zero every time? Seems like the intent would be to construct the promoter when the other fragments are allocated.

Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
@jmkuebler
Copy link
Author

Since the first version had some dead code that wasn't actually working, I pushed a v2. It still mitigates the accuracy loss, but adds every tile's output directly to the actualFP32 accumulation register.
I did a decoding benchmark (with cudagraph and num_splits=32 as is the default in vLLM for decoding) and varied the context lenght as well as looked at different variants.

In summary we the BF16 perf remains unchanged (expected) wheres the FP8 performance does indeed degrade. The serverity depends on the specific attn configuration.

Speedup Summary (v2 vs main):
==================================================
gqa32to4_hdim128_b1:
  BF16 avg speedup: 0.996x
  FP8 avg speedup:  0.955x
gqa64to8_hdim64_b1:
  BF16 avg speedup: 0.998x
  FP8 avg speedup:  0.945x
gqa32to8_hdim128_b1:
  BF16 avg speedup: 0.997x
  FP8 avg speedup:  0.953x
gqa4to1_hdim128_b1:
  BF16 avg speedup: 1.012x
  FP8 avg speedup:  0.867x
config seq_len bf16_main bf16_v2 fp8_main fp8_v2 bf16_speedup fp8_speedup
gqa32to4_hdim128_b1 1000 11.39 11.4 10.57 10.81 1.00 0.98
gqa32to4_hdim128_b1 2000 11.88 11.87 11.3 11.63 1.00 0.97
gqa32to4_hdim128_b1 5000 13.45 13.37 13.05 13.32 1.01 0.98
gqa32to4_hdim128_b1 10000 15.68 16.34 14.12 15.4 0.96 0.92
gqa32to4_hdim128_b1 20000 20.69 21.09 16.28 17.96 0.98 0.91
gqa32to4_hdim128_b1 50000 45.61 45.57 29.47 30.96 1.00 0.95
gqa32to4_hdim128_b1 100000 78.48 78.47 46.12 48.06 1.00 0.96
gqa32to4_hdim128_b1 200000 146.61 147.27 80.06 82.46 1.00 0.97
gqa32to4_hdim128_b1 400000 279.71 281.53 144.73 151.05 0.99 0.96
gqa32to4_hdim128_b1 800000 545.84 545.83 273.9 286.54 1.00 0.96
gqa32to4_hdim128_b1 1600000 1077.62 1060.55 535.81 561.38 1.02 0.95
gqa64to8_hdim64_b1 1000 10.51 10.59 10.29 10.25 0.99 1.00
gqa64to8_hdim64_b1 2000 10.9 10.97 11.08 11.42 0.99 0.97
gqa64to8_hdim64_b1 5000 12.42 12.53 12.33 12.79 0.99 0.96
gqa64to8_hdim64_b1 10000 14.82 14.87 14.29 15.1 1.00 0.95
gqa64to8_hdim64_b1 20000 19.1 19.23 17.62 18.97 0.99 0.93
gqa64to8_hdim64_b1 50000 44.86 44.13 31.33 33.59 1.02 0.93
gqa64to8_hdim64_b1 100000 76.96 78.46 50.9 54.18 0.98 0.94
gqa64to8_hdim64_b1 200000 145.03 145.05 88.31 95.61 1.00 0.92
gqa64to8_hdim64_b1 400000 281.03 278.15 163.92 177.53 1.01 0.92
gqa64to8_hdim64_b1 800000 547.05 549.76 316.83 340.98 1.00 0.93
gqa64to8_hdim64_b1 1600000 1090.39 1084.99 623.42 670.98 1.00 0.93
gqa32to8_hdim128_b1 1000 11.28 11.22 10.8 10.88 1.01 0.99
gqa32to8_hdim128_b1 2000 11.85 11.86 11.78 11.9 1.00 0.99
gqa32to8_hdim128_b1 5000 13.89 14 13.08 14.21 0.99 0.92
gqa32to8_hdim128_b1 10000 18.05 18.79 15.44 16.83 0.96 0.92
gqa32to8_hdim128_b1 20000 38.07 37.86 22.04 23.94 1.01 0.92
gqa32to8_hdim128_b1 50000 76.83 76.78 44.79 47 1.00 0.95
gqa32to8_hdim128_b1 100000 144.66 142.64 77.66 80.73 1.01 0.96
gqa32to8_hdim128_b1 200000 277.14 277.11 143.36 149.26 1.00 0.96
gqa32to8_hdim128_b1 400000 545.98 546.81 273.81 284.8 1.00 0.96
gqa32to8_hdim128_b1 800000 1078.41 1085.93 537.9 563.59 0.99 0.95
gqa32to8_hdim128_b1 1600000 2149.55 2155.51 1067.84 1122.13 1.00 0.95
gqa4to1_hdim128_b1 1000 11.36 11.3 10.79 10.86 1.01 0.99
gqa4to1_hdim128_b1 2000 11.62 11.53 11.12 11.37 1.01 0.98
gqa4to1_hdim128_b1 5000 12.6 12.52 12.48 12.94 1.01 0.96
gqa4to1_hdim128_b1 10000 14.46 14.55 13.67 14.73 0.99 0.93
gqa4to1_hdim128_b1 20000 16.76 17.12 15.88 17.45 0.98 0.91
gqa4to1_hdim128_b1 50000 24.45 24.37 22.5 26.28 1.00 0.86
gqa4to1_hdim128_b1 100000 37.74 37.78 32.77 40.1 1.00 0.82
gqa4to1_hdim128_b1 200000 63.45 63.18 54.76 69.15 1.00 0.79
gqa4to1_hdim128_b1 400000 112.97 113.39 97.3 125.23 1.00 0.78
gqa4to1_hdim128_b1 800000 224.86 212.04 180.7 236.84 1.06 0.76
gqa4to1_hdim128_b1 1600000 436.42 405.46 346.28 454.87 1.08 0.76

@MatthewBonanni
Copy link

MatthewBonanni commented Oct 30, 2025

@jmkuebler cool, good to know! I think the fact that there's a bit of a perf hit lends weight to the idea of making this optional with an env var. Do you agree? Do you think it should be enabled or disabled by default? I'd still lean towards enabling it by default. Most of the slowdown seems to be at long context lengths, where, if this were disabled, the accuracy would plummet anyway

@jmkuebler
Copy link
Author

@MatthewBonanni I would agree to have the default to be save from the accuracy side and have users opt-in for more perf at the risk of accuracy degradation.

How would it work to add this as an env var that we can control via vLLM? Or would it be an env var that directly goes to FA?

@MatthewBonanni
Copy link

I would probably do it as an env var that directly goes to FA (prefixed with FLASH_ATTENTION_)

@MatthewBonanni
Copy link

MatthewBonanni commented Oct 30, 2025

Like the other FLASH_ATTENTION_DISABLE_<X> variables, you could make a compile-time check in setup.py that registers a -DFLASHATTENTION_DISABLE_<X> build argument, and pick it up with an #ifdef in the code

Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
@jmkuebler
Copy link
Author

jmkuebler commented Oct 30, 2025

@MatthewBonanni I added the compile time flag and confirmed that the numbers above are still valid.
In particular
FLASH_ATTENTION_DISABLE_FP8_TWO_LEVEL_ACCUMULATION=TRUE python setup.py install --force then results in the original performance numbers.

Flag is a bit verbose, but I wanted it to be a bit descriptive. Happy to change for sth else ofc.

@MatthewBonanni
Copy link

MatthewBonanni commented Oct 30, 2025

LGTM! I like the flag as is. Let's find someone with permissions to review

if constexpr (!HasQv) { consumer_wait(pipeline_v, smem_pipe_read); }
warp_scheduler_barrier_sync();

#ifndef FLASHATTENTION_DISABLE_FP8_TWO_LEVEL_ACCUMULATION
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we restrict the compile guards and changes to be within if constexpr (Is_FP8) { blocks as much as possible? to keep the BF16 path as untouched as possible?

Copy link
Author

@jmkuebler jmkuebler Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Makes it much easier to grasp that BF16 is not affected by the flag at all.

Unfortunately, we still have duplicate code because FP8 with two-level accumulation turned-off at compile-time takes the exact same calls as BF16.
But since we went for the compile time variable, I cannot think of a way to remove that redundancy. Let me know if I missed sth @LucasWilkinson

Copy link
Collaborator

@LucasWilkinson LucasWilkinson Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ya I see that now :/ I think the cleanest thing to do might be to constexpr in flash::gemm that does:

                    auto temp_frag = cute::make_fragment_like(tOrO);
                    cute::clear(temp_frag);
                    ...
                    #pragma unroll
                    for (int i = 0; i < cute::size(tOrO); ++i) {
                        tOrO(i) += temp_frag(i);
                    }

when tensor A and C are fp8. That way we can leave the mainloop file untouched and compatible with upstream. Ill try to think about this more in a bit 👍

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

untested but maybe something like (with the compile guards too)

diff --git a/hopper/utils.h b/hopper/utils.h
index 3719eab..ee680b7 100644
--- a/hopper/utils.h
+++ b/hopper/utils.h
@@ -235,6 +235,15 @@ auto mma_partition_fragment_AB(Mma const& mma, Tensor0 const& tensor0) {
 template <bool zero_init=false, int wg_wait=0, bool SwapAB=false, int M_slice=-1,
         typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>
 CUTLASS_DEVICE void gemm(TiledMma& tiled_mma, Tensor0 const& tCrA, Tensor1 const& tCrB, Tensor2& tCrC) {
+    static constexpr bool Is_FP8 = std::is_same_v<typename Tensor0::value_type, cutlass::float_e4m3_t> 
+        || std::is_same_v<typename Tensor0::value_type, cutlass::float_e5m2_t>;
+
+    auto& tCrC_out = tCrC;
+    if constexpr (Is_FP8) {
+        tCrC = cute::make_fragment_like(tCrC);
+        cute::clear(tCrC);
+    }
+    
     if constexpr (M_slice >= 0) {
         static constexpr int MMA_M = decltype(size<1>(tCrC))::value;
         static_assert(M_slice < MMA_M);
@@ -301,6 +310,13 @@ CUTLASS_DEVICE void gemm(TiledMma& tiled_mma, Tensor0 const& tCrA, Tensor1 const
             }
         }
     }
+
+    if constexpr (Is_FP8) {
+        #pragma unroll
+        for (int i = 0; i < cute::size(tCrC); ++i) {
+            tCrC_out(i) += tCrC(i);
+        }
+    }
 }
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jmkuebler did you get a chance to try something like this?

Copy link
Author

@jmkuebler jmkuebler Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LucasWilkinson , I got it functionally to run but still testing around a bit and see if I can improve. Hope to get it done by today or tmr.
Posted update in main conversation. Thank you very much for this nice suggestion!

Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
@jmkuebler
Copy link
Author

jmkuebler commented Nov 5, 2025

@LucasWilkinson I now implemented a small modification of your proposal. Thanks a ton for that idea. Your version wasn’t yet functional since tCrC can’t be rebound in C++. Instead, I copy the original values into a temporary fragment, accumulate into a cleared local buffer, and then merge the results back afterward. [Note I figured this out with an LLM and don't grasp the C++ details completely, so please take a critical look].

The other modification I made was to only use the two-level when zero_init is false. In that case there is nothing to acucmulate on the second level and we should thus skip it. That's particularly relevant as otherwise, since we now moved to the GEMM, it might also be unecessarily called during Q x K^T, where I think it is not so critical as the contraction dimension there is always the head_dim which is quite small.

✅ I confirmed that this code still fixes our original issues, and performance wise it is similar to my previous version (see numbers below).

⏭️ I bumped the git tag to run CI again with the new refactor vllm-project/vllm#27889

config seq_len fp8_main fp8 (this PR originally) fp8 (this PR now)
gqa32to4_hdim128_b1 1000 10.57 10.81 10.73
gqa32to4_hdim128_b1 2000 11.3 11.63 11.41
gqa32to4_hdim128_b1 5000 13.05 13.32 13.34
gqa32to4_hdim128_b1 10000 14.12 15.4 15.43
gqa32to4_hdim128_b1 20000 16.28 17.96 17.97
gqa32to4_hdim128_b1 50000 29.47 30.96 30.68
gqa32to4_hdim128_b1 100000 46.12 48.06 48.13
gqa32to4_hdim128_b1 200000 80.06 82.46 82.01
gqa32to4_hdim128_b1 400000 144.73 151.05 150.09
gqa32to4_hdim128_b1 800000 273.9 286.54 287.3
gqa32to4_hdim128_b1 1600000 535.81 561.38 560.14
gqa32to8_hdim128_b1 1000 10.8 10.88 11.14
gqa32to8_hdim128_b1 2000 11.78 11.9 12.42
gqa32to8_hdim128_b1 5000 13.08 14.21 14.27
gqa32to8_hdim128_b1 10000 15.44 16.83 17.07
gqa32to8_hdim128_b1 20000 22.04 23.94 24
gqa32to8_hdim128_b1 50000 44.79 47 47.14
gqa32to8_hdim128_b1 100000 77.66 80.73 80.81
gqa32to8_hdim128_b1 200000 143.36 149.26 149.18
gqa32to8_hdim128_b1 400000 273.81 284.8 284.03
gqa32to8_hdim128_b1 800000 537.9 563.59 564.56
gqa32to8_hdim128_b1 1600000 1067.84 1122.13 1123.68
gqa4to1_hdim128_b1 1000 10.79 10.86 10.86
gqa4to1_hdim128_b1 2000 11.12 11.37 11.37
gqa4to1_hdim128_b1 5000 12.48 12.94 12.96
gqa4to1_hdim128_b1 10000 13.67 14.73 14.61
gqa4to1_hdim128_b1 20000 15.88 17.45 17.46
gqa4to1_hdim128_b1 50000 22.5 26.28 26.28
gqa4to1_hdim128_b1 100000 32.77 40.1 40.09
gqa4to1_hdim128_b1 200000 54.76 69.15 69.17
gqa4to1_hdim128_b1 400000 97.3 125.23 124.92
gqa4to1_hdim128_b1 800000 180.7 236.84 234.15
gqa4to1_hdim128_b1 1600000 346.28 454.87 455.41
gqa64to8_hdim64_b1 1000 10.29 10.25 10.3
gqa64to8_hdim64_b1 2000 11.08 11.42 11.37
gqa64to8_hdim64_b1 5000 12.33 12.79 12.73
gqa64to8_hdim64_b1 10000 14.29 15.1 15.09
gqa64to8_hdim64_b1 20000 17.62 18.97 18.84
gqa64to8_hdim64_b1 50000 31.33 33.59 33.33
gqa64to8_hdim64_b1 100000 50.9 54.18 54.26
gqa64to8_hdim64_b1 200000 88.31 95.61 96.08
gqa64to8_hdim64_b1 400000 163.92 177.53 179.44
gqa64to8_hdim64_b1 800000 316.83 340.98 342.44
gqa64to8_hdim64_b1 1600000 623.42 670.98 682.94

Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing; thanks for all the hard work! This is very clean now 😄

@LucasWilkinson LucasWilkinson merged commit 8e1b01d into vllm-project:main Nov 5, 2025
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants