Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions include/PTO/IR/PTOOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,47 @@ def SubsetOp : PTO_Op<"subset", [
}];
}

//=============================================================================
// TileBufArray container ops
//=============================================================================

def MakeTileBufArrayOp : PTO_Op<"make_tile_buf_array", [Pure]> {
let summary = "Build a logical array of tile_buf values (same tile type, non-contiguous allowed).";

let arguments = (ins
Variadic<TileBufType>:$elements
);

let results = (outs
TileBufArrayType:$result
);

let hasVerifier = 1;

let assemblyFormat = [{
`[` $elements `]` attr-dict `:` type($elements) `->` qualified(type($result))
}];
}

def TileBufArrayGetOp : PTO_Op<"tile_buf_array_get", [Pure]> {
let summary = "Get one tile_buf element from tile_buf_array by index.";

let arguments = (ins
TileBufArrayType:$array,
Index:$index
);

let results = (outs
TileBufType:$result
);

let hasVerifier = 1;

let assemblyFormat = [{
$array `[` $index `]` attr-dict `:` qualified(type($array)) `->` qualified(type($result))
}];
}

// ============================================================================
// SSA TileBuf Config Ops (aliasing views)
// ============================================================================
Expand Down
11 changes: 11 additions & 0 deletions include/PTO/IR/PTOTypeDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,14 @@ def TileBufType : TypeDef<PTO_Dialect, "TileBuf"> {
int32_t getPadValueI32() const; // 0 null, 1 zero, 2 max, 3 min
}];
}

// A logical array of tile buffers with the same tile_buf element type.
// Elements may come from unrelated/non-contiguous addresses.
def TileBufArrayType : TypeDef<PTO_Dialect, "TileBufArray"> {
let mnemonic = "tile_buf_array";
let parameters = (ins
"int64_t":$size,
"mlir::Type":$elementType
);
let assemblyFormat = "`<` $size `x` $elementType `>`";
}
7 changes: 7 additions & 0 deletions include/pto-c/Dialect/PTO.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ MLIR_CAPI_EXPORTED MlirType mlirPTOTileBufTypeGet(
MLIR_CAPI_EXPORTED MlirType mlirPTOTileBufTypeGetWithConfig(
MlirContext ctx, intptr_t rank, const int64_t *shape,
MlirType elementType, MlirAttribute memorySpace, MlirAttribute config);

// ---- !pto.tile_buf_array<N x !pto.tile_buf<...>> ----
MLIR_CAPI_EXPORTED bool mlirPTOTypeIsATileBufArrayType(MlirType type);
MLIR_CAPI_EXPORTED MlirType mlirPTOTileBufArrayTypeGet(
MlirContext ctx, int64_t size, MlirType elementType);
MLIR_CAPI_EXPORTED int64_t mlirPTOTileBufArrayTypeGetSize(MlirType type);
MLIR_CAPI_EXPORTED MlirType mlirPTOTileBufArrayTypeGetElementType(MlirType type);
// ---- Enum attrs helpers (BLayout/SLayout/PadValue in mlir::pto) ----
MLIR_CAPI_EXPORTED bool mlirPTOAttrIsABLayoutAttr(MlirAttribute attr);
MLIR_CAPI_EXPORTED MlirAttribute mlirPTOBLayoutAttrGet(MlirContext ctx, int32_t value);
Expand Down
33 changes: 31 additions & 2 deletions lib/Bindings/Python/PTOModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,35 @@ PYBIND11_MODULE(_pto, m) {
return py::none();
},
py::arg("cls"), py::arg("type"));

populatePTODialectSubmodule(m);

// ---- TileBufArrayType ----
mlir_type_subclass(
m, "TileBufArrayType",
[](MlirType t) -> bool { return mlirPTOTypeIsATileBufArrayType(t); })
.def_classmethod(
"get",
[](py::object cls, int64_t size, MlirType elementType,
MlirContext context) -> py::object {
MlirContext ctx = context;
if (!ctx.ptr)
ctx = mlirTypeGetContext(elementType);
MlirType t = mlirPTOTileBufArrayTypeGet(ctx, size, elementType);
if (mlirTypeIsNull(t))
return py::none();
return cls.attr("__call__")(t);
},
py::arg("cls"), py::arg("size"), py::arg("element_type"),
py::arg("context") = py::none())
.def_property_readonly(
"size",
[](MlirType self) -> int64_t {
return mlirPTOTileBufArrayTypeGetSize(self);
})
.def_property_readonly(
"element_type",
[](MlirType self) -> MlirType {
return mlirPTOTileBufArrayTypeGetElementType(self);
});

