From dc4cf5642b5888535c6c893ed4319bfd0e8952ad Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Tue, 10 Mar 2026 09:43:04 +0800 Subject: [PATCH] Add tile_buf_array type and lowering support --- include/PTO/IR/PTOOps.td | 41 +++++ include/PTO/IR/PTOTypeDefs.td | 11 ++ include/pto-c/Dialect/PTO.h | 7 + lib/Bindings/Python/PTOModule.cpp | 33 +++- lib/CAPI/Dialect/PTO.cpp | 21 +++ lib/PTO/IR/PTO.cpp | 48 ++++++ lib/PTO/Transforms/PTOViewToMemref.cpp | 152 ++++++++++++++++++ python/pto/dialects/pto.py | 104 ++++++++++++ .../TileBufArray/tile_buf_array_basic.py | 58 +++++++ 9 files changed, 473 insertions(+), 2 deletions(-) create mode 100644 test/samples/TileBufArray/tile_buf_array_basic.py diff --git a/include/PTO/IR/PTOOps.td b/include/PTO/IR/PTOOps.td index 2b57228f..4ad06df2 100644 --- a/include/PTO/IR/PTOOps.td +++ b/include/PTO/IR/PTOOps.td @@ -305,6 +305,47 @@ def SubsetOp : PTO_Op<"subset", [ }]; } +//============================================================================= +// TileBufArray container ops +//============================================================================= + +def MakeTileBufArrayOp : PTO_Op<"make_tile_buf_array", [Pure]> { + let summary = "Build a logical array of tile_buf values (same tile type, non-contiguous allowed)."; + + let arguments = (ins + Variadic:$elements + ); + + let results = (outs + TileBufArrayType:$result + ); + + let hasVerifier = 1; + + let assemblyFormat = [{ + `[` $elements `]` attr-dict `:` type($elements) `->` qualified(type($result)) + }]; +} + +def TileBufArrayGetOp : PTO_Op<"tile_buf_array_get", [Pure]> { + let summary = "Get one tile_buf element from tile_buf_array by index."; + + let arguments = (ins + TileBufArrayType:$array, + Index:$index + ); + + let results = (outs + TileBufType:$result + ); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $array `[` $index `]` attr-dict `:` qualified(type($array)) `->` qualified(type($result)) + }]; +} + // ============================================================================ // SSA TileBuf Config Ops (aliasing views) // ============================================================================ diff --git a/include/PTO/IR/PTOTypeDefs.td b/include/PTO/IR/PTOTypeDefs.td index d4a79746..73fbc1ff 100644 --- a/include/PTO/IR/PTOTypeDefs.td +++ b/include/PTO/IR/PTOTypeDefs.td @@ -184,3 +184,14 @@ def TileBufType : TypeDef { int32_t getPadValueI32() const; // 0 null, 1 zero, 2 max, 3 min }]; } + +// A logical array of tile buffers with the same tile_buf element type. +// Elements may come from unrelated/non-contiguous addresses. +def TileBufArrayType : TypeDef { + let mnemonic = "tile_buf_array"; + let parameters = (ins + "int64_t":$size, + "mlir::Type":$elementType + ); + let assemblyFormat = "`<` $size `x` $elementType `>`"; +} diff --git a/include/pto-c/Dialect/PTO.h b/include/pto-c/Dialect/PTO.h index d97024ac..10f25786 100644 --- a/include/pto-c/Dialect/PTO.h +++ b/include/pto-c/Dialect/PTO.h @@ -60,6 +60,13 @@ MLIR_CAPI_EXPORTED MlirType mlirPTOTileBufTypeGet( MLIR_CAPI_EXPORTED MlirType mlirPTOTileBufTypeGetWithConfig( MlirContext ctx, intptr_t rank, const int64_t *shape, MlirType elementType, MlirAttribute memorySpace, MlirAttribute config); + +// ---- !pto.tile_buf_array> ---- +MLIR_CAPI_EXPORTED bool mlirPTOTypeIsATileBufArrayType(MlirType type); +MLIR_CAPI_EXPORTED MlirType mlirPTOTileBufArrayTypeGet( + MlirContext ctx, int64_t size, MlirType elementType); +MLIR_CAPI_EXPORTED int64_t mlirPTOTileBufArrayTypeGetSize(MlirType type); +MLIR_CAPI_EXPORTED MlirType mlirPTOTileBufArrayTypeGetElementType(MlirType type); // ---- Enum attrs helpers (BLayout/SLayout/PadValue in mlir::pto) ---- MLIR_CAPI_EXPORTED bool mlirPTOAttrIsABLayoutAttr(MlirAttribute attr); MLIR_CAPI_EXPORTED MlirAttribute mlirPTOBLayoutAttrGet(MlirContext ctx, int32_t value); diff --git a/lib/Bindings/Python/PTOModule.cpp b/lib/Bindings/Python/PTOModule.cpp index b322c17b..7f9c4625 100644 --- a/lib/Bindings/Python/PTOModule.cpp +++ b/lib/Bindings/Python/PTOModule.cpp @@ -618,6 +618,35 @@ PYBIND11_MODULE(_pto, m) { return py::none(); }, py::arg("cls"), py::arg("type")); - - populatePTODialectSubmodule(m); + + // ---- TileBufArrayType ---- + mlir_type_subclass( + m, "TileBufArrayType", + [](MlirType t) -> bool { return mlirPTOTypeIsATileBufArrayType(t); }) + .def_classmethod( + "get", + [](py::object cls, int64_t size, MlirType elementType, + MlirContext context) -> py::object { + MlirContext ctx = context; + if (!ctx.ptr) + ctx = mlirTypeGetContext(elementType); + MlirType t = mlirPTOTileBufArrayTypeGet(ctx, size, elementType); + if (mlirTypeIsNull(t)) + return py::none(); + return cls.attr("__call__")(t); + }, + py::arg("cls"), py::arg("size"), py::arg("element_type"), + py::arg("context") = py::none()) + .def_property_readonly( + "size", + [](MlirType self) -> int64_t { + return mlirPTOTileBufArrayTypeGetSize(self); + }) + .def_property_readonly( + "element_type", + [](MlirType self) -> MlirType { + return mlirPTOTileBufArrayTypeGetElementType(self); + }); + + populatePTODialectSubmodule(m); } diff --git a/lib/CAPI/Dialect/PTO.cpp b/lib/CAPI/Dialect/PTO.cpp index 42478c5e..a28a864e 100644 --- a/lib/CAPI/Dialect/PTO.cpp +++ b/lib/CAPI/Dialect/PTO.cpp @@ -166,6 +166,27 @@ bool mlirPTOTypeIsATileBufType(MlirType type) { return unwrap(type).isa(); } +bool mlirPTOTypeIsATileBufArrayType(MlirType type) { + return unwrap(type).isa(); +} + +MlirType mlirPTOTileBufArrayTypeGet(MlirContext ctx, int64_t size, + MlirType elementType) { + MLIRContext *c = unwrap(ctx); + auto ty = mlir::pto::TileBufArrayType::get(c, size, unwrap(elementType)); + return wrap(ty); +} + +int64_t mlirPTOTileBufArrayTypeGetSize(MlirType type) { + auto t = mlir::cast(unwrap(type)); + return t.getSize(); +} + +MlirType mlirPTOTileBufArrayTypeGetElementType(MlirType type) { + auto t = mlir::cast(unwrap(type)); + return wrap(t.getElementType()); +} + MlirType mlirPTOTileBufTypeGet(MlirContext ctx, intptr_t rank, const int64_t *shape, MlirType elementType, MlirAttribute memorySpace) { diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index 73ff0cb0..a234183e 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -4026,6 +4026,54 @@ mlir::LogicalResult mlir::pto::SubsetOp::verify() { return success(); } +LogicalResult mlir::pto::MakeTileBufArrayOp::verify() { + auto elems = getElements(); + if (elems.empty()) + return emitOpError("expects at least one tile_buf element"); + + auto firstTy = llvm::dyn_cast(elems.front().getType()); + if (!firstTy) + return emitOpError("expects tile_buf elements"); + + for (auto v : elems) { + auto ty = llvm::dyn_cast(v.getType()); + if (!ty) + return emitOpError("expects tile_buf elements"); + if (ty != firstTy) + return emitOpError("all elements must have the same tile_buf type"); + } + + auto arrTy = llvm::dyn_cast(getResult().getType()); + if (!arrTy) + return emitOpError("result must be tile_buf_array type"); + if (arrTy.getSize() != static_cast(elems.size())) + return emitOpError("result size must equal number of elements"); + if (arrTy.getElementType() != firstTy) + return emitOpError("result element type must match input tile_buf type"); + + return success(); +} + +LogicalResult mlir::pto::TileBufArrayGetOp::verify() { + auto arrTy = llvm::dyn_cast(getArray().getType()); + auto resTy = llvm::dyn_cast(getResult().getType()); + if (!arrTy || !resTy) + return emitOpError("expects tile_buf_array input and tile_buf result"); + + auto elemTy = llvm::dyn_cast(arrTy.getElementType()); + if (!elemTy) + return emitOpError("array element type must be tile_buf"); + if (elemTy != resTy) + return emitOpError("result type must equal array element type"); + + int64_t idx = 0; + if (getConstIndex(getIndex(), idx)) { + if (idx < 0 || idx >= arrTy.getSize()) + return emitOpError("constant index out of range for tile_buf_array"); + } + return success(); +} + } // namespace pto } // namespace mlir diff --git a/lib/PTO/Transforms/PTOViewToMemref.cpp b/lib/PTO/Transforms/PTOViewToMemref.cpp index 385965d1..fd7c5f90 100644 --- a/lib/PTO/Transforms/PTOViewToMemref.cpp +++ b/lib/PTO/Transforms/PTOViewToMemref.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/AsmState.h" #include "mlir/IR/BuiltinAttributes.h" @@ -21,11 +22,13 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/raw_ostream.h" #include "Utils.h" // 假设包含一些通用的工具函数 #include +#include using namespace mlir; @@ -556,6 +559,155 @@ struct PTOViewToMemrefPass rewriter.replaceOp(op, bindOp.getResult()); } + // ------------------------------------------------------------------ + // Stage 0.75: Lower tile_buf_array container ops (MVP) + // ------------------------------------------------------------------ + SmallVector arrayGets; + func.walk([&](mlir::pto::TileBufArrayGetOp op) { arrayGets.push_back(op); }); + + for (auto op : arrayGets) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + Location loc = op.getLoc(); + + auto make = op.getArray().getDefiningOp(); + if (!make) { + op.emitError("tile_buf_array_get currently requires array from pto.make_tile_buf_array"); + signalPassFailure(); + return; + } + + auto elems = make.getElements(); + if (elems.empty()) { + op.emitError("tile_buf_array_get requires non-empty make_tile_buf_array"); + signalPassFailure(); + return; + } + + int64_t idx = 0; + if (!getConstIndexValue(op.getIndex(), idx)) { + // Dynamic index lowering: + // Avoid creating memref-typed scf.if results (which later EmitC + // lowering cannot always reconcile). Instead, clone each direct + // no-result user behind an if-ladder and bind the selected element + // per branch. + Value indexValue = ensureIndex(rewriter, loc, op.getIndex(), op); + if (!indexValue) { + signalPassFailure(); + return; + } + Value dynGetValue = op.getResult(); + SmallVector users; + SmallPtrSet seen; + for (OpOperand &use : dynGetValue.getUses()) { + Operation *user = use.getOwner(); + if (seen.insert(user).second) + users.push_back(user); + } + + auto cloneUserReplacingDynGet = [&](Operation *user, Value replacement) { + Operation *cloned = rewriter.clone(*user); + for (OpOperand &operand : cloned->getOpOperands()) { + if (operand.get() == dynGetValue) + operand.set(replacement); + } + }; + + for (Operation *user : users) { + if (user == op.getOperation()) + continue; + if (user->getNumRegions() != 0 || user->getNumResults() != 0) { + user->emitError( + "dynamic tile_buf_array_get currently supports only direct no-result users"); + signalPassFailure(); + return; + } + + // Single element array: trivial replacement. + if (elems.size() == 1) { + rewriter.setInsertionPoint(user); + cloneUserReplacingDynGet(user, elems.front()); + user->erase(); + continue; + } + + std::function buildElseLadder; + buildElseLadder = [&](Block *block, size_t elemIdx) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(block->getTerminator()); + + if (elemIdx + 1 == elems.size()) { + cloneUserReplacingDynGet(user, elems[elemIdx]); + return; + } + + Value cIdx = rewriter.create( + loc, static_cast(elemIdx)); + Value cond = rewriter.create( + loc, arith::CmpIPredicate::eq, indexValue, cIdx); + auto ifOp = rewriter.create( + loc, TypeRange{}, cond, /*withElseRegion=*/true); + + { + OpBuilder::InsertionGuard thenGuard(rewriter); + rewriter.setInsertionPoint(ifOp.thenBlock()->getTerminator()); + cloneUserReplacingDynGet(user, elems[elemIdx]); + } + + buildElseLadder(ifOp.elseBlock(), elemIdx + 1); + }; + + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(user); + Value cIdx0 = rewriter.create(loc, 0); + Value cond0 = rewriter.create( + loc, arith::CmpIPredicate::eq, indexValue, cIdx0); + auto ifOp = rewriter.create( + loc, TypeRange{}, cond0, /*withElseRegion=*/true); + + { + OpBuilder::InsertionGuard thenGuard(rewriter); + rewriter.setInsertionPoint(ifOp.thenBlock()->getTerminator()); + cloneUserReplacingDynGet(user, elems[0]); + } + + buildElseLadder(ifOp.elseBlock(), 1); + } + + user->erase(); + } + + if (!op->use_empty()) { + op.emitError("dynamic tile_buf_array_get still has users after lowering"); + signalPassFailure(); + return; + } + op.erase(); + continue; + } + + if (idx < 0 || idx >= static_cast(elems.size())) { + op.emitError("tile_buf_array_get index out of range in lowering"); + signalPassFailure(); + return; + } + + rewriter.replaceOp(op, elems[static_cast(idx)]); + } + + SmallVector makeArrays; + func.walk([&](mlir::pto::MakeTileBufArrayOp op) { makeArrays.push_back(op); }); + + for (auto op : makeArrays) { + if (!op->use_empty()) { + op.emitError("unlowered tile_buf_array value remains after tile_buf_array_get rewrite"); + signalPassFailure(); + return; + } + op.erase(); + } + // ------------------------------------------------------------------ // Stage 1: Lower pto.make_tensor_view -> memref.reinterpret_cast // ------------------------------------------------------------------ diff --git a/python/pto/dialects/pto.py b/python/pto/dialects/pto.py index cee9be78..cd357cf1 100644 --- a/python/pto/dialects/pto.py +++ b/python/pto/dialects/pto.py @@ -30,6 +30,7 @@ def _load_local_pto_ext(): PartitionTensorViewType = _pto_mod.PartitionTensorViewType TileType = _pto_mod.TileType TileBufType = _pto_mod.TileBufType +TileBufArrayType = _pto_mod.TileBufArrayType AddressSpace = _pto_mod.AddressSpace AddressSpaceAttr = _pto_mod.AddressSpaceAttr TileBufConfigAttr = _pto_mod.TileBufConfigAttr @@ -64,6 +65,10 @@ def _load_local_pto_ext(): "PartitionTensorViewType", "TileType", "TileBufType", + "TileBufArrayType", + "TileBufArray", + "make_tile_buf_array", + "tile_buf_array_get", "AddressSpace", "AddressSpaceAttr", "BLayout","BLayoutAttr", "SLayout","SLayoutAttr", @@ -175,6 +180,105 @@ def store_scalar(ptr, offset, value, *, loc=None, ip=None): ip=ip, ) + +# ----------------------------------------------------------------------------- +# TileBufArray helpers (array-like ergonomics) +# ----------------------------------------------------------------------------- +def _to_index_value(index, *, loc=None, ip=None): + if isinstance(index, int): + idx_ty = _ods_ir.IndexType.get() + idx_attr = _ods_ir.IntegerAttr.get(idx_ty, index) + cst = _ods_ir.Operation.create( + "arith.constant", + results=[idx_ty], + attributes={"value": idx_attr}, + loc=loc, + ip=ip, + ) + return cst.results[0] + return _pto_ops_gen._get_op_result_or_value(index) + + +def _to_tile_buf_array_value(array): + if isinstance(array, TileBufArray): + return array.value + return _pto_ops_gen._get_op_result_or_value(array) + + +def make_tile_buf_array(elements, *, loc=None, ip=None): + if not elements: + raise ValueError("make_tile_buf_array expects at least one element") + + ops = [_pto_ops_gen._get_op_result_or_value(v) for v in elements] + elem_ty = ops[0].type + for i, v in enumerate(ops[1:], start=1): + if v.type != elem_ty: + raise ValueError( + f"all elements must have the same type, got index 0={elem_ty} and index {i}={v.type}" + ) + + arr_ty = TileBufArrayType.get(len(ops), elem_ty, elem_ty.context) + op = _ods_ir.Operation.create( + "pto.make_tile_buf_array", + results=[arr_ty], + operands=ops, + loc=loc, + ip=ip, + ) + return op.results[0] + + +def tile_buf_array_get(array, index, *, loc=None, ip=None): + arr_val = _to_tile_buf_array_value(array) + arr_ty = TileBufArrayType(arr_val.type) + + idx = index + if isinstance(index, int) and index < 0: + idx = int(arr_ty.size) + index + idx_val = _to_index_value(idx, loc=loc, ip=ip) + + op = _ods_ir.Operation.create( + "pto.tile_buf_array_get", + results=[arr_ty.element_type], + operands=[arr_val, idx_val], + loc=loc, + ip=ip, + ) + return op.results[0] + + +class TileBufArray: + def __init__(self, value): + v = _pto_ops_gen._get_op_result_or_value(value) + self._value = v + self._type = TileBufArrayType(v.type) + + @classmethod + def from_elements(cls, elements, *, loc=None, ip=None): + return cls(make_tile_buf_array(elements, loc=loc, ip=ip)) + + @property + def value(self): + return self._value + + @property + def type(self): + return self._type + + @property + def size(self): + return int(self._type.size) + + @property + def element_type(self): + return self._type.element_type + + def __len__(self): + return self.size + + def __getitem__(self, index): + return tile_buf_array_get(self._value, index) + # ----------------------------------------------------------------------------- # Export enum aliases for terse calls: pto.record_event(TLOAD, TLOAD, EVENT_ID0) # ----------------------------------------------------------------------------- diff --git a/test/samples/TileBufArray/tile_buf_array_basic.py b/test/samples/TileBufArray/tile_buf_array_basic.py new file mode 100644 index 00000000..38194592 --- /dev/null +++ b/test/samples/TileBufArray/tile_buf_array_basic.py @@ -0,0 +1,58 @@ +from mlir.ir import Context, Location, Module, InsertionPoint +from mlir.dialects import func, arith, pto, scf +from mlir.ir import F32Type, IndexType + + +def build(): + with Context() as ctx: + pto.register_dialect(ctx, load=True) + + with Location.unknown(ctx): + m = Module.create() + + f32 = F32Type.get(ctx) + idx = IndexType.get(ctx) + + vec = pto.AddressSpaceAttr.get(pto.AddressSpace.VEC, ctx) + bl = pto.BLayoutAttr.get(pto.BLayout.RowMajor, ctx) + sl = pto.SLayoutAttr.get(pto.SLayout.NoneBox, ctx) + pd = pto.PadValueAttr.get(pto.PadValue.Null, ctx) + fractal_ab_size = pto.TileConfig.fractalABSize + cfg = pto.TileBufConfigAttr.get(bl, sl, fractal_ab_size, pd, ctx) + + tile_ty = pto.TileBufType.get([16, 16], f32, vec, [16, 16], cfg, ctx) + fn_ty = func.FunctionType.get([], []) + with InsertionPoint(m.body): + fn = func.FuncOp("tile_buf_array_basic", fn_ty) + entry = fn.add_entry_block() + + with InsertionPoint(entry): + c0 = arith.ConstantOp(idx, 0).result + c1 = arith.ConstantOp(idx, 1).result + c2 = arith.ConstantOp(idx, 2).result + + tile0 = pto.AllocTileOp(tile_ty).result + tile1 = pto.AllocTileOp(tile_ty).result + + # Array-style API: + # arr[0] -> constant-index get + # arr[iv] -> dynamic-index get + arr = pto.TileBufArray.from_elements([tile0, tile1]) + slot0 = arr[0] + slot1 = arr[c1] + pto.TAddOp(slot0, slot1, slot0) + + loop = scf.ForOp(c0, c2, c1, []) + with InsertionPoint(loop.body): + dyn_slot = arr[loop.induction_variable] + pto.TAddOp(dyn_slot, dyn_slot, dyn_slot) + scf.YieldOp([]) + + func.ReturnOp([]) + + m.operation.verify() + return m + + +if __name__ == "__main__": + print(build())