diff --git a/include/PTO/IR/PTOTypeDefs.td b/include/PTO/IR/PTOTypeDefs.td index d4a79746..21056df4 100644 --- a/include/PTO/IR/PTOTypeDefs.td +++ b/include/PTO/IR/PTOTypeDefs.td @@ -115,7 +115,7 @@ def TileBufType : TypeDef { "mlir::Type":$elementType, "mlir::Attribute":$memorySpace, ArrayRefParameter<"int64_t">:$validShape, - "mlir::pto::TileBufConfigAttr":$config // TileBufConfigAttr (or null -> default) + "mlir::pto::TileBufConfigAttr":$config // nullable: null means config omitted in source ); let hasCustomAssemblyFormat = 1; @@ -170,6 +170,7 @@ def TileBufType : TypeDef { } mlir::pto::TileBufConfigAttr getConfigAttr() const; + bool hasExplicitConfig() const; bool hasNonDefaultConfig() const; // ✅ 返回强类型 Attr,而不是 enum diff --git a/include/PTO/Transforms/Passes.h b/include/PTO/Transforms/Passes.h index 3df9390d..58aac184 100644 --- a/include/PTO/Transforms/Passes.h +++ b/include/PTO/Transforms/Passes.h @@ -64,6 +64,7 @@ std::unique_ptr createPTORemoveRedundantBarrierPass(); std::unique_ptr createPTOViewToMemrefPass(); std::unique_ptr createPTOInsertLoadStoreForMixCVPass(); std::unique_ptr createInferPTOLayoutPass(); +std::unique_ptr createInferPTOTileConfigPass(); // Declare register function void registerPTOPasses(); diff --git a/include/PTO/Transforms/Passes.td b/include/PTO/Transforms/Passes.td index ab7f29df..c3727666 100644 --- a/include/PTO/Transforms/Passes.td +++ b/include/PTO/Transforms/Passes.td @@ -69,6 +69,19 @@ def InferPTOLayout : Pass<"pto-infer-layout", "func::FuncOp"> { let dependentDialects = ["pto::PTODialect", "arith::ArithDialect"]; } +def InferPTOTileConfig : Pass<"pto-infer-tile-config", "ModuleOp"> { + let summary = "Infer arch-aware tile config for matmul memory spaces"; + let description = [{ + Normalizes LEFT/RIGHT/ACC tile buffer configs based on `pto.target_arch` + so users do not need to manually thread BLayout/SLayout/fractal choices. + This updates both high-level tile_buf values and pre-lowered + pto.pointer_cast/pto.bind_tile config attributes, while keeping + func.func / func.call interfaces in sync. + }]; + let constructor = "mlir::pto::createInferPTOTileConfigPass()"; + let dependentDialects = ["pto::PTODialect", "func::FuncDialect"]; +} + def InferPTOMemScope : Pass<"pto-infer-mem-scope"> { let summary = "Infer memory scope for PTO Ops"; diff --git a/lib/PTO/IR/PTOTypeDefs.cpp b/lib/PTO/IR/PTOTypeDefs.cpp index 2106249d..25cb7d86 100644 --- a/lib/PTO/IR/PTOTypeDefs.cpp +++ b/lib/PTO/IR/PTOTypeDefs.cpp @@ -18,8 +18,15 @@ TileBufConfigAttr TileBufType::getConfigAttr() const { return cfg; } } + +bool TileBufType::hasExplicitConfig() const { + if constexpr (std::is_same_v) + return static_cast(getConfig()); + return static_cast(llvm::dyn_cast_or_null(getConfig())); +} + bool TileBufType::hasNonDefaultConfig() const { - return !getConfigAttr().isDefault(); + return hasExplicitConfig() && !getConfigAttr().isDefault(); } mlir::Attribute TileBufType::getBLayoutAttr() const { return getConfigAttr().getBLayout(); } @@ -61,9 +68,14 @@ Type TileBufType::parse(AsmParser &parser) { Type dtype; int64_t rows = 0, cols = 0; int64_t vrow = -1, vcol = -1; - std::string blayoutStr, slayoutStr; - int64_t fractal = 0; - uint32_t padInt; + TileBufConfigAttr defaultConfig = TileBufConfigAttr::getDefault(ctx); + std::string blayoutStr = stringifyBLayout( + llvm::cast(defaultConfig.getBLayout()).getValue()).str(); + std::string slayoutStr = stringifySLayout( + llvm::cast(defaultConfig.getSLayout()).getValue()).str(); + int64_t fractal = defaultConfig.getSFractalSize().getInt(); + uint32_t padInt = 0; + bool hasExplicitConfig = false; auto parseKeyEq = [&](StringRef expectedKey) -> LogicalResult { if (failed(parser.parseKeyword(expectedKey))) @@ -133,40 +145,71 @@ Type TileBufType::parse(AsmParser &parser) { return Type(); } } - if (failed(parser.parseComma())) return Type(); - } - - // blayout=RowMajor - { - if (failed(parseKeyEq("blayout"))) return Type(); - if (failed(parser.parseKeywordOrString(&blayoutStr))) return Type(); - if (failed(parser.parseComma())) return Type(); - } - - - // slayout=NoneBox - { - if (failed(parseKeyEq("slayout"))) return Type(); - if (failed(parser.parseKeywordOrString(&slayoutStr))) return Type(); - if (failed(parser.parseComma())) return Type(); - } - - // fractal=512 - { - if (failed(parseKeyEq("fractal"))) return Type(); - if (failed(parser.parseInteger(fractal))) return Type(); - if (failed(parser.parseComma())) return Type(); } - // pad=Null - { - if (failed(parseKeyEq("pad"))) return Type(); - if (failed(parser.parseInteger(padInt))) return Type(); + if (failed(parser.parseOptionalGreater())) { + hasExplicitConfig = true; + if (failed(parser.parseComma())) + return Type(); + + bool seenBLayout = false; + bool seenSLayout = false; + bool seenFractal = false; + bool seenPad = false; + + while (true) { + StringRef key; + if (failed(parser.parseKeyword(&key))) + return Type(); + if (failed(parser.parseEqual())) + return Type(); + + if (key == "blayout") { + if (seenBLayout) { + parser.emitError(parser.getCurrentLocation(), "duplicate blayout"); + return Type(); + } + seenBLayout = true; + if (failed(parser.parseKeywordOrString(&blayoutStr))) + return Type(); + } else if (key == "slayout") { + if (seenSLayout) { + parser.emitError(parser.getCurrentLocation(), "duplicate slayout"); + return Type(); + } + seenSLayout = true; + if (failed(parser.parseKeywordOrString(&slayoutStr))) + return Type(); + } else if (key == "fractal") { + if (seenFractal) { + parser.emitError(parser.getCurrentLocation(), "duplicate fractal"); + return Type(); + } + seenFractal = true; + if (failed(parser.parseInteger(fractal))) + return Type(); + } else if (key == "pad") { + if (seenPad) { + parser.emitError(parser.getCurrentLocation(), "duplicate pad"); + return Type(); + } + seenPad = true; + if (failed(parser.parseInteger(padInt))) + return Type(); + } else { + parser.emitError(parser.getCurrentLocation(), + "unknown key in tile_buf type: ") + << key; + return Type(); + } + + if (succeeded(parser.parseOptionalGreater())) + break; + if (failed(parser.parseComma())) + return Type(); + } } - if (failed(parser.parseGreater())) - return Type(); - // -------- 语义校验/构造 -------- if (rows < 0 || cols < 0) { parser.emitError(parser.getNameLoc(), "rows/cols must be non-negative"); @@ -209,7 +252,9 @@ Type TileBufType::parse(AsmParser &parser) { IntegerAttr::get(IntegerType::get(ctx, 32), fractal); auto padAttr = PadValueAttr::get(ctx, pv.value()); auto memorySpaceAttr = AddressSpaceAttr::get(ctx, memorySpace.value()); - auto cfg = TileBufConfigAttr::get(ctx, blAttr, slAttr, fractalAttr, padAttr); + TileBufConfigAttr cfg; + if (hasExplicitConfig) + cfg = TileBufConfigAttr::get(ctx, blAttr, slAttr, fractalAttr, padAttr); SmallVector shape{rows, cols}; SmallVector validShape{vrow, vcol}; @@ -250,8 +295,9 @@ void mlir::pto::TileBufType::print(mlir::AsmPrinter &printer) const { int64_t rows = shape.size() > 0 ? shape[0] : 0; int64_t cols = shape.size() > 1 ? shape[1] : 0; - auto cfg = getConfigAttr(); - if (!cfg) cfg = mlir::pto::TileBufConfigAttr::getDefault(getContext()); + bool hasExplicit = hasExplicitConfig(); + auto cfg = hasExplicit ? getConfigAttr() + : mlir::pto::TileBufConfigAttr::getDefault(getContext()); llvm::StringRef locStr = stringifyLocFromMemorySpace(getMemorySpace()); @@ -281,9 +327,14 @@ void mlir::pto::TileBufType::print(mlir::AsmPrinter &printer) const { if (vcol < 0) printer << "?"; else printer << vcol; + if (!hasExplicit) { + printer << ">"; + return; + } + printer << ", blayout=" << stringifyBLayout(blayout.getValue()) << ", slayout=" << stringifySLayout(slayout.getValue()) << ", fractal=" << cfg.getSFractalSize().getInt() << ", pad=" << stringifyLocFromPad(cfg.getPad()) << ">"; -} \ No newline at end of file +} diff --git a/lib/PTO/Transforms/CMakeLists.txt b/lib/PTO/Transforms/CMakeLists.txt index d9d013c9..65763dd0 100644 --- a/lib/PTO/Transforms/CMakeLists.txt +++ b/lib/PTO/Transforms/CMakeLists.txt @@ -13,6 +13,7 @@ add_mlir_dialect_library(PTOTransforms PTOPlanMemory.cpp PTORemoveRedundantBarrier.cpp InferPTOLayout.cpp + InferPTOTileConfig.cpp BufferizableOpInterfaceImpl.cpp ConvertToPTOOp.cpp PTOHighDimLowering.cpp diff --git a/lib/PTO/Transforms/InferPTOTileConfig.cpp b/lib/PTO/Transforms/InferPTOTileConfig.cpp new file mode 100644 index 00000000..41ed39e6 --- /dev/null +++ b/lib/PTO/Transforms/InferPTOTileConfig.cpp @@ -0,0 +1,264 @@ +//===- InferPTOTileConfig.cpp - Infer arch-aware tile config -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PTO/IR/PTO.h" +#include "PTO/Transforms/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +#define GEN_PASS_DEF_INFERPTOTILECONFIG +#include "PTO/Transforms/Passes.h.inc" +} // namespace mlir + +using namespace mlir; +using namespace mlir::pto; + +namespace { + +static PTOArch getTargetArch(Operation *op) { + auto module = dyn_cast(op); + if (!module) + module = op->getParentOfType(); + if (!module) + return PTOArch::A3; + auto arch = module->getAttrOfType("pto.target_arch"); + if (arch && arch.getValue().equals_insensitive("a5")) + return PTOArch::A5; + return PTOArch::A3; +} + +static TileBufConfigAttr inferTileConfigForSpace(MLIRContext *ctx, + AddressSpace space, + PTOArch arch, + PadValueAttr padAttr) { + BLayout blayout = BLayout::RowMajor; + SLayout slayout = SLayout::NoneBox; + int32_t fractal = 512; + + switch (space) { + case AddressSpace::LEFT: + blayout = arch == PTOArch::A5 ? BLayout::ColMajor : BLayout::RowMajor; + slayout = SLayout::RowMajor; + fractal = 512; + break; + case AddressSpace::RIGHT: + blayout = BLayout::RowMajor; + slayout = SLayout::ColMajor; + fractal = 512; + break; + case AddressSpace::ACC: + blayout = BLayout::ColMajor; + slayout = SLayout::RowMajor; + fractal = 1024; + break; + default: + return {}; + } + + Builder builder(ctx); + if (!padAttr) + padAttr = PadValueAttr::get(ctx, PadValue::Null); + return TileBufConfigAttr::get( + ctx, BLayoutAttr::get(ctx, blayout), SLayoutAttr::get(ctx, slayout), + builder.getI32IntegerAttr(fractal), padAttr); +} + +static TileBufType normalizeTileBufType(TileBufType tileTy, PTOArch arch) { + if (tileTy.hasExplicitConfig()) + return tileTy; + + auto spaceAttr = + dyn_cast_or_null(tileTy.getMemorySpace()); + if (!spaceAttr) + return {}; + + auto currentConfig = tileTy.getConfigAttr(); + auto desiredConfig = inferTileConfigForSpace( + tileTy.getContext(), spaceAttr.getAddressSpace(), arch, + dyn_cast_or_null(currentConfig.getPad())); + if (!desiredConfig) + return {}; + + if (desiredConfig == currentConfig) + return tileTy; + + return TileBufType::get(tileTy.getContext(), tileTy.getShape(), + tileTy.getElementType(), tileTy.getMemorySpace(), + tileTy.getValidShape(), desiredConfig); +} + +static TileBufConfigAttr inferMemRefTileConfig(Type memrefLikeType, PTOArch arch, + MLIRContext *ctx, + TileBufConfigAttr currentConfig) { + auto memrefTy = dyn_cast(memrefLikeType); + if (!memrefTy) + return {}; + auto spaceAttr = dyn_cast_or_null(memrefTy.getMemorySpace()); + if (!spaceAttr) + return {}; + return inferTileConfigForSpace( + ctx, spaceAttr.getAddressSpace(), arch, + currentConfig ? dyn_cast_or_null(currentConfig.getPad()) + : PadValueAttr()); +} + +static Type normalizeType(Type type, PTOArch arch) { + auto tileTy = dyn_cast(type); + if (!tileTy) + return type; + auto normalizedTy = normalizeTileBufType(tileTy, arch); + return normalizedTy ? Type(normalizedTy) : type; +} + +static bool normalizeValue(Value value, PTOArch arch) { + Type currentType = value.getType(); + Type normalizedType = normalizeType(currentType, arch); + if (normalizedType == currentType) + return false; + value.setType(normalizedType); + return true; +} + +static LogicalResult syncFunctionSignature(func::FuncOp func, PTOArch arch) { + SmallVector newInputs; + SmallVector newResults; + + if (func.isExternal()) { + llvm::transform(func.getArgumentTypes(), std::back_inserter(newInputs), + [&](Type type) { return normalizeType(type, arch); }); + llvm::transform(func.getResultTypes(), std::back_inserter(newResults), + [&](Type type) { return normalizeType(type, arch); }); + } else { + Block &entry = func.front(); + newInputs.assign(entry.getArgumentTypes().begin(), entry.getArgumentTypes().end()); + + if (func.getNumResults() != 0) { + bool sawReturn = false; + func.walk([&](func::ReturnOp ret) { + SmallVector operandTypes(ret.getOperandTypes().begin(), + ret.getOperandTypes().end()); + if (!sawReturn) { + newResults = operandTypes; + sawReturn = true; + return; + } + if (newResults != operandTypes) { + ret.emitOpError("all return ops must agree on result types after " + "tile config inference"); + func.emitError("inconsistent function result types after tile config " + "inference"); + } + }); + if (!sawReturn) + return func.emitOpError("non-external function with results must have " + "a return op after tile config inference"); + } + } + + auto newFunctionType = FunctionType::get(func.getContext(), newInputs, newResults); + if (newFunctionType != func.getFunctionType()) + func.setFunctionType(newFunctionType); + return success(); +} + +static LogicalResult syncCallSites(ModuleOp module, func::FuncOp callee) { + auto uses = callee.getSymbolUses(module); + if (!uses) + return success(); + + for (SymbolTable::SymbolUse use : *uses) { + auto call = dyn_cast(use.getUser()); + if (!call) + continue; + + auto expectedInputs = callee.getFunctionType().getInputs(); + if (call.getNumOperands() != expectedInputs.size()) + return call.emitOpError("operand count does not match updated callee " + "signature for ") + << callee.getSymName(); + + for (auto [idx, operand] : llvm::enumerate(call.getArgOperands())) { + if (operand.getType() != expectedInputs[idx]) { + return call.emitOpError("operand type does not match updated callee " + "signature at index ") + << idx << " for " << callee.getSymName(); + } + } + + if (llvm::equal(call.getResultTypes(), callee.getResultTypes())) + continue; + + OpBuilder builder(call); + auto newCall = + builder.create(call.getLoc(), callee, call.getArgOperands()); + newCall->setAttrs(call->getAttrs()); + call.replaceAllUsesWith(newCall.getResults()); + call.erase(); + } + + return success(); +} + +struct InferPTOTileConfigPass + : public impl::InferPTOTileConfigBase { + void runOnOperation() override { + ModuleOp module = getOperation(); + PTOArch arch = getTargetArch(module); + + auto normalizeRegion = [&](Region ®ion, auto &self) -> void { + for (Block &block : region) { + for (BlockArgument arg : block.getArguments()) + (void)normalizeValue(arg, arch); + + for (Operation &op : block) { + for (Value result : op.getResults()) + (void)normalizeValue(result, arch); + + if (auto pointerCast = dyn_cast(op)) { + auto currentConfig = pointerCast.getConfig(); + if (!currentConfig) { + auto desiredConfig = inferMemRefTileConfig( + pointerCast.getResult().getType(), arch, &getContext(), + TileBufConfigAttr()); + if (desiredConfig) + pointerCast->setAttr("config", desiredConfig); + } + } + + for (Region &nested : op.getRegions()) + self(nested, self); + } + } + }; + + for (func::FuncOp func : module.getOps()) { + if (!func.isExternal()) + normalizeRegion(func.getBody(), normalizeRegion); + if (failed(syncFunctionSignature(func, arch))) { + signalPassFailure(); + return; + } + } + + for (func::FuncOp func : module.getOps()) { + if (failed(syncCallSites(module, func))) { + signalPassFailure(); + return; + } + } + } +}; + +} // namespace + +std::unique_ptr mlir::pto::createInferPTOTileConfigPass() { + return std::make_unique(); +} diff --git a/lib/PTO/Transforms/PTOToEmitC.cpp b/lib/PTO/Transforms/PTOToEmitC.cpp index 939287dd..a0937c66 100644 --- a/lib/PTO/Transforms/PTOToEmitC.cpp +++ b/lib/PTO/Transforms/PTOToEmitC.cpp @@ -2082,6 +2082,11 @@ struct FuncToEmitC : public OpConversionPattern { emitcFunc.setSpecifiersAttr( rewriter.getStrArrayAttr({"__global__ AICORE"})); + if (op.isExternal()) { + rewriter.eraseOp(op); + return success(); + } + // Inline the original body, then convert region/block argument types to // match the converted signature (also covers CFG blocks introduced by // pre-lowering, e.g. scf.while -> cf.br/cf.cond_br). @@ -2108,6 +2113,26 @@ struct FuncToEmitC : public OpConversionPattern { } }; +struct CallToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(func::CallOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector resultTypes; + if (failed(getTypeConverter()->convertTypes(op.getResultTypes(), resultTypes))) + return rewriter.notifyMatchFailure(op, "failed to convert call result types"); + + auto callee = op.getCalleeAttr(); + if (!callee) + return rewriter.notifyMatchFailure(op, "expected direct callee symbol"); + + auto newCall = rewriter.create( + op.getLoc(), resultTypes, callee.getValue(), adaptor.getOperands()); + rewriter.replaceOp(op, newCall.getResults()); + return success(); + } +}; + //===----------------------------------------------------------------------===// // SubView lowering to GlobalTensor (keep your existing code) //===----------------------------------------------------------------------=== @@ -7306,6 +7331,7 @@ static void populatePTOToEmitCPatterns(RewritePatternSet &patterns, patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); diff --git a/lib/PTO/Transforms/PTOViewToMemref.cpp b/lib/PTO/Transforms/PTOViewToMemref.cpp index 9bc0b95b..c0824498 100644 --- a/lib/PTO/Transforms/PTOViewToMemref.cpp +++ b/lib/PTO/Transforms/PTOViewToMemref.cpp @@ -407,6 +407,79 @@ static Type convertPTOTypeToMemRef(Type t) { return t; } +static FunctionType convertFunctionTypeToMemRef(FunctionType fnTy, + MLIRContext *ctx) { + SmallVector newInputs; + newInputs.reserve(fnTy.getNumInputs()); + for (Type t : fnTy.getInputs()) + newInputs.push_back(convertPTOTypeToMemRef(t)); + + SmallVector newResults; + newResults.reserve(fnTy.getNumResults()); + for (Type t : fnTy.getResults()) + newResults.push_back(convertPTOTypeToMemRef(t)); + + return FunctionType::get(ctx, newInputs, newResults); +} + +static LogicalResult rewriteFunctionInterfacesToMemRef(ModuleOp mod, + MLIRContext *ctx) { + for (auto func : mod.getOps()) { + auto newFnTy = convertFunctionTypeToMemRef(func.getFunctionType(), ctx); + + if (!func.isExternal()) { + Block &entry = func.front(); + if (entry.getNumArguments() != newFnTy.getNumInputs()) { + return func.emitOpError( + "entry block argument count does not match rewritten signature"); + } + + for (unsigned i = 0; i < entry.getNumArguments(); ++i) { + if (entry.getArgument(i).getType() != newFnTy.getInput(i)) + entry.getArgument(i).setType(newFnTy.getInput(i)); + } + } + + if (func.getFunctionType() != newFnTy) + func.setFunctionType(newFnTy); + } + + SmallVector callOps; + mod.walk([&](func::CallOp call) { callOps.push_back(call); }); + + for (func::CallOp call : callOps) { + auto callee = mod.lookupSymbol(call.getCallee()); + if (!callee) + continue; + + auto calleeTy = callee.getFunctionType(); + if (call.getNumOperands() != calleeTy.getNumInputs()) { + return call.emitOpError("operand count does not match rewritten callee " + "signature for ") + << callee.getSymName(); + } + + bool needsRewrite = !llvm::equal(call.getResultTypes(), calleeTy.getResults()); + for (auto [idx, operand] : llvm::enumerate(call.getOperands())) { + if (operand.getType() != calleeTy.getInput(idx)) { + needsRewrite = true; + break; + } + } + if (!needsRewrite) + continue; + + OpBuilder builder(call); + auto newCall = + builder.create(call.getLoc(), callee, call.getOperands()); + newCall->setAttrs(call->getAttrs()); + call.replaceAllUsesWith(newCall.getResults()); + call.erase(); + } + + return success(); +} + // Ensure scf.if result types follow the rewritten yield operand types. // PTOViewToMemref rewrites tile values to memref in branch bodies, but scf.if // result types are not auto-updated by those op-local rewrites. @@ -476,6 +549,11 @@ struct PTOViewToMemrefPass // Debug output before pass // dumpPretty(mod.getOperation(), llvm::errs()); + if (failed(rewriteFunctionInterfacesToMemRef(mod, ctx))) { + signalPassFailure(); + return; + } + for (auto func : mod.getOps()) { if (func.isExternal()) continue; diff --git a/test/basic/external_tile_call_a5.pto b/test/basic/external_tile_call_a5.pto new file mode 100644 index 00000000..9661b643 --- /dev/null +++ b/test/basic/external_tile_call_a5.pto @@ -0,0 +1,14 @@ +// RUN: ptoas --pto-arch=a5 %s | FileCheck %s + +module attributes {"pto.device-spec" = "Ascend910B1"} { + func.func private @ext() -> !pto.tile_buf + + func.func @caller() -> !pto.tile_buf { + %tile = func.call @ext() : () -> !pto.tile_buf + return %tile : !pto.tile_buf + } +} + +// CHECK: __global__ AICORE __ca__ float* ext(); +// CHECK: __global__ AICORE __ca__ float* caller() +// CHECK: __ca__ float* v1 = ext(); diff --git a/test/basic/external_tile_decl_a5.pto b/test/basic/external_tile_decl_a5.pto new file mode 100644 index 00000000..38e7319f --- /dev/null +++ b/test/basic/external_tile_decl_a5.pto @@ -0,0 +1,7 @@ +// RUN: ptoas --pto-arch=a5 %s | FileCheck %s + +module attributes {"pto.device-spec" = "Ascend910B1"} { + func.func private @ext() -> !pto.tile_buf +} + +// CHECK: __global__ AICORE __ca__ float* ext(); diff --git a/test/basic/matmul_tile_config_call_boundary_a5.pto b/test/basic/matmul_tile_config_call_boundary_a5.pto new file mode 100644 index 00000000..7e43b7d3 --- /dev/null +++ b/test/basic/matmul_tile_config_call_boundary_a5.pto @@ -0,0 +1,18 @@ +// RUN: ptoas --pto-arch=a5 %s | FileCheck %s + +module attributes {"pto.device-spec" = "Ascend910B1"} { + func.func @callee() -> !pto.tile_buf { + %tile = pto.alloc_tile : !pto.tile_buf + return %tile : !pto.tile_buf + } + + func.func @caller() -> !pto.tile_buf { + %tile = func.call @callee() : () -> !pto.tile_buf + return %tile : !pto.tile_buf + } +} + +// CHECK: __global__ AICORE __ca__ float* callee() +// CHECK: Tile +// CHECK: __global__ AICORE __ca__ float* caller() +// CHECK: __ca__ float* v1 = callee(); diff --git a/test/basic/matmul_tile_config_infer_a3.pto b/test/basic/matmul_tile_config_infer_a3.pto new file mode 100644 index 00000000..cb4dfe52 --- /dev/null +++ b/test/basic/matmul_tile_config_infer_a3.pto @@ -0,0 +1,15 @@ +// RUN: ptoas --pto-arch=a3 %s | FileCheck %s + +module attributes {"pto.device-spec" = "Ascend910B1"} { + func.func @matmul_tile_config_infer_a3() { + %lhs = pto.alloc_tile : !pto.tile_buf + %rhs = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tmatmul ins(%lhs, %rhs : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) + return + } +} + +// CHECK: Tile +// CHECK: Tile +// CHECK: Tile diff --git a/test/basic/matmul_tile_config_infer_a5.pto b/test/basic/matmul_tile_config_infer_a5.pto new file mode 100644 index 00000000..9e54e0ac --- /dev/null +++ b/test/basic/matmul_tile_config_infer_a5.pto @@ -0,0 +1,15 @@ +// RUN: ptoas --pto-arch=a5 %s | FileCheck %s + +module attributes {"pto.device-spec" = "Ascend910B1"} { + func.func @matmul_tile_config_infer_a5() { + %lhs = pto.alloc_tile : !pto.tile_buf + %rhs = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tmatmul ins(%lhs, %rhs : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) + return + } +} + +// CHECK: Tile +// CHECK: Tile +// CHECK: Tile diff --git a/test/basic/tile_config_preserve_explicit_a5.pto b/test/basic/tile_config_preserve_explicit_a5.pto new file mode 100644 index 00000000..db199dfb --- /dev/null +++ b/test/basic/tile_config_preserve_explicit_a5.pto @@ -0,0 +1,14 @@ +// RUN: ptoas --pto-arch=a5 %s 2>&1 1>/dev/null | FileCheck %s + +module attributes {"pto.device-spec" = "Ascend910B1"} { + func.func @preserve_explicit_a5() { + %lhs = pto.alloc_tile : !pto.tile_buf + %rhs = pto.alloc_tile : !pto.tile_buf + %acc = pto.alloc_tile : !pto.tile_buf + return + } +} + +// CHECK: pto.bind_tile {{.*}}config = #pto.tile_buf_config, slayout=#pto.slayout, s_fractal_size=512, pad=#pto.pad_value> +// CHECK: pto.bind_tile {{.*}}config = #pto.tile_buf_config, slayout=#pto.slayout, s_fractal_size=512, pad=#pto.pad_value> +// CHECK: pto.bind_tile {{.*}}config = #pto.tile_buf_config, slayout=#pto.slayout, s_fractal_size=1024, pad=#pto.pad_value> diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index 2d90de3b..61b34e39 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -701,6 +701,7 @@ int main(int argc, char **argv) { if (!disableInferLayout) pm.addNestedPass(pto::createInferPTOLayoutPass()); + pm.addPass(pto::createInferPTOTileConfigPass()); pm.addPass(pto::createPTOViewToMemrefPass()); // bufferizationPipeline(pm); //pm.addPass(createInferPTOMemScopePass());