populatePTODialectSubmodule(m);
}
21 changes: 21 additions & 0 deletions lib/CAPI/Dialect/PTO.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,27 @@ bool mlirPTOTypeIsATileBufType(MlirType type) {
return unwrap(type).isa<mlir::pto::TileBufType>();
}

bool mlirPTOTypeIsATileBufArrayType(MlirType type) {
return unwrap(type).isa<mlir::pto::TileBufArrayType>();
}

MlirType mlirPTOTileBufArrayTypeGet(MlirContext ctx, int64_t size,
MlirType elementType) {
MLIRContext *c = unwrap(ctx);
auto ty = mlir::pto::TileBufArrayType::get(c, size, unwrap(elementType));
return wrap(ty);
}

int64_t mlirPTOTileBufArrayTypeGetSize(MlirType type) {
auto t = mlir::cast<mlir::pto::TileBufArrayType>(unwrap(type));
return t.getSize();
}

MlirType mlirPTOTileBufArrayTypeGetElementType(MlirType type) {
auto t = mlir::cast<mlir::pto::TileBufArrayType>(unwrap(type));
return wrap(t.getElementType());
}

MlirType mlirPTOTileBufTypeGet(MlirContext ctx, intptr_t rank,
const int64_t *shape, MlirType elementType,
MlirAttribute memorySpace) {
Expand Down
48 changes: 48 additions & 0 deletions lib/PTO/IR/PTO.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4026,6 +4026,54 @@ mlir::LogicalResult mlir::pto::SubsetOp::verify() {
return success();
}

LogicalResult mlir::pto::MakeTileBufArrayOp::verify() {
auto elems = getElements();
if (elems.empty())
return emitOpError("expects at least one tile_buf element");

auto firstTy = llvm::dyn_cast<TileBufType>(elems.front().getType());
if (!firstTy)
return emitOpError("expects tile_buf elements");

for (auto v : elems) {
auto ty = llvm::dyn_cast<TileBufType>(v.getType());
if (!ty)
return emitOpError("expects tile_buf elements");
if (ty != firstTy)
return emitOpError("all elements must have the same tile_buf type");
}

auto arrTy = llvm::dyn_cast<TileBufArrayType>(getResult().getType());
if (!arrTy)
return emitOpError("result must be tile_buf_array type");
if (arrTy.getSize() != static_cast<int64_t>(elems.size()))
return emitOpError("result size must equal number of elements");
if (arrTy.getElementType() != firstTy)
return emitOpError("result element type must match input tile_buf type");

return success();
}

LogicalResult mlir::pto::TileBufArrayGetOp::verify() {
auto arrTy = llvm::dyn_cast<TileBufArrayType>(getArray().getType());
auto resTy = llvm::dyn_cast<TileBufType>(getResult().getType());
if (!arrTy || !resTy)
return emitOpError("expects tile_buf_array input and tile_buf result");

auto elemTy = llvm::dyn_cast<TileBufType>(arrTy.getElementType());
if (!elemTy)
return emitOpError("array element type must be tile_buf");
if (elemTy != resTy)
return emitOpError("result type must equal array element type");

int64_t idx = 0;
if (getConstIndex(getIndex(), idx)) {
if (idx < 0 || idx >= arrTy.getSize())
return emitOpError("constant index out of range for tile_buf_array");
}
return success();
}

} // namespace pto
} // namespace mlir

Expand Down
152 changes: 152 additions & 0 deletions lib/PTO/Transforms/PTOViewToMemref.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"

#include "mlir/IR/AsmState.h"
#include "mlir/IR/BuiltinAttributes.h"
Expand All @@ -21,11 +22,13 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"

#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/raw_ostream.h"
#include "Utils.h" // 假设包含一些通用的工具函数

#include <algorithm>
#include <functional>

using namespace mlir;

Expand Down Expand Up @@ -556,6 +559,155 @@ struct PTOViewToMemrefPass
rewriter.replaceOp(op, bindOp.getResult());
}

// ------------------------------------------------------------------
// Stage 0.75: Lower tile_buf_array container ops (MVP)
// ------------------------------------------------------------------
SmallVector<mlir::pto::TileBufArrayGetOp, 8> arrayGets;
func.walk([&](mlir::pto::TileBufArrayGetOp op) { arrayGets.push_back(op); });

