diff --git a/docs/ir/PTO-IR-vf-vops-design.md b/docs/ir/PTO-IR-vf-vops-design.md new file mode 100644 index 00000000..f4fc5f37 --- /dev/null +++ b/docs/ir/PTO-IR-vf-vops-design.md @@ -0,0 +1,43 @@ +# PTO IR: VF/VOPS (vtile) design notes + +This document is **restored from OpenClaw session logs** (jsonl) and recent chat decisions. +It records the intended IR surface syntax and canonicalization expectations for the VF/VOPS layer. + +## Types + +- `!pto.vtile` + - example: `!pto.vtile<64xf32>` +- `!pto.uscalar` + - example: `!pto.uscalar` +- `!pto.preg` + +## Target config + +Attach `pto.target_config` on module or function: + +```mlir +module attributes { + pto.target_config = #pto.target_config +} { + func.func @k() { return } +} +``` + +## Core ops + +- `pto.vf.scope { ... }` +- Predication: + - `pto.vpred.all` + - `pto.vpred.tail %count` +- Loads/stores: + - `pto.vload %tile, %row, %col, %pred` + - `pto.vstore %tile, %row, %col, %value, %pred` + - `pto.vload_tail %tile, %row, %col, %count` + - `pto.vstore_tail %tile, %row, %col, %count, %value` + +## Canonicalization (pass: `-pto-canonicalize-vops`) + +- `vload/vstore` with `vpred.tail(count)` should be rewritten to `vload_tail/vstore_tail`. +- If an operand is produced by `vload_tail(count)`, downstream binops/stores should use `vpred.tail(count)` / `vstore_tail`. +- If `count == lanes` (constant), tail ops may be simplified to non-tail ops. +- Conservative loop-invariant hoisting may move pure pto ops that do not depend on the induction variable out of a `scf.for`. diff --git a/include/PTO/IR/PTOAttrs.td b/include/PTO/IR/PTOAttrs.td index 343a3b76..b4ebff2d 100644 --- a/include/PTO/IR/PTOAttrs.td +++ b/include/PTO/IR/PTOAttrs.td @@ -437,4 +437,83 @@ def TileBufConfigAttr : AttrDef { }]; } +//===----------------------------------------------------------------------===// +// Target Config (VF/VOPS) +//===----------------------------------------------------------------------===// + +// #pto.target_config + +def TargetConfigAttr : AttrDef { + let mnemonic = "target_config"; + let summary = "Target configuration for VF/VOPS emission."; + + let parameters = (ins + "mlir::StringAttr":$arch, // required: "a3" | "a5" + "mlir::StringAttr":$isa, // optional + "mlir::StringAttr":$variant, // optional + "mlir::IntegerAttr":$repeatBytes, // optional + "mlir::IntegerAttr":$blockBytes, // optional + "mlir::DictionaryAttr":$caps // optional, default empty dict + ); + + let hasCustomAssemblyFormat = 1; + + let extraClassDeclaration = [{ + static ::mlir::Attribute parse(::mlir::AsmParser &parser, ::mlir::Type) { + if (failed(parser.parseLess())) return {}; + auto ctx = parser.getContext(); + + ::mlir::StringAttr arch; + ::mlir::StringAttr isa; + ::mlir::StringAttr variant; + ::mlir::IntegerAttr repeatBytes; + ::mlir::IntegerAttr blockBytes; + ::mlir::DictionaryAttr caps; + + while (true) { + llvm::StringRef key; + if (failed(parser.parseKeyword(&key))) return {}; + if (failed(parser.parseEqual())) return {}; + + if (key == "arch") { + llvm::StringRef av; + if (failed(parser.parseKeyword(&av))) return {}; + if (av != "a3" && av != "a5") return {}; + arch = ::mlir::StringAttr::get(ctx, av); + } else if (key == "isa") { + if (failed(parser.parseAttribute(isa))) return {}; + } else if (key == "variant") { + if (failed(parser.parseAttribute(variant))) return {}; + } else if (key == "repeat_bytes") { + if (failed(parser.parseAttribute(repeatBytes))) return {}; + } else if (key == "block_bytes") { + if (failed(parser.parseAttribute(blockBytes))) return {}; + } else if (key == "caps") { + if (failed(parser.parseAttribute(caps))) return {}; + } else { + return {}; + } + + if (succeeded(parser.parseOptionalGreater())) break; + if (failed(parser.parseComma())) return {}; + } + + if (!arch) return {}; + if (!caps) caps = ::mlir::DictionaryAttr::get(ctx); + return Base::get(ctx, arch, isa, variant, repeatBytes, blockBytes, caps); + } + + void print(::mlir::AsmPrinter &printer) const { + printer << "<"; + printer << "arch=" << getArch().getValue(); + if (getIsa()) printer << ", isa=" << getIsa(); + if (getVariant()) printer << ", variant=" << getVariant(); + if (getRepeatBytes()) printer << ", repeat_bytes=" << getRepeatBytes(); + if (getBlockBytes()) printer << ", block_bytes=" << getBlockBytes(); + if (getCaps() && !getCaps().empty()) printer << ", caps=" << getCaps(); + printer << ">"; + } + }]; +} + #endif // MLIR_DIALECT_PTO_IR_PTOATTRS diff --git a/include/PTO/IR/PTOOps.td b/include/PTO/IR/PTOOps.td index c967a75e..17d5418e 100644 --- a/include/PTO/IR/PTOOps.td +++ b/include/PTO/IR/PTOOps.td @@ -3645,4 +3645,177 @@ def TPrintOp: PTO_TOp<"tprint", [ }]; } +//===----------------------------------------------------------------------===// +// VF / VOPS (vector-tile ops) +//===----------------------------------------------------------------------===// + +// pto.vf.scope { ... } +def VFScopeOp : PTO_Op<"vf.scope", [IsolatedFromAbove, NoRegionArguments, SingleBlockImplicitTerminator<"YieldOp">]> { + let summary = "VF scope for explicit V-pipe ops."; + let regions = (region AnyRegion:$body); + let assemblyFormat = "attr-dict $body"; +} + +// yield terminator for vf.scope +def YieldOp : PTO_Op<"yield", [Terminator, HasParent<"VFScopeOp">]> { + let summary = "Terminator for pto.vf.scope"; + let arguments = (ins); + let results = (outs); + let assemblyFormat = "attr-dict"; +} + +// pto.vpred.all : !pto.preg +def VPredAllOp : PTO_Op<"vpred.all", [Pure]> { + let summary = "Create an all-true predicate."; + let results = (outs PregType:$pred); + let assemblyFormat = "attr-dict `:` type($pred)"; +} + +// pto.vpred.tail %count : !pto.preg +def VPredTailOp : PTO_Op<"vpred.tail", [Pure]> { + let summary = "Create a tail predicate for a given element count."; + let arguments = (ins Index:$count); + let results = (outs PregType:$pred); + let assemblyFormat = "$count attr-dict `:` type($pred)"; +} + +// pto.uload_row %tile, %row : !pto.uscalar +def ULoadRowOp : PTO_Op<"uload_row", [Pure]> { + let summary = "Uniform scalar load for RowExpand-like patterns."; + let arguments = (ins TileBufType:$tile, Index:$row); + let results = (outs UScalarType:$value); + let hasVerifier = 1; + let assemblyFormat = "$tile `,` $row attr-dict `:` type($value)"; +} + +// pto.vdup %u, %pred : !pto.vtile<...> +def VDupOp : PTO_Op<"vdup", [Pure]> { + let summary = "Duplicate a uniform scalar into a vtile under predicate."; + let arguments = (ins UScalarType:$src, PregType:$pred); + let results = (outs VTileType:$dst); + let hasVerifier = 1; + let assemblyFormat = "$src `,` $pred attr-dict `:` type($dst)"; +} + +// pto.vload %tile, %row, %col, %pred : !pto.vtile<...> +def VLoadOp : PTO_Op<"vload", [DeclareOpInterfaceMethods]> { + let summary = "Vector load from tile buffer at (row,col) under predicate."; + let arguments = (ins TileBufType:$tile, Index:$row, Index:$col, PregType:$pred); + let results = (outs VTileType:$value); + let hasVerifier = 1; + let assemblyFormat = "$tile `,` $row `,` $col `,` $pred attr-dict `:` type($value)"; +} + +// pto.vstore %tile, %row, %col, %value, %pred +def VStoreOp : PTO_Op<"vstore", [DeclareOpInterfaceMethods]> { + let summary = "Vector store to tile buffer at (row,col) under predicate."; + let arguments = (ins TileBufType:$tile, Index:$row, Index:$col, VTileType:$value, PregType:$pred); + let results = (outs); + let hasVerifier = 1; + let assemblyFormat = "$tile `,` $row `,` $col `,` $value `,` $pred attr-dict `:` type($value)"; +} + +// pto.vload_tail %tile, %row, %col, %count : !pto.vtile<...> +def VLoadTailOp : PTO_Op<"vload_tail", [DeclareOpInterfaceMethods]> { + let summary = "Tail-safe vector load with explicit count."; + let arguments = (ins TileBufType:$tile, Index:$row, Index:$col, Index:$count); + let results = (outs VTileType:$value); + let hasVerifier = 1; + let assemblyFormat = "$tile `,` $row `,` $col `,` $count attr-dict `:` type($value)"; +} + +// pto.vstore_tail %tile, %row, %col, %count, %value +def VStoreTailOp : PTO_Op<"vstore_tail", [DeclareOpInterfaceMethods]> { + let summary = "Tail-safe vector store with explicit count."; + let arguments = (ins TileBufType:$tile, Index:$row, Index:$col, Index:$count, VTileType:$value); + let results = (outs); + let hasVerifier = 1; + let assemblyFormat = "$tile `,` $row `,` $col `,` $count `,` $value attr-dict `:` type($value)"; +} + +// pto.vload_block %tile, %row : !pto.vtile<...> +def VLoadBlockOp : PTO_Op<"vload_block", [DeclareOpInterfaceMethods]> { + let summary = "Block load used by RowExpand block-broadcast patterns."; + let arguments = (ins TileBufType:$tile, Index:$row); + let results = (outs VTileType:$value); + let hasVerifier = 1; + let assemblyFormat = "$tile `,` $row attr-dict `:` type($value)"; +} + +// pto.vlane_adapt %blk : !pto.vtile<...> +def VLaneAdaptOp : PTO_Op<"vlane_adapt", [Pure]> { + let summary = "Adapt lanes from a block vtile to a full vtile."; + let arguments = (ins VTileType:$src); + let results = (outs VTileType:$dst); + let hasVerifier = 1; + let assemblyFormat = "$src attr-dict `:` type($dst)"; +} + +// Binops: (vtile, vtile, preg) -> vtile + +def VAddOp : PTO_Op<"vadd", [Pure]> { + let summary = "Vector add."; + let arguments = (ins VTileType:$lhs, VTileType:$rhs, PregType:$pred); + let results = (outs VTileType:$dst); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs `,` $pred attr-dict `:` type($dst)"; +} + +def VSubOp : PTO_Op<"vsub", [Pure]> { + let summary = "Vector sub."; + let arguments = (ins VTileType:$lhs, VTileType:$rhs, PregType:$pred); + let results = (outs VTileType:$dst); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs `,` $pred attr-dict `:` type($dst)"; +} + +def VMulOp : PTO_Op<"vmul", [Pure]> { + let summary = "Vector mul."; + let arguments = (ins VTileType:$lhs, VTileType:$rhs, PregType:$pred); + let results = (outs VTileType:$dst); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs `,` $pred attr-dict `:` type($dst)"; +} + +def VMinOp : PTO_Op<"vmin", [Pure]> { + let summary = "Vector min."; + let arguments = (ins VTileType:$lhs, VTileType:$rhs, PregType:$pred); + let results = (outs VTileType:$dst); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs `,` $pred attr-dict `:` type($dst)"; +} + +def VMaxOp : PTO_Op<"vmax", [Pure]> { + let summary = "Vector max."; + let arguments = (ins VTileType:$lhs, VTileType:$rhs, PregType:$pred); + let results = (outs VTileType:$dst); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs `,` $pred attr-dict `:` type($dst)"; +} + +def VAndOp : PTO_Op<"vand", [Pure]> { + let summary = "Vector and."; + let arguments = (ins VTileType:$lhs, VTileType:$rhs, PregType:$pred); + let results = (outs VTileType:$dst); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs `,` $pred attr-dict `:` type($dst)"; +} + +def VOrOp : PTO_Op<"vor", [Pure]> { + let summary = "Vector or."; + let arguments = (ins VTileType:$lhs, VTileType:$rhs, PregType:$pred); + let results = (outs VTileType:$dst); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs `,` $pred attr-dict `:` type($dst)"; +} + +def VXorOp : PTO_Op<"vxor", [Pure]> { + let summary = "Vector xor."; + let arguments = (ins VTileType:$lhs, VTileType:$rhs, PregType:$pred); + let results = (outs VTileType:$dst); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs `,` $pred attr-dict `:` type($dst)"; +} + + #endif // MLIR_DIALECT_PTO_IR_PTOOPS diff --git a/include/PTO/IR/PTOTypeDefs.td b/include/PTO/IR/PTOTypeDefs.td index d4a79746..89a8aa8c 100644 --- a/include/PTO/IR/PTOTypeDefs.td +++ b/include/PTO/IR/PTOTypeDefs.td @@ -184,3 +184,72 @@ def TileBufType : TypeDef { int32_t getPadValueI32() const; // 0 null, 1 zero, 2 max, 3 min }]; } + +// ---- !pto.preg ---- +// Predicate register used by vops. +def PregType : TypeDef { + let mnemonic = "preg"; + let summary = "Predicate register type used by VOPS."; +} + +// ---- !pto.uscalar ---- +// Uniform scalar (per-thread uniform) used by vops patterns. +def UScalarType : TypeDef { + let mnemonic = "uscalar"; + let summary = "Uniform scalar value used for scalar+SIMD patterns inside pto.vf.scope."; + let parameters = (ins "mlir::Type":$elementType); + + // Print/parse as: !pto.uscalar + let hasCustomAssemblyFormat = 1; + + let extraClassDeclaration = [{ + static ::mlir::Type parse(::mlir::AsmParser &parser) { + if (failed(parser.parseLess())) return {}; + ::mlir::Type elem; + if (failed(parser.parseType(elem))) return {}; + if (failed(parser.parseGreater())) return {}; + return Base::get(parser.getContext(), elem); + } + + void print(::mlir::AsmPrinter &printer) const { + printer << "<"; + printer.printType(getElementType()); + printer << ">"; + } + }]; +} + +// ---- !pto.vtile ---- +// Vector tile value used in vops. Lanes is typically elements-per-repeat (EPR). +def VTileType : TypeDef { + let mnemonic = "vtile"; + let summary = "Vector tile value used inside pto.vf.scope (maps to vreg on A5, subtile view on A3)."; + let parameters = (ins + "mlir::Type":$elementType, + "int64_t":$lanes + ); + + // Print/parse as: !pto.vtile + let hasCustomAssemblyFormat = 1; + + let extraClassDeclaration = [{ + static ::mlir::Type parse(::mlir::AsmParser &parser) { + if (failed(parser.parseLess())) return {}; + int64_t lanes = 0; + if (failed(parser.parseInteger(lanes))) return {}; + if (failed(parser.parseKeyword("x"))) return {}; + ::mlir::Type elem; + if (failed(parser.parseType(elem))) return {}; + if (failed(parser.parseGreater())) return {}; + return Base::get(parser.getContext(), elem, lanes); + } + + void print(::mlir::AsmPrinter &printer) const { + printer << "<" << getLanes() << "x"; + printer.printType(getElementType()); + printer << ">"; + } + + int64_t lanes() const { return getLanes(); } + }]; +} diff --git a/include/PTO/Transforms/Passes.td b/include/PTO/Transforms/Passes.td index ab7f29df..a841b2a0 100644 --- a/include/PTO/Transforms/Passes.td +++ b/include/PTO/Transforms/Passes.td @@ -116,4 +116,16 @@ def PTOLoweringSyncToPipe : Pass<"pto-lowering-sync-to-pipe", "func::FuncOp"> { ]; } + + +def PTOCanonicalizeVops : Pass<"pto-canonicalize-vops", "func::FuncOp"> { + let summary = "Canonicalize VOPS patterns (tail ops, pred propagation, hoisting)."; + let constructor = "mlir::pto::createPTOCanonicalizeVopsPass()"; + let dependentDialects = [ + "mlir::pto::PTODialect", + "mlir::scf::SCFDialect", + "mlir::arith::ArithDialect" + ]; +} + #endif // MLIR_DIALECT_PTO_PASSES diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index f220e09a..3117b560 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -4434,6 +4434,28 @@ void TMatmulMxBiasOp::getEffects(SmallVectorImpl()) + return lt.getCount(); + return Value(); +} + +static Value getOrCreatePredAll(Location loc, PatternRewriter &rewriter) { + for (auto &op : *rewriter.getInsertionBlock()) { + if (auto all = dyn_cast(&op)) + return all.getPred(); + } + return rewriter.create(loc).getPred(); +} + +// vload(tile,row,col, pred=tail(count)) -> vload_tail(tile,row,col,count) +struct VLoadTailFromPred final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(mlir::pto::VLoadOp op, + PatternRewriter &rewriter) const override { + auto tail = op.getPred().getDefiningOp(); + if (!tail) + return failure(); + + auto vt = rewriter.create(op.getLoc(), op.getValue().getType(), + op.getTile(), op.getRow(), + op.getCol(), tail.getCount()); + rewriter.replaceOp(op, vt.getValue()); + return success(); + } +}; + +// vstore(tile,row,col,val, pred=tail(count)) -> vstore_tail(tile,row,col,count,val) +struct VStoreTailFromPred final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(mlir::pto::VStoreOp op, + PatternRewriter &rewriter) const override { + auto tail = op.getPred().getDefiningOp(); + if (!tail) + return failure(); + + rewriter.create(op.getLoc(), op.getTile(), op.getRow(), + op.getCol(), tail.getCount(), + op.getValue()); + rewriter.eraseOp(op); + return success(); + } +}; + +// Binop pred propagation: all -> tail(count) if any operand comes from vload_tail(count). +template +struct BinOpUseTailPred final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(BinOp op, + PatternRewriter &rewriter) const override { + auto all = op.getPred().template getDefiningOp(); + if (!all) + return failure(); + + Value c0 = getCountFromVLoadTail(op.getLhs()); + Value c1 = getCountFromVLoadTail(op.getRhs()); + if (!c0 && !c1) + return failure(); + + Value count = c0 ? c0 : c1; + if (c0 && c1 && c0 != c1) + return failure(); + + auto tail = rewriter.create(op.getLoc(), count); + auto repl = rewriter.create(op.getLoc(), op.getDst().getType(), op.getLhs(), + op.getRhs(), tail.getPred()); + rewriter.replaceOp(op, repl.getDst()); + return success(); + } +}; + +// Store tail-ization: vstore(all) -> vstore_tail if value depends on vload_tail. +struct VStoreAllToTail final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(mlir::pto::VStoreOp op, + PatternRewriter &rewriter) const override { + auto all = op.getPred().getDefiningOp(); + if (!all) + return failure(); + + Value count = getCountFromVLoadTail(op.getValue()); + if (!count) + return failure(); + + rewriter.create(op.getLoc(), op.getTile(), op.getRow(), + op.getCol(), count, op.getValue()); + rewriter.eraseOp(op); + return success(); + } +}; + +// count==lanes : vload_tail -> vload(pred.all) +struct VLoadTailToVLoad final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(mlir::pto::VLoadTailOp op, + PatternRewriter &rewriter) const override { + auto cst = op.getCount().getDefiningOp(); + if (!cst) + return failure(); + auto lanes = cast(op.getValue().getType()).getLanes(); + if ((int64_t)cst.value() != lanes) + return failure(); + + Value pAll = getOrCreatePredAll(op.getLoc(), rewriter); + auto nl = rewriter.create(op.getLoc(), op.getValue().getType(), + op.getTile(), op.getRow(), + op.getCol(), pAll); + rewriter.replaceOp(op, nl.getValue()); + return success(); + } +}; + +// count==lanes : vstore_tail -> vstore(pred.all) +struct VStoreTailToVStore final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(mlir::pto::VStoreTailOp op, + PatternRewriter &rewriter) const override { + auto cst = op.getCount().getDefiningOp(); + if (!cst) + return failure(); + auto lanes = cast(op.getValue().getType()).getLanes(); + if ((int64_t)cst.value() != lanes) + return failure(); + + Value pAll = getOrCreatePredAll(op.getLoc(), rewriter); + rewriter.create(op.getLoc(), op.getTile(), op.getRow(), + op.getCol(), op.getValue(), pAll); + rewriter.eraseOp(op); + return success(); + } +}; + +// Conservative loop-invariant hoisting: hoist a single pure pto op producing vtile/uscalar/preg. +struct HoistPureVtileOpFromFor final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(scf::ForOp forOp, + PatternRewriter &rewriter) const override { + if (forOp.getNumIterOperands() != 0) + return failure(); + + Value iv = forOp.getInductionVar(); + + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!op.getDialect() || op.getDialect()->getNamespace() != "pto") + continue; + if (!MemoryEffectOpInterface::hasNoEffect(&op)) + continue; + if (op.getNumResults() != 1) + continue; + + Type ty = op.getResult(0).getType(); + if (!isa(ty)) + continue; + + bool usesIV = llvm::any_of(op.getOperands(), [&](Value v) { return v == iv; }); + if (usesIV) + continue; + + bool dependsOnLoop = false; + for (Value v : op.getOperands()) { + if (auto *def = v.getDefiningOp()) { + if (def->getParentOp() == forOp) { + dependsOnLoop = true; + break; + } + } + } + if (dependsOnLoop) + continue; + + rewriter.setInsertionPoint(forOp); + Operation *cloned = rewriter.clone(op); + op.replaceAllUsesWith(cloned->getResults()); + rewriter.eraseOp(&op); + return success(); + } + + return failure(); + } +}; + +struct PTOCanonicalizeVopsPass + : public mlir::pto::impl::PTOCanonicalizeVopsBase { + void runOnOperation() override { + MLIRContext *ctx = &getContext(); + RewritePatternSet patterns(ctx); + + patterns.add(ctx); + patterns.add(ctx); + + patterns.add(ctx); + + patterns.add, BinOpUseTailPred, + BinOpUseTailPred, BinOpUseTailPred, + BinOpUseTailPred, BinOpUseTailPred, + BinOpUseTailPred, BinOpUseTailPred>(ctx); + + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace + +namespace mlir::pto { +std::unique_ptr createPTOCanonicalizeVopsPass() { + return std::make_unique(); +} +} // namespace mlir::pto