Skip to content

Comments

[WIP] Add: Paged attention ST tests for PyPTO codegen#252

Closed
Crystal-wzy wants to merge 1 commit intohw-native-sys:mainfrom
Crystal-wzy:main
Closed

[WIP] Add: Paged attention ST tests for PyPTO codegen#252
Crystal-wzy wants to merge 1 commit intohw-native-sys:mainfrom
Crystal-wzy:main

Conversation

@Crystal-wzy
Copy link
Contributor

  • Add tests/st/codegen/test_paged_attention.py with system-level tests for the four core kernels of the Paged Attention algorithm

  • Add QKMatmulTestCase: verifies Q × K^T matrix multiplication using L1/L0 memory hierarchy and pl.matmul

  • Add SoftmaxPrepareTestCase: verifies scale → row_max → exp → row_sum pipeline; scale is passed via a config tensor encoded as INT32 bits

  • Add PVMatmulTestCase: verifies P × V matrix multiplication with the same L0A/L0B/L0C memory pattern as QK matmul

  • Add OnlineUpdateTestCase: verifies the online softmax update kernel covering all four runtime cases driven by is_first / is_last flags:

    • is_first=1, is_last=0: copy mij/lij/oi_new to accumulators
    • is_first=1, is_last=1: copy + normalize (single-block path)
    • is_first=0, is_last=0: full online update, store intermediate oi
    • is_first=0, is_last=1: full online update + normalize final dst
  • Add TestPagedAttentionKernels pytest class with parametrized methods for each kernel, enabling integration with the existing test_runner fixture

- Add `tests/st/codegen/test_paged_attention.py` with system-level tests
  for the four core kernels of the Paged Attention algorithm

- Add `QKMatmulTestCase`: verifies Q × K^T matrix multiplication using
  L1/L0 memory hierarchy and `pl.matmul`

- Add `SoftmaxPrepareTestCase`: verifies scale → row_max → exp → row_sum
  pipeline; scale is passed via a config tensor encoded as INT32 bits

- Add `PVMatmulTestCase`: verifies P × V matrix multiplication with the
  same L0A/L0B/L0C memory pattern as QK matmul

- Add `OnlineUpdateTestCase`: verifies the online softmax update kernel
  covering all four runtime cases driven by `is_first` / `is_last` flags:
  - `is_first=1, is_last=0`: copy mij/lij/oi_new to accumulators
  - `is_first=1, is_last=1`: copy + normalize (single-block path)
  - `is_first=0, is_last=0`: full online update, store intermediate oi
  - `is_first=0, is_last=1`: full online update + normalize final dst

- Add `TestPagedAttentionKernels` pytest class with parametrized methods
  for each kernel, enabling integration with the existing `test_runner`
  fixture
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Crystal-wzy, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a comprehensive suite of system-level tests for the fundamental kernels of the Paged Attention algorithm, specifically targeting the PyPTO codegen frontend. The new tests validate critical components such as QK and PV matrix multiplications, the softmax preparation pipeline, and the online update mechanism, ensuring the accuracy and robustness of the Paged Attention implementation.

Highlights

  • New Test File Added: A new system-level test file, tests/st/codegen/test_paged_attention.py, was added to validate the Paged Attention algorithm's core kernels.
  • QK Matmul Test Case: Introduced QKMatmulTestCase to verify Query (Q) multiplied by Key Transpose (K^T) matrix multiplication, utilizing L1/L0 memory hierarchy and pl.matmul.
  • Softmax Prepare Test Case: Implemented SoftmaxPrepareTestCase to validate the sequence of scaling, row maximum calculation, exponentiation, and row summation, with scale passed as an INT32-encoded config tensor.
  • PV Matmul Test Case: Added PVMatmulTestCase to verify Probability (P) multiplied by Value (V) matrix multiplication, mirroring the memory access patterns of the QK matmul.
  • Online Update Test Case: Created OnlineUpdateTestCase to cover all four runtime scenarios of the online softmax update kernel, managed by is_first and is_last flags.
  • Pytest Integration: Integrated all new test cases into a TestPagedAttentionKernels pytest class, enabling parameterization and use with the existing test_runner fixture.
