Skip to content

Comments

[WIP] test(st): add Paged Attention system tests for four PTO kernels#254

Open
Crystal-wzy wants to merge 1 commit intohw-native-sys:mainfrom
Crystal-wzy:main
Open

[WIP] test(st): add Paged Attention system tests for four PTO kernels#254
Crystal-wzy wants to merge 1 commit intohw-native-sys:mainfrom
Crystal-wzy:main

Conversation

@Crystal-wzy
Copy link

Add system-level tests for the Paged Attention algorithm using the PTOTestCase framework, covering four kernels:

  • QKMatmulTestCase: sij = qi @ kj_t via L1 → L0A/L0B → L0C memory flow
  • SoftmaxPrepareTestCase: scale → row_max → row_expand_sub → exp → row_sum pipeline; scale passed via INT32-encoded config tensor
  • PVMatmulTestCase: oi_new = pij @ vj with the same L0C path
  • OnlineUpdateTestCase: online normalizer update with is_first/is_last flags covering all four runtime branches; PyTorch golden validation

Each test case implements define_tensors(), get_program(), and compute_expected() and is driven by the test_runner fixture.

Add system-level tests for the Paged Attention algorithm using the
PTOTestCase framework, covering four kernels:

- QKMatmulTestCase: sij = qi @ kj_t via L1 → L0A/L0B → L0C memory flow
- SoftmaxPrepareTestCase: scale → row_max → row_expand_sub → exp →
  row_sum pipeline; scale passed via INT32-encoded config tensor
- PVMatmulTestCase: oi_new = pij @ vj with the same L0C path
- OnlineUpdateTestCase: online normalizer update with is_first/is_last
  flags covering all four runtime branches; PyTorch golden validation

Each test case implements define_tensors(), get_program(), and
compute_expected() and is driven by the test_runner fixture.
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Crystal-wzy, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a comprehensive suite of system-level tests for the Paged Attention algorithm's core kernels. By leveraging the PTOTestCase framework, these tests validate the correctness of critical matrix multiplication, softmax preparation, and online update operations, ensuring the reliability and accuracy of the underlying PyPTO implementations. The new tests cover various memory access patterns and conditional logic within the kernels, providing robust verification.

Highlights

  • QK Matmul Kernel Test: Added a system test for the QK Matmul kernel, verifying the sij = qi @ kj_t computation and its memory flow from Global Memory (GM) through L1, L0A/L0B, to L0C.
  • Softmax Prepare Kernel Test: Implemented a system test for the Softmax Prepare kernel, covering the sij * scale, row_max, exp, and row_sum pipeline, with the scale factor passed via an INT32-encoded config tensor.
  • PV Matmul Kernel Test: Introduced a system test for the PV Matmul kernel, confirming the oi_new = pij @ vj computation and its memory flow, similar to the QK Matmul kernel.
  • Online Update Kernel Test: Created a unified system test for the Online Update kernel, validating all four runtime branches based on is_first and is_last flags, and including PyTorch golden validation for correctness.
  • PTOTestCase Framework Integration: All new tests are built upon the PTOTestCase framework, implementing define_tensors(), get_program(), and compute_expected() methods, and are driven by the test_runner fixture for consistent execution and validation.
Changelog
  • tests/st/codegen/test_paged_attention.py
    • Added QKMatmulTestCase for testing qi @ kj_t operations.
    • Added SoftmaxPrepareTestCase to validate softmax scaling, row maximum, exponential, and row sum calculations.
    • Added PVMatmulTestCase for verifying pij @ vj computations.
    • Added OnlineUpdateTestCase to test the online normalizer update logic across different block scenarios (first, middle, last, single).
    • Implemented TestPagedAttentionKernels class to orchestrate and run these new system tests using pytest.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@coderabbitai
Copy link

coderabbitai bot commented Feb 24, 2026

📝 Walkthrough

Walkthrough

Introduces a new end-to-end test file for paged attention kernels using PyPTO frontend, containing four test case classes that validate QK matmul, softmax preparation, value projection matmul, and online update operations, along with reference implementations for correctness verification.

Changes

Cohort / File(s) Summary
Paged Attention Test Suite
tests/st/codegen/test_paged_attention.py
New file with 455 lines introducing QKMatmulTestCase, SoftmaxPrepareTestCase, PVMatmulTestCase, and OnlineUpdateTestCase classes for testing paged attention kernel components. Each test case defines tensor inputs, instantiates a PTO program, and provides reference compute_expected implementations. TestPagedAttentionKernels integrates all four test methods with parametrized execution via test runner fixture.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

Suggested reviewers

  • Hzfengsy

Poem

🐰 Hops of joy through kernels four!
QK matmul, softmax lore,
PV dancing, updates bright—
PyPTO tests shining light!
Attention perfected in sight!

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 3.13% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly describes the main change: adding system tests for Paged Attention kernels using PTO framework, which aligns with the changeset introducing test_paged_attention.py.
Description check ✅ Passed The description provides relevant context about the four test cases and their purposes, directly relating to the new test file added in the changeset.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds a valuable set of system tests for Paged Attention kernels. The structure of the tests is clear and follows the PTOTestCase framework well. My review focuses on two main areas for improvement: making the PyPTO programs within the test cases more flexible by using dynamic shapes instead of hardcoded values, and significantly improving the test coverage for the OnlineUpdateTestCase to validate all its logical branches. All original comments are valid and align with good testing practices, so they have been kept as is.

Comment on lines +391 to +423
def compute_expected(self, tensors, params=None):
"""Compute expected output for the last-block case (is_first=0, is_last=1).

config[0]=is_first and config[1]=is_last are read but not used for
branching; the computation is fixed to the last-block path matching
the default config [0, 1].
"""
# Read is_first and is_last from config tensor (unused in computation below)
is_first = int(tensors["config"][0])
is_last = int(tensors["config"][1])

mij = tensors["mij"].numpy()
lij = tensors["lij"].numpy()
oi_new = tensors["oi_new"].numpy()
mi = tensors["mi"].numpy().copy()
li = tensors["li"].numpy().copy()
oi = tensors["oi"].numpy().copy()

# Default test case: is_first=0, is_last=1 (last block case)
# mi_new = max(mi, mij); alpha = exp(mi - mi_new); beta = exp(mij - mi_new)
# li_updated = alpha * li + beta * lij
# oi_updated = alpha * oi + beta * oi_new
# dst = oi_updated / li_updated
mi_new = np.maximum(mi, mij)
alpha = np.exp(mi - mi_new)
beta = np.exp(mij - mi_new)
li_updated = alpha * li + beta * lij
oi_updated = alpha * oi + beta * oi_new

tensors["mi"][:] = torch.from_numpy(mi_new)
tensors["li"][:] = torch.from_numpy(li_updated)
tensors["oi"][:] = torch.from_numpy(oi_updated)
tensors["dst"][:] = torch.from_numpy(oi_updated / li_updated)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The compute_expected method for OnlineUpdateTestCase is hardcoded to test only the is_first=0, is_last=1 case. This leaves the other three logical branches of the online_update kernel untested. The implementation should be updated to handle all four combinations of is_first and is_last to ensure full test coverage.

To properly test this, OnlineUpdateTestCase should be parameterized for is_first and is_last, and test_online_update should be updated to run all combinations (e.g., using @pytest.mark.parametrize).

    def compute_expected(self, tensors, params=None):
        """Compute expected output for all four is_first/is_last cases."""
        is_first = int(tensors["config"][0])
        is_last = int(tensors["config"][1])

        mij = tensors["mij"].numpy()
        lij = tensors["lij"].numpy()
        oi_new = tensors["oi_new"].numpy()
        mi = tensors["mi"].numpy().copy()
        li = tensors["li"].numpy().copy()
        oi = tensors["oi"].numpy().copy()

        dst = np.zeros_like(tensors["dst"].numpy())

        if is_first:
            # First block
            mi[:] = mij
            li[:] = lij
            oi[:] = oi_new
            if is_last:
                # Single block case
                dst[:] = oi_new / lij
        else:
            # Not first block: full online update
            mi_new = np.maximum(mi, mij)
            alpha = np.exp(mi - mi_new)
            beta = np.exp(mij - mi_new)
            li_updated = alpha * li + beta * lij
            oi_updated = alpha * oi + beta * oi_new

            mi[:] = mi_new
            li[:] = li_updated
            oi[:] = oi_updated

            if is_last:
                # Last block
                dst[:] = oi_updated / li_updated

        tensors["mi"][:] = torch.from_numpy(mi)
        tensors["li"][:] = torch.from_numpy(li)
        tensors["oi"][:] = torch.from_numpy(oi)
        tensors["dst"][:] = torch.from_numpy(dst)

Comment on lines +57 to +84
def get_program(self) -> Any:
import pypto.language as pl

