Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/PTO/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ std::unique_ptr<Pass> createPTORemoveRedundantBarrierPass();
std::unique_ptr<Pass> createPTOViewToMemrefPass();
std::unique_ptr<mlir::Pass> createPTOInsertLoadStoreForMixCVPass();
std::unique_ptr<Pass> createInferPTOLayoutPass();
std::unique_ptr<Pass> createInferPTOTileConfigPass();
// Declare register function
void registerPTOPasses();

Expand Down
13 changes: 13 additions & 0 deletions include/PTO/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,19 @@ def InferPTOLayout : Pass<"pto-infer-layout", "func::FuncOp"> {
let dependentDialects = ["pto::PTODialect", "arith::ArithDialect"];
}

def InferPTOTileConfig : Pass<"pto-infer-tile-config", "ModuleOp"> {
let summary = "Infer arch-aware tile config for matmul memory spaces";
let description = [{
Normalizes LEFT/RIGHT/ACC tile buffer configs based on `pto.target_arch`
so users do not need to manually thread BLayout/SLayout/fractal choices.
This updates both high-level tile_buf values and pre-lowered
pto.pointer_cast/pto.bind_tile config attributes, while keeping
func.func / func.call interfaces in sync.
}];
let constructor = "mlir::pto::createInferPTOTileConfigPass()";
let dependentDialects = ["pto::PTODialect", "func::FuncDialect"];
}


def InferPTOMemScope : Pass<"pto-infer-mem-scope"> {
let summary = "Infer memory scope for PTO Ops";
Expand Down
107 changes: 73 additions & 34 deletions lib/PTO/IR/PTOTypeDefs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,13 @@ Type TileBufType::parse(AsmParser &parser) {
Type dtype;
int64_t rows = 0, cols = 0;
int64_t vrow = -1, vcol = -1;
std::string blayoutStr, slayoutStr;
int64_t fractal = 0;
uint32_t padInt;
TileBufConfigAttr defaultConfig = TileBufConfigAttr::getDefault(ctx);
std::string blayoutStr = stringifyBLayout(
llvm::cast<BLayoutAttr>(defaultConfig.getBLayout()).getValue()).str();
std::string slayoutStr = stringifySLayout(
llvm::cast<SLayoutAttr>(defaultConfig.getSLayout()).getValue()).str();
int64_t fractal = defaultConfig.getSFractalSize().getInt();
uint32_t padInt = 0;

auto parseKeyEq = [&](StringRef expectedKey) -> LogicalResult {
if (failed(parser.parseKeyword(expectedKey)))
Expand Down Expand Up @@ -133,40 +137,70 @@ Type TileBufType::parse(AsmParser &parser) {
return Type();
}
}
if (failed(parser.parseComma())) return Type();
}

// blayout=RowMajor
{
if (failed(parseKeyEq("blayout"))) return Type();
if (failed(parser.parseKeywordOrString(&blayoutStr))) return Type();
if (failed(parser.parseComma())) return Type();
}


// slayout=NoneBox
{
if (failed(parseKeyEq("slayout"))) return Type();
if (failed(parser.parseKeywordOrString(&slayoutStr))) return Type();
if (failed(parser.parseComma())) return Type();
}

// fractal=512
{
if (failed(parseKeyEq("fractal"))) return Type();
if (failed(parser.parseInteger(fractal))) return Type();
if (failed(parser.parseComma())) return Type();
}

// pad=Null
{
if (failed(parseKeyEq("pad"))) return Type();
if (failed(parser.parseInteger(padInt))) return Type();
if (failed(parser.parseOptionalGreater())) {
if (failed(parser.parseComma()))
return Type();

bool seenBLayout = false;
bool seenSLayout = false;
bool seenFractal = false;
bool seenPad = false;

while (true) {
StringRef key;
if (failed(parser.parseKeyword(&key)))
return Type();
if (failed(parser.parseEqual()))
return Type();

if (key == "blayout") {
if (seenBLayout) {
parser.emitError(parser.getCurrentLocation(), "duplicate blayout");
return Type();
}
seenBLayout = true;
if (failed(parser.parseKeywordOrString(&blayoutStr)))
return Type();
} else if (key == "slayout") {
if (seenSLayout) {
parser.emitError(parser.getCurrentLocation(), "duplicate slayout");
return Type();
}
seenSLayout = true;
if (failed(parser.parseKeywordOrString(&slayoutStr)))
return Type();
} else if (key == "fractal") {
if (seenFractal) {
parser.emitError(parser.getCurrentLocation(), "duplicate fractal");
return Type();
}
seenFractal = true;
if (failed(parser.parseInteger(fractal)))
return Type();
} else if (key == "pad") {
if (seenPad) {
parser.emitError(parser.getCurrentLocation(), "duplicate pad");
return Type();
}
seenPad = true;
if (failed(parser.parseInteger(padInt)))
return Type();
} else {
parser.emitError(parser.getCurrentLocation(),
"unknown key in tile_buf type: ")
<< key;
return Type();
}

if (succeeded(parser.parseOptionalGreater()))
break;
if (failed(parser.parseComma()))
return Type();
}
}

if (failed(parser.parseGreater()))
return Type();

// -------- 语义校验/构造 --------
if (rows < 0 || cols < 0) {
parser.emitError(parser.getNameLoc(), "rows/cols must be non-negative");
Expand Down Expand Up @@ -281,9 +315,14 @@ void mlir::pto::TileBufType::print(mlir::AsmPrinter &printer) const {
if (vcol < 0) printer << "?";
else printer << vcol;

if (cfg.isDefault()) {
printer << ">";
return;
}

printer << ", blayout=" << stringifyBLayout(blayout.getValue())
<< ", slayout=" << stringifySLayout(slayout.getValue())
<< ", fractal=" << cfg.getSFractalSize().getInt()
<< ", pad=" << stringifyLocFromPad(cfg.getPad())
<< ">";
}
}
1 change: 1 addition & 0 deletions lib/PTO/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ add_mlir_dialect_library(PTOTransforms
PTOPlanMemory.cpp
PTORemoveRedundantBarrier.cpp
InferPTOLayout.cpp
InferPTOTileConfig.cpp
BufferizableOpInterfaceImpl.cpp
ConvertToPTOOp.cpp
PTOHighDimLowering.cpp
Expand Down
Loading
Loading