Skip to content

Conversation

qsang-nv
Copy link
Collaborator

@qsang-nv qsang-nv commented Sep 25, 2025

📌 Description

Add xqa fp8 mha and fp8 kv cache.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>
@qsang-nv qsang-nv changed the title add xqa fp8 mha add xqa fp8 mha and fp8 kv cache Sep 25, 2025
Copy link
Contributor

Summary of Changes

Hello @qsang-nv, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the FlashInfer library by integrating FP8 support for both Multi-Head Attention computations and the Key-Value cache within the XQA framework. These changes are primarily aimed at leveraging the advanced capabilities of NVIDIA Hopper GPUs (SM90+) to achieve substantial performance and memory efficiency gains. The implementation includes new CUDA kernels utilizing GMMA and TMA, along with Python-side modifications to enable configurable FP8 execution paths, ensuring that users can opt into these optimizations while maintaining numerical stability.

Highlights

  • FP8 Multi-Head Attention (MHA) Support: Introduced support for FP8 (8-bit floating point) Multi-Head Attention, enabling more efficient computations on compatible hardware.
  • FP8 Key-Value (KV) Cache: Added functionality for using FP8 for the Key-Value cache, which can significantly reduce memory footprint and improve performance.
  • NVIDIA Hopper (SM90+) Optimizations: Integrated specialized Hopper GPU features like GMMA (General Matrix Multiply Accumulate) and TMA (Tensor Memory Accelerator) for optimized FP8 operations and efficient memory access patterns.
  • Configurable FP8 Execution: The XQA module generation and runtime execution now allow explicit control over whether FP8 MHA and FP8 KV cache are utilized, providing flexibility for different precision requirements.
  • Numerical Stability Adjustments: Modified the safeInitRowMax value and adjusted test tolerances to account for the numerical characteristics of lower precision FP8 computations, ensuring stability.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

Summary of Changes

Hello @qsang-nv, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the FlashInfer library by introducing support for FP8 Multi-Head Attention (MHA) and FP8 Key-Value (KV) cache. These additions leverage advanced features of NVIDIA Hopper GPUs, such as GMMA and TMA, to achieve higher performance and memory efficiency for large language model inference. The changes span the CUDA C++ backend, including new kernel implementations and memory management utilities, as well as updates to the Python AOT compilation and testing framework to ensure robust integration and validation of the new FP8 capabilities.

Highlights

  • FP8 MHA Integration: Introduced a new "run_fp8_mha" boolean parameter to the "xqa_wrapper" function, enabling conditional execution of FP8 Multi-Head Attention kernels.
  • FP8 KV Cache Support: Added "fp8_kv_cache" parameter to the AOT compilation and Python interface, allowing the use of FP8 for Key-Value cache storage.
  • New Low-Level CUDA Kernels: Incorporated "gmma.cuh" for Generic Matrix Multiply Accumulate (GMMA) operations and "tma.h" for Tensor Memory Accelerator (TMA) asynchronous memory transfers, leveraging Hopper (SM90) architecture features for optimized FP8 performance.
  • CUDA Tensor Map Utilities: Introduced "tensorMap.cpp" and "tensorMap.h" to facilitate efficient access and management of KV caches using CUDA Tensor Maps.
  • Numerical Stability Improvements: Adjusted "safeInitRowMax" in "utils.cuh" and modified attention mask application in "mha.cu" to enhance numerical stability, especially relevant for lower precision formats.
  • Expanded Testing: Updated "test_xqa.py" to include comprehensive tests for FP8 MHA and FP8 KV cache, with adjusted numerical tolerance checks to account for reduced precision.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for FP8 multi-head attention (MHA) and FP8 KV cache within the XQA kernel, primarily targeting the NVIDIA Hopper architecture. This is a significant feature addition, enabled by new CUDA primitives for Hopper's Tensor Memory Access (TMA) and Grace Hopper MMA (GMMA) instructions. The changes are well-implemented, including new CUDA headers for hardware abstraction, a dispatch mechanism for the new FP8 kernel path, and corresponding updates to the Python build system and tests. The tests have been thoughtfully adjusted with relaxed tolerances for FP8 precision. My review includes one suggestion to refactor a small piece of duplicated code to enhance maintainability.

