[WIP] Add: Paged attention ST tests for PyPTO codegen#252
[WIP] Add: Paged attention ST tests for PyPTO codegen#252Crystal-wzy wants to merge 1 commit intohw-native-sys:mainfrom
Conversation
- Add `tests/st/codegen/test_paged_attention.py` with system-level tests for the four core kernels of the Paged Attention algorithm - Add `QKMatmulTestCase`: verifies Q × K^T matrix multiplication using L1/L0 memory hierarchy and `pl.matmul` - Add `SoftmaxPrepareTestCase`: verifies scale → row_max → exp → row_sum pipeline; scale is passed via a config tensor encoded as INT32 bits - Add `PVMatmulTestCase`: verifies P × V matrix multiplication with the same L0A/L0B/L0C memory pattern as QK matmul - Add `OnlineUpdateTestCase`: verifies the online softmax update kernel covering all four runtime cases driven by `is_first` / `is_last` flags: - `is_first=1, is_last=0`: copy mij/lij/oi_new to accumulators - `is_first=1, is_last=1`: copy + normalize (single-block path) - `is_first=0, is_last=0`: full online update, store intermediate oi - `is_first=0, is_last=1`: full online update + normalize final dst - Add `TestPagedAttentionKernels` pytest class with parametrized methods for each kernel, enabling integration with the existing `test_runner` fixture
Summary of ChangesHello @Crystal-wzy, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a comprehensive suite of system-level tests for the fundamental kernels of the Paged Attention algorithm, specifically targeting the PyPTO codegen frontend. The new tests validate critical components such as QK and PV matrix multiplications, the softmax preparation pipeline, and the online update mechanism, ensuring the accuracy and robustness of the Paged Attention implementation. Highlights
Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
Caution Review failedThe pull request is closed. ℹ️ Recent review infoConfiguration used: Repository UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
📝 WalkthroughWalkthroughA new test module is added to provide comprehensive end-to-end testing scaffolding for Paged Attention kernels using PyPTO. The module defines four test case classes covering QK matmul, softmax preparation, PV matmul, and online update operations, along with a test suite that exercises each kernel with parameterized inputs. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested reviewers
Poem
Comment |
There was a problem hiding this comment.
Code Review
This pull request adds a comprehensive set of system-level tests for the paged attention kernels. The structure of the tests is good, with separate test cases for each kernel. However, I've identified a few areas for improvement. A recurring issue is the use of hardcoded shapes within the PyPTO programs, which makes the tests less robust and harder to maintain. More critically, the OnlineUpdateTestCase is not fully tested; its compute_expected method only covers one of four possible execution paths, and the corresponding pytest function isn't parameterized to trigger the other paths. I've provided detailed suggestions to address these points and improve test coverage and robustness.
| def compute_expected(self, tensors, params=None): | ||
| """Compute expected output based on config values.""" | ||
| # Read is_first and is_last from config tensor | ||
| is_first = int(tensors["config"][0]) | ||
| is_last = int(tensors["config"][1]) | ||
|
|
||
| mij = tensors["mij"].numpy() | ||
| lij = tensors["lij"].numpy() | ||
| oi_new = tensors["oi_new"].numpy() | ||
| mi = tensors["mi"].numpy().copy() | ||
| li = tensors["li"].numpy().copy() | ||
| oi = tensors["oi"].numpy().copy() | ||
|
|
||
| # Default test case: is_first=0, is_last=1 (last block case) | ||
| mi_new = np.maximum(mi, mij) | ||
| alpha = np.exp(mi - mi_new) | ||
| beta = np.exp(mij - mi_new) | ||
| li_updated = alpha * li + beta * lij | ||
| oi_updated = alpha * oi + beta * oi_new | ||
|
|
||
| tensors["mi"][:] = torch.from_numpy(mi_new) | ||
| tensors["li"][:] = torch.from_numpy(li_updated) | ||
| tensors["oi"][:] = torch.from_numpy(oi_updated) | ||
| tensors["dst"][:] = torch.from_numpy(oi_updated / li_updated) |
There was a problem hiding this comment.
The compute_expected method is incomplete. It only implements the logic for the is_first=0, is_last=1 case, but the OnlineUpdateTestCase is described as a 'unified test case' for all four behaviors of the online update kernel. The implementation should handle all four cases based on the is_first and is_last flags read from the config tensor to correctly validate the kernel's output in all scenarios. This is critical for ensuring the correctness of the online update logic.
def compute_expected(self, tensors, params=None):
"""Compute expected output based on config values."""
# Read is_first and is_last from config tensor
is_first = int(tensors["config"][0])
is_last = int(tensors["config"][1])
mij = tensors["mij"].numpy()
lij = tensors["lij"].numpy()
oi_new = tensors["oi_new"].numpy()
mi = tensors["mi"].numpy().copy()
li = tensors["li"].numpy().copy()
oi = tensors["oi"].numpy().copy()
dst = np.zeros_like(tensors["dst"].numpy())
if is_first:
# First block: copy mij->mi, lij->li, oi_new->oi
mi_new = mij
li_updated = lij
oi_updated = oi_new
if is_last:
# Single block: normalize dst = oi_new / lij
dst = oi_new / lij
else:
# Not first: full online update
mi_new = np.maximum(mi, mij)
alpha = np.exp(mi - mi_new)
beta = np.exp(mij - mi_new)
li_updated = alpha * li + beta * lij
oi_updated = alpha * oi + beta * oi_new
if is_last:
# Last block: normalize dst = oi_updated / li_updated
dst = oi_updated / li_updated
tensors["mi"][:] = torch.from_numpy(mi_new)
tensors["li"][:] = torch.from_numpy(li_updated)
tensors["oi"][:] = torch.from_numpy(oi_updated)
tensors["dst"][:] = torch.from_numpy(dst)| @pl.function(type=pl.FunctionType.InCore) | ||
| def qk_matmul( | ||
| self, | ||
| qi: pl.Tensor[[16, 16], pl.FP32], | ||
| kj_t: pl.Tensor[[16, 16], pl.FP32], | ||
| sij: pl.Tensor[[16, 16], pl.FP32], | ||
| ) -> pl.Tensor[[16, 16], pl.FP32]: | ||
| qi_l1 = pl.load(qi, [0, 0], [16, 16], target_memory=2) | ||
| kj_l1 = pl.load(kj_t, [0, 0], [16, 16], target_memory=2) | ||
| qi_l0a = pl.move(qi_l1, target_memory=3) | ||
| kj_l0b = pl.move(kj_l1, target_memory=4) | ||
| sij_l0c = pl.matmul(qi_l0a, kj_l0b) | ||
| out_sij = pl.l0c_store(sij_l0c, [0, 0], [16, 16], sij) | ||
| return out_sij |
There was a problem hiding this comment.
The tensor shapes within the QKMatmulProgram are hardcoded (e.g., pl.Tensor[[16, 16], pl.FP32]). This makes the test case brittle and not truly parametric, as it only works for the (16, 16) case. The shapes should be derived from the test case's self.num_heads and self.head_dim attributes to align with define_tensors and make the test robust. This issue is present in all TestCase classes in this file.
| @pytest.mark.parametrize("num_heads,head_dim", [(16, 16)]) | ||
| def test_online_update(self, test_runner, num_heads, head_dim): | ||
| test_case = OnlineUpdateTestCase(num_heads=num_heads, head_dim=head_dim) | ||
| result = test_runner.run(test_case) | ||
| assert result.passed, f"Online update test failed: {result.error}" |
There was a problem hiding this comment.
To properly test the 'unified' OnlineUpdateTestCase and its different logic paths, this test should be parameterized over the four combinations of is_first and is_last flags. Currently, it only runs with the default configuration (is_first=0, is_last=1), leaving the other three kernel behaviors untested.
| @pytest.mark.parametrize("num_heads,head_dim", [(16, 16)]) | |
| def test_online_update(self, test_runner, num_heads, head_dim): | |
| test_case = OnlineUpdateTestCase(num_heads=num_heads, head_dim=head_dim) | |
| result = test_runner.run(test_case) | |
| assert result.passed, f"Online update test failed: {result.error}" | |
| @pytest.mark.parametrize("num_heads,head_dim", [(16, 16)]) | |
| @pytest.mark.parametrize("is_first,is_last", [ | |
| (1, 0), # First block, more to come | |
| (1, 1), # Single block case | |
| (0, 0), # Middle block | |
| (0, 1), # Last block | |
| ]) | |
| def test_online_update(self, test_runner, num_heads, head_dim, is_first, is_last): | |
| class ParametrizedOnlineUpdateTestCase(OnlineUpdateTestCase): | |
| def define_tensors(self) -> List[TensorSpec]: | |
| tensors = super().define_tensors() | |
| for spec in tensors: | |
| if spec.name == "config": | |
| spec.init_value = [is_first, is_last] | |
| return tensors | |
| test_case = ParametrizedOnlineUpdateTestCase(num_heads=num_heads, head_dim=head_dim) | |
| result = test_runner.run(test_case) | |
| assert result.passed, f"Online update test failed for is_first={is_first}, is_last={is_last}: {result.error}" |
| def get_program(self) -> Any: | ||
| import pypto.language as pl | ||
|
|
||
| scale_value = self.scale |
Add
tests/st/codegen/test_paged_attention.pywith system-level tests for the four core kernels of the Paged Attention algorithmAdd
QKMatmulTestCase: verifies Q × K^T matrix multiplication using L1/L0 memory hierarchy andpl.matmulAdd
SoftmaxPrepareTestCase: verifies scale → row_max → exp → row_sum pipeline; scale is passed via a config tensor encoded as INT32 bitsAdd
PVMatmulTestCase: verifies P × V matrix multiplication with the same L0A/L0B/L0C memory pattern as QK matmulAdd
OnlineUpdateTestCase: verifies the online softmax update kernel covering all four runtime cases driven byis_first/is_lastflags:is_first=1, is_last=0: copy mij/lij/oi_new to accumulatorsis_first=1, is_last=1: copy + normalize (single-block path)is_first=0, is_last=0: full online update, store intermediate oiis_first=0, is_last=1: full online update + normalize final dstAdd
TestPagedAttentionKernelspytest class with parametrized methods for each kernel, enabling integration with the existingtest_runnerfixture