@pl.program
class QKMatmulProgram:
@pl.function(type=pl.FunctionType.InCore)
def qk_matmul(
self,
qi: pl.Tensor[[16, 16], pl.FP32],
kj_t: pl.Tensor[[16, 16], pl.FP32],
sij: pl.Tensor[[16, 16], pl.FP32],
) -> pl.Tensor[[16, 16], pl.FP32]:
qi_l1 = pl.load(qi, [0, 0], [16, 16], target_memory=2) # Load qi to L1
kj_l1 = pl.load(kj_t, [0, 0], [16, 16], target_memory=2) # Load kj_t to L1
qi_l0a = pl.move(qi_l1, target_memory=3) # Move qi L1 -> L0A
kj_l0b = pl.move(kj_l1, target_memory=4) # Move kj_t L1 -> L0B
sij_l0c = pl.matmul(qi_l0a, kj_l0b) # Compute qi @ kj_t in L0C
out_sij = pl.l0c_store(sij_l0c, [0, 0], [16, 16], sij) # Store L0C -> GM
return out_sij

@pl.function(type=pl.FunctionType.Orchestration)
def orchestrator(
self, qi: pl.Tensor[[16, 16], pl.FP32], kj_t: pl.Tensor[[16, 16], pl.FP32]
) -> pl.Tensor[[16, 16], pl.FP32]:
out_sij = self.qk_matmul(qi, kj_t)
return out_sij

return QKMatmulProgram
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The QKMatmulProgram uses hardcoded shapes (e.g., [16, 16]). This limits the test's flexibility and makes it difficult to parameterize with different num_heads and head_dim values. The program should be made dynamic by capturing these values from the test case instance.

    def get_program(self) -> Any:
        import pypto.language as pl

        num_heads = self.num_heads
        head_dim = self.head_dim

        @pl.program
        class QKMatmulProgram:
            @pl.function(type=pl.FunctionType.InCore)
            def qk_matmul(
                self,
                qi: pl.Tensor[[num_heads, head_dim], pl.FP32],
                kj_t: pl.Tensor[[head_dim, num_heads], pl.FP32],
                sij: pl.Tensor[[num_heads, num_heads], pl.FP32],
            ) -> pl.Tensor[[num_heads, num_heads], pl.FP32]:
                qi_l1 = pl.load(qi, [0, 0], [num_heads, head_dim], target_memory=2)    # Load qi to L1
                kj_l1 = pl.load(kj_t, [0, 0], [head_dim, num_heads], target_memory=2)  # Load kj_t to L1
                qi_l0a = pl.move(qi_l1, target_memory=3)                   # Move qi L1 -> L0A
                kj_l0b = pl.move(kj_l1, target_memory=4)                   # Move kj_t L1 -> L0B
                sij_l0c = pl.matmul(qi_l0a, kj_l0b)                        # Compute qi @ kj_t in L0C
                out_sij = pl.l0c_store(sij_l0c, [0, 0], [num_heads, num_heads], sij)    # Store L0C -> GM
                return out_sij

            @pl.function(type=pl.FunctionType.Orchestration)
            def orchestrator(
                self, qi: pl.Tensor[[num_heads, head_dim], pl.FP32], kj_t: pl.Tensor[[head_dim, num_heads], pl.FP32]
            ) -> pl.Tensor[[num_heads, num_heads], pl.FP32]:
                out_sij = self.qk_matmul(qi, kj_t)
                return out_sij

        return QKMatmulProgram

Comment on lines +122 to +177
def get_program(self) -> Any:
import pypto.language as pl

scale_value = self.scale # not referenced in the kernel; scale is passed via config tensor at runtime

@pl.program
class SoftmaxPrepareProgram:
@pl.function(type=pl.FunctionType.InCore)
def softmax_prepare(
self,
sij: pl.Tensor[[16, 16], pl.FP32],
scale: pl.Scalar[pl.FP32],
pij: pl.Tensor[[16, 16], pl.FP32],
mij: pl.Tensor[[16, 1], pl.FP32],
lij: pl.Tensor[[16, 1], pl.FP32],
) -> tuple[pl.Tensor[[16, 16], pl.FP32], pl.Tensor[[16, 1], pl.FP32], pl.Tensor[[16, 1], pl.FP32]]:
# Load sij to UB (target_memory=1)
sij_tile = pl.load(sij, [0, 0], [16, 16], target_memory=1)

# Scale: sij * scale_factor
sij_scaled = pl.muls(sij_tile, scale)

# Create temp tile for row reduction
tmp_tile = pl.create_tile([16, 16], dtype=pl.FP32, target_memory=1)

# Row max: mij = max(sij_scaled, axis=1) -> [16, 1] DN format
mij_tile = pl.row_max(sij_scaled, tmp_tile)

# Row broadcast subtraction: sij_scaled - mij
sij_centered = pl.row_expand_sub(sij_scaled, mij_tile)

# Exp: exp(sij_centered)
pij_tile = pl.exp(sij_centered)

# Row sum: lij = sum(pij, axis=1) -> [16, 1] DN format
lij_tile = pl.row_sum(pij_tile, tmp_tile)

# Store results
pij_out = pl.store(pij_tile, [0, 0], [16, 16], pij)
mij_out = pl.store(mij_tile, [0, 0], [16, 1], mij)
lij_out = pl.store(lij_tile, [0, 0], [16, 1], lij)

return pij_out, mij_out, lij_out

@pl.function(type=pl.FunctionType.Orchestration)
def orchestrator(
self,
sij: pl.Tensor[[16, 16], pl.FP32],
config: pl.Tensor[[1], pl.INT32],
) -> tuple[pl.Tensor[[16, 16], pl.FP32], pl.Tensor[[16, 1], pl.FP32], pl.Tensor[[16, 1], pl.FP32]]:
# Read scale value from config tensor
scale: pl.Scalar[pl.FP32] = pl.tensor.read(config, [0])
pij_out, mij_out, lij_out = self.softmax_prepare(sij, scale)
return pij_out, mij_out, lij_out

return SoftmaxPrepareProgram
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The SoftmaxPrepareProgram uses hardcoded shapes (e.g., [16, 16]). This prevents parameterizing the test with different num_heads and block_size values. The program should use dynamic shapes from the test case instance to improve flexibility. Additionally, the unused variable scale_value on line 125 can be removed.

    def get_program(self) -> Any:
        import pypto.language as pl
        
        num_heads = self.num_heads
        block_size = self.block_size

        @pl.program
        class SoftmaxPrepareProgram:
            @pl.function(type=pl.FunctionType.InCore)
            def softmax_prepare(
                self,
                sij: pl.Tensor[[num_heads, block_size], pl.FP32],
                scale: pl.Scalar[pl.FP32],
                pij: pl.Tensor[[num_heads, block_size], pl.FP32],
                mij: pl.Tensor[[num_heads, 1], pl.FP32],
                lij: pl.Tensor[[num_heads, 1], pl.FP32],
            ) -> tuple[pl.Tensor[[num_heads, block_size], pl.FP32], pl.Tensor[[num_heads, 1], pl.FP32], pl.Tensor[[num_heads, 1], pl.FP32]]:
                # Load sij to UB (target_memory=1)
                sij_tile = pl.load(sij, [0, 0], [num_heads, block_size], target_memory=1)
                
                # Scale: sij * scale_factor
                sij_scaled = pl.muls(sij_tile, scale)
                
                # Create temp tile for row reduction
                tmp_tile = pl.create_tile([num_heads, block_size], dtype=pl.FP32, target_memory=1)
                
                # Row max: mij = max(sij_scaled, axis=1) -> [num_heads, 1] DN format
                mij_tile = pl.row_max(sij_scaled, tmp_tile)
                
                # Row broadcast subtraction: sij_scaled - mij
                sij_centered = pl.row_expand_sub(sij_scaled, mij_tile)
                
                # Exp: exp(sij_centered)
                pij_tile = pl.exp(sij_centered)
                
                # Row sum: lij = sum(pij, axis=1) -> [num_heads, 1] DN format
                lij_tile = pl.row_sum(pij_tile, tmp_tile)
                
                # Store results
                pij_out = pl.store(pij_tile, [0, 0], [num_heads, block_size], pij)
                mij_out = pl.store(mij_tile, [0, 0], [num_heads, 1], mij)
                lij_out = pl.store(lij_tile, [0, 0], [num_heads, 1], lij)
                
                return pij_out, mij_out, lij_out

            @pl.function(type=pl.FunctionType.Orchestration)
            def orchestrator(
                self,
                sij: pl.Tensor[[num_heads, block_size], pl.FP32],
                config: pl.Tensor[[1], pl.INT32],
            ) -> tuple[pl.Tensor[[num_heads, block_size], pl.FP32], pl.Tensor[[num_heads, 1], pl.FP32], pl.Tensor[[num_heads, 1], pl.FP32]]:
                # Read scale value from config tensor
                scale: pl.Scalar[pl.FP32] = pl.tensor.read(config, [0])
                pij_out, mij_out, lij_out = self.softmax_prepare(sij, scale)
                return pij_out, mij_out, lij_out

        return SoftmaxPrepareProgram

Comment on lines +216 to +243
def get_program(self) -> Any:
import pypto.language as pl