for (auto op : arrayGets) {
IRRewriter rewriter(ctx);
rewriter.setInsertionPoint(op);
Location loc = op.getLoc();

auto make = op.getArray().getDefiningOp<mlir::pto::MakeTileBufArrayOp>();
if (!make) {
op.emitError("tile_buf_array_get currently requires array from pto.make_tile_buf_array");
signalPassFailure();
return;
}

auto elems = make.getElements();
if (elems.empty()) {
op.emitError("tile_buf_array_get requires non-empty make_tile_buf_array");
signalPassFailure();
return;
}

int64_t idx = 0;
if (!getConstIndexValue(op.getIndex(), idx)) {
// Dynamic index lowering:
// Avoid creating memref-typed scf.if results (which later EmitC
// lowering cannot always reconcile). Instead, clone each direct
// no-result user behind an if-ladder and bind the selected element
// per branch.
Value indexValue = ensureIndex(rewriter, loc, op.getIndex(), op);
if (!indexValue) {
signalPassFailure();
return;
}
Value dynGetValue = op.getResult();
SmallVector<Operation *, 8> users;
SmallPtrSet<Operation *, 8> seen;
for (OpOperand &use : dynGetValue.getUses()) {
Operation *user = use.getOwner();
if (seen.insert(user).second)
users.push_back(user);
}

auto cloneUserReplacingDynGet = [&](Operation *user, Value replacement) {
Operation *cloned = rewriter.clone(*user);
for (OpOperand &operand : cloned->getOpOperands()) {
if (operand.get() == dynGetValue)
operand.set(replacement);
}
};

for (Operation *user : users) {
if (user == op.getOperation())
continue;
if (user->getNumRegions() != 0 || user->getNumResults() != 0) {
user->emitError(
"dynamic tile_buf_array_get currently supports only direct no-result users");
signalPassFailure();
return;
}

// Single element array: trivial replacement.
if (elems.size() == 1) {
rewriter.setInsertionPoint(user);
cloneUserReplacingDynGet(user, elems.front());
user->erase();
continue;
}

std::function<void(Block *, size_t)> buildElseLadder;
buildElseLadder = [&](Block *block, size_t elemIdx) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(block->getTerminator());

if (elemIdx + 1 == elems.size()) {
cloneUserReplacingDynGet(user, elems[elemIdx]);
return;
}

Value cIdx = rewriter.create<arith::ConstantIndexOp>(
loc, static_cast<int64_t>(elemIdx));
Value cond = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, indexValue, cIdx);
auto ifOp = rewriter.create<scf::IfOp>(
loc, TypeRange{}, cond, /*withElseRegion=*/true);

{
OpBuilder::InsertionGuard thenGuard(rewriter);
rewriter.setInsertionPoint(ifOp.thenBlock()->getTerminator());
cloneUserReplacingDynGet(user, elems[elemIdx]);
}

buildElseLadder(ifOp.elseBlock(), elemIdx + 1);
};

{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(user);
Value cIdx0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value cond0 = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, indexValue, cIdx0);
auto ifOp = rewriter.create<scf::IfOp>(
loc, TypeRange{}, cond0, /*withElseRegion=*/true);

{
OpBuilder::InsertionGuard thenGuard(rewriter);
rewriter.setInsertionPoint(ifOp.thenBlock()->getTerminator());
cloneUserReplacingDynGet(user, elems[0]);
}

buildElseLadder(ifOp.elseBlock(), 1);
}

user->erase();
}

if (!op->use_empty()) {
op.emitError("dynamic tile_buf_array_get still has users after lowering");
signalPassFailure();
return;
}
op.erase();
continue;
}

if (idx < 0 || idx >= static_cast<int64_t>(elems.size())) {
op.emitError("tile_buf_array_get index out of range in lowering");
signalPassFailure();
return;
}

rewriter.replaceOp(op, elems[static_cast<size_t>(idx)]);
}

SmallVector<mlir::pto::MakeTileBufArrayOp, 8> makeArrays;
func.walk([&](mlir::pto::MakeTileBufArrayOp op) { makeArrays.push_back(op); });

for (auto op : makeArrays) {
if (!op->use_empty()) {
op.emitError("unlowered tile_buf_array value remains after tile_buf_array_get rewrite");
signalPassFailure();
return;
}
op.erase();
}

// ------------------------------------------------------------------
// Stage 1: Lower pto.make_tensor_view -> memref.reinterpret_cast
// ------------------------------------------------------------------
Expand Down
Loading