Changelog
  • tests/st/codegen/test_paged_attention.py
    • Added a new file containing system-level tests for the Paged Attention algorithm's core kernels, including QKMatmul, SoftmaxPrepare, PVMatmul, and OnlineUpdate.
Activity
  • The pull request is marked as Work In Progress (WIP).
  • New system-level tests for Paged Attention kernels have been added.
  • Specific test cases for QKMatmul, SoftmaxPrepare, PVMatmul, and OnlineUpdate have been implemented.
  • These tests are integrated with the test_runner fixture for execution.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@coderabbitai
Copy link

coderabbitai bot commented Feb 24, 2026

Caution

Review failed

The pull request is closed.

ℹ️ Recent review info

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 81383e1 and 05f939d.

📒 Files selected for processing (1)
  • tests/st/codegen/test_paged_attention.py

📝 Walkthrough

Walkthrough

A new test module is added to provide comprehensive end-to-end testing scaffolding for Paged Attention kernels using PyPTO. The module defines four test case classes covering QK matmul, softmax preparation, PV matmul, and online update operations, along with a test suite that exercises each kernel with parameterized inputs.

Changes

Cohort / File(s) Summary
Paged Attention Test Module
tests/st/codegen/test_paged_attention.py
Introduces QKMatmulTestCase, SoftmaxPrepareTestCase, PVMatmulTestCase, and OnlineUpdateTestCase classes with tensor schemas, PyPTO program definitions, and NumPy-based expected result computations. Adds TestPagedAttentionKernels test suite with parameterized test methods for each kernel component.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

Suggested reviewers

  • Hzfengsy

Poem

🐰 With whiskers twitching, I draft tests anew,
For Paged Attention kernels, both tried and true,
QK matmuls, softmax flows, and PV delight,
Online updates computed, every result just right! ✨


Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds a comprehensive set of system-level tests for the paged attention kernels. The structure of the tests is good, with separate test cases for each kernel. However, I've identified a few areas for improvement. A recurring issue is the use of hardcoded shapes within the PyPTO programs, which makes the tests less robust and harder to maintain. More critically, the OnlineUpdateTestCase is not fully tested; its compute_expected method only covers one of four possible execution paths, and the corresponding pytest function isn't parameterized to trigger the other paths. I've provided detailed suggestions to address these points and improve test coverage and robustness.

Comment on lines +362 to +385
def compute_expected(self, tensors, params=None):
"""Compute expected output based on config values."""
# Read is_first and is_last from config tensor
is_first = int(tensors["config"][0])
is_last = int(tensors["config"][1])

mij = tensors["mij"].numpy()
lij = tensors["lij"].numpy()
oi_new = tensors["oi_new"].numpy()
mi = tensors["mi"].numpy().copy()
li = tensors["li"].numpy().copy()
oi = tensors["oi"].numpy().copy()

# Default test case: is_first=0, is_last=1 (last block case)
mi_new = np.maximum(mi, mij)
alpha = np.exp(mi - mi_new)
beta = np.exp(mij - mi_new)
li_updated = alpha * li + beta * lij
oi_updated = alpha * oi + beta * oi_new

tensors["mi"][:] = torch.from_numpy(mi_new)
tensors["li"][:] = torch.from_numpy(li_updated)
tensors["oi"][:] = torch.from_numpy(oi_updated)
tensors["dst"][:] = torch.from_numpy(oi_updated / li_updated)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The compute_expected method is incomplete. It only implements the logic for the is_first=0, is_last=1 case, but the OnlineUpdateTestCase is described as a 'unified test case' for all four behaviors of the online update kernel. The implementation should handle all four cases based on the is_first and is_last flags read from the config tensor to correctly validate the kernel's output in all scenarios. This is critical for ensuring the correctness of the online update logic.

    def compute_expected(self, tensors, params=None):
        """Compute expected output based on config values."""
        # Read is_first and is_last from config tensor
        is_first = int(tensors["config"][0])
        is_last = int(tensors["config"][1])

        mij = tensors["mij"].numpy()
        lij = tensors["lij"].numpy()
        oi_new = tensors["oi_new"].numpy()
        mi = tensors["mi"].numpy().copy()
        li = tensors["li"].numpy().copy()
        oi = tensors["oi"].numpy().copy()
        dst = np.zeros_like(tensors["dst"].numpy())

        if is_first:
            # First block: copy mij->mi, lij->li, oi_new->oi
            mi_new = mij
            li_updated = lij
            oi_updated = oi_new
            if is_last:
                # Single block: normalize dst = oi_new / lij
                dst = oi_new / lij
        else:
            # Not first: full online update
            mi_new = np.maximum(mi, mij)
            alpha = np.exp(mi - mi_new)
            beta = np.exp(mij - mi_new)
            li_updated = alpha * li + beta * lij
            oi_updated = alpha * oi + beta * oi_new
            if is_last:
                # Last block: normalize dst = oi_updated / li_updated
                dst = oi_updated / li_updated

        tensors["mi"][:] = torch.from_numpy(mi_new)
        tensors["li"][:] = torch.from_numpy(li_updated)
        tensors["oi"][:] = torch.from_numpy(oi_updated)
        tensors["dst"][:] = torch.from_numpy(dst)