@pl.program
class PVMatmulProgram:
@pl.function(type=pl.FunctionType.InCore)
def pv_matmul(
self,
pij: pl.Tensor[[16, 16], pl.FP32],
vj: pl.Tensor[[16, 16], pl.FP32],
oi_new: pl.Tensor[[16, 16], pl.FP32],
) -> pl.Tensor[[16, 16], pl.FP32]:
pij_l1 = pl.load(pij, [0, 0], [16, 16], target_memory=2) # Load pij to L1
vj_l1 = pl.load(vj, [0, 0], [16, 16], target_memory=2) # Load vj to L1
pij_l0a = pl.move(pij_l1, target_memory=3) # Move pij L1 -> L0A
vj_l0b = pl.move(vj_l1, target_memory=4) # Move vj L1 -> L0B
oi_l0c = pl.matmul(pij_l0a, vj_l0b) # Compute pij @ vj in L0C
out_oi = pl.l0c_store(oi_l0c, [0, 0], [16, 16], oi_new) # Store L0C -> GM
return out_oi

@pl.function(type=pl.FunctionType.Orchestration)
def orchestrator(
self, pij: pl.Tensor[[16, 16], pl.FP32], vj: pl.Tensor[[16, 16], pl.FP32]
) -> pl.Tensor[[16, 16], pl.FP32]:
out_oi = self.pv_matmul(pij, vj)
return out_oi

return PVMatmulProgram
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The PVMatmulProgram uses hardcoded shapes (e.g., [16, 16]). This makes it difficult to extend the test for different num_heads and head_dim values. The program should be made dynamic by using the self.num_heads and self.head_dim from the test case instance.

    def get_program(self) -> Any:
        import pypto.language as pl

        num_heads = self.num_heads
        head_dim = self.head_dim

        @pl.program
        class PVMatmulProgram:
            @pl.function(type=pl.FunctionType.InCore)
            def pv_matmul(
                self,
                pij: pl.Tensor[[num_heads, num_heads], pl.FP32],
                vj: pl.Tensor[[num_heads, head_dim], pl.FP32],
                oi_new: pl.Tensor[[num_heads, head_dim], pl.FP32],
            ) -> pl.Tensor[[num_heads, head_dim], pl.FP32]:
                pij_l1 = pl.load(pij, [0, 0], [num_heads, num_heads], target_memory=2)  # Load pij to L1
                vj_l1 = pl.load(vj, [0, 0], [num_heads, head_dim], target_memory=2)    # Load vj to L1
                pij_l0a = pl.move(pij_l1, target_memory=3)                  # Move pij L1 -> L0A
                vj_l0b = pl.move(vj_l1, target_memory=4)                    # Move vj L1 -> L0B
                oi_l0c = pl.matmul(pij_l0a, vj_l0b)                         # Compute pij @ vj in L0C
                out_oi = pl.l0c_store(oi_l0c, [0, 0], [num_heads, head_dim], oi_new)   # Store L0C -> GM
                return out_oi

            @pl.function(type=pl.FunctionType.Orchestration)
            def orchestrator(
                self, pij: pl.Tensor[[num_heads, num_heads], pl.FP32], vj: pl.Tensor[[num_heads, head_dim], pl.FP32]
            ) -> pl.Tensor[[num_heads, head_dim], pl.FP32]:
                out_oi = self.pv_matmul(pij, vj)
                return out_oi

        return PVMatmulProgram

Comment on lines +283 to +389
def get_program(self) -> Any:
import pypto.language as pl

@pl.program
class OnlineUpdateProgram:
@pl.function(type=pl.FunctionType.InCore)
def online_update(
self,
mij: pl.Tensor[[16, 1], pl.FP32],
lij: pl.Tensor[[16, 1], pl.FP32],
oi_new: pl.Tensor[[16, 16], pl.FP32],
mi: pl.Tensor[[16, 1], pl.FP32],
li: pl.Tensor[[16, 1], pl.FP32],
oi: pl.Tensor[[16, 16], pl.FP32],
is_first: pl.Scalar[pl.BOOL],
is_last: pl.Scalar[pl.BOOL],
dst: pl.Tensor[[16, 16], pl.FP32],
) -> tuple[pl.Tensor[[16, 1], pl.FP32], pl.Tensor[[16, 1], pl.FP32],
pl.Tensor[[16, 16], pl.FP32], pl.Tensor[[16, 16], pl.FP32]]:
# Load all inputs
mij_tile = pl.load(mij, [0, 0], [16, 1], target_memory=1)
lij_tile = pl.load(lij, [0, 0], [16, 1], target_memory=1)
oi_new_tile = pl.load(oi_new, [0, 0], [16, 16], target_memory=1)
mi_tile = pl.load(mi, [0, 0], [16, 1], target_memory=1)
li_tile = pl.load(li, [0, 0], [16, 1], target_memory=1)
oi_tile = pl.load(oi, [0, 0], [16, 16], target_memory=1)

if is_first:
# First block: copy mij->mi, lij->li, oi_new->oi
mi_out = pl.store(mij_tile, [0, 0], [16, 1], mi)
li_out = pl.store(lij_tile, [0, 0], [16, 1], li)
oi_out = pl.store(oi_new_tile, [0, 0], [16, 16], oi)
if is_last:
# Single block: normalize dst = oi_new / lij
dst_tile = pl.row_expand_div(oi_new_tile, lij_tile)
dst_out = pl.store(dst_tile, [0, 0], [16, 16], dst)
else:
# First but not last: no dst output
zero_tile = pl.create_tile([16, 16], dtype=pl.FP32, target_memory=1)
dst_out = pl.store(zero_tile, [0, 0], [16, 16], dst)
else:
# Not first: full online update
# Reshape DN [16,1] -> ND [1,16] for element-wise ops
mi_tile_nd = pl.reshape(mi_tile, [1, 16])
mij_tile_nd = pl.reshape(mij_tile, [1, 16])
li_tile_nd = pl.reshape(li_tile, [1, 16])
lij_tile_nd = pl.reshape(lij_tile, [1, 16])

# mi_new = max(mi, mij): new running row maximum
mi_new = pl.maximum(mi_tile_nd, mij_tile_nd)
# alpha = exp(mi - mi_new): rescale factor for accumulated oi
mi_diff = pl.sub(mi_tile_nd, mi_new)
alpha = pl.exp(mi_diff)
# beta = exp(mij - mi_new): rescale factor for current oi_new
mij_diff = pl.sub(mij_tile_nd, mi_new)
beta = pl.exp(mij_diff)

# li_updated = alpha * li + beta * lij: updated normalizer
li_scaled = pl.mul(alpha, li_tile_nd)
lij_scaled = pl.mul(beta, lij_tile_nd)
li_updated = pl.add(li_scaled, lij_scaled)

# oi_updated = alpha * oi + beta * oi_new: updated attention output
alpha_dn = pl.reshape(alpha, [16, 1]) # Reshape [1,16] -> [16,1] DN for row_expand_mul
oi_scaled = pl.row_expand_mul(oi_tile, alpha_dn)
beta_dn = pl.reshape(beta, [16, 1]) # Reshape [1,16] -> [16,1] DN for row_expand_mul
oi_new_scaled = pl.row_expand_mul(oi_new_tile, beta_dn)
oi_updated = pl.add(oi_scaled, oi_new_scaled)

mi_new_dn = pl.reshape(mi_new, [16, 1]) # Reshape back to DN [16,1] for store
li_updated_dn = pl.reshape(li_updated, [16, 1]) # Reshape back to DN [16,1] for store

mi_out = pl.store(mi_new_dn, [0, 0], [16, 1], mi)
li_out = pl.store(li_updated_dn, [0, 0], [16, 1], li)

if is_last:
# Last block: normalize dst = oi / li
dst_tile = pl.row_expand_div(oi_updated, li_updated_dn)
dst_out = pl.store(dst_tile, [0, 0], [16, 16], dst)
oi_out = pl.store(oi_updated, [0, 0], [16, 16], oi)
else:
# Middle block: no normalize
oi_out = pl.store(oi_updated, [0, 0], [16, 16], oi)
zero_tile = pl.create_tile([16, 16], dtype=pl.FP32, target_memory=1)
dst_out = pl.store(zero_tile, [0, 0], [16, 16], dst)

return mi_out, li_out, oi_out, dst_out

@pl.function(type=pl.FunctionType.Orchestration)
def orchestrator(
self,
mij: pl.Tensor[[16, 1], pl.FP32],
lij: pl.Tensor[[16, 1], pl.FP32],
oi_new: pl.Tensor[[16, 16], pl.FP32],
config: pl.Tensor[[2], pl.INT64],
mi: pl.Tensor[[16, 1], pl.FP32],
li: pl.Tensor[[16, 1], pl.FP32],
oi: pl.Tensor[[16, 16], pl.FP32],
) -> tuple[pl.Tensor[[16, 1], pl.FP32], pl.Tensor[[16, 1], pl.FP32],
pl.Tensor[[16, 16], pl.FP32], pl.Tensor[[16, 16], pl.FP32]]:
# Read is_first and is_last from config tensor
is_first: pl.Scalar[pl.INT64] = pl.tensor.read(config, [0])
is_last: pl.Scalar[pl.INT64] = pl.tensor.read(config, [1])
mi, li, oi, dst = self.online_update(mij, lij, oi_new, mi, li, oi, is_first, is_last)
return mi, li, oi, dst

