From 498fd4d4f12692366df8414ca7ba2c81ea264f54 Mon Sep 17 00:00:00 2001 From: wangqin1723-max Date: Wed, 25 Feb 2026 16:41:42 +0800 Subject: [PATCH] feat(op): add block-level bitwise (and/or/xor/shl/shr/not), arithmetic (rem), activation (prelu/lrelu), select, matmul variants (matmul_bias/gemv), and broadcast (row_expand) ops --- python/pypto/ir/op/block_ops.py | 548 ++++++++++++++- python/pypto/language/__init__.py | 56 ++ python/pypto/language/op/__init__.py | 56 ++ python/pypto/language/op/block_ops.py | 500 ++++++++++++++ python/pypto/language/parser/ast_parser.py | 5 + src/ir/op/block_ops/broadcast.cpp | 18 + src/ir/op/block_ops/elementwise.cpp | 486 +++++++++++++- src/ir/op/block_ops/matmul.cpp | 107 +++ src/ir/op/block_ops/unary.cpp | 18 + tests/ut/ir/operators/test_block_ops.py | 743 ++++++++++++++++++++- 10 files changed, 2530 insertions(+), 7 deletions(-) diff --git a/python/pypto/ir/op/block_ops.py b/python/pypto/ir/op/block_ops.py index a4704a65..54ca3b64 100644 --- a/python/pypto/ir/op/block_ops.py +++ b/python/pypto/ir/op/block_ops.py @@ -366,6 +366,405 @@ def sub(lhs: Expr, rhs: Expr, span: Span | None = None) -> Call: return _ir_core.create_op_call("block.sub", [lhs, rhs], {}, actual_span) +def rem(lhs: Expr, rhs: Expr, span: Span | None = None) -> Call: + """Element-wise remainder (modulo) of two tiles. + + Computes lhs % rhs element-wise. Maps to the TREM hardware intrinsic. + + Args: + lhs: Left-hand side tile (TileType) + rhs: Right-hand side tile (TileType) + span: Optional source span for debugging (auto-captured if not provided) + + Returns: + Call expression for element-wise remainder + """ + actual_span = _get_span_or_capture(span) + return _ir_core.create_op_call("block.rem", [lhs, rhs], {}, actual_span) + + +def rems(lhs: Expr, rhs: int | float | Expr, span: Span | None = None) -> Call: + """Element-wise remainder (modulo) of tile and scalar. + + Computes lhs % rhs element-wise. Maps to the TREMS hardware intrinsic. + + Args: + lhs: Tile (TileType) + rhs: Scalar (int/float/Expr with ScalarType) + span: Optional source span for debugging (auto-captured if not provided) + + Returns: + Call expression for element-wise remainder with scalar + """ + actual_span = _get_span_or_capture(span) + rhs_expr = ( + _normalize_expr(rhs, actual_span, int_dtype=DataType.INT32, float_dtype=DataType.FP32) + if not isinstance(rhs, Expr) + else rhs + ) + return _ir_core.create_op_call("block.rems", [lhs, rhs_expr], {}, actual_span) + + +def shl(lhs: Expr, rhs: Expr, span: Span | None = None) -> Call: + """Element-wise bitwise left shift of two tiles. + + Computes lhs << rhs element-wise. Maps to the TSHL hardware intrinsic. + + Args: + lhs: Left-hand side tile (TileType) + rhs: Right-hand side tile (TileType) + span: Optional source span for debugging (auto-captured if not provided) + + Returns: + Call expression for element-wise bitwise left shift + """ + actual_span = _get_span_or_capture(span) + return _ir_core.create_op_call("block.shl", [lhs, rhs], {}, actual_span) + + +def shls(lhs: Expr, rhs: int | Expr, span: Span | None = None) -> Call: + """Element-wise bitwise left shift of tile and scalar. + + Computes lhs << rhs element-wise. Maps to the TSHLS hardware intrinsic. + + Note: + The scalar shift amount must be zero or positive; negative values are + not supported by the hardware and will be rejected by codegen. + + Args: + lhs: Tile (TileType) + rhs: Scalar shift amount (int/Expr with INT32 ScalarType); must be >= 0 + span: Optional source span for debugging (auto-captured if not provided) + + Returns: + Call expression for element-wise bitwise left shift with scalar + """ + actual_span = _get_span_or_capture(span) + rhs_expr = ( + _normalize_expr(rhs, actual_span, int_dtype=DataType.INT32) if not isinstance(rhs, Expr) else rhs + ) + return _ir_core.create_op_call("block.shls", [lhs, rhs_expr], {}, actual_span) + + +def shr(lhs: Expr, rhs: Expr, span: Span | None = None) -> Call: + """Element-wise bitwise right shift of two tiles. + + Computes lhs >> rhs element-wise. Maps to the TSHR hardware intrinsic. + + Args: + lhs: Left-hand side tile (TileType) + rhs: Right-hand side tile (TileType) + span: Optional source span for debugging (auto-captured if not provided) + + Returns: + Call expression for element-wise bitwise right shift + """ + actual_span = _get_span_or_capture(span) + return _ir_core.create_op_call("block.shr", [lhs, rhs], {}, actual_span) + + +def shrs(lhs: Expr, rhs: int | Expr, span: Span | None = None) -> Call: + """Element-wise bitwise right shift of tile and scalar. + + Computes lhs >> rhs element-wise. Maps to the TSHRS hardware intrinsic. + + Note: + The scalar shift amount must be zero or positive; negative values are + not supported by the hardware and will be rejected by codegen. + + Args: + lhs: Tile (TileType) + rhs: Scalar shift amount (int/Expr with INT32 ScalarType); must be >= 0 + span: Optional source span for debugging (auto-captured if not provided) + + Returns: + Call expression for element-wise bitwise right shift with scalar + """ + actual_span = _get_span_or_capture(span) + rhs_expr = ( + _normalize_expr(rhs, actual_span, int_dtype=DataType.INT32) if not isinstance(rhs, Expr) else rhs + ) + return _ir_core.create_op_call("block.shrs", [lhs, rhs_expr], {}, actual_span) + + +def and_(lhs: Expr, rhs: Expr, span: Span | None = None) -> Call: + """Element-wise bitwise AND of two tiles. + + Computes lhs & rhs element-wise. Maps to the TAND hardware intrinsic. + + Args: + lhs: Left-hand side tile (TileType) + rhs: Right-hand side tile (TileType) + span: Optional source span for debugging (auto-captured if not provided) + + Returns: + Call expression for element-wise bitwise AND + """ + actual_span = _get_span_or_capture(span) + return _ir_core.create_op_call("block.and", [lhs, rhs], {}, actual_span) + + +def ands(lhs: Expr, rhs: int | Expr, span: Span | None = None) -> Call: + """Element-wise bitwise AND of tile and scalar. + + Computes lhs & rhs element-wise. Maps to the TANDS hardware intrinsic. + + Args: + lhs: Tile (TileType) + rhs: Scalar (int/Expr with INT32 ScalarType) + span: Optional source span for debugging (auto-captured if not provided) + + Returns: + Call expression for element-wise bitwise AND with scalar + """ + actual_span = _get_span_or_capture(span) + rhs_expr = ( + _normalize_expr(rhs, actual_span, int_dtype=DataType.INT32) if not isinstance(rhs, Expr) else rhs + ) + return _ir_core.create_op_call("block.ands", [lhs, rhs_expr], {}, actual_span) + + +def or_(lhs: Expr, rhs: Expr, span: Span | None = None) -> Call: + """Element-wise bitwise OR of two tiles. + + Computes lhs | rhs element-wise. Maps to the TOR hardware intrinsic. + + Args: + lhs: Left-hand side tile (TileType) + rhs: Right-hand side tile (TileType) + span: Optional source span for debugging (auto-captured if not provided) + + Returns: + Call expression for element-wise bitwise OR + """ + actual_span = _get_span_or_capture(span) + return _ir_core.create_op_call("block.or", [lhs, rhs], {}, actual_span) + + +def ors(lhs: Expr, rhs: int | Expr, span: Span | None = None) -> Call: + """Element-wise bitwise OR of tile and scalar. + + Computes lhs | rhs element-wise. Maps to the TORS hardware intrinsic. + + Args: + lhs: Tile (TileType) + rhs: Scalar (int/Expr with INT32 ScalarType) + span: Optional source span for debugging (auto-captured if not provided) + + Returns: + Call expression for element-wise bitwise OR with scalar + """ + actual_span = _get_span_or_capture(span) + rhs_expr = ( + _normalize_expr(rhs, actual_span, int_dtype=DataType.INT32) if not isinstance(rhs, Expr) else rhs + ) + return _ir_core.create_op_call("block.ors", [lhs, rhs_expr], {}, actual_span) + + +def xor(lhs: Expr, rhs: Expr, tmp: Expr, span: Span | None = None) -> Call: + """Element-wise bitwise XOR of two tiles. + + Computes lhs ^ rhs element-wise. Maps to the TXOR hardware intrinsic. + + Args: + lhs: Left-hand side tile (TileType) + rhs: Right-hand side tile (TileType) + tmp: Temporary tile (TileType) required by the hardware + span: Optional source span for debugging (auto-captured if not provided) + + Returns: + Call expression for element-wise bitwise XOR + """ + actual_span = _get_span_or_capture(span) + return _ir_core.create_op_call("block.xor", [lhs, rhs, tmp], {}, actual_span) + + +def xors(lhs: Expr, rhs: int | Expr, tmp: Expr, span: Span | None = None) -> Call: + """Element-wise bitwise XOR of tile and scalar. + + Computes lhs ^ rhs element-wise. Maps to the TXORS hardware intrinsic. + + Args: + lhs: Tile (TileType) + rhs: Scalar (int/Expr with INT32 ScalarType) + tmp: Temporary tile (TileType) required by the hardware + span: Optional source span for debugging (auto-captured if not provided) + + Returns: + Call expression for element-wise bitwise XOR with scalar + """ + actual_span = _get_span_or_capture(span) + rhs_expr = ( + _normalize_expr(rhs, actual_span, int_dtype=DataType.INT32) if not isinstance(rhs, Expr) else rhs + ) + return _ir_core.create_op_call("block.xors", [lhs, rhs_expr, tmp], {}, actual_span) + + +def prelu(tile: Expr, slope: Expr, tmp: Expr, span: Span | None = None) -> Call: + """Element-wise parametric ReLU of a tile. + + Computes prelu(tile, slope) element-wise. Maps to the TPRELU hardware intrinsic. + + Args: + tile: Input tile (TileType) + slope: Slope tile (TileType) used for negative values + tmp: Temporary tile (TileType) required by the hardware + span: Optional source span for debugging (auto-captured if not provided) + + Returns: + Call expression for element-wise parametric ReLU + """ + actual_span = _get_span_or_capture(span) + return _ir_core.create_op_call("block.prelu", [tile, slope, tmp], {}, actual_span) + + +def addc(lhs: Expr, rhs: Expr, rhs2: Expr, span: Span | None = None) -> Call: + """Element-wise addition of three tiles. + + Computes lhs + rhs + rhs2 element-wise. Maps to the TADDC hardware intrinsic. + + Args: + lhs: Left-hand side tile (TileType) + rhs: Right-hand side tile (TileType) + rhs2: Third tile (TileType) + span: Optional source span for debugging (auto-captured if not provided) + + Returns: + Call expression for element-wise ternary addition + """ + actual_span = _get_span_or_capture(span) + return _ir_core.create_op_call("block.addc", [lhs, rhs, rhs2], {}, actual_span) + + +def subc(lhs: Expr, rhs: Expr, rhs2: Expr, span: Span | None = None) -> Call: + """Element-wise subtraction of three tiles. + + Computes lhs - rhs - rhs2 element-wise. Maps to the TSUBC hardware intrinsic. + + Args: + lhs: Left-hand side tile (TileType) + rhs: Right-hand side tile (TileType) + rhs2: Third tile (TileType) + span: Optional source span for debugging (auto-captured if not provided) + + Returns: + Call expression for element-wise ternary subtraction + """ + actual_span = _get_span_or_capture(span) + return _ir_core.create_op_call("block.subc", [lhs, rhs, rhs2], {}, actual_span) + + +def addsc(lhs: Expr, rhs: int | float | Expr, rhs2: Expr, span: Span | None = None) -> Call: + """Element-wise addition of tile, scalar, and tile. + + Computes lhs + rhs + rhs2 element-wise. Maps to the TADDSC hardware intrinsic. + + Args: + lhs: Left-hand side tile (TileType) + rhs: Scalar (int/float/Expr with ScalarType) + rhs2: Third tile (TileType) + span: Optional source span for debugging (auto-captured if not provided) + + Returns: + Call expression for element-wise tile-scalar-tile addition + """ + actual_span = _get_span_or_capture(span) + rhs_expr = ( + _normalize_expr(rhs, actual_span, int_dtype=DataType.INT32, float_dtype=DataType.FP32) + if not isinstance(rhs, Expr) + else rhs + ) + return _ir_core.create_op_call("block.addsc", [lhs, rhs_expr, rhs2], {}, actual_span) + + +def subsc(lhs: Expr, rhs: int | float | Expr, rhs2: Expr, span: Span | None = None) -> Call: + """Element-wise subtraction of tile, scalar, and tile. + + Computes lhs - rhs - rhs2 element-wise. Maps to the TSUBSC hardware intrinsic. + + Args: + lhs: Left-hand side tile (TileType) + rhs: Scalar (int/float/Expr with ScalarType) + rhs2: Third tile (TileType) + span: Optional source span for debugging (auto-captured if not provided) + + Returns: + Call expression for element-wise tile-scalar-tile subtraction + """ + actual_span = _get_span_or_capture(span) + rhs_expr = ( + _normalize_expr(rhs, actual_span, int_dtype=DataType.INT32, float_dtype=DataType.FP32) + if not isinstance(rhs, Expr) + else rhs + ) + return _ir_core.create_op_call("block.subsc", [lhs, rhs_expr, rhs2], {}, actual_span) + + +def lrelu(tile: Expr, slope: int | float | Expr, span: Span | None = None) -> Call: + """Element-wise leaky ReLU of a tile with scalar slope. + + Computes max(x, slope * x) element-wise. Maps to the TLRELU hardware intrinsic. + + Args: + tile: Input tile (TileType) + slope: Scalar slope for negative values (int/float/Expr with ScalarType) + span: Optional source span for debugging (auto-captured if not provided) + + Returns: + Call expression for element-wise leaky ReLU + """ + actual_span = _get_span_or_capture(span) + slope_expr = ( + _normalize_expr(slope, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32) + if not isinstance(slope, Expr) + else slope + ) + return _ir_core.create_op_call("block.lrelu", [tile, slope_expr], {}, actual_span) + + +def sel(mask: Expr, lhs: Expr, rhs: Expr, span: Span | None = None) -> Call: + """Per-element selection between two tiles using a predicate mask tile. + + For each element (i, j): dst[i,j] = lhs[i,j] if mask[i,j] is true, else rhs[i,j]. + Maps to the TSEL hardware intrinsic. The mask encoding is target-defined. + + Args: + mask: Predicate mask tile (TileType); encoding is target-defined + lhs: Source tile 0, selected where mask is true (TileType) + rhs: Source tile 1, selected where mask is false (TileType) + span: Optional source span for debugging (auto-captured if not provided) + + Returns: + Call expression for per-element tile selection + """ + actual_span = _get_span_or_capture(span) + return _ir_core.create_op_call("block.sel", [mask, lhs, rhs], {}, actual_span) + + +def sels(lhs: Expr, rhs: Expr, select_mode: int | float | Expr, span: Span | None = None) -> Call: + """Select between two tiles based on a scalar mode. + + Maps to the TSELS hardware intrinsic. The interpretation of select_mode values + is target-dependent and enforced by codegen. + + Args: + lhs: Source tile 0 (TileType) + rhs: Source tile 1 (TileType) + select_mode: Scalar select mode + span: Optional source span for debugging (auto-captured if not provided) + + Returns: + Call expression for tile select + """ + actual_span = _get_span_or_capture(span) + select_mode_expr = ( + _normalize_expr(select_mode, actual_span, int_dtype=DataType.INT32, float_dtype=DataType.FP32) + if not isinstance(select_mode, Expr) + else select_mode + ) + return _ir_core.create_op_call("block.sels", [lhs, rhs, select_mode_expr], {}, actual_span) + + def muls(lhs: Expr, rhs: int | float | Expr, span: Span | None = None) -> Call: """Element-wise multiplication of tile and scalar. @@ -379,7 +778,7 @@ def muls(lhs: Expr, rhs: int | float | Expr, span: Span | None = None) -> Call: """ actual_span = _get_span_or_capture(span) rhs_expr = ( - _normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32) + _normalize_expr(rhs, actual_span, int_dtype=DataType.INT32, float_dtype=DataType.FP32) if not isinstance(rhs, Expr) else rhs ) @@ -399,7 +798,7 @@ def adds(lhs: Expr, rhs: int | float | Expr, span: Span | None = None) -> Call: """ actual_span = _get_span_or_capture(span) rhs_expr = ( - _normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32) + _normalize_expr(rhs, actual_span, int_dtype=DataType.INT32, float_dtype=DataType.FP32) if not isinstance(rhs, Expr) else rhs ) @@ -419,7 +818,7 @@ def divs(lhs: Expr, rhs: int | float | Expr, span: Span | None = None) -> Call: """ actual_span = _get_span_or_capture(span) rhs_expr = ( - _normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32) + _normalize_expr(rhs, actual_span, int_dtype=DataType.INT32, float_dtype=DataType.FP32) if not isinstance(rhs, Expr) else rhs ) @@ -439,7 +838,7 @@ def subs(lhs: Expr, rhs: int | float | Expr, span: Span | None = None) -> Call: """ actual_span = _get_span_or_capture(span) rhs_expr = ( - _normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32) + _normalize_expr(rhs, actual_span, int_dtype=DataType.INT32, float_dtype=DataType.FP32) if not isinstance(rhs, Expr) else rhs ) @@ -487,7 +886,7 @@ def cmps( """ actual_span = _get_span_or_capture(span) rhs_expr = ( - _normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32) + _normalize_expr(rhs, actual_span, int_dtype=DataType.INT32, float_dtype=DataType.FP32) if not isinstance(rhs, Expr) else rhs ) @@ -643,6 +1042,22 @@ def relu(tile: Expr, span: Span | None = None) -> Call: return _ir_core.create_op_call("block.relu", [tile], {}, actual_span) +def not_(tile: Expr, span: Span | None = None) -> Call: + """Element-wise bitwise NOT of a tile. + + Computes ~tile element-wise. Maps to the TNOT hardware intrinsic. + + Args: + tile: Input tile (TileType) + span: Optional source span for debugging (auto-captured if not provided) + + Returns: + Call expression for element-wise bitwise NOT + """ + actual_span = _get_span_or_capture(span) + return _ir_core.create_op_call("block.not", [tile], {}, actual_span) + + # ============================================================================ # Matrix Operations # ============================================================================ @@ -683,11 +1098,90 @@ def matmul_acc(acc: Expr, lhs: Expr, rhs: Expr, span: Span | None = None) -> Cal return _ir_core.create_op_call("block.matmul_acc", [acc, lhs, rhs], {}, actual_span) +def matmul_bias(lhs: Expr, rhs: Expr, bias: Expr, span: Span | None = None) -> Call: + """Matrix multiplication with bias add: C = lhs @ rhs + bias. + + Args: + lhs: Left-hand side tile (TileType [M, K]) + rhs: Right-hand side tile (TileType [K, N]) + bias: Bias tile (TileType [1, N]) + span: Optional source span for debugging (auto-captured if not provided) + + Returns: + Call expression for matrix multiplication with bias + """ + actual_span = _get_span_or_capture(span) + return _ir_core.create_op_call("block.matmul_bias", [lhs, rhs, bias], {}, actual_span) + + +def gemv(lhs: Expr, rhs: Expr, span: Span | None = None) -> Call: + """General Matrix-Vector multiplication: C[1,N] = A[1,K] @ B[K,N]. + + Args: + lhs: Row vector tile (TileType [1, K]) + rhs: Right-hand side tile (TileType [K, N]) + span: Optional source span for debugging (auto-captured if not provided) + + Returns: + Call expression for GEMV + """ + actual_span = _get_span_or_capture(span) + return _ir_core.create_op_call("block.gemv", [lhs, rhs], {}, actual_span) + + +def gemv_acc(acc: Expr, lhs: Expr, rhs: Expr, span: Span | None = None) -> Call: + """GEMV with accumulation: C[1,N] += A[1,K] @ B[K,N]. + + Args: + acc: Accumulator tile (TileType [1, N]) + lhs: Row vector tile (TileType [1, K]) + rhs: Right-hand side tile (TileType [K, N]) + span: Optional source span for debugging (auto-captured if not provided) + + Returns: + Call expression for GEMV with accumulation + """ + actual_span = _get_span_or_capture(span) + return _ir_core.create_op_call("block.gemv_acc", [acc, lhs, rhs], {}, actual_span) + + +def gemv_bias(lhs: Expr, rhs: Expr, bias: Expr, span: Span | None = None) -> Call: + """GEMV with bias add: C[1,N] = A[1,K] @ B[K,N] + bias[1,N]. + + Args: + lhs: Row vector tile (TileType [1, K]) + rhs: Right-hand side tile (TileType [K, N]) + bias: Bias tile (TileType [1, N]) + span: Optional source span for debugging (auto-captured if not provided) + + Returns: + Call expression for GEMV with bias + """ + actual_span = _get_span_or_capture(span) + return _ir_core.create_op_call("block.gemv_bias", [lhs, rhs, bias], {}, actual_span) + + # ============================================================================ # Row Broadcast Operations # ============================================================================ +def row_expand(src: Expr, span: Span | None = None) -> Call: + """Broadcast the first element of each source row across the destination row. + + For each element (i, j) in the valid region: dst[i, j] = src[i, 0]. + + Args: + src: Input tile (TileType [M, N]) + span: Optional source span for debugging (auto-captured if not provided) + + Returns: + Call expression for row-wise first-element broadcast + """ + actual_span = _get_span_or_capture(span) + return _ir_core.create_op_call("block.row_expand", [src], {}, actual_span) + + def row_expand_sub(tile: Expr, row_vec: Expr, span: Span | None = None) -> Call: """Row-wise broadcast subtraction. @@ -885,6 +1379,50 @@ def minimum(lhs: Expr, rhs: Expr, span: Span | None = None) -> Call: return _ir_core.create_op_call("block.minimum", [lhs, rhs], {}, actual_span) +def maxs(lhs: Expr, rhs: int | float | Expr, span: Span | None = None) -> Call: + """Element-wise maximum of tile and scalar. + + Computes max(lhs, rhs) element-wise. Maps to the TMAXS hardware intrinsic. + + Args: + lhs: Tile (TileType) + rhs: Scalar (int/float/Expr with ScalarType) + span: Optional source span for debugging (auto-captured if not provided) + + Returns: + Call expression for element-wise maximum with scalar + """ + actual_span = _get_span_or_capture(span) + rhs_expr = ( + _normalize_expr(rhs, actual_span, int_dtype=DataType.INT32, float_dtype=DataType.FP32) + if not isinstance(rhs, Expr) + else rhs + ) + return _ir_core.create_op_call("block.maxs", [lhs, rhs_expr], {}, actual_span) + + +def mins(lhs: Expr, rhs: int | float | Expr, span: Span | None = None) -> Call: + """Element-wise minimum of tile and scalar. + + Computes min(lhs, rhs) element-wise. Maps to the TMINS hardware intrinsic. + + Args: + lhs: Tile (TileType) + rhs: Scalar (int/float/Expr with ScalarType) + span: Optional source span for debugging (auto-captured if not provided) + + Returns: + Call expression for element-wise minimum with scalar + """ + actual_span = _get_span_or_capture(span) + rhs_expr = ( + _normalize_expr(rhs, actual_span, int_dtype=DataType.INT32, float_dtype=DataType.FP32) + if not isinstance(rhs, Expr) + else rhs + ) + return _ir_core.create_op_call("block.mins", [lhs, rhs_expr], {}, actual_span) + + # ============================================================================ # Reduction Operations # ============================================================================ diff --git a/python/pypto/language/__init__.py b/python/pypto/language/__init__.py index bb8f0002..5347b65a 100644 --- a/python/pypto/language/__init__.py +++ b/python/pypto/language/__init__.py @@ -47,6 +47,10 @@ def scalar_func(x: pl.Scalar[pl.FP32]) -> pl.Scalar[pl.FP32]: from .op import tensor_ops as tensor from .op.block_ops import ( abs, + addc, + addsc, + and_, + ands, cmp, cmps, col_expand, @@ -55,27 +59,51 @@ def scalar_func(x: pl.Scalar[pl.FP32]) -> pl.Scalar[pl.FP32]: col_expand_sub, create_tile, expands, + gemv, + gemv_acc, + gemv_bias, l0c_store, load, log, + lrelu, matmul_acc, + matmul_bias, max, + maxs, min, minimum, + mins, move, neg, + not_, + or_, + ors, + prelu, recip, relu, + rem, + rems, + row_expand, row_expand_add, row_expand_div, row_expand_mul, row_expand_sub, row_min, rsqrt, + sel, + sels, + shl, + shls, + shr, + shrs, sqrt, store, + subc, + subsc, sum, ub_copy, + xor, + xors, ) from .op.tensor_ops import assemble, create_tensor, dim from .op.unified_ops import ( @@ -170,6 +198,10 @@ def scalar_func(x: pl.Scalar[pl.FP32]) -> pl.Scalar[pl.FP32]: "abs", "relu", "matmul_acc", + "matmul_bias", + "gemv", + "gemv_acc", + "gemv_bias", "minimum", "min", "sum", @@ -177,6 +209,7 @@ def scalar_func(x: pl.Scalar[pl.FP32]) -> pl.Scalar[pl.FP32]: "cmp", "cmps", "row_min", + "row_expand", "row_expand_add", "row_expand_sub", "row_expand_mul", @@ -186,6 +219,29 @@ def scalar_func(x: pl.Scalar[pl.FP32]) -> pl.Scalar[pl.FP32]: "col_expand_div", "col_expand_sub", "expands", + "rem", + "rems", + "and_", + "ands", + "or_", + "ors", + "xor", + "xors", + "shl", + "shls", + "shr", + "shrs", + "maxs", + "mins", + "prelu", + "not_", + "addc", + "subc", + "addsc", + "subsc", + "lrelu", + "sel", + "sels", # Promoted tensor-only "create_tensor", "assemble", diff --git a/python/pypto/language/op/__init__.py b/python/pypto/language/op/__init__.py index 2833c48f..163f7905 100644 --- a/python/pypto/language/op/__init__.py +++ b/python/pypto/language/op/__init__.py @@ -26,6 +26,10 @@ # Promoted block-only ops (accessible as pl.load, etc.) from .block_ops import ( abs, + addc, + addsc, + and_, + ands, cmp, cmps, col_expand, @@ -34,27 +38,51 @@ col_expand_sub, create_tile, expands, + gemv, + gemv_acc, + gemv_bias, l0c_store, load, log, + lrelu, matmul_acc, + matmul_bias, max, + maxs, min, minimum, + mins, move, neg, + not_, + or_, + ors, + prelu, recip, relu, + rem, + rems, + row_expand, row_expand_add, row_expand_div, row_expand_mul, row_expand_sub, row_min, rsqrt, + sel, + sels, + shl, + shls, + shr, + shrs, sqrt, store, + subc, + subsc, sum, ub_copy, + xor, + xors, ) # Promoted tensor-only ops (accessible as pl.create_tensor, etc.) @@ -112,10 +140,15 @@ "abs", "relu", "matmul_acc", + "matmul_bias", + "gemv", + "gemv_acc", + "gemv_bias", "minimum", "cmp", "cmps", "row_min", + "row_expand", "row_expand_add", "row_expand_sub", "row_expand_mul", @@ -125,6 +158,29 @@ "col_expand_div", "col_expand_sub", "expands", + "rem", + "rems", + "and_", + "ands", + "or_", + "ors", + "xor", + "xors", + "shl", + "shls", + "shr", + "shrs", + "maxs", + "mins", + "prelu", + "not_", + "addc", + "subc", + "addsc", + "subsc", + "lrelu", + "sel", + "sels", # Promoted tensor-only "create_tensor", "assemble", diff --git a/python/pypto/language/op/block_ops.py b/python/pypto/language/op/block_ops.py index e8fb0ecc..f5f24757 100644 --- a/python/pypto/language/op/block_ops.py +++ b/python/pypto/language/op/block_ops.py @@ -45,10 +45,15 @@ "cast", "matmul", "matmul_acc", + "matmul_bias", + "gemv", + "gemv_acc", + "gemv_bias", "row_max", "row_sum", "row_min", "maximum", + "row_expand", "row_expand_sub", "row_expand_div", "row_expand_mul", @@ -67,6 +72,29 @@ "view", "reshape", "transpose", + "rem", + "rems", + "and_", + "ands", + "or_", + "ors", + "xor", + "xors", + "shl", + "shls", + "shr", + "shrs", + "maxs", + "mins", + "prelu", + "not_", + "addc", + "subc", + "addsc", + "subsc", + "lrelu", + "sel", + "sels", ] from pypto.ir.op import block_ops as _ir_ops @@ -524,6 +552,65 @@ def matmul_acc(acc: Tile, lhs: Tile, rhs: Tile) -> Tile: return Tile(expr=call_expr) +def matmul_bias(lhs: Tile, rhs: Tile, bias: Tile) -> Tile: + """Matrix multiplication with bias add: C = lhs @ rhs + bias. + + Args: + lhs: Left-hand side tile [M, K] + rhs: Right-hand side tile [K, N] + bias: Bias tile [1, N] + + Returns: + Tile wrapping the matmul_bias operation + """ + call_expr = _ir_ops.matmul_bias(lhs.unwrap(), rhs.unwrap(), bias.unwrap()) + return Tile(expr=call_expr) + + +def gemv(lhs: Tile, rhs: Tile) -> Tile: + """General Matrix-Vector multiplication: C[1,N] = A[1,K] @ B[K,N]. + + Args: + lhs: Row vector tile [1, K] + rhs: Right-hand side tile [K, N] + + Returns: + Tile wrapping the gemv operation + """ + call_expr = _ir_ops.gemv(lhs.unwrap(), rhs.unwrap()) + return Tile(expr=call_expr) + + +def gemv_acc(acc: Tile, lhs: Tile, rhs: Tile) -> Tile: + """GEMV with accumulation: C[1,N] += A[1,K] @ B[K,N]. + + Args: + acc: Accumulator tile [1, N] + lhs: Row vector tile [1, K] + rhs: Right-hand side tile [K, N] + + Returns: + Tile wrapping the gemv_acc operation + """ + call_expr = _ir_ops.gemv_acc(acc.unwrap(), lhs.unwrap(), rhs.unwrap()) + return Tile(expr=call_expr) + + +def gemv_bias(lhs: Tile, rhs: Tile, bias: Tile) -> Tile: + """GEMV with bias add: C[1,N] = A[1,K] @ B[K,N] + bias[1,N]. + + Args: + lhs: Row vector tile [1, K] + rhs: Right-hand side tile [K, N] + bias: Bias tile [1, N] + + Returns: + Tile wrapping the gemv_bias operation + """ + call_expr = _ir_ops.gemv_bias(lhs.unwrap(), rhs.unwrap(), bias.unwrap()) + return Tile(expr=call_expr) + + def row_max(tile: Tile, tmp_tile: Tile) -> Tile: """Row-wise max reduction. @@ -580,6 +667,21 @@ def maximum(lhs: Tile, rhs: Tile) -> Tile: return Tile(expr=call_expr) +def row_expand(src: Tile) -> Tile: + """Broadcast the first element of each source row across the destination row. + + For each element (i, j): dst[i, j] = src[i, 0]. + + Args: + src: Input tile [M, N] + + Returns: + Tile wrapping the row_expand operation + """ + call_expr = _ir_ops.row_expand(src.unwrap()) + return Tile(expr=call_expr) + + def row_expand_sub(tile: Tile, row_vec: Tile) -> Tile: """Row-wise broadcast subtraction. @@ -874,3 +976,401 @@ def transpose(tile: Tile, axis1: int, axis2: int) -> Tile: tile_expr = tile.unwrap() call_expr = _ir_ops.transpose(tile_expr, axis1, axis2) return Tile(expr=call_expr) + + +def rem(lhs: Tile, rhs: Tile) -> Tile: + """Element-wise remainder (modulo) of two tiles. + + Computes lhs % rhs element-wise. Maps to the TREM hardware intrinsic. + + Args: + lhs: Left-hand side tile + rhs: Right-hand side tile + + Returns: + Tile wrapping the rem operation + """ + call_expr = _ir_ops.rem(lhs.unwrap(), rhs.unwrap()) + return Tile(expr=call_expr) + + +def rems(lhs: Tile, rhs: int | float | Expr | Scalar) -> Tile: + """Element-wise remainder (modulo) of tile and scalar. + + Computes lhs % rhs element-wise. Maps to the TREMS hardware intrinsic. + + Args: + lhs: Tile + rhs: Scalar value + + Returns: + Tile wrapping the rems operation + """ + rhs_expr = rhs.unwrap() if isinstance(rhs, Scalar) else rhs + call_expr = _ir_ops.rems(lhs.unwrap(), rhs_expr) + return Tile(expr=call_expr) + + +def and_(lhs: Tile, rhs: Tile) -> Tile: + """Element-wise bitwise AND of two tiles. + + Computes lhs & rhs element-wise. Maps to the TAND hardware intrinsic. + + Args: + lhs: Left-hand side tile + rhs: Right-hand side tile + + Returns: + Tile wrapping the and operation + """ + call_expr = _ir_ops.and_(lhs.unwrap(), rhs.unwrap()) + return Tile(expr=call_expr) + + +def ands(lhs: Tile, rhs: int | Expr | Scalar) -> Tile: + """Element-wise bitwise AND of tile and scalar. + + Computes lhs & rhs element-wise. Maps to the TANDS hardware intrinsic. + + Args: + lhs: Tile + rhs: Scalar value + + Returns: + Tile wrapping the ands operation + """ + rhs_expr = rhs.unwrap() if isinstance(rhs, Scalar) else rhs + call_expr = _ir_ops.ands(lhs.unwrap(), rhs_expr) + return Tile(expr=call_expr) + + +def or_(lhs: Tile, rhs: Tile) -> Tile: + """Element-wise bitwise OR of two tiles. + + Computes lhs | rhs element-wise. Maps to the TOR hardware intrinsic. + + Args: + lhs: Left-hand side tile + rhs: Right-hand side tile + + Returns: + Tile wrapping the or operation + """ + call_expr = _ir_ops.or_(lhs.unwrap(), rhs.unwrap()) + return Tile(expr=call_expr) + + +def ors(lhs: Tile, rhs: int | Expr | Scalar) -> Tile: + """Element-wise bitwise OR of tile and scalar. + + Computes lhs | rhs element-wise. Maps to the TORS hardware intrinsic. + + Args: + lhs: Tile + rhs: Scalar value + + Returns: + Tile wrapping the ors operation + """ + rhs_expr = rhs.unwrap() if isinstance(rhs, Scalar) else rhs + call_expr = _ir_ops.ors(lhs.unwrap(), rhs_expr) + return Tile(expr=call_expr) + + +def xor(lhs: Tile, rhs: Tile, tmp: Tile) -> Tile: + """Element-wise bitwise XOR of two tiles. + + Computes lhs ^ rhs element-wise. Maps to the TXOR hardware intrinsic. + + Args: + lhs: Left-hand side tile + rhs: Right-hand side tile + tmp: Temporary tile required by the hardware + + Returns: + Tile wrapping the xor operation + """ + call_expr = _ir_ops.xor(lhs.unwrap(), rhs.unwrap(), tmp.unwrap()) + return Tile(expr=call_expr) + + +def xors(lhs: Tile, rhs: int | Expr | Scalar, tmp: Tile) -> Tile: + """Element-wise bitwise XOR of tile and scalar. + + Computes lhs ^ rhs element-wise. Maps to the TXORS hardware intrinsic. + + Args: + lhs: Tile + rhs: Scalar value + tmp: Temporary tile required by the hardware + + Returns: + Tile wrapping the xors operation + """ + rhs_expr = rhs.unwrap() if isinstance(rhs, Scalar) else rhs + call_expr = _ir_ops.xors(lhs.unwrap(), rhs_expr, tmp.unwrap()) + return Tile(expr=call_expr) + + +def shl(lhs: Tile, rhs: Tile) -> Tile: + """Element-wise bitwise left shift of two tiles. + + Computes lhs << rhs element-wise. Maps to the TSHL hardware intrinsic. + + Args: + lhs: Left-hand side tile + rhs: Right-hand side tile + + Returns: + Tile wrapping the shl operation + """ + call_expr = _ir_ops.shl(lhs.unwrap(), rhs.unwrap()) + return Tile(expr=call_expr) + + +def shls(lhs: Tile, rhs: int | Expr | Scalar) -> Tile: + """Element-wise bitwise left shift of tile and scalar. + + Computes lhs << rhs element-wise. Maps to the TSHLS hardware intrinsic. + + Note: + The scalar shift amount must be zero or positive; negative values are + not supported by the hardware and will be rejected by codegen. + + Args: + lhs: Tile + rhs: Scalar shift amount; must be >= 0 + + Returns: + Tile wrapping the shls operation + """ + rhs_expr = rhs.unwrap() if isinstance(rhs, Scalar) else rhs + call_expr = _ir_ops.shls(lhs.unwrap(), rhs_expr) + return Tile(expr=call_expr) + + +def shr(lhs: Tile, rhs: Tile) -> Tile: + """Element-wise bitwise right shift of two tiles. + + Computes lhs >> rhs element-wise. Maps to the TSHR hardware intrinsic. + + Args: + lhs: Left-hand side tile + rhs: Right-hand side tile + + Returns: + Tile wrapping the shr operation + """ + call_expr = _ir_ops.shr(lhs.unwrap(), rhs.unwrap()) + return Tile(expr=call_expr) + + +def shrs(lhs: Tile, rhs: int | Expr | Scalar) -> Tile: + """Element-wise bitwise right shift of tile and scalar. + + Computes lhs >> rhs element-wise. Maps to the TSHRS hardware intrinsic. + + Note: + The scalar shift amount must be zero or positive; negative values are + not supported by the hardware and will be rejected by codegen. + + Args: + lhs: Tile + rhs: Scalar shift amount; must be >= 0 + + Returns: + Tile wrapping the shrs operation + """ + rhs_expr = rhs.unwrap() if isinstance(rhs, Scalar) else rhs + call_expr = _ir_ops.shrs(lhs.unwrap(), rhs_expr) + return Tile(expr=call_expr) + + +def maxs(lhs: Tile, rhs: int | float | Expr | Scalar) -> Tile: + """Element-wise maximum of tile and scalar. + + Computes max(lhs, rhs) element-wise. Maps to the TMAXS hardware intrinsic. + + Args: + lhs: Tile + rhs: Scalar value + + Returns: + Tile wrapping the maxs operation + """ + rhs_expr = rhs.unwrap() if isinstance(rhs, Scalar) else rhs + call_expr = _ir_ops.maxs(lhs.unwrap(), rhs_expr) + return Tile(expr=call_expr) + + +def mins(lhs: Tile, rhs: int | float | Expr | Scalar) -> Tile: + """Element-wise minimum of tile and scalar. + + Computes min(lhs, rhs) element-wise. Maps to the TMINS hardware intrinsic. + + Args: + lhs: Tile + rhs: Scalar value + + Returns: + Tile wrapping the mins operation + """ + rhs_expr = rhs.unwrap() if isinstance(rhs, Scalar) else rhs + call_expr = _ir_ops.mins(lhs.unwrap(), rhs_expr) + return Tile(expr=call_expr) + + +def prelu(tile: Tile, slope: Tile, tmp: Tile) -> Tile: + """Element-wise parametric ReLU of a tile. + + Computes prelu(tile, slope) element-wise. Maps to the TPRELU hardware intrinsic. + + Args: + tile: Input tile + slope: Slope tile used for negative values + tmp: Temporary tile required by the hardware + + Returns: + Tile wrapping the prelu operation + """ + call_expr = _ir_ops.prelu(tile.unwrap(), slope.unwrap(), tmp.unwrap()) + return Tile(expr=call_expr) + + +def not_(tile: Tile) -> Tile: + """Element-wise bitwise NOT of a tile. + + Computes ~tile element-wise. Maps to the TNOT hardware intrinsic. + + Args: + tile: Input tile + + Returns: + Tile wrapping the not operation + """ + call_expr = _ir_ops.not_(tile.unwrap()) + return Tile(expr=call_expr) + + +def addc(lhs: Tile, rhs: Tile, rhs2: Tile) -> Tile: + """Element-wise addition of three tiles. + + Computes lhs + rhs + rhs2 element-wise. Maps to the TADDC hardware intrinsic. + + Args: + lhs: Left-hand side tile + rhs: Right-hand side tile + rhs2: Third tile + + Returns: + Tile wrapping the addc operation + """ + call_expr = _ir_ops.addc(lhs.unwrap(), rhs.unwrap(), rhs2.unwrap()) + return Tile(expr=call_expr) + + +def subc(lhs: Tile, rhs: Tile, rhs2: Tile) -> Tile: + """Element-wise subtraction of three tiles. + + Computes lhs - rhs - rhs2 element-wise. Maps to the TSUBC hardware intrinsic. + + Args: + lhs: Left-hand side tile + rhs: Right-hand side tile + rhs2: Third tile + + Returns: + Tile wrapping the subc operation + """ + call_expr = _ir_ops.subc(lhs.unwrap(), rhs.unwrap(), rhs2.unwrap()) + return Tile(expr=call_expr) + + +def addsc(lhs: Tile, rhs: int | float | Expr | Scalar, rhs2: Tile) -> Tile: + """Element-wise addition of tile, scalar, and tile. + + Computes lhs + rhs + rhs2 element-wise. Maps to the TADDSC hardware intrinsic. + + Args: + lhs: Left-hand side tile + rhs: Scalar value + rhs2: Third tile + + Returns: + Tile wrapping the addsc operation + """ + rhs_expr = rhs.unwrap() if isinstance(rhs, Scalar) else rhs + call_expr = _ir_ops.addsc(lhs.unwrap(), rhs_expr, rhs2.unwrap()) + return Tile(expr=call_expr) + + +def subsc(lhs: Tile, rhs: int | float | Expr | Scalar, rhs2: Tile) -> Tile: + """Element-wise subtraction of tile, scalar, and tile. + + Computes lhs - rhs - rhs2 element-wise. Maps to the TSUBSC hardware intrinsic. + + Args: + lhs: Left-hand side tile + rhs: Scalar value + rhs2: Third tile + + Returns: + Tile wrapping the subsc operation + """ + rhs_expr = rhs.unwrap() if isinstance(rhs, Scalar) else rhs + call_expr = _ir_ops.subsc(lhs.unwrap(), rhs_expr, rhs2.unwrap()) + return Tile(expr=call_expr) + + +def lrelu(tile: Tile, slope: int | float | Expr | Scalar) -> Tile: + """Element-wise leaky ReLU with scalar slope. + + Computes max(tile, slope * tile) element-wise. Maps to the TLRELU hardware intrinsic. + + Args: + tile: Input tile + slope: Scalar slope for negative values + + Returns: + Tile wrapping the lrelu operation + """ + slope_expr = slope.unwrap() if isinstance(slope, Scalar) else slope + call_expr = _ir_ops.lrelu(tile.unwrap(), slope_expr) + return Tile(expr=call_expr) + + +def sel(mask: Tile, lhs: Tile, rhs: Tile) -> Tile: + """Per-element selection between two tiles using a predicate mask tile. + + For each element (i, j): dst[i,j] = lhs[i,j] if mask[i,j] is true, else rhs[i,j]. + Maps to the TSEL hardware intrinsic. The mask encoding is target-defined. + + Args: + mask: Predicate mask tile; encoding is target-defined + lhs: Source tile 0, selected where mask is true + rhs: Source tile 1, selected where mask is false + + Returns: + Tile wrapping the sel operation + """ + call_expr = _ir_ops.sel(mask.unwrap(), lhs.unwrap(), rhs.unwrap()) + return Tile(expr=call_expr) + + +def sels(lhs: Tile, rhs: Tile, select_mode: int | float | Expr | Scalar) -> Tile: + """Select between two tiles based on a scalar mode. + + Maps to the TSELS hardware intrinsic. The interpretation of select_mode values + is target-dependent and enforced by codegen. + + Args: + lhs: Source tile 0 + rhs: Source tile 1 + select_mode: Scalar select mode + + Returns: + Tile wrapping the sels operation + """ + select_mode_expr = select_mode.unwrap() if isinstance(select_mode, Scalar) else select_mode + call_expr = _ir_ops.sels(lhs.unwrap(), rhs.unwrap(), select_mode_expr) + return Tile(expr=call_expr) diff --git a/python/pypto/language/parser/ast_parser.py b/python/pypto/language/parser/ast_parser.py index bd63bf3b..904a85f9 100644 --- a/python/pypto/language/parser/ast_parser.py +++ b/python/pypto/language/parser/ast_parser.py @@ -1512,6 +1512,7 @@ def _parse_block_op(self, op_name: str, call: ast.Call) -> ir.Expr: "divs", "sum", "row_min", + "row_expand", "row_expand_add", "row_expand_sub", "row_expand_mul", @@ -1521,6 +1522,10 @@ def _parse_block_op(self, op_name: str, call: ast.Call) -> ir.Expr: "col_expand_div", "col_expand_sub", "expands", + "matmul_bias", + "gemv", + "gemv_acc", + "gemv_bias", "abs", "create_tile", } diff --git a/src/ir/op/block_ops/broadcast.cpp b/src/ir/op/block_ops/broadcast.cpp index 3a56d692..a822b79f 100644 --- a/src/ir/op/block_ops/broadcast.cpp +++ b/src/ir/op/block_ops/broadcast.cpp @@ -139,6 +139,24 @@ TypePtr DeduceBlockExpandScalarType(const std::vector& args, // Registration Function for Block Row Broadcast Operations // ============================================================================ +REGISTER_OP("block.row_expand") + .set_op_category("BlockOp") + .set_description( + "Broadcast first element of each source row across the destination row: dst[i,j] = src[i,0]") + .add_argument("src", "Input tile (TileType, 2D [M, N])") + .f_deduce_type([](const std::vector& args, + const std::vector>& kwargs) { + CHECK(args.size() == 1) << "The operator block.row_expand requires exactly 1 argument, but got " + << args.size(); + auto tile_type = As(args[0]->GetType()); + CHECK(tile_type) << "The operator block.row_expand requires argument to be a TileType, but got " + << args[0]->GetType()->TypeName(); + CHECK(tile_type->shape_.size() >= 2) + << "The operator block.row_expand requires input tile" + << " to have at least 2 dimensions, but got " << tile_type->shape_.size() << " dimensions"; + return std::make_shared(tile_type->shape_, tile_type->dtype_); + }); + REGISTER_OP("block.row_expand_sub") .set_op_category("BlockOp") .set_description("Row-wise broadcast subtraction: tile - row_vec (broadcasted)") diff --git a/src/ir/op/block_ops/elementwise.cpp b/src/ir/op/block_ops/elementwise.cpp index 5c015380..bf7cf59d 100644 --- a/src/ir/op/block_ops/elementwise.cpp +++ b/src/ir/op/block_ops/elementwise.cpp @@ -37,7 +37,7 @@ namespace ir { TypePtr DeduceBlockOpElementwiseBinaryType(const std::vector& args, const std::vector>& kwargs, - const std::string& op_name) { + const std::string& op_name, bool require_int = false) { CHECK(args.size() == 2) << "The operator " << op_name << " requires exactly 2 arguments, but got " << args.size(); @@ -50,6 +50,15 @@ TypePtr DeduceBlockOpElementwiseBinaryType(const std::vector& args, CHECK(tile_type2) << "The operator " << op_name << " requires second argument to be a TileType, but got " << args[1]->GetType()->TypeName(); + if (require_int) { + CHECK(tile_type1->dtype_.IsInt()) + << "The operator " << op_name << " requires integer tile dtype, but got " + << tile_type1->dtype_.ToString(); + CHECK(tile_type2->dtype_.IsInt()) + << "The operator " << op_name << " requires integer tile dtype, but got " + << tile_type2->dtype_.ToString(); + } + // Use broadcasting auto result_dtype = PromoteDataTypes(tile_type1->dtype_, tile_type2->dtype_); CHECK(result_dtype) << "The operator " << op_name << " requires compatible data types, but got " @@ -63,6 +72,32 @@ TypePtr DeduceBlockOpElementwiseBinaryType(const std::vector& args, return std::make_shared(broadcast_result.shape, *result_dtype); } +// Tile-tile shift ops (shl, shr): RHS is the shift amount, result type equals LHS tile type, +// consistent with scalar variants (shls/shrs) which preserve the LHS tile dtype. +TypePtr DeduceBlockOpShiftBinaryType(const std::vector& args, + const std::vector>& kwargs, + const std::string& op_name) { + CHECK(args.size() == 2) << "The operator " << op_name << " requires exactly 2 arguments, but got " + << args.size(); + + auto tile_type1 = As(args[0]->GetType()); + auto tile_type2 = As(args[1]->GetType()); + CHECK(tile_type1) << "The operator " << op_name << " requires first argument to be a TileType, but got " + << args[0]->GetType()->TypeName(); + CHECK(tile_type2) << "The operator " << op_name << " requires second argument to be a TileType, but got " + << args[1]->GetType()->TypeName(); + CHECK(tile_type1->dtype_.IsInt()) << "The operator " << op_name << " requires integer tile dtype, but got " + << tile_type1->dtype_.ToString(); + CHECK(tile_type2->dtype_.IsInt()) << "The operator " << op_name + << " requires integer shift tile dtype, but got " + << tile_type2->dtype_.ToString(); + + auto broadcast_result = BroadcastShapes(tile_type1->shape_, tile_type2->shape_); + CHECK(broadcast_result.success) << "The operator " << op_name << " requires compatible shapes"; + + return std::make_shared(broadcast_result.shape, tile_type1->dtype_); +} + TypePtr DeduceBlockOpScalarBinaryType(const std::vector& args, const std::vector>& kwargs, const std::string& op_name) { @@ -87,6 +122,33 @@ TypePtr DeduceBlockOpScalarBinaryType(const std::vector& args, return std::make_shared(tile_type->shape_, *result_dtype); } +TypePtr DeduceBlockOpIntScalarBinaryType(const std::vector& args, + const std::vector>& kwargs, + const std::string& op_name) { + CHECK(args.size() == 2) << "The operator " << op_name << " requires exactly 2 arguments, but got " + << args.size(); + + // First argument must be TileType with integer dtype. + auto tile_type = As(args[0]->GetType()); + CHECK(tile_type) << "The operator " << op_name << " requires first argument to be a TileType, but got " + << args[0]->GetType()->TypeName(); + CHECK(tile_type->dtype_.IsInt()) << "The operator " << op_name << " requires integer tile dtype, but got " + << tile_type->dtype_.ToString(); + + // Second argument must be ScalarType with an integer dtype per ISA spec: + // %dst = tshls/tshrs/tands/tors %src, %scalar : !pto.tile<...>, i32 + // The IR allows any integer width (INT8/16/32/64, UINT variants); codegen casts to i32. + auto scalar_type = As(args[1]->GetType()); + CHECK(scalar_type) << "The operator " << op_name << " requires second argument to be a ScalarType, but got " + << args[1]->GetType()->TypeName(); + CHECK(scalar_type->dtype_.IsInt()) << "The operator " << op_name + << " requires shift/bitwise scalar to be an integer type, but got " + << scalar_type->dtype_.ToString(); + + // Result has the same shape and dtype as the input tile; the shift amount does not change element type. + return std::make_shared(tile_type->shape_, tile_type->dtype_); +} + // ============================================================================ // Op Registration // ============================================================================ @@ -151,6 +213,16 @@ REGISTER_OP("block.minimum") return DeduceBlockOpElementwiseBinaryType(args, kwargs, "block.minimum"); }); +REGISTER_OP("block.rem") + .set_op_category("BlockOp") + .set_description("Element-wise remainder (modulo) of two tiles with broadcasting") + .add_argument("lhs", "Left-hand side tile (TileType)") + .add_argument("rhs", "Right-hand side tile (TileType)") + .f_deduce_type([](const std::vector& args, + const std::vector>& kwargs) { + return DeduceBlockOpElementwiseBinaryType(args, kwargs, "block.rem"); + }); + REGISTER_OP("block.muls") .set_op_category("BlockOp") .set_description("Element-wise multiplication of tile and scalar") @@ -191,6 +263,418 @@ REGISTER_OP("block.subs") return DeduceBlockOpScalarBinaryType(args, kwargs, "block.subs"); }); +REGISTER_OP("block.rems") + .set_op_category("BlockOp") + .set_description("Element-wise remainder (modulo) of tile and scalar") + .add_argument("lhs", "Tile (TileType)") + .add_argument("rhs", "Scalar (ScalarType)") + .f_deduce_type([](const std::vector& args, + const std::vector>& kwargs) { + return DeduceBlockOpScalarBinaryType(args, kwargs, "block.rems"); + }); + +REGISTER_OP("block.shl") + .set_op_category("BlockOp") + .set_description("Element-wise bitwise left shift of two tiles with broadcasting") + .add_argument("lhs", "Left-hand side tile (TileType)") + .add_argument("rhs", "Right-hand side tile (TileType)") + .f_deduce_type([](const std::vector& args, + const std::vector>& kwargs) { + return DeduceBlockOpShiftBinaryType(args, kwargs, "block.shl"); + }); + +REGISTER_OP("block.shls") + .set_op_category("BlockOp") + .set_description("Element-wise bitwise left shift of tile and scalar") + .add_argument("lhs", "Tile (TileType)") + .add_argument("rhs", "Scalar (ScalarType)") + .f_deduce_type([](const std::vector& args, + const std::vector>& kwargs) { + return DeduceBlockOpIntScalarBinaryType(args, kwargs, "block.shls"); + }); + +REGISTER_OP("block.shr") + .set_op_category("BlockOp") + .set_description("Element-wise bitwise right shift of two tiles with broadcasting") + .add_argument("lhs", "Left-hand side tile (TileType)") + .add_argument("rhs", "Right-hand side tile (TileType)") + .f_deduce_type([](const std::vector& args, + const std::vector>& kwargs) { + return DeduceBlockOpShiftBinaryType(args, kwargs, "block.shr"); + }); + +REGISTER_OP("block.shrs") + .set_op_category("BlockOp") + .set_description("Element-wise bitwise right shift of tile and scalar") + .add_argument("lhs", "Tile (TileType)") + .add_argument("rhs", "Scalar (ScalarType)") + .f_deduce_type([](const std::vector& args, + const std::vector>& kwargs) { + return DeduceBlockOpIntScalarBinaryType(args, kwargs, "block.shrs"); + }); + +REGISTER_OP("block.maxs") + .set_op_category("BlockOp") + .set_description("Element-wise maximum of tile and scalar") + .add_argument("lhs", "Tile (TileType)") + .add_argument("rhs", "Scalar (ScalarType)") + .f_deduce_type([](const std::vector& args, + const std::vector>& kwargs) { + return DeduceBlockOpScalarBinaryType(args, kwargs, "block.maxs"); + }); + +REGISTER_OP("block.mins") + .set_op_category("BlockOp") + .set_description("Element-wise minimum of tile and scalar") + .add_argument("lhs", "Tile (TileType)") + .add_argument("rhs", "Scalar (ScalarType)") + .f_deduce_type([](const std::vector& args, + const std::vector>& kwargs) { + return DeduceBlockOpScalarBinaryType(args, kwargs, "block.mins"); + }); + +REGISTER_OP("block.and") + .set_op_category("BlockOp") + .set_description("Element-wise bitwise AND of two tiles with broadcasting") + .add_argument("lhs", "Left-hand side tile (TileType)") + .add_argument("rhs", "Right-hand side tile (TileType)") + .f_deduce_type([](const std::vector& args, + const std::vector>& kwargs) { + return DeduceBlockOpElementwiseBinaryType(args, kwargs, "block.and", true); + }); + +REGISTER_OP("block.ands") + .set_op_category("BlockOp") + .set_description("Element-wise bitwise AND of tile and scalar") + .add_argument("lhs", "Tile (TileType)") + .add_argument("rhs", "Scalar (ScalarType)") + .f_deduce_type([](const std::vector& args, + const std::vector>& kwargs) { + return DeduceBlockOpIntScalarBinaryType(args, kwargs, "block.ands"); + }); + +REGISTER_OP("block.or") + .set_op_category("BlockOp") + .set_description("Element-wise bitwise OR of two tiles with broadcasting") + .add_argument("lhs", "Left-hand side tile (TileType)") + .add_argument("rhs", "Right-hand side tile (TileType)") + .f_deduce_type([](const std::vector& args, + const std::vector>& kwargs) { + return DeduceBlockOpElementwiseBinaryType(args, kwargs, "block.or", true); + }); + +REGISTER_OP("block.ors") + .set_op_category("BlockOp") + .set_description("Element-wise bitwise OR of tile and scalar") + .add_argument("lhs", "Tile (TileType)") + .add_argument("rhs", "Scalar (ScalarType)") + .f_deduce_type([](const std::vector& args, + const std::vector>& kwargs) { + return DeduceBlockOpIntScalarBinaryType(args, kwargs, "block.ors"); + }); + +// Tile-tile ternary ops with a tmp buffer as the third argument. +// When require_int is true (bitwise ops like xor), both tile dtypes must be integer. +TypePtr DeduceBlockOpTernaryType(const std::vector& args, + const std::vector>& kwargs, + const std::string& op_name, bool require_int = false) { + CHECK(args.size() == 3) << "The operator " << op_name << " requires exactly 3 arguments, but got " + << args.size(); + + auto tile_type1 = As(args[0]->GetType()); + auto tile_type2 = As(args[1]->GetType()); + CHECK(tile_type1) << "The operator " << op_name << " requires first argument to be a TileType, but got " + << args[0]->GetType()->TypeName(); + CHECK(tile_type2) << "The operator " << op_name << " requires second argument to be a TileType, but got " + << args[1]->GetType()->TypeName(); + CHECK(As(args[2]->GetType())) + << "The operator " << op_name << " requires third argument (tmp) to be a TileType, but got " + << args[2]->GetType()->TypeName(); + + if (require_int) { + CHECK(tile_type1->dtype_.IsInt()) + << "The operator " << op_name << " requires integer tile dtype, but got " + << tile_type1->dtype_.ToString(); + CHECK(tile_type2->dtype_.IsInt()) + << "The operator " << op_name << " requires integer tile dtype, but got " + << tile_type2->dtype_.ToString(); + } + + auto result_dtype = PromoteDataTypes(tile_type1->dtype_, tile_type2->dtype_); + CHECK(result_dtype) << "The operator " << op_name << " requires compatible data types"; + auto broadcast_result = BroadcastShapes(tile_type1->shape_, tile_type2->shape_); + CHECK(broadcast_result.success) << "The operator " << op_name << " requires compatible shapes"; + + return std::make_shared(broadcast_result.shape, *result_dtype); +} + +// All three tiles are real inputs (addc, subc): promote dtype and broadcast shape across all three. +TypePtr DeduceBlockOpTriTileType(const std::vector& args, + const std::vector>& kwargs, + const std::string& op_name) { + CHECK(args.size() == 3) << "The operator " << op_name << " requires exactly 3 arguments, but got " + << args.size(); + + auto tile_type1 = As(args[0]->GetType()); + auto tile_type2 = As(args[1]->GetType()); + auto tile_type3 = As(args[2]->GetType()); + CHECK(tile_type1) << "The operator " << op_name << " requires first argument to be a TileType, but got " + << args[0]->GetType()->TypeName(); + CHECK(tile_type2) << "The operator " << op_name << " requires second argument to be a TileType, but got " + << args[1]->GetType()->TypeName(); + CHECK(tile_type3) << "The operator " << op_name << " requires third argument to be a TileType, but got " + << args[2]->GetType()->TypeName(); + + auto result_dtype12 = PromoteDataTypes(tile_type1->dtype_, tile_type2->dtype_); + CHECK(result_dtype12) << "The operator " << op_name << " requires compatible data types"; + auto result_dtype = PromoteDataTypes(*result_dtype12, tile_type3->dtype_); + CHECK(result_dtype) << "The operator " << op_name << " requires compatible data types"; + + auto broadcast12 = BroadcastShapes(tile_type1->shape_, tile_type2->shape_); + CHECK(broadcast12.success) << "The operator " << op_name << " requires compatible shapes"; + auto broadcast_result = BroadcastShapes(broadcast12.shape, tile_type3->shape_); + CHECK(broadcast_result.success) << "The operator " << op_name << " requires compatible shapes"; + + return std::make_shared(broadcast_result.shape, *result_dtype); +} + +// (Tile, Scalar, Tile) pattern (addsc, subsc): any scalar type, promote output from all three inputs. +TypePtr DeduceBlockOpTileScalarTileType(const std::vector& args, + const std::vector>& kwargs, + const std::string& op_name) { + CHECK(args.size() == 3) << "The operator " << op_name << " requires exactly 3 arguments, but got " + << args.size(); + + auto tile_type1 = As(args[0]->GetType()); + CHECK(tile_type1) << "The operator " << op_name << " requires first argument to be a TileType, but got " + << args[0]->GetType()->TypeName(); + + auto scalar_type = As(args[1]->GetType()); + CHECK(scalar_type) << "The operator " << op_name << " requires second argument to be a ScalarType, but got " + << args[1]->GetType()->TypeName(); + + auto tile_type2 = As(args[2]->GetType()); + CHECK(tile_type2) << "The operator " << op_name << " requires third argument to be a TileType, but got " + << args[2]->GetType()->TypeName(); + + auto result_dtype12 = PromoteDataTypes(tile_type1->dtype_, scalar_type->dtype_); + CHECK(result_dtype12) << "The operator " << op_name << " requires compatible data types"; + auto result_dtype = PromoteDataTypes(*result_dtype12, tile_type2->dtype_); + CHECK(result_dtype) << "The operator " << op_name << " requires compatible data types"; + + auto broadcast_result = BroadcastShapes(tile_type1->shape_, tile_type2->shape_); + CHECK(broadcast_result.success) << "The operator " << op_name << " requires compatible shapes"; + + return std::make_shared(broadcast_result.shape, *result_dtype); +} + +TypePtr DeduceBlockOpXorScalarType(const std::vector& args, + const std::vector>& kwargs, + const std::string& op_name) { + CHECK(args.size() == 3) << "The operator " << op_name << " requires exactly 3 arguments, but got " + << args.size(); + + auto tile_type = As(args[0]->GetType()); + CHECK(tile_type) << "The operator " << op_name << " requires first argument to be a TileType, but got " + << args[0]->GetType()->TypeName(); + CHECK(tile_type->dtype_.IsInt()) << "The operator " << op_name << " requires integer tile dtype, but got " + << tile_type->dtype_.ToString(); + + // Second argument must be ScalarType with an integer dtype per ISA spec: + // %dst = txors %src, %scalar : !pto.tile<...>, i32 + // The IR allows any integer width (INT8/16/32/64, UINT variants); codegen casts to i32. + auto scalar_type = As(args[1]->GetType()); + CHECK(scalar_type) << "The operator " << op_name << " requires second argument to be a ScalarType, but got " + << args[1]->GetType()->TypeName(); + CHECK(scalar_type->dtype_.IsInt()) << "The operator " << op_name + << " requires scalar to be an integer type, but got " + << scalar_type->dtype_.ToString(); + + CHECK(As(args[2]->GetType())) + << "The operator " << op_name << " requires third argument to be a TileType, but got " + << args[2]->GetType()->TypeName(); + + // Result has the same shape and dtype as the input tile; bitwise ops do not change element type. + return std::make_shared(tile_type->shape_, tile_type->dtype_); +} + +REGISTER_OP("block.xor") + .set_op_category("BlockOp") + .set_description("Element-wise bitwise XOR of two tiles with broadcasting") + .add_argument("lhs", "Left-hand side tile (TileType)") + .add_argument("rhs", "Right-hand side tile (TileType)") + .add_argument("tmp", "Temporary tile (TileType)") + .f_deduce_type([](const std::vector& args, + const std::vector>& kwargs) { + return DeduceBlockOpTernaryType(args, kwargs, "block.xor", true); + }); + +REGISTER_OP("block.xors") + .set_op_category("BlockOp") + .set_description("Element-wise bitwise XOR of tile and scalar") + .add_argument("lhs", "Tile (TileType)") + .add_argument("rhs", "Scalar (ScalarType)") + .add_argument("tmp", "Temporary tile (TileType)") + .f_deduce_type([](const std::vector& args, + const std::vector>& kwargs) { + return DeduceBlockOpXorScalarType(args, kwargs, "block.xors"); + }); + +REGISTER_OP("block.prelu") + .set_op_category("BlockOp") + .set_description("Element-wise parametric ReLU of a tile with slope tile and temporary buffer") + .add_argument("tile", "Input tile (TileType)") + .add_argument("slope", "Slope tile (TileType)") + .add_argument("tmp", "Temporary tile (TileType)") + .f_deduce_type([](const std::vector& args, + const std::vector>& kwargs) { + return DeduceBlockOpTernaryType(args, kwargs, "block.prelu"); + }); + +REGISTER_OP("block.addc") + .set_op_category("BlockOp") + .set_description("Element-wise addition of three tiles (lhs + rhs + rhs2) with broadcasting") + .add_argument("lhs", "Left-hand side tile (TileType)") + .add_argument("rhs", "Right-hand side tile (TileType)") + .add_argument("rhs2", "Third tile (TileType)") + .f_deduce_type([](const std::vector& args, + const std::vector>& kwargs) { + return DeduceBlockOpTriTileType(args, kwargs, "block.addc"); + }); + +REGISTER_OP("block.subc") + .set_op_category("BlockOp") + .set_description("Element-wise subtraction of three tiles (lhs - rhs - rhs2) with broadcasting") + .add_argument("lhs", "Left-hand side tile (TileType)") + .add_argument("rhs", "Right-hand side tile (TileType)") + .add_argument("rhs2", "Third tile (TileType)") + .f_deduce_type([](const std::vector& args, + const std::vector>& kwargs) { + return DeduceBlockOpTriTileType(args, kwargs, "block.subc"); + }); + +REGISTER_OP("block.addsc") + .set_op_category("BlockOp") + .set_description("Element-wise addition of tile, scalar, and tile (lhs + scalar + rhs2)") + .add_argument("lhs", "Left-hand side tile (TileType)") + .add_argument("rhs", "Scalar (ScalarType)") + .add_argument("rhs2", "Third tile (TileType)") + .f_deduce_type([](const std::vector& args, + const std::vector>& kwargs) { + return DeduceBlockOpTileScalarTileType(args, kwargs, "block.addsc"); + }); + +REGISTER_OP("block.subsc") + .set_op_category("BlockOp") + .set_description("Element-wise subtraction of tile, scalar, and tile (lhs - scalar - rhs2)") + .add_argument("lhs", "Left-hand side tile (TileType)") + .add_argument("rhs", "Scalar (ScalarType)") + .add_argument("rhs2", "Third tile (TileType)") + .f_deduce_type([](const std::vector& args, + const std::vector>& kwargs) { + return DeduceBlockOpTileScalarTileType(args, kwargs, "block.subsc"); + }); + +REGISTER_OP("block.lrelu") + .set_op_category("BlockOp") + .set_description("Element-wise leaky ReLU of a tile with scalar slope (max(x, slope*x))") + .add_argument("tile", "Input tile (TileType)") + .add_argument("slope", "Scalar slope for negative values (ScalarType)") + .f_deduce_type([](const std::vector& args, + const std::vector>& kwargs) { + return DeduceBlockOpScalarBinaryType(args, kwargs, "block.lrelu"); + }); + +// Type deduction for block.sel (MaskTile x Tile x Tile -> Tile) +// The mask tile encodes per-element predicates in a target-defined layout; its dtype/shape +// do not influence the output type. Output type is derived from lhs and rhs only. +TypePtr DeduceBlockSelType(const std::vector& args, + const std::vector>& kwargs, + const std::string& op_name) { + CHECK(args.size() == 3) << "The operator " << op_name << " requires exactly 3 arguments, but got " + << args.size(); + + CHECK(As(args[0]->GetType())) + << "The operator " << op_name << " requires first argument (mask) to be a TileType, but got " + << args[0]->GetType()->TypeName(); + + auto tile_type1 = As(args[1]->GetType()); + auto tile_type2 = As(args[2]->GetType()); + CHECK(tile_type1) << "The operator " << op_name + << " requires second argument (lhs) to be a TileType, but got " + << args[1]->GetType()->TypeName(); + CHECK(tile_type2) << "The operator " << op_name + << " requires third argument (rhs) to be a TileType, but got " + << args[2]->GetType()->TypeName(); + + auto result_dtype = PromoteDataTypes(tile_type1->dtype_, tile_type2->dtype_); + CHECK(result_dtype) << "The operator " << op_name << " requires compatible data types, but got " + << tile_type1->dtype_.ToString() << " and " << tile_type2->dtype_.ToString(); + + auto broadcast_result = BroadcastShapes(tile_type1->shape_, tile_type2->shape_); + CHECK(broadcast_result.success) << "The operator " << op_name << " requires compatible shapes, but got " + << FormatShape(tile_type1->shape_) << " and " + << FormatShape(tile_type2->shape_); + + return std::make_shared(broadcast_result.shape, *result_dtype); +} + +REGISTER_OP("block.sel") + .set_op_category("BlockOp") + .set_description( + "Per-element selection between two tiles using a predicate mask tile. " + "Maps to the TSEL hardware intrinsic.") + .add_argument("mask", "Predicate mask tile; encoding is target-defined (TileType)") + .add_argument("lhs", "Source tile 0, selected where mask is true (TileType)") + .add_argument("rhs", "Source tile 1, selected where mask is false (TileType)") + .f_deduce_type([](const std::vector& args, + const std::vector>& kwargs) { + return DeduceBlockSelType(args, kwargs, "block.sel"); + }); + +// Type deduction for block.sels (Tile x Tile x Scalar -> Tile) +TypePtr DeduceBlockSelScalarType(const std::vector& args, + const std::vector>& kwargs, + const std::string& op_name) { + CHECK(args.size() == 3) << "The operator " << op_name << " requires exactly 3 arguments, but got " + << args.size(); + + auto tile_type1 = As(args[0]->GetType()); + auto tile_type2 = As(args[1]->GetType()); + CHECK(tile_type1) << "The operator " << op_name + << " requires first argument (lhs) to be a TileType, but got " + << args[0]->GetType()->TypeName(); + CHECK(tile_type2) << "The operator " << op_name + << " requires second argument (rhs) to be a TileType, but got " + << args[1]->GetType()->TypeName(); + + CHECK(As(args[2]->GetType())) + << "The operator " << op_name << " requires third argument (select_mode) to be a ScalarType, but got " + << args[2]->GetType()->TypeName(); + + auto result_dtype = PromoteDataTypes(tile_type1->dtype_, tile_type2->dtype_); + CHECK(result_dtype) << "The operator " << op_name << " requires compatible data types, but got " + << tile_type1->dtype_.ToString() << " and " << tile_type2->dtype_.ToString(); + + auto broadcast_result = BroadcastShapes(tile_type1->shape_, tile_type2->shape_); + CHECK(broadcast_result.success) << "The operator " << op_name << " requires compatible shapes, but got " + << FormatShape(tile_type1->shape_) << " and " + << FormatShape(tile_type2->shape_); + + return std::make_shared(broadcast_result.shape, *result_dtype); +} + +REGISTER_OP("block.sels") + .set_op_category("BlockOp") + .set_description("Select between two tiles based on a scalar mode. Maps to the TSELS hardware intrinsic.") + .add_argument("lhs", "Source tile 0 (TileType)") + .add_argument("rhs", "Source tile 1 (TileType)") + .add_argument("select_mode", "Scalar select mode (ScalarType)") + .f_deduce_type([](const std::vector& args, + const std::vector>& kwargs) { + return DeduceBlockSelScalarType(args, kwargs, "block.sels"); + }); + // Type deduction for block.cmp and block.cmps (comparison operations) TypePtr DeduceBlockCmpType(const std::vector& args, const std::vector>& kwargs, diff --git a/src/ir/op/block_ops/matmul.cpp b/src/ir/op/block_ops/matmul.cpp index 6c6d7533..224bac69 100644 --- a/src/ir/op/block_ops/matmul.cpp +++ b/src/ir/op/block_ops/matmul.cpp @@ -171,6 +171,70 @@ TypePtr DeduceBlockMatMulAccType(const std::vector& args, return std::make_shared(output_shape, *result_dtype); } +TypePtr DeduceBlockMatMulBiasType(const std::vector& args, + const std::vector>& kwargs, + const std::string& op_name) { + CHECK(args.size() == 3) << "The operator " << op_name << " requires exactly 3 arguments, but got " + << args.size(); + + auto lhs_type = As(args[0]->GetType()); + auto rhs_type = As(args[1]->GetType()); + auto bias_type = As(args[2]->GetType()); + + CHECK(lhs_type) << "The operator " << op_name << " requires first argument (lhs) to be a TileType, but got " + << args[0]->GetType()->TypeName(); + CHECK(rhs_type) << "The operator " << op_name + << " requires second argument (rhs) to be a TileType, but got " + << args[1]->GetType()->TypeName(); + CHECK(bias_type) << "The operator " << op_name + << " requires third argument (bias) to be a TileType, but got " + << args[2]->GetType()->TypeName(); + + const auto& lhs_shape = lhs_type->shape_; + const auto& rhs_shape = rhs_type->shape_; + const auto& bias_shape = bias_type->shape_; + + CHECK(lhs_shape.size() == 2) << "The operator " << op_name << " requires lhs to be 2D, but got " + << lhs_shape.size() << " dimensions"; + CHECK(rhs_shape.size() == 2) << "The operator " << op_name << " requires rhs to be 2D, but got " + << rhs_shape.size() << " dimensions"; + CHECK(bias_shape.size() == 2) << "The operator " << op_name << " requires bias to be 2D, but got " + << bias_shape.size() << " dimensions"; + + auto k_lhs_const = As(lhs_shape[1]); + auto k_rhs_const = As(rhs_shape[0]); + if (k_lhs_const && k_rhs_const) { + CHECK(k_lhs_const->value_ == k_rhs_const->value_) + << "The operator " << op_name + << " requires matching inner dimensions, but got lhs K=" << k_lhs_const->value_ + << " and rhs K=" << k_rhs_const->value_; + } + + std::vector output_shape = {lhs_shape[0], rhs_shape[1]}; + + // Hardware requires bias to be [1, N] + auto bias_row_const = As(bias_shape[0]); + CHECK(bias_row_const && bias_row_const->value_ == 1) + << "The operator " << op_name << " requires bias to have shape [1, N], but got " + << FormatShape(bias_shape); + auto bias_n_const = As(bias_shape[1]); + auto rhs_n_const = As(rhs_shape[1]); + if (bias_n_const && rhs_n_const) { + CHECK(bias_n_const->value_ == rhs_n_const->value_) + << "The operator " << op_name + << " requires bias N dimension to match output N=" << rhs_n_const->value_ + << ", but got bias N=" << bias_n_const->value_; + } + + auto lhs_rhs_dtype = PromoteDataTypes(lhs_type->dtype_, rhs_type->dtype_); + CHECK(lhs_rhs_dtype) << "The operator " << op_name << " requires compatible lhs/rhs data types, but got " + << lhs_type->dtype_.ToString() << " and " << rhs_type->dtype_.ToString(); + auto result_dtype = PromoteDataTypes(*lhs_rhs_dtype, bias_type->dtype_); + CHECK(result_dtype) << "The operator " << op_name << " requires compatible bias data type, but got " + << lhs_rhs_dtype->ToString() << " and " << bias_type->dtype_.ToString(); + return std::make_shared(output_shape, *result_dtype); +} + // ============================================================================ // Registration Function for Block Matrix Multiplication Operations // ============================================================================ @@ -196,5 +260,48 @@ REGISTER_OP("block.matmul_acc") return DeduceBlockMatMulAccType(args, kwargs, "block.matmul_acc"); }); +REGISTER_OP("block.matmul_bias") + .set_op_category("BlockOp") + .set_description("Matrix multiplication with bias add: C = lhs @ rhs + bias") + .add_argument("lhs", "Left-hand side tile (TileType, 2D)") + .add_argument("rhs", "Right-hand side tile (TileType, 2D)") + .add_argument("bias", "Bias tile (TileType, [1, N])") + .f_deduce_type([](const std::vector& args, + const std::vector>& kwargs) { + return DeduceBlockMatMulBiasType(args, kwargs, "block.matmul_bias"); + }); + +REGISTER_OP("block.gemv") + .set_op_category("BlockOp") + .set_description("General Matrix-Vector multiplication: C[1,N] = A[1,K] @ B[K,N]") + .add_argument("lhs", "Row vector tile (TileType, 2D [1, K])") + .add_argument("rhs", "Right-hand side tile (TileType, 2D [K, N])") + .f_deduce_type([](const std::vector& args, + const std::vector>& kwargs) { + return DeduceBlockMatMulType(args, kwargs, "block.gemv"); + }); + +REGISTER_OP("block.gemv_acc") + .set_op_category("BlockOp") + .set_description("GEMV with accumulation: C[1,N] += A[1,K] @ B[K,N]") + .add_argument("acc", "Accumulator tile (TileType, 2D [1, N])") + .add_argument("lhs", "Row vector tile (TileType, 2D [1, K])") + .add_argument("rhs", "Right-hand side tile (TileType, 2D [K, N])") + .f_deduce_type([](const std::vector& args, + const std::vector>& kwargs) { + return DeduceBlockMatMulAccType(args, kwargs, "block.gemv_acc"); + }); + +REGISTER_OP("block.gemv_bias") + .set_op_category("BlockOp") + .set_description("GEMV with bias add: C[1,N] = A[1,K] @ B[K,N] + bias[1,N]") + .add_argument("lhs", "Row vector tile (TileType, 2D [1, K])") + .add_argument("rhs", "Right-hand side tile (TileType, 2D [K, N])") + .add_argument("bias", "Bias tile (TileType, [1, N])") + .f_deduce_type([](const std::vector& args, + const std::vector>& kwargs) { + return DeduceBlockMatMulBiasType(args, kwargs, "block.gemv_bias"); + }); + } // namespace ir } // namespace pypto diff --git a/src/ir/op/block_ops/unary.cpp b/src/ir/op/block_ops/unary.cpp index bcf18b76..4e3423fc 100644 --- a/src/ir/op/block_ops/unary.cpp +++ b/src/ir/op/block_ops/unary.cpp @@ -170,5 +170,23 @@ REGISTER_OP("block.relu") return DeduceBlockUnaryType(args, kwargs, "block.relu"); }); +REGISTER_OP("block.not") + .set_op_category("BlockOp") + .set_description("Element-wise bitwise NOT of a tile") + .add_argument("tile", "Input tile (TileType) with int16 or uint16 dtype") + .f_deduce_type([](const std::vector& args, + const std::vector>& kwargs) { + const std::string op_name = "block.not"; + CHECK(args.size() == 1) << "The operator " << op_name << " requires exactly 1 argument, but got " + << args.size(); + auto tile_type = As(args[0]->GetType()); + CHECK(tile_type) << "The operator " << op_name << " requires argument to be a TileType, but got " + << args[0]->GetType()->TypeName(); + CHECK(tile_type->dtype_ == DataType::INT16 || tile_type->dtype_ == DataType::UINT16) + << "The operator " << op_name << " requires int16 or uint16 tile dtype, but got " + << tile_type->dtype_.ToString(); + return std::make_shared(tile_type->shape_, tile_type->dtype_); + }); + } // namespace ir } // namespace pypto diff --git a/tests/ut/ir/operators/test_block_ops.py b/tests/ut/ir/operators/test_block_ops.py index 859544b4..d279904d 100644 --- a/tests/ut/ir/operators/test_block_ops.py +++ b/tests/ut/ir/operators/test_block_ops.py @@ -11,7 +11,8 @@ import pypto.language as pl import pytest -from pypto import DataType, ir +from pypto import DataType, backend, ir +from pypto.backend import BackendType from pypto.ir.op import block from pypto.ir.pass_manager import PassManager @@ -378,6 +379,8 @@ def row_max_kernel( return result program = RowMaxKernel + backend.reset_for_testing() + backend.set_backend_type(BackendType.CCE) pm = PassManager.get_strategy() optimized_program = pm.run_passes(program) @@ -402,6 +405,8 @@ def row_sum_kernel( return result program = RowSumKernel + backend.reset_for_testing() + backend.set_backend_type(BackendType.CCE) pm = PassManager.get_strategy() optimized_program = pm.run_passes(program) @@ -577,6 +582,88 @@ def main( ir_str = str(Program) assert "block.row_expand_add" in ir_str + def test_block_row_expand_sub(self): + """Test block.row_expand_sub operator - subtract row vector from each tile row.""" + + @pl.program + class Program: + @pl.function(type=pl.FunctionType.InCore) + def main( + self, + tile: pl.Tensor[[128, 128], pl.FP32], + row: pl.Tensor[[128, 128], pl.FP32], + output: pl.Tensor[[128, 128], pl.FP32], + ) -> pl.Tensor[[128, 128], pl.FP32]: + tile_a: pl.Tile[[32, 32], pl.FP32] = pl.load(tile, [0, 0], [32, 32]) + tile_row: pl.Tile[[32, 1], pl.FP32] = pl.load(row, [0, 0], [32, 1]) + tile_c: pl.Tile[[32, 32], pl.FP32] = pl.row_expand_sub(tile_a, tile_row) + result: pl.Tensor[[128, 128], pl.FP32] = pl.store(tile_c, [0, 0], [32, 32], output) + return result + + ir_str = str(Program) + assert "block.row_expand_sub" in ir_str + + def test_block_row_expand_div(self): + """Test block.row_expand_div operator - divide each tile row by row vector.""" + + @pl.program + class Program: + @pl.function(type=pl.FunctionType.InCore) + def main( + self, + tile: pl.Tensor[[128, 128], pl.FP32], + row: pl.Tensor[[128, 128], pl.FP32], + output: pl.Tensor[[128, 128], pl.FP32], + ) -> pl.Tensor[[128, 128], pl.FP32]: + tile_a: pl.Tile[[32, 32], pl.FP32] = pl.load(tile, [0, 0], [32, 32]) + tile_row: pl.Tile[[32, 1], pl.FP32] = pl.load(row, [0, 0], [32, 1]) + tile_c: pl.Tile[[32, 32], pl.FP32] = pl.row_expand_div(tile_a, tile_row) + result: pl.Tensor[[128, 128], pl.FP32] = pl.store(tile_c, [0, 0], [32, 32], output) + return result + + ir_str = str(Program) + assert "block.row_expand_div" in ir_str + + def test_block_row_expand_mul(self): + """Test block.row_expand_mul operator - multiply each tile row by row vector.""" + + @pl.program + class Program: + @pl.function(type=pl.FunctionType.InCore) + def main( + self, + tile: pl.Tensor[[128, 128], pl.FP32], + row: pl.Tensor[[128, 128], pl.FP32], + output: pl.Tensor[[128, 128], pl.FP32], + ) -> pl.Tensor[[128, 128], pl.FP32]: + tile_a: pl.Tile[[32, 32], pl.FP32] = pl.load(tile, [0, 0], [32, 32]) + tile_row: pl.Tile[[32, 1], pl.FP32] = pl.load(row, [0, 0], [32, 1]) + tile_c: pl.Tile[[32, 32], pl.FP32] = pl.row_expand_mul(tile_a, tile_row) + result: pl.Tensor[[128, 128], pl.FP32] = pl.store(tile_c, [0, 0], [32, 32], output) + return result + + ir_str = str(Program) + assert "block.row_expand_mul" in ir_str + + def test_block_row_expand(self): + """Test block.row_expand operator - broadcast first element of each row across the row.""" + + @pl.program + class Program: + @pl.function(type=pl.FunctionType.InCore) + def main( + self, + tile: pl.Tensor[[128, 128], pl.FP32], + output: pl.Tensor[[128, 128], pl.FP32], + ) -> pl.Tensor[[128, 128], pl.FP32]: + tile_a: pl.Tile[[32, 32], pl.FP32] = pl.load(tile, [0, 0], [32, 32]) + tile_c: pl.Tile[[32, 32], pl.FP32] = pl.row_expand(tile_a) + result: pl.Tensor[[128, 128], pl.FP32] = pl.store(tile_c, [0, 0], [32, 32], output) + return result + + ir_str = str(Program) + assert "block.row_expand" in ir_str + def test_block_expands(self): """Test block.expands operator - expand scalar to tile shape.""" @@ -621,6 +708,122 @@ def main( ir_str = str(Program) assert "block.matmul" in ir_str + def test_block_matmul_acc(self): + """Test block.matmul_acc operator - matrix multiplication with accumulation (TMATMUL_ACC). + + Computes: acc_out = acc_in + lhs @ rhs + """ + + @pl.program + class Program: + @pl.function(type=pl.FunctionType.InCore) + def main( + self, + acc_in: pl.Tensor[[128, 128], pl.FP32], + a: pl.Tensor[[128, 64], pl.FP32], + b: pl.Tensor[[64, 128], pl.FP32], + output: pl.Tensor[[128, 128], pl.FP32], + ) -> pl.Tensor[[128, 128], pl.FP32]: + tile_acc: pl.Tile[[32, 32], pl.FP32] = pl.load(acc_in, [0, 0], [32, 32]) + tile_a: pl.Tile[[32, 16], pl.FP32] = pl.load(a, [0, 0], [32, 16]) + tile_b: pl.Tile[[16, 32], pl.FP32] = pl.load(b, [0, 0], [16, 32]) + tile_c: pl.Tile[[32, 32], pl.FP32] = pl.matmul_acc(tile_acc, tile_a, tile_b) + result: pl.Tensor[[128, 128], pl.FP32] = pl.store(tile_c, [0, 0], [32, 32], output) + return result + + ir_str = str(Program) + assert "block.matmul_acc" in ir_str + + def test_block_matmul_bias(self): + """Test block.matmul_bias operator - matrix multiplication with bias add.""" + + @pl.program + class Program: + @pl.function(type=pl.FunctionType.InCore) + def main( + self, + a: pl.Tensor[[128, 64], pl.FP32], + b: pl.Tensor[[64, 128], pl.FP32], + bias: pl.Tensor[[1, 128], pl.FP32], + output: pl.Tensor[[128, 128], pl.FP32], + ) -> pl.Tensor[[128, 128], pl.FP32]: + tile_a: pl.Tile[[32, 16], pl.FP32] = pl.load(a, [0, 0], [32, 16]) + tile_b: pl.Tile[[16, 32], pl.FP32] = pl.load(b, [0, 0], [16, 32]) + tile_bias: pl.Tile[[1, 32], pl.FP32] = pl.load(bias, [0, 0], [1, 32]) + tile_c: pl.Tile[[32, 32], pl.FP32] = pl.matmul_bias(tile_a, tile_b, tile_bias) + result: pl.Tensor[[128, 128], pl.FP32] = pl.store(tile_c, [0, 0], [32, 32], output) + return result + + ir_str = str(Program) + assert "block.matmul_bias" in ir_str + + def test_block_gemv(self): + """Test block.gemv operator - general matrix-vector multiplication.""" + + @pl.program + class Program: + @pl.function(type=pl.FunctionType.InCore) + def main( + self, + a: pl.Tensor[[1, 64], pl.FP32], + b: pl.Tensor[[64, 128], pl.FP32], + output: pl.Tensor[[1, 128], pl.FP32], + ) -> pl.Tensor[[1, 128], pl.FP32]: + tile_a: pl.Tile[[1, 16], pl.FP32] = pl.load(a, [0, 0], [1, 16]) + tile_b: pl.Tile[[16, 32], pl.FP32] = pl.load(b, [0, 0], [16, 32]) + tile_c: pl.Tile[[1, 32], pl.FP32] = pl.gemv(tile_a, tile_b) + result: pl.Tensor[[1, 128], pl.FP32] = pl.store(tile_c, [0, 0], [1, 32], output) + return result + + ir_str = str(Program) + assert "block.gemv" in ir_str + + def test_block_gemv_acc(self): + """Test block.gemv_acc operator - GEMV with accumulation.""" + + @pl.program + class Program: + @pl.function(type=pl.FunctionType.InCore) + def main( + self, + acc_in: pl.Tensor[[1, 128], pl.FP32], + a: pl.Tensor[[1, 64], pl.FP32], + b: pl.Tensor[[64, 128], pl.FP32], + output: pl.Tensor[[1, 128], pl.FP32], + ) -> pl.Tensor[[1, 128], pl.FP32]: + tile_acc: pl.Tile[[1, 32], pl.FP32] = pl.load(acc_in, [0, 0], [1, 32]) + tile_a: pl.Tile[[1, 16], pl.FP32] = pl.load(a, [0, 0], [1, 16]) + tile_b: pl.Tile[[16, 32], pl.FP32] = pl.load(b, [0, 0], [16, 32]) + tile_c: pl.Tile[[1, 32], pl.FP32] = pl.gemv_acc(tile_acc, tile_a, tile_b) + result: pl.Tensor[[1, 128], pl.FP32] = pl.store(tile_c, [0, 0], [1, 32], output) + return result + + ir_str = str(Program) + assert "block.gemv_acc" in ir_str + + def test_block_gemv_bias(self): + """Test block.gemv_bias operator - GEMV with bias add.""" + + @pl.program + class Program: + @pl.function(type=pl.FunctionType.InCore) + def main( + self, + a: pl.Tensor[[1, 64], pl.FP32], + b: pl.Tensor[[64, 128], pl.FP32], + bias: pl.Tensor[[1, 128], pl.FP32], + output: pl.Tensor[[1, 128], pl.FP32], + ) -> pl.Tensor[[1, 128], pl.FP32]: + tile_a: pl.Tile[[1, 16], pl.FP32] = pl.load(a, [0, 0], [1, 16]) + tile_b: pl.Tile[[16, 32], pl.FP32] = pl.load(b, [0, 0], [16, 32]) + tile_bias: pl.Tile[[1, 32], pl.FP32] = pl.load(bias, [0, 0], [1, 32]) + tile_c: pl.Tile[[1, 32], pl.FP32] = pl.gemv_bias(tile_a, tile_b, tile_bias) + result: pl.Tensor[[1, 128], pl.FP32] = pl.store(tile_c, [0, 0], [1, 32], output) + return result + + ir_str = str(Program) + assert "block.gemv_bias" in ir_str + class TestBlockTransformOps: """Test suite for block-level transform operators.""" @@ -914,5 +1117,543 @@ def test_view_3d(self): assert len(result_type.shape) == 3 +class TestBlockBitwiseArithmeticOps: + """Test suite for newly added block-level bitwise and arithmetic ops (rem, and, or, xor).""" + + def test_block_rem(self): + """Test block.rem operator - element-wise remainder of two tiles.""" + + @pl.program + class Program: + @pl.function(type=pl.FunctionType.InCore) + def main( + self, + a: pl.Tensor[[128, 128], pl.FP32], + b: pl.Tensor[[128, 128], pl.FP32], + output: pl.Tensor[[128, 128], pl.FP32], + ) -> pl.Tensor[[128, 128], pl.FP32]: + tile_a: pl.Tile[[32, 32], pl.FP32] = pl.load(a, [0, 0], [32, 32]) + tile_b: pl.Tile[[32, 32], pl.FP32] = pl.load(b, [0, 0], [32, 32]) + tile_c: pl.Tile[[32, 32], pl.FP32] = pl.rem(tile_a, tile_b) + result: pl.Tensor[[128, 128], pl.FP32] = pl.store(tile_c, [0, 0], [32, 32], output) + return result + + ir_str = str(Program) + assert "block.rem" in ir_str + + def test_block_rems(self): + """Test block.rems operator - element-wise remainder of tile and scalar.""" + + @pl.program + class Program: + @pl.function(type=pl.FunctionType.InCore) + def main( + self, + a: pl.Tensor[[128, 128], pl.FP32], + output: pl.Tensor[[128, 128], pl.FP32], + ) -> pl.Tensor[[128, 128], pl.FP32]: + tile_a: pl.Tile[[32, 32], pl.FP32] = pl.load(a, [0, 0], [32, 32]) + tile_c: pl.Tile[[32, 32], pl.FP32] = pl.rems(tile_a, 3.0) + result: pl.Tensor[[128, 128], pl.FP32] = pl.store(tile_c, [0, 0], [32, 32], output) + return result + + ir_str = str(Program) + assert "block.rems" in ir_str + + def test_block_and(self): + """Test block.and operator - element-wise bitwise AND of two tiles.""" + + @pl.program + class Program: + @pl.function(type=pl.FunctionType.InCore) + def main( + self, + a: pl.Tensor[[128, 128], pl.INT32], + b: pl.Tensor[[128, 128], pl.INT32], + output: pl.Tensor[[128, 128], pl.INT32], + ) -> pl.Tensor[[128, 128], pl.INT32]: + tile_a: pl.Tile[[32, 32], pl.INT32] = pl.load(a, [0, 0], [32, 32]) + tile_b: pl.Tile[[32, 32], pl.INT32] = pl.load(b, [0, 0], [32, 32]) + tile_c: pl.Tile[[32, 32], pl.INT32] = pl.and_(tile_a, tile_b) + result: pl.Tensor[[128, 128], pl.INT32] = pl.store(tile_c, [0, 0], [32, 32], output) + return result + + ir_str = str(Program) + assert "block.and" in ir_str + + def test_block_ands(self): + """Test block.ands operator - element-wise bitwise AND of tile and scalar.""" + + @pl.program + class Program: + @pl.function(type=pl.FunctionType.InCore) + def main( + self, + a: pl.Tensor[[128, 128], pl.INT32], + scalar: pl.Scalar[pl.INT32], + output: pl.Tensor[[128, 128], pl.INT32], + ) -> pl.Tensor[[128, 128], pl.INT32]: + tile_a: pl.Tile[[32, 32], pl.INT32] = pl.load(a, [0, 0], [32, 32]) + tile_c: pl.Tile[[32, 32], pl.INT32] = pl.ands(tile_a, scalar) + result: pl.Tensor[[128, 128], pl.INT32] = pl.store(tile_c, [0, 0], [32, 32], output) + return result + + ir_str = str(Program) + assert "block.ands" in ir_str + + def test_block_or(self): + """Test block.or operator - element-wise bitwise OR of two tiles.""" + + @pl.program + class Program: + @pl.function(type=pl.FunctionType.InCore) + def main( + self, + a: pl.Tensor[[128, 128], pl.INT32], + b: pl.Tensor[[128, 128], pl.INT32], + output: pl.Tensor[[128, 128], pl.INT32], + ) -> pl.Tensor[[128, 128], pl.INT32]: + tile_a: pl.Tile[[32, 32], pl.INT32] = pl.load(a, [0, 0], [32, 32]) + tile_b: pl.Tile[[32, 32], pl.INT32] = pl.load(b, [0, 0], [32, 32]) + tile_c: pl.Tile[[32, 32], pl.INT32] = pl.or_(tile_a, tile_b) + result: pl.Tensor[[128, 128], pl.INT32] = pl.store(tile_c, [0, 0], [32, 32], output) + return result + + ir_str = str(Program) + assert "block.or" in ir_str + + def test_block_ors(self): + """Test block.ors operator - element-wise bitwise OR of tile and scalar.""" + + @pl.program + class Program: + @pl.function(type=pl.FunctionType.InCore) + def main( + self, + a: pl.Tensor[[128, 128], pl.INT32], + scalar: pl.Scalar[pl.INT32], + output: pl.Tensor[[128, 128], pl.INT32], + ) -> pl.Tensor[[128, 128], pl.INT32]: + tile_a: pl.Tile[[32, 32], pl.INT32] = pl.load(a, [0, 0], [32, 32]) + tile_c: pl.Tile[[32, 32], pl.INT32] = pl.ors(tile_a, scalar) + result: pl.Tensor[[128, 128], pl.INT32] = pl.store(tile_c, [0, 0], [32, 32], output) + return result + + ir_str = str(Program) + assert "block.ors" in ir_str + + def test_block_xor(self): + """Test block.xor operator - element-wise bitwise XOR of two tiles with tmp buffer.""" + + @pl.program + class Program: + @pl.function(type=pl.FunctionType.InCore) + def main( + self, + a: pl.Tensor[[128, 128], pl.INT32], + b: pl.Tensor[[128, 128], pl.INT32], + output: pl.Tensor[[128, 128], pl.INT32], + ) -> pl.Tensor[[128, 128], pl.INT32]: + tile_a: pl.Tile[[32, 32], pl.INT32] = pl.load(a, [0, 0], [32, 32]) + tile_b: pl.Tile[[32, 32], pl.INT32] = pl.load(b, [0, 0], [32, 32]) + tmp: pl.Tile[[32, 32], pl.INT32] = pl.block.create_tile( + [32, 32], dtype=pl.INT32, target_memory=pl.MemorySpace.UB + ) + tile_c: pl.Tile[[32, 32], pl.INT32] = pl.xor(tile_a, tile_b, tmp) + result: pl.Tensor[[128, 128], pl.INT32] = pl.store(tile_c, [0, 0], [32, 32], output) + return result + + ir_str = str(Program) + assert "block.xor" in ir_str + + def test_block_xors(self): + """Test block.xors operator - element-wise bitwise XOR of tile and scalar with tmp buffer.""" + + @pl.program + class Program: + @pl.function(type=pl.FunctionType.InCore) + def main( + self, + a: pl.Tensor[[128, 128], pl.INT32], + scalar: pl.Scalar[pl.INT32], + output: pl.Tensor[[128, 128], pl.INT32], + ) -> pl.Tensor[[128, 128], pl.INT32]: + tile_a: pl.Tile[[32, 32], pl.INT32] = pl.load(a, [0, 0], [32, 32]) + tmp: pl.Tile[[32, 32], pl.INT32] = pl.block.create_tile( + [32, 32], dtype=pl.INT32, target_memory=pl.MemorySpace.UB + ) + tile_c: pl.Tile[[32, 32], pl.INT32] = pl.xors(tile_a, scalar, tmp) + result: pl.Tensor[[128, 128], pl.INT32] = pl.store(tile_c, [0, 0], [32, 32], output) + return result + + ir_str = str(Program) + assert "block.xors" in ir_str + + def test_block_shl(self): + """Test block.shl operator - element-wise bitwise left shift of two tiles.""" + + @pl.program + class Program: + @pl.function(type=pl.FunctionType.InCore) + def main( + self, + a: pl.Tensor[[128, 128], pl.UINT32], + b: pl.Tensor[[128, 128], pl.UINT32], + output: pl.Tensor[[128, 128], pl.UINT32], + ) -> pl.Tensor[[128, 128], pl.UINT32]: + tile_a: pl.Tile[[16, 16], pl.UINT32] = pl.load(a, [0, 0], [16, 16]) + tile_b: pl.Tile[[16, 16], pl.UINT32] = pl.load(b, [0, 0], [16, 16]) + tile_c: pl.Tile[[16, 16], pl.UINT32] = pl.shl(tile_a, tile_b) + result: pl.Tensor[[128, 128], pl.UINT32] = pl.store(tile_c, [0, 0], [16, 16], output) + return result + + ir_str = str(Program) + assert "block.shl" in ir_str + + def test_block_shls(self): + """Test block.shls operator - element-wise bitwise left shift of tile and scalar.""" + + @pl.program + class Program: + @pl.function(type=pl.FunctionType.InCore) + def main( + self, + a: pl.Tensor[[128, 128], pl.UINT32], + scalar: pl.Scalar[pl.INT32], + output: pl.Tensor[[128, 128], pl.UINT32], + ) -> pl.Tensor[[128, 128], pl.UINT32]: + tile_a: pl.Tile[[16, 16], pl.UINT32] = pl.load(a, [0, 0], [16, 16]) + tile_c: pl.Tile[[16, 16], pl.UINT32] = pl.shls(tile_a, scalar) + result: pl.Tensor[[128, 128], pl.UINT32] = pl.store(tile_c, [0, 0], [16, 16], output) + return result + + ir_str = str(Program) + assert "block.shls" in ir_str + + def test_block_maxs(self): + """Test block.maxs operator - element-wise maximum of tile and scalar.""" + + @pl.program + class Program: + @pl.function(type=pl.FunctionType.InCore) + def main( + self, + a: pl.Tensor[[128, 128], pl.FP32], + output: pl.Tensor[[128, 128], pl.FP32], + ) -> pl.Tensor[[128, 128], pl.FP32]: + tile_a: pl.Tile[[16, 16], pl.FP32] = pl.load(a, [0, 0], [16, 16]) + tile_c: pl.Tile[[16, 16], pl.FP32] = pl.maxs(tile_a, 0.0) + result: pl.Tensor[[128, 128], pl.FP32] = pl.store(tile_c, [0, 0], [16, 16], output) + return result + + ir_str = str(Program) + assert "block.maxs" in ir_str + + def test_block_mins(self): + """Test block.mins operator - element-wise minimum of tile and scalar.""" + + @pl.program + class Program: + @pl.function(type=pl.FunctionType.InCore) + def main( + self, + a: pl.Tensor[[128, 128], pl.FP32], + output: pl.Tensor[[128, 128], pl.FP32], + ) -> pl.Tensor[[128, 128], pl.FP32]: + tile_a: pl.Tile[[16, 16], pl.FP32] = pl.load(a, [0, 0], [16, 16]) + tile_c: pl.Tile[[16, 16], pl.FP32] = pl.mins(tile_a, 0.0) + result: pl.Tensor[[128, 128], pl.FP32] = pl.store(tile_c, [0, 0], [16, 16], output) + return result + + ir_str = str(Program) + assert "block.mins" in ir_str + + def test_block_shr(self): + """Test block.shr operator - element-wise bitwise right shift of two tiles.""" + + @pl.program + class Program: + @pl.function(type=pl.FunctionType.InCore) + def main( + self, + a: pl.Tensor[[128, 128], pl.UINT32], + b: pl.Tensor[[128, 128], pl.UINT32], + output: pl.Tensor[[128, 128], pl.UINT32], + ) -> pl.Tensor[[128, 128], pl.UINT32]: + tile_a: pl.Tile[[16, 16], pl.UINT32] = pl.load(a, [0, 0], [16, 16]) + tile_b: pl.Tile[[16, 16], pl.UINT32] = pl.load(b, [0, 0], [16, 16]) + tile_c: pl.Tile[[16, 16], pl.UINT32] = pl.shr(tile_a, tile_b) + result: pl.Tensor[[128, 128], pl.UINT32] = pl.store(tile_c, [0, 0], [16, 16], output) + return result + + ir_str = str(Program) + assert "block.shr" in ir_str + + def test_block_shrs(self): + """Test block.shrs operator - element-wise bitwise right shift of tile and scalar.""" + + @pl.program + class Program: + @pl.function(type=pl.FunctionType.InCore) + def main( + self, + a: pl.Tensor[[128, 128], pl.UINT32], + scalar: pl.Scalar[pl.INT32], + output: pl.Tensor[[128, 128], pl.UINT32], + ) -> pl.Tensor[[128, 128], pl.UINT32]: + tile_a: pl.Tile[[16, 16], pl.UINT32] = pl.load(a, [0, 0], [16, 16]) + tile_c: pl.Tile[[16, 16], pl.UINT32] = pl.shrs(tile_a, scalar) + result: pl.Tensor[[128, 128], pl.UINT32] = pl.store(tile_c, [0, 0], [16, 16], output) + return result + + ir_str = str(Program) + assert "block.shrs" in ir_str + + def test_block_shl_preserves_lhs_dtype(self): + """Regression: block.shl result dtype must match LHS dtype, not the promoted type. + + When lhs is UINT16 and rhs is UINT32, the result must be UINT16 (LHS dtype), + consistent with the scalar variant block.shls which preserves the LHS tile dtype. + """ + + @pl.program + class Program: + @pl.function(type=pl.FunctionType.InCore) + def main( + self, + a: pl.Tensor[[128, 128], pl.UINT16], + b: pl.Tensor[[128, 128], pl.UINT32], + output: pl.Tensor[[128, 128], pl.UINT16], + ) -> pl.Tensor[[128, 128], pl.UINT16]: + tile_a: pl.Tile[[16, 16], pl.UINT16] = pl.load(a, [0, 0], [16, 16]) + tile_b: pl.Tile[[16, 16], pl.UINT32] = pl.load(b, [0, 0], [16, 16]) + tile_c: pl.Tile[[16, 16], pl.UINT16] = pl.shl(tile_a, tile_b) + result: pl.Tensor[[128, 128], pl.UINT16] = pl.store(tile_c, [0, 0], [16, 16], output) + return result + + ir_str = str(Program) + assert "block.shl" in ir_str + + def test_block_shr_preserves_lhs_dtype(self): + """Regression: block.shr result dtype must match LHS dtype, not the promoted type. + + When lhs is UINT16 and rhs is UINT32, the result must be UINT16 (LHS dtype), + consistent with the scalar variant block.shrs which preserves the LHS tile dtype. + """ + + @pl.program + class Program: + @pl.function(type=pl.FunctionType.InCore) + def main( + self, + a: pl.Tensor[[128, 128], pl.UINT16], + b: pl.Tensor[[128, 128], pl.UINT32], + output: pl.Tensor[[128, 128], pl.UINT16], + ) -> pl.Tensor[[128, 128], pl.UINT16]: + tile_a: pl.Tile[[16, 16], pl.UINT16] = pl.load(a, [0, 0], [16, 16]) + tile_b: pl.Tile[[16, 16], pl.UINT32] = pl.load(b, [0, 0], [16, 16]) + tile_c: pl.Tile[[16, 16], pl.UINT16] = pl.shr(tile_a, tile_b) + result: pl.Tensor[[128, 128], pl.UINT16] = pl.store(tile_c, [0, 0], [16, 16], output) + return result + + ir_str = str(Program) + assert "block.shr" in ir_str + + def test_block_prelu(self): + """Test block.prelu operator - element-wise parametric ReLU with slope and tmp buffer.""" + + @pl.program + class Program: + @pl.function(type=pl.FunctionType.InCore) + def main( + self, + a: pl.Tensor[[128, 128], pl.FP32], + output: pl.Tensor[[128, 128], pl.FP32], + ) -> pl.Tensor[[128, 128], pl.FP32]: + tile_x: pl.Tile[[16, 16], pl.FP32] = pl.load(a, [0, 0], [16, 16]) + slope: pl.Tile[[16, 16], pl.FP32] = pl.block.create_tile( + [16, 16], dtype=pl.FP32, target_memory=pl.MemorySpace.UB + ) + tmp: pl.Tile[[16, 16], pl.FP32] = pl.block.create_tile( + [16, 16], dtype=pl.FP32, target_memory=pl.MemorySpace.UB + ) + tile_c: pl.Tile[[16, 16], pl.FP32] = pl.prelu(tile_x, slope, tmp) + result: pl.Tensor[[128, 128], pl.FP32] = pl.store(tile_c, [0, 0], [16, 16], output) + return result + + ir_str = str(Program) + assert "block.prelu" in ir_str + + def test_block_not(self): + """Test block.not operator - element-wise bitwise NOT of a tile (int16/uint16 only).""" + + @pl.program + class Program: + @pl.function(type=pl.FunctionType.InCore) + def main( + self, + a: pl.Tensor[[128, 128], pl.INT16], + output: pl.Tensor[[128, 128], pl.INT16], + ) -> pl.Tensor[[128, 128], pl.INT16]: + tile_a: pl.Tile[[16, 16], pl.INT16] = pl.load(a, [0, 0], [16, 16]) + tile_c: pl.Tile[[16, 16], pl.INT16] = pl.not_(tile_a) + result: pl.Tensor[[128, 128], pl.INT16] = pl.store(tile_c, [0, 0], [16, 16], output) + return result + + ir_str = str(Program) + assert "block.not" in ir_str + + def test_block_addc(self): + """Test block.addc operator - element-wise addition of three tiles.""" + + @pl.program + class Program: + @pl.function(type=pl.FunctionType.InCore) + def main( + self, + a: pl.Tensor[[128, 128], pl.FP32], + b: pl.Tensor[[128, 128], pl.FP32], + c: pl.Tensor[[128, 128], pl.FP32], + output: pl.Tensor[[128, 128], pl.FP32], + ) -> pl.Tensor[[128, 128], pl.FP32]: + tile_a: pl.Tile[[32, 32], pl.FP32] = pl.load(a, [0, 0], [32, 32]) + tile_b: pl.Tile[[32, 32], pl.FP32] = pl.load(b, [0, 0], [32, 32]) + tile_c: pl.Tile[[32, 32], pl.FP32] = pl.load(c, [0, 0], [32, 32]) + tile_out: pl.Tile[[32, 32], pl.FP32] = pl.addc(tile_a, tile_b, tile_c) + result: pl.Tensor[[128, 128], pl.FP32] = pl.store(tile_out, [0, 0], [32, 32], output) + return result + + ir_str = str(Program) + assert "block.addc" in ir_str + + def test_block_subc(self): + """Test block.subc operator - element-wise subtraction of three tiles.""" + + @pl.program + class Program: + @pl.function(type=pl.FunctionType.InCore) + def main( + self, + a: pl.Tensor[[128, 128], pl.FP32], + b: pl.Tensor[[128, 128], pl.FP32], + c: pl.Tensor[[128, 128], pl.FP32], + output: pl.Tensor[[128, 128], pl.FP32], + ) -> pl.Tensor[[128, 128], pl.FP32]: + tile_a: pl.Tile[[32, 32], pl.FP32] = pl.load(a, [0, 0], [32, 32]) + tile_b: pl.Tile[[32, 32], pl.FP32] = pl.load(b, [0, 0], [32, 32]) + tile_c: pl.Tile[[32, 32], pl.FP32] = pl.load(c, [0, 0], [32, 32]) + tile_out: pl.Tile[[32, 32], pl.FP32] = pl.subc(tile_a, tile_b, tile_c) + result: pl.Tensor[[128, 128], pl.FP32] = pl.store(tile_out, [0, 0], [32, 32], output) + return result + + ir_str = str(Program) + assert "block.subc" in ir_str + + def test_block_addsc(self): + """Test block.addsc operator - element-wise addition of tile, scalar, and tile.""" + + @pl.program + class Program: + @pl.function(type=pl.FunctionType.InCore) + def main( + self, + a: pl.Tensor[[128, 128], pl.FP32], + b: pl.Tensor[[128, 128], pl.FP32], + output: pl.Tensor[[128, 128], pl.FP32], + ) -> pl.Tensor[[128, 128], pl.FP32]: + tile_a: pl.Tile[[32, 32], pl.FP32] = pl.load(a, [0, 0], [32, 32]) + tile_b: pl.Tile[[32, 32], pl.FP32] = pl.load(b, [0, 0], [32, 32]) + tile_out: pl.Tile[[32, 32], pl.FP32] = pl.addsc(tile_a, 2.0, tile_b) + result: pl.Tensor[[128, 128], pl.FP32] = pl.store(tile_out, [0, 0], [32, 32], output) + return result + + ir_str = str(Program) + assert "block.addsc" in ir_str + + def test_block_subsc(self): + """Test block.subsc operator - element-wise subtraction of tile, scalar, and tile.""" + + @pl.program + class Program: + @pl.function(type=pl.FunctionType.InCore) + def main( + self, + a: pl.Tensor[[128, 128], pl.FP32], + b: pl.Tensor[[128, 128], pl.FP32], + output: pl.Tensor[[128, 128], pl.FP32], + ) -> pl.Tensor[[128, 128], pl.FP32]: + tile_a: pl.Tile[[32, 32], pl.FP32] = pl.load(a, [0, 0], [32, 32]) + tile_b: pl.Tile[[32, 32], pl.FP32] = pl.load(b, [0, 0], [32, 32]) + tile_out: pl.Tile[[32, 32], pl.FP32] = pl.subsc(tile_a, 2.0, tile_b) + result: pl.Tensor[[128, 128], pl.FP32] = pl.store(tile_out, [0, 0], [32, 32], output) + return result + + ir_str = str(Program) + assert "block.subsc" in ir_str + + def test_block_lrelu(self): + """Test block.lrelu operator - element-wise leaky ReLU with scalar slope.""" + + @pl.program + class Program: + @pl.function(type=pl.FunctionType.InCore) + def main( + self, + a: pl.Tensor[[128, 128], pl.FP32], + output: pl.Tensor[[128, 128], pl.FP32], + ) -> pl.Tensor[[128, 128], pl.FP32]: + tile_a: pl.Tile[[32, 32], pl.FP32] = pl.load(a, [0, 0], [32, 32]) + tile_out: pl.Tile[[32, 32], pl.FP32] = pl.lrelu(tile_a, 0.1) + result: pl.Tensor[[128, 128], pl.FP32] = pl.store(tile_out, [0, 0], [32, 32], output) + return result + + ir_str = str(Program) + assert "block.lrelu" in ir_str + + def test_block_sels(self): + """Test block.sels operator - select between two tiles via integer scalar mode.""" + + @pl.program + class Program: + @pl.function(type=pl.FunctionType.InCore) + def main( + self, + a: pl.Tensor[[128, 128], pl.FP32], + b: pl.Tensor[[128, 128], pl.FP32], + output: pl.Tensor[[128, 128], pl.FP32], + ) -> pl.Tensor[[128, 128], pl.FP32]: + tile_a: pl.Tile[[32, 32], pl.FP32] = pl.load(a, [0, 0], [32, 32]) + tile_b: pl.Tile[[32, 32], pl.FP32] = pl.load(b, [0, 0], [32, 32]) + tile_out: pl.Tile[[32, 32], pl.FP32] = pl.sels(tile_a, tile_b, 1) + result: pl.Tensor[[128, 128], pl.FP32] = pl.store(tile_out, [0, 0], [32, 32], output) + return result + + ir_str = str(Program) + assert "block.sels" in ir_str + + def test_block_sel(self): + """Test block.sel operator - per-element selection between two tiles via mask tile.""" + + @pl.program + class Program: + @pl.function(type=pl.FunctionType.InCore) + def main( + self, + a: pl.Tensor[[128, 128], pl.FP32], + b: pl.Tensor[[128, 128], pl.FP32], + m: pl.Tensor[[128, 128], pl.FP32], + output: pl.Tensor[[128, 128], pl.FP32], + ) -> pl.Tensor[[128, 128], pl.FP32]: + tile_a: pl.Tile[[32, 32], pl.FP32] = pl.load(a, [0, 0], [32, 32]) + tile_b: pl.Tile[[32, 32], pl.FP32] = pl.load(b, [0, 0], [32, 32]) + tile_m: pl.Tile[[32, 32], pl.FP32] = pl.load(m, [0, 0], [32, 32]) + tile_out: pl.Tile[[32, 32], pl.FP32] = pl.sel(tile_m, tile_a, tile_b) + result: pl.Tensor[[128, 128], pl.FP32] = pl.store(tile_out, [0, 0], [32, 32], output) + return result + + ir_str = str(Program) + assert "block.sel" in ir_str + + if __name__ == "__main__": pytest.main([__file__, "-v"])