Skip to content

[BUG FIX] Optimize LSE computation in forward combine kernel#265

Merged
LoserCheems merged 5 commits intomainfrom
optim-combine-func
Apr 14, 2026
Merged

[BUG FIX] Optimize LSE computation in forward combine kernel#265
LoserCheems merged 5 commits intomainfrom
optim-combine-func

Conversation

@LoserCheems
Copy link
Copy Markdown
Collaborator

Summary

  • This fixes inefficiencies in the LSE computation within the forward combine kernel.

Root Cause

  • The original implementation contained unused variables and unnecessary parameters that hindered performance.

Changes

  • Refactored the _fwd_combine_kernel to optimize LSE computation and removed the head_dim parameter from the get_fwd_combine_grid function.

Reproduction

  • Test the forward combine kernel with various input sizes to observe performance improvements.

Tests

  • Validated changes through existing tests for the forward combine functionality.

Compatibility

  • No backward compatibility issues identified.

Checklist

  • Linked issue provided
  • Adds or updates tests
  • Updates docs if needed
  • No perf regressions

Copilot AI review requested due to automatic review settings April 14, 2026 01:43
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Optimizes the Triton forward “combine” kernel used to merge split-KV partial outputs/LSEs (decode path) by removing redundant grid dimensions and refactoring the LSE accumulation to an online log-sum-exp form.

Changes:

  • Removed the head_dim argument and K-dimension from get_fwd_combine_grid, making the combine launch grid 2D.
  • Refactored _fwd_combine_kernel to compute combined outputs and final LSE using an online log-sum-exp accumulator and a single LSE store path.
  • Removed now-unused parameters/imports related to the old MAX_SPLITS-based implementation.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
flash_sparse_attn/ops/triton/launch_grid.py Simplifies the combine kernel launch grid by removing the (redundant) head-dim tiling dimension and its parameter.
flash_sparse_attn/ops/triton/flash_fwd_combine.py Updates the combine kernel to use a 2D grid and an online LSE accumulation scheme while removing unused logic/params.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +141 to +158
# Compute normalized exponentials
new_e_max = tl.maximum(lse_s, e_max)
old_scale = tl.exp(e_max - new_e_max)
exp_logic = tl.exp(lse_s - new_e_max)

# Load partial outputs
o_s = tl.sum(tl.load(out_part_ptrs, boundary_check=(0, 1, 2)), axis=0)

# Advance output pointers
out_part_ptrs = tl.advance(out_part_ptrs, (1, 0, 0))

inv_sum = tl.where((sum_exp == 0.0) | (sum_exp != sum_exp), 0.0, 1.0 / sum_exp)
# Compute scaled outputs
acc_o *= old_scale[:, None]
acc_o += exp_logic[:, None] * o_s

for s in tl.static_range(MAX_SPLITS):
if s < num_splits:
lse_s = tl.sum(
tl.where(
(tl.arange(0, MAX_SPLITS) == s)[:, None],
lse_vals,
tl.full((MAX_SPLITS, TILE_M), 0.0, dtype=tl.float32),
),
axis=0,
)
scale = tl.exp(lse_s - max_lse) * inv_sum
o_s = tl.sum(tl.load(out_part_ptrs, boundary_check=(0, 1, 2)), axis=0)
out_part_ptrs = tl.advance(out_part_ptrs, (1, 0, 0))
acc_o += scale[:, None] * o_s
# Update e_sum and e_max
e_sum = e_sum * old_scale + exp_logic
e_max = new_e_max
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new online log-sum-exp update can produce NaNs when a split has no valid blocks and writes LSE=-inf. In that case e_max starts at -inf and new_e_max can also be -inf, so tl.exp(e_max - new_e_max) and tl.exp(lse_s - new_e_max) evaluate exp(-inf - -inf) -> NaN, which then contaminates e_sum/acc_o (and inv_sum=0 won’t reliably clear NaNs). Please add an explicit guard for the new_e_max == -inf (or e_max == -inf & lse_s == -inf) case to keep old_scale finite (typically 1.0) and exp_logic at 0.0, matching the prior behavior that avoided NaNs when all contributions were -inf.

Copilot uses AI. Check for mistakes.
Comment on lines +141 to +144
# Compute normalized exponentials
new_e_max = tl.maximum(lse_s, e_max)
old_scale = tl.exp(e_max - new_e_max)
exp_logic = tl.exp(lse_s - new_e_max)
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This kernel is only exercised in split-KV decode paths (e.g., seqlen_q == 1), and the main pytest correctness tests appear to use seqlen_q > 1 so the combine path is likely untested. Adding a targeted correctness test that forces is_split_kv and includes a local-window (or other) configuration that yields at least one empty split (LSE=-inf) would help catch regressions in the updated LSE accumulation logic.

Suggested change
# Compute normalized exponentials
new_e_max = tl.maximum(lse_s, e_max)
old_scale = tl.exp(e_max - new_e_max)
exp_logic = tl.exp(lse_s - new_e_max)
# Compute normalized exponentials.
# Empty splits can produce lse_s == -inf; when e_max is also -inf,
# the naive difference (-inf - -inf) becomes NaN and poisons the
# accumulation. Treat invalid/empty contributions as zero instead.
new_e_max = tl.maximum(lse_s, e_max)
has_prev = e_max != float("-inf")
has_curr = lse_s != float("-inf")
old_scale = tl.where(has_prev, tl.exp(e_max - new_e_max), 0.0)
exp_logic = tl.where(has_curr, tl.exp(lse_s - new_e_max), 0.0)

Copilot uses AI. Check for mistakes.
@LoserCheems LoserCheems merged commit 0411f0e into main Apr 14, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants