feat(op): add block-level bitwise (and/or/xor/shl/shr/not), arithmeti…#260
feat(op): add block-level bitwise (and/or/xor/shl/shr/not), arithmeti…#260wangqin1723-max wants to merge 1 commit intohw-native-sys:mainfrom
Conversation
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughThis PR expands the block-level operations API by adding 30+ new element-wise, arithmetic, bitwise, and matrix/vector operations across the IR and language layers. Changes span IR registration (C++), Python IR bindings, language-level Tile wrappers, module exports, parser updates, and comprehensive test coverage. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Comment |
Summary of ChangesHello @wangqin1723-max, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly expands the set of available block-level operations within the system. It introduces a wide array of new functionalities ranging from bitwise and arithmetic operations to specialized matrix multiplication and activation functions. These additions enhance the expressiveness and capability of the intermediate representation and Python DSL, enabling more complex computations directly at the block level. The changes are thoroughly integrated with updated parser support and comprehensive unit tests to ensure correctness and stability. Highlights
Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a significant number of new block-level operations across various categories, including element-wise remainder (rem, rems), bitwise shifts (shl, shls, shr, shrs), bitwise logical operations (and_, ands, or_, ors, xor, xors, not_), activation functions (prelu, lrelu), ternary arithmetic operations (addc, subc, addsc, subsc), selection operations (sel, sels), and matrix operations (matmul_bias, gemv, gemv_acc, gemv_bias). Additionally, a row_expand operation for broadcasting and maxs/mins for scalar comparisons were added. These operations are implemented in the Python IR layer (python/pypto/ir/op/block_ops.py), exposed through the pypto.language API, and registered in the C++ IR backend with corresponding type deduction logic. Comprehensive unit tests for all new operations have been added to tests/ut/ir/operators/test_block_ops.py to ensure their functionality and correct IR generation. Review comments highlight an issue where _normalize_expr in scalar bitwise and remainder operations incorrectly uses DataType.FP32 for int_dtype instead of DataType.INT32, which needs correction. Another comment points out a potentially confusing error message in DeduceBlockOpXorBinaryType due to its generic use for operations beyond XOR, suggesting a more general error message or a dedicated deduction function. Finally, the reviewer noted the high level of code duplication in the new unit tests and suggested refactoring them using parameterized tests or helper methods for better maintainability.
| rhs_expr = ( | ||
| _normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32) | ||
| if not isinstance(rhs, Expr) | ||
| else rhs | ||
| ) |
There was a problem hiding this comment.
The int_dtype for _normalize_expr is set to DataType.FP32. For a bitwise shift operation, the scalar shift amount should be an integer type, not a floating-point type. Using FP32 could lead to incorrect behavior or type errors during compilation. Please use an integer DataType like DataType.INT32.
This issue also exists for shrs, ands, ors, and xors.
| rhs_expr = ( | |
| _normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32) | |
| if not isinstance(rhs, Expr) | |
| else rhs | |
| ) | |
| rhs_expr = ( | |
| _normalize_expr(rhs, actual_span, int_dtype=DataType.INT32, float_dtype=DataType.FP32) | |
| if not isinstance(rhs, Expr) | |
| else rhs | |
| ) |
| rhs_expr = ( | ||
| _normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32) | ||
| if not isinstance(rhs, Expr) | ||
| else rhs | ||
| ) |
There was a problem hiding this comment.
The int_dtype for _normalize_expr is set to DataType.FP32. For a bitwise shift operation, the scalar shift amount should be an integer type. Using FP32 is incorrect. Please use an integer DataType like DataType.INT32.
| rhs_expr = ( | |
| _normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32) | |
| if not isinstance(rhs, Expr) | |
| else rhs | |
| ) | |
| rhs_expr = ( | |
| _normalize_expr(rhs, actual_span, int_dtype=DataType.INT32, float_dtype=DataType.FP32) | |
| if not isinstance(rhs, Expr) | |
| else rhs | |
| ) |
| rhs_expr = ( | ||
| _normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32) | ||
| if not isinstance(rhs, Expr) | ||
| else rhs | ||
| ) |
There was a problem hiding this comment.
The int_dtype for _normalize_expr is set to DataType.FP32. For a bitwise AND operation, the scalar operand should be an integer type. Using FP32 is incorrect. Please use an integer DataType like DataType.INT32.
| rhs_expr = ( | |
| _normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32) | |
| if not isinstance(rhs, Expr) | |
| else rhs | |
| ) | |
| rhs_expr = ( | |
| _normalize_expr(rhs, actual_span, int_dtype=DataType.INT32, float_dtype=DataType.FP32) | |
| if not isinstance(rhs, Expr) | |
| else rhs | |
| ) |
| rhs_expr = ( | ||
| _normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32) | ||
| if not isinstance(rhs, Expr) | ||
| else rhs | ||
| ) |
There was a problem hiding this comment.
The int_dtype for _normalize_expr is set to DataType.FP32. For a bitwise OR operation, the scalar operand should be an integer type. Using FP32 is incorrect. Please use an integer DataType like DataType.INT32.
| rhs_expr = ( | |
| _normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32) | |
| if not isinstance(rhs, Expr) | |
| else rhs | |
| ) | |
| rhs_expr = ( | |
| _normalize_expr(rhs, actual_span, int_dtype=DataType.INT32, float_dtype=DataType.FP32) | |
| if not isinstance(rhs, Expr) | |
| else rhs | |
| ) |
| rhs_expr = ( | ||
| _normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32) | ||
| if not isinstance(rhs, Expr) | ||
| else rhs | ||
| ) |
There was a problem hiding this comment.
The int_dtype for _normalize_expr is set to DataType.FP32. For a bitwise XOR operation, the scalar operand should be an integer type. Using FP32 is incorrect. Please use an integer DataType like DataType.INT32.
| rhs_expr = ( | |
| _normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32) | |
| if not isinstance(rhs, Expr) | |
| else rhs | |
| ) | |
| rhs_expr = ( | |
| _normalize_expr(rhs, actual_span, int_dtype=DataType.INT32, float_dtype=DataType.FP32) | |
| if not isinstance(rhs, Expr) | |
| else rhs | |
| ) |
| CHECK(As<TileType>(args[2]->GetType())) | ||
| << "The operator " << op_name << " requires third argument (tmp) to be a TileType, but got " | ||
| << args[2]->GetType()->TypeName(); |
There was a problem hiding this comment.
The error message in this CHECK is specific to xor-like operations that use a temporary buffer (tmp), but this function DeduceBlockOpXorBinaryType is also used for addc and subc, where the third argument is a tile operand, not a temporary buffer. This can lead to confusing error messages.
Consider renaming this function to something more generic like DeduceBlockOpTernaryType and making the error message more general, or creating a separate deduction function for addc and subc.
A quick fix would be to generalize the error message.
| CHECK(As<TileType>(args[2]->GetType())) | |
| << "The operator " << op_name << " requires third argument (tmp) to be a TileType, but got " | |
| << args[2]->GetType()->TypeName(); | |
| CHECK(As<TileType>(args[2]->GetType())) | |
| << "The operator " << op_name << " requires third argument to be a TileType, but got " | |
| << args[2]->GetType()->TypeName(); |
src/ir/op/block_ops/broadcast.cpp
Outdated
| // ============================================================================ | ||
| // Registration Function for Block Row Broadcast Operations | ||
| // ============================================================================ |
| class TestBlockBitwiseArithmeticOps: | ||
| """Test suite for newly added block-level bitwise and arithmetic ops (rem, and, or, xor).""" |
There was a problem hiding this comment.
This test class, along with others in this file, contains a lot of duplicated code for testing each new operator. This makes the test suite harder to maintain. Consider refactoring these tests to use a parameterized test function or a helper method to reduce boilerplate. A single test function could iterate over a list of operator configurations (op name, arguments, expected IR string).
There was a problem hiding this comment.
Actionable comments posted: 7
🧹 Nitpick comments (2)
tests/ut/ir/operators/test_block_ops.py (1)
382-383: Isolate backend overrides to avoid cross-test state leakage.These lines mutate global backend state to
BackendType.CCEbut do not restore prior state. This can make later tests order-dependent.💡 Suggested pytest fixture pattern
+@pytest.fixture +def cce_backend(): + backend.reset_for_testing() + backend.set_backend_type(BackendType.CCE) + yield + backend.reset_for_testing() + class TestBlockReductionOps: @@ - def test_block_row_max(self): + def test_block_row_max(self, cce_backend): @@ - backend.reset_for_testing() - backend.set_backend_type(BackendType.CCE) @@ - def test_block_row_sum(self): + def test_block_row_sum(self, cce_backend): @@ - backend.reset_for_testing() - backend.set_backend_type(BackendType.CCE)Also applies to: 408-409
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ut/ir/operators/test_block_ops.py` around lines 382 - 383, The test directly mutates global backend state by calling backend.reset_for_testing() and backend.set_backend_type(BackendType.CCE) without restoring it; change this to use a pytest fixture or try/finally that captures the current backend type (e.g., prev = backend.get_backend_type() or equivalent), calls backend.reset_for_testing() and backend.set_backend_type(BackendType.CCE) for the test, then restores the original value with backend.set_backend_type(prev) in teardown; apply the same pattern to the other occurrence that sets BackendType.CCE so tests don't leak state across runs.python/pypto/language/op/block_ops.py (1)
1030-1186: Tighten bitwise/shift scalar APIs to integer-only at the DSL layer.Line 1030, Line 1063, Line 1097, Line 1131, and Line 1168 currently accept
floatfor bitwise/shift scalar variants. Restricting these to integer-like inputs gives earlier and clearer failures.♻️ Suggested API tightening pattern
-def ands(lhs: Tile, rhs: int | float | Expr | Scalar) -> Tile: +def ands(lhs: Tile, rhs: int | Expr | Scalar) -> Tile: @@ - rhs_expr = rhs.unwrap() if isinstance(rhs, Scalar) else rhs + if isinstance(rhs, float): + raise TypeError("ands rhs must be an integer scalar") + rhs_expr = rhs.unwrap() if isinstance(rhs, Scalar) else rhsApply the same pattern to
ors,xors,shls, andshrs.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@python/pypto/language/op/block_ops.py` around lines 1030 - 1186, The bitwise/shift scalar helper functions (ands, ors, xors, shls, shrs) currently allow floats in their type signatures; tighten them to integer-only at the DSL layer by removing float from the union types (change parameters from int | float | Expr | Scalar to int | Expr | Scalar), update the corresponding docstring "Scalar value" / "Scalar shift amount" notes as needed, and ensure any calls to rhs.unwrap() logic remains unchanged; modify the function signatures for ands, ors, xors, shls, and shrs accordingly.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@python/pypto/ir/op/block_ops.py`:
- Around line 425-611: The scalar normalization in shls, shrs, ands, ors, xors
is incorrectly passing int_dtype=DataType.FP32 which will coerce integer
literals to FP32; update the _normalize_expr calls inside the functions shls,
shrs, ands, ors, and xors to use int_dtype=DataType.INT32 (keep
float_dtype=DataType.FP32 or remove float_dtype if not needed) so scalar
literals remain integer-typed for bitwise/shift intrinsics.
In `@src/ir/op/block_ops/broadcast.cpp`:
- Around line 146-159: The type-deduction lambda for
REGISTER_OP("block.row_expand") currently checks only that the argument is a
TileType but not its rank; update the f_deduce_type lambda (the block using the
tile_type variable) to validate that tile_type->shape_.size() == 2 and emit a
CHECK/failure message that the input must be 2D [M, N] (include the actual rank
in the error string). Return the same
std::make_shared<TileType>(tile_type->shape_, tile_type->dtype_) only after this
rank check passes.
In `@src/ir/op/block_ops/elementwise.cpp`:
- Around line 214-313: The bitwise ops (REGISTER_OP entries for "block.shl",
"block.shls", "block.shr", "block.shrs", "block.and", "block.ands", "block.or",
"block.ors") currently call generic deduction helpers and thus allow float
dtypes; change their .f_deduce_type lambdas to call bitwise-specific deduction
helpers (e.g., replace calls to DeduceBlockOpElementwiseBinaryType and
DeduceBlockOpScalarBinaryType with names like
DeduceBlockOpElementwiseBitwiseBinaryType and
DeduceBlockOpScalarBitwiseBinaryType or an equivalent helper that enforces
integer/bitwise-compatible dtypes), and implement/update those helper functions
to validate/require integer or bitwise-compatible ScalarType/TileType during
type inference. Ensure the op-name string passed (e.g., "block.shl",
"block.ands", etc.) is preserved when invoking the new deducer so error messages
remain clear.
- Around line 391-433: The registered ops block.addc, block.subc, block.addsc
and block.subsc currently call DeduceBlockOpXorBinaryType /
DeduceBlockOpXorScalarType which ignore the third operand rhs2; create new
ternary deduction helpers (e.g. DeduceBlockOpTernaryBinaryType and
DeduceBlockOpTernaryScalarType) and update the f_deduce_type lambdas for
"block.addc", "block.subc", "block.addsc", and "block.subsc" to call them;
ensure the new helpers include rhs2 in type validation, broadcast compatibility
checks, and dtype promotion logic (use the same validation points as
DeduceBlockOpXor* but include args[2]/rhs2 in shape and dtype resolution and
error messages).
- Around line 338-356: DeduceBlockOpXorScalarType currently never validates the
third argument (args[2]); add a check that args[2] is a TileType and has the
same shape as the first tile: obtain auto other_tile =
As<TileType>(args[2]->GetType()), CHECK(other_tile) with an error message
referencing op_name, then CHECK(other_tile->shape_ == tile_type->shape_) (and
optionally that other_tile->dtype_ is compatible with tile_type->dtype_ or with
result_dtype if required) before returning the resulting TileType; this ensures
the third operand is a tile with a matching shape (and compatible dtype) in
DeduceBlockOpXorScalarType.
In `@src/ir/op/block_ops/matmul.cpp`:
- Around line 174-215: DeduceBlockMatMulBiasType currently ignores bias shape
and dtype when inferring output: after computing output_shape and result_dtype
from lhs_type and rhs_type, validate that bias_type->shape_ matches output_shape
(same rank and same ConstInt dimensions where present) and include
bias_type->dtype_ in the dtype promotion (e.g., promote lhs and rhs then promote
with bias, or call a 3-way PromoteDataTypes if available) before constructing
and returning the TileType; update checks to use bias_type and output_shape
variables and ensure error messages reference op_name and the offending
dimension or dtype.
In `@src/ir/op/block_ops/unary.cpp`:
- Around line 173-180: The registered op block.not currently uses
DeduceBlockUnaryType which allows floating-point tiles; change the
type-deduction lambda to enforce bitwise-compatible dtypes by checking the input
tile's dtype and rejecting non-integer/non-bool types (or call a helper like
DeduceBlockUnaryBitwiseType if available). Specifically, inside the
REGISTER_OP("block.not") f_deduce_type lambda, inspect the first arg's dtype via
the existing type-deduction utilities and return an error/invalid type when the
dtype is floating-point, ensuring only integer or boolean tile dtypes are
accepted; keep the operator name "block.not" and reuse existing deduction
patterns from other bitwise ops for consistency.
---
Nitpick comments:
In `@python/pypto/language/op/block_ops.py`:
- Around line 1030-1186: The bitwise/shift scalar helper functions (ands, ors,
xors, shls, shrs) currently allow floats in their type signatures; tighten them
to integer-only at the DSL layer by removing float from the union types (change
parameters from int | float | Expr | Scalar to int | Expr | Scalar), update the
corresponding docstring "Scalar value" / "Scalar shift amount" notes as needed,
and ensure any calls to rhs.unwrap() logic remains unchanged; modify the
function signatures for ands, ors, xors, shls, and shrs accordingly.
In `@tests/ut/ir/operators/test_block_ops.py`:
- Around line 382-383: The test directly mutates global backend state by calling
backend.reset_for_testing() and backend.set_backend_type(BackendType.CCE)
without restoring it; change this to use a pytest fixture or try/finally that
captures the current backend type (e.g., prev = backend.get_backend_type() or
equivalent), calls backend.reset_for_testing() and
backend.set_backend_type(BackendType.CCE) for the test, then restores the
original value with backend.set_backend_type(prev) in teardown; apply the same
pattern to the other occurrence that sets BackendType.CCE so tests don't leak
state across runs.
ℹ️ Review info
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (10)
python/pypto/ir/op/block_ops.pypython/pypto/language/__init__.pypython/pypto/language/op/__init__.pypython/pypto/language/op/block_ops.pypython/pypto/language/parser/ast_parser.pysrc/ir/op/block_ops/broadcast.cppsrc/ir/op/block_ops/elementwise.cppsrc/ir/op/block_ops/matmul.cppsrc/ir/op/block_ops/unary.cpptests/ut/ir/operators/test_block_ops.py
python/pypto/ir/op/block_ops.py
Outdated
| def shls(lhs: Expr, rhs: int | float | Expr, span: Span | None = None) -> Call: | ||
| """Element-wise bitwise left shift of tile and scalar. | ||
|
|
||
| Computes lhs << rhs element-wise. Maps to the TSHLS hardware intrinsic. | ||
|
|
||
| Note: | ||
| The scalar shift amount must be zero or positive; negative values are | ||
| not supported by the hardware and will be rejected by codegen. | ||
|
|
||
| Args: | ||
| lhs: Tile (TileType) | ||
| rhs: Scalar shift amount (int/float/Expr with ScalarType); must be >= 0 | ||
| span: Optional source span for debugging (auto-captured if not provided) | ||
|
|
||
| Returns: | ||
| Call expression for element-wise bitwise left shift with scalar | ||
| """ | ||
| actual_span = _get_span_or_capture(span) | ||
| rhs_expr = ( | ||
| _normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32) | ||
| if not isinstance(rhs, Expr) | ||
| else rhs | ||
| ) | ||
| return _ir_core.create_op_call("block.shls", [lhs, rhs_expr], {}, actual_span) | ||
|
|
||
|
|
||
| def shr(lhs: Expr, rhs: Expr, span: Span | None = None) -> Call: | ||
| """Element-wise bitwise right shift of two tiles. | ||
|
|
||
| Computes lhs >> rhs element-wise. Maps to the TSHR hardware intrinsic. | ||
|
|
||
| Args: | ||
| lhs: Left-hand side tile (TileType) | ||
| rhs: Right-hand side tile (TileType) | ||
| span: Optional source span for debugging (auto-captured if not provided) | ||
|
|
||
| Returns: | ||
| Call expression for element-wise bitwise right shift | ||
| """ | ||
| actual_span = _get_span_or_capture(span) | ||
| return _ir_core.create_op_call("block.shr", [lhs, rhs], {}, actual_span) | ||
|
|
||
|
|
||
| def shrs(lhs: Expr, rhs: int | float | Expr, span: Span | None = None) -> Call: | ||
| """Element-wise bitwise right shift of tile and scalar. | ||
|
|
||
| Computes lhs >> rhs element-wise. Maps to the TSHRS hardware intrinsic. | ||
|
|
||
| Note: | ||
| The scalar shift amount must be zero or positive; negative values are | ||
| not supported by the hardware and will be rejected by codegen. | ||
|
|
||
| Args: | ||
| lhs: Tile (TileType) | ||
| rhs: Scalar shift amount (int/float/Expr with ScalarType); must be >= 0 | ||
| span: Optional source span for debugging (auto-captured if not provided) | ||
|
|
||
| Returns: | ||
| Call expression for element-wise bitwise right shift with scalar | ||
| """ | ||
| actual_span = _get_span_or_capture(span) | ||
| rhs_expr = ( | ||
| _normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32) | ||
| if not isinstance(rhs, Expr) | ||
| else rhs | ||
| ) | ||
| return _ir_core.create_op_call("block.shrs", [lhs, rhs_expr], {}, actual_span) | ||
|
|
||
|
|
||
| def and_(lhs: Expr, rhs: Expr, span: Span | None = None) -> Call: | ||
| """Element-wise bitwise AND of two tiles. | ||
|
|
||
| Computes lhs & rhs element-wise. Maps to the TAND hardware intrinsic. | ||
|
|
||
| Args: | ||
| lhs: Left-hand side tile (TileType) | ||
| rhs: Right-hand side tile (TileType) | ||
| span: Optional source span for debugging (auto-captured if not provided) | ||
|
|
||
| Returns: | ||
| Call expression for element-wise bitwise AND | ||
| """ | ||
| actual_span = _get_span_or_capture(span) | ||
| return _ir_core.create_op_call("block.and", [lhs, rhs], {}, actual_span) | ||
|
|
||
|
|
||
| def ands(lhs: Expr, rhs: int | float | Expr, span: Span | None = None) -> Call: | ||
| """Element-wise bitwise AND of tile and scalar. | ||
|
|
||
| Computes lhs & rhs element-wise. Maps to the TANDS hardware intrinsic. | ||
|
|
||
| Args: | ||
| lhs: Tile (TileType) | ||
| rhs: Scalar (int/float/Expr with ScalarType) | ||
| span: Optional source span for debugging (auto-captured if not provided) | ||
|
|
||
| Returns: | ||
| Call expression for element-wise bitwise AND with scalar | ||
| """ | ||
| actual_span = _get_span_or_capture(span) | ||
| rhs_expr = ( | ||
| _normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32) | ||
| if not isinstance(rhs, Expr) | ||
| else rhs | ||
| ) | ||
| return _ir_core.create_op_call("block.ands", [lhs, rhs_expr], {}, actual_span) | ||
|
|
||
|
|
||
| def or_(lhs: Expr, rhs: Expr, span: Span | None = None) -> Call: | ||
| """Element-wise bitwise OR of two tiles. | ||
|
|
||
| Computes lhs | rhs element-wise. Maps to the TOR hardware intrinsic. | ||
|
|
||
| Args: | ||
| lhs: Left-hand side tile (TileType) | ||
| rhs: Right-hand side tile (TileType) | ||
| span: Optional source span for debugging (auto-captured if not provided) | ||
|
|
||
| Returns: | ||
| Call expression for element-wise bitwise OR | ||
| """ | ||
| actual_span = _get_span_or_capture(span) | ||
| return _ir_core.create_op_call("block.or", [lhs, rhs], {}, actual_span) | ||
|
|
||
|
|
||
| def ors(lhs: Expr, rhs: int | float | Expr, span: Span | None = None) -> Call: | ||
| """Element-wise bitwise OR of tile and scalar. | ||
|
|
||
| Computes lhs | rhs element-wise. Maps to the TORS hardware intrinsic. | ||
|
|
||
| Args: | ||
| lhs: Tile (TileType) | ||
| rhs: Scalar (int/float/Expr with ScalarType) | ||
| span: Optional source span for debugging (auto-captured if not provided) | ||
|
|
||
| Returns: | ||
| Call expression for element-wise bitwise OR with scalar | ||
| """ | ||
| actual_span = _get_span_or_capture(span) | ||
| rhs_expr = ( | ||
| _normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32) | ||
| if not isinstance(rhs, Expr) | ||
| else rhs | ||
| ) | ||
| return _ir_core.create_op_call("block.ors", [lhs, rhs_expr], {}, actual_span) | ||
|
|
||
|
|
||
| def xor(lhs: Expr, rhs: Expr, tmp: Expr, span: Span | None = None) -> Call: | ||
| """Element-wise bitwise XOR of two tiles. | ||
|
|
||
| Computes lhs ^ rhs element-wise. Maps to the TXOR hardware intrinsic. | ||
|
|
||
| Args: | ||
| lhs: Left-hand side tile (TileType) | ||
| rhs: Right-hand side tile (TileType) | ||
| tmp: Temporary tile (TileType) required by the hardware | ||
| span: Optional source span for debugging (auto-captured if not provided) | ||
|
|
||
| Returns: | ||
| Call expression for element-wise bitwise XOR | ||
| """ | ||
| actual_span = _get_span_or_capture(span) | ||
| return _ir_core.create_op_call("block.xor", [lhs, rhs, tmp], {}, actual_span) | ||
|
|
||
|
|
||
| def xors(lhs: Expr, rhs: int | float | Expr, tmp: Expr, span: Span | None = None) -> Call: | ||
| """Element-wise bitwise XOR of tile and scalar. | ||
|
|
||
| Computes lhs ^ rhs element-wise. Maps to the TXORS hardware intrinsic. | ||
|
|
||
| Args: | ||
| lhs: Tile (TileType) | ||
| rhs: Scalar (int/float/Expr with ScalarType) | ||
| tmp: Temporary tile (TileType) required by the hardware | ||
| span: Optional source span for debugging (auto-captured if not provided) | ||
|
|
||
| Returns: | ||
| Call expression for element-wise bitwise XOR with scalar | ||
| """ | ||
| actual_span = _get_span_or_capture(span) | ||
| rhs_expr = ( | ||
| _normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32) | ||
| if not isinstance(rhs, Expr) | ||
| else rhs | ||
| ) | ||
| return _ir_core.create_op_call("block.xors", [lhs, rhs_expr, tmp], {}, actual_span) | ||
|
|
There was a problem hiding this comment.
Bitwise/shift scalar normalization is using FP32 and can break operand typing.
Line 444, Line 487, Line 526, Line 565, and Line 606 normalize scalar literals with int_dtype=DataType.FP32. For shls/shrs/ands/ors/xors, this can coerce integer literals into FP32 constants before emitting integer bitwise/shift ops.
🐛 Proposed fix pattern
+def _normalize_int_scalar(rhs: int | float | Expr, span: Span, arg_name: str = "rhs") -> Expr:
+ if isinstance(rhs, float):
+ raise TypeError(f"{arg_name} must be an integer scalar for bitwise/shift ops")
+ if isinstance(rhs, Expr):
+ return rhs
+ return _normalize_expr(rhs, span, int_dtype=DataType.INT32, float_dtype=DataType.FP32)
@@
def shls(lhs: Expr, rhs: int | float | Expr, span: Span | None = None) -> Call:
@@
- rhs_expr = (
- _normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32)
- if not isinstance(rhs, Expr)
- else rhs
- )
+ rhs_expr = _normalize_int_scalar(rhs, actual_span)
@@
def shrs(lhs: Expr, rhs: int | float | Expr, span: Span | None = None) -> Call:
@@
- rhs_expr = (
- _normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32)
- if not isinstance(rhs, Expr)
- else rhs
- )
+ rhs_expr = _normalize_int_scalar(rhs, actual_span)
@@
def ands(lhs: Expr, rhs: int | float | Expr, span: Span | None = None) -> Call:
@@
- rhs_expr = (
- _normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32)
- if not isinstance(rhs, Expr)
- else rhs
- )
+ rhs_expr = _normalize_int_scalar(rhs, actual_span)
@@
def ors(lhs: Expr, rhs: int | float | Expr, span: Span | None = None) -> Call:
@@
- rhs_expr = (
- _normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32)
- if not isinstance(rhs, Expr)
- else rhs
- )
+ rhs_expr = _normalize_int_scalar(rhs, actual_span)
@@
def xors(lhs: Expr, rhs: int | float | Expr, tmp: Expr, span: Span | None = None) -> Call:
@@
- rhs_expr = (
- _normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32)
- if not isinstance(rhs, Expr)
- else rhs
- )
+ rhs_expr = _normalize_int_scalar(rhs, actual_span)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@python/pypto/ir/op/block_ops.py` around lines 425 - 611, The scalar
normalization in shls, shrs, ands, ors, xors is incorrectly passing
int_dtype=DataType.FP32 which will coerce integer literals to FP32; update the
_normalize_expr calls inside the functions shls, shrs, ands, ors, and xors to
use int_dtype=DataType.INT32 (keep float_dtype=DataType.FP32 or remove
float_dtype if not needed) so scalar literals remain integer-typed for
bitwise/shift intrinsics.
| REGISTER_OP("block.shl") | ||
| .set_op_category("BlockOp") | ||
| .set_description("Element-wise bitwise left shift of two tiles with broadcasting") | ||
| .add_argument("lhs", "Left-hand side tile (TileType)") | ||
| .add_argument("rhs", "Right-hand side tile (TileType)") | ||
| .f_deduce_type([](const std::vector<ExprPtr>& args, | ||
| const std::vector<std::pair<std::string, std::any>>& kwargs) { | ||
| return DeduceBlockOpElementwiseBinaryType(args, kwargs, "block.shl"); | ||
| }); | ||
|
|
||
| REGISTER_OP("block.shls") | ||
| .set_op_category("BlockOp") | ||
| .set_description("Element-wise bitwise left shift of tile and scalar") | ||
| .add_argument("lhs", "Tile (TileType)") | ||
| .add_argument("rhs", "Scalar (ScalarType)") | ||
| .f_deduce_type([](const std::vector<ExprPtr>& args, | ||
| const std::vector<std::pair<std::string, std::any>>& kwargs) { | ||
| return DeduceBlockOpScalarBinaryType(args, kwargs, "block.shls"); | ||
| }); | ||
|
|
||
| REGISTER_OP("block.shr") | ||
| .set_op_category("BlockOp") | ||
| .set_description("Element-wise bitwise right shift of two tiles with broadcasting") | ||
| .add_argument("lhs", "Left-hand side tile (TileType)") | ||
| .add_argument("rhs", "Right-hand side tile (TileType)") | ||
| .f_deduce_type([](const std::vector<ExprPtr>& args, | ||
| const std::vector<std::pair<std::string, std::any>>& kwargs) { | ||
| return DeduceBlockOpElementwiseBinaryType(args, kwargs, "block.shr"); | ||
| }); | ||
|
|
||
| REGISTER_OP("block.shrs") | ||
| .set_op_category("BlockOp") | ||
| .set_description("Element-wise bitwise right shift of tile and scalar") | ||
| .add_argument("lhs", "Tile (TileType)") | ||
| .add_argument("rhs", "Scalar (ScalarType)") | ||
| .f_deduce_type([](const std::vector<ExprPtr>& args, | ||
| const std::vector<std::pair<std::string, std::any>>& kwargs) { | ||
| return DeduceBlockOpScalarBinaryType(args, kwargs, "block.shrs"); | ||
| }); | ||
|
|
||
| REGISTER_OP("block.maxs") | ||
| .set_op_category("BlockOp") | ||
| .set_description("Element-wise maximum of tile and scalar") | ||
| .add_argument("lhs", "Tile (TileType)") | ||
| .add_argument("rhs", "Scalar (ScalarType)") | ||
| .f_deduce_type([](const std::vector<ExprPtr>& args, | ||
| const std::vector<std::pair<std::string, std::any>>& kwargs) { | ||
| return DeduceBlockOpScalarBinaryType(args, kwargs, "block.maxs"); | ||
| }); | ||
|
|
||
| REGISTER_OP("block.mins") | ||
| .set_op_category("BlockOp") | ||
| .set_description("Element-wise minimum of tile and scalar") | ||
| .add_argument("lhs", "Tile (TileType)") | ||
| .add_argument("rhs", "Scalar (ScalarType)") | ||
| .f_deduce_type([](const std::vector<ExprPtr>& args, | ||
| const std::vector<std::pair<std::string, std::any>>& kwargs) { | ||
| return DeduceBlockOpScalarBinaryType(args, kwargs, "block.mins"); | ||
| }); | ||
|
|
||
| REGISTER_OP("block.and") | ||
| .set_op_category("BlockOp") | ||
| .set_description("Element-wise bitwise AND of two tiles with broadcasting") | ||
| .add_argument("lhs", "Left-hand side tile (TileType)") | ||
| .add_argument("rhs", "Right-hand side tile (TileType)") | ||
| .f_deduce_type([](const std::vector<ExprPtr>& args, | ||
| const std::vector<std::pair<std::string, std::any>>& kwargs) { | ||
| return DeduceBlockOpElementwiseBinaryType(args, kwargs, "block.and"); | ||
| }); | ||
|
|
||
| REGISTER_OP("block.ands") | ||
| .set_op_category("BlockOp") | ||
| .set_description("Element-wise bitwise AND of tile and scalar") | ||
| .add_argument("lhs", "Tile (TileType)") | ||
| .add_argument("rhs", "Scalar (ScalarType)") | ||
| .f_deduce_type([](const std::vector<ExprPtr>& args, | ||
| const std::vector<std::pair<std::string, std::any>>& kwargs) { | ||
| return DeduceBlockOpScalarBinaryType(args, kwargs, "block.ands"); | ||
| }); | ||
|
|
||
| REGISTER_OP("block.or") | ||
| .set_op_category("BlockOp") | ||
| .set_description("Element-wise bitwise OR of two tiles with broadcasting") | ||
| .add_argument("lhs", "Left-hand side tile (TileType)") | ||
| .add_argument("rhs", "Right-hand side tile (TileType)") | ||
| .f_deduce_type([](const std::vector<ExprPtr>& args, | ||
| const std::vector<std::pair<std::string, std::any>>& kwargs) { | ||
| return DeduceBlockOpElementwiseBinaryType(args, kwargs, "block.or"); | ||
| }); | ||
|
|
||
| REGISTER_OP("block.ors") | ||
| .set_op_category("BlockOp") | ||
| .set_description("Element-wise bitwise OR of tile and scalar") | ||
| .add_argument("lhs", "Tile (TileType)") | ||
| .add_argument("rhs", "Scalar (ScalarType)") | ||
| .f_deduce_type([](const std::vector<ExprPtr>& args, | ||
| const std::vector<std::pair<std::string, std::any>>& kwargs) { | ||
| return DeduceBlockOpScalarBinaryType(args, kwargs, "block.ors"); | ||
| }); | ||
|
|
There was a problem hiding this comment.
Bitwise op registrations should enforce bitwise-compatible dtypes.
block.shl/shls/shr/shrs/and/ands/or/ors currently use generic arithmetic deduction, so float operands are accepted. That allows invalid bitwise IR combinations at type-inference time.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/ir/op/block_ops/elementwise.cpp` around lines 214 - 313, The bitwise ops
(REGISTER_OP entries for "block.shl", "block.shls", "block.shr", "block.shrs",
"block.and", "block.ands", "block.or", "block.ors") currently call generic
deduction helpers and thus allow float dtypes; change their .f_deduce_type
lambdas to call bitwise-specific deduction helpers (e.g., replace calls to
DeduceBlockOpElementwiseBinaryType and DeduceBlockOpScalarBinaryType with names
like DeduceBlockOpElementwiseBitwiseBinaryType and
DeduceBlockOpScalarBitwiseBinaryType or an equivalent helper that enforces
integer/bitwise-compatible dtypes), and implement/update those helper functions
to validate/require integer or bitwise-compatible ScalarType/TileType during
type inference. Ensure the op-name string passed (e.g., "block.shl",
"block.ands", etc.) is preserved when invoking the new deducer so error messages
remain clear.
| REGISTER_OP("block.not") | ||
| .set_op_category("BlockOp") | ||
| .set_description("Element-wise bitwise NOT of a tile") | ||
| .add_argument("tile", "Input tile (TileType)") | ||
| .f_deduce_type([](const std::vector<ExprPtr>& args, | ||
| const std::vector<std::pair<std::string, std::any>>& kwargs) { | ||
| return DeduceBlockUnaryType(args, kwargs, "block.not"); | ||
| }); |
There was a problem hiding this comment.
Restrict block.not to bitwise-compatible dtypes.
Line 179 routes block.not through DeduceBlockUnaryType, so float tiles are accepted. For a bitwise NOT op, this should reject non-bitwise dtypes during type deduction.
💡 Proposed fix
+TypePtr DeduceBlockBitwiseUnaryType(const std::vector<ExprPtr>& args,
+ const std::vector<std::pair<std::string, std::any>>& kwargs,
+ const std::string& op_name) {
+ CHECK(args.size() == 1) << "The operator " << op_name << " requires exactly 1 argument, but got "
+ << args.size();
+ auto tile_type = As<TileType>(args[0]->GetType());
+ CHECK(tile_type) << "The operator " << op_name << " requires argument to be a TileType, but got "
+ << args[0]->GetType()->TypeName();
+ CHECK(!tile_type->dtype_.IsFloat())
+ << "The operator " << op_name << " requires bitwise-compatible dtype, but got "
+ << tile_type->dtype_.ToString();
+ return std::make_shared<TileType>(tile_type->shape_, tile_type->dtype_);
+}
+
REGISTER_OP("block.not")
@@
.f_deduce_type([](const std::vector<ExprPtr>& args,
const std::vector<std::pair<std::string, std::any>>& kwargs) {
- return DeduceBlockUnaryType(args, kwargs, "block.not");
+ return DeduceBlockBitwiseUnaryType(args, kwargs, "block.not");
});📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| REGISTER_OP("block.not") | |
| .set_op_category("BlockOp") | |
| .set_description("Element-wise bitwise NOT of a tile") | |
| .add_argument("tile", "Input tile (TileType)") | |
| .f_deduce_type([](const std::vector<ExprPtr>& args, | |
| const std::vector<std::pair<std::string, std::any>>& kwargs) { | |
| return DeduceBlockUnaryType(args, kwargs, "block.not"); | |
| }); | |
| TypePtr DeduceBlockBitwiseUnaryType(const std::vector<ExprPtr>& args, | |
| const std::vector<std::pair<std::string, std::any>>& kwargs, | |
| const std::string& op_name) { | |
| CHECK(args.size() == 1) << "The operator " << op_name << " requires exactly 1 argument, but got " | |
| << args.size(); | |
| auto tile_type = As<TileType>(args[0]->GetType()); | |
| CHECK(tile_type) << "The operator " << op_name << " requires argument to be a TileType, but got " | |
| << args[0]->GetType()->TypeName(); | |
| CHECK(!tile_type->dtype_.IsFloat()) | |
| << "The operator " << op_name << " requires bitwise-compatible dtype, but got " | |
| << tile_type->dtype_.ToString(); | |
| return std::make_shared<TileType>(tile_type->shape_, tile_type->dtype_); | |
| } | |
| REGISTER_OP("block.not") | |
| .set_op_category("BlockOp") | |
| .set_description("Element-wise bitwise NOT of a tile") | |
| .add_argument("tile", "Input tile (TileType)") | |
| .f_deduce_type([](const std::vector<ExprPtr>& args, | |
| const std::vector<std::pair<std::string, std::any>>& kwargs) { | |
| return DeduceBlockBitwiseUnaryType(args, kwargs, "block.not"); | |
| }); |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/ir/op/block_ops/unary.cpp` around lines 173 - 180, The registered op
block.not currently uses DeduceBlockUnaryType which allows floating-point tiles;
change the type-deduction lambda to enforce bitwise-compatible dtypes by
checking the input tile's dtype and rejecting non-integer/non-bool types (or
call a helper like DeduceBlockUnaryBitwiseType if available). Specifically,
inside the REGISTER_OP("block.not") f_deduce_type lambda, inspect the first
arg's dtype via the existing type-deduction utilities and return an
error/invalid type when the dtype is floating-point, ensuring only integer or
boolean tile dtypes are accepted; keep the operator name "block.not" and reuse
existing deduction patterns from other bitwise ops for consistency.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces several new block-level operations, including rem, rems, bitwise operations (and_, ands, or_, ors, xor, xors, shl, shls, shr, shrs), maxs, mins, prelu, not_, ternary arithmetic operations (addc, subc, addsc, subsc), lrelu, selection operations (sel, sels), matrix-vector multiplication (gemv, gemv_acc, gemv_bias), and row_expand. The code review highlights several issues in the type deduction logic for ternary operations (addc, subc, addsc, subsc, prelu), matmul_bias, and rems. Specifically, the type deduction for ternary operations and matmul_bias incorrectly handles the third argument's data type promotion and shape broadcasting. The rems function also incorrectly normalizes integer rhs values to DataType.FP32 instead of DataType.INT32. Additionally, the addsc and subsc operations use an incorrectly named type deduction function (DeduceBlockOpXorScalarType) that enforces an integer type for the scalar argument and does not correctly deduce the output type.
| auto result_dtype = PromoteDataTypes(tile_type1->dtype_, tile_type2->dtype_); | ||
| CHECK(result_dtype) << "The operator " << op_name << " requires compatible data types"; | ||
| auto broadcast_result = BroadcastShapes(tile_type1->shape_, tile_type2->shape_); | ||
| CHECK(broadcast_result.success) << "The operator " << op_name << " requires compatible shapes"; |
There was a problem hiding this comment.
The DeduceBlockOpTernaryType function only promotes the data types and broadcasts the shapes of the first two arguments, ignoring the third argument. For operations like addc(lhs, rhs, rhs2), this is incorrect as the type of the third argument should also be part of the type promotion and shape broadcasting.
| REGISTER_OP("block.addsc") | ||
| .set_op_category("BlockOp") | ||
| .set_description("Element-wise addition of tile, scalar, and tile (lhs + scalar + rhs2)") | ||
| .add_argument("lhs", "Left-hand side tile (TileType)") | ||
| .add_argument("rhs", "Scalar (ScalarType)") | ||
| .add_argument("rhs2", "Third tile (TileType)") | ||
| .f_deduce_type([](const std::vector<ExprPtr>& args, | ||
| const std::vector<std::pair<std::string, std::any>>& kwargs) { | ||
| return DeduceBlockOpXorScalarType(args, kwargs, "block.addsc"); | ||
| }); | ||
|
|
||
| REGISTER_OP("block.subsc") | ||
| .set_op_category("BlockOp") | ||
| .set_description("Element-wise subtraction of tile, scalar, and tile (lhs - scalar - rhs2)") | ||
| .add_argument("lhs", "Left-hand side tile (TileType)") | ||
| .add_argument("rhs", "Scalar (ScalarType)") | ||
| .add_argument("rhs2", "Third tile (TileType)") | ||
| .f_deduce_type([](const std::vector<ExprPtr>& args, | ||
| const std::vector<std::pair<std::string, std::any>>& kwargs) { | ||
| return DeduceBlockOpXorScalarType(args, kwargs, "block.subsc"); | ||
| }); |
There was a problem hiding this comment.
The addsc and subsc operations are using DeduceBlockOpXorScalarType for type deduction. This is incorrect for several reasons:
- The function is poorly named for this purpose, leading to confusion.
- It enforces that the scalar argument is an integer (
CHECK(scalar_type->dtype_.IsInt())), which is too restrictive foraddsc/subscwhich should support floating-point types. - It does not correctly deduce the output type. It only considers the first argument's type (
lhs) and ignores the types of the scalar and the third tile argument (rhs2). The output type should be a promotion of all three inputs.
A new type deduction helper function should be created for(Tile, Scalar, Tile)operations likeaddscandsubsc.
| """ | ||
| actual_span = _get_span_or_capture(span) | ||
| rhs_expr = ( | ||
| _normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32) |
There was a problem hiding this comment.
The rems function normalizes integer rhs values to DataType.FP32. Remainder is typically an integer operation, and this implicit conversion of an integer to a float can lead to precision issues or unexpected behavior. It should be normalized to an integer type, such as DataType.INT32, to preserve the integer nature of the operation when an integer literal is provided.
| _normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32) | |
| _normalize_expr(rhs, actual_span, int_dtype=DataType.INT32, float_dtype=DataType.FP32) |
src/ir/op/block_ops/matmul.cpp
Outdated
| auto result_dtype = PromoteDataTypes(lhs_type->dtype_, rhs_type->dtype_); | ||
| CHECK(result_dtype) << "The operator " << op_name << " requires compatible data types"; |
There was a problem hiding this comment.
The type deduction for matmul_bias only promotes the data types of lhs and rhs, ignoring the bias argument's type. For an operation like C = lhs @ rhs + bias, the final result type should be the promotion of all three inputs: promote(promote(lhs, rhs), bias).
| auto result_dtype = PromoteDataTypes(lhs_type->dtype_, rhs_type->dtype_); | |
| CHECK(result_dtype) << "The operator " << op_name << " requires compatible data types"; | |
| auto matmul_dtype = PromoteDataTypes(lhs_type->dtype_, rhs_type->dtype_); | |
| CHECK(matmul_dtype) << "The operator " << op_name << " requires compatible data types for lhs and rhs"; | |
| auto result_dtype = PromoteDataTypes(*matmul_dtype, bias_type->dtype_); | |
| CHECK(result_dtype) << "The operator " << op_name << " requires compatible data types for matmul result and bias"; |
|
/gemini review |
There was a problem hiding this comment.
Code Review
The pull request introduces a comprehensive set of new block-level operations, including bitwise operations (AND, OR, XOR, SHL, SHR, NOT), arithmetic remainder (REM), activation functions (PReLU, LReLU), selection operations (SEL, SELS), and matrix multiplication variants (MATMUL_BIAS, GEMV). The changes span across the IR definition, Python DSL wrappers, and unit tests. The implementation appears robust, with appropriate type deduction logic in C++ and consistent Python bindings. The new test cases cover the basic functionality of each added operation, ensuring correct IR generation.
| _normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32) | ||
| if not isinstance(rhs, Expr) | ||
| else rhs | ||
| ) |
There was a problem hiding this comment.
The rems function normalizes integer scalar inputs to DataType.FP32 via _normalize_expr. This might lead to unexpected floating-point remainder behavior if the user intends integer remainder for integer inputs, especially since other bitwise scalar operations (shls, shrs, ands, ors, xors, addsc, subsc) normalize integer scalars to DataType.INT32. If the underlying TREMS hardware intrinsic supports both integer and floating-point remainder, or if integer remainder is expected for integer inputs, this implicit conversion could be problematic. Please clarify if TREMS is exclusively floating-point, or if integer remainder is also supported and intended for integer scalar inputs.
| _normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32) | |
| if not isinstance(rhs, Expr) | |
| else rhs | |
| ) | |
| _normalize_expr(rhs, actual_span, int_dtype=DataType.INT32, float_dtype=DataType.FP32) | |
| if not isinstance(rhs, Expr) | |
| else rhs |
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces a significant number of new block-level operations, including bitwise, arithmetic, activation, and matrix multiplication variants. The implementation spans the Python DSL, IR bindings, and C++ operator registration, following the existing architecture. The changes are well-structured and include corresponding tests for the new operations. My review identified a high-severity issue in the type deduction logic for matmul_bias and gemv_bias which could lead to incorrect type promotion, and a minor style issue with a duplicated comment. Overall, this is a great addition to the operator library.
src/ir/op/block_ops/matmul.cpp
Outdated
| auto result_dtype = PromoteDataTypes(lhs_type->dtype_, rhs_type->dtype_); | ||
| CHECK(result_dtype) << "The operator " << op_name << " requires compatible data types"; | ||
|
|
||
| std::vector<ExprPtr> output_shape = {lhs_shape[0], rhs_shape[1]}; | ||
| return std::make_shared<TileType>(output_shape, *result_dtype); |
There was a problem hiding this comment.
The type deduction for matmul_bias (and gemv_bias) doesn't seem to consider the bias tensor's data type for promotion. The operation is C = (lhs @ rhs) + bias, so the final result type should be the promotion of (lhs @ rhs) and bias. Currently, it only promotes lhs and rhs. This could lead to incorrect type inference and potential precision loss if the bias has a higher precision type.
| auto result_dtype = PromoteDataTypes(lhs_type->dtype_, rhs_type->dtype_); | |
| CHECK(result_dtype) << "The operator " << op_name << " requires compatible data types"; | |
| std::vector<ExprPtr> output_shape = {lhs_shape[0], rhs_shape[1]}; | |
| return std::make_shared<TileType>(output_shape, *result_dtype); | |
| auto matmul_dtype = PromoteDataTypes(lhs_type->dtype_, rhs_type->dtype_); | |
| CHECK(matmul_dtype) << "The operator " << op_name << " requires compatible lhs and rhs data types"; | |
| auto result_dtype = PromoteDataTypes(*matmul_dtype, bias_type->dtype_); | |
| CHECK(result_dtype) << "The operator " << op_name << " requires a compatible bias data type"; | |
| return std::make_shared<TileType>(std::vector<ExprPtr>{lhs_shape[0], rhs_shape[1]}, *result_dtype); |
src/ir/op/block_ops/broadcast.cpp
Outdated
| // ============================================================================ | ||
| // Registration Function for Block Row Broadcast Operations | ||
| // ============================================================================ |
…c (rem), activation (prelu/lrelu), select, matmul variants (matmul_bias/gemv), and broadcast (row_expand) ops
38af807 to
59680cf
Compare
| """ | ||
| actual_span = _get_span_or_capture(span) | ||
| rhs_expr = ( | ||
| _normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32) |
There was a problem hiding this comment.
int_dtype should be INT32?
| "matmul", | ||
| "matmul_acc", | ||
| "matmul_bias", | ||
| "gemv", |
| return Tile(expr=call_expr) | ||
|
|
||
|
|
||
| def and_(lhs: Tile, rhs: Tile) -> Tile: |
| return DeduceBlockOpScalarBinaryType(args, kwargs, "block.mins"); | ||
| }); | ||
|
|
||
| REGISTER_OP("block.and") |
| auto result_dtype = PromoteDataTypes(tile_type1->dtype_, tile_type2->dtype_); | ||
| CHECK(result_dtype) << "The operator " << op_name << " requires compatible data types"; | ||
| auto broadcast_result = BroadcastShapes(tile_type1->shape_, tile_type2->shape_); | ||
| CHECK(broadcast_result.success) << "The operator " << op_name << " requires compatible shapes"; |
| } | ||
|
|
||
| // All three tiles are real inputs (addc, subc): promote dtype and broadcast shape across all three. | ||
| TypePtr DeduceBlockOpTriTileType(const std::vector<ExprPtr>& args, |
There was a problem hiding this comment.
这个和DeduceBlockOpTernaryType似乎是一样的?
| return DeduceBlockMatMulAccType(args, kwargs, "block.matmul_acc"); | ||
| }); | ||
|
|
||
| REGISTER_OP("block.matmul_bias") |
…c (rem), activation (prelu/lrelu), select, matmul variants (matmul_bias/gemv), and broadcast (row_expand) ops