Comment on lines 37 to 72
if (run_fp8_mha) {
launchHopperF8MHAFlashInfer(
multiProcessorCount, nbKHeads, slidingWinSize, qScale,
reinterpret_cast<OutputHead*>(output.data_ptr()),
#if LOW_PREC_OUTPUT
reinterpret_cast<float const*>(rcpOutScale.data_ptr()),
reinterpret_cast<float const*>(rcpOutScale.data_ptr()),
#endif
reinterpret_cast<InputHead const*>(q.data_ptr()), attentionSinksPtr,
reinterpret_cast<GMemCacheHead*>(pool.data_ptr()),
reinterpret_cast<KVCachePageIndex const*>(kvCachePageList.data_ptr()),
maxSeqLen, reinterpret_cast<uint32_t const*>(seqLen.data_ptr()), batchSize,
reinterpret_cast<float const*>(kvCacheScale.data_ptr()),
reinterpret_cast<InputHead const*>(q.data_ptr()), attentionSinksPtr,
reinterpret_cast<GMemCacheHead*>(pool.data_ptr()),
reinterpret_cast<KVCachePageIndex const*>(kvCachePageList.data_ptr()), maxSeqLen,
reinterpret_cast<uint32_t const*>(seqLen.data_ptr()), batchSize,
reinterpret_cast<float const*>(kvCacheScale.data_ptr()),
#if SPEC_DEC
qSeqLen, reinterpret_cast<uint32_t const*>(qCuSeqLens.data_ptr()),
reinterpret_cast<MaskType const*>(mask.data_ptr()),
qSeqLen, reinterpret_cast<uint32_t const*>(qCuSeqLens.data_ptr()),
reinterpret_cast<MaskType const*>(mask.data_ptr()),
#endif
reinterpret_cast<uint32_t*>(semaphores.data_ptr()),
reinterpret_cast<void*>(scratch.data_ptr()), stream);
reinterpret_cast<uint32_t*>(semaphores.data_ptr()),
reinterpret_cast<void*>(scratch.data_ptr()), stream);
} else {
launchMHAFlashInfer(multiProcessorCount, nbKHeads, slidingWinSize, qScale,
reinterpret_cast<OutputHead*>(output.data_ptr()),
#if LOW_PREC_OUTPUT
reinterpret_cast<float const*>(rcpOutScale.data_ptr()),
#endif
reinterpret_cast<InputHead const*>(q.data_ptr()), attentionSinksPtr,
reinterpret_cast<GMemCacheHead*>(pool.data_ptr()),
reinterpret_cast<KVCachePageIndex const*>(kvCachePageList.data_ptr()),
maxSeqLen, reinterpret_cast<uint32_t const*>(seqLen.data_ptr()), batchSize,
reinterpret_cast<float const*>(kvCacheScale.data_ptr()),
#if SPEC_DEC
qSeqLen, reinterpret_cast<uint32_t const*>(qCuSeqLens.data_ptr()),
reinterpret_cast<MaskType const*>(mask.data_ptr()),
#endif
reinterpret_cast<uint32_t*>(semaphores.data_ptr()),
reinterpret_cast<void*>(scratch.data_ptr()), stream);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The if and else blocks contain identical arguments passed to two different functions (launchHopperF8MHAFlashInfer and launchMHAFlashInfer). This code duplication can be reduced to improve maintainability. Since both functions share the same signature, you can use a function pointer to select the appropriate kernel and then call it once with the common set of arguments.

  using mha_launcher_t = decltype(&launchMHAFlashInfer);
  mha_launcher_t launcher = run_fp8_mha ? &launchHopperF8MHAFlashInfer : &launchMHAFlashInfer;
  launcher(multiProcessorCount, nbKHeads, slidingWinSize, qScale,
           reinterpret_cast<OutputHead*>(output.data_ptr()),
#if LOW_PREC_OUTPUT
           reinterpret_cast<float const*>(rcpOutScale.data_ptr()),
#endif
           reinterpret_cast<InputHead const*>(q.data_ptr()), attentionSinksPtr,
           reinterpret_cast<GMemCacheHead*>(pool.data_ptr()),
           reinterpret_cast<KVCachePageIndex const*>(kvCachePageList.data_ptr()),
           maxSeqLen, reinterpret_cast<uint32_t const*>(seqLen.data_ptr()), batchSize,
           reinterpret_cast<float const*>(kvCacheScale.data_ptr()),
#if SPEC_DEC
           qSeqLen, reinterpret_cast<uint32_t const*>(qCuSeqLens.data_ptr()),
           reinterpret_cast<MaskType const*>(mask.data_ptr()),
#endif
           reinterpret_cast<uint32_t*>(semaphores.data_ptr()),
           reinterpret_cast<void*>(scratch.data_ptr()), stream);

Copy link
Contributor

Summary of Changes

Hello @qsang-nv, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the XQA (eXtended Query Attention) kernels by introducing support for FP8 Multi-Head Attention and FP8 Key-Value cache. These changes are designed to optimize performance and memory usage on NVIDIA Hopper (SM90+) GPUs through low-level CUDA programming, including asynchronous tensor memory access and matrix operations. The integration ensures that the system can efficiently handle lower precision data types, with comprehensive testing to maintain accuracy within acceptable bounds.

Highlights

  • FP8 Multi-Head Attention (MHA) Support: Introduced a new launchHopperF8MHAFlashInfer function and integrated it into the xqa_wrapper to enable FP8 MHA execution, controlled by a new run_fp8_mha boolean parameter.
  • FP8 KV Cache Implementation: Added support for FP8 Key-Value (KV) cache, including new CUtensorMap utilities for efficient memory access and configuration options in the AOT/JIT compilation system.
  • Low-Level CUDA Optimizations: Incorporated new CUDA kernel files (gmma.cuh, tma.h) that define asynchronous matrix multiply accumulate (GMMA) operations and tensor memory access (TMA) functions, leveraging Hopper architecture features for performance.
  • Numerical Stability Improvements: Adjusted the safeInitRowMax constant and its usage in applyMaskFromInput to enhance numerical stability, especially for large values, in attention calculations.
  • AOT/JIT Compilation and Testing: Updated the AOT and JIT compilation infrastructure to generate kernels for various FP8 configurations and expanded the test suite to thoroughly validate the new FP8 MHA and KV cache functionalities, including adjusted precision checks.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for FP8 multi-head attention (MHA) and FP8 KV cache in the XQA kernels, targeting Hopper architecture for performance improvements. The changes include new low-level CUDA files (gmma.cuh, tma.h, tensorMap.cpp) with Hopper-specific WGMMA and TMA instructions, a new FP8 MHA kernel entry point, and updates to the AOT compilation scripts and Python wrappers to handle the new FP8 variants. The tests have also been updated to include FP8 configurations and use a more lenient assertion method to account for precision differences.

My review focuses on code maintainability and clarity. I've suggested refactoring a duplicated code block in the C++ wrapper to improve readability and proposed adding a comment in the Python tests to clarify a magic number used for data scaling. Overall, the changes are well-structured and the addition of FP8 support is a valuable performance enhancement.

Comment on lines 37 to 72
if (run_fp8_mha) {
launchHopperF8MHAFlashInfer(
multiProcessorCount, nbKHeads, slidingWinSize, qScale,
reinterpret_cast<OutputHead*>(output.data_ptr()),
#if LOW_PREC_OUTPUT
reinterpret_cast<float const*>(rcpOutScale.data_ptr()),
reinterpret_cast<float const*>(rcpOutScale.data_ptr()),
#endif
reinterpret_cast<InputHead const*>(q.data_ptr()), attentionSinksPtr,
reinterpret_cast<GMemCacheHead*>(pool.data_ptr()),
reinterpret_cast<KVCachePageIndex const*>(kvCachePageList.data_ptr()),
maxSeqLen, reinterpret_cast<uint32_t const*>(seqLen.data_ptr()), batchSize,
reinterpret_cast<float const*>(kvCacheScale.data_ptr()),
reinterpret_cast<InputHead const*>(q.data_ptr()), attentionSinksPtr,
reinterpret_cast<GMemCacheHead*>(pool.data_ptr()),
reinterpret_cast<KVCachePageIndex const*>(kvCachePageList.data_ptr()), maxSeqLen,
reinterpret_cast<uint32_t const*>(seqLen.data_ptr()), batchSize,
reinterpret_cast<float const*>(kvCacheScale.data_ptr()),
#if SPEC_DEC
qSeqLen, reinterpret_cast<uint32_t const*>(qCuSeqLens.data_ptr()),
reinterpret_cast<MaskType const*>(mask.data_ptr()),
qSeqLen, reinterpret_cast<uint32_t const*>(qCuSeqLens.data_ptr()),
reinterpret_cast<MaskType const*>(mask.data_ptr()),
#endif
reinterpret_cast<uint32_t*>(semaphores.data_ptr()),
reinterpret_cast<void*>(scratch.data_ptr()), stream);
reinterpret_cast<uint32_t*>(semaphores.data_ptr()),
reinterpret_cast<void*>(scratch.data_ptr()), stream);
} else {
launchMHAFlashInfer(multiProcessorCount, nbKHeads, slidingWinSize, qScale,
reinterpret_cast<OutputHead*>(output.data_ptr()),
#if LOW_PREC_OUTPUT
reinterpret_cast<float const*>(rcpOutScale.data_ptr()),
#endif
reinterpret_cast<InputHead const*>(q.data_ptr()), attentionSinksPtr,
reinterpret_cast<GMemCacheHead*>(pool.data_ptr()),
reinterpret_cast<KVCachePageIndex const*>(kvCachePageList.data_ptr()),
maxSeqLen, reinterpret_cast<uint32_t const*>(seqLen.data_ptr()), batchSize,
reinterpret_cast<float const*>(kvCacheScale.data_ptr()),
#if SPEC_DEC
qSeqLen, reinterpret_cast<uint32_t const*>(qCuSeqLens.data_ptr()),
reinterpret_cast<MaskType const*>(mask.data_ptr()),
#endif
reinterpret_cast<uint32_t*>(semaphores.data_ptr()),
reinterpret_cast<void*>(scratch.data_ptr()), stream);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's a large block of duplicated code for launching the MHA kernels. The only difference between the if and else blocks is the function being called (launchHopperF8MHAFlashInfer vs. launchMHAFlashInfer). This could be refactored to improve maintainability and reduce redundancy.

Consider using a function pointer to select the kernel, and then make a single call. This would make the code cleaner and easier to manage if more arguments are added in the future.

For example:

void (*mha_func)(uint32_t, uint32_t, ...); // Using a function pointer type alias

if (run_fp8_mha) {
    mha_func = &launchHopperF8MHAFlashInfer;
} else {
    mha_func = &launchMHAFlashInfer;
}

mha_func(
    multiProcessorCount,
    nbKHeads,
    slidingWinSize,
    // ... other arguments
);
  using mha_func_t = void (*)(uint32_t, uint32_t, uint32_t, float, OutputHead*,
#if LOW_PREC_OUTPUT
                              float const*,
#endif
                              InputHead const*, float const*, GMemCacheHead*,
                              KVCachePageIndex const*, uint32_t, uint32_t const*, uint32_t,
                              float const* __restrict__,
#if SPEC_DEC
                              uint32_t, uint32_t const*, MaskType const*,
#endif
                              uint32_t*, void*, cudaStream_t);

  mha_func_t mha_func = run_fp8_mha ? &launchHopperF8MHAFlashInfer : &launchMHAFlashInfer;

  mha_func(multiProcessorCount, nbKHeads, slidingWinSize, qScale,
           reinterpret_cast<OutputHead*>(output.data_ptr()),
#if LOW_PREC_OUTPUT
           reinterpret_cast<float const*>(rcpOutScale.data_ptr()),
#endif
           reinterpret_cast<InputHead const*>(q.data_ptr()), attentionSinksPtr,
           reinterpret_cast<GMemCacheHead*>(pool.data_ptr()),
           reinterpret_cast<KVCachePageIndex const*>(kvCachePageList.data_ptr()), maxSeqLen,
           reinterpret_cast<uint32_t const*>(seqLen.data_ptr()), batchSize,
           reinterpret_cast<float const*>(kvCacheScale.data_ptr()),
#if SPEC_DEC
           qSeqLen, reinterpret_cast<uint32_t const*>(qCuSeqLens.data_ptr()),
           reinterpret_cast<MaskType const*>(mask.data_ptr()),
#endif
           reinterpret_cast<uint32_t*>(semaphores.data_ptr()),
           reinterpret_cast<void*>(scratch.data_ptr()), stream);

Comment on lines 230 to 231
if fp8_kv_cache:
cache_heads /= 4.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The value 4.0 is used to scale down the cache_heads tensor when fp8_kv_cache is enabled. This appears to be a magic number. To improve code clarity and maintainability, please add a comment explaining the rationale for this specific scaling factor. For example, explaining that it's to prevent overflow and how 4.0 was determined would be very helpful for future readers.

Suggested change
if fp8_kv_cache:
cache_heads /= 4.0
if fp8_kv_cache:
# Scale down the cache heads to keep values within the representable range of FP8
# and prevent overflow during computation. The factor 4.0 is chosen empirically.
cache_heads /= 4.0

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for FP8 multi-head attention (MHA) and FP8 KV cache for Cross-Query Attention (XQA), targeting NVIDIA's Hopper architecture. This is a significant feature addition that leverages low-level hardware capabilities like TMA and GMMA for performance. The changes include new CUDA files for these Hopper-specific features, along with updates to the Python build system and tests to accommodate the new configurations. The review identified a critical bug in the new TMA storeAsync implementation and a high-severity correctness issue related to the handling of masked values in the softmax computation.

Comment on lines +220 to +243
: "memory");
} else if constexpr (nbDims == 5) {
asm volatile(
"cp.async.bulk.tensor.2d.global.shared::cta.bulk_group.tile [%0, {%1, %2, %3, %4, %5}], "
"[%6];\n"
:
: "l"(reinterpret_cast<uint64_t>(&tensorMap)), "r"(offset[0]), "r"(offset[1]),
"r"(offset[2]), "r"(offset[3]), "r"(offset[4]), "l"(__cvta_generic_to_shared(src))
: "memory");
} else {
static_assert(nbDims >= 1 && nbDims <= 5);
}
}

__device__ inline void setTensorMapGlbAddr(CUtensorMap& tensorMap, void* ptr) {
asm volatile(
"tensormap.replace.tile.global_address.global.b1024.b64 [%0], %1;\n" ::"l"(&tensorMap),
"l"(ptr)
: "memory");
}

__device__ inline void commitGroup() {
asm volatile("cp.async.bulk.commit_group;\n" : : : "memory");
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There appears to be a copy-paste error in the storeAsync template function. For nbDims of 3, 4, and 5, the inline assembly instruction is cp.async.bulk.tensor.2d..., but it should be cp.async.bulk.tensor.3d..., cp.async.bulk.tensor.4d..., and cp.async.bulk.tensor.5d... respectively. This will lead to incorrect memory access patterns and likely data corruption for higher-dimensional tensors.

    else if constexpr (nbDims == 3)
    {
        asm volatile("cp.async.bulk.tensor.3d.global.shared::cta.bulk_group.tile [%0, {%1, %2, %3}], [%4];\n"
                     :
                     : "l"(reinterpret_cast<uint64_t>(&tensorMap)), "r"(offset[0]), "r"(offset[1]), "r"(offset[2]),
                     "l"(__cvta_generic_to_shared(src))
                     : "memory");
    }
    else if constexpr (nbDims == 4)
    {
        asm volatile("cp.async.bulk.tensor.4d.global.shared::cta.bulk_group.tile [%0, {%1, %2, %3, %4}], [%5];\n"
                     :
                     : "l"(reinterpret_cast<uint64_t>(&tensorMap)), "r"(offset[0]), "r"(offset[1]), "r"(offset[2]),
                     "r"(offset[3]), "l"(__cvta_generic_to_shared(src))
                     : "memory");
    }
    else if constexpr (nbDims == 5)
    {
        asm volatile("cp.async.bulk.tensor.5d.global.shared::cta.bulk_group.tile [%0, {%1, %2, %3, %4, %5}], [%6];\n"
                     :
                     : "l"(reinterpret_cast<uint64_t>(&tensorMap)), "r"(offset[0]), "r"(offset[1]), "r"(offset[2]),
                     "r"(offset[3]), "r"(offset[4]), "l"(__cvta_generic_to_shared(src))
                     : "memory");
    }

? true
: packedMask & (1u << ((col + actualQSeqLen - nbValidCols) - maskPosStart));
acc(m, n)(i, j) = maskFlag && col < nbValidCols ? acc(m, n)(i, j) : -INFINITY;
acc(m, n)(i, j) = maskFlag && col < nbValidCols ? acc(m, n)(i, j) : safeInitRowMax;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using safeInitRowMax for masked elements can lead to incorrect results. When an entire row/sequence is masked, all attention scores become safeInitRowMax. In the softmax computation, maxVal also becomes safeInitRowMax, and exp(score - maxVal) evaluates to 1 for all masked positions. This results in a uniform attention distribution over masked tokens, and the output becomes the average of values in V, instead of zero.

A correct implementation should ensure that the softmax output for masked tokens is zero. If the entire row is masked, the final output should also be zero. This might require changes in the softmax function to handle safeInitRowMax specially, and in the final normalization step to handle a row sum of zero.

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>
@qsang-nv qsang-nv requested a review from yzh119 September 25, 2025 08:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant