-
Notifications
You must be signed in to change notification settings - Fork 59
Rename forward combine functions and clarify comments #270
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
79a7670
7650a66
4dc2796
b511322
c6507bb
af18e0c
31ee171
318de29
f772c1f
40e1baf
a795d01
e7f411d
7ef17e6
41c6f09
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,17 +9,74 @@ def check_inf(x): | |
|
|
||
| @triton.jit | ||
| def online_softmax( | ||
| acc_s, | ||
| row_max, | ||
| row_sum, | ||
| scale_log2, | ||
| CHECK_INF: tl.constexpr, | ||
| RESCALE_THRESHOLD: tl.constexpr, | ||
| ): | ||
| """ | ||
| Apply online softmax to acc_s, and update row_max and row_sum. | ||
|
|
||
|
Comment on lines
11
to
21
|
||
| :param acc_s: Attention scores tensor of shape [BLOCK_M, BLOCK_N]. | ||
| :param row_max: Current maximum values per row of shape [BLOCK_M], init to -inf. | ||
| :param row_sum: Current sum values per row of shape [BLOCK_M], init to 0. | ||
| :param scale_log2: Log2 of the scaling factor to be applied to acc_s. | ||
| :param CHECK_INF: Boolean flag indicating if -inf row_max should be clamped to 0. | ||
| :param RESCALE_THRESHOLD: Threshold for rescaling to avoid underflow. If <= 0, rescaling is disabled. | ||
|
|
||
| :return p: Softmax probabilities tensor of shape [BLOCK_M, BLOCK_N]. | ||
| :return row_max_new: Updated maximum values per row of shape [BLOCK_M]. | ||
| :return row_sum_new: Updated sum values per row of shape [BLOCK_M]. | ||
| :return row_scale: Scaling factors per row of shape [BLOCK_M]. | ||
| """ | ||
| # Compute current row max | ||
| row_max_curr = tl.max(acc_s, axis=1) | ||
|
|
||
| # Update row max | ||
| row_max_new = tl.maximum(row_max_curr, row_max) | ||
|
|
||
| # Avoid exp(-inf - (-inf)) = nan by clamping -inf to 0 | ||
| if CHECK_INF: | ||
| row_max_new = check_inf(row_max_new) | ||
|
|
||
| # Compute scaled differences to new row max | ||
| acc_scale_log2 = (row_max - row_max_new) * scale_log2 | ||
|
|
||
| # Compute row scale | ||
| if RESCALE_THRESHOLD > 0.0: | ||
| # Triton can only skip computation at block granularity | ||
| if tl.min(acc_scale_log2) < -RESCALE_THRESHOLD: | ||
| row_scale = tl.exp2(acc_scale_log2) | ||
| else: | ||
| row_max_new = row_max | ||
| row_scale = acc_scale_log2 * 0.0 + 1.0 | ||
| else: | ||
| row_scale = tl.exp2(acc_scale_log2) | ||
|
|
||
| # Compute attention weights | ||
| p = tl.exp2(acc_s * scale_log2 - row_max_new[:, None] * scale_log2) | ||
|
|
||
| # Update row sum | ||
| row_sum_cur = tl.sum(p, axis=1) | ||
| row_sum_new = row_sum * row_scale + row_sum_cur | ||
|
|
||
| return p, row_max_new, row_sum_new, row_scale | ||
|
|
||
|
|
||
| @triton.jit | ||
| def online_sparse_softmax( | ||
| acc_s, | ||
| block_max, | ||
| row_max, | ||
| row_sum, | ||
| scale_log2, | ||
| softmax_threshold_log2, | ||
| CHECK_INF: tl.constexpr, | ||
| RESCALE_THRESHOLD: tl.constexpr, | ||
| ): | ||
|
Comment on lines
+74
to
+77
|
||
| """ | ||
| Apply online softmax to acc_s, and update block_max, row_max and row_sum. | ||
| Apply online sparse softmax to acc_s, and update block_max, row_max and row_sum. | ||
|
|
||
| :param acc_s: Attention scores tensor of shape [BLOCK_M, BLOCK_N]. | ||
| :param block_max: Running block-wise maximum scalar, init to -inf. | ||
|
|
@@ -28,7 +85,6 @@ def online_softmax( | |
| :param scale_log2: Log2 of the scaling factor to be applied to acc_s. | ||
| :param softmax_threshold_log2: Threshold in log2-domain for block-level skip. If > -inf and block max is below threshold relative to running max, skip softmax update. | ||
| :param CHECK_INF: Boolean flag indicating if -inf row_max should be clamped to 0. | ||
| :param RESCALE_THRESHOLD: Threshold for rescaling to avoid underflow. If <= 0, rescaling is disabled. | ||
|
|
||
| :return p: Softmax probabilities tensor of shape [BLOCK_M, BLOCK_N]. | ||
| :return block_max_new: Updated block-wise maximum scalar. | ||
|
|
@@ -69,15 +125,7 @@ def online_softmax( | |
| acc_scale_log2 = (row_max - row_max_new) * scale_log2 | ||
|
|
||
| # Compute row scale | ||
| if RESCALE_THRESHOLD > 0.0: | ||
| # Triton can only skip computation at block granularity | ||
| if tl.min(acc_scale_log2) < -RESCALE_THRESHOLD: | ||
| row_scale = tl.exp2(acc_scale_log2) | ||
| else: | ||
| row_max_new = row_max | ||
| row_scale = acc_scale_log2 * 0.0 + 1.0 | ||
| else: | ||
| row_scale = tl.exp2(acc_scale_log2) | ||
| row_scale = tl.exp2(acc_scale_log2) | ||
|
|
||
| # Compute attention weights | ||
| p = tl.exp2(acc_s * scale_log2 - row_max_new[:, None] * scale_log2) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR is described as a pure rename/comment clarity change, but this file introduces a new
online_softmaximplementation and renames the previous implementation toonline_sparse_softmax(plus signature/return-value changes used by dense kernels). Please update the PR description to reflect these functional refactors, or split the refactor into a separate PR for easier review/risk assessment.