return OnlineUpdateProgram
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The OnlineUpdateProgram uses hardcoded shapes (e.g., [16, 16]). This makes it difficult to extend the test for different num_heads and head_dim values. The program should be made dynamic by using the self.num_heads and self.head_dim from the test case instance. This will make the test more robust and easier to parameterize.

    def get_program(self) -> Any:
        import pypto.language as pl

        num_heads = self.num_heads
        head_dim = self.head_dim

        @pl.program
        class OnlineUpdateProgram:
            @pl.function(type=pl.FunctionType.InCore)
            def online_update(
                self,
                mij: pl.Tensor[[num_heads, 1], pl.FP32],
                lij: pl.Tensor[[num_heads, 1], pl.FP32],
                oi_new: pl.Tensor[[num_heads, head_dim], pl.FP32],
                mi: pl.Tensor[[num_heads, 1], pl.FP32],
                li: pl.Tensor[[num_heads, 1], pl.FP32],
                oi: pl.Tensor[[num_heads, head_dim], pl.FP32],
                is_first: pl.Scalar[pl.BOOL],
                is_last: pl.Scalar[pl.BOOL],
                dst: pl.Tensor[[num_heads, head_dim], pl.FP32],
            ) -> tuple[pl.Tensor[[num_heads, 1], pl.FP32], pl.Tensor[[num_heads, 1], pl.FP32],
                       pl.Tensor[[num_heads, head_dim], pl.FP32], pl.Tensor[[num_heads, head_dim], pl.FP32]]:
                # Load all inputs
                mij_tile = pl.load(mij, [0, 0], [num_heads, 1], target_memory=1)
                lij_tile = pl.load(lij, [0, 0], [num_heads, 1], target_memory=1)
                oi_new_tile = pl.load(oi_new, [0, 0], [num_heads, head_dim], target_memory=1)
                mi_tile = pl.load(mi, [0, 0], [num_heads, 1], target_memory=1)
                li_tile = pl.load(li, [0, 0], [num_heads, 1], target_memory=1)
                oi_tile = pl.load(oi, [0, 0], [num_heads, head_dim], target_memory=1)

                if is_first:
                    # First block: copy mij->mi, lij->li, oi_new->oi
                    mi_out = pl.store(mij_tile, [0, 0], [num_heads, 1], mi)
                    li_out = pl.store(lij_tile, [0, 0], [num_heads, 1], li)
                    oi_out = pl.store(oi_new_tile, [0, 0], [num_heads, head_dim], oi)
                    if is_last:
                        # Single block: normalize dst = oi_new / lij
                        dst_tile = pl.row_expand_div(oi_new_tile, lij_tile)
                        dst_out = pl.store(dst_tile, [0, 0], [num_heads, head_dim], dst)
                    else:
                        # First but not last: no dst output
                        zero_tile = pl.create_tile([num_heads, head_dim], dtype=pl.FP32, target_memory=1)
                        dst_out = pl.store(zero_tile, [0, 0], [num_heads, head_dim], dst)
                else:
                    # Not first: full online update
                    # Reshape DN [num_heads,1] -> ND [1,num_heads] for element-wise ops
                    mi_tile_nd = pl.reshape(mi_tile, [1, num_heads])
                    mij_tile_nd = pl.reshape(mij_tile, [1, num_heads])
                    li_tile_nd = pl.reshape(li_tile, [1, num_heads])
                    lij_tile_nd = pl.reshape(lij_tile, [1, num_heads])

                    # mi_new = max(mi, mij): new running row maximum
                    mi_new = pl.maximum(mi_tile_nd, mij_tile_nd)
                    # alpha = exp(mi - mi_new): rescale factor for accumulated oi
                    mi_diff = pl.sub(mi_tile_nd, mi_new)
                    alpha = pl.exp(mi_diff)
                    # beta = exp(mij - mi_new): rescale factor for current oi_new
                    mij_diff = pl.sub(mij_tile_nd, mi_new)
                    beta = pl.exp(mij_diff)

                    # li_updated = alpha * li + beta * lij: updated normalizer
                    li_scaled = pl.mul(alpha, li_tile_nd)
                    lij_scaled = pl.mul(beta, lij_tile_nd)
                    li_updated = pl.add(li_scaled, lij_scaled)

                    # oi_updated = alpha * oi + beta * oi_new: updated attention output
                    alpha_dn = pl.reshape(alpha, [num_heads, 1])              # Reshape [1,num_heads] -> [num_heads,1] DN for row_expand_mul
                    oi_scaled = pl.row_expand_mul(oi_tile, alpha_dn)
                    beta_dn = pl.reshape(beta, [num_heads, 1])                # Reshape [1,num_heads] -> [num_heads,1] DN for row_expand_mul
                    oi_new_scaled = pl.row_expand_mul(oi_new_tile, beta_dn)
                    oi_updated = pl.add(oi_scaled, oi_new_scaled)

                    mi_new_dn = pl.reshape(mi_new, [num_heads, 1])            # Reshape back to DN [num_heads,1] for store
                    li_updated_dn = pl.reshape(li_updated, [num_heads, 1])    # Reshape back to DN [num_heads,1] for store

                    mi_out = pl.store(mi_new_dn, [0, 0], [num_heads, 1], mi)
                    li_out = pl.store(li_updated_dn, [0, 0], [num_heads, 1], li)

                    if is_last:
                        # Last block: normalize dst = oi / li
                        dst_tile = pl.row_expand_div(oi_updated, li_updated_dn)
                        dst_out = pl.store(dst_tile, [0, 0], [num_heads, head_dim], dst)
                        oi_out = pl.store(oi_updated, [0, 0], [num_heads, head_dim], oi)
                    else:
                        # Middle block: no normalize
                        oi_out = pl.store(oi_updated, [0, 0], [num_heads, head_dim], oi)
                        zero_tile = pl.create_tile([num_heads, head_dim], dtype=pl.FP32, target_memory=1)
                        dst_out = pl.store(zero_tile, [0, 0], [num_heads, head_dim], dst)

                return mi_out, li_out, oi_out, dst_out

            @pl.function(type=pl.FunctionType.Orchestration)
            def orchestrator(
                self,
                mij: pl.Tensor[[num_heads, 1], pl.FP32],
                lij: pl.Tensor[[num_heads, 1], pl.FP32],
                oi_new: pl.Tensor[[num_heads, head_dim], pl.FP32],
                config: pl.Tensor[[2], pl.INT64],
                mi: pl.Tensor[[num_heads, 1], pl.FP32],
                li: pl.Tensor[[num_heads, 1], pl.FP32],
                oi: pl.Tensor[[num_heads, head_dim], pl.FP32],
            ) -> tuple[pl.Tensor[[num_heads, 1], pl.FP32], pl.Tensor[[num_heads, 1], pl.FP32],
                       pl.Tensor[[num_heads, head_dim], pl.FP32], pl.Tensor[[num_heads, head_dim], pl.FP32]]:
                # Read is_first and is_last from config tensor
                is_first: pl.Scalar[pl.INT64] = pl.tensor.read(config, [0])
                is_last: pl.Scalar[pl.INT64] = pl.tensor.read(config, [1])
                mi, li, oi, dst = self.online_update(mij, lij, oi_new, mi, li, oi, is_first, is_last)
                return mi, li, oi, dst

        return OnlineUpdateProgram

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/st/codegen/test_paged_attention.py`:
- Around line 258-397: The test currently fixes config to [0,1] and
compute_expected only implements the last-block path; update the test to
parameterize the config/TensorSpec (config) so the harness can run all four
combinations and modify compute_expected to read the runtime booleans
(is_first/is_last) and branch to compute the correct golden outputs for each
case (single-block first+last, first-but-not-last, not-first+last, middle)
matching the logic in orchestrator.online_update and using the same intermediate
names (mi, li, oi, dst) so expected outputs align with the OnlineUpdateProgram
behavior.
- Line 125: Remove the unused local variable assignment "scale_value =
self.scale" (which triggers Ruff F841); either delete that line altogether or,
if it's intentional for clarity, rename it to a throwaway name like
"_scale_value" or use the value directly where needed. Locate the assignment to
"scale_value" in the method that contains the comment "scale is passed via
config tensor at runtime" and remove or rename it accordingly.
- Around line 271-277: The "config" TensorSpec in define_tensors uses
init_value=[0,1], which will be passed to TensorSpec.create_array and cause
torch.full(fill_value=list) to raise; change the init_value for the TensorSpec
named "config" to a tensor or a callable that returns a tensor (e.g.,
torch.tensor([0,1]) or a lambda creating that tensor) so TensorSpec.create_array
uses a proper tensor fill value; update the define_tensors return entry for
"config" (and any similar specs) rather than changing torch.full usage inside
TensorSpec.create_array.
- Around line 383-386: The values read into is_first and is_last are
pl.Scalar[pl.INT64] but online_update(...) expects pl.Scalar[pl.BOOL]; add an
explicit cast/bitcast of the results of pl.tensor.read(config, [0]) and
pl.tensor.read(config, [1]) to pl.Scalar[pl.BOOL] before passing them into
self.online_update (or confirm the DSL auto-converts and add an explicit no-op
cast if required). Locate the pl.tensor.read calls that produce is_first/is_last
and replace or wrap them with an explicit cast to BOOL so the types match the
online_update function signature.
- Around line 110-174: The config tensor is defined as INT32 but read as
pl.Scalar[pl.FP32] in orchestrator, causing a dtype mismatch; fix define_tensors
to declare the "config" TensorSpec with DataType.FP32 (remove the struct
pack/scale_bits logic in define_tensors) so the runtime pl.tensor.read(config,
[0]) returns a FP32 scalar, or alternatively keep INT32 and change orchestrator
to read an INT32 and bitcast to FP32 if a bitcast primitive exists; update
references in define_tensors and get_program/orchestrator (config,
pl.tensor.read, scale usage) to match the chosen approach.
- Around line 42-46: Orchestrator calls must pass the explicit output tensors
required by the InCore functions and avoid hard-coded shapes: update calls to
QKMatmul.qk_matmul(...) to include the sij output tensor argument, update
PVMatmul/SoftmaxPrepare/OnlineUpdate orchestrators to pass pij/mij/lij, oi_new
and dst outputs respectively (match the pattern used in harness.core.harness
InCore invocations), and in OnlineUpdateTestCase fix the scalar type conversion
from INT64 to BOOL for boolean scalars; also replace hard-coded shapes like
[16,16] and [16,1] with parameters derived from the test class (use
self.num_heads, self.head_dim, self.block_size or make the program constructors
accept these values) so dimensions remain consistent across orchestration and
InCore function signatures.

ℹ️ Review info

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 81383e1 and 4e5fd64.

📒 Files selected for processing (1)
  • tests/st/codegen/test_paged_attention.py

Comment on lines +42 to +46
def __init__(self, num_heads: int = 16, head_dim: int = 16, **kwargs):
super().__init__(**kwargs)
self.num_heads = num_heads
self.head_dim = head_dim

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, let's check if the test file exists and read the relevant sections
cat -n tests/st/codegen/test_paged_attention.py | head -100

Repository: hw-native-sys/pypto

Length of output: 4640


🏁 Script executed:

# Search for FunctionType.Orchestration patterns in other tests
rg -n "FunctionType\.Orchestration" tests/st -g '*.py' -C3

Repository: hw-native-sys/pypto

Length of output: 10484


🏁 Script executed:

# Look for orchestrator method patterns in PyPTO tests
rg -n "def orchestrator" tests/st -g '*.py' -A5

Repository: hw-native-sys/pypto

Length of output: 6766


🏁 Script executed:

# Read the full test_paged_attention.py to see all InCore and orchestrator definitions
wc -l tests/st/codegen/test_paged_attention.py

Repository: hw-native-sys/pypto

Length of output: 107


🏁 Script executed:

# Read the matmul InCore function from test_matmul.py to see its signature
sed -n '35,65p' tests/st/runtime/test_matmul.py

Repository: hw-native-sys/pypto

Length of output: 1504


🏁 Script executed:

# Check if InCore functions actually need output tensors by examining a full pair
rg -n "def matmul" tests/st/runtime/test_matmul.py -A10

Repository: hw-native-sys/pypto

Length of output: 802


🏁 Script executed:

# Read full test_paged_attention.py to see all InCore functions and their calls
sed -n '60,85p' tests/st/codegen/test_paged_attention.py

Repository: hw-native-sys/pypto

Length of output: 1360


🏁 Script executed:

# Check the other orchestrator calls in test_paged_attention.py more carefully
sed -n '130,180p' tests/st/codegen/test_paged_attention.py

Repository: hw-native-sys/pypto

Length of output: 2491


🏁 Script executed:

# Look for documentation or framework code that might auto-wire outputs
rg -n "def matmul|self\.matmul" tests/st/runtime/test_matmul.py -B2 -A8

Repository: hw-native-sys/pypto

Length of output: 1123


🏁 Script executed:

# Check if there's any auto-wiring of output tensors in the framework
rg -n "output_tensor|automatic.*output|infer.*output" tests/st -g '*.py' -B2 -A2

Repository: hw-native-sys/pypto

Length of output: 5149


🏁 Script executed:

# Check the PV matmul and online update orchestrators more carefully
sed -n '220,245p' tests/st/codegen/test_paged_attention.py

Repository: hw-native-sys/pypto

Length of output: 1391


🏁 Script executed:

# See if there's a pattern where output tensors are sometimes passed vs not
rg -n "def orchestrator" tests/st/codegen/test_paged_attention.py -A6

Repository: hw-native-sys/pypto

Length of output: 1432


🏁 Script executed:

# Verify the shape hard-coding more clearly
sed -n '42,56p' tests/st/codegen/test_paged_attention.py

Repository: hw-native-sys/pypto

Length of output: 700


🏁 Script executed:

# Check if softmax_prepare also has this shape mismatch
sed -n '100,125p' tests/st/codegen/test_paged_attention.py

Repository: hw-native-sys/pypto

Length of output: 1257


🏁 Script executed:

# Look at online update to see if similar pattern exists there
sed -n '286,330p' tests/st/codegen/test_paged_attention.py

Repository: hw-native-sys/pypto

Length of output: 2533


🏁 Script executed:

# Check if there's documentation or comments about output auto-wiring
rg -n "auto.*output|infer.*parameter|implicit.*output" tests/st -g '*.py' -B2 -A2

Repository: hw-native-sys/pypto

Length of output: 45


🏁 Script executed:

# Let's check if there's a difference in how test_matmul.py (runtime) vs test_paged_attention.py (codegen) handle outputs
# First, is there a framework that auto-wires outputs based on InCore signature?
rg -n "get_output_tensors|output_tensors" tests/ -g '*.py' -A3 | head -40

Repository: hw-native-sys/pypto

Length of output: 467


🏁 Script executed:

# Check if the framework automatically fills in missing arguments
# Look at the test harness to understand how InCore functions are called
rg -n "def.*execute|InCore.*call|invoke.*kernel" tests/st/harness -g '*.py' -B2 -A5

Repository: hw-native-sys/pypto

Length of output: 570


🏁 Script executed:

# Let's verify the shape mismatch more thoroughly - check all test case constructors
grep -n "def __init__" tests/st/codegen/test_paged_attention.py -A4

Repository: hw-native-sys/pypto

Length of output: 924


🏁 Script executed:

# And check all the InCore function parameter declarations vs the actual tensor shapes used
rg -n "pl.Tensor\[\[" tests/st/codegen/test_paged_attention.py | head -30

Repository: hw-native-sys/pypto

Length of output: 1948


🏁 Script executed:

# Check if there's framework documentation or if orchestrators are auto-generated for simple cases
rg -n "auto.*generat|infer.*orchestrat" tests/st -g '*.py' -B2 -A2

Repository: hw-native-sys/pypto

Length of output: 1444


🏁 Script executed:

# Compare harness.py more carefully - does it have different orchestration patterns?
sed -n '190,210p' tests/st/harness/core/harness.py

Repository: hw-native-sys/pypto

Length of output: 1350


🏁 Script executed:

# Check if output tensor parameters are mandatory or optional in the framework
rg -n "output_tensor\s*=" tests/ -g '*.py' | head -20

Repository: hw-native-sys/pypto

Length of output: 2300


🏁 Script executed:

# Check if there are any test instantiations with non-default parameters
rg -n "QKMatmulTestCase|SoftmaxPrepareTestCase|PVMatmulTestCase|OnlineUpdateTestCase" tests/ -g '*.py' -B2 -A2

Repository: hw-native-sys/pypto

Length of output: 3535


🏁 Script executed:

# Let's see the full OnlineUpdateTestCase constructor to be sure
sed -n '263,280p' tests/st/codegen/test_paged_attention.py

Repository: hw-native-sys/pypto

Length of output: 1173


🏁 Script executed:

# Let me verify the softmax_prepare orchestrator call - does it also miss output tensors?
sed -n '166,180p' tests/st/codegen/test_paged_attention.py

Repository: hw-native-sys/pypto

Length of output: 809


🏁 Script executed:

# Check the online_update orchestrator call as well
sed -n '371,388p' tests/st/codegen/test_paged_attention.py

Repository: hw-native-sys/pypto

Length of output: 1048


🏁 Script executed:

# Verify if test_matmul.py (which already exists) works despite missing output tensor args
rg -n "class MatmulTestCase" tests/st/runtime/test_matmul.py -A30

Repository: hw-native-sys/pypto

Length of output: 45


🏁 Script executed:

# Check if there's documentation on how the framework handles output tensor auto-wiring
rg -n "auto.*wire|implicit.*output|infer.*argument" . -g '*.py' -B2 -A2 2>/dev/null | head -30

Repository: hw-native-sys/pypto

Length of output: 45


🏁 Script executed:

# Let's look at the online_update orchestrator more carefully - why does it pass output tensors but others don't?
sed -n '330,390p' tests/st/codegen/test_paged_attention.py

Repository: hw-native-sys/pypto

Length of output: 3460


🏁 Script executed:

# Check the type mismatch more carefully - INT64 vs BOOL
sed -n '295,300p' tests/st/codegen/test_paged_attention.py
sed -n '380,387p' tests/st/codegen/test_paged_attention.py

Repository: hw-native-sys/pypto

Length of output: 954


🏁 Script executed:

# Also check if the framework can handle INT64 where BOOL is expected
rg -n "is_first.*is_last" tests/st/codegen/test_paged_attention.py -B2 -A2

Repository: hw-native-sys/pypto

Length of output: 2895


🏁 Script executed:

# Verify the config init_value list issue
sed -n '278,283p' tests/st/codegen/test_paged_attention.py

Repository: hw-native-sys/pypto

Length of output: 410


🏁 Script executed:

# Let me verify if the online_update call is actually missing the dst output
sed -n '384,388p' tests/st/codegen/test_paged_attention.py

Repository: hw-native-sys/pypto

Length of output: 355


🏁 Script executed:

# And compare with the InCore function signature one more time
sed -n '287,300p' tests/st/codegen/test_paged_attention.py

Repository: hw-native-sys/pypto

Length of output: 729


Fix orchestrator calls to pass all required output tensor parameters.
The InCore functions accept explicit output tensors (e.g., sij, pij/mij/lij, oi_new, dst), but several orchestrators invoke them without passing those outputs. For example, QKMatmul.qk_matmul() signature requires sij as a parameter, but the orchestrator calls it with only qi and kj_t. Similarly, OnlineUpdate.online_update() requires dst but the orchestrator omits it. Compare against tests/st/harness/core/harness.py (lines 202–205), where output tensors are passed explicitly to InCore functions.

Additionally, the program annotations use hard-coded shapes ([16, 16], [16, 1]) while constructors accept configurable num_heads, head_dim, and block_size. Enforce consistent dimensions throughout or make programs parameterizable.

💡 Example fix for QKMatmulProgram
-            def orchestrator(
-                self, qi: pl.Tensor[[16, 16], pl.FP32], kj_t: pl.Tensor[[16, 16], pl.FP32]
-            ) -> pl.Tensor[[16, 16], pl.FP32]:
-                out_sij = self.qk_matmul(qi, kj_t)
+            def orchestrator(
+                self,
+                qi: pl.Tensor[[16, 16], pl.FP32],
+                kj_t: pl.Tensor[[16, 16], pl.FP32],
+                sij: pl.Tensor[[16, 16], pl.FP32],
+            ) -> pl.Tensor[[16, 16], pl.FP32]:
+                out_sij = self.qk_matmul(qi, kj_t, sij)
                 return out_sij

Also applies to: SoftmaxPrepareTestCase (lines 128–175), PVMatmulTestCase (lines 219–241), OnlineUpdateTestCase (lines 286–387; also fix INT64→BOOL scalar conversion and dst parameter).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/st/codegen/test_paged_attention.py` around lines 42 - 46, Orchestrator
calls must pass the explicit output tensors required by the InCore functions and
avoid hard-coded shapes: update calls to QKMatmul.qk_matmul(...) to include the
sij output tensor argument, update PVMatmul/SoftmaxPrepare/OnlineUpdate
orchestrators to pass pij/mij/lij, oi_new and dst outputs respectively (match
the pattern used in harness.core.harness InCore invocations), and in
OnlineUpdateTestCase fix the scalar type conversion from INT64 to BOOL for
boolean scalars; also replace hard-coded shapes like [16,16] and [16,1] with
parameters derived from the test class (use self.num_heads, self.head_dim,
self.block_size or make the program constructors accept these values) so
dimensions remain consistent across orchestration and InCore function
signatures.

Comment on lines +110 to +174
def define_tensors(self) -> List[TensorSpec]:
import struct
# Pack scale as float32 bits into int32
scale_bits = struct.unpack('i', struct.pack('f', self.scale))[0]
return [
TensorSpec("sij", [self.num_heads, self.block_size], DataType.FP32, init_value=1.0),
TensorSpec("config", [1], DataType.INT32, init_value=scale_bits),
TensorSpec("pij", [self.num_heads, self.block_size], DataType.FP32, is_output=True),
TensorSpec("mij", [self.num_heads, 1], DataType.FP32, is_output=True),
TensorSpec("lij", [self.num_heads, 1], DataType.FP32, is_output=True),
]

def get_program(self) -> Any:
import pypto.language as pl

scale_value = self.scale # not referenced in the kernel; scale is passed via config tensor at runtime

@pl.program
class SoftmaxPrepareProgram:
@pl.function(type=pl.FunctionType.InCore)
def softmax_prepare(
self,
sij: pl.Tensor[[16, 16], pl.FP32],
scale: pl.Scalar[pl.FP32],
pij: pl.Tensor[[16, 16], pl.FP32],
mij: pl.Tensor[[16, 1], pl.FP32],
lij: pl.Tensor[[16, 1], pl.FP32],
) -> tuple[pl.Tensor[[16, 16], pl.FP32], pl.Tensor[[16, 1], pl.FP32], pl.Tensor[[16, 1], pl.FP32]]:
# Load sij to UB (target_memory=1)
sij_tile = pl.load(sij, [0, 0], [16, 16], target_memory=1)