Comment on lines +51 to +64
@pl.function(type=pl.FunctionType.InCore)
def qk_matmul(
self,
qi: pl.Tensor[[16, 16], pl.FP32],
kj_t: pl.Tensor[[16, 16], pl.FP32],
sij: pl.Tensor[[16, 16], pl.FP32],
) -> pl.Tensor[[16, 16], pl.FP32]:
qi_l1 = pl.load(qi, [0, 0], [16, 16], target_memory=2)
kj_l1 = pl.load(kj_t, [0, 0], [16, 16], target_memory=2)
qi_l0a = pl.move(qi_l1, target_memory=3)
kj_l0b = pl.move(kj_l1, target_memory=4)
sij_l0c = pl.matmul(qi_l0a, kj_l0b)
out_sij = pl.l0c_store(sij_l0c, [0, 0], [16, 16], sij)
return out_sij
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The tensor shapes within the QKMatmulProgram are hardcoded (e.g., pl.Tensor[[16, 16], pl.FP32]). This makes the test case brittle and not truly parametric, as it only works for the (16, 16) case. The shapes should be derived from the test case's self.num_heads and self.head_dim attributes to align with define_tensors and make the test robust. This issue is present in all TestCase classes in this file.

Comment on lines +407 to +411
@pytest.mark.parametrize("num_heads,head_dim", [(16, 16)])
def test_online_update(self, test_runner, num_heads, head_dim):
test_case = OnlineUpdateTestCase(num_heads=num_heads, head_dim=head_dim)
result = test_runner.run(test_case)
assert result.passed, f"Online update test failed: {result.error}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

To properly test the 'unified' OnlineUpdateTestCase and its different logic paths, this test should be parameterized over the four combinations of is_first and is_last flags. Currently, it only runs with the default configuration (is_first=0, is_last=1), leaving the other three kernel behaviors untested.

Suggested change
@pytest.mark.parametrize("num_heads,head_dim", [(16, 16)])
def test_online_update(self, test_runner, num_heads, head_dim):
test_case = OnlineUpdateTestCase(num_heads=num_heads, head_dim=head_dim)
result = test_runner.run(test_case)
assert result.passed, f"Online update test failed: {result.error}"
@pytest.mark.parametrize("num_heads,head_dim", [(16, 16)])
@pytest.mark.parametrize("is_first,is_last", [
(1, 0), # First block, more to come
(1, 1), # Single block case
(0, 0), # Middle block
(0, 1), # Last block
])
def test_online_update(self, test_runner, num_heads, head_dim, is_first, is_last):
class ParametrizedOnlineUpdateTestCase(OnlineUpdateTestCase):
def define_tensors(self) -> List[TensorSpec]:
tensors = super().define_tensors()
for spec in tensors:
if spec.name == "config":
spec.init_value = [is_first, is_last]
return tensors
test_case = ParametrizedOnlineUpdateTestCase(num_heads=num_heads, head_dim=head_dim)
result = test_runner.run(test_case)
assert result.passed, f"Online update test failed for is_first={is_first}, is_last={is_last}: {result.error}"

def get_program(self) -> Any:
import pypto.language as pl

scale_value = self.scale
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable scale_value is assigned but never used. It should be removed to avoid dead code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant