[BUG FIX] Optimize LSE computation in forward combine kernel#265
[BUG FIX] Optimize LSE computation in forward combine kernel#265LoserCheems merged 5 commits intomainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
Optimizes the Triton forward “combine” kernel used to merge split-KV partial outputs/LSEs (decode path) by removing redundant grid dimensions and refactoring the LSE accumulation to an online log-sum-exp form.
Changes:
- Removed the
head_dimargument and K-dimension fromget_fwd_combine_grid, making the combine launch grid 2D. - Refactored
_fwd_combine_kernelto compute combined outputs and final LSE using an online log-sum-exp accumulator and a single LSE store path. - Removed now-unused parameters/imports related to the old MAX_SPLITS-based implementation.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
flash_sparse_attn/ops/triton/launch_grid.py |
Simplifies the combine kernel launch grid by removing the (redundant) head-dim tiling dimension and its parameter. |
flash_sparse_attn/ops/triton/flash_fwd_combine.py |
Updates the combine kernel to use a 2D grid and an online LSE accumulation scheme while removing unused logic/params. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # Compute normalized exponentials | ||
| new_e_max = tl.maximum(lse_s, e_max) | ||
| old_scale = tl.exp(e_max - new_e_max) | ||
| exp_logic = tl.exp(lse_s - new_e_max) | ||
|
|
||
| # Load partial outputs | ||
| o_s = tl.sum(tl.load(out_part_ptrs, boundary_check=(0, 1, 2)), axis=0) | ||
|
|
||
| # Advance output pointers | ||
| out_part_ptrs = tl.advance(out_part_ptrs, (1, 0, 0)) | ||
|
|
||
| inv_sum = tl.where((sum_exp == 0.0) | (sum_exp != sum_exp), 0.0, 1.0 / sum_exp) | ||
| # Compute scaled outputs | ||
| acc_o *= old_scale[:, None] | ||
| acc_o += exp_logic[:, None] * o_s | ||
|
|
||
| for s in tl.static_range(MAX_SPLITS): | ||
| if s < num_splits: | ||
| lse_s = tl.sum( | ||
| tl.where( | ||
| (tl.arange(0, MAX_SPLITS) == s)[:, None], | ||
| lse_vals, | ||
| tl.full((MAX_SPLITS, TILE_M), 0.0, dtype=tl.float32), | ||
| ), | ||
| axis=0, | ||
| ) | ||
| scale = tl.exp(lse_s - max_lse) * inv_sum | ||
| o_s = tl.sum(tl.load(out_part_ptrs, boundary_check=(0, 1, 2)), axis=0) | ||
| out_part_ptrs = tl.advance(out_part_ptrs, (1, 0, 0)) | ||
| acc_o += scale[:, None] * o_s | ||
| # Update e_sum and e_max | ||
| e_sum = e_sum * old_scale + exp_logic | ||
| e_max = new_e_max |
There was a problem hiding this comment.
The new online log-sum-exp update can produce NaNs when a split has no valid blocks and writes LSE=-inf. In that case e_max starts at -inf and new_e_max can also be -inf, so tl.exp(e_max - new_e_max) and tl.exp(lse_s - new_e_max) evaluate exp(-inf - -inf) -> NaN, which then contaminates e_sum/acc_o (and inv_sum=0 won’t reliably clear NaNs). Please add an explicit guard for the new_e_max == -inf (or e_max == -inf & lse_s == -inf) case to keep old_scale finite (typically 1.0) and exp_logic at 0.0, matching the prior behavior that avoided NaNs when all contributions were -inf.
| # Compute normalized exponentials | ||
| new_e_max = tl.maximum(lse_s, e_max) | ||
| old_scale = tl.exp(e_max - new_e_max) | ||
| exp_logic = tl.exp(lse_s - new_e_max) |
There was a problem hiding this comment.
This kernel is only exercised in split-KV decode paths (e.g., seqlen_q == 1), and the main pytest correctness tests appear to use seqlen_q > 1 so the combine path is likely untested. Adding a targeted correctness test that forces is_split_kv and includes a local-window (or other) configuration that yields at least one empty split (LSE=-inf) would help catch regressions in the updated LSE accumulation logic.
| # Compute normalized exponentials | |
| new_e_max = tl.maximum(lse_s, e_max) | |
| old_scale = tl.exp(e_max - new_e_max) | |
| exp_logic = tl.exp(lse_s - new_e_max) | |
| # Compute normalized exponentials. | |
| # Empty splits can produce lse_s == -inf; when e_max is also -inf, | |
| # the naive difference (-inf - -inf) becomes NaN and poisons the | |
| # accumulation. Treat invalid/empty contributions as zero instead. | |
| new_e_max = tl.maximum(lse_s, e_max) | |
| has_prev = e_max != float("-inf") | |
| has_curr = lse_s != float("-inf") | |
| old_scale = tl.where(has_prev, tl.exp(e_max - new_e_max), 0.0) | |
| exp_logic = tl.where(has_curr, tl.exp(lse_s - new_e_max), 0.0) |
Summary
Root Cause
Changes
_fwd_combine_kernelto optimize LSE computation and removed thehead_dimparameter from theget_fwd_combine_gridfunction.Reproduction
Tests
Compatibility
Checklist