# Scale: sij * scale_factor
sij_scaled = pl.muls(sij_tile, scale)

# Create temp tile for row reduction
tmp_tile = pl.create_tile([16, 16], dtype=pl.FP32, target_memory=1)

# Row max: mij = max(sij_scaled, axis=1) -> [16, 1] DN format
mij_tile = pl.row_max(sij_scaled, tmp_tile)

# Row broadcast subtraction: sij_scaled - mij
sij_centered = pl.row_expand_sub(sij_scaled, mij_tile)

# Exp: exp(sij_centered)
pij_tile = pl.exp(sij_centered)

# Row sum: lij = sum(pij, axis=1) -> [16, 1] DN format
lij_tile = pl.row_sum(pij_tile, tmp_tile)

# Store results
pij_out = pl.store(pij_tile, [0, 0], [16, 16], pij)
mij_out = pl.store(mij_tile, [0, 0], [16, 1], mij)
lij_out = pl.store(lij_tile, [0, 0], [16, 1], lij)

return pij_out, mij_out, lij_out

@pl.function(type=pl.FunctionType.Orchestration)
def orchestrator(
self,
sij: pl.Tensor[[16, 16], pl.FP32],
config: pl.Tensor[[1], pl.INT32],
) -> tuple[pl.Tensor[[16, 16], pl.FP32], pl.Tensor[[16, 1], pl.FP32], pl.Tensor[[16, 1], pl.FP32]]:
# Read scale value from config tensor
scale: pl.Scalar[pl.FP32] = pl.tensor.read(config, [0])
pij_out, mij_out, lij_out = self.softmax_prepare(sij, scale)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

# Search for bitcast/reinterpret patterns in pypto language files
rg -i "bitcast|reinterpret" --type py

Repository: hw-native-sys/pypto

Length of output: 45


🏁 Script executed:

# Search for tensor.read usage patterns, especially with config tensors
rg "tensor\.read" --type py -B2 -A2

Repository: hw-native-sys/pypto

Length of output: 6880


🏁 Script executed:

# Search for tensor.cast or explicit type conversion patterns
rg "tensor\.cast|as_int|as_float" --type py -B2 -A2

Repository: hw-native-sys/pypto

Length of output: 5439


🏁 Script executed:

# Check if there are other similar patterns in tests with INT32/FP32 conversion
rg "INT32.*FP32|FP32.*INT32" --type py -B3 -A3

Repository: hw-native-sys/pypto

Length of output: 11318


🏁 Script executed:

# Look at the pypto language module to understand tensor.read behavior
fd "language" --type f -name "*.py" | head -20

Repository: hw-native-sys/pypto

Length of output: 293


🏁 Script executed:

# Look at tensor.read implementation
rg "def read" python/pypto/ir/op/ -A10

Repository: hw-native-sys/pypto

Length of output: 887


🏁 Script executed:

# Check if there are any bitcast-like operations or reinterpret patterns
rg "reinterpret|bitcast|bit_cast|as_bits|from_bits" --type py -i

Repository: hw-native-sys/pypto

Length of output: 45


🏁 Script executed:

# Search for how type mismatches are handled in tensor.read assignments
rg "tensor\.read.*FP32.*INT32|tensor\.read.*INT32.*FP32" --type py -B2 -A2

Repository: hw-native-sys/pypto

Length of output: 45


🏁 Script executed:

# Check the language documentation or examples for INT32/FP32 config handling
fd "test_paged_attention" --type f

Repository: hw-native-sys/pypto

Length of output: 103


🏁 Script executed:

# Look at the full test to understand the intent and how values are passed
head -200 tests/st/codegen/test_paged_attention.py | tail -100

Repository: hw-native-sys/pypto

Length of output: 4511


🏁 Script executed:

# Check if there's a bitcast or reinterpret operation in the language API
rg "class.*Scalar|def.*cast|def.*reinterpret|def.*bitcast" python/pypto/language -B2 -A5 | head -100

Repository: hw-native-sys/pypto

Length of output: 3208


🏁 Script executed:

# Look for how other tests handle INT32/FP32 conversions
rg "Tensor.*INT32.*FP32|Tensor.*FP32.*INT32" tests/st -B3 -A3 | head -50

Repository: hw-native-sys/pypto

Length of output: 45


🏁 Script executed:

# Check if tensor.cast can be used for this purpose
rg "tensor\.cast" tests --type py -B3 -A3 | head -100

Repository: hw-native-sys/pypto

Length of output: 5200


🏁 Script executed:

# Verify what pl.tensor.read actually returns by checking its usage pattern
rg "pl\.tensor\.read.*INT" tests/st/codegen/test_paged_attention.py -B2 -A2

Repository: hw-native-sys/pypto

Length of output: 45


🏁 Script executed:

# Check the cast operation modes to see if bitcast is supported
rg "mode.*bitcast|bitcast.*mode" python/pypto -A5 -B5

Repository: hw-native-sys/pypto

Length of output: 45


🏁 Script executed:

# Look for any explicit bitcast or reinterpret operations in the codebase
fd "*.py" --type f | xargs grep -l "bitcast\|reinterpret" 2>/dev/null || echo "No bitcast/reinterpret found"

Repository: hw-native-sys/pypto

Length of output: 392


🏁 Script executed:

# Check the unified_ops.py to see the full cast documentation
cat python/pypto/language/op/unified_ops.py | grep -A20 "^def cast"

Repository: hw-native-sys/pypto

Length of output: 822


🏁 Script executed:

# Look at the online_update test case to see if similar INT32 config is handled correctly there
rg "is_first.*is_last" tests/st/codegen/test_paged_attention.py -B5 -A10

Repository: hw-native-sys/pypto

Length of output: 4365


🏁 Script executed:

# Verify if there's any implicit type conversion in tensor.read or if it errors on mismatch
rg "tensor\.read" python/pypto/language/op -A10 | head -50

Repository: hw-native-sys/pypto

Length of output: 45


🏁 Script executed:

# Check if there are any other tests that read INT32 into FP32 or vice versa
rg "tensor\.read.*config" tests/st -B5 -A5

Repository: hw-native-sys/pypto

Length of output: 2190


Store config tensor as FP32 instead of INT32, or use explicit bitcast if available.

The code packs float32 bits into an INT32 tensor at setup, but the orchestrator reads it as pl.Scalar[pl.FP32]. Based on codebase patterns (e.g., OnlineUpdateTestCase where config is INT64 and reads match as INT64), pl.tensor.read preserves the source tensor's dtype. Reading INT32 as FP32 will either cause a type mismatch error or trigger an implicit numeric conversion that corrupts the float representation.

Recommended fix: Define config as DataType.FP32 and remove the struct packing at setup, or apply an explicit bitcast operation if the DSL supports it (none found in codebase).

🧰 Tools
🪛 Ruff (0.15.2)

[error] 125-125: Local variable scale_value is assigned to but never used

Remove assignment to unused variable scale_value

(F841)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/st/codegen/test_paged_attention.py` around lines 110 - 174, The config
tensor is defined as INT32 but read as pl.Scalar[pl.FP32] in orchestrator,
causing a dtype mismatch; fix define_tensors to declare the "config" TensorSpec
with DataType.FP32 (remove the struct pack/scale_bits logic in define_tensors)
so the runtime pl.tensor.read(config, [0]) returns a FP32 scalar, or
alternatively keep INT32 and change orchestrator to read an INT32 and bitcast to
FP32 if a bitcast primitive exists; update references in define_tensors and
get_program/orchestrator (config, pl.tensor.read, scale usage) to match the
chosen approach.

def get_program(self) -> Any:
import pypto.language as pl

scale_value = self.scale # not referenced in the kernel; scale is passed via config tensor at runtime
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove unused scale_value to satisfy Ruff F841.

🧹 Suggested cleanup
-        scale_value = self.scale  # not referenced in the kernel; scale is passed via config tensor at runtime
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
scale_value = self.scale # not referenced in the kernel; scale is passed via config tensor at runtime
🧰 Tools
🪛 Ruff (0.15.2)

[error] 125-125: Local variable scale_value is assigned to but never used

Remove assignment to unused variable scale_value

(F841)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/st/codegen/test_paged_attention.py` at line 125, Remove the unused
local variable assignment "scale_value = self.scale" (which triggers Ruff F841);
either delete that line altogether or, if it's intentional for clarity, rename
it to a throwaway name like "_scale_value" or use the value directly where
needed. Locate the assignment to "scale_value" in the method that contains the
comment "scale is passed via config tensor at runtime" and remove or rename it
accordingly.

Comment on lines +258 to +397
The default config [is_first=0, is_last=1] exercises the last-block path.
compute_expected is hardcoded to the last-block computation and does not
branch on config values.
"""

def __init__(self, num_heads: int = 16, head_dim: int = 16, **kwargs):
super().__init__(**kwargs)
self.num_heads = num_heads
self.head_dim = head_dim

def get_name(self) -> str:
return f"online_update_{self.num_heads}h_{self.head_dim}d"

def define_tensors(self) -> List[TensorSpec]:
return [
TensorSpec("mij", [self.num_heads, 1], DataType.FP32, init_value=0.5),
TensorSpec("lij", [self.num_heads, 1], DataType.FP32, init_value=1.5),
TensorSpec("oi_new", [self.num_heads, self.head_dim], DataType.FP32, init_value=0.3),
TensorSpec("config", [2], DataType.INT64, init_value=[0, 1]), # [is_first=0, is_last=1]
TensorSpec("mi", [self.num_heads, 1], DataType.FP32, init_value=0.4, is_output=True),
TensorSpec("li", [self.num_heads, 1], DataType.FP32, init_value=2.0, is_output=True),
TensorSpec("oi", [self.num_heads, self.head_dim], DataType.FP32, init_value=0.2, is_output=True),
TensorSpec("dst", [self.num_heads, self.head_dim], DataType.FP32, is_output=True),
]

def get_program(self) -> Any:
import pypto.language as pl

@pl.program
class OnlineUpdateProgram:
@pl.function(type=pl.FunctionType.InCore)
def online_update(
self,
mij: pl.Tensor[[16, 1], pl.FP32],
lij: pl.Tensor[[16, 1], pl.FP32],
oi_new: pl.Tensor[[16, 16], pl.FP32],
mi: pl.Tensor[[16, 1], pl.FP32],
li: pl.Tensor[[16, 1], pl.FP32],
oi: pl.Tensor[[16, 16], pl.FP32],
is_first: pl.Scalar[pl.BOOL],
is_last: pl.Scalar[pl.BOOL],
dst: pl.Tensor[[16, 16], pl.FP32],
) -> tuple[pl.Tensor[[16, 1], pl.FP32], pl.Tensor[[16, 1], pl.FP32],
pl.Tensor[[16, 16], pl.FP32], pl.Tensor[[16, 16], pl.FP32]]:
# Load all inputs
mij_tile = pl.load(mij, [0, 0], [16, 1], target_memory=1)
lij_tile = pl.load(lij, [0, 0], [16, 1], target_memory=1)
oi_new_tile = pl.load(oi_new, [0, 0], [16, 16], target_memory=1)
mi_tile = pl.load(mi, [0, 0], [16, 1], target_memory=1)
li_tile = pl.load(li, [0, 0], [16, 1], target_memory=1)
oi_tile = pl.load(oi, [0, 0], [16, 16], target_memory=1)

if is_first:
# First block: copy mij->mi, lij->li, oi_new->oi
mi_out = pl.store(mij_tile, [0, 0], [16, 1], mi)
li_out = pl.store(lij_tile, [0, 0], [16, 1], li)
oi_out = pl.store(oi_new_tile, [0, 0], [16, 16], oi)
if is_last:
# Single block: normalize dst = oi_new / lij
dst_tile = pl.row_expand_div(oi_new_tile, lij_tile)
dst_out = pl.store(dst_tile, [0, 0], [16, 16], dst)
else:
# First but not last: no dst output
zero_tile = pl.create_tile([16, 16], dtype=pl.FP32, target_memory=1)
dst_out = pl.store(zero_tile, [0, 0], [16, 16], dst)
else:
# Not first: full online update
# Reshape DN [16,1] -> ND [1,16] for element-wise ops
mi_tile_nd = pl.reshape(mi_tile, [1, 16])
mij_tile_nd = pl.reshape(mij_tile, [1, 16])
li_tile_nd = pl.reshape(li_tile, [1, 16])
lij_tile_nd = pl.reshape(lij_tile, [1, 16])

# mi_new = max(mi, mij): new running row maximum
mi_new = pl.maximum(mi_tile_nd, mij_tile_nd)
# alpha = exp(mi - mi_new): rescale factor for accumulated oi
mi_diff = pl.sub(mi_tile_nd, mi_new)
alpha = pl.exp(mi_diff)
# beta = exp(mij - mi_new): rescale factor for current oi_new
mij_diff = pl.sub(mij_tile_nd, mi_new)
beta = pl.exp(mij_diff)

# li_updated = alpha * li + beta * lij: updated normalizer
li_scaled = pl.mul(alpha, li_tile_nd)
lij_scaled = pl.mul(beta, lij_tile_nd)
li_updated = pl.add(li_scaled, lij_scaled)

# oi_updated = alpha * oi + beta * oi_new: updated attention output
alpha_dn = pl.reshape(alpha, [16, 1]) # Reshape [1,16] -> [16,1] DN for row_expand_mul
oi_scaled = pl.row_expand_mul(oi_tile, alpha_dn)
beta_dn = pl.reshape(beta, [16, 1]) # Reshape [1,16] -> [16,1] DN for row_expand_mul
oi_new_scaled = pl.row_expand_mul(oi_new_tile, beta_dn)
oi_updated = pl.add(oi_scaled, oi_new_scaled)

mi_new_dn = pl.reshape(mi_new, [16, 1]) # Reshape back to DN [16,1] for store
li_updated_dn = pl.reshape(li_updated, [16, 1]) # Reshape back to DN [16,1] for store

mi_out = pl.store(mi_new_dn, [0, 0], [16, 1], mi)
li_out = pl.store(li_updated_dn, [0, 0], [16, 1], li)

if is_last:
# Last block: normalize dst = oi / li
dst_tile = pl.row_expand_div(oi_updated, li_updated_dn)
dst_out = pl.store(dst_tile, [0, 0], [16, 16], dst)
oi_out = pl.store(oi_updated, [0, 0], [16, 16], oi)
else:
# Middle block: no normalize
oi_out = pl.store(oi_updated, [0, 0], [16, 16], oi)
zero_tile = pl.create_tile([16, 16], dtype=pl.FP32, target_memory=1)
dst_out = pl.store(zero_tile, [0, 0], [16, 16], dst)

return mi_out, li_out, oi_out, dst_out

@pl.function(type=pl.FunctionType.Orchestration)
def orchestrator(
self,
mij: pl.Tensor[[16, 1], pl.FP32],
lij: pl.Tensor[[16, 1], pl.FP32],
oi_new: pl.Tensor[[16, 16], pl.FP32],
config: pl.Tensor[[2], pl.INT64],
mi: pl.Tensor[[16, 1], pl.FP32],
li: pl.Tensor[[16, 1], pl.FP32],
oi: pl.Tensor[[16, 16], pl.FP32],
) -> tuple[pl.Tensor[[16, 1], pl.FP32], pl.Tensor[[16, 1], pl.FP32],
pl.Tensor[[16, 16], pl.FP32], pl.Tensor[[16, 16], pl.FP32]]:
# Read is_first and is_last from config tensor
is_first: pl.Scalar[pl.INT64] = pl.tensor.read(config, [0])
is_last: pl.Scalar[pl.INT64] = pl.tensor.read(config, [1])
mi, li, oi, dst = self.online_update(mij, lij, oi_new, mi, li, oi, is_first, is_last)
return mi, li, oi, dst

return OnlineUpdateProgram

def compute_expected(self, tensors, params=None):
"""Compute expected output for the last-block case (is_first=0, is_last=1).

