Rename forward combine functions and clarify comments#270
Conversation
Co-authored-by: Copilot <copilot@github.com>
… and update related comments Co-authored-by: Copilot <copilot@github.com>
…d docstring Co-authored-by: Copilot <copilot@github.com>
…rences Co-authored-by: Copilot <copilot@github.com>
…n forward kernels
…n gated attention kernels
…and functionality Co-authored-by: Copilot <copilot@github.com>
There was a problem hiding this comment.
Pull request overview
This PR aims to clarify naming around the “forward combine” path by renaming it to “decode combine”, and updates related comments/usages across Triton kernels.
Changes:
- Renames launch/grid helpers from
get_fwd_combine_*toget_dec_combine_*. - Renames Triton combine kernel/entrypoint in
flash_dec_combine.pyand updates imports in forward paths. - Refactors softmax helpers by changing
online_softmax’s signature and introducingonline_sparse_softmax, updating several kernels accordingly.
Reviewed changes
Copilot reviewed 10 out of 10 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| flash_sparse_attn/ops/triton/launch_template.py | Renames combine launch-config helper and updates associated comments. |
| flash_sparse_attn/ops/triton/launch_grid.py | Renames combine grid helper and updates docstring. |
| flash_sparse_attn/ops/triton/flash_sparse_fwd.py | Updates combine-module import and sparse softmax call site. |
| flash_sparse_attn/ops/triton/flash_gated_fwd.py | Updates combine-module import and sparse softmax call site. |
| flash_sparse_attn/ops/triton/flash_gated_bwd.py | Comment-only change (“pointer” → “pointers”). |
| flash_sparse_attn/ops/triton/flash_dense_fwd.py | Removes block_max plumbing and updates softmax call/returns; updates combine-module import. |
| flash_sparse_attn/ops/triton/flash_dec_combine.py | Renames combine kernel and Python entrypoint; updates helper calls to renamed launch/grid funcs. |
| flash_sparse_attn/ops/triton/flash_bwd_preprocess.py | Comment-only change (“pointer” → “pointers”). |
| flash_sparse_attn/ops/triton/flash_bwd_postprocess.py | Comment-only change (“pointer” → “pointers”). |
| flash_sparse_attn/ops/triton/activations.py | Changes online_softmax API and introduces online_sparse_softmax. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
| if is_split_kv: | ||
| flash_fwd_combine._flash_attn_fwd_combine( | ||
| flash_dec_combine._flash_attn_fwd_combine( |
There was a problem hiding this comment.
Varlen path still calls flash_dec_combine._flash_attn_fwd_combine, but the renamed function is _flash_attn_dec_combine. This will break split-KV combine at runtime.
| flash_dec_combine._flash_attn_fwd_combine( | |
| flash_dec_combine._flash_attn_dec_combine( |
|
|
||
| if is_split_kv: | ||
| flash_fwd_combine._flash_attn_fwd_combine( | ||
| flash_dec_combine._flash_attn_fwd_combine( |
There was a problem hiding this comment.
This call uses flash_dec_combine._flash_attn_fwd_combine, but flash_dec_combine.py defines _flash_attn_dec_combine after the rename. Update the call (or add an alias) to avoid runtime failures in the split-KV path.
| flash_dec_combine._flash_attn_fwd_combine( | |
| flash_dec_combine._flash_attn_dec_combine( |
|
|
||
| if is_split_kv: | ||
| flash_fwd_combine._flash_attn_fwd_combine( | ||
| flash_dec_combine._flash_attn_fwd_combine( |
There was a problem hiding this comment.
Varlen split-KV combine still calls _flash_attn_fwd_combine, but the function was renamed to _flash_attn_dec_combine in flash_dec_combine.py. Update this call to match the new name.
| flash_dec_combine._flash_attn_fwd_combine( | |
| flash_dec_combine._flash_attn_dec_combine( |
| scale_log2, | ||
| SOFTMAX_THRESHOLD_LOG2: tl.constexpr, | ||
| CHECK_INF: tl.constexpr, | ||
| ): |
There was a problem hiding this comment.
SOFTMAX_THRESHOLD_LOG2 is declared as tl.constexpr, but call sites pass a runtime value computed inside the kernel (e.g., via seqlen_info.get_softmax_threshold). Triton requires tl.constexpr arguments to be compile-time constants, so this is likely to fail compilation; make the threshold a regular runtime argument (remove tl.constexpr) or restructure so the threshold is known at specialization time.
| def online_softmax( | ||
| acc_s, | ||
| block_max, | ||
| row_max, | ||
| row_sum, | ||
| scale_log2, | ||
| softmax_threshold_log2, | ||
| CHECK_INF: tl.constexpr, | ||
| RESCALE_THRESHOLD: tl.constexpr, | ||
| ): | ||
| """ | ||
| Apply online softmax to acc_s, and update block_max, row_max and row_sum. | ||
| Apply online softmax to acc_s, and update row_max and row_sum. | ||
|
|
There was a problem hiding this comment.
PR description says this is a pure rename/comment-clarification refactor, but this file also changes the softmax API (signature changes to online_softmax and introduces online_sparse_softmax). If this is intentional, please update the PR description to reflect the behavioral/API surface changes; otherwise, consider limiting the PR to naming/comment changes only.
|
|
||
| if is_split_kv: | ||
| flash_fwd_combine._flash_attn_fwd_combine( | ||
| flash_dec_combine._flash_attn_fwd_combine( |
There was a problem hiding this comment.
flash_dec_combine no longer defines _flash_attn_fwd_combine (it was renamed to _flash_attn_dec_combine), so this call will raise an AttributeError. Update the call site (and any similar ones) to use the renamed entrypoint, or add a backwards-compatible alias in flash_dec_combine.py.
| flash_dec_combine._flash_attn_fwd_combine( | |
| flash_dec_combine._flash_attn_dec_combine( |
|
|
||
| if is_split_kv: | ||
| flash_fwd_combine._flash_attn_fwd_combine( | ||
| flash_dec_combine._flash_attn_fwd_combine( |
There was a problem hiding this comment.
flash_dec_combine exports _flash_attn_dec_combine, not _flash_attn_fwd_combine, so this varlen combine path will fail at runtime. Update to the renamed function (or keep an alias).
| flash_dec_combine._flash_attn_fwd_combine( | |
| flash_dec_combine._flash_attn_dec_combine( |
|
|
||
| if is_split_kv: | ||
| flash_fwd_combine._flash_attn_fwd_combine( | ||
| flash_dec_combine._flash_attn_fwd_combine( |
There was a problem hiding this comment.
This still calls flash_dec_combine._flash_attn_fwd_combine, but the function was renamed to _flash_attn_dec_combine in flash_dec_combine.py. Update the call to prevent runtime failures when is_split_kv is true.
| flash_dec_combine._flash_attn_fwd_combine( | |
| flash_dec_combine._flash_attn_dec_combine( |
…functions for consistency
…e attention kernels
…oftmax function Co-authored-by: Copilot <copilot@github.com>
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 10 out of 10 changed files in this pull request and generated 3 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| @@ -75,15 +75,14 @@ def _fwd_inner_sparse_base_kernel( | |||
|
|
|||
| # Apply online softmax | |||
There was a problem hiding this comment.
The comment says "Apply online softmax" but the code now calls online_sparse_softmax, which includes sparse-specific behavior (block-level skip). Update the comment to avoid confusion when debugging/maintaining this kernel.
| # Apply online softmax | |
| # Apply online sparse softmax (may skip fully masked/sparse blocks) |
| @@ -99,15 +99,14 @@ def _fwd_inner_gated_base_kernel( | |||
|
|
|||
| # Apply online softmax | |||
There was a problem hiding this comment.
The comment says "Apply online softmax" but the implementation uses online_sparse_softmax (sparse/block-threshold variant). Please update the comment so it matches the actual routine being used.
| # Apply online softmax | |
| # Apply online sparse softmax (sparse/block-threshold variant) |
| def online_softmax( | ||
| acc_s, | ||
| row_max, | ||
| row_sum, | ||
| scale_log2, | ||
| CHECK_INF: tl.constexpr, | ||
| RESCALE_THRESHOLD: tl.constexpr, | ||
| ): |
There was a problem hiding this comment.
This PR is described as a pure rename/comment clarity change, but this file introduces a new online_softmax implementation and renames the previous implementation to online_sparse_softmax (plus signature/return-value changes used by dense kernels). Please update the PR description to reflect these functional refactors, or split the refactor into a separate PR for easier review/risk assessment.
Summary
Root Cause
Changes
fwd_combinetodec_combineand updated related comments for clarity.Reproduction
Tests
Compatibility
Checklist