-
Notifications
You must be signed in to change notification settings - Fork 103
Add two-level accumulation for SM90 FP8 FWD to mitigate long-context accuracy drops #104
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add two-level accumulation for SM90 FP8 FWD to mitigate long-context accuracy drops #104
Conversation
Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
fe016a7 to
1339384
Compare
hopper/fp8_promotion_helper.hpp
Outdated
| AccumFragment& accum_main; | ||
| AccumFragment accum_temp; | ||
| int mma_count = 0; | ||
| static constexpr int promotion_interval = 2; // Tunable: balance precision vs performance |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this needs some further experimentation. If 1 works well perf-wise that would be the best choice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the promotion_interval working properly as written? If the promoter is being constructed immediately before the mma won't mma_count be zero every time? Seems like the intent would be to construct the promoter when the other fragments are allocated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, if 1 works well perf-wise we might not need the promoter struct at all, but it is a nice way to encapsulate this feature
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are absolutely right, that didn't work as intendend.
Making this promoter work correctly would require a bit more work I think. I have now pushed a version two that is heavily simplified and just adds every tile. So the promotion interval is one.
I'll share some benchmarking numbers below... and then we can discuss how to do it correctly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like this approach! Have you had a chance to profile it? One concern @LucasWilkinson mentioned to me is that the register usage is already pretty high as-is, so adding the extra accumulator could push it over the limit.
Also, do we want to make it possible to enable / disable this with an env var? I suppose if it's a universal win then it isn't necessary
hopper/fp8_promotion_helper.hpp
Outdated
| AccumFragment& accum_main; | ||
| AccumFragment accum_temp; | ||
| int mma_count = 0; | ||
| static constexpr int promotion_interval = 2; // Tunable: balance precision vs performance |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the promotion_interval working properly as written? If the promoter is being constructed immediately before the mma won't mma_count be zero every time? Seems like the intent would be to construct the promoter when the other fragments are allocated.
Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
|
Since the first version had some dead code that wasn't actually working, I pushed a v2. It still mitigates the accuracy loss, but adds every tile's output directly to the actualFP32 accumulation register. In summary we the BF16 perf remains unchanged (expected) wheres the FP8 performance does indeed degrade. The serverity depends on the specific attn configuration.
|
|
@jmkuebler cool, good to know! I think the fact that there's a bit of a perf hit lends weight to the idea of making this optional with an env var. Do you agree? Do you think it should be enabled or disabled by default? I'd still lean towards enabling it by default. Most of the slowdown seems to be at long context lengths, where, if this were disabled, the accuracy would plummet anyway |
|
@MatthewBonanni I would agree to have the default to be save from the accuracy side and have users opt-in for more perf at the risk of accuracy degradation. How would it work to add this as an env var that we can control via vLLM? Or would it be an env var that directly goes to FA? |
|
I would probably do it as an env var that directly goes to FA (prefixed with |
|
Like the other |
Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
|
@MatthewBonanni I added the compile time flag and confirmed that the numbers above are still valid. Flag is a bit verbose, but I wanted it to be a bit descriptive. Happy to change for sth else ofc. |
|
LGTM! I like the flag as is. Let's find someone with permissions to review |
| if constexpr (!HasQv) { consumer_wait(pipeline_v, smem_pipe_read); } | ||
| warp_scheduler_barrier_sync(); | ||
|
|
||
| #ifndef FLASHATTENTION_DISABLE_FP8_TWO_LEVEL_ACCUMULATION |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we restrict the compile guards and changes to be within if constexpr (Is_FP8) { blocks as much as possible? to keep the BF16 path as untouched as possible?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Makes it much easier to grasp that BF16 is not affected by the flag at all.
Unfortunately, we still have duplicate code because FP8 with two-level accumulation turned-off at compile-time takes the exact same calls as BF16.
But since we went for the compile time variable, I cannot think of a way to remove that redundancy. Let me know if I missed sth @LucasWilkinson
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ya I see that now :/ I think the cleanest thing to do might be to constexpr in flash::gemm that does:
auto temp_frag = cute::make_fragment_like(tOrO);
cute::clear(temp_frag);
...
#pragma unroll
for (int i = 0; i < cute::size(tOrO); ++i) {
tOrO(i) += temp_frag(i);
}
when tensor A and C are fp8. That way we can leave the mainloop file untouched and compatible with upstream. Ill try to think about this more in a bit 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
untested but maybe something like (with the compile guards too)
diff --git a/hopper/utils.h b/hopper/utils.h
index 3719eab..ee680b7 100644
--- a/hopper/utils.h
+++ b/hopper/utils.h
@@ -235,6 +235,15 @@ auto mma_partition_fragment_AB(Mma const& mma, Tensor0 const& tensor0) {
template <bool zero_init=false, int wg_wait=0, bool SwapAB=false, int M_slice=-1,
typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>
CUTLASS_DEVICE void gemm(TiledMma& tiled_mma, Tensor0 const& tCrA, Tensor1 const& tCrB, Tensor2& tCrC) {
+ static constexpr bool Is_FP8 = std::is_same_v<typename Tensor0::value_type, cutlass::float_e4m3_t>
+ || std::is_same_v<typename Tensor0::value_type, cutlass::float_e5m2_t>;
+
+ auto& tCrC_out = tCrC;
+ if constexpr (Is_FP8) {
+ tCrC = cute::make_fragment_like(tCrC);
+ cute::clear(tCrC);
+ }
+
if constexpr (M_slice >= 0) {
static constexpr int MMA_M = decltype(size<1>(tCrC))::value;
static_assert(M_slice < MMA_M);
@@ -301,6 +310,13 @@ CUTLASS_DEVICE void gemm(TiledMma& tiled_mma, Tensor0 const& tCrA, Tensor1 const
}
}
}
+
+ if constexpr (Is_FP8) {
+ #pragma unroll
+ for (int i = 0; i < cute::size(tCrC); ++i) {
+ tCrC_out(i) += tCrC(i);
+ }
+ }
}
////////////////////////////////////////////////////////////////////////////////////////////////////
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jmkuebler did you get a chance to try something like this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@LucasWilkinson , I got it functionally to run but still testing around a bit and see if I can improve. Hope to get it done by today or tmr.
Posted update in main conversation. Thank you very much for this nice suggestion!
Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
|
@LucasWilkinson I now implemented a small modification of your proposal. Thanks a ton for that idea. Your version wasn’t yet functional since tCrC can’t be rebound in C++. Instead, I copy the original values into a temporary fragment, accumulate into a cleared local buffer, and then merge the results back afterward. [Note I figured this out with an LLM and don't grasp the C++ details completely, so please take a critical look]. The other modification I made was to only use the two-level when ✅ I confirmed that this code still fixes our original issues, and performance wise it is similar to my previous version (see numbers below). ⏭️ I bumped the git tag to run CI again with the new refactor vllm-project/vllm#27889
|
LucasWilkinson
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Amazing; thanks for all the hard work! This is very clean now 😄
Problem
The FP8 TensorCore accumulator does not correctly work in FP32 precision but in some reduced format. This materializes in drastic accuracy loss on long-context needle in a haystack tasks as discussed in
vllm-project/vllm#23813 (comment)
Solution:
Similar to Sage-Attention 2, we implement a two-level accumulator, where only
a few1 tile is combined with the FP8 TC accumulator and then flushed to an actual FP32 accumulator. This contains floating point errors to afewsingle tile instead of the full matmult.The solution also follows the approach here https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8.hpp#L600
Testing:
We run LM Eval
niah_multikey_2where the largest accuracy drops where observed.Performance Impact
We measured the FP8 decoding perfomance which slightly degrades by around 5%, but still is much faster than BF16. Just the setting with 4Q:1KV at head dim 128 is not competitive (see details in comments).
We added an compile-time env flag (
FLASH_ATTENTION_DISABLE_FP8_TWO_LEVEL_ACCUMULATION=TRUE) to restore the legacy behavior. But the default will be changed by this PR to favor accuracy over performance.--
Thanks @MatthewBonanni for quick iterations and support!