config[0]=is_first and config[1]=is_last are read but not used for
branching; the computation is fixed to the last-block path matching
the default config [0, 1].
"""
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

OnlineUpdate expected results ignore is_first/is_last; only last-block path is tested.
compute_expected is hardcoded to the last-block path and the test config is fixed to [0, 1]. This doesn’t exercise the other three runtime branches despite the stated goal. Consider parameterizing config in the test and branching compute_expected on is_first/is_last so each path has a golden result.

Also applies to: 451-455

🧰 Tools
🪛 Ruff (0.15.2)

[warning] 391-391: Unused method argument: params

(ARG002)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/st/codegen/test_paged_attention.py` around lines 258 - 397, The test
currently fixes config to [0,1] and compute_expected only implements the
last-block path; update the test to parameterize the config/TensorSpec (config)
so the harness can run all four combinations and modify compute_expected to read
the runtime booleans (is_first/is_last) and branch to compute the correct golden
outputs for each case (single-block first+last, first-but-not-last,
not-first+last, middle) matching the logic in orchestrator.online_update and
using the same intermediate names (mi, li, oi, dst) so expected outputs align
with the OnlineUpdateProgram behavior.

Comment on lines +271 to +277
def define_tensors(self) -> List[TensorSpec]:
return [
TensorSpec("mij", [self.num_heads, 1], DataType.FP32, init_value=0.5),
TensorSpec("lij", [self.num_heads, 1], DataType.FP32, init_value=1.5),
TensorSpec("oi_new", [self.num_heads, self.head_dim], DataType.FP32, init_value=0.3),
TensorSpec("config", [2], DataType.INT64, init_value=[0, 1]), # [is_first=0, is_last=1]
TensorSpec("mi", [self.num_heads, 1], DataType.FP32, init_value=0.4, is_output=True),
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

TensorSpec.init_value list will raise during tensor creation.
TensorSpec.create_array() uses torch.full for non-tensor/callable values; passing a Python list ([0, 1]) as the fill value will error. Use a tensor or a callable instead.

✅ Suggested fix
-            TensorSpec("config", [2], DataType.INT64, init_value=[0, 1]),  # [is_first=0, is_last=1]
+            TensorSpec("config", [2], DataType.INT64, init_value=lambda shape: [0, 1]),  # [is_first=0, is_last=1]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def define_tensors(self) -> List[TensorSpec]:
return [
TensorSpec("mij", [self.num_heads, 1], DataType.FP32, init_value=0.5),
TensorSpec("lij", [self.num_heads, 1], DataType.FP32, init_value=1.5),
TensorSpec("oi_new", [self.num_heads, self.head_dim], DataType.FP32, init_value=0.3),
TensorSpec("config", [2], DataType.INT64, init_value=[0, 1]), # [is_first=0, is_last=1]
TensorSpec("mi", [self.num_heads, 1], DataType.FP32, init_value=0.4, is_output=True),
def define_tensors(self) -> List[TensorSpec]:
return [
TensorSpec("mij", [self.num_heads, 1], DataType.FP32, init_value=0.5),
TensorSpec("lij", [self.num_heads, 1], DataType.FP32, init_value=1.5),
TensorSpec("oi_new", [self.num_heads, self.head_dim], DataType.FP32, init_value=0.3),
TensorSpec("config", [2], DataType.INT64, init_value=lambda shape: [0, 1]), # [is_first=0, is_last=1]
TensorSpec("mi", [self.num_heads, 1], DataType.FP32, init_value=0.4, is_output=True),
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/st/codegen/test_paged_attention.py` around lines 271 - 277, The
"config" TensorSpec in define_tensors uses init_value=[0,1], which will be
passed to TensorSpec.create_array and cause torch.full(fill_value=list) to
raise; change the init_value for the TensorSpec named "config" to a tensor or a
callable that returns a tensor (e.g., torch.tensor([0,1]) or a lambda creating
that tensor) so TensorSpec.create_array uses a proper tensor fill value; update
the define_tensors return entry for "config" (and any similar specs) rather than
changing torch.full usage inside TensorSpec.create_array.

Comment on lines +383 to +386
# Read is_first and is_last from config tensor
is_first: pl.Scalar[pl.INT64] = pl.tensor.read(config, [0])
is_last: pl.Scalar[pl.INT64] = pl.tensor.read(config, [1])
mi, li, oi, dst = self.online_update(mij, lij, oi_new, mi, li, oi, is_first, is_last)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, let's look at the test file to understand the context
cd tests && find . -name "test_paged_attention.py" -type f

Repository: hw-native-sys/pypto

Length of output: 99


🏁 Script executed:

# Search for the online_update method definition
rg -n "def online_update" --type py -C 5

Repository: hw-native-sys/pypto

Length of output: 3867


🏁 Script executed:

# Check the signature and how it's defined in the class
rg -n "online_update\(" --type py -B 2 -A 5 | head -100

Repository: hw-native-sys/pypto

Length of output: 5825


🏁 Script executed:

# Look for BOOL type usage in the codebase
rg -n "pl\.BOOL" --type py -C 2 | head -60

Repository: hw-native-sys/pypto

Length of output: 3119


🏁 Script executed:

# Get more context around the problematic lines
sed -n '375,395p' tests/st/codegen/test_paged_attention.py

Repository: hw-native-sys/pypto

Length of output: 1211


🏁 Script executed:

# Also check if there are bitcast/cast calls nearby
rg -n "bitcast|cast" tests/st/codegen/test_paged_attention.py -C 3 | head -80

Repository: hw-native-sys/pypto

Length of output: 436


🏁 Script executed:

# Check if there's any documentation or comments about implicit type conversion
rg -n "implicit.*cast|auto.*cast|INT64.*BOOL|BOOL.*INT64" tests --type py

Repository: hw-native-sys/pypto

Length of output: 641


🏁 Script executed:

# Look at the test class docstring more completely to understand the intent
sed -n '250,270p' tests/st/codegen/test_paged_attention.py

Repository: hw-native-sys/pypto

Length of output: 877


Add explicit cast for is_first/is_last from INT64 to BOOL before passing to online_update(). The InCore function signature requires pl.Scalar[pl.BOOL] (lines 297–298), but these values are read as pl.Scalar[pl.INT64] (lines 384–385) and passed without explicit conversion. The test documentation itself acknowledges this pattern, indicating the developers are aware of the mismatch but have not implemented an explicit fix. Confirm whether the DSL handles this conversion implicitly or add an explicit bitcast/cast operation if required.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/st/codegen/test_paged_attention.py` around lines 383 - 386, The values
read into is_first and is_last are pl.Scalar[pl.INT64] but online_update(...)
expects pl.Scalar[pl.BOOL]; add an explicit cast/bitcast of the results of
pl.tensor.read(config, [0]) and pl.tensor.read(config, [1]) to
pl.Scalar[pl.BOOL] before passing them into self.online_update (or confirm the
DSL auto-converts and add an explicit no-op cast if required). Locate the
pl.tensor.read calls that produce is_first/is_last and replace or wrap them with
an explicit cast to BOOL so the types match the online_update function
signature.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant