diff --git a/python/pypto/ir/op/block_ops.py b/python/pypto/ir/op/block_ops.py index dc262c75..a4704a65 100644 --- a/python/pypto/ir/op/block_ops.py +++ b/python/pypto/ir/op/block_ops.py @@ -279,6 +279,20 @@ def full( return _ir_core.create_op_call("block.full", [shape_tuple, value_expr], kwargs, actual_span) +def fillpad(tile: Expr, span: Span | None = None) -> Call: + """Fill tile with padding for remaining elements. + + Args: + tile: Input tile (TileType) + span: Optional source span for debugging (auto-captured if not provided) + + Returns: + Call expression that returns the filled and padded tile + """ + actual_span = _get_span_or_capture(span) + return _ir_core.create_op_call("block.fillpad", [tile], {}, actual_span) + + # ============================================================================ # Element-wise Operations # ============================================================================ @@ -583,7 +597,7 @@ def cast( raise ValueError(f"Invalid rounding mode '{mode}'. Expected one of {list(modes.keys())}.") actual_span = _get_span_or_capture(span) - kwargs: dict[str, Any] = {"target_dtype": target_type, "mode": mode_val} + kwargs: dict[str, Any] = {"target_type": target_type, "mode": mode_val} return _ir_core.create_op_call("block.cast", [tile], kwargs, actual_span) diff --git a/python/pypto/language/op/block_ops.py b/python/pypto/language/op/block_ops.py index d8c74f6c..62d30a9e 100644 --- a/python/pypto/language/op/block_ops.py +++ b/python/pypto/language/op/block_ops.py @@ -24,6 +24,7 @@ "move", "ub_copy", "full", + "fillpad", "get_block_idx", "add", "sub", @@ -214,7 +215,6 @@ def full(shape: list[int], dtype: DataType, value: int | float) -> Tile: shape: Shape of the tile dtype: Data type of the tile value: filling scalar - span: Optional source span for debugging (auto-captured if not provided) Returns: Tile wrapping the full operation @@ -223,6 +223,19 @@ def full(shape: list[int], dtype: DataType, value: int | float) -> Tile: return Tile(expr=call_expr) +def fillpad(tile: Tile) -> Tile: + """Fill tile with padding for remaining elements. + + Args: + tile: Input tile + + Returns: + Tile wrapping the fillpad operation + """ + call_expr = _ir_ops.fillpad(tile.unwrap()) + return Tile(expr=call_expr) + + def get_block_idx() -> Scalar: """Get the current block index. diff --git a/src/backend/910B_CCE/backend_910b_cce_ops.cpp b/src/backend/910B_CCE/backend_910b_cce_ops.cpp index 99c5b527..449ef98e 100644 --- a/src/backend/910B_CCE/backend_910b_cce_ops.cpp +++ b/src/backend/910B_CCE/backend_910b_cce_ops.cpp @@ -673,6 +673,12 @@ REGISTER_BACKEND_OP(Backend910B_CCE, "block.row_expand_add") return MakeBinaryElementwiseCodegenCCE("TROWEXPANDADD", op, codegen); }); +REGISTER_BACKEND_OP(Backend910B_CCE, "block.fillpad") + .set_pipe(ir::PipeType::V) + .f_codegen([](const ir::CallPtr& op, codegen::CodegenBase& codegen) { + return MakeUnaryCodegenCCE("TFILLPAD", op, codegen); + }); + REGISTER_BACKEND_OP(Backend910B_CCE, "block.col_expand") .set_pipe(ir::PipeType::V) .f_codegen([](const ir::CallPtr& op, codegen::CodegenBase& codegen) { diff --git a/src/backend/910B_PTO/backend_910b_pto_ops.cpp b/src/backend/910B_PTO/backend_910b_pto_ops.cpp index ec2b3e1e..f2459e0d 100644 --- a/src/backend/910B_PTO/backend_910b_pto_ops.cpp +++ b/src/backend/910B_PTO/backend_910b_pto_ops.cpp @@ -112,7 +112,19 @@ static std::string MakeTernaryTileTileCodegenPTO(const std::string& pto_op_name, return ""; } -// Helper function for binary Tile-Scalar operations +// Helper function for full op +static std::string MakeFullCodegenPTO(const std::string& pto_op_name, const CallPtr& op, + codegen::CodegenBase& codegen_base) { + auto& codegen = dynamic_cast(codegen_base); + CHECK(op->args_.size() == 2) << "full op requires 3 arguments." + << op->args_.size(); // Actually 2 args, two of them are conbined! + std::string scalar = codegen.GetExprAsCode(op->args_[1]); + std::string dst = codegen.GetCurrentResultTarget(); + codegen.Emit(pto_op_name + " " + "ins(" + scalar + ") outs(" + dst + ")"); + return ""; +} + +// Helper function for Binary Tile-Scalar operations static std::string MakeBinaryTileScalarCodegenPTO(const std::string& pto_op_name, const CallPtr& op, codegen::CodegenBase& codegen_base) { auto& codegen = dynamic_cast(codegen_base); @@ -175,6 +187,33 @@ static std::string MakeTernaryGEMVCodegenPTO(const std::string& pto_op_name, con return ""; } +// Helper function for padding operations +static std::string MakeFillPadCodegenPTO(const std::string& pto_op_name, const CallPtr& op, + codegen::CodegenBase& codegen_base) { + auto& codegen = dynamic_cast(codegen_base); + CHECK(op->args_.size() == 1) << "Fill pad op requires 1 argument."; + codegen.Emit(pto_op_name + " " + GenerateInsOutsClause(op, codegen)); + return ""; +} + +// Helper function for Ternary Data Movement/Layout operations +static std::string MakeTernaryDataMoveLayoutCodegenPTO(const std::string& pto_op_name, const CallPtr& op, + codegen::CodegenBase& codegen_base) { + auto& codegen = dynamic_cast(codegen_base); + CHECK(op->args_.size() == 3) << "Ternary move/layout op requires 3 arguments."; + codegen.Emit(pto_op_name + " " + GenerateInsOutsClause(op, codegen)); + return ""; +} + +// Helper function for Binary Axis Reduction/Expansion operations +static std::string MakeBinaryAxisCodegenPTO(const std::string& pto_op_name, const CallPtr& op, + codegen::CodegenBase& codegen_base) { + auto& codegen = dynamic_cast(codegen_base); + CHECK(op->args_.size() == 2) << "Binary Axis op requires 2 arguments."; + codegen.Emit(pto_op_name + " " + GenerateInsOutsClause(op, codegen)); + return ""; +} + // block.load: emit pto.subview + pto.tload (same format as original IR layer codegen) static std::string MakeBlockLoadCodegenPTO(const CallPtr& op, codegen::CodegenBase& codegen_base) { auto& codegen = dynamic_cast(codegen_base); @@ -524,7 +563,13 @@ REGISTER_BACKEND_OP(Backend910B_PTO, "block.mins") return MakeBinaryTileScalarCodegenPTO("pto.tmins", op, codegen); }); -// Not Implemented: tlrelu tcmps taddsc tsubsc tsels texpands +REGISTER_BACKEND_OP(Backend910B_PTO, "block.full") + .set_pipe(ir::PipeType::V) + .f_codegen([](const ir::CallPtr& op, codegen::CodegenBase& codegen) { + return MakeFullCodegenPTO("pto.texpands", op, codegen); + }); + +// Not Implemented: tlrelu tcmps taddsc tsubsc tsels // ============================================================================ // Matrix Multiplication Operations @@ -584,6 +629,66 @@ REGISTER_BACKEND_OP(Backend910B_PTO, "block.gemv_bias") return MakeTernaryGEMVCodegenPTO("pto.tgemv.bias", op, codegen); }); +// ============================================================================ +// Data Movement/Layout Operations +// ============================================================================ + +REGISTER_BACKEND_OP(Backend910B_PTO, "block.transpose") + .set_pipe(ir::PipeType::V) + .f_codegen([](const ir::CallPtr& op, codegen::CodegenBase& codegen) { + return MakeTernaryDataMoveLayoutCodegenPTO("pto.ttrans", op, codegen); + }); + +// ============================================================================ +// Axis reduction/expansion Operations +// ============================================================================ + +REGISTER_BACKEND_OP(Backend910B_PTO, "block.row_sum") + .set_pipe(ir::PipeType::V) + .f_codegen([](const ir::CallPtr& op, codegen::CodegenBase& codegen) { + return MakeBinaryAxisCodegenPTO("pto.trowsum", op, codegen); + }); + +REGISTER_BACKEND_OP(Backend910B_PTO, "block.row_max") + .set_pipe(ir::PipeType::V) + .f_codegen([](const ir::CallPtr& op, codegen::CodegenBase& codegen) { + return MakeBinaryAxisCodegenPTO("pto.trowmax", op, codegen); + }); + +REGISTER_BACKEND_OP(Backend910B_PTO, "block.row_min") + .set_pipe(ir::PipeType::V) + .f_codegen([](const ir::CallPtr& op, codegen::CodegenBase& codegen) { + return MakeBinaryAxisCodegenPTO("pto.trowmin", op, codegen); + }); + +REGISTER_BACKEND_OP(Backend910B_PTO, "block.row_expand_div") + .set_pipe(ir::PipeType::V) + .f_codegen([](const ir::CallPtr& op, codegen::CodegenBase& codegen) { + return MakeBinaryAxisCodegenPTO("pto.trowexpanddiv", op, codegen); + }); + +REGISTER_BACKEND_OP(Backend910B_PTO, "block.row_expand_mul") + .set_pipe(ir::PipeType::V) + .f_codegen([](const ir::CallPtr& op, codegen::CodegenBase& codegen) { + return MakeBinaryAxisCodegenPTO("pto.trowexpandmul", op, codegen); + }); + +REGISTER_BACKEND_OP(Backend910B_PTO, "block.row_expand_sub") + .set_pipe(ir::PipeType::V) + .f_codegen([](const ir::CallPtr& op, codegen::CodegenBase& codegen) { + return MakeBinaryAxisCodegenPTO("pto.trowexpandsub", op, codegen); + }); + +// ============================================================================ +// Padding Operations +// ============================================================================ + +REGISTER_BACKEND_OP(Backend910B_PTO, "block.fillpad") + .set_pipe(ir::PipeType::V) + .f_codegen([](const ir::CallPtr& op, codegen::CodegenBase& codegen) { + return MakeFillPadCodegenPTO("pto.tfillpad", op, codegen); + }); + // ============================================================================ // Memory Operations // ============================================================================ diff --git a/src/ir/op/block_ops/elementwise.cpp b/src/ir/op/block_ops/elementwise.cpp index 2e847e85..5c015380 100644 --- a/src/ir/op/block_ops/elementwise.cpp +++ b/src/ir/op/block_ops/elementwise.cpp @@ -268,5 +268,23 @@ REGISTER_OP("block.cmps") return DeduceBlockCmpType(args, kwargs, "block.cmps", true); }); +REGISTER_OP("block.fillpad") + .set_op_category("BlockOp") + .set_description("Fill destination tile with source tile data and pad remaining elements") + .add_argument("tile", "Input tile (TileType)") + .f_deduce_type([](const std::vector& args, + const std::vector>& kwargs) { + CHECK(args.size() == 1) << "The operator block.fillpad requires exactly 1 argument, but got " + << args.size(); + + // Argument must be TileType + auto tile_type = As(args[0]->GetType()); + CHECK(tile_type) << "The operator block.fillpad requires first argument to be a TileType, but got " + << args[0]->GetType()->TypeName(); + + // Return same TileType + return std::make_shared(tile_type->shape_, tile_type->dtype_); + }); + } // namespace ir } // namespace pypto diff --git a/tests/ut/codegen/test_pto_codegen_paged_attn.py b/tests/ut/codegen/test_pto_codegen_paged_attn.py new file mode 100644 index 00000000..fa7d911a --- /dev/null +++ b/tests/ut/codegen/test_pto_codegen_paged_attn.py @@ -0,0 +1,182 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- +"""Unit tests for PTO backend codegen for paged attention operations.""" + +import pypto.language as pl +import pytest +from pypto import backend, ir +from pypto.backend import BackendType +from pypto.ir.pass_manager import OptimizationStrategy, PassManager +from pypto.pypto_core import codegen + + +@pl.program +class PagedAttention: + """ + Case1: + batch 256 + num_heads 16 + kv_head_num 1 + head_dim 128 + block_size 128 + max_num_blocks_per_req 256 + scale_value 1 + Q(256, 16, 128) BF16 + K(16384, 128, 1, 128) BF16 + V(16384, 128, 1, 128) BF16 + block_table(256, 256) INT32 + context_lens(256, ) INT32 + out(524288, ) FP32 + """ + + """ + orchestration config + + q_tile_size 16 + + num_head_tiles 1 + sij_size 16 * 128 float + pij_size 16 * 128 uint16 + mij_size 16 float + lij_size 16 float + oi_new_size 16 * 128 float + + mi_size 16 float + li_size 16 float + oi_size 16 * 128 float + + qi(256, 16, 128) BF16 + out() + """ + + # M K N 16 128 128 + + # AIC kernels + + @pl.function + def qk_matmul( + self, + qi: pl.Tensor[[16, 128], pl.BF16], + kj: pl.Tensor[[128, 128], pl.BF16], + s_ij: pl.Tensor[[16, 128], pl.FP32], + ) -> pl.Tensor[[16, 128], pl.FP32]: + q_tile: pl.Tile[[16, 128], pl.BF16] = pl.load(qi, [0, 0], [16, 128]) + k_tile: pl.Tile[[128, 128], pl.BF16] = pl.load(kj, [0, 0], [128, 128]) + k_tile_T: pl.Tile[[128, 128], pl.BF16] = pl.transpose(k_tile, axis1=0, axis2=1) + s_tile: pl.Tile[[16, 128], pl.FP32] = pl.block.matmul(q_tile, k_tile_T) + updated_sij: pl.Tensor[[16, 128], pl.FP32] = pl.store(s_tile, [0, 0], [16, 128], s_ij) + return updated_sij + + @pl.function + def pv_matmul( + self, + pij: pl.Tensor[[16, 128], pl.BF16], + vj: pl.Tensor[[128, 128], pl.BF16], + oij: pl.Tensor[[16, 128], pl.FP32], + ) -> pl.Tensor[[16, 128], pl.FP32]: + p_tile: pl.Tile[[16, 128], pl.BF16] = pl.load(pij, [0, 0], [16, 128]) + v_tile: pl.Tile[[128, 128], pl.BF16] = pl.load(vj, [0, 0], [128, 128]) + o_tile: pl.Tile[[16, 128], pl.FP32] = pl.block.matmul(p_tile, v_tile) + updated_oij: pl.Tensor[[16, 128], pl.FP32] = pl.store(o_tile, [0, 0], [16, 128], oij) + return updated_oij + + # AIV kernels + + @pl.function + def softmax_prepare( + self, + sij: pl.Tensor[[16, 128], pl.FP32], + pij: pl.Tensor[[16, 128], pl.BF16], + mij: pl.Tensor[[16, 1], pl.FP32], + lij: pl.Tensor[[16, 1], pl.FP32], + scale_value: pl.Scalar[pl.FP32], + ): + sij_tile: pl.Tile[[16, 128], pl.FP32] = pl.load(sij, [0, 0], [16, 128]) + # sij_dyn_tile: pl.Tile[[16, 128], pl.FP32] = pl.load( + # sij, [0, 0], [16, 128] + # ) + # TODO: + pij_tile: pl.Tile[[16, 128], pl.FP32] = pl.load(pij, [0, 0], [16, 128]) + tmp_tile: pl.Tile[[16, 128], pl.FP32] = pl.block.sub(sij_tile, sij_tile) + sij_tile = pl.block.fillpad(sij_tile) + sij_tile = pl.block.muls(sij_tile, scale_value) + max_tile: pl.Tile[[16, 1], pl.FP32] = pl.block.row_max(sij_tile, tmp_tile) + pij_tile = pl.block.row_expand_sub(sij_tile, max_tile) + pij_tile = pl.block.exp(pij_tile) + pij_bf16_tile = pl.block.cast(pij_tile, mode="round", target_type=pl.BF16) + pij_tile = pl.block.cast(pij_bf16_tile, mode="round", target_type=pl.FP16) + sum_tile: pl.Tile[[16, 1], pl.FP32] = pl.block.row_sum(pij_tile, tmp_tile) + pl.store(max_tile, [0, 0], [16, 1], mij) + pl.store(sum_tile, [0, 0], [16, 1], lij) + pl.store(pij_bf16_tile, [0, 0], [16, 128], pij) + + @pl.function + def online_update( + self, + mij: pl.Tensor[[16, 1], pl.FP32], + lij: pl.Tensor[[16, 1], pl.FP32], + oi_new: pl.Tensor[[16, 128], pl.FP32], + mi: pl.Tensor[[16, 1], pl.FP32], + li: pl.Tensor[[16, 1], pl.FP32], + oi: pl.Tensor[[16, 128], pl.FP32], + dst: pl.Tensor[[16, 128], pl.FP32], + ): + oi_new_tile: pl.Tile[[16, 128], pl.FP32] = pl.load(oi_new, [0, 0], [16, 128]) + oi_tile: pl.Tile[[16, 128], pl.FP32] = pl.load(oi, [0, 0], [16, 128]) + mij_tile: pl.Tile[[16, 1], pl.FP32] = pl.load(mij, [0, 0], [16, 1]) + lij_tile: pl.Tile[[16, 1], pl.FP32] = pl.load(lij, [0, 0], [16, 1]) + mi_tile: pl.Tile[[16, 1], pl.FP32] = pl.load(mi, [0, 0], [16, 1]) + li_tile: pl.Tile[[16, 1], pl.FP32] = pl.load(li, [0, 0], [16, 1]) + + mi_new_tile: pl.Tile[[16, 1], pl.FP32] = pl.block.maximum(mi_tile, mij_tile) + + alpha_tile: pl.Tile[[16, 1], pl.FP32] = pl.block.sub(mi_tile, mi_new_tile) + alpha_tile = pl.block.exp(alpha_tile) + + beta_tile: pl.Tile[[16, 1], pl.FP32] = pl.block.sub(mij_tile, mi_new_tile) + beta_tile = pl.block.exp(beta_tile) + + li_scaled: pl.Tile[[16, 1], pl.FP32] = pl.block.mul(alpha_tile, li_tile) + lij_scaled: pl.Tile[[16, 1], pl.FP32] = pl.block.mul(beta_tile, lij_tile) + li_new_tile: pl.Tile[[16, 1], pl.FP32] = pl.block.add(li_scaled, lij_scaled) + + oi_scaled: pl.Tile[[16, 128], pl.FP32] = pl.block.row_expand_mul(oi_tile, alpha_tile) + oi_new_scaled: pl.Tile[[16, 128], pl.FP32] = pl.block.row_expand_mul(oi_new_tile, beta_tile) + oi_updated_tile: pl.Tile[[16, 128], pl.FP32] = pl.block.add(oi_scaled, oi_new_scaled) + + dst_tile: pl.Tile[[16, 128], pl.FP32] = pl.block.row_expand_div(oi_updated_tile, li_new_tile) + + pl.store(mi_new_tile, [0, 0], [16, 1], mi) + pl.store(li_new_tile, [0, 0], [16, 1], li) + pl.store(oi_updated_tile, [0, 0], [16, 128], oi) + pl.store(dst_tile, [0, 0], [16, 128], dst) + + +def test_block_ops_codegen(): + backend.reset_for_testing() + backend.set_backend_type(BackendType.PTO) + + program = PagedAttention + pm = PassManager.get_strategy(OptimizationStrategy.PTOAS) + optimized_program = pm.run_passes(program) + codegen_instance = codegen.PTOCodegen() + + for func in optimized_program.functions.values(): + func_name = func.name + single_func_program = ir.Program([func], func_name, optimized_program.span) + mlir_code = codegen_instance.generate(single_func_program) + assert mlir_code, f"Generated MLIR code for {func_name} should not be empty" + + +if __name__ == "__main__": + pytest.main([__file__])